├── .gitignore ├── LICENSE ├── README.md └── nmp ├── __init__.py ├── aggregator.py ├── config ├── digit │ ├── mix_digits.yaml │ ├── mix_digits_offline.yaml │ ├── noisy_digit.yaml │ └── one_digit.yaml └── robot_push │ ├── robot_push_bhc.yaml │ ├── robot_push_cnmp.yaml │ ├── robot_push_prodmp.yaml │ └── robot_push_promp.yaml ├── data_process.py ├── decoder.py ├── encoder.py ├── experiment ├── digit │ ├── __init__.py │ ├── digit.py │ └── digit_cw.py └── robot_push │ ├── __init__.py │ ├── robot_push_bhc.py │ ├── robot_push_cnmp.py │ ├── robot_push_prodmp.py │ └── robot_push_promp.py ├── logger.py ├── loss.py ├── net.py ├── nn_base.py ├── others ├── __init__.py └── ellipses_noise.py └── util ├── __init__.py ├── util.py ├── util_data_structure.py ├── util_debug.py ├── util_file.py ├── util_geometry.py ├── util_hyperparams.py ├── util_learning.py ├── util_matrix.py ├── util_media.py ├── util_numerical.py └── util_string.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | 132 | # Pycharm config files 133 | /.idea 134 | 135 | # Tensorboard 136 | runs/ 137 | 138 | # Dataset 139 | dataset/ 140 | 141 | # log 142 | log/ 143 | 144 | # figures 145 | figure/ 146 | 147 | # results 148 | result/ 149 | /.pycharmrc 150 | 151 | # wandb 152 | **/wandb/ 153 | 154 | # Tmp 155 | tmp_video/ 156 | tmp/ 157 | media/ 158 | numerical_result/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Ge Li (Bruce) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ProDMP_RAL 2 | This is the imitation learning code base of ProDMP, which combines the ProDMP with an encoder-decoder deep neural network. 3 | The ProDMP alone can be used as an individual module. We implemented it together with the DMP and ProMP in https://github.com/ALRhub/MP_PyTorch. 4 | 5 | 6 | ## Pre-requisites 7 | conda or pip: 8 | pytorch (ML code base), 9 | wandb (Online ML experiment logger) 10 | 11 | pip: 12 | 13 | MP_PyTorch (trajectory generator): https://pypi.org/project/mp-pytorch/, 14 | 15 | cw2 (deploy experiment locally or on cluster): https://pypi.org/project/cw2/, 16 | 17 | pyyaml (parse yaml file), 18 | 19 | tabulate (utility package), 20 | 21 | natsort (utility package), 22 | 23 | python-mnist (utility package), 24 | 25 | ## Run Exp: 26 | First init your wandb account: 'wandb init' 27 | 28 | Replace my username of wandb account in all config files by your wandb account 29 | 'gelikit' -> 'my_wandb_username' 30 | 31 | Run experiment through this way: 32 | 33 | 'python exp.py config.yaml -o --nocodecopy' 34 | 35 | E.g. 'python digit_cw.py one_digit.yaml -o, --nocodecopy' 36 | 37 | ## Dataset 38 | https://drive.google.com/drive/folders/1N_WomzuY2wDX5lOGVg5PvjTqkVGlA4pl?usp=sharing 39 | -------------------------------------------------------------------------------- /nmp/__init__.py: -------------------------------------------------------------------------------- 1 | from .encoder import * 2 | from .aggregator import * 3 | from .decoder import * 4 | from .nn_base import * 5 | from .net import * 6 | from .util import * 7 | from .logger import * 8 | from .data_process import * 9 | from .loss import * -------------------------------------------------------------------------------- /nmp/aggregator.py: -------------------------------------------------------------------------------- 1 | """ 2 | @brief: Aggregator classes in PyTorch 3 | """ 4 | from typing import Optional 5 | 6 | import torch 7 | 8 | 9 | class BayesianAggregator: 10 | """A Bayesian Aggregator""" 11 | 12 | def __init__(self, **kwargs): 13 | """ 14 | Bayesian Aggregator constructor 15 | Args: 16 | **kwargs: aggregator configuration 17 | """ 18 | 19 | # Aggregator dimension 20 | self.dim_lat_obs: int = kwargs["dim_lat"] 21 | self.dim_lat: int = kwargs["dim_lat"] 22 | self.multiple_steps: bool = kwargs.get("multiple_steps", False) 23 | 24 | # Scalar prior 25 | self.prior_mean_init = kwargs["prior_mean"] 26 | self.prior_var_init = kwargs["prior_var"] 27 | assert self.prior_var_init >= 0 # We only consider diagonal 28 | # terms, so always be positive 29 | 30 | # Number of trajectories, i.e. equals to batch size 31 | self.num_traj = None 32 | 33 | # Number of aggregated subsets, each subset may contain more than 1 obs 34 | self.num_agg = 0 35 | 36 | # Number of aggregated observations 37 | self.num_agg_obs = 0 38 | 39 | # Aggregation history of latent variables 40 | self.mean_lat_var_state = None 41 | self.variance_lat_var_state = None 42 | 43 | def reset(self, num_traj: int): 44 | """ 45 | Reset aggregator 46 | 47 | Args: 48 | num_traj: batch size 49 | 50 | Returns: 51 | None 52 | 53 | """ 54 | 55 | # Reset num_traj, i.e. equals to batch size 56 | self.num_traj = num_traj 57 | 58 | # Reset number of counters 59 | self.num_agg = 0 60 | self.num_agg_obs = 0 61 | 62 | # Reset aggregation history of latent variables 63 | # i.e. mean_lat_var_state and variance_lat_var_state 64 | # 65 | # Note its shape[1] = num_agg + 1, which tells how many context sets 66 | # have been aggregated by the aggregator, e.g. index 0 denotes the prior 67 | # distribution of latent variable, index -1 denotes the current 68 | # distribution of latent variable. Note in each aggregation, the latent 69 | # observation may have different number of samples. 70 | # 71 | # Shape of mean_lat_var_state: [num_traj, num_agg + 1, dim_lat] 72 | # Shape of variance_lat_var_state: [num_traj, num_agg + 1, dim_lat] 73 | 74 | # Get prior tensors from scalar 75 | prior_mean, prior_var = self.generate_prior(self.prior_mean_init, 76 | self.prior_var_init) 77 | 78 | # Add one axis (record number of aggregation) 79 | self.mean_lat_var_state = prior_mean[:, None, :] 80 | self.variance_lat_var_state = prior_var[:, None, :] 81 | 82 | def generate_prior(self, mean: float, cov: float) \ 83 | -> (torch.Tensor, torch.Tensor): 84 | """ 85 | Given scalar values of mean and covariance, generate prior tensor 86 | Args: 87 | mean: scalar value of mean 88 | cov: scalar value of covariance 89 | 90 | Returns: tensors of prior's mean and prior's variance 91 | """ 92 | # Shape of prior_mean, prior_var: 93 | # [num_traj, dim_lat] 94 | 95 | prior_mean = torch.full(size=(self.num_traj, self.dim_lat), 96 | fill_value=mean) 97 | prior_var = torch.full(size=(self.num_traj, self.dim_lat), 98 | fill_value=cov) 99 | return prior_mean, prior_var 100 | 101 | def aggregate(self, lat_obs: torch.Tensor, var_lat_obs: torch.Tensor): 102 | """ 103 | Aggregate info of latent observation and compute new latent variable. 104 | If there's no latent observation, then return prior of latent variable. 105 | 106 | Args: 107 | lat_obs: latent observations of samples in certain trajectories 108 | var_lat_obs: covariance (uncertainty) of latent observations 109 | 110 | Returns: 111 | None 112 | """ 113 | 114 | # Shape of lat_obs: 115 | # [num_traj, num_obs, dim_lat] 116 | # 117 | # Shape of var_lat_obs: 118 | # [num_traj, num_obs, dim_lat] 119 | 120 | # Case without latent observation 121 | if lat_obs.shape[1] == 0 and var_lat_obs.shape[1] == 0: 122 | # No latent observation, do not update the latent variable state 123 | pass 124 | 125 | # Case with latent observation 126 | else: 127 | # Check input shapes 128 | assert lat_obs.ndim == var_lat_obs.ndim == 3 129 | assert lat_obs.shape == var_lat_obs.shape 130 | assert lat_obs.shape[0] == self.num_traj 131 | assert lat_obs.shape[2] == self.dim_lat_obs 132 | 133 | # number of observations 134 | num_obs = lat_obs.shape[1] 135 | 136 | # Get the latest latent variable distribution 137 | mean_lat_var = self.mean_lat_var_state[:, -1, :] 138 | variance_lat_var = self.variance_lat_var_state[:, -1, :] 139 | 140 | # Aggregate 141 | agg_step = 1 if self.multiple_steps else num_obs 142 | for idx in range(0, num_obs, agg_step): 143 | # Update uncertainty of latent variable 144 | variance_lat_var = \ 145 | 1 / (1 / variance_lat_var 146 | + torch.sum(1 / var_lat_obs[:, idx:idx + agg_step, :], 147 | dim=1)) 148 | # Update mean of latent variable 149 | mean_lat_var = mean_lat_var + variance_lat_var * torch.sum( 150 | 1 / var_lat_obs[:, idx:idx + agg_step, :] 151 | * (lat_obs[:, idx:idx + agg_step, :] 152 | - mean_lat_var[:, None, :]), dim=1) 153 | 154 | # Append to latent variable state 155 | self.mean_lat_var_state = torch.cat( 156 | (self.mean_lat_var_state, mean_lat_var[:, None, :]), dim=1) 157 | self.variance_lat_var_state = \ 158 | torch.cat((self.variance_lat_var_state, 159 | variance_lat_var[:, None, :]), dim=1) 160 | 161 | # Update counters 162 | self.num_agg += 1 163 | self.num_agg_obs += agg_step 164 | 165 | def get_agg_state(self, index: Optional[int]) \ 166 | -> (torch.Tensor, torch.Tensor): 167 | """ 168 | Return all latent variable state, or the one at given index. 169 | E.g. index -1 denotes the last latent variable state; index 0 the prior 170 | 171 | Returns: 172 | mean_lat_var_state: mean of the latent variable state 173 | variance_lat_var_state: covariance of the latent variable state 174 | """ 175 | 176 | # Shape of mean_lat_var_state: 177 | # [num_traj, num_agg, dim_lat] 178 | # 179 | # Shape of variance_lat_var_state: 180 | # [num_traj, num_agg, dim_lat] 181 | # 182 | # num_agg = 1 if index is not None 183 | 184 | if index is None: 185 | # Full case 186 | return self.mean_lat_var_state, self.variance_lat_var_state 187 | 188 | elif index == -1 or index + 1 == self.mean_lat_var_state.shape[1]: 189 | # Index case -1 190 | return self.mean_lat_var_state[:, index:, :], \ 191 | self.variance_lat_var_state[:, index:, :] 192 | else: 193 | # Other index cases 194 | return self.mean_lat_var_state[:, index: index + 1, :], \ 195 | self.variance_lat_var_state[:, index:index + 1, :] 196 | 197 | 198 | class MeanAggregator: 199 | """A mean aggregator""" 200 | 201 | def __init__(self, **kwargs): 202 | """ 203 | Mean aggregator constructor 204 | Args: 205 | **kwargs: aggregator configuration 206 | """ 207 | # Aggregator dimension 208 | self.dim_lat_obs: int = kwargs["dim_lat"] 209 | self.multiple_steps: bool = kwargs.get("multiple_steps", False) 210 | 211 | # Scalar prior 212 | self.prior_mean_init = kwargs["prior_mean"] 213 | 214 | # Number of trajectories, i.e. equals to batch size 215 | self.num_traj = None 216 | 217 | # Number of aggregated subsets, each subset may have more than 1 obs 218 | self.num_agg = 0 219 | 220 | # Number of aggregated obs 221 | self.num_agg_obs = 0 222 | 223 | # Aggregation history of latent observation 224 | self.mean_lat_obs_state = None 225 | 226 | def reset(self, num_traj: int): 227 | """ 228 | Reset aggregator 229 | 230 | Args: 231 | num_traj: batch size 232 | 233 | Returns: 234 | None 235 | 236 | """ 237 | # Reset num_traj, i.e. equals to batch size 238 | self.num_traj = num_traj 239 | 240 | # Reset counters 241 | self.num_agg = 0 242 | self.num_agg_obs = 0 243 | 244 | # Reset aggregation history of latent observation 245 | # i.e. mean_lat_rep_state 246 | # 247 | # Note its shape[1] = num_agg + 1, which tells how many context 248 | # "sets" have been aggregated by the aggregator, e.g. index 0 249 | # denotes the prior mean of latent observation, index -1 denotes the 250 | # current mean of latent observation. Note in each aggregation, the 251 | # latent observation to be aggregated may have different number of 252 | # samples. 253 | # 254 | # Shape of mean_lat_obs_state: [num_traj, num_agg + 1, dim_lat] 255 | 256 | # Get prior tensors from scalar 257 | prior_mean = self.generate_prior(self.prior_mean_init) 258 | 259 | # Add one axis (record number of aggregation) 260 | self.mean_lat_obs_state = prior_mean[:, None, :] 261 | 262 | def generate_prior(self, mean: float) -> torch.Tensor: 263 | """ 264 | Given scalar value of mean, generate prior tensor 265 | Args: 266 | mean: scalar value of mean 267 | 268 | Returns: tensors of prior's mean for mean latent observation 269 | """ 270 | # Shape of prior_mean: 271 | # [num_traj, dim_lat] 272 | 273 | prior_mean = torch.full(size=(self.num_traj, self.dim_lat_obs), 274 | fill_value=mean) 275 | 276 | return prior_mean 277 | 278 | def aggregate(self, lat_obs: torch.Tensor): 279 | """ 280 | Aggregate info of latent observation 281 | 282 | Args: 283 | lat_obs: latent observations of samples in certain trajectories 284 | 285 | Returns: 286 | None 287 | """ 288 | 289 | # Shape of lat_obs: 290 | # [num_traj, num_obs, dim_lat] 291 | 292 | # Case without latent observation 293 | if lat_obs.shape[1] == 0: 294 | # No latent observation, do not update the latent observations 295 | pass 296 | 297 | else: 298 | # Check input shapes 299 | assert lat_obs.ndim == 3 300 | assert lat_obs.shape[0] == self.num_traj 301 | assert lat_obs.shape[2] == self.dim_lat_obs 302 | 303 | # Number of observations 304 | num_obs = lat_obs.shape[1] 305 | 306 | # Get latest latent obs 307 | mean_lat_obs = self.mean_lat_obs_state[:, -1, :] 308 | 309 | # Aggregate 310 | agg_step = 1 if self.multiple_steps else num_obs 311 | for step in range(0, num_obs, agg_step): 312 | # Compute new mean 313 | mean_lat_obs = \ 314 | (mean_lat_obs * self.num_agg_obs + 315 | torch.sum(lat_obs[:, step:step + agg_step, :], dim=1)) \ 316 | / (self.num_agg_obs + agg_step) 317 | 318 | # Append 319 | self.mean_lat_obs_state = torch.cat( 320 | (self.mean_lat_obs_state, mean_lat_obs[:, None, :]), dim=1) 321 | 322 | # Update counters 323 | self.num_agg += 1 324 | self.num_agg_obs += agg_step 325 | 326 | def get_agg_state(self, index: Optional[int]) -> torch.Tensor: 327 | """ 328 | Return all latent observation state, or the one at given index. 329 | E.g. index -1 denotes the last latent obs state; index 0 the prior 330 | 331 | Returns: 332 | mean_lat_obs_state: mean of the latent observation state 333 | """ 334 | 335 | # Shape of mean_lat_obs_state: 336 | # [num_traj, num_agg, dim_lat] 337 | # 338 | # num_agg = 1 if index is not None 339 | 340 | if index is None: 341 | # Full case 342 | return self.mean_lat_obs_state 343 | 344 | elif index == -1 or index + 1 == self.mean_lat_obs_state.shape()[1]: 345 | # Index case -1 346 | return self.mean_lat_obs_state[:, index:, :] 347 | else: 348 | # Other index cases 349 | return self.mean_lat_obs_state[:, index: index + 1, :] 350 | 351 | # End of class MeanAggregator 352 | 353 | 354 | class AggregatorFactory: 355 | 356 | @staticmethod 357 | def get_aggregator(aggregator_type: str, **kwargs): 358 | return eval(aggregator_type + "(**kwargs)") 359 | -------------------------------------------------------------------------------- /nmp/config/digit/mix_digits.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | # Slurm config bwuni gpu 3 | name: "SLURM" # MUST BE "SLURM" 4 | partition: "gpu" # "single" for cpu, "gpu_4" or gpu_8" for gpu 5 | job-name: "mnist" # this will be the experiment's name in slurm 6 | num_parallel_jobs: 20 # max number of jobs executed in parallel 7 | ntasks: 1 # leave that like it is 8 | cpus-per-task: 2 # there are 10 cores for each GPU 9 | mem-per-cpu: 10000 # in MB 10 | time: 1000 # in minutes 11 | sbatch_args: # gpus need to be explicitly requested using this 12 | gres=gpu:1: "" #and this 13 | # nodelist: "node3" 14 | 15 | --- 16 | name: &name "pronmp_mix_digits" 17 | 18 | # Required: Can also be set in DEFAULT 19 | path: /tmp/result/mix_digits # path for saving the results 20 | repetitions: 1 # number of repeated runs for each parameter combination 21 | 22 | # Required for AbstractIterativeExperiments only. Can also be set in DEFAULT 23 | iterations: 5000 # number of iterations per repetition. 24 | 25 | # Optional: Can also be set in DEFAULT 26 | # Only change these values if you are sure you know what you are doing. 27 | reps_per_job: 1 # number of repetitions in each job. useful for paralellization. defaults to 1. 28 | reps_in_parallel: 20 29 | 30 | wandb: 31 | project: *name 32 | group: training 33 | entity: gelikit 34 | log_interval: &interval 20 35 | log_model: true 36 | model_name: test_model_name 37 | 38 | params: 39 | dim_lat: &dim_lat 128 40 | act_func: &act_func leaky_relu 41 | lr: 2e-4 42 | wd: 5e-5 43 | vali_log_interval: *interval 44 | save_model_interval: 200 45 | seed: 1234 46 | runtime_noise: true 47 | max_norm: 20 48 | 49 | encoders: 50 | cnn: 51 | type: ProNMPEncoderCnnMlp 52 | args: 53 | image_size: 54 | - 40 55 | - 40 56 | kernel_size: 5 57 | num_cnn: 2 58 | cnn_channels: 59 | - 1 60 | - 10 61 | - 20 62 | dim_lat: *dim_lat 63 | obs_hidden: 64 | avg_neuron: 128 65 | num_hidden: 2 66 | shape: 0.0 67 | unc_hidden: 68 | avg_neuron: 128 69 | num_hidden: 3 70 | shape: 0.0 71 | act_func: *act_func 72 | 73 | aggregator: 74 | type: BayesianAggregator 75 | args: 76 | dim_lat: *dim_lat 77 | multiple_steps: true 78 | prior_mean: 0.0 79 | prior_var: 1 80 | 81 | decoder: 82 | type: PBDecoder 83 | args: 84 | dim_add_in: 0 85 | dim_val: 54 86 | dim_lat: *dim_lat 87 | std_only: False 88 | mean_hidden: 89 | avg_neuron: 128 90 | num_hidden: 3 91 | shape: 0.0 92 | variance_hidden: 93 | avg_neuron: 256 94 | num_hidden: 4 95 | shape: 0.0 96 | act_func: *act_func 97 | 98 | dataset: 99 | name: s_mnist_25_mix_0_only 100 | partition: 101 | train: 0.7 102 | validate: 0.15 103 | test: 0.15 104 | shuffle_set: False 105 | batch_size: 512 106 | shuffle_train_loader: True 107 | transform: null 108 | time_min: 0 109 | time_max: 3 110 | save_type: tensor 111 | data: 112 | images: 113 | time_dependent: false 114 | normalize: false 115 | trajs: 116 | time_dependent: true 117 | init_x_y_dmp_w_g: 118 | time_dependent: false 119 | 120 | mp: 121 | num_dof: 2 122 | tau: 3.0 123 | mp_type: prodmp 124 | mp_args: 125 | alpha_phase: 2.0 126 | num_basis: 25 127 | basis_bandwidth_factor: 2 128 | num_basis_outside: 0 129 | alpha: 25 130 | dt: 0.01 131 | assign_config: 132 | num_ctx: 0 133 | num_select: 10 134 | num_all: 301 135 | -------------------------------------------------------------------------------- /nmp/config/digit/mix_digits_offline.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | # Slurm config bwuni gpu 3 | name: "SLURM" # MUST BE "SLURM" 4 | partition: "gpu" # "single" for cpu, "gpu_4" or gpu_8" for gpu 5 | job-name: "mnist" # this will be the experiment's name in slurm 6 | num_parallel_jobs: 20 # max number of jobs executed in parallel 7 | ntasks: 1 # leave that like it is 8 | cpus-per-task: 2 # there are 10 cores for each GPU 9 | mem-per-cpu: 10000 # in MB 10 | time: 1000 # in minutes 11 | sbatch_args: # gpus need to be explicitly requested using this 12 | gres=gpu:1: "" #and this 13 | # nodelist: "node5" 14 | 15 | --- 16 | name: &name "pronmp_mix_digits_offline" 17 | 18 | # Required: Can also be set in DEFAULT 19 | path: /tmp/result/mix_digits # path for saving the results 20 | repetitions: 1 # number of repeated runs for each parameter combination 21 | 22 | # Required for AbstractIterativeExperiments only. Can also be set in DEFAULT 23 | iterations: 5000 # number of iterations per repetition. 24 | 25 | # Optional: Can also be set in DEFAULT 26 | # Only change these values if you are sure you know what you are doing. 27 | reps_per_job: 1 # number of repetitions in each job. useful for paralellization. defaults to 1. 28 | reps_in_parallel: 20 29 | 30 | wandb: 31 | project: *name 32 | group: training 33 | entity: gelikit 34 | log_interval: &interval 20 35 | log_model: true 36 | model_name: test_model_name 37 | 38 | params: 39 | dim_lat: &dim_lat 128 40 | act_func: &act_func leaky_relu 41 | lr: 2e-4 42 | wd: 5e-5 43 | vali_log_interval: *interval 44 | save_model_interval: 200 45 | seed: 1234 46 | runtime_noise: false 47 | max_norm: 20 48 | 49 | encoders: 50 | cnn: 51 | type: ProNMPEncoderCnnMlp 52 | args: 53 | image_size: 54 | - 40 55 | - 40 56 | kernel_size: 5 57 | num_cnn: 2 58 | cnn_channels: 59 | - 1 60 | - 10 61 | - 20 62 | dim_lat: *dim_lat 63 | obs_hidden: 64 | avg_neuron: 128 65 | num_hidden: 2 66 | shape: 0.0 67 | unc_hidden: 68 | avg_neuron: 128 69 | num_hidden: 3 70 | shape: 0.0 71 | act_func: *act_func 72 | 73 | aggregator: 74 | type: BayesianAggregator 75 | args: 76 | dim_lat: *dim_lat 77 | multiple_steps: true 78 | prior_mean: 0.0 79 | prior_var: 1 80 | 81 | decoder: 82 | type: PBDecoder 83 | args: 84 | dim_add_in: 0 85 | dim_val: 54 86 | dim_lat: *dim_lat 87 | std_only: False 88 | mean_hidden: 89 | avg_neuron: 128 90 | num_hidden: 3 91 | shape: 0.0 92 | variance_hidden: 93 | avg_neuron: 128 94 | num_hidden: 4 95 | shape: 0.0 96 | act_func: *act_func 97 | 98 | dataset: 99 | name: s_mnist_25_mix_2_plus_3_offline 100 | partition: 101 | train: 0.7 102 | validate: 0.15 103 | test: 0.15 104 | shuffle_set: False 105 | batch_size: 64 106 | shuffle_train_loader: True 107 | transform: null 108 | time_min: 0 109 | time_max: 3 110 | save_type: tensor 111 | data: 112 | images: 113 | time_dependent: false 114 | normalize: false 115 | trajs: 116 | time_dependent: true 117 | init_x_y_dmp_w_g: 118 | time_dependent: false 119 | 120 | mp: 121 | num_dof: 2 122 | tau: 3.0 123 | mp_type: prodmp 124 | mp_args: 125 | alpha_phase: 2.0 126 | num_basis: 25 127 | basis_bandwidth_factor: 2 128 | num_basis_outside: 0 129 | alpha: 25 130 | dt: 0.01 131 | assign_config: 132 | num_ctx: 0 133 | num_select: 10 134 | num_all: 301 135 | 136 | 137 | 138 | 139 | -------------------------------------------------------------------------------- /nmp/config/digit/noisy_digit.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | # Slurm config bwuni gpu 3 | name: "SLURM" # MUST BE "SLURM" 4 | partition: "gpu" # "single" for cpu, "gpu_4" or gpu_8" for gpu 5 | job-name: "mnist" # this will be the experiment's name in slurm 6 | num_parallel_jobs: 5 # max number of jobs executed in parallel 7 | ntasks: 1 # leave that like it is 8 | cpus-per-task: 2 # there are 10 cores for each GPU 9 | mem-per-cpu: 10000 # in MB 10 | time: 1000 # in minutes 11 | sbatch_args: # gpus need to be explicitly requested using this 12 | gres=gpu:1: "" #and this 13 | # nodelist: "node3" 14 | 15 | --- 16 | name: &name "noisy_digit_online" 17 | 18 | # Required: Can also be set in DEFAULT 19 | path: /tmp/result/noisy_digit_online # path for saving the results 20 | repetitions: 1 # number of repeated runs for each parameter combination 21 | 22 | # Required for AbstractIterativeExperiments only. Can also be set in DEFAULT 23 | iterations: 10000 # number of iterations per repetition. 24 | 25 | # Optional: Can also be set in DEFAULT 26 | # Only change these values if you are sure you know what you are doing. 27 | reps_per_job: 1 # number of repetitions in each job. useful for paralellization. defaults to 1. 28 | reps_in_parallel: 1 29 | 30 | wandb: 31 | project: *name 32 | group: training 33 | entity: gelikit 34 | log_interval: &interval 20 35 | log_model: true 36 | model_name: test_model_name 37 | 38 | params: 39 | dim_lat: &dim_lat 128 40 | act_func: &act_func leaky_relu 41 | lr: 2e-4 42 | wd: 5e-5 43 | vali_log_interval: *interval 44 | save_model_interval: 200 45 | seed: 1234 46 | runtime_noise: true 47 | max_norm: 20 48 | 49 | encoders: 50 | cnn: 51 | type: ProNMPEncoderCnnMlp 52 | args: 53 | image_size: 54 | - 40 55 | - 40 56 | kernel_size: 5 57 | num_cnn: 2 58 | cnn_channels: 59 | - 1 60 | - 10 61 | - 20 62 | dim_lat: *dim_lat 63 | obs_hidden: 64 | avg_neuron: 128 65 | num_hidden: 2 66 | shape: 0.0 67 | unc_hidden: 68 | avg_neuron: 128 69 | num_hidden: 3 70 | shape: 0.0 71 | act_func: *act_func 72 | 73 | aggregator: 74 | type: BayesianAggregator 75 | args: 76 | dim_lat: *dim_lat 77 | multiple_steps: true 78 | prior_mean: 0.0 79 | prior_var: 1 80 | 81 | decoder: 82 | type: PBDecoder 83 | args: 84 | dim_add_in: 0 85 | dim_val: 54 86 | dim_lat: *dim_lat 87 | std_only: False 88 | mean_hidden: 89 | avg_neuron: 128 90 | num_hidden: 3 91 | shape: 0.0 92 | variance_hidden: 93 | avg_neuron: 256 94 | num_hidden: 4 95 | shape: 0.0 96 | act_func: *act_func 97 | 98 | dataset: 99 | name: s_mnist_25_new 100 | partition: 101 | train: 0.7 102 | validate: 0.15 103 | test: 0.15 104 | shuffle_set: False 105 | batch_size: 512 106 | shuffle_train_loader: True 107 | transform: null 108 | time_min: 0 109 | time_max: 3 110 | save_type: tensor 111 | data: 112 | images: 113 | time_dependent: false 114 | normalize: false 115 | trajs: 116 | time_dependent: true 117 | init_x_y_dmp_w_g: 118 | time_dependent: false 119 | 120 | mp: 121 | num_dof: 2 122 | tau: 3.0 123 | mp_type: prodmp 124 | mp_args: 125 | alpha_phase: 2.0 126 | num_basis: 25 127 | basis_bandwidth_factor: 2 128 | num_basis_outside: 0 129 | alpha: 25 130 | dt: 0.01 131 | assign_config: 132 | num_ctx: 0 133 | num_select: 10 134 | num_all: 301 135 | 136 | 137 | -------------------------------------------------------------------------------- /nmp/config/digit/one_digit.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | # Slurm config bwuni gpu 3 | name: "SLURM" # MUST BE "SLURM" 4 | partition: "gpu" # "single" for cpu, "gpu_4" or gpu_8" for gpu 5 | job-name: "mnist" # this will be the experiment's name in slurm 6 | num_parallel_jobs: 5 # max number of jobs executed in parallel 7 | ntasks: 1 # leave that like it is 8 | cpus-per-task: 2 # there are 10 cores for each GPU 9 | mem-per-cpu: 10000 # in MB 10 | time: 1000 # in minutes 11 | sbatch_args: # gpus need to be explicitly requested using this 12 | gres=gpu:1: "" #and this 13 | 14 | --- 15 | name: &name "pronmp_one_digit" 16 | 17 | # Required: Can also be set in DEFAULT 18 | path: /tmp/result/one_digit # path for saving the results 19 | repetitions: 1 # number of repeated runs for each parameter combination 20 | 21 | # Required for AbstractIterativeExperiments only. Can also be set in DEFAULT 22 | iterations: 5000 # number of iterations per repetition. 23 | 24 | # Optional: Can also be set in DEFAULT 25 | # Only change these values if you are sure you know what you are doing. 26 | reps_per_job: 1 # number of repetitions in each job. useful for paralellization. defaults to 1. 27 | reps_in_parallel: 1 28 | 29 | wandb: 30 | project: *name 31 | group: training 32 | entity: gelikit 33 | log_interval: &interval 20 34 | log_model: true 35 | model_name: test_model_name 36 | 37 | params: 38 | dim_lat: &dim_lat 128 39 | act_func: &act_func leaky_relu 40 | lr: 2e-4 41 | wd: 5e-5 42 | vali_log_interval: *interval 43 | save_model_interval: 200 44 | seed: 1234 45 | max_norm: 25 46 | 47 | encoders: 48 | cnn: 49 | type: ProNMPEncoderCnnMlp 50 | args: 51 | image_size: 52 | - 40 53 | - 40 54 | kernel_size: 5 55 | num_cnn: 2 56 | cnn_channels: 57 | - 1 58 | - 10 59 | - 20 60 | dim_lat: *dim_lat 61 | obs_hidden: 62 | avg_neuron: 128 63 | num_hidden: 2 64 | shape: 0.0 65 | unc_hidden: 66 | avg_neuron: 128 67 | num_hidden: 3 68 | shape: 0.0 69 | act_func: *act_func 70 | 71 | aggregator: 72 | type: BayesianAggregator 73 | args: 74 | dim_lat: *dim_lat 75 | multiple_steps: true 76 | prior_mean: 0.0 77 | prior_var: 1 78 | 79 | decoder: 80 | type: PBDecoder 81 | args: 82 | dim_add_in: 0 83 | dim_val: 54 84 | dim_lat: *dim_lat 85 | std_only: False 86 | mean_hidden: 87 | avg_neuron: 128 88 | num_hidden: 3 89 | shape: 0.0 90 | variance_hidden: 91 | avg_neuron: 256 92 | num_hidden: 4 93 | shape: 0.0 94 | act_func: *act_func 95 | 96 | dataset: 97 | name: s_mnist_25_new 98 | partition: 99 | train: 0.7 100 | validate: 0.15 101 | test: 0.15 102 | shuffle_set: False 103 | batch_size: 512 104 | shuffle_train_loader: True 105 | transform: null 106 | time_min: 0 107 | time_max: 3 108 | save_type: tensor 109 | data: 110 | images: 111 | time_dependent: false 112 | normalize: false 113 | trajs: 114 | time_dependent: true 115 | init_x_y_dmp_w_g: 116 | time_dependent: false 117 | 118 | mp: 119 | num_dof: 2 120 | tau: 3.0 121 | mp_type: prodmp 122 | mp_args: 123 | alpha_phase: 2.0 124 | num_basis: 25 125 | basis_bandwidth_factor: 2 126 | num_basis_outside: 0 127 | alpha: 25 128 | dt: 0.01 129 | assign_config: 130 | num_ctx: 0 131 | num_select: 10 132 | num_all: 301 133 | 134 | 135 | 136 | -------------------------------------------------------------------------------- /nmp/config/robot_push/robot_push_bhc.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | # Slurm config bwuni gpu 3 | name: "SLURM" # MUST BE "SLURM" 4 | partition: "gpu" # "single" for cpu, "gpu_4" or gpu_8" for gpu 5 | job-name: "robot_push_bhc" # this will be the experiment's name in slurm 6 | num_parallel_jobs: 20 # max number of jobs executed in parallel 7 | ntasks: 1 # leave that like it is 8 | cpus-per-task: 2 # there are 10 cores for each GPU 9 | mem-per-cpu: 10000 # in MB 10 | time: 1000 # in minutes 11 | sbatch_args: # gpus need to be explicitly requested using this 12 | gres=gpu:1: "" #and this 13 | 14 | 15 | 16 | --- 17 | name: &name robot_push_bhc 18 | 19 | # Required: Can also be set in DEFAULT 20 | path: /tmp/result/robot_push_bhc # path for saving the results 21 | repetitions: 1 # number of repeated runs for each parameter combination 22 | 23 | # Required for AbstractIterativeExperiments only. Can also be set in DEFAULT 24 | iterations: 20000 # number of iterations per repetition. 25 | 26 | # Optional: Can also be set in DEFAULT 27 | # Only change these values if you are sure you know what you are doing. 28 | reps_per_job: 1 # number of repetitions in each job. useful for paralellization. defaults to 1. 29 | reps_in_parallel: 20 30 | 31 | wandb: 32 | project: robot_push_bhc 33 | group: *name 34 | entity: gelikit 35 | log_interval: &interval 20 36 | log_model: true 37 | model_name: model 38 | 39 | params: 40 | act_func: &act_func leaky_relu 41 | lr: 2e-4 42 | wd: 5e-5 43 | vali_log_interval: *interval 44 | save_model_interval: 500 45 | seed: 1234 46 | max_norm: 150 47 | 48 | mlp_net: 49 | type: CNMPEncoderMlp # Here just use this class as the network, it is not CNMP! 50 | args: 51 | name: BehaviorCloning 52 | dim_obs: 6 53 | dim_lat: 2 # Here is just the output size of the network, there is no latent space 54 | obs_hidden: 55 | avg_neuron: 128 56 | num_hidden: 3 57 | shape: 0.0 58 | act_func: *act_func 59 | 60 | dataset: 61 | name: robot_push 62 | partition: 63 | train: 0.7 64 | validate: 0.2 65 | test: 0.1 66 | shuffle_set: True 67 | batch_size: 512 68 | shuffle_train_loader: True 69 | transform: null 70 | time_min: 0 71 | time_max: 3 72 | save_type: tensor 73 | data: 74 | object_pos_ori: 75 | time_dependent: true 76 | des_cart_pos_vel: 77 | time_dependent: true 78 | box_robot_state: 79 | time_dependent: true 80 | file_index: 81 | time_dependent: false 82 | normalize: false 83 | 84 | -------------------------------------------------------------------------------- /nmp/config/robot_push/robot_push_cnmp.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | # Slurm config bwuni gpu 3 | name: "SLURM" # MUST BE "SLURM" 4 | partition: "gpu" # "single" for cpu, "gpu_4" or gpu_8" for gpu 5 | job-name: "robot_push_cnmp" # this will be the experiment's name in slurm 6 | num_parallel_jobs: 20 # max number of jobs executed in parallel 7 | ntasks: 1 # leave that like it is 8 | cpus-per-task: 2 # there are 10 cores for each GPU 9 | mem-per-cpu: 10000 # in MB 10 | time: 1000 # in minutes 11 | sbatch_args: # gpus need to be explicitly requested using this 12 | gres=gpu:1: "" #and this 13 | 14 | 15 | 16 | --- 17 | name: &name robot_push_cnmp 18 | 19 | # Required: Can also be set in DEFAULT 20 | path: /tmp/result/robot_push_cnmp # path for saving the results 21 | repetitions: 1 # number of repeated runs for each parameter combination 22 | 23 | # Required for AbstractIterativeExperiments only. Can also be set in DEFAULT 24 | iterations: 10000 # number of iterations per repetition. 25 | 26 | # Optional: Can also be set in DEFAULT 27 | # Only change these values if you are sure you know what you are doing. 28 | reps_per_job: 1 # number of repetitions in each job. useful for paralellization. defaults to 1. 29 | reps_in_parallel: 20 30 | 31 | wandb: 32 | project: robot_push_cnmp 33 | group: *name 34 | entity: gelikit 35 | log_interval: &interval 20 36 | log_model: true 37 | model_name: model 38 | 39 | 40 | params: 41 | dim_lat: &dim_lat 128 42 | act_func: &act_func leaky_relu 43 | lr: 2e-4 44 | wd: 5e-5 45 | vali_log_interval: *interval 46 | save_model_interval: 500 47 | seed: 1234 48 | max_norm: 150 49 | 50 | encoders: 51 | ctx: 52 | type: CNMPEncoderMlp 53 | args: 54 | dim_obs: 8 55 | dim_lat: *dim_lat 56 | obs_hidden: 57 | avg_neuron: 128 58 | num_hidden: 3 59 | shape: 0.0 60 | act_func: *act_func 61 | 62 | aggregator: 63 | type: MeanAggregator 64 | args: 65 | dim_lat: *dim_lat 66 | multiple_steps: false 67 | prior_mean: 0.0 68 | 69 | decoder: 70 | type: CNPDecoder 71 | args: 72 | dim_add_in: 1 73 | dim_val: 2 74 | dim_lat: *dim_lat 75 | std_only: true 76 | mean_hidden: 77 | avg_neuron: 128 78 | num_hidden: 3 79 | shape: 0.0 80 | variance_hidden: 81 | avg_neuron: 128 82 | num_hidden: 3 83 | shape: 0.0 84 | act_func: *act_func 85 | 86 | dataset: 87 | name: robot_push 88 | partition: 89 | train: 0.7 90 | validate: 0.2 91 | test: 0.1 92 | shuffle_set: True 93 | batch_size: 48 94 | shuffle_train_loader: True 95 | transform: null 96 | time_min: 0 97 | time_max: 3 98 | save_type: tensor 99 | data: 100 | object_pos_ori: 101 | time_dependent: true 102 | des_cart_pos_vel: 103 | time_dependent: true 104 | box_robot_state: 105 | time_dependent: true 106 | file_index: 107 | time_dependent: false 108 | normalize: false 109 | 110 | assign_config: 111 | num_ctx_min: 1 112 | num_ctx_max: 10 113 | pred_range_min: 50 114 | pred_range_max: 50 115 | 116 | -------------------------------------------------------------------------------- /nmp/config/robot_push/robot_push_prodmp.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | # Slurm config bwuni gpu 3 | name: "SLURM" # MUST BE "SLURM" 4 | partition: "gpu" # "single" for cpu, "gpu_4" or gpu_8" for gpu 5 | job-name: "robot_push_prodmp" # this will be the experiment's name in slurm 6 | num_parallel_jobs: 20 # max number of jobs executed in parallel 7 | ntasks: 1 # leave that like it is 8 | cpus-per-task: 2 # there are 10 cores for each GPU 9 | mem-per-cpu: 10000 # in MB 10 | time: 1000 # in minutes 11 | sbatch_args: # gpus need to be explicitly requested using this 12 | gres=gpu:1: "" #and this 13 | 14 | 15 | 16 | --- 17 | name: &name robot_push_prodmp 18 | 19 | # Required: Can also be set in DEFAULT 20 | path: /tmp/result/robot_push_prodmp # path for saving the results 21 | repetitions: 1 # number of repeated runs for each parameter combination 22 | 23 | # Required for AbstractIterativeExperiments only. Can also be set in DEFAULT 24 | iterations: 10000 # number of iterations per repetition. 25 | 26 | # Optional: Can also be set in DEFAULT 27 | # Only change these values if you are sure you know what you are doing. 28 | reps_per_job: 1 # number of repetitions in each job. useful for paralellization. defaults to 1. 29 | reps_in_parallel: 20 30 | 31 | wandb: 32 | project: robot_push_prodmp 33 | group: *name 34 | entity: gelikit 35 | log_interval: &interval 20 36 | log_model: true 37 | model_name: model 38 | 39 | 40 | params: 41 | dim_lat: &dim_lat 128 42 | act_func: &act_func leaky_relu 43 | lr: 2e-4 44 | wd: 5e-5 45 | vali_log_interval: *interval 46 | save_model_interval: 500 47 | seed: 1234 48 | max_norm: 150 49 | 50 | encoders: 51 | ctx: 52 | type: ProNMPEncoderMlp 53 | args: 54 | dim_obs: 8 55 | dim_lat: *dim_lat 56 | obs_hidden: 57 | avg_neuron: 128 58 | num_hidden: 3 59 | shape: 0.0 60 | unc_hidden: 61 | avg_neuron: 128 62 | num_hidden: 3 63 | shape: 0.0 64 | act_func: *act_func 65 | 66 | aggregator: 67 | type: BayesianAggregator 68 | args: 69 | dim_lat: *dim_lat 70 | multiple_steps: false 71 | prior_mean: 0.0 72 | prior_var: 1 73 | 74 | decoder: 75 | type: PBDecoder 76 | args: 77 | dim_add_in: 0 78 | dim_val: 52 79 | dim_lat: *dim_lat 80 | std_only: false 81 | mean_hidden: 82 | avg_neuron: 128 83 | num_hidden: 3 84 | shape: 0.0 85 | variance_hidden: 86 | avg_neuron: 128 87 | num_hidden: 3 88 | shape: 0.0 89 | act_func: *act_func 90 | 91 | dataset: 92 | name: robot_push 93 | partition: 94 | train: 0.7 95 | validate: 0.2 96 | test: 0.1 97 | shuffle_set: True 98 | batch_size: 48 99 | shuffle_train_loader: True 100 | transform: null 101 | time_min: 0 102 | time_max: 3 103 | save_type: tensor 104 | data: 105 | object_pos_ori: 106 | time_dependent: true 107 | des_cart_pos_vel: 108 | time_dependent: true 109 | box_robot_state: 110 | time_dependent: true 111 | file_index: 112 | time_dependent: false 113 | normalize: false 114 | 115 | mp: 116 | num_dof: 2 117 | tau: 0.5 118 | mp_type: prodmp 119 | mp_args: 120 | alpha_phase: 0.5 121 | num_basis: 25 122 | basis_bandwidth_factor: 3 123 | num_basis_outside: 0 124 | alpha: 25 125 | dt: 0.01 126 | 127 | assign_config: 128 | num_ctx_min: 1 129 | num_ctx_max: 10 130 | pred_range_min: 50 131 | pred_range_max: 50 132 | 133 | -------------------------------------------------------------------------------- /nmp/config/robot_push/robot_push_promp.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | # Slurm config bwuni gpu 3 | name: "SLURM" # MUST BE "SLURM" 4 | partition: "gpu" # "single" for cpu, "gpu_4" or gpu_8" for gpu 5 | job-name: "robot_push_promp" # this will be the experiment's name in slurm 6 | num_parallel_jobs: 20 # max number of jobs executed in parallel 7 | ntasks: 1 # leave that like it is 8 | cpus-per-task: 2 # there are 10 cores for each GPU 9 | mem-per-cpu: 10000 # in MB 10 | time: 1000 # in minutes 11 | sbatch_args: # gpus need to be explicitly requested using this 12 | gres=gpu:1: "" #and this 13 | 14 | 15 | 16 | --- 17 | name: &name robot_push_promp 18 | 19 | # Required: Can also be set in DEFAULT 20 | path: /tmp/result/robot_push_promp # path for saving the results 21 | repetitions: 1 # number of repeated runs for each parameter combination 22 | 23 | # Required for AbstractIterativeExperiments only. Can also be set in DEFAULT 24 | iterations: 10000 # number of iterations per repetition. 25 | 26 | # Optional: Can also be set in DEFAULT 27 | # Only change these values if you are sure you know what you are doing. 28 | reps_per_job: 1 # number of repetitions in each job. useful for paralellization. defaults to 1. 29 | reps_in_parallel: 20 30 | 31 | wandb: 32 | project: robot_push_promp 33 | group: *name 34 | entity: gelikit 35 | log_interval: &interval 20 36 | log_model: true 37 | model_name: model 38 | 39 | 40 | params: 41 | dim_lat: &dim_lat 128 42 | act_func: &act_func leaky_relu 43 | lr: 2e-4 44 | wd: 5e-5 45 | vali_log_interval: *interval 46 | save_model_interval: 500 47 | seed: 1234 48 | max_norm: 150 49 | 50 | encoders: 51 | ctx: 52 | type: ProNMPEncoderMlp 53 | args: 54 | dim_obs: 8 55 | dim_lat: *dim_lat 56 | obs_hidden: 57 | avg_neuron: 128 58 | num_hidden: 3 59 | shape: 0.0 60 | unc_hidden: 61 | avg_neuron: 128 62 | num_hidden: 3 63 | shape: 0.0 64 | act_func: *act_func 65 | 66 | aggregator: 67 | type: BayesianAggregator 68 | args: 69 | dim_lat: *dim_lat 70 | multiple_steps: false 71 | prior_mean: 0.0 72 | prior_var: 1 73 | 74 | decoder: 75 | type: PBDecoder 76 | args: 77 | dim_add_in: 0 78 | dim_val: 20 79 | dim_lat: *dim_lat 80 | std_only: false 81 | mean_hidden: 82 | avg_neuron: 128 83 | num_hidden: 3 84 | shape: 0.0 85 | variance_hidden: 86 | avg_neuron: 128 87 | num_hidden: 3 88 | shape: 0.0 89 | act_func: *act_func 90 | 91 | dataset: 92 | name: robot_push 93 | partition: 94 | train: 0.7 95 | validate: 0.2 96 | test: 0.1 97 | shuffle_set: True 98 | batch_size: 48 99 | shuffle_train_loader: True 100 | transform: null 101 | time_min: 0 102 | time_max: 3 103 | save_type: tensor 104 | data: 105 | object_pos_ori: 106 | time_dependent: true 107 | des_cart_pos_vel: 108 | time_dependent: true 109 | box_robot_state: 110 | time_dependent: true 111 | file_index: 112 | time_dependent: false 113 | normalize: false 114 | 115 | mp: 116 | num_dof: 2 117 | tau: 0.5 118 | mp_type: promp 119 | mp_args: 120 | num_basis: 10 121 | basis_bandwidth_factor: 3 122 | num_basis_outside: 0 123 | dt: 0.01 124 | 125 | assign_config: 126 | num_ctx_min: 1 127 | num_ctx_max: 10 128 | pred_range_min: 50 129 | pred_range_max: 50 130 | 131 | -------------------------------------------------------------------------------- /nmp/decoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | @brief: Decoder classes in PyTorch 3 | """ 4 | 5 | from abc import ABC 6 | from abc import abstractmethod 7 | # Import Python libs 8 | from typing import Optional 9 | 10 | import torch 11 | 12 | import nmp.util as util 13 | from nmp.nn_base import MLP 14 | from nmp.util import mlp_arch_3_params 15 | 16 | 17 | class Decoder(ABC): 18 | """Decoder class interface""" 19 | 20 | def __init__(self, **kwargs): 21 | """ 22 | Constructor 23 | 24 | Args: 25 | **kwargs: Decoder configuration 26 | """ 27 | 28 | # MLP configuration 29 | self.dim_add_in: int = kwargs["dim_add_in"] 30 | self.dim_val: int = kwargs["dim_val"] 31 | self.dim_lat: int = kwargs["dim_lat"] 32 | self.std_only: bool = kwargs["std_only"] 33 | 34 | self.mean_hidden: dict = kwargs["mean_hidden"] 35 | self.variance_hidden: dict = kwargs["variance_hidden"] 36 | 37 | self.act_func: str = kwargs["act_func"] 38 | 39 | # Decoders 40 | self.mean_val_net = None 41 | self.cov_val_net = None 42 | 43 | # Create decoders 44 | self._create_network() 45 | 46 | @property 47 | def _decoder_type(self) -> str: 48 | """ 49 | Returns: string of decoder type 50 | """ 51 | return self.__class__.__name__ 52 | 53 | def _create_network(self): 54 | """ 55 | Create decoder with given configuration 56 | 57 | Returns: 58 | None 59 | """ 60 | 61 | # compute the output dimension of covariance network 62 | if self.std_only: 63 | # Only has diagonal elements 64 | dim_out_cov = self.dim_val 65 | else: 66 | # Diagonal + Non-diagonal elements, form up Cholesky Decomposition 67 | dim_out_cov = self.dim_val \ 68 | + (self.dim_val * (self.dim_val - 1)) // 2 69 | 70 | # Two separate value decoders: mean_val_net + cov_val_net 71 | self.mean_val_net = MLP(name=self._decoder_type + "_mean_val", 72 | dim_in=self.dim_add_in + self.dim_lat, 73 | dim_out=self.dim_val, 74 | hidden_layers= 75 | mlp_arch_3_params(**self.mean_hidden), 76 | act_func=self.act_func) 77 | 78 | self.cov_val_net = MLP(name=self._decoder_type + "_cov_val", 79 | dim_in=self.dim_add_in + self.dim_lat, 80 | dim_out=dim_out_cov, 81 | hidden_layers= 82 | mlp_arch_3_params(**self.variance_hidden), 83 | act_func=self.act_func) 84 | 85 | @property 86 | def network(self): 87 | """ 88 | Return decoder networks 89 | 90 | Returns: 91 | """ 92 | return self.mean_val_net, self.cov_val_net 93 | 94 | @property 95 | def parameters(self) -> []: 96 | """ 97 | Get network parameters 98 | Returns: 99 | parameters 100 | """ 101 | return list(self.mean_val_net.parameters()) + \ 102 | list(self.cov_val_net.parameters()) 103 | 104 | def save_weights(self, log_dir: str, epoch: int): 105 | """ 106 | Save NN weights to file 107 | Args: 108 | log_dir: directory to save weights to 109 | epoch: training epoch 110 | 111 | Returns: 112 | None 113 | """ 114 | self.mean_val_net.save(log_dir, epoch) 115 | self.cov_val_net.save(log_dir, epoch) 116 | 117 | def load_weights(self, log_dir: str, epoch: int): 118 | """ 119 | Load NN weights from file 120 | Args: 121 | log_dir: directory stored weights 122 | epoch: training epoch 123 | 124 | Returns: 125 | None 126 | """ 127 | self.mean_val_net.load(log_dir, epoch) 128 | self.cov_val_net.load(log_dir, epoch) 129 | 130 | def _process_cov_net_output(self, cov_val: torch.Tensor): 131 | """ 132 | Divide diagonal and off-diagonal elements of cov-net output, 133 | apply reverse "Log-Cholesky to diagonal elements" 134 | Args: 135 | cov_val: output of covariance network 136 | 137 | Returns: diagonal and off-diagonal tensors 138 | 139 | """ 140 | # Decompose diagonal and off-diagonal elements 141 | diag_cov_val = cov_val[..., :self.dim_val] 142 | off_diag_cov_val = None if self.std_only \ 143 | else cov_val[..., self.dim_val:] 144 | 145 | # De-parametrize Log-Cholesky for diagonal elements 146 | diag_cov_val = util.to_softplus_space(diag_cov_val, lower_bound=None) 147 | 148 | # Return 149 | return diag_cov_val, off_diag_cov_val 150 | 151 | @abstractmethod 152 | def decode(self, *args, **kwargs): 153 | pass 154 | 155 | 156 | class PBDecoder(Decoder): 157 | """Parameter based decoder""" 158 | 159 | def decode(self, 160 | add_inputs: Optional[torch.Tensor], 161 | mean_lat_var: torch.Tensor, 162 | variance_lat_var: torch.Tensor) \ 163 | -> [torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: 164 | """ 165 | Decode and compute target value's distribution 166 | 167 | Here, target value to be predicted is a 4th order tensor with axes: 168 | 169 | traj: this target value is on which trajectory 170 | aggr: based on how much aggregated context do we make this prediction? 171 | tar: this target value is on which target time? 172 | value: vector to be predicted 173 | 174 | Args: 175 | add_inputs: additional inputs, can be None 176 | mean_lat_var: mean of latent variable 177 | variance_lat_var: variance of latent variable 178 | 179 | Returns: 180 | mean_val: mean of target value 181 | 182 | diag_cov_val: diagonal elements of Cholesky Decomposition of 183 | covariance of target value 184 | 185 | off_diag_cov_val: None, or off-diagonal elements of Cholesky 186 | Decomposition of covariance of target value 187 | 188 | """ 189 | 190 | # Shape of mean_lat_var: 191 | # [num_traj, num_agg, dim_lat] 192 | # 193 | # Shape of variance_lat_var: 194 | # [num_traj, num_agg, dim_lat] 195 | # 196 | # Shape of add_inputs: 197 | # [num_traj, num_time_pts, dim_add_in=1] 198 | # 199 | # Shape of mean_val: 200 | # [num_traj, num_agg, num_time_pts, dim_val] 201 | # 202 | # Shape of diag_cov_val: 203 | # [num_traj, num_agg, num_time_pts, dim_val] 204 | # 205 | # Shape of off_diag_cov_val: 206 | # [num_traj, num_agg, num_time_pts, (dim_val * (dim_val - 1) // 2)] 207 | 208 | # Dimension check 209 | assert mean_lat_var.ndim == variance_lat_var.ndim == 3 210 | num_agg = mean_lat_var.shape[1] 211 | 212 | # Process add_inputs 213 | if add_inputs is not None: 214 | assert add_inputs.ndim == 3 215 | num_time_pts = add_inputs.shape[1] 216 | # Add one axis (aggregation-wise batch dimension) to add_inputs 217 | add_inputs = util.add_expand_dim(add_inputs, [1], [num_agg]) 218 | else: 219 | num_time_pts = 1 220 | 221 | # Parametrize variance 222 | variance_lat_var = util.to_log_space(variance_lat_var, 223 | lower_bound=None) 224 | 225 | # Add one axis (time-scale-wise batch dimension) to latent variable 226 | mean_lat_var = util.add_expand_dim(mean_lat_var, [2], [num_time_pts]) 227 | variance_lat_var = util.add_expand_dim(variance_lat_var, [2], 228 | [num_time_pts]) 229 | 230 | # Prepare input to decoder networks 231 | mean_net_input = mean_lat_var 232 | cov_net_input = variance_lat_var 233 | if add_inputs is not None: 234 | mean_net_input = torch.cat((add_inputs, mean_net_input), dim=-1) 235 | cov_net_input = torch.cat((add_inputs, cov_net_input), dim=-1) 236 | 237 | # Decode 238 | mean_val = self.mean_val_net(mean_net_input) 239 | cov_val = self.cov_val_net(cov_net_input) 240 | 241 | # Process cov net prediction 242 | diag_cov_val, off_diag_cov_val = self._process_cov_net_output(cov_val) 243 | 244 | # Return 245 | return mean_val, diag_cov_val, off_diag_cov_val 246 | 247 | 248 | class CNPDecoder(Decoder): 249 | """Conditional Neural Processes decoder""" 250 | 251 | def decode(self, 252 | add_inputs: torch.Tensor, 253 | mean_lat_obs: torch.Tensor) \ 254 | -> [torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: 255 | """ 256 | Decode and compute target value's distribution at target add_inputs 257 | 258 | Here, target value to be predicted is a 4th order tensor with axes: 259 | 260 | traj: this target value is on which trajectory 261 | aggr: based on how much aggregated context do we make this prediction? 262 | tar: this target value is on which target time? 263 | value: vector to be predicted 264 | 265 | Args: 266 | add_inputs: additional inputs, can be None 267 | mean_lat_obs: mean of latent observation 268 | 269 | Returns: 270 | mean_val: mean of target value 271 | 272 | diag_cov_val: diagonal elements of Cholesky Decomposition of 273 | covariance of target value 274 | 275 | off_diag_cov_val: None, or off-diagonal elements of Cholesky 276 | Decomposition of covariance of target value 277 | 278 | """ 279 | 280 | # Shape of mean_lat_obs: 281 | # [num_traj, num_agg, dim_lat] 282 | # 283 | # Shape of add_inputs: 284 | # [num_traj, num_time_pts, dim_add_in=1] if add_inputs not None 285 | # 286 | # Shape of mean_val: 287 | # [num_traj, num_agg, num_time_pts, dim_val] 288 | # 289 | # Shape of diag_cov_val: 290 | # [num_traj, num_agg, num_time_pts, dim_val] 291 | # 292 | # Shape of off_diag_cov_val: 293 | # [num_traj, num_agg, num_time_pts, (dim_val * (dim_val - 1) // 2)] 294 | 295 | # Dimension check 296 | assert mean_lat_obs.ndim == 3 297 | num_agg = mean_lat_obs.shape[1] 298 | 299 | # Process add_inputs 300 | if add_inputs is not None: 301 | assert add_inputs.ndim == 3 302 | # Get dimensions 303 | num_time_pts = add_inputs.shape[1] 304 | # Add one axis (aggregation-wise batch dimension) to add_inputs 305 | add_inputs = util.add_expand_dim(add_inputs, [1], [num_agg]) 306 | else: 307 | num_time_pts = 1 308 | 309 | # Add one axis (time-scale-wise batch dimension) to latent observation 310 | mean_lat_obs = util.add_expand_dim(mean_lat_obs, [2], [num_time_pts]) 311 | 312 | # Prepare input to decoder network 313 | net_input = mean_lat_obs 314 | if add_inputs is not None: 315 | net_input = torch.cat((add_inputs, net_input), dim=-1) 316 | 317 | # Decode 318 | mean_val = self.mean_val_net(net_input) 319 | cov_val = self.cov_val_net(net_input) 320 | 321 | # Process cov net prediction 322 | diag_cov_val, off_diag_cov_val = self._process_cov_net_output(cov_val) 323 | 324 | # Return 325 | return mean_val, diag_cov_val, off_diag_cov_val 326 | 327 | 328 | class MCDecoder(Decoder): 329 | """Monte-Carlo decoder""" 330 | 331 | def decode(self, 332 | add_inputs: torch.Tensor, 333 | sampled_lat_var: torch.Tensor, 334 | variance_lat_var: torch.Tensor) \ 335 | -> [torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: 336 | """ 337 | Decode and compute target value's distribution 338 | 339 | Here, target value to be predicted is a 5th order tensor with axes: 340 | 341 | traj: this target value is on which trajectory 342 | aggr: based on how much aggregated context do we make this prediction? 343 | sample: latent variable samples for Monte-Carlo 344 | tar: this target value is on which target time? 345 | value: vector to be predicted 346 | 347 | Args: 348 | add_inputs: additional inputs, can be None 349 | sampled_lat_var: sampled latent variable 350 | variance_lat_var: variance of latent variable 351 | 352 | Returns: 353 | mean_val: mean of target value 354 | 355 | diag_cov_val: diagonal elements of Cholesky Decomposition of 356 | covariance of target value 357 | 358 | off_diag_cov_val: None, or off-diagonal elements of Cholesky 359 | Decomposition of covariance of target value 360 | """ 361 | 362 | # Shape of sampled_lat_var: 363 | # [num_traj, num_agg, num_smp, dim_lat] 364 | # 365 | # Shape of variance_lat_var: 366 | # [num_traj, num_agg, num_smp, dim_lat] 367 | # 368 | # Shape of add_inputs: 369 | # [num_traj, num_time_pts, dim_add_in=1] if add_inputs not None 370 | # 371 | # Shape of mean_val: 372 | # [num_traj, num_agg, num_smp, num_time_pts, dim_val] 373 | # 374 | # Shape of diag_cov_val: 375 | # [num_traj, num_agg, num_smp, num_time_pts, dim_val] 376 | # 377 | # Shape of off_diag_cov_val: 378 | # [num_traj, num_agg, num_smp, num_time_pts, 379 | # (dim_val * (dim_val - 1) // 2)] 380 | 381 | # Dimension check 382 | assert sampled_lat_var.ndim == variance_lat_var.ndim == 4 383 | num_agg = sampled_lat_var.shape[1] 384 | num_smp = sampled_lat_var.shape[2] 385 | 386 | # Process add_inputs 387 | if add_inputs is not None: 388 | assert add_inputs.ndim == 3 389 | # Get dimensions 390 | num_time_pts = add_inputs.shape[1] 391 | # Add one axis (aggregation-wise batch dimension) to add_inputs 392 | add_inputs = util.add_expand_dim(add_inputs, [1, 2], 393 | [num_agg, num_smp]) 394 | 395 | else: 396 | num_time_pts = 1 397 | 398 | # Parametrize variance 399 | variance_lat_var = util.to_log_space(variance_lat_var, lower_bound=None) 400 | 401 | # Add one axis (time-scale-wise batch dimension) to latent observation 402 | sampled_lat_var = util.add_expand_dim(sampled_lat_var, 403 | [-2], [num_time_pts]) 404 | variance_lat_var = util.add_expand_dim(variance_lat_var, 405 | [-2], [num_time_pts]) 406 | 407 | # Prepare input to decoder network 408 | mean_net_input = sampled_lat_var 409 | cov_net_input = variance_lat_var 410 | if add_inputs is not None: 411 | mean_net_input = torch.cat((add_inputs, sampled_lat_var), dim=-1) 412 | cov_net_input = torch.cat((add_inputs, variance_lat_var), dim=-1) 413 | 414 | # Decode 415 | mean_val = self.mean_val_net(mean_net_input) 416 | cov_val = self.cov_val_net(cov_net_input) 417 | 418 | # Process cov net prediction 419 | diag_cov_val, off_diag_cov_val = self._process_cov_net_output(cov_val) 420 | 421 | # Return 422 | return mean_val, diag_cov_val, off_diag_cov_val 423 | 424 | 425 | class DecoderFactory: 426 | 427 | @staticmethod 428 | def get_decoder(decoder_type: str, **kwargs): 429 | return eval(decoder_type + "(**kwargs)") 430 | -------------------------------------------------------------------------------- /nmp/encoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | @brief: Encoder classes in PyTorch 3 | """ 4 | from abc import ABC 5 | from abc import abstractmethod 6 | 7 | from nmp.nn_base import CNNMLP 8 | from nmp.nn_base import MLP 9 | from nmp.util import mlp_arch_3_params 10 | from nmp.util import to_softplus_space 11 | 12 | 13 | class ProNMPEncoder(ABC): 14 | 15 | def __init__(self, **kwargs): 16 | """ 17 | NMP encoder constructor 18 | Args: 19 | **kwargs: Encoder configuration 20 | """ 21 | 22 | # Encoders 23 | self.lat_mean_net = None 24 | self.lat_var_net = None 25 | self._create_network(**kwargs) 26 | 27 | @abstractmethod 28 | def _create_network(self, *args, **kwargs): 29 | """ 30 | Create encoder with given configuration 31 | 32 | Returns: 33 | None 34 | """ 35 | pass 36 | 37 | @property 38 | def network(self): 39 | """ 40 | Return encoder networks 41 | 42 | Returns: 43 | """ 44 | return self.lat_mean_net, self.lat_var_net 45 | 46 | @property 47 | def parameters(self): 48 | """ 49 | Get network parameters 50 | Returns: 51 | parameters 52 | """ 53 | return list(self.lat_mean_net.parameters()) + \ 54 | list(self.lat_var_net.parameters()) 55 | 56 | def save_weights(self, log_dir: str, epoch: int): 57 | """ 58 | Save NN weights to file 59 | Args: 60 | log_dir: directory to save weights to 61 | epoch: training epoch 62 | 63 | Returns: 64 | None 65 | """ 66 | self.lat_mean_net.save(log_dir, epoch) 67 | self.lat_var_net.save(log_dir, epoch) 68 | 69 | def load_weights(self, log_dir: str, epoch: int): 70 | """ 71 | Load NN weights from file 72 | Args: 73 | log_dir: directory stored weights 74 | epoch: training epoch 75 | 76 | Returns: 77 | None 78 | """ 79 | self.lat_mean_net.load(log_dir, epoch) 80 | self.lat_var_net.load(log_dir, epoch) 81 | 82 | @abstractmethod 83 | def encode(self, *args, **kwargs): 84 | """ 85 | Encode observations 86 | 87 | Returns: 88 | lat_obs: latent observations 89 | var_lat_obs: variance of latent observations 90 | """ 91 | pass 92 | 93 | 94 | class ProNMPEncoderMlp(ProNMPEncoder): 95 | def _create_network(self, **kwargs): 96 | """ 97 | Create encoder with given configuration 98 | 99 | Returns: 100 | None 101 | """ 102 | # MLP configuration 103 | self.name: str = kwargs["name"] 104 | self.dim_obs: int = kwargs["dim_obs"] 105 | self.dim_lat_obs: int = kwargs["dim_lat"] 106 | 107 | self.obs_hidden: dict = kwargs["obs_hidden"] 108 | self.unc_hidden: dict = kwargs["unc_hidden"] 109 | 110 | self.act_func: str = kwargs["act_func"] 111 | 112 | # Two separate latent observation encoders 113 | # lat_mean_net + lat_var_net 114 | self.lat_mean_net = MLP(name="ProNMPEncoder_lat_mean_" + self.name, 115 | dim_in=self.dim_obs, 116 | dim_out=self.dim_lat_obs, 117 | hidden_layers= 118 | mlp_arch_3_params(**self.obs_hidden), 119 | act_func=self.act_func) 120 | 121 | self.lat_var_net = \ 122 | MLP(name="ProNMPEncoder_lat_var_" + self.name, 123 | dim_in=self.dim_obs, 124 | dim_out=self.dim_lat_obs, 125 | hidden_layers=mlp_arch_3_params(**self.unc_hidden), 126 | act_func=self.act_func) 127 | 128 | def encode(self, obs): 129 | """ 130 | Encode observations 131 | 132 | Args: 133 | obs: observations 134 | 135 | Returns: 136 | lat_obs: latent observations 137 | var_lat_obs: variance of latent observations 138 | """ 139 | 140 | # Shape of obs: 141 | # [num_traj, num_obs, dim_obs], 142 | # 143 | # Shape of lat_obs: 144 | # [num_traj, num_obs, dim_lat] 145 | # 146 | # Shape of var_lat_obs: 147 | # [num_traj, num_obs, dim_lat] 148 | 149 | # Check input shapes 150 | assert obs.ndim == 3 151 | 152 | # Encode 153 | return self.lat_mean_net(obs), \ 154 | to_softplus_space(self.lat_var_net(obs), lower_bound=None) 155 | 156 | 157 | class ProNMPEncoderCnnMlp(ProNMPEncoder): 158 | def _create_network(self, **kwargs): 159 | """ 160 | Create encoder with given configuration 161 | 162 | Returns: 163 | None 164 | """ 165 | # configuration 166 | self.name = kwargs["name"] 167 | self.image_size = kwargs["image_size"] 168 | self.kernel_size = kwargs["kernel_size"] 169 | self.num_cnn = kwargs["num_cnn"] 170 | self.cnn_channels = kwargs["cnn_channels"] 171 | self.dim_lat_obs: int = kwargs["dim_lat"] 172 | self.obs_hidden: dict = kwargs["obs_hidden"] 173 | self.unc_hidden: dict = kwargs["unc_hidden"] 174 | self.act_func = kwargs["act_func"] 175 | 176 | # Two separate latent observation encoders 177 | # lat_mean_net + lat_var_net 178 | 179 | self.lat_mean_net = CNNMLP(name="ProNMPEncoder_lat_mean_" + self.name, 180 | image_size=self.image_size, 181 | kernel_size=self.kernel_size, 182 | num_cnn=self.num_cnn, 183 | cnn_channels=self.cnn_channels, 184 | hidden_layers= 185 | mlp_arch_3_params(**self.obs_hidden), 186 | dim_out=self.dim_lat_obs, 187 | act_func=self.act_func) 188 | 189 | self.lat_var_net = CNNMLP(name="ProNMPEncoder_lat_var_" + self.name, 190 | image_size=self.image_size, 191 | kernel_size=self.kernel_size, 192 | num_cnn=self.num_cnn, 193 | cnn_channels=self.cnn_channels, 194 | hidden_layers= 195 | mlp_arch_3_params(**self.unc_hidden), 196 | dim_out=self.dim_lat_obs, 197 | act_func=self.act_func) 198 | 199 | def encode(self, obs): 200 | """ 201 | Encode observations 202 | 203 | Args: 204 | obs: observations 205 | 206 | Returns: 207 | lat_obs: latent observations 208 | var_lat_obs: variance of latent observations 209 | """ 210 | 211 | # Shape of obs: 212 | # [num_traj, num_obs, C, H, W], 213 | # 214 | # Shape of lat_obs: 215 | # [num_traj, num_obs, dim_lat] 216 | # 217 | # Shape of var_lat_obs: 218 | # [num_traj, num_obs, dim_lat] 219 | 220 | # Check input shapes 221 | assert obs.ndim == 5 222 | 223 | # Encode 224 | return self.lat_mean_net(obs), \ 225 | to_softplus_space(self.lat_var_net(obs), lower_bound=None) 226 | 227 | 228 | class CNMPEncoder(ABC): 229 | 230 | def __init__(self, **kwargs): 231 | """ 232 | CNMP encoder constructor 233 | 234 | Args: 235 | **kwargs: Encoder configuration 236 | """ 237 | 238 | # Encoder 239 | self.lat_obs_net = None 240 | self._create_network(**kwargs) 241 | 242 | @abstractmethod 243 | def _create_network(self, *args, **kwargs): 244 | """ 245 | Create encoder network with given configuration 246 | 247 | Returns: 248 | None 249 | """ 250 | pass 251 | 252 | @property 253 | def network(self): 254 | """ 255 | Return encoder network 256 | 257 | Returns: 258 | """ 259 | return self.lat_obs_net 260 | 261 | @property 262 | def parameters(self): 263 | """ 264 | Get network parameters 265 | Returns: 266 | parameters 267 | """ 268 | 269 | return list(self.lat_obs_net.parameters()) 270 | 271 | def save_weights(self, log_dir: str, epoch: int): 272 | """ 273 | Save NN weights to file 274 | Args: 275 | log_dir: directory to save weights to 276 | epoch: training epoch 277 | 278 | Returns: 279 | None 280 | """ 281 | 282 | self.lat_obs_net.save(log_dir, epoch) 283 | 284 | def load_weights(self, log_dir: str, epoch: int): 285 | """ 286 | Load NN weights from file 287 | Args: 288 | log_dir: directory stored weights 289 | epoch: training epoch 290 | 291 | Returns: 292 | None 293 | """ 294 | self.lat_obs_net.load(log_dir, epoch) 295 | 296 | @abstractmethod 297 | def encode(self, *args, **kwargs): 298 | """ 299 | Encode observations 300 | 301 | Returns: 302 | lat_obs: latent observations 303 | """ 304 | pass 305 | 306 | 307 | class CNMPEncoderMlp(CNMPEncoder): 308 | 309 | def _create_network(self, **kwargs): 310 | """ 311 | Create encoder network with given configuration 312 | 313 | Returns: 314 | None 315 | """ 316 | 317 | # MLP configuration 318 | self.name: str = kwargs["name"] 319 | self.dim_obs: int = kwargs["dim_obs"] 320 | self.dim_lat_obs: int = kwargs["dim_lat"] 321 | self.obs_hidden: dict = kwargs["obs_hidden"] 322 | self.act_func: str = kwargs["act_func"] 323 | 324 | self.lat_obs_net = MLP(name="CNMPEncoder_lat_obs_" + self.name, 325 | dim_in=self.dim_obs, 326 | dim_out=self.dim_lat_obs, 327 | hidden_layers= 328 | mlp_arch_3_params(**self.obs_hidden), 329 | act_func=self.act_func) 330 | 331 | def encode(self, obs): 332 | """ 333 | Encode observations 334 | 335 | Args: 336 | obs: observations 337 | 338 | Returns: 339 | lat_obs: latent observations 340 | """ 341 | 342 | # Shape of obs: 343 | # [num_traj, num_obs, dim_obs], 344 | # 345 | # Shape of lat_obs: 346 | # [num_traj, num_obs, dim_lat] 347 | 348 | # Check input shapes 349 | assert obs.ndim == 3 350 | 351 | # Encode 352 | return self.lat_obs_net(obs) 353 | 354 | 355 | class CNMPEncoderCnnMlp(CNMPEncoder): 356 | 357 | def _create_network(self, **kwargs): 358 | """ 359 | Create encoder with given configuration 360 | 361 | Returns: 362 | None 363 | """ 364 | 365 | # configuration 366 | self.name = kwargs["name"] 367 | self.image_size = kwargs["image_size"] 368 | self.kernel_size = kwargs["kernel_size"] 369 | self.num_cnn = kwargs["num_cnn"] 370 | self.cnn_channels = kwargs["cnn_channels"] 371 | self.dim_lat_obs: int = kwargs["dim_lat"] 372 | self.obs_hidden: dict = kwargs["obs_hidden"] 373 | self.act_func = kwargs["act_func"] 374 | 375 | # lat_obs_net 376 | self.lat_obs_net = CNNMLP(name="CNMPEncoder_lat_obs_" + self.name, 377 | image_size=self.image_size, 378 | kernel_size=self.kernel_size, 379 | num_cnn=self.num_cnn, 380 | cnn_channels=self.cnn_channels, 381 | hidden_layers= 382 | mlp_arch_3_params(**self.obs_hidden), 383 | dim_out=self.dim_lat_obs, 384 | act_func=self.act_func) 385 | 386 | def encode(self, obs): 387 | """ 388 | Encode observations 389 | 390 | Args: 391 | obs: observations 392 | 393 | Returns: 394 | lat_obs: latent observations 395 | """ 396 | 397 | # Shape of obs: 398 | # [num_traj, num_obs, C, H, W], 399 | # 400 | # Shape of lat_obs: 401 | # [num_traj, num_obs, dim_lat] 402 | # 403 | # Shape of var_lat_obs: 404 | # [num_traj, num_obs, dim_lat] 405 | 406 | # Check input shapes 407 | assert obs.ndim == 5 408 | 409 | # Encode 410 | return self.lat_obs_net(obs) 411 | 412 | 413 | class EncoderFactory: 414 | 415 | @staticmethod 416 | def get_encoder(encoder_type: str, **kwargs): 417 | return eval(encoder_type + "(**kwargs)") 418 | 419 | @staticmethod 420 | def get_encoders(**config) -> dict: 421 | encoder_dict = dict() 422 | for encoder_name, encoder_info in config.items(): 423 | encoder_info["args"]["name"] = encoder_name 424 | encoder = EncoderFactory.get_encoder(encoder_info["type"], 425 | **(encoder_info["args"])) 426 | encoder_dict[encoder_name] = encoder 427 | return encoder_dict 428 | -------------------------------------------------------------------------------- /nmp/experiment/digit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALRhub/ProDMP_RAL/78063bdb4c9ad04e8a16d7b5a14d4077de774082/nmp/experiment/digit/__init__.py -------------------------------------------------------------------------------- /nmp/experiment/digit/digit.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from mp_pytorch.mp import MPFactory 7 | from nmp import get_data_loaders_and_normalizer 8 | from nmp import nll_loss 9 | from nmp import select_ctx_pred_pts 10 | from nmp import util 11 | from nmp.aggregator import AggregatorFactory 12 | from nmp.data_process import NormProcess 13 | from nmp.decoder import DecoderFactory 14 | from nmp.encoder import EncoderFactory 15 | from nmp.net import MPNet 16 | from nmp.others.ellipses_noise import EllipseNoiseTransform 17 | 18 | 19 | class OneDigit: 20 | def __init__(self, cfg): 21 | random.seed(cfg["seed"]) 22 | np.random.seed(cfg["seed"]) 23 | torch.manual_seed(cfg["seed"]) 24 | 25 | # Net 26 | self.encoder_dict = EncoderFactory.get_encoders(**cfg["encoders"]) 27 | self.aggregator = AggregatorFactory.get_aggregator(cfg["aggregator"]["type"], 28 | **cfg["aggregator"]["args"]) 29 | self.decoder = DecoderFactory.get_decoder(cfg["decoder"]["type"], 30 | **cfg["decoder"]["args"]) 31 | self.net = MPNet(self.encoder_dict, self.aggregator, self.decoder) 32 | 33 | # Dataset and Dataloader 34 | dataset = util.load_npz_dataset(cfg["dataset"]["name"]) 35 | self.train_loader, self.vali_loader, self.test_loader, self.normalizer \ 36 | = get_data_loaders_and_normalizer(dataset, **cfg["dataset"], 37 | seed=cfg["seed"]) 38 | 39 | # Data assignment config 40 | self.assign_config = cfg["assign_config"] 41 | 42 | # Reconstructor 43 | self.mp = MPFactory.init_mp(device="cuda", **cfg["mp"]) 44 | 45 | # Optimizer 46 | self.optimizer = torch.optim.Adam(params=self.net.get_net_params(), 47 | lr=float(cfg["lr"]), 48 | weight_decay=float(cfg["wd"])) 49 | self.net_params = self.net.get_net_params() 50 | 51 | # Runtime noise 52 | self.runtime_noise = cfg.get("runtime_noise", False) 53 | 54 | # Denormalize 55 | self.denormalize = cfg.get("denormalize", True) 56 | 57 | # Zero start 58 | self.zero_start = cfg.get("zero_start", False) 59 | 60 | def compute_loss(self, batch): 61 | 62 | if self.runtime_noise: 63 | batch = self.add_runtime_noise(batch) 64 | 65 | _, pred_index = select_ctx_pred_pts(**self.assign_config) 66 | pred_pairs = torch.combinations(pred_index, 2) 67 | 68 | num_traj = batch["images"]["value"].shape[0] 69 | num_agg = batch["images"]["value"].shape[1] + 1 70 | num_pred_pairs = pred_pairs.shape[0] 71 | 72 | # Get encoder input 73 | num_total_ctx = batch["images"]["value"].shape[1] 74 | if num_total_ctx == 1: 75 | ctx = {"cnn": batch["images"]["value"]} 76 | elif num_total_ctx > 1: 77 | # Remove original img in noise case 78 | # Only use the first 3 noisy images 79 | ctx = {"cnn": batch["images"]["value"][:, :-1]} 80 | num_agg -= 1 81 | 82 | # Reconstructor input 83 | 84 | init_time = torch.zeros(num_traj, num_agg, num_pred_pairs) 85 | init_vel = torch.zeros(num_traj, num_agg, num_pred_pairs, self.mp.num_dof) 86 | times = util.add_expand_dim(batch["trajs"]["time"][:, pred_pairs], 87 | add_dim_indices=[1], 88 | add_dim_sizes=[num_agg]) 89 | 90 | # Ground-truth 91 | gt = util.add_expand_dim(batch["trajs"]["value"][:, pred_pairs], 92 | add_dim_indices=[1], add_dim_sizes=[num_agg]) 93 | # Switch the time and dof dimension 94 | gt = torch.einsum('...ji->...ij', gt) 95 | # Make the time and dof dimensions flat 96 | gt = gt.reshape(*gt.shape[:-2], -1) 97 | 98 | # Predict 99 | mean, diag, off_diag = self.net.predict(num_traj=num_traj, 100 | enc_inputs=ctx, 101 | dec_input=None) 102 | # Denormalize prediction 103 | if self.denormalize: 104 | mean, L = NormProcess.distribution_denormalize(self.normalizer, 105 | "init_x_y_dmp_w_g", 106 | mean, diag, off_diag) 107 | else: 108 | L = util.build_lower_matrix(diag, off_diag) 109 | 110 | # Split initial position and DMP weights 111 | start_point = mean[..., 0, :self.mp.num_dof] 112 | mean = mean[..., self.mp.num_dof:].squeeze(-2) 113 | L = L[..., self.mp.num_dof:, self.mp.num_dof:].squeeze(-3) 114 | assert mean.ndim == 3 115 | 116 | # Add dim of time group 117 | mean = util.add_expand_dim(data=mean, 118 | add_dim_indices=[-2], 119 | add_dim_sizes=[num_pred_pairs]) 120 | L = util.add_expand_dim(data=L, 121 | add_dim_indices=[-3], 122 | add_dim_sizes=[num_pred_pairs]) 123 | start_point = util.add_expand_dim(data=start_point, 124 | add_dim_indices=[-2], 125 | add_dim_sizes=[num_pred_pairs]) 126 | 127 | # Reconstruct predicted trajectories 128 | if self.zero_start: 129 | init_pos = torch.zeros_like(start_point) 130 | else: 131 | init_pos = start_point 132 | self.mp.update_inputs(times=times, params=mean, params_L=L, 133 | init_time=init_time, init_pos=init_pos, 134 | init_vel=init_vel) 135 | traj_pos_mean = self.mp.get_traj_pos(flat_shape=True) 136 | 137 | if self.zero_start: 138 | start_point = util.add_expand_dim(start_point, [-1], 139 | [times.shape[-1]]) 140 | traj_pos_mean += start_point.reshape(*start_point.shape[:-2], -1) 141 | 142 | traj_pos_L = torch.linalg.cholesky(self.mp.get_traj_pos_cov()) 143 | 144 | # Loss 145 | loss = nll_loss(gt, traj_pos_mean, traj_pos_L) 146 | return loss 147 | 148 | @staticmethod 149 | def add_runtime_noise(batch): 150 | transform = EllipseNoiseTransform() 151 | batch["images"]["value"] = transform(batch["images"]["value"]) 152 | return batch 153 | -------------------------------------------------------------------------------- /nmp/experiment/digit/digit_cw.py: -------------------------------------------------------------------------------- 1 | from cw2 import cluster_work 2 | from cw2 import cw_error 3 | from cw2 import experiment 4 | from cw2.cw_data import cw_logging 5 | from cw2.cw_data.cw_wandb_logger import WandBLogger 6 | 7 | from nmp import util 8 | from nmp.experiment.digit.digit import OneDigit 9 | from nmp.net import avg_batch_loss 10 | 11 | 12 | class OneDigitCW(experiment.AbstractIterativeExperiment): 13 | def initialize(self, cw_config: dict, 14 | rep: int, logger: cw_logging.LoggerArray) -> None: 15 | # Device 16 | util.use_cuda() 17 | 18 | # Random seed 19 | cfg = cw_config["params"] 20 | 21 | self.exp = OneDigit(cfg) 22 | 23 | # Log interval 24 | self.vali_interval = cfg["vali_log_interval"] 25 | self.save_model_interval = cfg["save_model_interval"] 26 | 27 | # Logger 28 | self.save_model_dir = cw_config.get("save_model_dir", None) 29 | if self.save_model_dir: 30 | util.remove_file_dir(self.save_model_dir) 31 | util.mkdir(self.save_model_dir, overwrite=True) 32 | # Save configuration 33 | util.dump_config(dict(cfg), "config", self.save_model_dir) 34 | 35 | def iterate(self, cw_config: dict, rep: int, n: int) -> dict: 36 | max_norm = cw_config["params"].get("max_norm", None) 37 | train_loss = avg_batch_loss(self.exp.train_loader, 38 | self.exp.compute_loss, self.exp.optimizer, 39 | self.exp.net_params, max_norm) 40 | 41 | if n % self.vali_interval == 0: 42 | vali_loss = avg_batch_loss(self.exp.vali_loader, 43 | self.exp.compute_loss, None, None, None) 44 | else: 45 | vali_loss = None 46 | print(n) 47 | 48 | return {"train_loss": train_loss, "vali_loss": vali_loss} 49 | 50 | def save_state(self, cw_config: dict, rep: int, n: int) -> None: 51 | if self.save_model_dir and ((n + 1) % self.save_model_interval == 0 52 | or (n + 1) == cw_config["iterations"]): 53 | self.exp.net.save_weights(log_dir=self.save_model_dir, 54 | epoch=n + 1) 55 | 56 | def finalize(self, surrender: cw_error.ExperimentSurrender = None, 57 | crash: bool = False): 58 | pass 59 | 60 | 61 | if __name__ == "__main__": 62 | cw = cluster_work.ClusterWork(OneDigitCW) 63 | cw.add_logger(WandBLogger()) 64 | cw.run() 65 | -------------------------------------------------------------------------------- /nmp/experiment/robot_push/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALRhub/ProDMP_RAL/78063bdb4c9ad04e8a16d7b5a14d4077de774082/nmp/experiment/robot_push/__init__.py -------------------------------------------------------------------------------- /nmp/experiment/robot_push/robot_push_bhc.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import torch 5 | from cw2 import cluster_work 6 | from cw2 import cw_error 7 | from cw2 import experiment 8 | from cw2.cw_data import cw_logging 9 | from cw2.cw_data.cw_wandb_logger import WandBLogger 10 | 11 | from nmp import BehaviorCloningNet 12 | from nmp import get_data_loaders_and_normalizer 13 | from nmp import mse_loss 14 | from nmp import util 15 | from nmp.data_process import NormProcess 16 | from nmp.encoder import EncoderFactory 17 | from nmp.net import avg_batch_loss 18 | 19 | 20 | class RobotPushBehaviorCloning(experiment.AbstractIterativeExperiment): 21 | def initialize(self, cw_config: dict, 22 | rep: int, logger: cw_logging.LoggerArray) -> None: 23 | # Device 24 | util.use_cuda() 25 | 26 | # Random seed 27 | cfg = cw_config["params"] 28 | cfg["seed"] = rep 29 | random.seed(cfg["seed"]) 30 | np.random.seed(cfg["seed"]) 31 | torch.manual_seed(cfg["seed"]) 32 | 33 | # Net 34 | self.net = BehaviorCloningNet( 35 | EncoderFactory.get_encoder(cfg["mlp_net"]["type"], 36 | **cfg["mlp_net"]["args"])) 37 | 38 | # Dataset and Dataloader 39 | dataset = util.load_npz_dataset(cfg["dataset"]["name"]) 40 | self.train_loader, self.vali_loader, self.test_loader, self.normalizer \ 41 | = get_data_loaders_and_normalizer(dataset, **cfg["dataset"], 42 | seed=cfg["seed"]) 43 | 44 | # Optimizer 45 | self.optimizer = torch.optim.Adam(params=self.net.get_net_params(), 46 | lr=float(cfg["lr"]), 47 | weight_decay=float(cfg["wd"])) 48 | self.net_params = self.net.get_net_params() 49 | 50 | # Log interval 51 | self.vali_interval = cfg["vali_log_interval"] 52 | self.save_model_interval = cfg["save_model_interval"] 53 | 54 | # Logger 55 | self.save_model_dir = cw_config.get("save_model_dir", None) 56 | if self.save_model_dir: 57 | util.remove_file_dir(self.save_model_dir) 58 | util.mkdir(self.save_model_dir, overwrite=True) 59 | 60 | def iterate(self, cw_config: dict, rep: int, n: int) -> dict: 61 | max_norm = cw_config["params"].get("max_norm", None) 62 | train_loss = \ 63 | avg_batch_loss(self.train_loader, self.compute_loss, self.optimizer, 64 | self.net_params, max_norm) 65 | 66 | if n % self.vali_interval == 0: 67 | vali_loss = avg_batch_loss(self.vali_loader, self.compute_loss, 68 | None, None, None) 69 | else: 70 | vali_loss = None 71 | print(n) 72 | 73 | return {"train_loss": train_loss, "vali_loss": vali_loss} 74 | 75 | def save_state(self, cw_config: dict, rep: int, n: int) -> None: 76 | if self.save_model_dir and ((n + 1) % self.save_model_interval == 0 77 | or (n + 1) == cw_config["iterations"]): 78 | self.net.save_weights(log_dir=self.save_model_dir, 79 | epoch=n + 1) 80 | 81 | def finalize(self, surrender: cw_error.ExperimentSurrender = None, 82 | crash: bool = False): 83 | torch.cuda.empty_cache() 84 | 85 | def compute_loss(self, batch): 86 | # Choose the point to start obs 87 | 88 | # Get net input 89 | ctx_times = batch["box_robot_state"]["time"][..., None] 90 | ctx_values = batch["box_robot_state"]["value"] 91 | norm_ctx_dict = NormProcess.batch_normalize(self.normalizer, 92 | {"box_robot_state": 93 | {"time": ctx_times, 94 | "value": ctx_values}}) 95 | ctx_times = norm_ctx_dict["box_robot_state"]["time"] 96 | ctx_values = norm_ctx_dict["box_robot_state"]["value"] 97 | 98 | ctx = torch.cat([ctx_times, ctx_values[..., :-2]], dim=-1) 99 | # ctx = ctx_values[..., :-2] 100 | 101 | # Ground-truth, dof=2 102 | gt = batch["des_cart_pos_vel"]["value"][..., :2] 103 | 104 | # Predict 105 | pred_traj = self.net.predict(net_input=ctx) 106 | 107 | # Loss 108 | loss = mse_loss(gt, pred_traj) 109 | 110 | return loss 111 | 112 | 113 | if __name__ == "__main__": 114 | cw = cluster_work.ClusterWork(RobotPushBehaviorCloning) 115 | 116 | # Optional: Add loggers 117 | cw.add_logger(WandBLogger()) 118 | cw.run() 119 | -------------------------------------------------------------------------------- /nmp/experiment/robot_push/robot_push_cnmp.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import torch 5 | from cw2 import cluster_work 6 | from cw2 import cw_error 7 | from cw2 import experiment 8 | from cw2.cw_data import cw_logging 9 | from cw2.cw_data.cw_wandb_logger import WandBLogger 10 | from nmp import get_data_loaders_and_normalizer 11 | from nmp import nll_loss 12 | from nmp import util 13 | from nmp.aggregator import AggregatorFactory 14 | from nmp.data_process import NormProcess 15 | from nmp.decoder import DecoderFactory 16 | from nmp.encoder import EncoderFactory 17 | from nmp.net import MPNet 18 | from nmp.net import avg_batch_loss 19 | 20 | 21 | class RobotPushCNMP(experiment.AbstractIterativeExperiment): 22 | def initialize(self, cw_config: dict, 23 | rep: int, logger: cw_logging.LoggerArray) -> None: 24 | # Device 25 | util.use_cuda() 26 | 27 | # Random seed 28 | cfg = cw_config["params"] 29 | cfg["seed"] = rep 30 | random.seed(cfg["seed"]) 31 | np.random.seed(cfg["seed"]) 32 | torch.manual_seed(cfg["seed"]) 33 | 34 | # Net 35 | self.encoder_dict = EncoderFactory.get_encoders(**cfg["encoders"]) 36 | self.aggregator = AggregatorFactory.get_aggregator( 37 | cfg["aggregator"]["type"], 38 | **cfg["aggregator"]["args"]) 39 | self.decoder = DecoderFactory.get_decoder(cfg["decoder"]["type"], 40 | **cfg["decoder"]["args"]) 41 | self.net = MPNet(self.encoder_dict, self.aggregator, self.decoder) 42 | 43 | # Dataset and Dataloader 44 | dataset = util.load_npz_dataset(cfg["dataset"]["name"]) 45 | self.train_loader, self.vali_loader, self.test_loader, self.normalizer \ 46 | = get_data_loaders_and_normalizer(dataset, **cfg["dataset"], 47 | seed=cfg["seed"]) 48 | 49 | # Data assignment config 50 | self.assign_config = cfg["assign_config"] 51 | 52 | # Optimizer 53 | self.optimizer = torch.optim.Adam(params=self.net.get_net_params(), 54 | lr=float(cfg["lr"]), 55 | weight_decay=float(cfg["wd"])) 56 | self.net_params = self.net.get_net_params() 57 | 58 | # Log interval 59 | self.vali_interval = cfg["vali_log_interval"] 60 | self.save_model_interval = cfg["save_model_interval"] 61 | 62 | # Logger 63 | self.save_model_dir = cw_config.get("save_model_dir", None) 64 | if self.save_model_dir: 65 | util.remove_file_dir(self.save_model_dir) 66 | util.mkdir(self.save_model_dir, overwrite=True) 67 | 68 | def iterate(self, cw_config: dict, rep: int, n: int) -> dict: 69 | max_norm = cw_config["params"].get("max_norm", None) 70 | train_loss = \ 71 | avg_batch_loss(self.train_loader, self.compute_loss, self.optimizer, 72 | self.net_params, max_norm) 73 | 74 | if n % self.vali_interval == 0: 75 | vali_loss = avg_batch_loss(self.vali_loader, self.compute_loss, 76 | None, None, None) 77 | else: 78 | vali_loss = None 79 | print(n) 80 | 81 | return {"train_loss": train_loss, "vali_loss": vali_loss} 82 | 83 | def save_state(self, cw_config: dict, rep: int, n: int) -> None: 84 | if self.save_model_dir and ((n + 1) % self.save_model_interval == 0 85 | or (n + 1) == cw_config["iterations"]): 86 | self.net.save_weights(log_dir=self.save_model_dir, 87 | epoch=n + 1) 88 | 89 | def finalize(self, surrender: cw_error.ExperimentSurrender = None, 90 | crash: bool = False): 91 | torch.cuda.empty_cache() 92 | 93 | def compute_loss(self, batch): 94 | # Choose the point to start obs 95 | num_ctx_min = self.assign_config["num_ctx_min"] 96 | num_ctx_max = self.assign_config["num_ctx_max"] 97 | num_ctx = torch.randint(num_ctx_min, num_ctx_max + 1, []) 98 | pred_range_min = self.assign_config["pred_range_min"] 99 | pred_range_max = self.assign_config["pred_range_max"] 100 | pred_range = torch.randint(pred_range_min, pred_range_max + 1, []) 101 | # pred_step = pred_range // 10 + 1 102 | # num_total = num_ctx + pred_range.item() 103 | start_obs_idx_max = 301 - num_ctx - 10 104 | start_obs_idx = torch.randint(0, start_obs_idx_max + 1, [], 105 | dtype=torch.long) 106 | ctx_index = torch.arange(start_obs_idx, start_obs_idx + num_ctx, step=1, 107 | dtype=torch.long) 108 | pred_index_max = min(300, ctx_index[-1] + 1 + pred_range) 109 | pred_index = torch.arange(ctx_index[-1] + 1, pred_index_max, 110 | dtype=torch.long) 111 | 112 | # Get encoder input 113 | time_ctx_last = batch["box_robot_state"]["time"][:, ctx_index[-1]] 114 | ctx_times = (batch["box_robot_state"]["time"] 115 | - time_ctx_last[:, None])[:, ctx_index][..., None] 116 | ctx_values = batch["box_robot_state"]["value"][:, ctx_index] 117 | norm_ctx_dict = NormProcess.batch_normalize(self.normalizer, 118 | {"box_robot_state": 119 | {"time": ctx_times, 120 | "value": ctx_values}}) 121 | # ctx_times = norm_ctx_dict["box_robot_state"]["time"] 122 | ctx_values = norm_ctx_dict["box_robot_state"]["value"] 123 | 124 | ctx = {"ctx": torch.cat([ctx_times, ctx_values], dim=-1)} 125 | 126 | num_traj = batch["box_robot_state"]["value"].shape[0] 127 | 128 | times = (batch["des_cart_pos_vel"]["time"] - time_ctx_last[:, None])[:, 129 | pred_index] 130 | 131 | # Ground-truth, dof=2 132 | gt = batch["des_cart_pos_vel"]["value"][:, pred_index, :2] 133 | 134 | # Predict 135 | mean, diag, off_diag = self.net.predict(num_traj=num_traj, 136 | enc_inputs=ctx, 137 | dec_input=times[..., None]) 138 | assert off_diag is None 139 | 140 | # Denormalize prediction 141 | L = util.build_lower_matrix(diag, off_diag) 142 | 143 | mean = mean.squeeze(1) 144 | L = L.squeeze(1) 145 | 146 | assert mean.ndim == 3 147 | 148 | # Loss 149 | loss = nll_loss(gt, mean, L) 150 | return loss 151 | 152 | 153 | if __name__ == "__main__": 154 | cw = cluster_work.ClusterWork(RobotPushCNMP) 155 | 156 | # Optional: Add loggers 157 | cw.add_logger(WandBLogger()) 158 | cw.run() 159 | -------------------------------------------------------------------------------- /nmp/experiment/robot_push/robot_push_prodmp.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import torch 5 | from cw2 import cluster_work 6 | from cw2 import cw_error 7 | from cw2 import experiment 8 | from cw2.cw_data import cw_logging 9 | from cw2.cw_data.cw_wandb_logger import WandBLogger 10 | from mp_pytorch.mp import MPFactory 11 | from nmp import get_data_loaders_and_normalizer 12 | from nmp import nll_loss 13 | from nmp import util 14 | from nmp.aggregator import AggregatorFactory 15 | from nmp.data_process import NormProcess 16 | from nmp.decoder import DecoderFactory 17 | from nmp.encoder import EncoderFactory 18 | from nmp.net import MPNet 19 | from nmp.net import avg_batch_loss 20 | 21 | 22 | class RobotPush(experiment.AbstractIterativeExperiment): 23 | def initialize(self, cw_config: dict, 24 | rep: int, logger: cw_logging.LoggerArray) -> None: 25 | # Device 26 | util.use_cuda() 27 | 28 | # Random seed 29 | cfg = cw_config["params"] 30 | cfg["seed"] = rep 31 | random.seed(cfg["seed"]) 32 | np.random.seed(cfg["seed"]) 33 | torch.manual_seed(cfg["seed"]) 34 | 35 | # Net 36 | self.encoder_dict = EncoderFactory.get_encoders(**cfg["encoders"]) 37 | self.aggregator = AggregatorFactory.get_aggregator( 38 | cfg["aggregator"]["type"], 39 | **cfg["aggregator"]["args"]) 40 | self.decoder = DecoderFactory.get_decoder(cfg["decoder"]["type"], 41 | **cfg["decoder"]["args"]) 42 | self.net = MPNet(self.encoder_dict, self.aggregator, self.decoder) 43 | 44 | # Dataset and Dataloader 45 | dataset = util.load_npz_dataset(cfg["dataset"]["name"]) 46 | self.train_loader, self.vali_loader, self.test_loader, self.normalizer \ 47 | = get_data_loaders_and_normalizer(dataset, **cfg["dataset"], 48 | seed=cfg["seed"]) 49 | 50 | # Data assignment config 51 | self.assign_config = cfg["assign_config"] 52 | 53 | # Reconstructor 54 | self.mp = MPFactory.init_mp(device="cuda", **cfg["mp"]) 55 | 56 | # Optimizer 57 | self.optimizer = torch.optim.Adam(params=self.net.get_net_params(), 58 | lr=float(cfg["lr"]), 59 | weight_decay=float(cfg["wd"])) 60 | self.net_params = self.net.get_net_params() 61 | 62 | # Log interval 63 | self.vali_interval = cfg["vali_log_interval"] 64 | self.save_model_interval = cfg["save_model_interval"] 65 | 66 | # Logger 67 | self.save_model_dir = cw_config.get("save_model_dir", None) 68 | if self.save_model_dir: 69 | util.remove_file_dir(self.save_model_dir) 70 | util.mkdir(self.save_model_dir, overwrite=True) 71 | 72 | def iterate(self, cw_config: dict, rep: int, n: int) -> dict: 73 | max_norm = cw_config["params"].get("max_norm", None) 74 | train_loss = \ 75 | avg_batch_loss(self.train_loader, self.compute_loss, self.optimizer, 76 | self.net_params, max_norm) 77 | 78 | if n % self.vali_interval == 0: 79 | vali_loss = avg_batch_loss(self.vali_loader, self.compute_loss, 80 | None, None, None) 81 | else: 82 | vali_loss = None 83 | print(n) 84 | 85 | return {"train_loss": train_loss, "vali_loss": vali_loss} 86 | 87 | def save_state(self, cw_config: dict, rep: int, n: int) -> None: 88 | if self.save_model_dir and ((n + 1) % self.save_model_interval == 0 89 | or (n + 1) == cw_config["iterations"]): 90 | self.net.save_weights(log_dir=self.save_model_dir, 91 | epoch=n + 1) 92 | 93 | def finalize(self, surrender: cw_error.ExperimentSurrender = None, 94 | crash: bool = False): 95 | torch.cuda.empty_cache() 96 | 97 | def compute_loss(self, batch): 98 | # Choose the point to start obs 99 | num_ctx_min = self.assign_config["num_ctx_min"] 100 | num_ctx_max = self.assign_config["num_ctx_max"] 101 | num_ctx = torch.randint(num_ctx_min, num_ctx_max + 1, []) 102 | pred_range_min = self.assign_config["pred_range_min"] 103 | pred_range_max = self.assign_config["pred_range_max"] 104 | pred_range = torch.randint(pred_range_min, pred_range_max + 1, []) 105 | # pred_step = pred_range // 10 + 1 106 | # num_total = num_ctx + pred_range.item() 107 | start_obs_idx_max = 301 - num_ctx - 10 108 | start_obs_idx = torch.randint(0, start_obs_idx_max + 1, [], 109 | dtype=torch.long) 110 | ctx_index = torch.arange(start_obs_idx, start_obs_idx + num_ctx, step=1, 111 | dtype=torch.long) 112 | pred_index_max = min(300, ctx_index[-1] + 1 + pred_range) 113 | pred_index = torch.arange(ctx_index[-1], pred_index_max, 114 | dtype=torch.long) 115 | 116 | pred_pairs = torch.combinations(pred_index, 2).long() 117 | 118 | # Get encoder input 119 | time_ctx_last = batch["box_robot_state"]["time"][:, ctx_index[-1]] 120 | ctx_times = (batch["box_robot_state"]["time"] 121 | - time_ctx_last[:, None])[:, ctx_index][..., None] 122 | ctx_values = batch["box_robot_state"]["value"][:, ctx_index] 123 | norm_ctx_dict = NormProcess.batch_normalize(self.normalizer, 124 | {"box_robot_state": 125 | {"time": ctx_times, 126 | "value": ctx_values}}) 127 | # ctx_times = norm_ctx_dict["box_robot_state"]["time"] 128 | ctx_values = norm_ctx_dict["box_robot_state"]["value"] 129 | 130 | ctx = {"ctx": torch.cat([ctx_times, ctx_values], dim=-1)} 131 | 132 | # Reconstructor input 133 | num_traj = batch["box_robot_state"]["value"].shape[0] 134 | # num_agg = len(ctx_index) + 1 135 | num_agg = 1 136 | num_pred_pairs = pred_pairs.shape[0] 137 | 138 | # init_time = batch["des_cart_pos_vel"]["time"][:, ctx_index[-1]] 139 | init_time = torch.zeros([num_traj]) 140 | init_time = util.add_expand_dim(init_time, [1, -1], 141 | [num_agg, num_pred_pairs]) 142 | init_pos = batch["des_cart_pos_vel"]["value"][:, ctx_index[-1], 143 | :self.mp.num_dof] 144 | init_pos = util.add_expand_dim(init_pos, [1, -2], 145 | [num_agg, num_pred_pairs]) 146 | 147 | init_vel = batch["des_cart_pos_vel"]["value"][:, ctx_index[-1], 148 | self.mp.num_dof:] 149 | init_vel = util.add_expand_dim(init_vel, [1, -2], 150 | [num_agg, num_pred_pairs]) 151 | 152 | times = util.add_expand_dim( 153 | (batch["des_cart_pos_vel"]["time"] - time_ctx_last[:, None])[:, 154 | pred_pairs], 155 | add_dim_indices=[1], add_dim_sizes=[num_agg]) 156 | 157 | # Ground-truth 158 | gt = util.add_expand_dim( 159 | batch["des_cart_pos_vel"]["value"][:, pred_pairs, :self.mp.num_dof], 160 | add_dim_indices=[1], add_dim_sizes=[num_agg]) 161 | 162 | # Switch the time and dof dimension 163 | gt = torch.einsum('...ji->...ij', gt) 164 | # Make the time and dof dimensions flat 165 | gt = gt.reshape(*gt.shape[:-2], -1) 166 | 167 | # Predict 168 | mean, diag, off_diag = self.net.predict(num_traj=num_traj, 169 | enc_inputs=ctx, 170 | dec_input=None) 171 | 172 | # Denormalize prediction 173 | # mean, L = NormProcess.distribution_denormalize(self.normalizer, 174 | # "idmp", 175 | # mean, diag, off_diag) 176 | L = util.build_lower_matrix(diag, off_diag) 177 | 178 | mean = mean.squeeze(-2) 179 | L = L.squeeze(-3) 180 | 181 | assert mean.ndim == 3 182 | 183 | # Add dim of time group 184 | mean = util.add_expand_dim(data=mean, 185 | add_dim_indices=[-2], 186 | add_dim_sizes=[num_pred_pairs]) 187 | L = util.add_expand_dim(data=L, 188 | add_dim_indices=[-3], 189 | add_dim_sizes=[num_pred_pairs]) 190 | 191 | # Reconstruct predicted trajectories 192 | self.mp.update_inputs(times=times, params=mean, params_L=L, 193 | init_time=init_time, init_pos=init_pos, 194 | init_vel=init_vel) 195 | traj_pos_mean = self.mp.get_traj_pos(flat_shape=True) 196 | traj_pos_L = torch.linalg.cholesky(self.mp.get_traj_pos_cov()) 197 | 198 | # Loss 199 | loss = nll_loss(gt, traj_pos_mean, traj_pos_L) 200 | return loss 201 | 202 | 203 | if __name__ == "__main__": 204 | cw = cluster_work.ClusterWork(RobotPush) 205 | 206 | # Optional: Add loggers 207 | cw.add_logger(WandBLogger()) 208 | cw.run() 209 | -------------------------------------------------------------------------------- /nmp/experiment/robot_push/robot_push_promp.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import torch 5 | from cw2 import cluster_work 6 | from cw2 import cw_error 7 | from cw2 import experiment 8 | from cw2.cw_data import cw_logging 9 | from cw2.cw_data.cw_wandb_logger import WandBLogger 10 | from mp_pytorch.mp import MPFactory 11 | from nmp import get_data_loaders_and_normalizer 12 | from nmp import nll_loss 13 | from nmp import util 14 | from nmp.aggregator import AggregatorFactory 15 | from nmp.data_process import NormProcess 16 | from nmp.decoder import DecoderFactory 17 | from nmp.encoder import EncoderFactory 18 | from nmp.net import MPNet 19 | from nmp.net import avg_batch_loss 20 | 21 | 22 | class RobotPushProMP(experiment.AbstractIterativeExperiment): 23 | def initialize(self, cw_config: dict, 24 | rep: int, logger: cw_logging.LoggerArray) -> None: 25 | # Device 26 | util.use_cuda() 27 | 28 | # Random seed 29 | cfg = cw_config["params"] 30 | cfg["seed"] = rep 31 | random.seed(cfg["seed"]) 32 | np.random.seed(cfg["seed"]) 33 | torch.manual_seed(cfg["seed"]) 34 | 35 | # Net 36 | self.encoder_dict = EncoderFactory.get_encoders(**cfg["encoders"]) 37 | self.aggregator = AggregatorFactory.get_aggregator( 38 | cfg["aggregator"]["type"], 39 | **cfg["aggregator"]["args"]) 40 | self.decoder = DecoderFactory.get_decoder(cfg["decoder"]["type"], 41 | **cfg["decoder"]["args"]) 42 | self.net = MPNet(self.encoder_dict, self.aggregator, self.decoder) 43 | 44 | # Dataset and Dataloader 45 | dataset = util.load_npz_dataset(cfg["dataset"]["name"]) 46 | self.train_loader, self.vali_loader, self.test_loader, self.normalizer \ 47 | = get_data_loaders_and_normalizer(dataset, **cfg["dataset"], 48 | seed=cfg["seed"]) 49 | 50 | # Data assignment config 51 | self.assign_config = cfg["assign_config"] 52 | 53 | # Reconstructor 54 | self.mp = MPFactory.init_mp(device="cuda", **cfg["mp"]) 55 | 56 | # Optimizer 57 | self.optimizer = torch.optim.Adam(params=self.net.get_net_params(), 58 | lr=float(cfg["lr"]), 59 | weight_decay=float(cfg["wd"])) 60 | self.net_params = self.net.get_net_params() 61 | 62 | # Log interval 63 | self.vali_interval = cfg["vali_log_interval"] 64 | self.save_model_interval = cfg["save_model_interval"] 65 | 66 | # Logger 67 | self.save_model_dir = cw_config.get("save_model_dir", None) 68 | if self.save_model_dir: 69 | util.remove_file_dir(self.save_model_dir) 70 | util.mkdir(self.save_model_dir, overwrite=True) 71 | 72 | def iterate(self, cw_config: dict, rep: int, n: int) -> dict: 73 | max_norm = cw_config["params"].get("max_norm", None) 74 | train_loss = \ 75 | avg_batch_loss(self.train_loader, self.compute_loss, self.optimizer, 76 | self.net_params, max_norm) 77 | 78 | if n % self.vali_interval == 0: 79 | vali_loss = avg_batch_loss(self.vali_loader, self.compute_loss, 80 | None, None, None) 81 | else: 82 | vali_loss = None 83 | print(n) 84 | 85 | return {"train_loss": train_loss, "vali_loss": vali_loss} 86 | 87 | def save_state(self, cw_config: dict, rep: int, n: int) -> None: 88 | if self.save_model_dir and ((n + 1) % self.save_model_interval == 0 89 | or (n + 1) == cw_config["iterations"]): 90 | self.net.save_weights(log_dir=self.save_model_dir, 91 | epoch=n + 1) 92 | 93 | def finalize(self, surrender: cw_error.ExperimentSurrender = None, 94 | crash: bool = False): 95 | torch.cuda.empty_cache() 96 | 97 | def compute_loss(self, batch): 98 | # Choose the point to start obs 99 | num_ctx_min = self.assign_config["num_ctx_min"] 100 | num_ctx_max = self.assign_config["num_ctx_max"] 101 | num_ctx = torch.randint(num_ctx_min, num_ctx_max + 1, []) 102 | pred_range_min = self.assign_config["pred_range_min"] 103 | pred_range_max = self.assign_config["pred_range_max"] 104 | pred_range = torch.randint(pred_range_min, pred_range_max + 1, []) 105 | # pred_step = pred_range // 10 + 1 106 | # num_total = num_ctx + pred_range.item() 107 | start_obs_idx_max = 301 - num_ctx - 10 108 | start_obs_idx = torch.randint(0, start_obs_idx_max + 1, [], 109 | dtype=torch.long) 110 | ctx_index = torch.arange(start_obs_idx, start_obs_idx + num_ctx, step=1, 111 | dtype=torch.long) 112 | pred_index_max = min(300, ctx_index[-1] + 1 + pred_range) 113 | pred_index = torch.arange(ctx_index[-1], pred_index_max, 114 | dtype=torch.long) 115 | 116 | pred_pairs = torch.combinations(pred_index, 2).long() 117 | 118 | # Get encoder input 119 | time_ctx_last = batch["box_robot_state"]["time"][:, ctx_index[-1]] 120 | ctx_times = (batch["box_robot_state"]["time"] 121 | - time_ctx_last[:, None])[:, ctx_index][..., None] 122 | ctx_values = batch["box_robot_state"]["value"][:, ctx_index] 123 | norm_ctx_dict = NormProcess.batch_normalize(self.normalizer, 124 | {"box_robot_state": 125 | {"time": ctx_times, 126 | "value": ctx_values}}) 127 | # ctx_times = norm_ctx_dict["box_robot_state"]["time"] 128 | ctx_values = norm_ctx_dict["box_robot_state"]["value"] 129 | 130 | ctx = {"ctx": torch.cat([ctx_times, ctx_values], dim=-1)} 131 | 132 | # Reconstructor input 133 | num_traj = batch["box_robot_state"]["value"].shape[0] 134 | # num_agg = len(ctx_index) + 1 135 | num_agg = 1 136 | num_pred_pairs = pred_pairs.shape[0] 137 | 138 | # init_time = batch["des_cart_pos_vel"]["time"][:, ctx_index[-1]] 139 | init_time = torch.zeros([num_traj]) 140 | init_time = util.add_expand_dim(init_time, [1, -1], 141 | [num_agg, num_pred_pairs]) 142 | init_pos = batch["des_cart_pos_vel"]["value"][:, ctx_index[-1], 143 | :self.mp.num_dof] 144 | init_pos = util.add_expand_dim(init_pos, [1, -2], 145 | [num_agg, num_pred_pairs]) 146 | 147 | init_vel = batch["des_cart_pos_vel"]["value"][:, ctx_index[-1], 148 | self.mp.num_dof:] 149 | init_vel = util.add_expand_dim(init_vel, [1, -2], 150 | [num_agg, num_pred_pairs]) 151 | 152 | times = util.add_expand_dim( 153 | (batch["des_cart_pos_vel"]["time"] - time_ctx_last[:, None])[:, 154 | pred_pairs], 155 | add_dim_indices=[1], add_dim_sizes=[num_agg]) 156 | 157 | # Ground-truth 158 | gt = util.add_expand_dim( 159 | batch["des_cart_pos_vel"]["value"][:, pred_pairs, :self.mp.num_dof], 160 | add_dim_indices=[1], add_dim_sizes=[num_agg]) 161 | 162 | # Switch the time and dof dimension 163 | gt = torch.einsum('...ji->...ij', gt) 164 | # Make the time and dof dimensions flat 165 | gt = gt.reshape(*gt.shape[:-2], -1) 166 | 167 | # Predict 168 | mean, diag, off_diag = self.net.predict(num_traj=num_traj, 169 | enc_inputs=ctx, 170 | dec_input=None) 171 | 172 | # Denormalize prediction 173 | # mean, L = NormProcess.distribution_denormalize(self.normalizer, 174 | # "idmp", 175 | # mean, diag, off_diag) 176 | L = util.build_lower_matrix(diag, off_diag) 177 | 178 | mean = mean.squeeze(-2) 179 | L = L.squeeze(-3) 180 | 181 | assert mean.ndim == 3 182 | 183 | # Add dim of time group 184 | mean = util.add_expand_dim(data=mean, 185 | add_dim_indices=[-2], 186 | add_dim_sizes=[num_pred_pairs]) 187 | L = util.add_expand_dim(data=L, 188 | add_dim_indices=[-3], 189 | add_dim_sizes=[num_pred_pairs]) 190 | 191 | # Reconstruct predicted trajectories 192 | self.mp.update_inputs(times=times, params=mean, params_L=L, 193 | init_time=init_time, init_pos=init_pos, 194 | init_vel=init_vel) 195 | traj_pos_mean = self.mp.get_traj_pos(flat_shape=True) 196 | traj_pos_L = torch.linalg.cholesky(self.mp.get_traj_pos_cov()) 197 | 198 | # Loss 199 | loss = nll_loss(gt, traj_pos_mean, traj_pos_L) 200 | return loss 201 | 202 | 203 | if __name__ == "__main__": 204 | cw = cluster_work.ClusterWork(RobotPushProMP) 205 | 206 | # Optional: Add loggers 207 | cw.add_logger(WandBLogger()) 208 | cw.run() 209 | -------------------------------------------------------------------------------- /nmp/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | @brief: Logger 3 | """ 4 | import csv 5 | import os 6 | 7 | import matplotlib.pyplot as plt 8 | import wandb 9 | 10 | import nmp.util as util 11 | 12 | 13 | class WandbLogger: 14 | def __init__(self, config): 15 | """ 16 | Initialize wandb logger 17 | Args: 18 | config: config file of current task 19 | """ 20 | self.project_name = config["logger"]["log_name"] 21 | entity = config["logger"].get("entity") 22 | group = config["logger"].get("group") 23 | self._initialize_log_dir() 24 | self._run = wandb.init(project=self.project_name, entity=entity, 25 | group=group, config=config) 26 | 27 | def _initialize_log_dir(self): 28 | """ 29 | Clean and initialize local log directory 30 | Returns: 31 | True if successfully cleaned 32 | """ 33 | # Clean old log 34 | util.remove_file_dir(self.log_dataset_dir) 35 | util.remove_file_dir(self.log_model_dir) 36 | util.remove_file_dir(self.log_dir) 37 | 38 | # Make new directory 39 | os.makedirs(self.log_dir) 40 | os.makedirs(self.log_dataset_dir) 41 | os.makedirs(self.log_model_dir) 42 | 43 | @property 44 | def config(self): 45 | """ 46 | Log configuration file 47 | 48 | Returns: 49 | synchronized config from wandb server 50 | 51 | """ 52 | return wandb.config 53 | 54 | @property 55 | def log_dir(self): 56 | """ 57 | Get local log saving directory 58 | Returns: 59 | log directory 60 | """ 61 | if not hasattr(self, "_log_dir"): 62 | self._log_dir = util.make_log_dir_with_time_stamp(self.project_name) 63 | 64 | return self._log_dir 65 | 66 | @property 67 | def log_dataset_dir(self): 68 | """ 69 | Get downloaded logged dataset directory 70 | Returns: 71 | logged dataset directory 72 | """ 73 | return os.path.join(self.log_dir, "dataset") 74 | 75 | @property 76 | def log_model_dir(self): 77 | """ 78 | Get downloaded logged model directory 79 | Returns: 80 | logged model directory 81 | """ 82 | return os.path.join(self.log_dir, "model") 83 | 84 | def log_dataset(self, 85 | dataset_name, 86 | pd_df_dict: dict): 87 | """ 88 | Log raw dataset to Artifact 89 | 90 | Args: 91 | dataset_name: Name of dataset 92 | pd_df_dict: dictionary of train, validate and test sets 93 | 94 | Returns: 95 | None 96 | """ 97 | 98 | # Initialize wandb Artifact 99 | raw_data = wandb.Artifact(name=dataset_name + "_dataset", 100 | type="dataset", 101 | description="dataset") 102 | 103 | # Save DataFrames in Artifact 104 | for key, value in pd_df_dict.items(): 105 | for index, pd_df in enumerate(value): 106 | with raw_data.new_file(key + "_{}.csv".format(index), 107 | mode="w") as file: 108 | file.write(pd_df.to_csv(path_or_buf=None, 109 | index=False, 110 | quoting=csv.QUOTE_ALL)) 111 | 112 | # Log Artifact 113 | self._run.log_artifact(raw_data) 114 | 115 | def log_info(self, 116 | epoch, 117 | key, 118 | value): 119 | self._run.log({"Epoch": epoch, 120 | key: value}) 121 | 122 | def log_model(self, 123 | finished: bool = False): 124 | """ 125 | Log model into Artifact 126 | 127 | Args: 128 | finished: True if current training is finished, this will clean 129 | the old model version without any special aliass 130 | 131 | Returns: 132 | None 133 | """ 134 | # Initialize wandb artifact 135 | model_artifact = wandb.Artifact(name="model", type="model") 136 | 137 | # Get all file names in log dir 138 | file_names = util.get_file_names_in_directory(self.log_model_dir) 139 | 140 | # Add files into artifact 141 | for file in file_names: 142 | path = os.path.join(self.log_model_dir, file) 143 | model_artifact.add_file(path) 144 | 145 | if finished: 146 | aliases = ["latest", 147 | "finished-{}".format(util.get_formatted_date_time())] 148 | else: 149 | aliases = ["latest"] 150 | 151 | # Log and upload 152 | self._run.log_artifact(model_artifact, aliases=aliases) 153 | 154 | if finished: 155 | self.delete_useless_model() 156 | 157 | def delete_useless_model(self): 158 | """ 159 | Delete useless models in WandB server 160 | Returns: 161 | None 162 | 163 | """ 164 | api = wandb.Api() 165 | 166 | artifact_type = "model" 167 | artifact_name = "{}/{}/model".format(self._run.entity, 168 | self._run.project) 169 | 170 | for version in api.artifact_versions(artifact_type, artifact_name): 171 | # Clean up all versions that don't have an alias such as 'latest'. 172 | if len(version.aliases) == 0: 173 | version.delete() 174 | 175 | def load_model(self, 176 | model_api: str): 177 | """ 178 | Load model from Artifact 179 | 180 | model_api: the string for load the model if init_epoch is not zero 181 | 182 | Returns: 183 | model_dir: Model's directory 184 | 185 | """ 186 | model_api = "self._" + model_api[11:] 187 | artifact = eval(model_api) 188 | artifact.download(root=self.log_model_dir) 189 | file_names = util.get_file_names_in_directory(self.log_model_dir) 190 | file_names.sort() 191 | util.print_line_title(title="Download model files from WandB") 192 | for file in file_names: 193 | print(file) 194 | return self.log_model_dir 195 | 196 | def watch_networks(self, 197 | networks, 198 | log_freq): 199 | """ 200 | Watch Neural Network weights and gradients 201 | Args: 202 | networks: network to being watched 203 | log_freq: frequency for logging 204 | 205 | Returns: 206 | None 207 | 208 | """ 209 | for idx, net in enumerate(networks): 210 | self._run.watch(net, 211 | log="all", 212 | log_freq=log_freq, 213 | idx=idx) 214 | 215 | def log_figure(self, 216 | figure_obj: plt.Figure, 217 | figure_name: str = "Unnamed Figure"): 218 | """ 219 | Log figure 220 | Args: 221 | figure_obj: Matplotlib Figure object 222 | figure_name: name of the figure 223 | 224 | Returns: 225 | None 226 | 227 | """ 228 | self._run.log({figure_name: wandb.Image(figure_obj)}) 229 | 230 | def log_video(self, 231 | path_to_video: str, 232 | video_name: str = "Unnamed Video"): 233 | """ 234 | Log video 235 | Args: 236 | path_to_video: path where the video is stored 237 | video_name: name of the video 238 | 239 | Returns: 240 | None 241 | """ 242 | self._run.log({video_name: wandb.Video(path_to_video)}) 243 | 244 | def log_data_dict(self, 245 | data_dict: dict): 246 | """ 247 | Log data in dictionary 248 | Args: 249 | data_dict: dictionary to log 250 | 251 | Returns: 252 | None 253 | """ 254 | self._run.log(data_dict) 255 | 256 | 257 | def get_logger_dict(): 258 | return {"wandb": WandbLogger} 259 | -------------------------------------------------------------------------------- /nmp/loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | @brief: Custom loss functions in PyTorch 3 | """ 4 | import numpy as np 5 | import torch 6 | from torch import nn 7 | from torch.distributions import MultivariateNormal 8 | 9 | 10 | def nll_loss(true_val, 11 | pred_mean, 12 | pred_L): 13 | """ 14 | Log likelihood loss 15 | Args: 16 | true_val: true target values 17 | pred_mean: predicted mean of the Normal distribution 18 | pred_L: alternatively, use predicted Cholesky Decomposition 19 | 20 | Returns: 21 | log likelihood 22 | 23 | """ 24 | # Shape of true_val: 25 | # [*add_dim, dim_val] 26 | # 27 | # Shape of pred_mean: 28 | # [*add_dim, dim_val] 29 | # 30 | # Shape of pred_L: 31 | # [*add_dim, dim_val, dim_val] 32 | 33 | # Construct distribution 34 | mvn = MultivariateNormal(loc=pred_mean, scale_tril=pred_L, 35 | validate_args=False) 36 | 37 | # Compute log likelihood 38 | ll = mvn.log_prob(true_val).mean() 39 | 40 | # Loss 41 | ll_loss = -ll 42 | return ll_loss 43 | 44 | 45 | def nmll_loss(true_val, 46 | pred_mean, 47 | pred_L, 48 | mc_smp_dim: int = -3): 49 | """ 50 | Marginal log likelihood loss 51 | Args: 52 | true_val: true target values 53 | pred_mean: predicted mean of the Normal distribution 54 | pred_L: alternatively, use predicted Cholesky Decomposition 55 | mc_smp_dim: where is the mc sample dimension 56 | 57 | Returns: 58 | 59 | """ 60 | # Shape of true_val: 61 | # [*add_dim_1, num_mc_smp, *add_dim_2, dim_val] 62 | # 63 | # Shape of pred_mean: 64 | # [*add_dim_1, num_mc_smp, *add_dim_2, dim_val] 65 | # 66 | # Shape of pred_L: 67 | # [*add_dim_1, num_mc_smp, *add_dim_2, dim_val, dim_val] 68 | 69 | # Check dimensions 70 | if mc_smp_dim < 0: 71 | mc_smp_dim = pred_mean.ndim + mc_smp_dim 72 | shapes = pred_mean.shape 73 | dimensions = list(range(pred_mean.ndim)) 74 | add_dim_1 = torch.tensor(shapes[:mc_smp_dim]) 75 | num_mc_smp = shapes[mc_smp_dim] 76 | add_dim_2 = torch.tensor(shapes[mc_smp_dim + 1:-1]) 77 | 78 | # Construct distribution 79 | mvn = MultivariateNormal(loc=pred_mean, scale_tril=pred_L, 80 | validate_args=False) 81 | 82 | # Compute log likelihood loss for each trajectory 83 | ll = mvn.log_prob(true_val) 84 | 85 | # Sum among additional dimensions part 2 86 | ll = torch.sum(ll, dim=dimensions[mc_smp_dim + 1:-1]) 87 | 88 | # MC average 89 | ll = torch.logsumexp(ll, dim=mc_smp_dim) 90 | 91 | # Average among additional dimensions part 1 92 | ll = torch.sum(ll, dim=dimensions[:mc_smp_dim]) 93 | assert ll.ndim == 0 94 | 95 | # Get marginal log likelihood 96 | ll = ll - torch.prod(add_dim_1) * np.log(num_mc_smp) 97 | mll = ll / (torch.prod(add_dim_1) * torch.prod(add_dim_2)) 98 | 99 | # Loss 100 | mll_loss = -mll 101 | return mll_loss 102 | 103 | 104 | def mse_loss(true_val, pred): 105 | """ 106 | Mean squared error 107 | 108 | Args: 109 | true_val: Ground truth 110 | pred: predicted value 111 | 112 | Returns: 113 | mse 114 | """ 115 | mse = nn.MSELoss() 116 | return mse(pred, true_val) 117 | -------------------------------------------------------------------------------- /nmp/net.py: -------------------------------------------------------------------------------- 1 | """ 2 | @brief: Network class 3 | """ 4 | from typing import Callable 5 | from typing import Tuple 6 | from typing import Union 7 | 8 | import torch.optim 9 | from torch import nn 10 | 11 | from nmp import CNMPEncoder 12 | from nmp import util 13 | from nmp.aggregator import * 14 | from nmp.decoder import Decoder 15 | from torch.utils.data import DataLoader 16 | 17 | 18 | class MPNet: 19 | def __init__(self, encoder_dict: dict, 20 | aggregator: Union[MeanAggregator, BayesianAggregator], 21 | decoder: Decoder): 22 | self.encoder_dict = encoder_dict 23 | self.aggregator = aggregator 24 | self.decoder = decoder 25 | 26 | def get_net_params(self): 27 | """ 28 | Get parameters to be optimized 29 | Returns: 30 | Tuple of parameters of neural networks 31 | 32 | """ 33 | # Decoder 34 | parameters = self.decoder.parameters 35 | 36 | # Encoders 37 | for encoder in self.encoder_dict.values(): 38 | parameters += encoder.parameters 39 | 40 | return (parameters) 41 | 42 | def save_weights(self, log_dir: str, epoch: int): 43 | """ 44 | Save parameters 45 | Args: 46 | log_dir: directory to save weights to 47 | epoch: training epoch 48 | 49 | Returns: 50 | None 51 | 52 | """ 53 | 54 | # Encoder 55 | for encoder in self.encoder_dict.values(): 56 | encoder.save_weights(log_dir, epoch) 57 | 58 | # Decoder 59 | self.decoder.save_weights(log_dir, epoch) 60 | 61 | def load_weights(self, log_dir: str, epoch: int): 62 | """ 63 | Load parameters 64 | Args: 65 | log_dir: directory stored weights 66 | epoch: training epoch 67 | 68 | Returns: 69 | None 70 | """ 71 | # Encoder 72 | for encoder in self.encoder_dict.values(): 73 | encoder.load_weights(log_dir, epoch) 74 | 75 | # Decoder 76 | self.decoder.load_weights(log_dir, epoch) 77 | 78 | def predict(self, num_traj: int, enc_inputs: dict, 79 | dec_input: Optional[torch.Tensor], 80 | **kwargs): 81 | """ 82 | Predict using the network 83 | 84 | Args: 85 | num_traj: batch size of the number of trajectories 86 | enc_inputs: input of encoder 87 | dec_input: output of encoder 88 | **kwargs: keyword arguments 89 | 90 | Returns: 91 | mean, variance of the predicted values 92 | """ 93 | # Reset aggregator 94 | self.aggregator.reset(num_traj=num_traj) 95 | 96 | # Loop over all encoders 97 | for encoder_name, encoder in self.encoder_dict.items(): 98 | # Get data assigned to it 99 | encoder_input = enc_inputs[encoder_name] 100 | 101 | # Encode, make result in tuple and store 102 | lat_obs = util.make_iterable(encoder.encode(encoder_input)) 103 | 104 | # Aggregate 105 | self.aggregator.aggregate(*lat_obs) 106 | 107 | # Get latent variable 108 | index = None if self.aggregator.multiple_steps else -1 109 | lat_var = util.make_iterable(self.aggregator.get_agg_state(index=index)) 110 | 111 | # Sample latent variables if necessary, todo remove this 112 | num_mc_smp = kwargs.get("num_mc_smp", 0) 113 | if num_mc_smp != 0: 114 | lat_var = self.sample_latent_variable(num_mc_smp, *lat_var) 115 | 116 | # Decode 117 | mean, diag, off_diag = self.decoder.decode(dec_input, *lat_var) 118 | 119 | # Return 120 | return mean, diag, off_diag 121 | 122 | @staticmethod 123 | def sample_latent_variable(num_mc_smp: int, lat_mean, lat_var): 124 | """ 125 | Sample latent variable for Monte-Carlo approximation 126 | 127 | Args: 128 | num_mc_smp: num of Monte-Carlo samples when necessary 129 | lat_mean: mean of latent variable 130 | lat_var: variance of latent variable 131 | 132 | Returns: 133 | sampled latent variable, shape: 134 | [num_traj, num_agg, num_smp, dim_lat] 135 | 136 | variance of latent variable, shape: 137 | [num_traj, num_agg, num_smp, dim_lat] 138 | """ 139 | gaussian = torch.distributions.normal.Normal(loc=lat_mean, 140 | scale=lat_var, 141 | validate_args=False) 142 | 143 | mc_smp = torch.einsum('kij...->ijk...', gaussian.rsample([num_mc_smp])) 144 | # lat_var = lat_var[..., None, :] 145 | lat_var = util.add_expand_dim(lat_var, [-2], [num_mc_smp]) 146 | return mc_smp, lat_var 147 | 148 | 149 | def avg_batch_loss(data_loader: DataLoader, 150 | loss_func: Callable, 151 | optimizer: Optional[torch.optim.Optimizer], 152 | params: Optional[Tuple[torch.Tensor]], 153 | max_norm: Optional[int] = 2): 154 | loss = 0.0 155 | num_batch = 0 156 | 157 | gradient_norm_list = [] 158 | for batch in data_loader: 159 | num_batch += 1 160 | 161 | if optimizer is not None: 162 | # Training 163 | batch_loss = loss_func(batch) 164 | 165 | # Optimize 166 | optimizer.zero_grad(set_to_none=True) 167 | batch_loss.backward() 168 | if max_norm: 169 | total_norm = 0 170 | parameters = [p for p in params if 171 | p.grad is not None and p.requires_grad] 172 | for p in parameters: 173 | param_norm = p.grad.detach().data.norm(2) 174 | total_norm += param_norm.item() ** 2 175 | total_norm = total_norm ** 0.5 176 | gradient_norm_list.append(total_norm) 177 | 178 | nn.utils.clip_grad_norm_(params, max_norm=max_norm, norm_type=2) 179 | 180 | optimizer.step() 181 | 182 | else: 183 | # Validation or Testing 184 | with torch.no_grad(): 185 | batch_loss = loss_func(batch) 186 | 187 | # Sum up batch loss 188 | loss += batch_loss.item() 189 | 190 | # Compute average batch loss 191 | avg_loss = loss / num_batch 192 | 193 | if optimizer is not None: 194 | print(gradient_norm_list) 195 | 196 | # Return 197 | return avg_loss 198 | 199 | 200 | class BehaviorCloningNet: 201 | def __init__(self, mlp_net: CNMPEncoder): 202 | # Here CNMP encoder can be used as a Behavior cloning net directly. 203 | self.mlp_net = mlp_net 204 | 205 | def get_net_params(self): 206 | """ 207 | Get parameters to be optimized 208 | Returns: 209 | Tuple of parameters of neural networks 210 | 211 | """ 212 | # Decoder 213 | parameters = self.mlp_net.parameters 214 | return (parameters) 215 | 216 | def save_weights(self, log_dir: str, epoch: int): 217 | """ 218 | Save parameters 219 | Args: 220 | log_dir: directory to save weights to 221 | epoch: training epoch 222 | 223 | Returns: 224 | None 225 | 226 | """ 227 | self.mlp_net.save_weights(log_dir, epoch) 228 | 229 | def load_weights(self, log_dir: str, epoch: int): 230 | """ 231 | Load parameters 232 | Args: 233 | log_dir: directory stored weights 234 | epoch: training epoch 235 | 236 | Returns: 237 | None 238 | """ 239 | self.mlp_net.load_weights(log_dir, epoch) 240 | 241 | def predict(self, net_input: torch.Tensor): 242 | """ 243 | Predict using the network 244 | 245 | Args: 246 | net_input: input of network 247 | 248 | Returns: 249 | predicted values 250 | """ 251 | 252 | return self.mlp_net.encode(net_input) 253 | 254 | 255 | -------------------------------------------------------------------------------- /nmp/nn_base.py: -------------------------------------------------------------------------------- 1 | """ 2 | @brief: Classes of Neural Network Bases 3 | """ 4 | import pickle as pkl 5 | from typing import Callable 6 | from typing import Optional 7 | 8 | import torch 9 | from torch import nn as nn 10 | from torch.nn import ModuleList 11 | from torch.nn import functional as F 12 | 13 | import nmp.util as util 14 | 15 | 16 | def get_act_func(key: str) -> Optional[Callable]: 17 | func_dict = dict() 18 | func_dict["tanh"] = torch.tanh 19 | func_dict["relu"] = F.relu 20 | func_dict["leaky_relu"] = F.leaky_relu 21 | func_dict["softplus"] = F.softplus 22 | func_dict["None"] = None 23 | return func_dict[key] 24 | 25 | 26 | class MLP(nn.Module): 27 | def __init__(self, 28 | name: str, 29 | dim_in: int, 30 | dim_out: int, 31 | hidden_layers: list, 32 | act_func: str): 33 | """ 34 | Multi-layer Perceptron Constructor 35 | 36 | Args: 37 | name: name of the MLP 38 | dim_in: dimension of the input 39 | dim_out: dimension of the output 40 | hidden_layers: a list containing hidden layers' dimensions 41 | act_func: activation function 42 | """ 43 | 44 | super(MLP, self).__init__() 45 | 46 | self.mlp_name = name + "_mlp" 47 | 48 | # Initialize the MLP 49 | self.dim_in = dim_in 50 | self.dim_out = dim_out 51 | self.hidden_layers = hidden_layers 52 | self.act_func_type = act_func 53 | self.act_func = get_act_func(act_func) 54 | 55 | # Create networks 56 | # Ugly but useful to distinguish networks in gradient watch 57 | # e.g. if self.mlp_name is "encoder_mlp" 58 | # Then below will lead to self.encoder_mlp = self._create_network() 59 | setattr(self, self.mlp_name, self._create_network()) 60 | 61 | def _create_network(self): 62 | """ 63 | Create MLP Network 64 | 65 | Returns: 66 | MLP Network 67 | """ 68 | 69 | # Total layers (n+1) = hidden layers (n) + output layer (1) 70 | 71 | # Add first hidden layer 72 | mlp = ModuleList([nn.Linear(in_features=self.dim_in, 73 | out_features=self.hidden_layers[0])]) 74 | 75 | # Add other hidden layers 76 | for i in range(1, len(self.hidden_layers)): 77 | mlp.append(nn.Linear(in_features=mlp[-1].out_features, 78 | out_features=self.hidden_layers[i])) 79 | 80 | # Add output layer 81 | mlp.append(nn.Linear(in_features=mlp[-1].out_features, 82 | out_features=self.dim_out)) 83 | 84 | return mlp 85 | 86 | def save(self, log_dir: str, epoch: int): 87 | """ 88 | Save NN structure and weights to file 89 | Args: 90 | log_dir: directory to save weights to 91 | epoch: training epoch 92 | 93 | Returns: 94 | None 95 | """ 96 | 97 | # Get paths to structure parameters and weights respectively 98 | s_path, w_path = util.get_nn_save_paths(log_dir, self.mlp_name, epoch) 99 | 100 | # Store structure parameters 101 | with open(s_path, "wb") as f: 102 | parameters = { 103 | "dim_in": self.dim_in, 104 | "dim_out": self.dim_out, 105 | "hidden_layers": self.hidden_layers, 106 | "act_func_type": self.act_func_type, 107 | } 108 | pkl.dump(parameters, f) 109 | 110 | # Store NN weights 111 | with open(w_path, "wb") as f: 112 | torch.save(self.state_dict(), f) 113 | 114 | def load(self, log_dir: str, epoch: int): 115 | """ 116 | Load NN structure and weights from file 117 | Args: 118 | log_dir: directory stored weights 119 | epoch: training epoch 120 | 121 | Returns: 122 | None 123 | """ 124 | # Get paths to structure parameters and weights respectively 125 | s_path, w_path = util.get_nn_save_paths(log_dir, self.mlp_name, epoch) 126 | 127 | # Check structure parameters 128 | with open(s_path, "rb") as f: 129 | parameters = pkl.load(f) 130 | assert self.dim_in == parameters["dim_in"] \ 131 | and self.dim_out == parameters["dim_out"] \ 132 | and self.hidden_layers == parameters["hidden_layers"] \ 133 | and self.act_func_type == parameters["act_func_type"], \ 134 | "NN structure parameters do not match" 135 | 136 | # Load NN weights 137 | self.load_state_dict(torch.load(w_path)) 138 | 139 | def forward(self, data): 140 | """ 141 | Network forward function 142 | 143 | Args: 144 | data: input data 145 | 146 | Returns: MLP output 147 | 148 | """ 149 | 150 | # Hidden layers (n) + output layer (1) 151 | mlp = eval("self." + self.mlp_name) 152 | for i in range(len(self.hidden_layers)): 153 | data = self.act_func(mlp[i](data)) 154 | data = mlp[-1](data) 155 | 156 | # Return 157 | return data 158 | 159 | 160 | class CNNMLP(nn.Module): 161 | def __init__(self, 162 | name: str, 163 | image_size: list, 164 | kernel_size: int, 165 | num_cnn: int, 166 | cnn_channels: list, 167 | hidden_layers: list, 168 | dim_out: int, 169 | act_func: str): 170 | """ 171 | CNN, MLP constructor 172 | 173 | Args: 174 | name: name of the MLP 175 | image_size: w and h of input images size 176 | kernel_size: size of cnn kernel 177 | num_cnn: number of cnn layers 178 | cnn_channels: a list containing cnn in and out channels 179 | hidden_layers: a list containing hidden layers' dimensions 180 | dim_out: dimension of the output 181 | act_func: activation function 182 | """ 183 | super(CNNMLP, self).__init__() 184 | 185 | self.name = name 186 | self.cnn_mlp_name = name + "_cnn_mlp" 187 | 188 | self.image_size = image_size 189 | self.kernel_size = kernel_size 190 | assert num_cnn + 1 == len(cnn_channels) 191 | self.num_cnn = num_cnn 192 | self.cnn_channels = cnn_channels 193 | self.dim_in = self.get_mlp_dim_in() 194 | self.hidden_layers = hidden_layers 195 | self.dim_out = dim_out 196 | self.act_func_type = act_func 197 | self.act_func = get_act_func(act_func) 198 | 199 | # Initialize the CNN and MLP 200 | setattr(self, self.cnn_mlp_name, self._create_network()) 201 | 202 | def get_mlp_dim_in(self) -> int: 203 | """ 204 | Compute the input size of mlp layers 205 | Returns: 206 | dim_in 207 | """ 208 | image_out_size = \ 209 | [util.image_output_size(size=s, 210 | num_cnn=self.num_cnn, 211 | cnn_kernel_size=self.kernel_size) 212 | for s in self.image_size] 213 | # dim_in = channel * w * h 214 | dim_in = self.cnn_channels[-1] 215 | for s in image_out_size: 216 | dim_in *= s 217 | return dim_in 218 | 219 | def _create_network(self): 220 | """ 221 | Create CNNs and MLP 222 | 223 | Returns: cnn_mlp 224 | """ 225 | cnn_mlp = ModuleList() 226 | for i in range(self.num_cnn): 227 | in_channel = self.cnn_channels[i] 228 | out_channel = self.cnn_channels[i + 1] 229 | cnn_mlp.append(nn.Conv2d(in_channel, out_channel, self.kernel_size)) 230 | 231 | # Initialize the MLP 232 | cnn_mlp.append(MLP(name=self.name, 233 | dim_in=self.dim_in, 234 | dim_out=self.dim_out, 235 | hidden_layers=self.hidden_layers, 236 | act_func=self.act_func_type)) 237 | return cnn_mlp 238 | 239 | def save(self, log_dir: str, epoch: int): 240 | """ 241 | Save NN structure and weights to file 242 | Args: 243 | log_dir: directory to save weights to 244 | epoch: training epoch 245 | 246 | Returns: 247 | None 248 | """ 249 | 250 | # Get paths to structure parameters and weights respectively 251 | s_path, w_path = util.get_nn_save_paths(log_dir, 252 | self.cnn_mlp_name, 253 | epoch) 254 | 255 | # Store structure parameters 256 | with open(s_path, "wb") as f: 257 | parameters = { 258 | "num_cnn": self.num_cnn, 259 | "cnn_channels": self.cnn_channels, 260 | "kernel_size": self.kernel_size, 261 | "image_size": self.image_size, 262 | "dim_in": self.dim_in, 263 | "hidden_layers": self.hidden_layers, 264 | "dim_out": self.dim_out, 265 | "act_func_type": self.act_func_type 266 | } 267 | pkl.dump(parameters, f) 268 | 269 | # Store NN weights 270 | with open(w_path, "wb") as f: 271 | torch.save(self.state_dict(), f) 272 | 273 | def load(self, log_dir: str, epoch: int): 274 | """ 275 | Load NN structure and weights from file 276 | Args: 277 | log_dir: directory stored weights 278 | epoch: training epoch 279 | 280 | Returns: 281 | None 282 | """ 283 | # Get paths to structure parameters and weights respectively 284 | s_path, w_path = util.get_nn_save_paths(log_dir, 285 | self.cnn_mlp_name, 286 | epoch) 287 | 288 | # Load structure parameters 289 | with open(s_path, "rb") as f: 290 | parameters = pkl.load(f) 291 | assert self.num_cnn == parameters["num_cnn"] \ 292 | and self.cnn_channels == parameters["cnn_channels"] \ 293 | and self.kernel_size == parameters["kernel_size"] \ 294 | and self.image_size == parameters["image_size"] \ 295 | and self.dim_in == parameters["dim_in"] \ 296 | and self.hidden_layers == parameters["hidden_layers"] \ 297 | and self.dim_out == parameters["dim_out"] \ 298 | and self.act_func_type == parameters["act_func_type"], \ 299 | "NN structure parameters do not match" 300 | 301 | # Load NN weights 302 | self.load_state_dict(torch.load(w_path)) 303 | 304 | def forward(self, data): 305 | """ 306 | Network forward function 307 | 308 | Args: 309 | data: input data 310 | 311 | Returns: CNN + MLP output 312 | """ 313 | 314 | # Reshape images batch to [num_traj * num_obs, C, H, W] 315 | num_traj, num_obs = data.shape[:2] 316 | data = data.reshape(-1, *data.shape[2:]) 317 | 318 | cnns = eval("self." + self.cnn_mlp_name)[:-1] 319 | mlp = eval("self." + self.cnn_mlp_name)[-1] 320 | 321 | # Forward pass in CNNs 322 | # todo, check if dropout is critical to training case 323 | for i in range(len(cnns)-1): 324 | data = self.act_func(F.max_pool2d(cnns[i](data), 2)) 325 | data = self.act_func(F.max_pool2d( 326 | F.dropout2d(cnns[-1](data), training=self.training), 2)) 327 | 328 | # Flatten 329 | data = data.view(num_traj, num_obs, self.dim_in) 330 | 331 | # Forward pass in MLPs 332 | data = mlp(data) 333 | 334 | # Return 335 | return data 336 | 337 | 338 | class GruRnn(nn.Module): 339 | def __init__(self, 340 | name: str, 341 | dim_in: int, 342 | dim_out: int, 343 | num_layers: int, 344 | seed: int): 345 | """ 346 | Gated Recurrent Unit of RNN 347 | 348 | Args: 349 | name: name of the GRU 350 | dim_in: dimension of the input 351 | dim_out: dimension of the output 352 | num_layers: number of hidden layers 353 | seed: seed for random behaviours 354 | """ 355 | 356 | super(GruRnn, self).__init__() 357 | 358 | self.name = name 359 | self.gru_name = name + "_gru" 360 | 361 | self.dim_in = dim_in 362 | self.dim_out = dim_out 363 | self.num_layers = num_layers 364 | self.seed = seed 365 | 366 | # Create networks 367 | setattr(self, self.gru_name, self._create_network()) 368 | 369 | def _create_network(self): 370 | """ 371 | Create GRU Network 372 | 373 | Returns: 374 | GRU Network 375 | """ 376 | gru = nn.GRU(input_size=self.dim_in, 377 | hidden_size=self.dim_out, 378 | num_layers=self.num_layers, 379 | batch_first=True) 380 | 381 | return gru 382 | 383 | def save(self, log_dir: str, epoch: int): 384 | """ 385 | Save NN structure and weights to file 386 | Args: 387 | log_dir: directory to save weights to 388 | epoch: training epoch 389 | 390 | Returns: 391 | None 392 | """ 393 | 394 | # Get paths to structure parameters and weights respectively 395 | s_path, w_path = util.get_nn_save_paths(log_dir, self.gru_name, epoch) 396 | 397 | # Store structure parameters 398 | with open(s_path, "wb") as f: 399 | parameters = { 400 | "dim_in": self.dim_in, 401 | "dim_out": self.dim_out, 402 | "num_layers": self.num_layers, 403 | "seed": self.seed, 404 | } 405 | pkl.dump(parameters, f) 406 | 407 | # Store NN weights 408 | with open(w_path, "wb") as f: 409 | torch.save(self.state_dict(), f) 410 | 411 | def load(self, log_dir: str, epoch: int): 412 | """ 413 | Load NN structure and weights from file 414 | Args: 415 | log_dir: directory stored weights 416 | epoch: training epoch 417 | 418 | Returns: 419 | None 420 | """ 421 | # Get paths to structure parameters and weights respectively 422 | s_path, w_path = util.get_nn_save_paths(log_dir, self.gru_name, epoch) 423 | 424 | # Load structure parameters 425 | with open(s_path, "rb") as f: 426 | parameters = pkl.load(f) 427 | assert self.dim_in == parameters["dim_in"] \ 428 | and self.dim_out == parameters["dim_out"] \ 429 | and self.num_layers == parameters["num_layers"] \ 430 | and self.seed == parameters["seed"] 431 | "NN structure parameters do not match" 432 | 433 | # Load NN weights 434 | self.load_state_dict(torch.load(w_path)) 435 | 436 | def forward(self, input_data): 437 | """ 438 | Network forward function 439 | 440 | Args: 441 | input_data: input data 442 | 443 | Returns: GRU output 444 | 445 | """ 446 | data = input_data 447 | 448 | gru = eval("self." + self.gru_name) 449 | data = gru(data) 450 | 451 | # Return 452 | return data 453 | -------------------------------------------------------------------------------- /nmp/others/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALRhub/ProDMP_RAL/78063bdb4c9ad04e8a16d7b5a14d4077de774082/nmp/others/__init__.py -------------------------------------------------------------------------------- /nmp/others/ellipses_noise.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from addict import Dict 3 | 4 | # create dictionary with parameters 5 | cfg = Dict() 6 | 7 | cfg.args.n_noisy = 3 # number of noisy images 8 | cfg.noise.n_ellipses = 2 9 | cfg.noise.radius.low = 5 10 | cfg.noise.radius.high = 10 11 | cfg.noise.gaussian_var = 0.25 12 | 13 | cfg.ds.res = 32 # image resolution 14 | cfg.ds.n_channels = 1 # number of color channels 15 | 16 | 17 | # @torch.jit.script 18 | def create_elliptic_mask(size: int, center: torch.Tensor, radius: torch.Tensor, 19 | ellip: torch.Tensor): 20 | """ 21 | 22 | :param size: (scalar), e.g. x_res=y_res=32 23 | :param center: (n_ellipses=4, n_noisy=3, n_dim=2 (xy)) 24 | :param radius: (n_ellipses=4, n_noisy=3) 25 | :param ellip: (n_ellipses=4, n_noisy=3), ellip=1 creates a circle 26 | :return: (n_noisy=3, size=64, size=64]) 27 | """ 28 | x = torch.arange(size, dtype=torch.float32)[:, None] # (64, 1) 29 | y = torch.arange(size, dtype=torch.float32)[None] # (1, 64) 30 | 31 | # distance of each pixel to the ellipsis' center (4, 3, 64, 64) 32 | dist_from_center = torch.sqrt( 33 | ellip[:, :, :, None, None] * (x - center[:, :, :, 0:1, None]) ** 2 34 | + (y - center[:, :, :, 1:2, None]) ** 2 / ellip[:, :, :, None, None]) 35 | # dist_from_center = torch.sqrt(ellip*(x - center[0])**2 + (y - center[1])**2/ellip) 36 | 37 | masks = dist_from_center <= radius[:, :, :, None, None] 38 | mask, _ = torch.max(masks, dim=1) 39 | return mask # (n_noisy=3, size=64, size=64]) 40 | 41 | 42 | # @torch.jit.script 43 | def apply_mask_and_noise(mask: torch.Tensor, noise: torch.Tensor, 44 | img: torch.Tensor, n_noisy: int, 45 | n_channels: int): # , translation: torch.Tensor 46 | imgs = img.repeat(1, n_noisy + 1, 1, 1, 1) 47 | 48 | if n_channels == 3: 49 | # apply noise and mask on all RGB color channels equally 50 | noise = noise.repeat(1, 3, 1, 1) 51 | mask = mask[:, None].repeat(1, 3, 1, 1) 52 | else: 53 | mask = mask[:, :, None] 54 | 55 | imgs[:, 0:n_noisy] *= mask # apply noise mask 56 | imgs[:, 0:n_noisy] += noise # apply additive (Gaussian) noise 57 | imgs[:, 0:n_noisy] = imgs[:, 0:n_noisy].clamp_(min=0, max=1) 58 | return imgs 59 | 60 | 61 | class EllipseNoiseTransform: 62 | def __init__(self, seed=None): 63 | self.seed = seed 64 | if seed: 65 | print("Init EllipseNoiseTransform with seed", seed) 66 | self.gen = torch.Generator() 67 | self.gen.manual_seed(seed) 68 | else: 69 | self.gen = None 70 | 71 | def reset_random_generator(self): 72 | self.gen.manual_seed(self.seed) 73 | 74 | def __call__(self, img): 75 | # img: torch tensor (3, 64, 64), float32 76 | 77 | n_noisy = cfg.args.n_noisy 78 | 79 | # imgs = torch.zeros((n_noisy + 1, img.size(0), img.size(1), img.size(2))) 80 | 81 | radius = torch.randint(low=cfg.noise.radius.low, 82 | high=cfg.noise.radius.high, 83 | size=( 84 | img.shape[0], cfg.noise.n_ellipses, n_noisy), 85 | generator=self.gen) 86 | center = torch.randint(low=1, high=cfg.ds.res - 2, 87 | size=( 88 | img.shape[0], cfg.noise.n_ellipses, n_noisy, 89 | 2), 90 | generator=self.gen) 91 | ellip = torch.rand(size=(img.shape[0], cfg.noise.n_ellipses, n_noisy), 92 | generator=self.gen) + 0.5 93 | # translation = torch.randint(low=-cfg.noise.translation.abs, high=cfg.noise.translation.abs, size=(2, ), generator=self.gen) 94 | gaussian_noise = cfg.noise.gaussian_var * torch.randn( 95 | size=(img.shape[0], n_noisy, 1, img.shape[-2], img.shape[-1]), 96 | generator=self.gen) if cfg.noise.gaussian_var else torch.tensor(0) 97 | 98 | # imgs[-1] = img 99 | mask = create_elliptic_mask(size=img.shape[-1], center=center, 100 | radius=radius, 101 | ellip=ellip) # (n_ellipses=4, n_noisy=3, size=64, size=64]) 102 | 103 | return apply_mask_and_noise(mask, gaussian_noise, img, n_noisy, 104 | n_channels=cfg.ds.n_channels) # (4, 1, 64, 64) 105 | 106 | 107 | if __name__ == '__main__': 108 | transform = EllipseNoiseTransform() 109 | 110 | img = torch.ones(size=(1, 28, 28)) * 0.5 # create gray example image 111 | 112 | img_transformed = transform(img) 113 | 114 | # visualization 115 | import matplotlib.pyplot as plt 116 | 117 | fig, axes = plt.subplots(1, cfg.args.n_noisy) 118 | for i in range(cfg.args.n_noisy): 119 | axes[i].imshow(img_transformed[i, 0], cmap='gray') 120 | 121 | plt.imshow(img_transformed[0, 0], cmap='gray') 122 | plt.show() 123 | -------------------------------------------------------------------------------- /nmp/util/__init__.py: -------------------------------------------------------------------------------- 1 | from .util_data_structure import * 2 | from .util_debug import * 3 | from .util_file import * 4 | from .util_geometry import * 5 | from .util_learning import * 6 | from .util_matrix import * 7 | from .util_numerical import * 8 | from .util_media import * 9 | from .util_string import * 10 | from .util_hyperparams import * 11 | from .util import * -------------------------------------------------------------------------------- /nmp/util/util.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Ge Li, ge.li@kit.edu 3 | @brief: Utilities 4 | """ 5 | 6 | # Import Python libs 7 | import csv 8 | import json 9 | import os 10 | import random 11 | 12 | import numpy as np 13 | import pandas as pd 14 | from mnist import MNIST 15 | from natsort import os_sorted 16 | import nmp.util as util 17 | 18 | 19 | 20 | 21 | 22 | def read_dataset(dataset_name: str, 23 | shuffle: bool = False, 24 | seed=None) -> (list, list): 25 | """ 26 | Read raw data from files 27 | 28 | Args: 29 | dataset_name: name of dataset to be read 30 | shuffle: shuffle the order of dataset files when reading 31 | seed: random seed 32 | 33 | Returns: 34 | list_pd_df: a list of pandas DataFrames with time-dependent data 35 | list_pd_df_static: ... time-dependent data, can be None 36 | 37 | """ 38 | # Get dir to dataset 39 | dataset_dir = util.get_dataset_dir(dataset_name) 40 | 41 | # Get all data-file names 42 | file_names = util.get_file_names_in_directory(dataset_dir) 43 | 44 | # Check file names for both time-dependent and time-independent data exist 45 | num_files = len(file_names) 46 | file_names = os_sorted(file_names) 47 | 48 | # Check if both time-dependent and time-independent dataset exist 49 | if all(['static' in name for name in file_names]): 50 | # Only time-independent dataset 51 | list_pd_df = [pd.DataFrame() for data_file in file_names] 52 | # Construct a empty dataset for time independent data 53 | list_pd_df_static = [pd.read_csv(os.path.join(dataset_dir, data_file), 54 | quoting=csv.QUOTE_ALL) 55 | for data_file in file_names] 56 | elif all(['static' not in name for name in file_names]): 57 | # Only time-dependent dataset 58 | list_pd_df = [pd.read_csv(os.path.join(dataset_dir, data_file), 59 | quoting=csv.QUOTE_ALL) 60 | for data_file in file_names] 61 | # Construct a empty dataset for time independent data 62 | list_pd_df_static = [pd.DataFrame() for data_file in file_names] 63 | else: 64 | # Both exist 65 | assert \ 66 | all(['static' not in name for name in file_names[:num_files // 2]]) 67 | assert all(['static' in name for name in file_names[num_files // 2:]]) 68 | 69 | # Read data from files and generate list of pandas DataFrame 70 | list_pd_df = [pd.read_csv(os.path.join(dataset_dir, data_file), 71 | quoting=csv.QUOTE_ALL) 72 | for data_file in file_names[:num_files // 2]] 73 | list_pd_df_static = [pd.read_csv(os.path.join(dataset_dir, data_file), 74 | quoting=csv.QUOTE_ALL) 75 | for data_file in file_names[num_files // 2:]] 76 | 77 | if shuffle: 78 | list_zip = list(zip(list_pd_df, list_pd_df_static)) 79 | random.seed(seed) 80 | random.shuffle(list_zip) 81 | list_pd_df, list_pd_df_static = zip(*list_zip) 82 | 83 | # Return 84 | return list_pd_df, list_pd_df_static 85 | -------------------------------------------------------------------------------- /nmp/util/util_data_structure.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities of data type and structure 3 | """ 4 | from typing import List 5 | from typing import Literal 6 | from typing import Tuple 7 | from typing import Union 8 | 9 | import numpy as np 10 | import torch 11 | 12 | import nmp.util as util 13 | 14 | 15 | def current_device(): 16 | """ 17 | Return current torch default device 18 | 19 | Returns: "cpu" or "gpu" 20 | 21 | """ 22 | if not hasattr(current_device, "device"): 23 | return "cpu" 24 | else: 25 | return current_device.device 26 | 27 | 28 | def use_cpu(): 29 | """ 30 | Switch to cpu tensor 31 | Returns: 32 | None 33 | """ 34 | torch.set_default_tensor_type('torch.FloatTensor') 35 | current_device.device = "cpu" 36 | 37 | 38 | def use_cuda() -> bool: 39 | """ 40 | Check if GPU is available and set default torch datatype 41 | 42 | Returns: 43 | None 44 | """ 45 | if torch.cuda.is_available(): 46 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 47 | current_device.device = "cuda" 48 | # torch.multiprocessing.set_start_method(method="spawn") 49 | 50 | return True 51 | else: 52 | current_device.device = "cpu" 53 | return False 54 | 55 | 56 | def make_iterable(data: any, default: Literal['tuple', 'list'] = 'tuple') \ 57 | -> Union[Tuple, List]: 58 | """ 59 | Make data a tuple or list, i.e. (data) or [data] 60 | Args: 61 | data: some data 62 | default: default type 63 | Returns: 64 | (data) if it is not a tuple 65 | """ 66 | if isinstance(data, tuple): 67 | return data 68 | elif isinstance(data, list): 69 | return data 70 | else: 71 | if default == 'tuple': 72 | return (data,) # Do not use tuple() 73 | elif default == 'list': 74 | return [data, ] 75 | else: 76 | raise NotImplementedError 77 | 78 | 79 | def from_string_to_array(s: str) -> np.ndarray: 80 | """ 81 | Convert string in Pandas DataFrame cell to numpy array 82 | Args: 83 | s: string, e.g. "[1.0 2.3 4.5 \n 5.3 5.6]" 84 | 85 | Returns: 86 | 1D numpy array 87 | """ 88 | return np.asarray(s[1:-1].split(), 89 | dtype=np.float64) 90 | 91 | 92 | def to_np(tensor: Union[np.ndarray, torch.Tensor]) -> np.ndarray: 93 | """ 94 | Transfer any type and device of tensor to a numpy ndarray 95 | Args: 96 | tensor: np.ndarray, cpu tensor or gpu tensor 97 | 98 | Returns: 99 | tensor in np.ndarray 100 | """ 101 | if is_np(tensor): 102 | return tensor 103 | elif is_ts(tensor): 104 | if tensor.device.type == "cpu": 105 | return tensor.numpy() 106 | elif tensor.device.type == "cuda": 107 | return tensor.cpu().numpy() 108 | raise NotImplementedError 109 | 110 | 111 | def to_nps(*tensors: [Union[np.ndarray, torch.Tensor]]) -> [np.ndarray]: 112 | """ 113 | transfer a list of any type of tensors to np.ndarray 114 | Args: 115 | tensors: a list of tensors 116 | 117 | Returns: 118 | a list of np.ndarray 119 | """ 120 | return [to_np(tensor) for tensor in tensors] 121 | 122 | 123 | def is_np(data: any) -> bool: 124 | """ 125 | is data a numpy array? 126 | """ 127 | return isinstance(data, np.ndarray) 128 | 129 | 130 | def to_ts(data: Union[int, float, np.ndarray, torch.Tensor], 131 | dtype: Literal["float32", "float64"] = "float32", 132 | device: Literal["cpu", "cuda"] = "cpu") -> torch.Tensor: 133 | """ 134 | Transfer any numerical input to a torch tensor in default data type + device 135 | 136 | Args: 137 | device: device of the tensor, default: cpu 138 | dtype: data type of tensor, float 32 or float 64 (double) 139 | data: float, np.ndarray, torch.Tensor 140 | 141 | Returns: 142 | tensor in torch.Tensor 143 | """ 144 | if dtype == "float32": 145 | data_type = torch.float32 146 | elif dtype == "float64": 147 | data_type = torch.float64 148 | else: 149 | raise NotImplementedError 150 | 151 | if isinstance(data, float) or isinstance(data, int): 152 | return torch.tensor(data, dtype=data_type, device=device) 153 | elif is_ts(data): 154 | return data.clone().detach().to(device).type(data_type) 155 | 156 | elif is_np(data): 157 | return torch.tensor(data, dtype=data_type, device=device) 158 | else: 159 | raise NotImplementedError 160 | 161 | 162 | def to_tss(*datas: [Union[int, float, np.ndarray, torch.Tensor]], 163 | dtype: Literal["float32", "float64"] = "float32", 164 | device: Literal["cpu", "cuda"] = "cpu") \ 165 | -> [torch.Tensor]: 166 | """ 167 | transfer a list of any type of numerical input to a list of tensors in given 168 | data type and device 169 | 170 | Args: 171 | datas: a list of data 172 | dtype: data type of tensor, float 32 or float 64 (double) 173 | device: device of the tensor, default: cpu 174 | 175 | Returns: 176 | a list of np.ndarray 177 | """ 178 | return [to_ts(data, dtype, device) for data in datas] 179 | 180 | 181 | def is_ts(data: any) -> bool: 182 | """ 183 | is data a torch Tensor? 184 | """ 185 | return isinstance(data, torch.Tensor) 186 | 187 | 188 | def to_tensor_dict(np_dict: dict) -> dict: 189 | """ 190 | Transform a nested dict of np.ndarray into a dict of torch tensor 191 | The default tensor device and type shall be used 192 | 193 | Args: 194 | np_dict: np dict 195 | 196 | Returns: 197 | ts_dict: torch dict 198 | """ 199 | ts_dict = dict() 200 | 201 | for name, data in np_dict.items(): 202 | if isinstance(data, dict): 203 | ts_dict[name] = to_tensor_dict(data) 204 | elif is_np(data) or isinstance(data, (list, tuple)): 205 | ts_dict[name] = torch.Tensor(data) 206 | else: 207 | raise NotImplementedError 208 | return ts_dict 209 | 210 | 211 | def to_numpy_dict(ts_dict: dict) -> dict: 212 | """ 213 | Transform a nested dict of torch tensor into a dict of np.ndarray 214 | 215 | Args: 216 | ts_dict: torch dict 217 | 218 | Returns: 219 | np_dict: np dict 220 | """ 221 | np_dict = dict() 222 | 223 | for name, data in ts_dict.items(): 224 | if isinstance(data, dict): 225 | np_dict[name] = to_numpy_dict(data) 226 | elif is_ts(data) or isinstance(data, (list, tuple)): 227 | np_dict[name] = util.to_np(torch.Tensor(data)) 228 | else: 229 | raise NotImplementedError 230 | return np_dict 231 | 232 | 233 | def conv2d_size_out(size: int, kernel_size: int = 5, stride=1) -> int: 234 | """ 235 | Get output size of cnn 236 | 237 | Args: 238 | size: size of input image 239 | kernel_size: kernel size 240 | stride: stride 241 | 242 | Returns: 243 | output size 244 | """ 245 | return (size - (kernel_size - 1) - 1) // stride + 1 246 | 247 | 248 | def maxpool2d_size_out(size: int, kernel_size: int = 2, stride=None) -> int: 249 | """ 250 | Get output size of max-pooling 251 | 252 | Args: 253 | size: size of input image 254 | kernel_size: kernel size 255 | stride: stride 256 | 257 | Returns: 258 | output size 259 | """ 260 | if stride is None: 261 | stride = kernel_size 262 | return conv2d_size_out(size, kernel_size=kernel_size, stride=stride) 263 | 264 | 265 | def image_output_size(size: int, 266 | num_cnn: int, 267 | cnn_kernel_size: int = 5, 268 | cnn_stride: int = 1, 269 | max_pool: bool = True, 270 | maxpool_kernel_size: int = 2, 271 | max_pool_stride: int = None): 272 | """ 273 | Get output size of multiple cnn-maxpool layers 274 | Args: 275 | size: size of input image 276 | num_cnn: number of cnns 277 | cnn_kernel_size 278 | cnn_stride 279 | max_pool 280 | maxpool_kernel_size 281 | max_pool_stride 282 | 283 | Returns: 284 | 285 | """ 286 | for _ in range(num_cnn): 287 | size = conv2d_size_out(size, cnn_kernel_size, cnn_stride) 288 | if max_pool: 289 | size = maxpool2d_size_out(size, maxpool_kernel_size, 290 | max_pool_stride) 291 | 292 | return size 293 | -------------------------------------------------------------------------------- /nmp/util/util_debug.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for debugging 3 | """ 4 | 5 | import time 6 | from typing import Callable 7 | from typing import Optional 8 | from typing import Union 9 | 10 | import matplotlib.pyplot as plt 11 | import numpy as np 12 | import torch 13 | 14 | import nmp.util as util 15 | 16 | 17 | def how_fast(repeat: int, func: Callable, *args, **kwargs): 18 | """ 19 | Test how fast a given function call is 20 | Args: 21 | repeat: number of times to run the function 22 | func: function to be tested 23 | *args: list of arguments used in the function call 24 | 25 | Returns: 26 | avg duration function call 27 | 28 | Raise: 29 | any type of exception when test the function call 30 | """ 31 | run_time_test(lock=True) 32 | try: 33 | for i in range(repeat): 34 | func(*args, **kwargs) 35 | duration = run_time_test(lock=False) 36 | if duration is not None: 37 | print(f"total_time of {repeat} runs: {duration} s") 38 | print(f"avg_time of each run: {duration / repeat} s") 39 | return duration / repeat 40 | except RuntimeError: 41 | raise 42 | except Exception: 43 | raise 44 | 45 | def run_time_test(lock: bool) -> Optional[float]: 46 | """ 47 | A manual running time computing function. It will print the running time 48 | for every second call 49 | 50 | E.g.: 51 | run_time_test(lock=True) 52 | some_func1() 53 | some_func2() 54 | ... 55 | run_time_test(lock=False) 56 | 57 | Args: 58 | lock: flag indicating if time counter starts 59 | 60 | Returns: 61 | None (every first call) or duration (every second call) 62 | 63 | Raise: 64 | RuntimeError if is used in a wrong way 65 | """ 66 | # Initialize function attribute 67 | if not hasattr(run_time_test, "lock_state"): 68 | run_time_test.lock_state = False 69 | run_time_test.last_run_time = time.time() 70 | run_time_test.duration_list = list() 71 | 72 | # Check correct usage 73 | if run_time_test.lock_state == lock: 74 | run_time_test.lock_state = False 75 | raise RuntimeError("run_time_test is wrongly used.") 76 | 77 | # Setup lock 78 | run_time_test.lock_state = lock 79 | 80 | # Update time 81 | if lock is False: 82 | duration = time.time() - run_time_test.last_run_time 83 | run_time_test.duration_list.append(duration) 84 | run_time_test.last_run_time = time.time() 85 | print("duration", duration) 86 | return duration 87 | else: 88 | run_time_test.last_run_time = time.time() 89 | return None 90 | 91 | 92 | def debug_plot(x: Union[np.ndarray, torch.Tensor], 93 | y: [], labels: [] = None, title="debug_plot", grid=True) -> \ 94 | plt.Figure: 95 | """ 96 | One line to plot some variable for debugging, numpy + torch 97 | Args: 98 | x: data used for x-axis, can be None 99 | y: list of data used for y-axis 100 | labels: labels in plots 101 | title: title of current plot 102 | grid: show grid or not 103 | 104 | Returns: 105 | None 106 | """ 107 | fig = plt.figure() 108 | y = util.make_iterable(y) 109 | if labels is not None: 110 | labels = util.make_iterable(labels) 111 | for i, yi in enumerate(y): 112 | yi = util.to_np(yi) 113 | label = labels[i] if labels is not None else None 114 | if x is not None: 115 | x = util.to_np(x) 116 | plt.plot(x, yi, label=label) 117 | else: 118 | plt.plot(yi, label=label) 119 | plt.title(title) 120 | if labels is not None: 121 | plt.legend() 122 | if grid: 123 | plt.grid(alpha=0.5) 124 | plt.show() 125 | return fig 126 | -------------------------------------------------------------------------------- /nmp/util/util_file.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities of files operation 3 | """ 4 | 5 | import os 6 | import shutil 7 | from pathlib import Path 8 | from typing import Optional 9 | from typing import Tuple 10 | from typing import Union 11 | 12 | import numpy as np 13 | import yaml 14 | from natsort import os_sorted 15 | 16 | import nmp.util as util 17 | 18 | 19 | def join_path(*paths: Union[str]) -> str: 20 | """ 21 | 22 | Args: 23 | *paths: paths to join 24 | 25 | Returns: 26 | joined path 27 | """ 28 | return os.path.join(*paths) 29 | 30 | 31 | def mkdir(directory: str, overwrite: bool = False): 32 | """ 33 | 34 | Args: 35 | directory: dir path to make 36 | overwrite: overwrite exist dir 37 | 38 | Returns: 39 | None 40 | 41 | Raise: 42 | FileExistsError if dir exists and overwrite is False 43 | """ 44 | path = Path(directory) 45 | try: 46 | path.mkdir(parents=True, exist_ok=overwrite) 47 | except FileExistsError: 48 | util.error("Directory already exists, remove it before make a new one.") 49 | raise 50 | 51 | 52 | def remove_file_dir(path: str) -> bool: 53 | """ 54 | Remove file or directory 55 | Args: 56 | path: path to directory or file 57 | 58 | Returns: 59 | True if successfully remove file or directory 60 | 61 | """ 62 | if not os.path.exists(path): 63 | return False 64 | elif os.path.isfile(path) or os.path.islink(path): 65 | os.unlink(path) 66 | return True 67 | else: 68 | shutil.rmtree(path) 69 | return True 70 | 71 | 72 | def dir_go_up(num_level: int = 2, current_file_dir: str = "default") -> str: 73 | """ 74 | Go to upper n level of current file directory 75 | Args: 76 | num_level: number of level to go up 77 | current_file_dir: current dir 78 | 79 | Returns: 80 | dir n level up 81 | """ 82 | if current_file_dir == "default": 83 | current_file_dir = os.path.realpath(__file__) 84 | while num_level != 0: 85 | current_file_dir = os.path.dirname(current_file_dir) 86 | num_level -= 1 87 | return current_file_dir 88 | 89 | 90 | def get_dataset_dir(dataset_name: str) -> str: 91 | """ 92 | Get the path to the directory storing the dataset 93 | Args: 94 | dataset_name: name of the dataset 95 | 96 | Returns: 97 | path to the directory storing the dataset 98 | """ 99 | return os.path.join(dir_go_up(2), "dataset", dataset_name) 100 | 101 | 102 | def get_media_dir(media_name: str) -> str: 103 | """ 104 | Get the path to the directory storing the media 105 | Args: 106 | media_name: name of the media 107 | 108 | Returns: 109 | path to the directory storing the media_name 110 | """ 111 | return os.path.join(dir_go_up(2), "media", media_name) 112 | 113 | 114 | def get_config_type() -> set: 115 | """ 116 | Register current config type 117 | Returns: 118 | 119 | """ 120 | return {"local", "mp", "cluster"} 121 | 122 | 123 | def get_config_path(config_name: str, config_type: str = "local") -> str: 124 | """ 125 | Get the path to the config file 126 | Args: 127 | config_name: name of the config file 128 | config_type: configuration type 129 | 130 | Returns: 131 | path to the config file 132 | """ 133 | # Check config type 134 | assert config_type in get_config_type(), \ 135 | "Unknown config type." 136 | return os.path.join(dir_go_up(2), "config", config_type, 137 | config_name + ".yaml") 138 | 139 | 140 | def make_log_dir_with_time_stamp(log_name: str) -> str: 141 | """ 142 | Get the dir to the log 143 | Args: 144 | log_name: log's name 145 | 146 | Returns: 147 | directory to log file 148 | """ 149 | 150 | return os.path.join(dir_go_up(2), "log", log_name, 151 | util.get_formatted_date_time()) 152 | 153 | 154 | def parse_config(config_path: str, config_type: str = "local") -> dict: 155 | """ 156 | Parse config file into a dictionary 157 | Args: 158 | config_path: path to config file 159 | config_type: configuration type 160 | 161 | Returns: 162 | configuration in dictionary 163 | """ 164 | assert config_type in get_config_type(), \ 165 | "Unknown config type" 166 | 167 | all_config = list() 168 | with open(config_path, "r") as f: 169 | for config in yaml.load_all(f, yaml.FullLoader): 170 | all_config.append(config) 171 | if config_type == "cluster": 172 | return all_config 173 | else: 174 | return all_config[0] 175 | 176 | 177 | def dump_config(config_dict: dict, config_name: str, dump_dir: str): 178 | """ 179 | Dump configuration into yaml file 180 | Args: 181 | config_dict: config dictionary to be dumped 182 | config_name: config file name 183 | dump_dir: dir to dump 184 | Returns: 185 | None 186 | """ 187 | 188 | # Generate config path 189 | dump_path = util.join_path(dump_dir, config_name + ".yaml") 190 | 191 | # Remove old config if exists 192 | remove_file_dir(dump_path) 193 | 194 | # Write new config to file 195 | with open(dump_path, "w") as f: 196 | yaml.dump(config_dict, f) 197 | 198 | 199 | def get_file_names_in_directory(directory: str) -> [str]: 200 | """ 201 | Get file names in given directory 202 | Args: 203 | directory: directory where you want to explore 204 | 205 | Returns: 206 | file names in a list 207 | 208 | """ 209 | file_names = None 210 | try: 211 | (_, _, file_names) = next(os.walk(directory)) 212 | if len(file_names) == 0: 213 | file_names = None 214 | except StopIteration as e: 215 | print("Cannot read files from directory: ", directory) 216 | raise StopIteration("Cannot read files from directory") 217 | return os_sorted(file_names) 218 | 219 | 220 | def move_files_from_to(from_dir: str, 221 | to_dir: str, 222 | copy=False): 223 | """ 224 | Move or copy files from one directory to another 225 | Args: 226 | from_dir: from directory A 227 | to_dir: to directory B 228 | copy: True if copy instead of move 229 | 230 | Returns: 231 | None 232 | """ 233 | file_names = get_file_names_in_directory(from_dir) 234 | for file in file_names: 235 | from_path = os.path.join(from_dir, file) 236 | to_path = os.path.join(to_dir, file) 237 | if copy: 238 | shutil.copy(from_path, to_path) 239 | else: 240 | shutil.move(from_path, to_path) 241 | 242 | 243 | def clean_and_get_tmp_dir() -> str: 244 | """ 245 | Get the path to the tmp folder 246 | 247 | Returns: 248 | path to the tmp directory 249 | """ 250 | tmp_path = os.path.join(dir_go_up(2), "tmp") 251 | remove_file_dir(tmp_path) 252 | util.mkdir(tmp_path) 253 | return tmp_path 254 | 255 | 256 | def get_nn_save_paths(log_dir: str, nn_name: str, 257 | epoch: Optional[int]) -> Tuple[str, str]: 258 | """ 259 | Get path storing nn structure parameters and nn weights 260 | Args: 261 | log_dir: directory to log 262 | nn_name: name of NN 263 | epoch: number of training epoch 264 | 265 | Returns: 266 | path to nn structure parameters 267 | path to nn weights 268 | """ 269 | s_path = os.path.join(log_dir, nn_name + "_parameters.pkl") 270 | w_path = os.path.join(log_dir, nn_name + "_weights") 271 | if epoch is not None: 272 | w_path = w_path + "_{:d}".format(epoch) 273 | 274 | return s_path, w_path 275 | 276 | 277 | def save_npz_dataset(dataset_name: str, name: str = None, 278 | overwrite: bool = False, **data_dict): 279 | if name is None: 280 | name = dataset_name 281 | save_dir = get_dataset_dir(dataset_name) 282 | mkdir(save_dir, overwrite=overwrite) 283 | np.savez(join_path(save_dir, name + ".npz"), 284 | **data_dict) 285 | 286 | 287 | def load_npz_dataset(dataset_name: str, name: str = None) -> dict: 288 | if name is None: 289 | name = dataset_name 290 | load_dir = get_dataset_dir(dataset_name) 291 | load_path = join_path(load_dir, name + ".npz") 292 | data_dict = dict(np.load(load_path, allow_pickle=True)) 293 | 294 | for key, value in data_dict.items(): 295 | if value.shape == (): 296 | data_dict[key] = value.item() 297 | 298 | return data_dict 299 | -------------------------------------------------------------------------------- /nmp/util/util_geometry.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities of geometry computation 3 | """ 4 | from typing import Union 5 | import numpy as np 6 | import torch 7 | import nmp.util as util 8 | 9 | # For testing whether a number is close to zero 10 | _FLOAT_EPS = np.finfo(np.float64).eps 11 | _EPS4 = _FLOAT_EPS * 4.0 12 | 13 | 14 | def euler2quat(euler: Union[np.ndarray, torch.Tensor]) \ 15 | -> Union[np.ndarray, torch.Tensor]: 16 | """ 17 | Convert Euler Angles to Quaternions. See rotation.py for notes 18 | Args: 19 | euler: Euler angle 20 | 21 | Returns: 22 | Quaternion, WXYZ 23 | """ 24 | assert euler.shape[-1] == 3, "Invalid shape euler {}".format(euler) 25 | 26 | ai, aj, ak = euler[..., 2] / 2, -euler[..., 1] / 2, euler[..., 0] / 2 27 | si, sj, sk = np.sin(ai), np.sin(aj), np.sin(ak) 28 | ci, cj, ck = np.cos(ai), np.cos(aj), np.cos(ak) 29 | cc, cs = ci * ck, ci * sk 30 | sc, ss = si * ck, si * sk 31 | if util.is_np(euler): 32 | quat = np.zeros(euler.shape[:-1] + (4,), dtype=np.float64) 33 | elif util.is_ts(euler): 34 | quat = torch.zeros(euler.shape[:-1] + (4,)) 35 | else: 36 | raise NotImplementedError 37 | quat[..., 0] = cj * cc + sj * ss 38 | quat[..., 3] = cj * sc - sj * cs 39 | quat[..., 2] = -(cj * ss + sj * cc) 40 | quat[..., 1] = cj * cs - sj * sc 41 | return quat 42 | 43 | 44 | def mat2euler(mat: Union[np.ndarray, torch.Tensor]) \ 45 | -> Union[np.ndarray, torch.Tensor]: 46 | """ 47 | Convert Rotation Matrix to Euler Angles. 48 | 49 | Args: 50 | mat: rotation matrix 51 | 52 | Returns: 53 | euler angle 54 | """ 55 | use_torch = False 56 | if util.is_ts(mat): 57 | mat = util.to_np(mat) 58 | use_torch = True 59 | 60 | assert mat.shape[-2:] == (3, 3), "Invalid shape matrix {}".format(mat) 61 | 62 | cy = np.sqrt( 63 | mat[..., 2, 2] * mat[..., 2, 2] + mat[..., 1, 2] * mat[..., 1, 2]) 64 | condition = cy > _EPS4 65 | euler = np.zeros(mat.shape[:-1], dtype=np.float64) 66 | euler[..., 2] = np.where( 67 | condition, 68 | -np.arctan2(mat[..., 0, 1], mat[..., 0, 0]), 69 | -np.arctan2(-mat[..., 1, 0], mat[..., 1, 1]), 70 | ) 71 | euler[..., 1] = np.where( 72 | condition, -np.arctan2(-mat[..., 0, 2], cy), 73 | -np.arctan2(-mat[..., 0, 2], cy) 74 | ) 75 | euler[..., 0] = np.where( 76 | condition, -np.arctan2(mat[..., 1, 2], mat[..., 2, 2]), 0.0 77 | ) 78 | if use_torch: 79 | euler = torch.Tensor(euler) 80 | return euler 81 | 82 | 83 | def quat2mat(quat: Union[np.ndarray, torch.Tensor]) \ 84 | -> Union[np.ndarray, torch.Tensor]: 85 | """ 86 | Convert Quaternion to Euler Angles. 87 | 88 | Args: 89 | quat: quaternion, WXYZ 90 | 91 | Returns: 92 | rotation matrix 93 | """ 94 | use_torch = False 95 | if util.is_ts(quat): 96 | quat = util.to_np(quat) 97 | use_torch = True 98 | 99 | assert quat.shape[-1] == 4, "Invalid shape quat {}".format(quat) 100 | 101 | w, x, y, z = quat[..., 0], quat[..., 1], quat[..., 2], quat[..., 3] 102 | Nq = np.sum(quat * quat, axis=-1) 103 | s = 2.0 / Nq 104 | X, Y, Z = x * s, y * s, z * s 105 | wX, wY, wZ = w * X, w * Y, w * Z 106 | xX, xY, xZ = x * X, x * Y, x * Z 107 | yY, yZ, zZ = y * Y, y * Z, z * Z 108 | 109 | mat = np.empty(quat.shape[:-1] + (3, 3), dtype=np.float64) 110 | mat[..., 0, 0] = 1.0 - (yY + zZ) 111 | mat[..., 0, 1] = xY - wZ 112 | mat[..., 0, 2] = xZ + wY 113 | mat[..., 1, 0] = xY + wZ 114 | mat[..., 1, 1] = 1.0 - (xX + zZ) 115 | mat[..., 1, 2] = yZ - wX 116 | mat[..., 2, 0] = xZ - wY 117 | mat[..., 2, 1] = yZ + wX 118 | mat[..., 2, 2] = 1.0 - (xX + yY) 119 | result = np.where((Nq > _FLOAT_EPS)[..., np.newaxis, np.newaxis], mat, 120 | np.eye(3)) 121 | 122 | if use_torch: 123 | return torch.Tensor(result) 124 | else: 125 | return result 126 | 127 | 128 | def quat2euler(quat: Union[np.ndarray, torch.Tensor]) \ 129 | -> Union[np.ndarray, torch.Tensor]: 130 | """ 131 | Convert Quaternion to Euler Angles. 132 | Args: 133 | quat: quaternion, WXYZ 134 | 135 | Returns: 136 | euler angles 137 | """ 138 | return mat2euler(quat2mat(quat)) 139 | -------------------------------------------------------------------------------- /nmp/util/util_hyperparams.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities of hyper-parameters and randomness 3 | """ 4 | 5 | import random 6 | 7 | import numpy as np 8 | import torch 9 | from addict import Dict 10 | 11 | 12 | class HyperParametersPool: 13 | def __init__(self): 14 | raise RuntimeError("Do not instantiate this class.") 15 | 16 | @staticmethod 17 | def set_hyperparameters(hp_dict: Dict): 18 | """ 19 | Set runtime hyper-parameters 20 | Args: 21 | hp_dict: dictionary of hyper-parameters 22 | 23 | Returns: 24 | None 25 | """ 26 | if hasattr(HyperParametersPool, "_hp_dict"): 27 | raise RuntimeError("Hyper-parameters already exist") 28 | else: 29 | # Initialize hyper-parameters dictionary 30 | HyperParametersPool._hp_dict = hp_dict 31 | 32 | # Setup random seeds globally 33 | seed = hp_dict.get("seed", 1234) 34 | random.seed(seed) 35 | np.random.seed(seed) 36 | torch.manual_seed(seed) 37 | 38 | @staticmethod 39 | def hp_dict(): 40 | """ 41 | Get runtime hyper-parameters 42 | Returns: 43 | hp_dict: dictionary of hyper-parameters 44 | """ 45 | if not hasattr(HyperParametersPool, "_hp_dict"): 46 | return None 47 | else: 48 | hp_dict = HyperParametersPool._hp_dict 49 | return hp_dict 50 | 51 | 52 | def decide_hyperparameter(obj: any, 53 | run_time_value: any, 54 | parameter_key: str, 55 | parameter_default: any) -> any: 56 | """ 57 | A helper function to determine function's hyper-parameter 58 | Args: 59 | obj: the object asking for hyper-parameter 60 | run_time_value: runtime value, will be used if it is not None 61 | parameter_key: the key to search in the hyper-parameters pool 62 | parameter_default: use this value if neither runtime nor config value 63 | 64 | Returns: 65 | the parameter following the preference 66 | - if runtime value is given, use it 67 | - else if find it in the config pool, use that one 68 | - else use the default value 69 | """ 70 | if run_time_value is not None: 71 | return run_time_value 72 | elif hasattr(obj, parameter_key): 73 | return getattr(obj, parameter_key) 74 | else: 75 | hp_dict = HyperParametersPool.hp_dict() 76 | if hp_dict is not None \ 77 | and parameter_key in hp_dict.keys(): 78 | actual_value = hp_dict.get(parameter_key) 79 | setattr(obj, parameter_key, actual_value) 80 | return actual_value 81 | else: 82 | return parameter_default 83 | 84 | 85 | def mlp_arch_3_params(avg_neuron: int, num_hidden: int, shape: float) -> [int]: 86 | """ 87 | 3 params way of specifying dense net, mostly for hyperparameter optimization 88 | Originally from Optuna work 89 | 90 | Args: 91 | avg_neuron: average number of neurons per layer 92 | num_hidden: number of layers 93 | shape: parameters between -1 and 1: 94 | shape < 0: "contracting" network, i.e, layers get smaller, 95 | for extrem case (shape = -1): 96 | first layer 2 * avg_neuron neurons, 97 | last layer 1 neuron, rest interpolating 98 | shape 0: all layers avg_neuron neurons 99 | shape > 0: "expanding" network, i.e., representation gets larger, 100 | for extrem case (shape = 1) 101 | first layer 1 neuron, 102 | last layer 2 * avg_neuron neurons, rest interpolating 103 | 104 | Returns: 105 | architecture: list of integers representing the number of neurons of 106 | each layer 107 | """ 108 | 109 | assert avg_neuron >= 0 110 | assert -1.0 <= shape <= 1.0 111 | assert num_hidden >= 1 112 | shape = shape * avg_neuron # we want the user to provide shape \in [-1, +1] 113 | architecture = [] 114 | for i in range(num_hidden): 115 | # compute real-valued 'position' x of current layer (x \in (-1, 1)) 116 | x = 2 * i / (num_hidden - 1) - 1 if num_hidden != 1 else 0.0 117 | # compute number of units in current layer 118 | d = shape * x + avg_neuron 119 | d = int(np.floor(d)) 120 | if d == 0: # occurs if shape == -avg_neuron or shape == avg_neuron 121 | d = 1 122 | architecture.append(d) 123 | return architecture 124 | -------------------------------------------------------------------------------- /nmp/util/util_learning.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities of learning operation 3 | """ 4 | from typing import Union 5 | import numpy as np 6 | import torch 7 | import nmp.util as util 8 | 9 | 10 | def joint_to_conditional(joint_mean: Union[np.ndarray, torch.Tensor], 11 | joint_L: Union[np.ndarray, torch.Tensor], 12 | sample_x: Union[np.ndarray, torch.Tensor]) -> \ 13 | [Union[np.ndarray, torch.Tensor]]: 14 | """ 15 | Given joint distribution p(x,y), and a sample of x, do: 16 | Compute conditional distribution p(y|x) 17 | Args: 18 | joint_mean: mean of joint distribution 19 | joint_L: cholesky distribution of joint distribution 20 | sample_x: samples of x 21 | 22 | Returns: 23 | conditional mean and L 24 | """ 25 | 26 | # Shape of joint_mean: 27 | # [*add_dim, dim_x + dim_y] 28 | # 29 | # Shape of joint_L: 30 | # [*add_dim, dim_x + dim_y, dim_x + dim_y] 31 | # 32 | # Shape of sample_x: 33 | # [*add_dim, dim_x] 34 | # 35 | # Shape of conditional_mean: 36 | # [*add_dim, dim_y] 37 | # 38 | # Shape of conditional_cov: 39 | # [*add_dim, dim_y, dim_y] 40 | 41 | # Check dimension 42 | dim_x = sample_x.shape[-1] 43 | # dim_y = joint_mean.shape[-1] - dim_x 44 | 45 | # Decompose joint distribution parameters 46 | mu_x = joint_mean[..., :dim_x] 47 | mu_y = joint_mean[..., dim_x:] 48 | 49 | L_x = joint_L[..., :dim_x, :dim_x] 50 | L_y = joint_L[..., dim_x:, dim_x:] 51 | L_x_y = joint_L[..., dim_x:, :dim_x] 52 | 53 | if util.is_ts(joint_mean): 54 | cond_mean = mu_y + \ 55 | torch.einsum('...ik,...lk,...lm,...m->...i', L_x_y, L_x, 56 | torch.cholesky_inverse(L_x), sample_x - mu_x) 57 | elif util.is_np(joint_mean): 58 | # Scipy cho_solve does not support batch operation 59 | cond_mean = mu_y + \ 60 | np.einsum('...ik,...lk,...lm,...m->...i', L_x_y, L_x, 61 | torch.cholesky_inverse(torch.from_numpy( 62 | L_x)).numpy(), 63 | sample_x - mu_x) 64 | else: 65 | raise NotImplementedError 66 | 67 | cond_L = L_y 68 | 69 | return cond_mean, cond_L 70 | -------------------------------------------------------------------------------- /nmp/util/util_matrix.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities of matrix operation 3 | """ 4 | from typing import Optional 5 | from typing import Union 6 | 7 | import numpy as np 8 | import torch 9 | 10 | 11 | def build_lower_matrix(param_diag: torch.Tensor, 12 | param_off_diag: Optional[torch.Tensor]) -> torch.Tensor: 13 | """ 14 | Compose the lower triangular matrix L from diag and off-diag elements 15 | It seems like faster than using the cholesky transformation from PyTorch 16 | Args: 17 | param_diag: diagonal parameters 18 | param_off_diag: off-diagonal parameters 19 | 20 | Returns: 21 | Lower triangular matrix L 22 | 23 | """ 24 | dim_pred = param_diag.shape[-1] 25 | # Fill diagonal terms 26 | L = param_diag.diag_embed() 27 | if param_off_diag is not None: 28 | # Fill off-diagonal terms 29 | [row, col] = torch.tril_indices(dim_pred, dim_pred, -1) 30 | L[..., row, col] = param_off_diag[..., :] 31 | 32 | return L 33 | 34 | 35 | def transform_to_cholesky(mat: torch.Tensor) -> torch.Tensor: 36 | """ 37 | Transform an unconstrained matrix to cholesky, will abandon half of the data 38 | Args: 39 | mat: an unconstrained square matrix 40 | 41 | Returns: 42 | lower triangle matrix as Cholesky 43 | """ 44 | lct = torch.distributions.transforms.LowerCholeskyTransform(cache_size=0) 45 | return lct(mat) 46 | 47 | 48 | def add_expand_dim(data: Union[torch.Tensor, np.ndarray], 49 | add_dim_indices: [int], 50 | add_dim_sizes: [int]) -> Union[torch.Tensor, np.ndarray]: 51 | """ 52 | Add additional dimensions to tensor and expand accordingly 53 | Args: 54 | data: tensor to be operated. Torch.Tensor or numpy.ndarray 55 | add_dim_indices: the indices of added dimensions in the result tensor 56 | add_dim_sizes: the expanding size of the additional dimensions 57 | 58 | Returns: 59 | result: result tensor after adding and expanding 60 | """ 61 | num_data_dim = data.ndim 62 | num_dim_to_add = len(add_dim_indices) 63 | 64 | add_dim_reverse_indices = [num_data_dim + num_dim_to_add + idx 65 | for idx in add_dim_indices] 66 | 67 | str_add_dim = "" 68 | str_expand = "" 69 | add_dim_index = 0 70 | for dim in range(num_data_dim + num_dim_to_add): 71 | if dim in add_dim_indices or dim in add_dim_reverse_indices: 72 | str_add_dim += "None, " 73 | str_expand += str(add_dim_sizes[add_dim_index]) + ", " 74 | add_dim_index += 1 75 | else: 76 | str_add_dim += ":, " 77 | if type(data) == torch.Tensor: 78 | str_expand += "-1, " 79 | elif type(data) == np.ndarray: 80 | str_expand += "1, " 81 | else: 82 | raise NotImplementedError 83 | 84 | str_add_dime_eval = "data[" + str_add_dim + "]" 85 | if type(data) == torch.Tensor: 86 | return eval("eval(str_add_dime_eval).expand(" + str_expand + ")") 87 | else: 88 | return eval("np.tile(eval(str_add_dime_eval),[" + str_expand + "])") 89 | 90 | 91 | def to_cholesky(diag_vector=None, off_diag_vector=None, 92 | L=None, cov_matrix=None): 93 | """ 94 | Compute Cholesky matrix 95 | Args: 96 | diag_vector: diag elements in a vector 97 | off_diag_vector: off-diagonal elements in a vector 98 | L: Cholesky matrix 99 | cov_matrix: Covariance matrix 100 | 101 | Returns: 102 | Cholesky matrix L 103 | """ 104 | if L is not None: 105 | pass 106 | elif cov_matrix is None: 107 | assert diag_vector is not None and off_diag_vector is not None 108 | L = build_lower_matrix(diag_vector, off_diag_vector) 109 | elif diag_vector is None and off_diag_vector is None: 110 | L = torch.linalg.cholesky(cov_matrix) 111 | else: 112 | raise RuntimeError("Unexpected behaviours") 113 | return L 114 | 115 | 116 | def tensor_linspace(start: Union[float, int, torch.Tensor], 117 | end: Union[float, int, torch.Tensor], 118 | steps: int) -> torch.Tensor: 119 | """ 120 | Vectorized version of torch.linspace. 121 | Modified from: 122 | https://github.com/zhaobozb/layout2im/blob/master/models/bilinear.py#L246 123 | 124 | Args: 125 | start: start value, scalar or tensor 126 | end: end value, scalar or tensor 127 | steps: num of steps 128 | 129 | Returns: 130 | linspace tensor 131 | """ 132 | # Shape of start: 133 | # [*add_dim, dim_data] or a scalar 134 | # 135 | # Shape of end: 136 | # [*add_dim, dim_data] or a scalar 137 | # 138 | # Shape of out: 139 | # [*add_dim, steps, dim_data] 140 | 141 | # - out: Tensor of shape start.size() + (steps,), such that 142 | # out.select(-1, 0) == start, out.select(-1, -1) == end, 143 | # and the other elements of out linearly interpolate between 144 | # start and end. 145 | 146 | if isinstance(start, torch.Tensor) and not isinstance(end, torch.Tensor): 147 | end += torch.zeros_like(start) 148 | elif not isinstance(start, torch.Tensor) and isinstance(end, torch.Tensor): 149 | start += torch.zeros_like(end) 150 | elif isinstance(start, torch.Tensor) and isinstance(end, torch.Tensor): 151 | assert start.size() == end.size() 152 | else: 153 | return torch.linspace(start, end, steps) 154 | 155 | view_size = start.size() + (1,) 156 | w_size = (1,) * start.dim() + (steps,) 157 | out_size = start.size() + (steps,) 158 | 159 | start_w = torch.linspace(1, 0, steps=steps).to(start) 160 | start_w = start_w.view(w_size).expand(out_size) 161 | end_w = torch.linspace(0, 1, steps=steps).to(start) 162 | end_w = end_w.view(w_size).expand(out_size) 163 | 164 | start = start.contiguous().view(view_size).expand(out_size) 165 | end = end.contiguous().view(view_size).expand(out_size) 166 | 167 | out = start_w * start + end_w * end 168 | out = torch.einsum('...ji->...ij', out) 169 | return out 170 | 171 | 172 | def indexing_interpolate(data: torch.Tensor, 173 | indices: torch.Tensor) -> torch.Tensor: 174 | """ 175 | Indexing values from a given tensor's data, using non-integer indices and 176 | thus apply interpolation. 177 | 178 | Args: 179 | data: data tensor from where indexing happens 180 | indices: float indices tensor 181 | 182 | Returns: 183 | indexed and interpolated data 184 | """ 185 | # Shape of data: 186 | # [num_data, *dim_data] 187 | # 188 | # Shape of indices: 189 | # [*add_dim, num_indices] 190 | # 191 | # Shape of interpolate_result: 192 | # [*add_dim, num_indices, *dim_data] 193 | 194 | ndim_data = data.ndim - 1 195 | indices_0 = torch.clip(indices.floor().long(), 0, 196 | data.shape[-data.ndim] - 2) 197 | indices_1 = indices_0 + 1 198 | weights = indices - indices_0 199 | if ndim_data > 0: 200 | weights = add_expand_dim(weights, 201 | range(indices.ndim, indices.ndim + ndim_data), 202 | [-1] * ndim_data) 203 | interpolate_result = torch.lerp(data[indices_0], data[indices_1], weights) 204 | return interpolate_result 205 | -------------------------------------------------------------------------------- /nmp/util/util_media.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for generating media stuff 3 | """ 4 | 5 | from typing import List 6 | from typing import Literal 7 | from typing import Union 8 | 9 | import numpy as np 10 | import torch 11 | from matplotlib import animation 12 | from matplotlib import pyplot as plt 13 | 14 | import nmp.util as util 15 | 16 | 17 | def savefig(figs: Union[plt.Figure, List[plt.Figure]], media_name, 18 | fmt=Literal['pdf', 'png', 'jpeg'], dpi=200, overwrite=False): 19 | """ 20 | 21 | Args: 22 | figs: figure object or a list of figures 23 | media_name: name of the media 24 | fmt: format of the figures 25 | dpi: resolution 26 | overwrite: if overwrite when old exists 27 | 28 | Returns: 29 | None 30 | 31 | """ 32 | path = util.get_media_dir(media_name) 33 | util.mkdir(path, overwrite=overwrite) 34 | 35 | figs = util.make_iterable(figs) 36 | 37 | for i, fig in enumerate(figs): 38 | fig_path = util.join_path(path, str(i) + '.' + fmt) 39 | fig.savefig(fig_path, dpi=dpi, bbox_inches="tight") 40 | 41 | 42 | def save_subfig(fig: plt.Figure, axes: np.ndarray, 43 | ax_coordinate: [List[int], List[List[int]]], 44 | media_name: str, fmt=Literal['pdf', 'png', 'jpeg'], dpi=200, 45 | overwrite=False, 46 | x_scale=0.2, y_scale=0.2, x_offset=-0.1, y_offset=-0.1): 47 | """ 48 | Save subplots as individual plots 49 | Args: 50 | fig: figure object 51 | axes: axes in np array 52 | ax_coordinate: which subplots you want to save 53 | media_name: name of the save dir 54 | fmt: format to save as 55 | dpi: resolution 56 | overwrite: overwrite the dir if it exists already 57 | x_scale: scale of the bounding box in x direction 58 | y_scale: scale of the bounding box in y direction 59 | x_offset: offset of the bounding box in x direction 60 | y_offset: offset of the bounding box in y direction 61 | 62 | Returns: 63 | None 64 | """ 65 | if isinstance(ax_coordinate[0], int): 66 | ax_coordinate = [ax_coordinate, ] 67 | fig.tight_layout() 68 | path = util.get_media_dir(media_name) 69 | util.mkdir(path, overwrite=overwrite) 70 | 71 | for coord in ax_coordinate: 72 | ax = axes[coord[0], coord[1]] 73 | ext = ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted()) 74 | subfig_name = str(coord).replace(", ", "_") + '.' + fmt 75 | fig_path = util.join_path(path, subfig_name) 76 | bbox_inches = ext.expanded(1 + x_scale, 77 | 1 + y_scale).translated(x_offset, y_offset) 78 | fig.savefig(fig_path, dpi=dpi, bbox_inches=bbox_inches) 79 | 80 | 81 | def from_figures_to_video(figure_list: [], video_name: str, 82 | interval: int = 2000, overwrite=False) -> str: 83 | """ 84 | Generate and save a video given a list of figures 85 | Args: 86 | figure_list: list of matplotlib figure objects 87 | video_name: name of video 88 | interval: interval between two figures in [ms] 89 | overwrite: if overwrite when old exists 90 | Returns: 91 | path to the saved video 92 | """ 93 | figure, ax = plt.subplots() 94 | figure.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) 95 | ax.margins(0, 0) 96 | plt.gca().xaxis.set_major_locator(plt.NullLocator()) 97 | plt.gca().yaxis.set_major_locator(plt.NullLocator()) 98 | 99 | frames = [] 100 | 101 | video_path = util.get_media_dir(video_name) 102 | util.mkdir(video_path, overwrite) 103 | for i, fig in enumerate(figure_list): 104 | fig.savefig(util.join_path(video_path, "{}.png".format(i)), dpi=300, 105 | bbox_inches="tight") 106 | 107 | for j in range(len(figure_list)): 108 | image = plt.imread(util.join_path(video_path, "{}.png".format(j))) 109 | img = plt.imshow(image, animated=True) 110 | plt.axis('off') 111 | plt.gca().set_axis_off() 112 | 113 | frames.append([img]) 114 | 115 | ani = animation.ArtistAnimation(figure, frames, interval=interval, 116 | blit=True, 117 | repeat=False) 118 | save_path = util.join_path(video_path, video_name + '.mp4') 119 | ani.save(save_path, dpi=300) 120 | 121 | return save_path 122 | 123 | 124 | def fill_between(x: Union[np.ndarray, torch.Tensor], 125 | y_mean: Union[np.ndarray, torch.Tensor], 126 | y_std: Union[np.ndarray, torch.Tensor], 127 | axis=None, std_scale: int = 2, draw_mean: bool = False, 128 | alpha=0.2, color='gray'): 129 | """ 130 | Utilities to draw std plot 131 | Args: 132 | x: x value 133 | y_mean: y mean value 134 | y_std: standard deviation of y 135 | axis: figure axis to draw 136 | std_scale: filling range of [-scale * std, scale * std] 137 | draw_mean: plot mean curve as well 138 | alpha: transparency of std plot 139 | color: color to fill 140 | 141 | Returns: 142 | None 143 | """ 144 | x, y_mean, y_std = util.to_nps(x, y_mean, y_std) 145 | if axis is None: 146 | axis = plt.gca() 147 | if draw_mean: 148 | axis.plot(x, y_mean) 149 | axis.fill_between(x=x, 150 | y1=y_mean - std_scale * y_std, 151 | y2=y_mean + std_scale * y_std, 152 | alpha=alpha, color=color) 153 | -------------------------------------------------------------------------------- /nmp/util/util_numerical.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities of numerical computation 3 | """ 4 | from typing import Union, Optional 5 | import torch 6 | import numpy as np 7 | import nmp.util as util 8 | 9 | 10 | def to_log_space(data: Union[np.ndarray, torch.Tensor], 11 | lower_bound: Optional[float]) \ 12 | -> Union[np.ndarray, torch.Tensor]: 13 | """ 14 | project data to log space 15 | 16 | Args: 17 | data: original data 18 | lower_bound: customized lower bound in runtime, will override the 19 | default value 20 | 21 | Returns: log(data + lower_bound) 22 | 23 | """ 24 | # Determine lower bound, runtime? config? default? 25 | actual_lower_bound = util.decide_hyperparameter(to_log_space, lower_bound, 26 | "log_lower_bound", 1e-8) 27 | # Compute 28 | assert data.min() >= 0 29 | if type(data) == np.ndarray: 30 | log_data = np.log(data + actual_lower_bound) 31 | elif type(data) == torch.Tensor: 32 | log_data = torch.log(data + actual_lower_bound) 33 | else: 34 | raise NotImplementedError 35 | return log_data 36 | 37 | 38 | def to_softplus_space(data: Union[np.ndarray, torch.Tensor], 39 | lower_bound: Optional[float]) -> \ 40 | Union[np.ndarray, torch.Tensor]: 41 | """ 42 | Project data to exp space 43 | 44 | Args: 45 | data: original data 46 | lower_bound: runtime lower bound of the result 47 | 48 | Returns: softplus(data) + lower_bound 49 | 50 | """ 51 | # todo, should we use a fixed lower bound or adaptive to the values? 52 | # Determine lower bound, runtime? config? default? 53 | actual_lower_bound = \ 54 | util.decide_hyperparameter(to_softplus_space, lower_bound, 55 | "softplus_lower_bound", 1e-2) 56 | # Compute 57 | softplus = torch.nn.Softplus() 58 | sp_result = softplus(data) + actual_lower_bound 59 | return sp_result 60 | 61 | 62 | def interpolate(x_ori: np.ndarray, y_ori: np.ndarray, 63 | num_tar: int) -> np.ndarray: 64 | """ 65 | Interpolates trajectories to desired length and data density 66 | 67 | Args: 68 | x_ori: original data time, shape [num_x] 69 | y_ori: original data value, shape [num_x, dim_y] 70 | num_tar: number of target sequence points 71 | 72 | Returns: 73 | interpolated y data, [num_tar, dim_y] 74 | """ 75 | 76 | # Setup interpolation scale 77 | start, stop = x_ori[0], x_ori[-1] 78 | x_tar = np.linspace(start, stop, num_tar) 79 | 80 | # check y dim 81 | if y_ori.ndim == 1: 82 | y_tar = np.interp(x_tar, x_ori, y_ori) 83 | else: 84 | # Initialize result array as shape 85 | y_tar = np.zeros((num_tar, y_ori.shape[1])) 86 | 87 | # Loop over y's dim 88 | for k in range(y_ori.shape[1]): 89 | y_tar[:, k] = np.interp(x_tar, x_ori, y_ori[:, k]) 90 | 91 | return y_tar 92 | -------------------------------------------------------------------------------- /nmp/util/util_string.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities of string operation and printing stuff 3 | """ 4 | from tabulate import tabulate 5 | from datetime import datetime 6 | 7 | 8 | def print_line(char: str = "=", length: int = 60, 9 | before: int = 0, after: int = 0) -> None: 10 | """ 11 | Print a line with given letter in given length 12 | Args: 13 | char: char for print the line 14 | length: length of line 15 | before: number of new lines before print line 16 | after: number of new lines after print line 17 | 18 | Returns: None 19 | """ 20 | 21 | print("\n" * before, end="") 22 | print(char * length) 23 | print("\n" * after, end="") 24 | # End of function print_line 25 | 26 | 27 | def print_line_title(title: str = "", middle: bool = True, char: str = "=", 28 | length: int = 60, before: int = 1, after: int = 1) -> None: 29 | """ 30 | Print a line with title 31 | Args: 32 | title: title to print 33 | middle: if title should be in the middle, otherwise left 34 | char: char for print the line 35 | length: length of line 36 | before: number of new lines before print line 37 | after: number of new lines after print line 38 | 39 | Returns: None 40 | """ 41 | assert len(title) < length, "Title is longer than line length" 42 | len_before_title = (length - len(title)) // 2 - 1 43 | len_after_title = length - len(title) - (length - len(title)) // 2 - 1 44 | print("\n" * before, end="") 45 | if middle is True: 46 | print(char * len_before_title, "", end="") 47 | print(title, end="") 48 | print("", char * len_after_title) 49 | else: 50 | print(title, end="") 51 | print(" ", char * (length - len(title) - 1)) 52 | print("\n" * after, end="") 53 | # End of function print_line_title 54 | 55 | 56 | def print_wrap_title(title: str = "", char: str = "*", length: int = 60, 57 | wrap: int = 1, before: int = 1, after: int = 1) -> None: 58 | """ 59 | Print title with wrapped box 60 | Args: 61 | title: title to print 62 | char: char for print the line 63 | length: length of line 64 | wrap: number of wrapped layers 65 | before: number of new lines before print line 66 | after: number of new lines after print line 67 | 68 | Returns: None 69 | """ 70 | 71 | assert len(title) < length - 4, "Title is longer than line length - 4" 72 | 73 | len_before_title = (length - len(title)) // 2 - 1 74 | len_after_title = length - len(title) - (length - len(title)) // 2 - 1 75 | 76 | print_line(char=char, length=length, before=before) 77 | for _ in range(wrap - 1): 78 | print(char, " " * (length - 2), char, sep="") 79 | print(char, " " * len_before_title, title, " " * len_after_title, char, 80 | sep="") 81 | 82 | for _ in range(wrap - 1): 83 | print(char, " " * (length - 2), char, sep="") 84 | print_line(char=char, length=length, after=after) 85 | # End of function print_wrap_title 86 | 87 | 88 | def print_table(tabular_data: list, headers: list, 89 | table_format: str = "grid") -> None: 90 | """ 91 | Print nice table in using tabulate 92 | 93 | Example: 94 | print_table(tabular_data=[["value1", "value2"], ["value3", "value4"]], 95 | headers=["headers 1", "headers 2"], 96 | table_format="grid")) 97 | 98 | Args: 99 | tabular_data: data in table 100 | headers: column headers 101 | table_format: format 102 | 103 | Returns: 104 | 105 | """ 106 | print(tabulate(tabular_data, headers, table_format)) 107 | 108 | 109 | def get_formatted_date_time() -> str: 110 | """ 111 | Get formatted date and time, e.g. May-01-2021 22:14:31 112 | Returns: 113 | dt_string: date time string 114 | """ 115 | now = datetime.now() 116 | dt_string = now.strftime("%b-%d-%Y %-H:%-M:%-S") 117 | return dt_string 118 | 119 | 120 | class BColors: 121 | """ 122 | Colors 123 | """ 124 | HEADER = '\033[95m' 125 | OKBLUE = '\033[94m' 126 | OKCYAN = '\033[96m' 127 | OKGREEN = '\033[92m' 128 | WARNING = '\033[93m' 129 | FAIL = '\033[91m' 130 | ENDC = '\033[0m' 131 | BOLD = '\033[1m' 132 | UNDERLINE = '\033[4m' 133 | 134 | 135 | def warn(warn_str: str): 136 | """ 137 | Print a warning string in console 138 | Args: 139 | warn_str: string to be printed 140 | 141 | Returns: 142 | None 143 | """ 144 | print(f"{BColors.WARNING}Warning: " + warn_str + f"{BColors.ENDC}") 145 | 146 | 147 | def error(error_str: str): 148 | """ 149 | Print an error string in console 150 | Args: 151 | error_str: string to be printed 152 | 153 | Returns: 154 | None 155 | """ 156 | print(f"{BColors.FAIL}Error: " + error_str + f"{BColors.ENDC}") 157 | --------------------------------------------------------------------------------