├── .gitignore ├── README.md ├── configs ├── config.yaml ├── dataset │ └── bouncingball.yaml ├── model │ ├── architecture │ │ └── default.yaml │ ├── dkf.yaml │ ├── kvae.yaml │ ├── node.yaml │ ├── rgn.yaml │ ├── rgnres.yaml │ └── vrnn.yaml └── training │ ├── forecast_forecast.yaml │ ├── forecast_recon.yaml │ ├── recon_forecast.yaml │ └── recon_recon.yaml ├── data ├── README.md ├── generate_bouncingball.py ├── generate_hamiltonian.py └── visualize_data.py ├── experiments └── README.md ├── main.py ├── models ├── CommonDynamics.py ├── CommonVAE.py ├── group_a │ ├── DKF.py │ └── VRNN.py ├── group_b1 │ └── KVAE.py └── group_b2 │ ├── NeuralODE.py │ ├── RGN.py │ └── RGNRes.py ├── requirements.txt ├── scripts ├── ablation_generation_length.py └── ablation_odeintegrator.py └── utils ├── dataloader.py ├── layers.py ├── metrics.py ├── plotting.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | data/bouncingball 132 | .idea/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

torch-neural-ssm

2 |

Neural State-Space Models and Latent Dynamic Functions
for High-Dimensional Generative Time-Series Modeling

3 | 4 | 5 | ## About this Repository 6 | 7 | This repository is meant to conceptually introduce and highlight implementation considerations for the recent class of models called Neural State-Space Models (Neural SSMs). They leverage the classic state-space model with the flexibility of deep learning to approach high-dimensional generative time-series modeling and learning latent dynamics functions. 8 | 9 | Included is an abstract PyTorch-Lightning training class with several latent dynamic functions that inherit it, as well as common metrics used in their evaluation and training examples on common datasets. Further broken down via implementation is the distinction between system identification and state estimation approaches, which are reminiscent of their classic SSM counterparts and arise from fundamental differences in the underlying choice of their probabilistic graphical model (PGM). 10 | This repository (currently) focuses primarily on considerations related to training dynamics models for system identification and forecasting rather than per-frame state estimation or filtering. 11 | 12 | Note: This repo is not fully finished and some of the experiments/sections may be incomplete. This is released as public in order to maximize the potential benefit of this repo and hopefully inspire collaboration in improving it. Feel free to check out the "To-Do" section if you're interesting in contributing! 13 | 14 | 15 |

pgm schematic

16 |

Fig 1. Schematic of the two PGM forms of Neural SSMs.

17 | 18 | 19 | 20 | 21 | ## Citation 22 | If you found the information helpful for your work or use portions of this repo in research development, please consider 23 | citing one of the following works: 24 | ``` 25 | @misc{missel2022torchneuralssm, 26 | title={TorchNeuralSSM}, 27 | author={Missel, Ryan}, 28 | publisher={Github}, 29 | journal={Github repository}, 30 | howpublished={\url{https://github.com/qu-gg/torchssm}}, 31 | year={2022}, 32 | } 33 | 34 | @inproceedings{jiangsequentialLVM, 35 | title={Sequential Latent Variable Models for Few-Shot High-Dimensional Time-Series Forecasting}, 36 | author={Jiang, Xiajun and Missel, Ryan and Li, Zhiyuan and Wang, Linwei}, 37 | booktitle={The Eleventh International Conference on Learning Representations} 38 | } 39 | ``` 40 | 41 | 42 | ## Table of Contents 43 | - [About](#about) 44 | - [Citation](#citation) 45 | - [Table of Contents](#toc) 46 | - [Background](#background) 47 | - [Preliminaries](#prelims) 48 | - [What are Neural SSMs?](#neuralSSMwhat) 49 | - [Choice of SSM PGM](#pgmChoice) 50 | - [Latent Initial State Z0 / Zk / Zinit](#initialState) 51 | - [Reconstruction vs. Extrapolation](#reconstructionExtrapolation) 52 | - [System Controls, ut](#ssmControls) 53 | - [Implementation](#implementation) 54 | - [Datasets](#data) 55 | - [Models](#models) 56 | - [Metrics](#metrics) 57 | - [Experiments](#experiments) 58 | - [Hyperparameter Tuning](#hyperparameters) 59 | - [Hamiltionian Systems](#hamiltonian) 60 | - [Miscellaneous](#misc) 61 | - [To-Do](#todo) 62 | - [Contributions](#contributions) 63 | - [References](#references) 64 | 65 | 66 | 67 | # Background 68 | 69 | This section provides an introduction to the concept of Neural SSMs, some common considerations and limitations, and active areas of research. This section assumes some familiarity with state-space models, though little background is needed to gain a conceptual understanding if one is already coming from a latent modeling perspective or Bayesian learning. Resources are available in abundance considering the width and depth of state-space usage, however, this video and modern textbook are good starting points. 70 | 71 | 72 | 73 | ## Preliminaries 74 | Variational Auto-encoders (VAEs): VAEs[28] provide a principled and popular framework to learn the generative model pθ(x|z) behind data x, involving latent variables z that follows a prior distribution p(z). Variational inference over the generative model is facilitated by a variational approximation of the posterior density of latent variables z, in the form of a recognition model qφ(z|x). Parameters of both the generative and recognition models are optimized with the objective to maximize the evidence lower bound (ELBO) of the marginal data likelihood: 75 |

vae equation

76 | where the first term encourages the reconstruction of the observed data, and the second term of 77 | Kullback–Leibler (KL) divergence constrains the estimated posterior density of qφ(z|x) with a pre-defined prior p(z), often assumed to be a zero-mean isotropic Gaussian density. 78 | 79 | 80 | 81 | ## What are Neural SSMs? 82 | An extension of classic state-space models, neural state-space models - at their core - consist of a dynamic function of some latent states z_k and their emission to observations x_k, realized through the equations: 83 | 84 |

neural ssm equations

85 | where θz represents the parameters of the latent dynamic function. The precise form of these functions can vary significantly - from deterministic or stochastic, linear or non-linear, and discrete or continuous. 86 |

87 | Due to their explicit differentiation of transition and emission and leveraging of structured equations, they have found success in learning interpretable latent dynamic spaces[1,2,3], identifying physical systems from non-direct features[4,5,6], and uses in counterfactual forecasting[7,8,14]. 88 |

89 | Given the fast pace of progress in latent dynamics modeling over recent years, many models have been presented under a variety of terminologies and proposed frameworks - examples being variational latent recurrent models[5,9,10,11,12,22], deep state-space models[1,2,3,7,13,14], and deterministic encoding-decoding models[4,15,16]. Despite differences in appearance, they all adhere to the same conceptual framework of latent variable modeling and state-space disentanglement. As such, here we unify them under the terminology of Neural SSMs and segment them into the two base choices of probabilistic graphical models that they adhere to: system identification and state estimation. We highlight each PGM's properties and limitations with experimental evaluations on benchmark datasets. 90 | 91 | 92 | 93 | 94 | ## Choice of PGM - System Identification vs State Estimation 95 | The PGM associated with each approach is determined by the latent variable chosen for inference. 96 | 97 | 98 |

latent variable schematic

99 |

Fig 2. Schematic of latent variable PGMs in Neural SSMS.

100 | 101 | System states as latent variables (State Estimation): The intuitive choice for the latent variable is the 102 | latent state z_k that underlies x_k, given that it is already latent in the system and is directly 103 | associated with the observations. The PGM of this form is shown under Fig. [1A](#pgmSchematic) where its marginal 104 | likelihood over an observed sequence x0:T is written as: 105 | 106 |

state likelihood

107 | 108 | where p(xi | zi) describes the emission model and 109 | p(zi | z, x<i) describes the latent dynamics 110 | function. Given the common intractability of the posterior, parameter inference is performed through a variational 111 | approximation of the posterior density q(z0:T | x0:T), expressed as: 112 | 113 |

state variational posterior

114 | 115 | Given these two components, the standard training objective of the Evidence Lower Bound Objective (ELBO) is thus 116 | derived with the form: 117 | 118 |

state ELBO

119 | 120 | where the first term represents a reconstruction likelihood term over the sequence and the second is a Kullback-Leibler 121 | divergence loss between the variational posterior approximation and some assumed prior of the latent dynamics. This 122 | prior can come in many forms, either being the standard Gaussian Normal in variational inference, flow-based priors 123 | from ODE settings[5], or physics-based priors in problem-specific situations[20]. This is 124 | the primary design choice that separates current works in this area, specifically the modeling of the dynamics prior 125 | and its learned approximation. Many works draw inspiration for modeling this interaction by filtering techniques 126 | in standard SSMs, where a divergence term is constructed between the dynamics-predicted latent state and the 127 | data-corrected observation[7,18]. 128 | 129 | With this formulation, it is easy to see how dynamics models of this type can have a strong reconstructive capacity for 130 | the high-dimensional outputs and contain strong short-term predictions. In addition, input-influenced dynamics are 131 | inherent to the prediction task, as errors in the predictions of the latent dynamics are corrected by true observations 132 | every step. However, given this data-based correction, the resulting inference of 133 | q(zi | z, x<i) is weakened, and without 134 | near-term observations to guide the dynamics function, its long-horizon forecasting is limited[1,3]. 135 | 136 | System parameters as latent variables (System Identification): Rather than system states, one can instead choose 137 | to select the parameters (denoted as θz in Equation [1](#ssmEQ)). With this change, the resulting PGM 138 | is represented in Fig. [1B](#pgmSchematic) and its marginal likelihood over x0:T is represented now by: 139 | 140 |

sysID likelihood

141 | 142 | where the resulting output observations are derived from an initial latent state z0 and the dynamics 143 | parameters θz. As before, a variational approximation is considered for inference in place of an 144 | intractable posterior but now for the density q(θz, z0) instead. Given 145 | prior density assumptions of p(θz) and p(z0) in a similar vein as 146 | above, the ELBO function for this PGM is constructed as: 147 | 148 |

sysID ELBO

149 | where again the first term is a reconstruction likelihood and the terms following represent KL-Divergence losses over 150 | the inferred variables. 151 | 152 |

153 | The given formulation here is the most general form for this line of models and other works can be covered under 154 | special assumptions or restrictions of how q(θz) and p(θz) are 155 | modeled. Original Neural SSM parameter works consider Linear-Gaussian SSMs as the transition function and 156 | introduce non-linearity by varying the transition parameters over time as θz0:T[1,2,3]. 157 | However, as shown in Fig. [2B1](#latentSchematic), the result of this results in convoluted temporal 158 | modeling and devolves into the same state estimation problem as now the time-varying parameters rely on near-term 159 | observations for correctness[8,20]. Rather than time-varying, the system parameters could be considered 160 | an optimized global variable, in which the underlying dynamics function becomes a Bayesian neural network in a VAE's 161 | latent space[5] and is shown in Fig. [2B2](#latentSchematic). Restricting these parameters to 162 | be deterministic results in a model of the form presented in Latent ODE[10]. The furthest restriction in 163 | forgoing stochasticity in the inference of z0 results in the suite of models as presented in [4]. 164 | 165 |

166 | Regardless of the precise assumptions, this framework builds a strong latent dynamics function that enables long-term 167 | forecasting and, in some settings, even full-scale system identification[1,4] of physical systems. This is 168 | done at the cost of a harder inference task given no access to dynamics correction during generation and for full 169 | identification tasks, often requires a large number of training samples over the potential system state space[4,5]. 170 | 171 | 172 | 173 | ## Latent Initial State Z0 / Zinit 174 | 175 | As the transition dynamics and the observation space are intentionally disconnected in this framework, 176 | the problem of inferring a strong initial latent state from which to forecast is an important consideration 177 | when designing a neural state-space model[30]. This is primarily a task- and data-dependent choice, 178 | in which the architecture follows the data structure. Thankfully, much work has been done in other research 179 | directions on designing good latent encoding models. As such, works in this area often draw from them. 180 | This section is split into three parts - one on the usual architecture for high-dimensional image tasks, 181 | one on lower-dimensional and/or miscellaneous encoders, and one on the different forms of inference for the initial 182 | state depending on which sequence portions are observed. 183 | 184 | Image-based Encoders: Unsurprisingly, the common architecture used in latent image encoding is a convolutional 185 | neural network (CNN) given its inherent bias toward spatial feature extraction[1,3,4,5]. Works are mixed 186 | between either having the sequential input reshaped as frames stacked over the channel dimension or simply running 187 | the CNN over each observed frame separately and passing the concatenated embedding into an output layer. Regardless 188 | of methodology, a few frames are assumed as observations for initialization, as multiple timesteps are required to 189 | infer the initial system movement. A subset of works considers second-order latent vector spaces, in which the 190 | encoder is explicitly split into two individual position and momenta functions, taking single and multiple frames 191 | respectively[5]. 192 | 193 | 194 |

initial state visualization

195 |

Fig N. Visualization of the stacked initial state encoder, modified from [23].

196 | 197 | 198 | Alternate Encoders: In settings with non-image-based inputs, the initial latent encoder can take on a large variety of forms, ranging anywhere from simple linear/MLP networks in physical systems[5] to graph convolution networks for latent medical image forecasting[20]. Multi-modal and dynamics conditioning inputs can be leveraged via combinations of encoders whose embeddings go through a shared linear function. 199 | 200 | 201 |

alternate encoder visualization

202 |

Fig N. Visualization of the stacked graph convolutional encoder, modified from [24].

203 | 204 | 205 | Variables z0, zk, and zinit: 206 | Beyond just the inference of this latent variable, there is one more variation that can be seen throughout literature - 207 | that of which portions of the input sequence are observed and used in the initial state inference. 208 | 209 | Generally, there are 3 forms seen: 210 |
    211 |
  • z0 - which uses a sequence x0:k to get z0.
  • 212 |
  • zk - which uses previous frames x-k:k to get zk to go forward past observed frames.
  • 213 |
  • zinit - which uses the sequence x0:k to get an abstract initial vector state that the dynamics function starts from.
  • 214 |
215 | 216 | Throughout literature, these variable names as shown here aren't used (as most works just call it z0 217 | and describe its inference) but we differentiate it specifically to highlight the distinctions. 218 | For training purspoes, it is a subtle distinction but potentially has implications for the resulting l 219 | ikelihood optimization and learned vector space. 220 | 221 | 222 |

initial latent variable comparison

223 |

Fig N. Schematic of the difference between z0 and zinit formulations. 224 | 225 | Saying that, generally there is a lack of work exploring the considerations for each approach, besides ad-hoc solutions to bridge 226 | the gap between the latent encoder and dynamics function distributions[5]. This gap can stem from 227 | optimization problems caused by imbalanced reconstruction terms between dynamics and initial states or in cases 228 | where the initial state distribution is far enough away from the data distribution of downstream frames. 229 | 230 | However, a recent work "Learning Neural State-Space Models: Do we need a state estimator?" [30] is the first detailed 231 | study into the considerations of initial state inference, providing ablations across increasing difficulties of 232 | datasets and inference forms. In their work, they found that to get competitive performance of neural SSMs on some 233 | dynamical systems, more advanced architectures were required (feed-forward or LSTM networks). Notably, they only evaluate 234 | on the zk form, varying architectural choices. 235 | 236 | A variety of empirical techniques have been proposed to tackle this gap, much in the same spirit of empirical 237 | VAE stability 'tricks.' These include separated x0 and x1:T terms 238 | (where x0 has a positive weighting coefficient), VAE pre-training for x0, 239 | and KL-regularization terms between the output distributions of the encoder and the dynamics flow[1,5]. 240 | One personal intuition regarding these two variable approaches and the tricks applied is that there exists 241 | a theoretical trade-off between the two formulations and the tricks applied help to empirically alleviate the 242 | shortcomings of either approach. This, however, requires experimentation and validation before any claims can be made. 243 | 244 | 245 | 246 | ## Reconstruction vs. Extrapolation 247 | 248 | There are three important phases during the forecasting for a neural SSM, that of initial state inference, reconstruction, 249 | and extrapolation. 250 | 251 |

reconstruction vs extrapolation

252 |

Fig N. Breakdown of the three forecasting phases - initial state inference, reconstruction, and extrapolation.

253 | 254 | Initial State Inference: Inferring the initial state and how many frames are required to get a good 255 | initialization is fairly domain/problem specific, as each problem may require more or less time to highlight 256 | distinctive patterns that enable effective dynamics separation. 257 | 258 | Reconstruction: The former refers to the number of timesteps that are used in training, from which the 259 | likelihood term is calculated. So far in works, there is no generally agreed upon standard on how many steps to use 260 | in this and works can be seen using anywhere from 1 (i.e. next-step prediction) to 60 frames in this portion[4]. 261 | Some works frame this as a hyper-parameter to tune in experiments and there is a consideration of computational cost 262 | when scaling up to longer sequences. In our experiments, we've noticed a linear scaling in training time w.r.t. 263 | this sequence length. (TO-DO) In the Experiments section, we perform an ablation study on how fixed lengths of reconstruction affects the 264 | extrapolation ability of models on Hamiltonian systems. 265 | 266 | Extrapolation: This phase refers to the arbitrarily long forecasting of frames that goes beyond the length used 267 | in the likelihood term during training. It represents whether a model has captured the system dynamics sufficiently to 268 | enable long-term forecasting or model energy decay in non-conserving systems. For specific 269 | dynamical systems, this can be a difficult task as, at base, there is no training signal to inform the model to learn 270 | good extrapolation. Works often highlight metrics independently on reconstruction and extrapolation phases to highlight 271 | a model's strength of identification[4]. 272 | 273 | Training Considerations: It is important to note that the exact structure of how the likelihood loss is formulated 274 | plays a role in how this sequence length may affect extrapolation. Having your likelihood incorporate temporal information 275 | (e.g. summation over the sequence, trajectory mean, etc.) can have a detrimental effect on extrapolation as the model 276 | optimizes with respect to the fixed reconstruction length. Figure N highlights an example of using temporal information 277 | in a likelihood term, where there is near flawless reconstruction but immediate forecasting failure when going towards 278 | extrapolation. 279 | 280 |

reconstruction vs extrapolation

281 |

Fig N. Example of failed extrapolation given an incorrect likelihood term. Red highlights beginning of extrapoolation.

282 | 283 | As well, it is often the case where the reconstruction training metrics (e.g. likelihood/pixel MSE) and visualizations 284 | will often show strong convergence despite still poor extrapolation. It can sometimes be the case, especially in Neural ODE 285 | latent dynamics, that more training than expected is required to enable strong extrapolation. It is an intuition that 286 | the vector field may require a longer optimization than just the reconstruction convergence to be 287 | robust against error accumulation that impacts long-horizon forecasting. 288 | 289 |

reconstruction vs extrapolation

290 |

Fig N. Training vs. Validation pixel MSE metrics, highlight the continued extrapolation learning past training "convergence."

291 | 292 | 293 | Tips for training good extrapolation in these models include: 294 |
    295 |
  1. Perform extrapolation in your validation steps such that there is a metric to highlight extrapolation learning over training.
  2. 296 |
  3. Use per-frame averages in the likelihood function rather than any form with temporal information.
  4. 297 |
  5. Use variable lengths of reconstruction during training, sampling 1-T frames to reconstruct in a given batch.
  6. 298 |
  7. If you have long sequences, especially in non-conserving systems, sample a random starting point per batch.
  8. 299 |
  9. Train for longer than you might expect, even when training metrics have converged for "reconstruction."
  10. 300 |
  11. The integrator choice can affect this, as non-symplectic integrators have known error accumulation which affects the vector state over long horizons[4]
  12. 301 |
302 | 303 | 304 | 305 | ## Use of System Controls, ut 306 | 307 | Insofar we have ignored another common and important component of state-space modeling, the incorporation of external 308 | controls u that affect the transition function of the state. Controls represent factors that influence the 309 | trajectory of a system but are not direct features of the object/system being modeled. For example, an external 310 | force such as friction acting on a moving ball or medications given to a patient could be considered 311 | controls[8,14]. These allow an additional layer of interpretability in SSMs and even enable counterfactual 312 | reasoning; i.e., given the current state, what does its trajectory look like under varying control inputs going 313 | forwards? This has myriad uses in medical modeling with counterfactual medicine[14] or physical system 314 | simulations[8]. 315 |

316 | 317 | For Neural SSMs, a variety of approaches have been taken thus far depending on the type of latent transition function used. 318 |

319 | 320 | Linear Dynamics: In latent dynamics still parameterized by traditional linear gaussian transition functions, 321 | control incorporation is as easy as the addition of another transition matrix Bt that modifies a 322 | control input ut at each timestep[1,2,4,7]. 323 | 324 |

linear control

325 |

Fig N. Example of control input in a linear transition function[1].

326 |

327 | 328 | Non-Linear Dynamics: In discrete non-linear transition matrices using either multi-layer perceptrons or 329 | recurrent cells, these can be leveraged by either concatenating it to the input vector before the network forward 330 | pass or as a data transformation in the form of element-wise addition and a weighted combination[10]. 331 | 332 |

non-linear control

333 |

Fig N. Example of control input in a non-linear transition function[1].

334 |

335 | 336 | Continuous Dynamics: For incorporation into continuous latent dynamics functions, finding the best approaches 337 | is an ongoing topic of interest. Thus far, the reigning approaches are: 338 | 339 | 1. Directly jumping the vector field state with recurrent cells[18] 340 |

jump control

341 | 342 | 2. Influencing the vector field gradient (e.g. neural controlled differential equations)[17] 343 |

gradient control

344 | 345 | 3. Introducing another dynamics mechanism, continuous or otherwise (e.g. neural ODE or attention blocks), that is combined with the latent trajectory z1:T into an auxiliary state h1:T[8,14,25]. 346 |

continuous control

347 |

Fig N. Visualization of the IMODE architecture, taken from [8].

348 | 349 | 350 | 352 | 353 | 354 | 358 | 359 | 360 | 361 | 362 | # Implementation 363 | 364 | In this section, specifics on model implementation and the datasets/metrics used are detailed. Specific data generation details are available in the URLs provided for each dataset. The models and datasets used throughout this repo are solely grayscale physics datasets with underlying Hamiltonian laws, such as pendulum and mass-spring sets. Extensions to color images and non-pixel-based tasks (or even graph-based data!) are easily done in this framework, as the only architecture change needed is the structure of the encoder and decoder networks as the state propagation happens solely in a latent space. 365 | 366 | The project's folder structure is as follows: 367 | 368 | ``` 369 | torchssm/ 370 | │ 371 | ├── main.py # Training entry point that takes in user args and handles boilerplate 372 | ├── tune.py # Performs a hyperparameter search for a given dataset using Ray[Tune] 373 | ├── README.md # What you're reading right now :^) 374 | ├── requirements.txt # Anaconda requirements file to enable easy setup 375 | ├── configs/ 376 | │ ├── dataset/ # Dataset config files 377 | │ ├── model/ # Model-specific hyperparameters 378 | │ ├── training/ # Training-setting parameters 379 | | └── config.yaml # Base config file 380 | ├── data/ 381 | | ├── # Name of the stored dynamics dataset (e.g. pendulum) 382 | | ├── generate_bouncingball.py # Dataset generation script for Bouncing Ball 383 | | ├── generate_hamiltonian.py # Dataset generation script for Hamiltonian Dynamics 384 | | └── visualize_data.py # Visualize trajectories of a dataset 385 | ├── experiments/ 386 | | └── # Name of the dynamics model run 387 | | └── # Given name for the ran experiment 388 | | └── / # Each experiment type has its sequential lightning logs saved 389 | ├── models/ 390 | │ ├── CommonDynamics.py # Abstract PyTorch-Lightning Module to handle train/test loops 391 | │ ├── CommonVAE.py # Shared encoder/decoder Modules for the VAE portion 392 | │ ├── group_a/ 393 | │ └── ... # State-Estimation Examples 394 | │ └── group_b1/ 395 | │ └── ... # Time-Varying System-Identification Examples 396 | │ └── group_b2/ 397 | │ └── ... # Time-Invariant System-Identification Examples 398 | ├── utils/ 399 | │ ├── dataloader.py # Dataloaders used in train/val/testing 400 | │ ├── layers.py # PyTorch Modules that represent general network layers 401 | │ ├── metrics.py # Metric functions for evaluation 402 | │ ├── plotting.py # Plotting functions for visualizatin 403 | | └── utils.py # General utility functions (e.g. argparsing, etc) 404 | └── 405 | ``` 406 | 407 | 408 | 409 | ## Config Management with Hydra 410 | 411 | This repository uses [Hydra](https://hydra.cc/) to manage configurations for running experiments, ensuring modularity, scalability, and ease of experimentation. Hydra allows for a hierarchical organization of configuration files, making it straightforward to adjust parameters for datasets, models, and training setups without editing the core code. 412 | 413 | The configuration files are organized under the `configs/` directory and follow a modular design: 414 | 415 | - **`dataset/`**: Contains configuration files specific to datasets, such as paths, preprocessing steps, and dataset-specific parameters. For example, `pendulum.yaml` might specify the length of the pendulum and time step resolution. 416 | - **`model/`**: Defines model-specific parameters like architecture, latent dimension size, and dynamics function hyperparameters. 417 | - **`training/`**: Manages training-specific settings, such as learning rate, batch size, number of epochs, and checkpointing options. 418 | - **`config.yaml`**: The base configuration file that aggregates default values and sets up common parameters shared across experiments. 419 | 420 | Usage: Hydra facilitates experiment customization and parameter sweeping with ease. By specifying the desired configuration components, users can dynamically compose configurations at runtime. For instance: 421 | 422 | ```bash 423 | python main.py dataset=pendulum model=node training=default 424 | ``` 425 | 426 | 427 | 428 | 429 | ## Data 430 | 431 | All data used throughout these experiments are available for download here on Google Drive, in which they already come in their .npz forms. Feel free to generate your own sets using the provided data scripts! 432 | 433 | Hamiltonian Dynamics: Provided are a dataloader and generation scripts for DeepMind's Hamiltonian Dynamics 434 | suite[4], a simulation library for 17 different physics datasets that have known underlying Hamiltonian dynamics. 435 | It comes in the form of color image sequences of arbitrary length, coupled with the system's ground truth states (e.g., for pendulum, angular velocity and angle). It is well-benchmarked and customizable, making it a perfect testbed for latent dynamics function evaluation. For each setting, the physical parameters are tweakable alongside an optional friction coefficient to construct non-energy conserving systems. The location of focal points and the color of the objects are all individually tuneable, enabling mixed and complex visual datasets of varying latent dynamics. 436 | 437 |

pendulum examples

438 |

Fig N. Pendulum-Colors Examples.

439 | 440 | For the base presented experiments of this dataset, we consider and evaluate grayscale versions of pendulum and 441 | mass-spring - which conveniently are just the sliced red channel of the original sets. Each set has 10000 442 | training and 1000 testing trajectories sampled at Δt = 1 time intervals. Energy conservation 443 | is preserved without friction and we assume constant placement of focal points for simplicity. Note that the 444 | modification to color outputs in this framework is as simple as modifying the number of channels in the 445 | encoder and decoder. 446 | 447 |

448 | Bouncing Balls: Additionally, we provide a dataloader and generation scripts for the standard latent dynamics 449 | dataset of bouncing balls[1,2,5,7,8], modified from the implementation in 450 | [1]. It consists of a ball or multiple 451 | balls moving within a bounding box while being affected by potential external effects, e.g. gravitational 452 | forces[1,2,5], pong[2], and interventions[8]. The starting position, angle, 453 | and velocity of the ball(s) are sampled uniformly between a set range. It is generated with the 454 | PyMunk and PyGame libraries. 455 | In this repository, we consider two sets - a simple set of one gravitational force and a mixed set of 4 gravitational 456 | forces in the cardinal directions with varying strengths. We similarly generate 10000 training and 457 | 1000 testing trajectories, however sample them at Δt = 0.1 intervals. 458 | 459 |

bouncing ball examples

460 |

Fig N. Single Gravity Bouncing Ball Example.

461 | 462 |

463 | Notably, this system is surprisingly difficult to successfully perform long-term generation on, especially in cases 464 | of mixed gravities or multiple objects. Most works only report on generation within 5-15 timesteps following a 465 | period of 3-5 observation timesteps[1,2,7] and farther timesteps show lost trajectories and/or 466 | incoherent reconstructions. 467 | 468 |

469 | Meta-Learning Datasets: One of the latest research directions for neural SSMs is evaluating the potential of 470 | meta-learning to build domain-adaptable latent dynamics functions[26,27,29]. A representative dataset 471 | example for this task is the Turbulent Flow dataset that is affected by various buoyancy forces, highlighting a 472 | task with partially shared yet heterogeneous dynamics[27]. 473 | 474 |

turbulent flow examples

475 |

Fig N. Turbulent Flow Example, sourced from [27].

476 | 477 |

478 | Multi-System Dynamics: So far in the literature, the majority of works only consider training Neural SSMs on 479 | one system of dynamics at a time - with the most variety lying in that of differing trajectory hyper-parameters. 480 | The ability to infer multiple dynamical systems under one model (or learn to output dynamical functions given 481 | system observations) and leverage similarities between the sets is an ongoing research pursuit - with applications 482 | of neural unit hypernetworks[27] and dynamics functions conditioned on sequences via 483 | meta-learning[26,29] being the first dives into this. 484 | 485 |

486 | Other Sets in Literature: Outside of the previous sets, there are a plethora of other datasets that have been 487 | explored in relevant work. The popular task of human motion prediction in both the pose estimation and video 488 | generation setting has been considered via datasets 489 | Human3.6Mil, 490 | CMU Mocap, and 491 | Weizzman-Action[5,19], 492 | though proper experimentation into this area would require problem-specific architectures given the depth of the 493 | existing field. Past high-dimensionality and image-space, standard benchmark datasets in time-series forecasting 494 | have also been considered, including the M4, 495 | Electricity Transformer Temperature (ETT), and 496 | the NASA Turbofan Degradation set. 497 | Recent works have begun looking at medical applications in inverse image reconstruction and the incorporation of 498 | physics-inspired priors[20,29y ]. Regardless of the success of Neural SSMs on these tasks, the task-agnostic 499 | factor and principled structure of this framework make it a versatile and appealing option for generative time-series modeling. 500 | 501 | 502 | 503 | 504 | ## Models 505 | 506 | Here, details on how the model implementation is structured and running experiments locally are given. As well, 507 | an overview of the abstract class implementation for a general Neural SSM and its types are explained. 508 | 509 | ### Implementation Structure 510 | Provided within this repository is a PyTorch class structure in which an abstract PyTorch-Lightning Module is shared 511 | across all the given models, from which the specific VAE and dynamics functions inherit and override the relevant 512 | forward functions for training and evaluation. Swapping between dynamics functions and PGM type is as easy as passing 513 | in the model's name for arguments, e.g. `python3 main.py model=node`. As the implementation is provided in 514 | PyTorch-Lightning, an optimization and boilerplate 515 | library for PyTorch, it is recommended to be familiar at face-level. 516 | 517 |

518 | For every model run, a new experiment version under `experiments/` related to the passed in naming arguments. Hyperparameters passed in for this run are both stored in the Tensorboard instance created as well as in the local files hparams.yaml, config.json. During training and validation sequences, all of the metrics below are automatically tracked and saved into a Tensorboard instance 519 | which can be used to compare different model runs following. Every 500 batches, reconstruction sequences against the 520 | ground truth for a set of samples are saved to the `experiments/` folder. Currently, only one checkpoint is saved based 521 | on the last epoch ran rather than checkpoints based on the best validation score or at set epochs. Restarting training 522 | from a checkpoint or loading in a model for testing is done currently by specifying the ckpt_path to the 523 | base experiment folder and the checkpt filename. 524 | 525 |

526 | The implemented dynamics functions are each separated into their respective PGM groups, however, they can still share 527 | the same general classes. Each dynamics subclass has a model_specific_loss function that allows it to 528 | return additional loss values without interrupting the abstract flow. For example, this could be used in a flow-based 529 | prior that has additional KL terms over ODE flow density without needing to override the training_step 530 | function with a duplicate copy. As well, there is additionally model_specific_plotting to enable custom 531 | plots every training epoch end. 532 | 533 | ### Implemented Dynamics 534 | 535 | Group A, 'State-Estimation': For the Group A PGM category, we provide a reimplementation of the classic Neural SSM work Deep Kalman Filter[7] alongside state estimation versions of the above, provided in Fig. N below. The DKF model modifies the standard Kalman Filter Gaussian transition function to incorporate non-linearity and expressivity by parameterizing the distribution parameters with neural networks ztN(G(zt−1,∆t), S(zt−1,∆t))[7]. 536 | Additionally, we provide a reimplementation of the Variational Recurrent Neural 537 | Network (VRNN), one of the starting state estimation works in Neural SSMs[22]. For the latent correction 538 | step, we leverage a standard Gated Recurrent Unit (GRU) cell and the corrected latent state is what is passed to the 539 | decoder and likelihood function. Notably, there are two settings these models can be run under: reconstruction 540 | and generation. Reconstruction is used for training and incorporates ground truth observations to correct 541 | the latent state while generation is used to test the forecasting abilities of the model, both short- and long-term. 542 | 543 |

544 |

Grou pA PGM models

545 |

Fig N. Model schematics for Group A's implemented dynamics functions.

546 | 547 | 548 | Group B1, 'Time-Varying System-Identification': For the Group B1 PGM category, we present a reimplementation of the Kalman Variational Autoencoder (KVAE)[2], a hybrid model combining Kalman Filter dynamics with a deep recognition network. The KVAE disentangles latent states into dynamics and representations by using a structured transition model paired with a learned recognition function. Specifically, the latent dynamics are modeled as xtN(A xt−1 + But, Q), where A, B, and Q parameterize a linear dynamical system, while the recognition network maps observations into latent space. 549 | We provide both state estimation and generation capabilities. State estimation incorporates observed data into the inference process at every timestep. For generation, the model tests its forecasting ability in both short- and long-term scenarios by propagating dynamics without direct observation, instead using prior reconstructions as input. 550 | 551 |

552 |

Grou pA PGM models

553 |

Fig N. Model schematics for Group A's implemented dynamics functions.

554 | 555 | 556 | Group B2, 'Time-Invariant System-Identification': For the system identification models, we provide a variety of dynamics functions that resemble the general and special 557 | cases detailed above, which are provided in Fig N. below. The most general version is that of the Bayesian Neural ODE, 558 | in which a neural ordinary differential equation[21] is sampled from a set of optimized distributional 559 | parameters and used as the latent dynamics function 560 | z't = fp(θ)(zs)[5]. A deterministic version 561 | of a standard Neural ODE is similarly provided, e.g. 562 | z't = fθ(zs)[10,21]. Following that, two forms of a 563 | Recurrent Generative Network are provided, a residual version (RGN-Res) and a full-step version (RGN), that represent 564 | deterministic and discrete non-linear transition functions. RGN-Res is the equivalent of a Neural ODE using a fixed 565 | step Euler integrator while RGN is just a recurrent forward step function. 566 | 567 |

568 | Training for these models has one mode, that of taking in several observational frames to infer z0 569 | and then outputting a full sequence autonomously without access to subsequent observations. A likelihood function is 570 | compared over the full reconstructed sequence and optimized over. Testing and generation in this setting can be done 571 | out to any horizon easily. 572 | 573 |

574 |

sysID models

575 |

Fig N. Model schematics for system identification's implemented dynamics functions.

576 | 577 | 578 | 579 | 580 | 581 | 582 | ## Metrics 583 | 584 | Mean Squared Error (MSE): A common metric used in video and image tasks where its use is in per-frame average over individual pixel error. While a multitude of papers solely uses plots of frame MSE over time as an evaluation metric, it is insufficient for comparison between models - especially in cases where the dataset contains a small object for reconstruction[4]. This is especially prominent in tasks of system identification where a model that fails to predict long-term may end up with a lower average MSE than a model that has better generation but is slightly off in its object placement. 585 | 586 |

mse equation

587 |

Fig N. Per-Frame MSE Equation.

588 | 589 | Valid Prediction Time (VPT): Introduced in [4], the VPT metric is an advance on latent dynamics evaluation over pure pixel-based MSE metrics. For each prediction sequence, the per-pixel MSE is taken over the frames individually, and the minimum timestep in which the MSE surpasses a pre-defined epsilon is considered the 'valid prediction time.' The resulting mean number over the samples is often normalized over the total prediction timesteps to get a percentage of valid predictions. 590 | 591 |

vpt equation

592 |

Fig N. Per-Sequence VPT Equation.

593 | 594 | Object Distance (DST): Another potential metric to support evaluation (useful in image-based physics forecasting tasks) is using the Euclidean distance between the estimated center of the predicted object and its ground truth center. Otsu's Thresholding method can be applied to grayscale output images to get binary predictions of each pixel and then the average pixel location of all the "active" pixels can be calculated. This approach can help alleviate the prior MSE issues of metric imbalance as the maximum Euclidean error of a given image space can be applied to model predictions that fail to have any pixels over Otsu's threshold. 595 | 596 |

dst equation

597 |

Fig N. Per-Frame DST Equation.

598 | where RN is the dimension of the output (e.g. number of image channels) and s, shat are the subsets of "active" predicted pixels. 599 | 600 |

601 | Valid Prediction Distance (VPD): Similar in spirit to how VPT leverages MSE, VPD is the minimum timestep in which the DST metric surpasses a pre-defined epsilon[29]. This is useful in tracking how long a model can generate an object in a physical system before either incorrect trajectories and/or error accumulation cause significant divergence. 602 | 603 |

vpd equation

604 |

Fig N. Per-Sequence VPD Equation.

605 | 606 |

607 | R2 Score: For evaluating systems where the full underlying latent system is available and known (e.g. image translations of Hamiltonian systems), the goodness-of-fit score R2 can be used per dimension to show how well the latent system of the Neural SSM captures the dynamics in an interpretable way[1,3]. This is easiest to leverage in linear transition dynamics. 608 | 609 | Ref. [1], while containing linear transition dynamics, mentioned the possibility of non-linear regression via vanilla neural networks, though this may run into concerns of regressor capacity and data sizes. Additionally, incorporating metrics derived from latent disentanglement learning may provide stronger evaluation capabilities. 610 | 611 |

dvbf latent interpretability

612 |

Fig N. DVBF Latent Space Visualization for R2 scores, sourced from [1,3].

613 | 614 | 615 | 616 | 617 | # Experiments 618 | 619 | This section details some experiments that evaluate the fundamental aspects of Neural SSMs and the effects of the 620 | framework decisions one can take. Trained model checkpoints and hyperparameter files are provided for each experiment 621 | under experiments/model. Evaluations are done with the metrics discussed above, as well as visualizations of 622 | animated trajectories over time and latent walk visualizations. 623 | 624 | 625 | 626 | 627 | ## Pendulum Dynamics 628 | 629 | Here we report the results of tuning each of the models on the Hamiltonian physics dataset Pendulum. For each model, 630 | we highlight their best-performing hyperparameters with respect to the validation extrapolation MSE. For experiment 631 | going forwards, these hyperparameters will be used in experiments of similar complexity. 632 | 633 | We test two environments for the Pendulum dataset, a fixed-point one-color pendulum and a multi-point multi-color 634 | pendulum set of increased complexity. As described in [4], each individual sequence is sampled from a uniform 635 | distribution over physical parameters like mass, gravity, and pivot length. 636 | We describe data generation above in the Data section. 637 | 638 |
[TO-DO: Click to show the results for Fixed-Point Pendulum] 639 | Coming soon. 640 |
641 | 642 |
643 | 644 |
[TO-DO: Click to show the results for Multi-Point Pendulum tuning] 645 | Coming soon. 646 |
647 | 648 | 649 | 650 | 651 | ## Bouncing Ball Dynamics 652 | 653 | Similar to above, we highlight the results and hyperparameters of each model for the Bouncing Ball dataset. 654 | 655 |
[TO-DO: Click to show the results for Bouncing Ball tuning] 656 | Coming soon. 657 |
658 | 659 | 660 | 661 | ## Ablation Studies 662 | 663 | ODE Solvers: To measure the impact that ODE Solvers have on the optimized dynamics models, we performed an ablation on the available 664 | solvers that exist within the torchdiffeq library, including both fixed and adaptive solvers. We make note of 665 | their respective training times due to increased solution complexity and train each ODE solver over a variety of parameters 666 | depending on their type (e.g. step size or solution tolerances). 667 | 668 |
[Click to show the results for the ODE Solver ablation] 669 | Coming soon. 670 |
671 | 672 | 673 | 674 | # Miscellaneous 675 | 676 | This section just consists of to-do's within the repo, contribution guidelines, 677 | and a section on how to find the references used throughout the repo. 678 | 679 | 680 | 681 | ## To-Do 682 | 683 |

Ablations-to-do

684 | 685 | - Generation lengths used in training (e.g. 1/2/3/5/10 frames) 686 | - Fixed vs variable generation lengths 687 | - z0 inference scheme (no overlap, overlap-by-one, full overlap) 688 | - Use of ODE solvers (fixed, adaptive, tolerances) 689 | - Different forms of learning rate schedulers 690 | - Linear versus CNN decoder 691 | - Activation functions in the latent dynamics function 692 | 693 |

Repository-wise

694 | 695 |

Model-wise

696 | 697 |

Evaluation-wise

698 | 699 | - Implement latent walk visualizations against data-space observations (like in DVBF) 700 | 701 |

README-wise

702 | 703 | - Add guidelines for an ```Experiment``` section highlighting experiments performed in validating the models 704 | - Add a section explaining for ```Meta-Learning``` works in Neural SSMs 705 | - Add a section explaining for ```ODE Integrator``` considerations in Neural SSMs 706 | 707 | 708 | 709 | ## Contributions 710 | Contributions are welcome and encouraged! If you have an implementation of a latent dynamics function you think 711 | would be relevant and add to the conversation, feel free to submit an Issue or PR and we can discuss its 712 | incorporation. Similarly, if you feel an area of the README is lacking or contains errors, please put up a 713 | README editing PR with your suggested updates. Even tackling items on the To-Do would be massively helpful! 714 | 715 | 716 | 717 | ## References 718 | 1. Maximilian Karl, Maximilian Soelch, Justin Bayer, and Patrick van der Smagt. Deep variational bayes filters: Unsupervised learning of state space models from raw data. In International Conference on Learning Representations, 2017. 719 | 2. Marco Fraccaro, Simon Kamronn, Ulrich Paquetz, and OleWinthery. A disentangled recognition and nonlinear dynamics model for unsupervised learning. In Advances in Neural Information Processing Systems, 2017. 720 | 3. Alexej Klushyn, Richard Kurle, Maximilian Soelch, Botond Cseke, and Patrick van der Smagt. Latent matters: Learning deep state-space models. Advances in Neural Information Processing Systems, 34, 2021. 721 | 4. Aleksandar Botev, Andrew Jaegle, Peter Wirnsberger, Daniel Hennes, and Irina Higgins. Which priors matter? benchmarking models for learning latent dynamics. In Advances in Neural Information Processing Systems, 2021. 722 | 5. C. Yildiz, M. Heinonen, and H. Lahdesmaki. ODE2VAE: Deep generative second order odes with bayesian neural networks. In Neural Information Processing Systems, 2020. 723 | 6. Batuhan Koyuncu. Analysis of ode2vae with examples. arXiv preprint arXiv:2108.04899, 2021. 724 | 7. Rahul G. Krishnan, Uri Shalit, and David Sontag. Structured inference networks for nonlinear state space models. In Association for the Advancement of Artificial Intelligence, 2017. 725 | 8. Daehoon Gwak, Gyuhyeon Sim, Michael Poli, Stefano Massaroli, Jaegul Choo, and Edward Choi. Neural ordinary differential equations for intervention modeling. arXiv preprint arXiv:2010.08304, 2020. 726 | 9. Junyoung Chung, Kyle Kastner, Laurent Dinh, Kratarth Goel, Aaron Courville, and Yoshua Bengio. A recurrent latent variable model for sequential data. In Advances in Neural Information Processing Systems, 2015. 727 | 10. Yulia Rubanova, Ricky T. Q. Chen, and David Duvenaud. Latent odes for irregularly-sampled time series. In Neural Information Processing Systems, 2019. 728 | 11. Tsuyoshi Ishizone, Tomoyuki Higuchi, and Kazuyuki Nakamura. Ensemble kalman variational objectives: Nonlinear latent trajectory inference with a hybrid of variational inference and ensemble kalman filter. arXiv preprint arXiv:2010.08729, 2020. 729 | 12. Justin Bayer, Maximilian Soelch, Atanas Mirchev, Baris Kayalibay, and Patrick van der Smagt. Mind the gap when conditioning amortised inference in sequential latent-variable models. arXiv preprint arXiv:2101.07046, 2021. 730 | 13. Ðor ̄de Miladinovi ́c, Muhammad Waleed Gondal, Bernhard Schölkopf, Joachim M Buhmann, and Stefan Bauer. Disentangled state space representations. arXiv preprint arXiv:1906.03255, 2019. 731 | 14. Zeshan Hussain, Rahul G. Krishnan, and David Sontag. Neural pharmacodynamic state space modeling, 2021. 732 | 15. Francesco Paolo Casale, Adrian Dalca, Luca Saglietti, Jennifer Listgarten, and Nicolo Fusi.Gaussian process prior variational autoencoders. Advances in neural information processing systems, 31, 2018. 733 | 16. Yingzhen Li and Stephan Mandt. Disentangled sequential autoencoder. arXiv preprint arXiv:1803.02991, 2018. 734 | 17. Patrick Kidger, James Morrill, James Foster, and Terry Lyons. Neural controlled differential equations for irregular time series. Advances in Neural Information Processing Systems, 33:6696-6707, 2020. 735 | 18. Edward De Brouwer, Jaak Simm, Adam Arany, and Yves Moreau. Gru-ode-bayes: Continuous modeling of sporadically-observed time series. Advances in neural information processing systems, 32, 2019. 736 | 19. Ruben Villegas, Jimei Yang, Yuliang Zou, Sungryull Sohn, Xunyu Lin, and Honglak Lee. Learning to generate long-term future via hierarchical prediction. In international conference on machine learning, pages 3560–3569. PMLR, 2017 737 | 20. Xiajun Jiang, Ryan Missel, Maryam Toloubidokhti, Zhiyuan Li, Omar Gharbia, John L Sapp, and Linwei Wang. Label-free physics-informed image sequence reconstruction with disentangled spatial-temporal modeling. In International Conference on Medical Image Computing and Computer-Assisted Intervention, pages 361–371. Springer, 2021. 738 | 21. Ricky TQ Chen, Yulia Rubanova, Jesse Bettencourt, and David K Duvenaud. Neural ordinary differential equations. Advances in neural information processing systems, 31, 2018. 739 | 22. Junyoung Chung, Kyle Kastner, Laurent Dinh, Kratarth Goel, Aaron C Courville, and Yoshua Bengio. A recurrent latent variable model for sequential data. Advances in neural information processing systems, 28, 2015. 740 | 23. Junbo Zhang, Yu Zheng, and Dekang Qi. Deep spatio-temporal residual networks for citywide crowd flows prediction. In Thirty-first AAAI conference on artificial intelligence, 2017. 741 | 24. Yong Han, Shukang Wang, Yibin Ren, Cheng Wang, Peng Gao, and Ge Chen. Predicting station-level short-term passenger flow in a citywide metro network using spatiotemporal graph convolutional neural networks. ISPRS International Journal of Geo-Information, 8(6):243, 2019 742 | 25. Maryam Toloubidokhti, Ryan Missel, Xiajun Jiang, Niels Otani, and Linwei Wang. Neural state-space modeling with latent causal-effect disentanglement. In International Workshop on Machine Learning in Medical Imaging, 2022. 743 | 26. Matthieu Kirchmeyer, Yuan Yin, J ́er ́emie Don`a, Nicolas Baskiotis, Alain Rakotomamonjy, and Patrick Gallinari. Generalizing to new physical systems via context-informed dynamics model. arXiv preprint arXiv:2202.01889, 2022. 744 | 27. Rui Wang, Robin Walters, and Rose Yu. Meta-learning dynamics forecasting using task inference. arXiv preprint arXiv:2102.10271, 2021. 745 | 28. Kingma Diederik P, Welling Max. Auto-encoding variational bayes // arXiv preprint arXiv:1312.6114.2013. 746 | 29. Xiajun Jiang, Zhiyuan Li, Ryan Missel, Md Shakil Zaman, Brian Zenger, Wilson W Good, Rob S MacLeod, John L Sapp, and Linwei Wang. Few-shot generation of personalized neural surrogates for cardiac simulation via bayesian meta-learning. In International Conference on Medical Image Computing and Computer-Assisted Intervention, pages 46–56. Springer, 2022. 747 | 30. Marco Forgione, Manas Mejari, and Dario Piga. Learning neural state-space models: do we need a state estimator? arXiv preprint arXiv:2206.12928, 2022. 748 | -------------------------------------------------------------------------------- /configs/config.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - model: node 4 | - dataset: bouncingball 5 | - training: forecast_forecast 6 | 7 | hydra: 8 | output_subdir: null 9 | run: 10 | dir: . 11 | 12 | # Random seed of the run 13 | seed: 125125125 14 | devices: [0] 15 | 16 | # Experiment folder naming 17 | expname: ${dataset.dataset}_${model.model_type}_${training.z_amort_train}ztrain_${training.z_amort_test}ztest_${training.num_steps}iterations_${seed}seed 18 | model_path: "" 19 | checkpt: "" 20 | 21 | # For training, overrideable by cmd 22 | train: true 23 | resume: false 24 | -------------------------------------------------------------------------------- /configs/dataset/bouncingball.yaml: -------------------------------------------------------------------------------- 1 | dataset: bouncingball 2 | dataset_percent: 1.0 3 | batches_to_save: 50 4 | -------------------------------------------------------------------------------- /configs/model/architecture/default.yaml: -------------------------------------------------------------------------------- 1 | # Dimension and channels of the input image 2 | dim: 32 3 | num_channels: 1 4 | 5 | # Dynamics MLP function parameters 6 | num_layers: 2 7 | num_hidden: 128 8 | latent_dim: 8 9 | latent_act: swish 10 | 11 | # VAE parameters 12 | num_filters: 24 13 | -------------------------------------------------------------------------------- /configs/model/dkf.yaml: -------------------------------------------------------------------------------- 1 | model_type: dkf 2 | stochastic: true 3 | 4 | defaults: 5 | - architecture: default -------------------------------------------------------------------------------- /configs/model/kvae.yaml: -------------------------------------------------------------------------------- 1 | model_type: kvae 2 | stochastic: true 3 | 4 | defaults: 5 | - architecture: default -------------------------------------------------------------------------------- /configs/model/node.yaml: -------------------------------------------------------------------------------- 1 | model_type: node 2 | stochastic: true 3 | 4 | integrator: rk4 5 | integrator_step_size: 0.5 6 | 7 | defaults: 8 | - architecture: default -------------------------------------------------------------------------------- /configs/model/rgn.yaml: -------------------------------------------------------------------------------- 1 | model_type: rgn 2 | stochastic: true 3 | 4 | defaults: 5 | - architecture: default -------------------------------------------------------------------------------- /configs/model/rgnres.yaml: -------------------------------------------------------------------------------- 1 | model_type: rgnres 2 | stochastic: true 3 | 4 | defaults: 5 | - architecture: default -------------------------------------------------------------------------------- /configs/model/vrnn.yaml: -------------------------------------------------------------------------------- 1 | model_type: vrnn 2 | stochastic: true 3 | 4 | defaults: 5 | - architecture: default -------------------------------------------------------------------------------- /configs/training/forecast_forecast.yaml: -------------------------------------------------------------------------------- 1 | # PyTorch-Lightning hardware params 2 | accelerator: gpu 3 | num_workers: 0 4 | 5 | # Which GPU ID to run on 6 | devices: [0] 7 | 8 | # Number of steps per task 9 | num_steps: 50001 10 | 11 | # How often to log metrics and how often to save image reconstructions 12 | metric_interval: 50 13 | image_interval: 500 14 | 15 | # What metrics to evaluate on 16 | metrics: 17 | - vpt 18 | - reconstruction_mse 19 | 20 | test_metrics: 21 | - vpt 22 | - dst 23 | - vpd 24 | - reconstruction_mse 25 | - extrapolation_mse 26 | 27 | # Batch size 28 | batch_size: 64 29 | 30 | # Learning rate and cosine annealing scheduler 31 | # We use CosineAnnealing with WarmRestarts 32 | # More information here: https://github.com/qu-gg/pytorch-cosine-annealing-with-decay-and-initial-warmup 33 | learning_rate: 1e-3 34 | scheduler_use: true 35 | scheduler_restart_interval: 5000 36 | scheduler_warmup_steps: 200 37 | scheduler_decay: 0.90 38 | 39 | # KL loss betas 40 | beta_z0: 1e-2 41 | beta_kl: 1e-3 42 | 43 | # How many steps are given as observed data 44 | # For forecasting, this will be small (e.g., 3 to 5) 45 | # For reconstruction, this will be the train_length 46 | z_amort_train: 5 47 | z_amort_test: 5 48 | 49 | # Total steps to either reconstruct or forecast for 50 | train_length: 20 51 | val_length: 20 52 | test_length: 20 -------------------------------------------------------------------------------- /configs/training/forecast_recon.yaml: -------------------------------------------------------------------------------- 1 | # PyTorch-Lightning hardware params 2 | accelerator: gpu 3 | num_workers: 0 4 | 5 | # Which GPU ID to run on 6 | devices: [0] 7 | 8 | # Number of steps per task 9 | num_steps: 50001 10 | 11 | # How often to log metrics and how often to save image reconstructions 12 | metric_interval: 50 13 | image_interval: 500 14 | 15 | # What metrics to evaluate on 16 | metrics: 17 | - vpt 18 | - reconstruction_mse 19 | 20 | test_metrics: 21 | - vpt 22 | - dst 23 | - vpd 24 | - reconstruction_mse 25 | - extrapolation_mse 26 | 27 | # Batch size 28 | batch_size: 64 29 | 30 | # Learning rate and cosine annealing scheduler 31 | # We use CosineAnnealing with WarmRestarts 32 | # More information here: https://github.com/qu-gg/pytorch-cosine-annealing-with-decay-and-initial-warmup 33 | learning_rate: 1e-3 34 | scheduler_use: true 35 | scheduler_restart_interval: 5000 36 | scheduler_warmup_steps: 200 37 | scheduler_decay: 0.90 38 | 39 | # KL loss betas 40 | beta_z0: 1e-2 41 | beta_kl: 1e-3 42 | 43 | # How many steps are given as observed data 44 | # For forecasting, this will be small (e.g., 3 to 5) 45 | # For reconstruction, this will be the train_length 46 | z_amort_train: 5 47 | z_amort_test: 20 48 | 49 | # Total steps to either reconstruct or forecast for 50 | train_length: 20 51 | val_length: 20 52 | test_length: 20 -------------------------------------------------------------------------------- /configs/training/recon_forecast.yaml: -------------------------------------------------------------------------------- 1 | # PyTorch-Lightning hardware params 2 | accelerator: gpu 3 | num_workers: 0 4 | 5 | # Which GPU ID to run on 6 | devices: [0] 7 | 8 | # Number of steps per task 9 | num_steps: 50001 10 | 11 | # How often to log metrics and how often to save image reconstructions 12 | metric_interval: 50 13 | image_interval: 500 14 | 15 | # What metrics to evaluate on 16 | metrics: 17 | - vpt 18 | - reconstruction_mse 19 | 20 | test_metrics: 21 | - vpt 22 | - dst 23 | - vpd 24 | - reconstruction_mse 25 | - extrapolation_mse 26 | 27 | # Batch size 28 | batch_size: 64 29 | 30 | # Learning rate and cosine annealing scheduler 31 | # We use CosineAnnealing with WarmRestarts 32 | # More information here: https://github.com/qu-gg/pytorch-cosine-annealing-with-decay-and-initial-warmup 33 | learning_rate: 1e-3 34 | scheduler_use: true 35 | scheduler_restart_interval: 5000 36 | scheduler_warmup_steps: 200 37 | scheduler_decay: 0.90 38 | 39 | # KL loss betas 40 | beta_z0: 1e-2 41 | beta_kl: 1e-3 42 | 43 | # How many steps are given as observed data 44 | # For forecasting, this will be small (e.g., 3 to 5) 45 | # For reconstruction, this will be the train_length 46 | z_amort_train: 20 47 | z_amort_test: 5 48 | 49 | # Total steps to either reconstruct or forecast for 50 | train_length: 20 51 | val_length: 20 52 | test_length: 20 -------------------------------------------------------------------------------- /configs/training/recon_recon.yaml: -------------------------------------------------------------------------------- 1 | # PyTorch-Lightning hardware params 2 | accelerator: gpu 3 | num_workers: 0 4 | 5 | # Which GPU ID to run on 6 | devices: [0] 7 | 8 | # Number of steps per task 9 | num_steps: 50001 10 | 11 | # How often to log metrics and how often to save image reconstructions 12 | metric_interval: 50 13 | image_interval: 500 14 | 15 | # What metrics to evaluate on 16 | metrics: 17 | - vpt 18 | - reconstruction_mse 19 | 20 | test_metrics: 21 | - vpt 22 | - dst 23 | - vpd 24 | - reconstruction_mse 25 | - extrapolation_mse 26 | 27 | # Batch size 28 | batch_size: 64 29 | 30 | # Learning rate and cosine annealing scheduler 31 | # We use CosineAnnealing with WarmRestarts 32 | # More information here: https://github.com/qu-gg/pytorch-cosine-annealing-with-decay-and-initial-warmup 33 | learning_rate: 1e-3 34 | scheduler_use: true 35 | scheduler_restart_interval: 5000 36 | scheduler_warmup_steps: 200 37 | scheduler_decay: 0.90 38 | 39 | # KL loss betas 40 | beta_z0: 1e-2 41 | beta_kl: 1e-3 42 | 43 | # How many steps are given as observed data 44 | # For forecasting, this will be small (e.g., 3 to 5) 45 | # For reconstruction, this will be the train_length 46 | z_amort_train: 20 47 | z_amort_test: 20 48 | 49 | # Total steps to either reconstruct or forecast for 50 | train_length: 20 51 | val_length: 20 52 | test_length: 20 -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qu-gg/torch-neural-ssm/13813f75e18b9efd5cedc40bd7d707b0a1e7870e/data/README.md -------------------------------------------------------------------------------- /data/generate_bouncingball.py: -------------------------------------------------------------------------------- 1 | """ 2 | @file twomix_gravity.py 3 | 4 | Handles generating a mixed set of a simple static velocity set (left to right, same initial velocities) alongside 5 | a single gravity 6 | """ 7 | import os 8 | import pygame 9 | import random 10 | import tarfile 11 | import numpy as np 12 | import pymunk.pygame_util 13 | import matplotlib.pyplot as plt 14 | 15 | from tqdm import tqdm 16 | 17 | 18 | class BallBox: 19 | def __init__(self, dt=0.2, res=(32, 32), init_pos=(3, 3), init_std=0, wall=None, gravity=(0.0, 0.0), ball_color="white"): 20 | pygame.init() 21 | 22 | self.ball_color = ball_color 23 | 24 | self.dt = dt 25 | self.res = res 26 | if os.environ.get('SDL_VIDEODRIVER', '') == 'dummy': 27 | pygame.display.set_mode(res, 0, 24) 28 | self.screen = pygame.Surface(res, pygame.SRCCOLORKEY, 24) 29 | pygame.draw.rect(self.screen, (0, 0, 0), (0, 0, res[0], res[1]), 0) 30 | else: 31 | self.screen = pygame.display.set_mode(res, 0, 24) 32 | self.gravity = gravity 33 | self.initial_position = init_pos 34 | self.initial_std = init_std 35 | self.space = pymunk.Space() 36 | self.space.gravity = self.gravity 37 | self.draw_options = pymunk.pygame_util.DrawOptions(self.screen) 38 | self.clock = pygame.time.Clock() 39 | self.wall = wall 40 | self.static_lines = None 41 | 42 | self.dd = 2 43 | 44 | def _clear(self): 45 | self.screen.fill((0.3176, 0.3451, 0.3647)) 46 | 47 | def create_ball(self, radius=3): 48 | inertia = pymunk.moment_for_circle(1, 0, radius, (0, 0)) 49 | body = pymunk.Body(1, inertia) 50 | position = np.array(self.initial_position) + self.initial_std * np.random.normal(size=(2,)) 51 | position = np.clip(position, self.dd + radius + 1, self.res[0]-self.dd-radius-1) 52 | position = position.tolist() 53 | body.position = position 54 | 55 | shape = pymunk.Circle(body, radius, (0, 0)) 56 | shape.elasticity = 1.0 57 | 58 | shape.color = pygame.color.THECOLORS[self.ball_color] 59 | return shape 60 | 61 | def fire(self, angle=50, velocity=20, radius=3): 62 | speedX = velocity * np.cos(angle * np.pi / 180) 63 | speedY = velocity * np.sin(angle * np.pi / 180) 64 | 65 | ball = self.create_ball(radius) 66 | ball.body.velocity = (speedX, speedY) 67 | 68 | self.space.add(ball, ball.body) 69 | return ball 70 | 71 | def run(self, iterations=20, sequences=500, angle_limits=(0, 360), velocity_limits=(10, 25), radius=3, 72 | flip_gravity=None, save=None, filepath='../../data/balls.npz', delay=None): 73 | if save: 74 | images = np.empty((sequences, iterations, self.res[0], self.res[1]), dtype=np.float32) 75 | state = np.empty((sequences, iterations, 2), dtype=np.float32) 76 | 77 | dd = 0 78 | self.static_lines = [pymunk.Segment(self.space.static_body, (-1, -1), (-1, self.res[1]-dd), 0.0), 79 | pymunk.Segment(self.space.static_body, (-1, -1), (self.res[0]-dd, -1), 0.0), 80 | pymunk.Segment(self.space.static_body, (self.res[0] - dd, self.res[1] - dd), 81 | (-1, self.res[1]-dd), 0.0), 82 | pymunk.Segment(self.space.static_body, (self.res[0] - dd, self.res[1] - dd), 83 | (self.res[0]-dd, -1), 0.0)] 84 | for line in self.static_lines: 85 | line.elasticity = 1.0 86 | 87 | if self.ball_color == "white2": 88 | line.color = pygame.color.THECOLORS["white"] 89 | else: 90 | line.color = pygame.color.THECOLORS[self.ball_color] 91 | # self.space.add(self.static_lines) 92 | 93 | for sl in self.static_lines: 94 | self.space.add(sl) 95 | 96 | for s in range(sequences): 97 | if s % 100 == 0: 98 | print(s) 99 | 100 | angle = np.random.uniform(*angle_limits) 101 | velocity = np.random.uniform(*velocity_limits) 102 | # controls[:, s] = np.array([angle, velocity]) 103 | ball = self.fire(angle, velocity, radius) 104 | for i in range(iterations): 105 | self._clear() 106 | self.space.debug_draw(self.draw_options) 107 | self.space.step(self.dt) 108 | pygame.display.flip() 109 | 110 | if delay: 111 | self.clock.tick(delay) 112 | 113 | if save == 'png': 114 | pygame.image.save(self.screen, os.path.join(filepath, "bouncing_balls_%02d_%02d.png" % (s, i))) 115 | elif save == 'npz': 116 | images[s, i] = pygame.surfarray.array2d(self.screen).swapaxes(1, 0).astype(np.float32) / (2**24 - 1) 117 | state[s, i] = list(ball.body.velocity) # list(ball.body.position) + # Note that this is done for compatibility with the combined dataset 118 | 119 | # Remove the ball and the wall from the space 120 | self.space.remove(ball, ball.body) 121 | 122 | return images, state 123 | 124 | 125 | if __name__ == '__main__': 126 | os.environ['SDL_VIDEODRIVER'] = 'dummy' 127 | 128 | # Parameters of generation, resolution and number of samples 129 | scale = 1 130 | timesteps = 40 131 | training_size = 10000 132 | validation_size = 1500 133 | testing_size = 2500 134 | total_size = training_size + validation_size + testing_size 135 | 136 | base_dir = f"bouncingball_{training_size}samples_{timesteps}steps/" 137 | 138 | # Arrays to hold sets 139 | train_images, train_states = [], [] 140 | test_images, test_states = [], [] 141 | np.random.seed(1234) 142 | 143 | # Generate the data sequences 144 | cannon = BallBox(dt=0.25, res=(32*scale, 32*scale), init_pos=(16*scale, 16*scale), init_std=8, 145 | wall=None, gravity=(0.0, -5.0), ball_color="white") 146 | i, s = cannon.run(delay=None, iterations=timesteps, sequences=total_size + 5000, 147 | radius=3, angle_limits=(0, 360), velocity_limits=(5.0, 10.0), save='npz') 148 | 149 | # Setting the pixels to a uniform background and foreground (for simplicity of training) 150 | i = (i > 0).astype(np.float32) 151 | 152 | # Brute force check for any bad trajectories where the ball leaves 153 | bad_indices = [] 154 | for seq_idx, sequence in enumerate(i): 155 | sums = np.sum(sequence, axis=(1, 2)) 156 | if np.where(sums == 0.0)[0].shape[0] > 0: 157 | bad_indices.append(seq_idx) 158 | 159 | print(f"Bad indices: {len(bad_indices)}") 160 | 161 | i = np.delete(i, bad_indices, 0) 162 | s = np.delete(s, bad_indices, 0) 163 | if i.shape[0] > total_size: 164 | i = i[:total_size, :] 165 | s = s[:total_size, :] 166 | 167 | # Break into train and test sets, adding in generic labels 168 | train_images = i[:training_size] 169 | train_states = s[:training_size] 170 | train_classes = np.full([train_images.shape[0], 1], fill_value=2) 171 | 172 | val_images = i[training_size:training_size + validation_size] 173 | val_states = s[training_size:training_size + validation_size] 174 | val_classes = np.full([val_images.shape[0], 1], fill_value=2) 175 | 176 | test_images = i[training_size + validation_size:] 177 | test_states = s[training_size + validation_size:] 178 | test_classes = np.full([test_images.shape[0], 1], fill_value=2) 179 | 180 | print(f"Train - Images: {train_images.shape} | States: {train_states.shape} | Classes: {train_classes.shape}") 181 | print(f"Val - Images: {val_images.shape} | States: {val_states.shape} | Classes: {val_classes.shape}") 182 | print(f"Test - Images: {test_images.shape} | States: {test_states.shape} | Classes: {test_classes.shape}") 183 | 184 | # Make sure all directories are made beforehand 185 | if not os.path.exists(f"{base_dir}/"): 186 | os.mkdir(f"{base_dir}/") 187 | 188 | # Permute the sets and states together 189 | p = np.random.permutation(train_images.shape[0]) 190 | train_images, train_classes, train_states = train_images[p], train_classes[p], train_states[p] 191 | 192 | p = np.random.permutation(val_images.shape[0]) 193 | val_images, val_classes, val_states = val_images[p], val_classes[p], val_states[p] 194 | 195 | p = np.random.permutation(test_images.shape[0]) 196 | test_images, test_classes, test_states = test_images[p],test_classes[p], test_states[p] 197 | 198 | # Save sets 199 | np.savez(os.path.abspath(f"{base_dir}/train.npz"), image=train_images, state=train_states, label=train_classes) 200 | np.savez(os.path.abspath(f"{base_dir}/val.npz"), image=val_images, state=val_states, label=val_classes) 201 | np.savez(os.path.abspath(f"{base_dir}/test.npz"), image=test_images, state=test_states, label=test_classes) 202 | -------------------------------------------------------------------------------- /data/generate_hamiltonian.py: -------------------------------------------------------------------------------- 1 | """ 2 | @file generate_date.py 3 | @url https://github.com/deepmind/dm_hamiltonian_dynamics_suite/tree/master/dm_hamiltonian_dynamics_suite 4 | 5 | Adapted from the DeepMind Hamiltonian Dynamics Suite for usage with WebDataset 6 | Requires Linux and JAX to run, as well as a local folder structure of dm_hamiltonian_dynamics_suite. 7 | As such, it won't run straight in this repository, however is here for demonstration purposes. 8 | """ 9 | from subprocess import getstatusoutput 10 | from matplotlib import pyplot as plt 11 | from matplotlib import animation as plt_animation 12 | import numpy as np 13 | from jax import config as jax_config 14 | 15 | import os 16 | import random 17 | import tarfile 18 | from tqdm import tqdm 19 | 20 | jax_config.update("jax_enable_x64", True) 21 | 22 | from dm_hamiltonian_dynamics_suite import load_datasets 23 | from dm_hamiltonian_dynamics_suite import datasets 24 | 25 | # @title Helper functions 26 | DATASETS_URL = "gs://dm-hamiltonian-dynamics-suite" 27 | DATASETS_FOLDER = "./datasets" # @param {type: "string"} 28 | os.makedirs(DATASETS_FOLDER, exist_ok=True) 29 | 30 | 31 | def download_file(file_url, destination_file): 32 | print("Downloading", file_url, "to", destination_file) 33 | command = f"gsutil cp {file_url} {destination_file}" 34 | status_code, output = getstatusoutput(command) 35 | if status_code != 0: 36 | raise ValueError(output) 37 | 38 | 39 | def download_dataset(dataset_name: str): 40 | """Downloads the provided dataset from the DM Hamiltonian Dataset Suite""" 41 | destination_folder = os.path.join(DATASETS_FOLDER, dataset_name) 42 | dataset_url = os.path.join(DATASETS_URL, dataset_name) 43 | os.makedirs(destination_folder, exist_ok=True) 44 | if "long_trajectory" in dataset_name: 45 | files = ("features.txt", "test.tfrecord") 46 | else: 47 | files = ("features.txt", "train.tfrecord", "test.tfrecord") 48 | for file_name in files: 49 | file_url = os.path.join(dataset_url, file_name) 50 | destination_file = os.path.join(destination_folder, file_name) 51 | if os.path.exists(destination_file): 52 | print("File", file_url, "already present.") 53 | continue 54 | download_file(file_url, destination_file) 55 | 56 | 57 | def unstack(value: np.ndarray, axis: int = 0): 58 | """Unstacks an array along an axis into a list""" 59 | split = np.split(value, value.shape[axis], axis=axis) 60 | return [np.squeeze(v, axis=axis) for v in split] 61 | 62 | 63 | def make_batch_grid( 64 | batch: np.ndarray, 65 | grid_height: int, 66 | grid_width: int, 67 | with_padding: bool = True): 68 | """Makes a single grid image from a batch of multiple images.""" 69 | assert batch.ndim == 5 70 | assert grid_height * grid_width >= batch.shape[0] 71 | batch = batch[:grid_height * grid_width] 72 | batch = batch.reshape((grid_height, grid_width) + batch.shape[1:]) 73 | if with_padding: 74 | batch = np.pad(batch, pad_width=[[0, 0], [0, 0], [0, 0], 75 | [1, 0], [1, 0], [0, 0]], 76 | mode="constant", constant_values=1.0) 77 | batch = np.concatenate(unstack(batch), axis=-3) 78 | batch = np.concatenate(unstack(batch), axis=-2) 79 | if with_padding: 80 | batch = batch[:, 1:, 1:] 81 | return batch 82 | 83 | 84 | def plot_animattion_from_batch( 85 | batch: np.ndarray, 86 | grid_height, 87 | grid_width, 88 | with_padding=True, 89 | figsize=None): 90 | """Plots an animation of the batch of sequences.""" 91 | if figsize is None: 92 | figsize = (grid_width, grid_height) 93 | batch = make_batch_grid(batch, grid_height, grid_width, with_padding) 94 | batch = batch[:, ::-1] 95 | fig = plt.figure(figsize=figsize) 96 | plt.close() 97 | ax = fig.add_subplot(1, 1, 1) 98 | ax.axis('off') 99 | img = ax.imshow(batch[0]) 100 | 101 | def frame_update(i): 102 | i = int(np.floor(i).astype("int64")) 103 | img.set_data(batch[i]) 104 | return [img] 105 | 106 | anim = plt_animation.FuncAnimation( 107 | fig=fig, 108 | func=frame_update, 109 | frames=np.linspace(0.0, len(batch), len(batch) * 5 + 1)[:-1], 110 | save_count=len(batch), 111 | interval=10, 112 | blit=True 113 | ) 114 | return anim 115 | 116 | 117 | def plot_sequence_from_batch( 118 | batch: np.ndarray, 119 | t_start: int = 0, 120 | with_padding: bool = True, 121 | fontsize: int = 20): 122 | """Plots all of the sequences in the batch.""" 123 | n, t, dx, dy = batch.shape[:-1] 124 | xticks = np.linspace(dx // 2, t * (dx + 1) - 1 - dx // 2, t) 125 | xtick_labels = np.arange(t) + t_start 126 | yticks = np.linspace(dy // 2, n * (dy + 1) - 1 - dy // 2, n) 127 | ytick_labels = np.arange(n) 128 | batch = batch.reshape((n * t, 1) + batch.shape[2:]) 129 | batch = make_batch_grid(batch, n, t, with_padding)[0] 130 | plt.imshow(batch.squeeze()) 131 | plt.xticks(ticks=xticks, labels=xtick_labels, fontsize=fontsize) 132 | plt.yticks(ticks=yticks, labels=ytick_labels, fontsize=fontsize) 133 | 134 | 135 | def visalize_dataset( 136 | dataset_path: str, 137 | sequence_lengths: int = 60, 138 | grid_height: int = 2, 139 | grid_width: int = 5): 140 | """Visualizes a dataset loaded from the path provided.""" 141 | split = "test" 142 | batch_size = grid_height * grid_width 143 | dataset = load_datasets.load_dataset( 144 | path=dataset_path, 145 | tfrecord_prefix=split, 146 | sub_sample_length=sequence_lengths, 147 | per_device_batch_size=batch_size, 148 | num_epochs=None, 149 | drop_remainder=True, 150 | shuffle=False, 151 | shuffle_buffer=100 152 | ) 153 | sample = next(iter(dataset)) 154 | batch_x = sample['x'].numpy() 155 | batch_image = sample['image'].numpy() 156 | # Plot real system dimensions 157 | plt.figure(figsize=(24, 8)) 158 | for i in range(batch_x.shape[-1]): 159 | plt.subplot(1, batch_x.shape[-1], i + 1) 160 | plt.title(f"Samples from dimension {i + 1}") 161 | plt.plot(batch_x[:, :, i].T) 162 | plt.show() 163 | # Plot a sequence of 50 images 164 | plt.figure(figsize=(30, 10)) 165 | plt.title("Samples from 50 steps sub sequences.") 166 | plot_sequence_from_batch(batch_image[:, :50]) 167 | plt.show() 168 | # Plot animation 169 | return plot_animattion_from_batch(batch_image, grid_height, grid_width) 170 | 171 | 172 | # Generate dataset 173 | print("Generating datasets....") 174 | folder_to_store = "./generated_datasets" 175 | dataset = "pendulum" 176 | class_id = np.array([1]) 177 | dt = 0.1 178 | num_steps = 1000 179 | steps_per_dt = 1 180 | num_train = 1000 181 | num_test = 2000 182 | overwrite = True 183 | datasets.generate_full_dataset( 184 | folder=folder_to_store, 185 | dataset=dataset, 186 | dt=dt, 187 | num_steps=num_steps, 188 | steps_per_dt=steps_per_dt, 189 | num_train=num_train, 190 | num_test=num_test, 191 | overwrite=overwrite, 192 | ) 193 | dataset_full_name = dataset + "_dt_" + str(dt).replace(".", "_") 194 | dataset_output_path = dataset + "_{}samples_{}steps".format(num_train, num_steps) + "_dt" + str(dt).replace(".", "") 195 | dataset_path = os.path.join(folder_to_store, dataset_full_name) 196 | visalize_dataset(dataset_path) 197 | 198 | if not os.path.exists("data_out/{}".format(dataset_output_path)): 199 | os.mkdir("data_out/{}".format(dataset_output_path)) 200 | 201 | """ 202 | Training Generation 203 | """ 204 | print("Converting training files...") 205 | if not os.path.exists("data_out/{}/train/".format(dataset_output_path)): 206 | os.mkdir("data_out/{}/train/".format(dataset_output_path)) 207 | 208 | loaded_dataset = load_datasets.load_dataset( 209 | path=dataset_path, 210 | tfrecord_prefix="train", 211 | sub_sample_length=num_steps, 212 | per_device_batch_size=1, 213 | num_epochs=1, 214 | drop_remainder=True, 215 | shuffle=True, 216 | shuffle_buffer=100 217 | ) 218 | 219 | images = None 220 | xs = None 221 | for idx, sample in enumerate(loaded_dataset): 222 | image = sample['image'][0].numpy() 223 | 224 | # (32, 32, 3) -> (3, 32, 32) 225 | image = np.swapaxes(image, 2, 3) 226 | image = np.swapaxes(image, 1, 2) 227 | 228 | # Just grab the R channel 229 | image = image[:, 0, :, :] 230 | 231 | x = sample['x'].numpy() 232 | 233 | np.savez("data_out/{}/train/{}.npz".format(dataset_output_path, idx), image=image, x=x, class_id=class_id) 234 | 235 | # Get file list and then shuffle it 236 | file_list = os.listdir("data_out/" + dataset_output_path + "/train/") 237 | random.shuffle(file_list) 238 | 239 | if not os.path.exists("data_out/" + dataset_output_path + '/train_tars/'): 240 | os.mkdir("data_out/" + dataset_output_path + '/train_tars/') 241 | 242 | n_shards = 200 243 | elements_per_shard = len(file_list) // n_shards 244 | 245 | for n in tqdm(range(n_shards)): 246 | with tarfile.open("data_out/" + dataset_output_path + "/train_tars/train{0:03}.tar".format(n), "w:gz") as tar: 247 | for file in file_list[n * elements_per_shard: (n + 1) * elements_per_shard]: 248 | tar.add("data_out/" + dataset_output_path + "/train/{}".format(file)) 249 | 250 | """ 251 | Testing Generation 252 | """ 253 | print("Converting testing files...") 254 | if not os.path.exists("data_out/{}/test/".format(dataset_output_path)): 255 | os.mkdir("data_out/{}/test/".format(dataset_output_path)) 256 | loaded_dataset = load_datasets.load_dataset( 257 | path=dataset_path, 258 | tfrecord_prefix="test", 259 | sub_sample_length=num_steps, 260 | per_device_batch_size=1, 261 | num_epochs=1, 262 | drop_remainder=True, 263 | shuffle=True, 264 | shuffle_buffer=100 265 | ) 266 | 267 | images = None 268 | xs = None 269 | for idx, sample in enumerate(loaded_dataset): 270 | image = sample['image'][0].numpy() 271 | 272 | # (32, 32, 3) -> (3, 32, 32) 273 | image = np.swapaxes(image, 2, 3) 274 | image = np.swapaxes(image, 1, 2) 275 | 276 | # Just grab the R channel 277 | image = image[:, 0, :, :] 278 | 279 | x = sample['x'].numpy() 280 | np.savez("data_out/{}/test/{}.npz".format(dataset_output_path, idx), image=image, x=x, class_id=class_id) 281 | 282 | # Get file list and then shuffle it 283 | file_list = os.listdir("data_out/" + dataset_output_path + "/test/") 284 | random.shuffle(file_list) 285 | 286 | if not os.path.exists("data_out/" + dataset_output_path + '/test_tars/'): 287 | os.mkdir("data_out/" + dataset_output_path + '/test_tars/') 288 | 289 | n_shards = 200 290 | elements_per_shard = len(file_list) // n_shards 291 | 292 | for n in tqdm(range(n_shards)): 293 | with tarfile.open("data_out/" + dataset_output_path + "/test_tars/test{0:03}.tar".format(n), "w:gz") as tar: 294 | for file in file_list[n * elements_per_shard: (n + 1) * elements_per_shard]: 295 | tar.add("data_out/" + dataset_output_path + "/test/{}".format(file)) 296 | -------------------------------------------------------------------------------- /data/visualize_data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from matplotlib.colors import Normalize 4 | from matplotlib.cm import ScalarMappable 5 | 6 | def plot_ball_trajectory(images): 7 | # Define number of timesteps and image dimensions 8 | timesteps, dim1, dim2 = images.shape 9 | 10 | # Use 'turbo' colormap for a bright and high-contrast gradient 11 | cmap = plt.get_cmap("turbo") 12 | norm = Normalize(vmin=2, vmax=timesteps - 1) 13 | sm = ScalarMappable(cmap=cmap, norm=norm) 14 | 15 | # Create a base figure for the trajectory plot 16 | plt.figure(figsize=(6, 6)) 17 | plt.axis("off") 18 | 19 | # Plot each timestep's image on the plot with a color gradient 20 | for t in range(timesteps): 21 | # Get color for the current timestep 22 | color = cmap(norm(t))[:3] # RGB color for the current timestep 23 | 24 | # Create an RGB mask for the current frame 25 | mask = images[t] > 0 # Assuming ball is represented by non-zero values 26 | colored_image = np.zeros((dim1, dim2, 3)) # Create an RGB image 27 | for c in range(3): # Assign the color to masked areas only 28 | colored_image[:, :, c] = mask * color[c] 29 | 30 | # Plot the colored mask with some transparency 31 | plt.imshow(1 - colored_image, alpha=0.3) 32 | 33 | # Add color bar to indicate time progression 34 | plt.colorbar(sm, label="Time Steps") 35 | plt.title("Ball Trajectory Over Time", color='white') 36 | plt.gca().set_facecolor('black') # Set the background to black for contrast 37 | plt.show() 38 | 39 | images = np.load("bouncingball_10000samples_40steps/train.npz", allow_pickle=True)["image"] 40 | plot_ball_trajectory(images[0]) 41 | plot_ball_trajectory(images[1]) 42 | plot_ball_trajectory(images[2]) -------------------------------------------------------------------------------- /experiments/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qu-gg/torch-neural-ssm/13813f75e18b9efd5cedc40bd7d707b0a1e7870e/experiments/README.md -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """ 2 | @file main.py 3 | 4 | Main entrypoint for the training and testing environments. Takes in a configuration file 5 | of arguments and either trains a model or tests a given model and checkpoint. 6 | """ 7 | import torch 8 | import hydra 9 | import pytorch_lightning 10 | import pytorch_lightning.loggers as pl_loggers 11 | 12 | from omegaconf import DictConfig 13 | from utils.dataloader import SSMDataModule 14 | from utils.utils import get_model, flatten_cfg 15 | from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, LearningRateMonitor 16 | 17 | 18 | @hydra.main(version_base="1.3", config_path="configs", config_name="config") 19 | def main(cfg: DictConfig): 20 | # Flatten the cfg down a level 21 | cfg.expname = cfg.expname 22 | cfg = flatten_cfg(cfg) 23 | 24 | # Set a consistent seed over the full set for consistent analysis 25 | pytorch_lightning.seed_everything(cfg.seed, workers=True) 26 | 27 | # Enable fp16 training 28 | torch.backends.cudnn.benchmark = True 29 | torch.backends.cudnn.allow_tf32 = True 30 | torch.set_float32_matmul_precision('medium') 31 | 32 | # Limit number of CPU workers 33 | torch.set_num_threads(8) 34 | 35 | # Building the PL-DataModule for all splits 36 | datamodule = SSMDataModule(cfg=cfg) 37 | 38 | # Initialize model type and initialize 39 | model = get_model(cfg.model_type)(cfg) 40 | 41 | # Tensorboard Logger 42 | tb_logger = pl_loggers.TensorBoardLogger(save_dir=f"experiments/{cfg.expname}/", name=f"{cfg.model_type}") 43 | 44 | # Callbacks for checkpointing and early stopping 45 | checkpoint_callback = ModelCheckpoint(monitor='val_reconstruction_mse', 46 | filename='step{step:02d}-val_reconstruction_mse{val_reconstruction_mse:.4f}', 47 | auto_insert_metric_name=False, save_last=True) 48 | early_stop_callback = EarlyStopping(monitor="val_reconstruction_mse", min_delta=0.000001, patience=15, mode="min") 49 | lr_monitor = LearningRateMonitor(logging_interval='step') 50 | 51 | # Initialize trainer 52 | trainer = pytorch_lightning.Trainer( 53 | callbacks=[ 54 | # early_stop_callback, 55 | lr_monitor, 56 | checkpoint_callback 57 | ], 58 | accelerator=cfg.accelerator, 59 | devices=cfg.devices, 60 | deterministic=True, 61 | max_steps=cfg.num_steps * cfg.batch_size, 62 | max_epochs=1, 63 | gradient_clip_val=5.0, 64 | val_check_interval=cfg.metric_interval, 65 | num_sanity_val_steps=0, 66 | logger=tb_logger 67 | ) 68 | 69 | # Training from scratch 70 | if cfg.train is True and cfg.resume is False: 71 | trainer.fit(model, datamodule) 72 | 73 | # Training from the last epoch 74 | elif cfg.train is True and cfg.resume is True: 75 | ckpt_path = tb_logger.log_dir + "/checkpoints/" + cfg.checkpt if cfg.checkpt != "" \ 76 | else f"{tb_logger.log_dir}/checkpoints/last.ckpt" 77 | 78 | trainer.fit(model, datamodule, ckpt_path=ckpt_path) 79 | 80 | # Testing the model on each split 81 | ckpt_path = tb_logger.log_dir + "/checkpoints/" + cfg.checkpt if cfg.checkpt != "" \ 82 | else f"{tb_logger.log_dir}/checkpoints/last.ckpt" 83 | 84 | model.setting = 'train' 85 | trainer.test(model, datamodule.evaluate_train_dataloader(), ckpt_path=ckpt_path) 86 | 87 | model.setting = 'val' 88 | trainer.test(model, datamodule.val_dataloader(), ckpt_path=ckpt_path) 89 | 90 | model.setting = 'test' 91 | trainer.test(model, datamodule.test_dataloader(), ckpt_path=ckpt_path) 92 | 93 | 94 | if __name__ == '__main__': 95 | main() 96 | -------------------------------------------------------------------------------- /models/CommonDynamics.py: -------------------------------------------------------------------------------- 1 | """ 2 | @file CommonDynamics.py 3 | 4 | A common class that each latent dynamics function inherits. 5 | Holds the training + validation step logic and the VAE components for reconstructions. 6 | """ 7 | import os 8 | import json 9 | import torch 10 | import numpy as np 11 | import torch.nn as nn 12 | import pytorch_lightning 13 | import utils.metrics as metrics 14 | import matplotlib.pyplot as plt 15 | 16 | from sklearn.manifold import TSNE 17 | from models.CommonVAE import LatentStateEncoder, EmissionDecoder 18 | from utils.plotting import show_images, get_embedding_trajectories 19 | from utils.utils import determine_annealing_factor, CosineAnnealingWarmRestartsWithDecayAndLinearWarmup 20 | 21 | 22 | class LatentDynamicsModel(pytorch_lightning.LightningModule): 23 | def __init__(self, cfg): 24 | """ 25 | Generic implementation of a Latent Dynamics Model 26 | Holds the training and testing boilerplate, as well as experiment tracking 27 | :param cfg: passed in hydra configdict 28 | """ 29 | super().__init__() 30 | # Config 31 | self.cfg = cfg 32 | self.setting = 'train' 33 | 34 | # Encoder + Decoder 35 | self.encoder = LatentStateEncoder(cfg) 36 | self.decoder = EmissionDecoder(cfg) 37 | 38 | # Recurrent dynamics function 39 | self.dynamics_func = None 40 | 41 | # Number of steps for training 42 | self.n_updates = 0 43 | 44 | # Loss function 45 | self.reconstruction_loss = nn.MSELoss(reduction='none') 46 | 47 | # Variable to hold batch outputs to manually log 48 | self.outputs = list() 49 | 50 | def forward(self, x, generation_len): 51 | """ Placeholder function for the dynamics forward pass """ 52 | raise NotImplementedError("In forward function: Latent Dynamics function not specified.") 53 | 54 | def model_specific_loss(self, x, x_rec, train=True): 55 | """ Placeholder function for any additional loss terms a dynamics function may have """ 56 | return 0.0 57 | 58 | def model_specific_plotting(self, version_path, outputs): 59 | """ Placeholder function for any additional plots a dynamics function may have """ 60 | return None 61 | 62 | def configure_optimizers(self): 63 | """ 64 | Most standard NSSM models have a joint optimization step under one ELBO, however there is room 65 | for EM-optimization procedures based on the PGM. 66 | 67 | By default, we assume a joint optim with the Adam Optimizer. We additionally include LR Warmup and 68 | CosineAnnealing with decay for standard learning rate care during training. 69 | 70 | For CosineAnnealing, we set the LR bounds to be [LR * 1e-2, LR] 71 | """ 72 | optim = torch.optim.AdamW(self.parameters(), lr=self.cfg.learning_rate) 73 | 74 | # Build the scheduler if using it 75 | if self.cfg.scheduler_use is True: 76 | scheduler = CosineAnnealingWarmRestartsWithDecayAndLinearWarmup( 77 | optim, 78 | T_0=self.cfg.scheduler_restart_interval, T_mult=1, 79 | eta_min=self.cfg.learning_rate * 1e-2, 80 | warmup_steps=self.cfg.scheduler_warmup_steps, 81 | decay=self.cfg.scheduler_decay 82 | ) 83 | 84 | # Explicit dictionary to state how often to ping the scheduler 85 | scheduler = { 86 | 'scheduler': scheduler, 87 | 'frequency': 1, 88 | 'interval': 'step' 89 | } 90 | 91 | return [optim], [scheduler] 92 | 93 | # Otherwise just return the optimizer 94 | return optim 95 | 96 | def on_train_start(self): 97 | """ 98 | Before a training session starts, we set some model variables and save a JSON configuration of the 99 | used hyperparameters to allow for easy load-in at test-time. 100 | """ 101 | # Get total number of parameters for the model and save 102 | self.log("total_num_parameters", float(sum(p.numel() for p in self.parameters() if p.requires_grad)), prog_bar=False) 103 | 104 | # Make image dir in lightning experiment folder if it doesn't exist 105 | if not os.path.exists(f"{self.logger.log_dir}/images/"): 106 | os.mkdir(f"{self.logger.log_dir}/images/") 107 | 108 | def get_step_outputs(self, batch, generation_len): 109 | """ 110 | Handles the process of pre-processing and subsequence sampling a batch, 111 | as well as getting the outputs from the models regardless of step 112 | :param batch: list of dictionary objects representing a single image 113 | :param generation_len: how far out to generate for, dependent on the step (train/val) 114 | :return: processed model outputs 115 | """ 116 | # Deconstruct batch 117 | _, images, states, _, labels = batch 118 | 119 | # Set the length of z_amort depending on training/testing 120 | if self.trainer.training: 121 | self.z_amort = self.cfg.z_amort_train 122 | else: 123 | self.z_amort = self.cfg.z_amort_test 124 | 125 | # Same random portion of the sequence over generation_len, saving room for backwards solving 126 | if max(images.shape[1] - self.z_amort - generation_len, 0) > 0: 127 | random_start = np.random.randint(0, images.shape[1] - self.z_amort - generation_len) 128 | images = images[:, random_start:random_start + generation_len + self.z_amort] 129 | states = states[:, random_start:random_start + generation_len + self.z_amort] 130 | 131 | # Get predictions 132 | preds, embeddings = self(images, generation_len) 133 | 134 | # Restrict images to start from after inference, for metrics and likelihood 135 | images = images[:, self.z_amort:] 136 | states = states[:, self.z_amort:] 137 | return images, states, labels, preds, embeddings 138 | 139 | def get_step_losses(self, images, preds): 140 | """ 141 | Handles getting the ELBO terms for the given step 142 | :param images: ground truth images 143 | :param images_rev: grouth truth images, reversed for some models' secondary TRS loss 144 | :param preds: forward predictions from the model 145 | :return: likelihood, kl on z0, model-specific dynamics loss 146 | """ 147 | # Reconstruction loss for the sequence and z0 148 | likelihood = self.reconstruction_loss(preds, images) 149 | likelihood = likelihood.reshape([likelihood.shape[0] * likelihood.shape[1], -1]).sum([-1]).mean() 150 | 151 | # Initial encoder loss, KL[q(z_K|x_0:K) || p(z_K)] 152 | klz = self.encoder.kl_z_term() 153 | 154 | # Get the loss terms from the specific latent dynamics loss 155 | dynamics_loss = self.model_specific_loss(images, preds) 156 | return likelihood, klz, dynamics_loss 157 | 158 | def get_epoch_metrics(self, outputs, length=20): 159 | """ 160 | Takes the dictionary of saved batch metrics, stacks them, and gets outputs to log in the Tensorboard. 161 | :param outputs: list of dictionaries with outputs from each back 162 | :return: dictionary of metrics aggregated over the epoch 163 | """ 164 | # Convert outputs to Tensors and then Numpy arrays 165 | images = torch.vstack([out["images"] for out in outputs]).cpu().numpy() 166 | preds = torch.vstack([out["preds"] for out in outputs]).cpu().numpy() 167 | 168 | # Iterate through each metric function and add to a dictionary 169 | out_metrics = {} 170 | for met in self.cfg.metrics: 171 | metric_function = getattr(metrics, met) 172 | out_metrics[met] = metric_function(images, preds, cfg=self.cfg, length=length)[0] 173 | 174 | # Return a dictionary of metrics 175 | return out_metrics 176 | 177 | def training_step(self, batch, batch_idx): 178 | """ 179 | PyTorch-Lightning training step where the network is propagated and returns a single loss value, 180 | which is automatically handled for the backward update 181 | :param batch: list of dictionary objects representing a single image 182 | :param batch_idx: how far in the epoch this batch is 183 | """ 184 | # Get model outputs from batch 185 | images, _, labels, preds, _ = self.get_step_outputs(batch, self.cfg.train_length) 186 | 187 | # Get model loss terms for the step 188 | likelihood, klz, dynamics_loss = self.get_step_losses(images, preds) 189 | 190 | # Determine KL annealing factor for the current step 191 | kl_factor = determine_annealing_factor(self.n_updates, anneal_update=1000) 192 | 193 | # Build the full loss 194 | loss = likelihood + kl_factor * ((self.cfg.beta_z0 * klz) + (self.cfg.beta_kl * dynamics_loss)) 195 | 196 | # Log ELBO loss terms 197 | self.log_dict({ 198 | "likelihood": likelihood, 199 | "kl_z": self.cfg.beta_z0 * klz, 200 | "dynamics_loss": self.cfg.beta_kl * dynamics_loss, 201 | "kl_factor": kl_factor 202 | }) 203 | 204 | # Log metrics every N batches 205 | if len(self.outputs) < self.cfg.batches_to_save: 206 | self.outputs.append({"loss": loss, "labels": labels.detach(), "preds": preds.detach(), "images": images.detach()}) 207 | 208 | # Return the loss for updating and track the iteration number 209 | self.n_updates += 1 210 | return {"loss": loss} 211 | 212 | def on_train_batch_end(self, outputs, batch, batch_idx): 213 | """ Given the iterative training, check on every batch's end whether it is evaluation time or not """ 214 | # Show side-by-side reconstructions 215 | if batch_idx % self.cfg.image_interval == 0 and batch_idx != 0: 216 | show_images(self.outputs[0]["images"], self.outputs[0]["preds"], f'{self.logger.log_dir}/images/recon{batch_idx}train.png', num_out=5) 217 | 218 | # Get per-dynamics plots 219 | self.model_specific_plotting(self.logger.log_dir, self.outputs) 220 | 221 | # Get metrics over the window of batches and clear output buffer 222 | if batch_idx % self.cfg.metric_interval == 0 and batch_idx != 0: 223 | metrics = self.get_epoch_metrics(self.outputs[:self.cfg.batches_to_save], length=self.cfg.train_length) 224 | for metric in metrics.keys(): 225 | self.log(f"train_{metric}", metrics[metric], prog_bar=True) 226 | 227 | self.outputs = list() 228 | 229 | def validation_step(self, batch, batch_idx): 230 | """ 231 | PyTorch-Lightning validation step. Similar to the training step but on the given val set under torch.no_grad() 232 | :param batch: list of dictionary objects representing a single image 233 | :param batch_idx: how far in the epoch this batch is 234 | """ 235 | # Get model outputs from batch 236 | images, _, _, preds, _ = self.get_step_outputs(batch, self.cfg.val_length) 237 | 238 | # Get model loss terms for the step 239 | likelihood, _, _ = self.get_step_losses(images, preds) 240 | 241 | # Log validation likelihood and metrics 242 | self.log("val_likelihood", likelihood, prog_bar=True) 243 | 244 | # Return outputs as dict 245 | out = {"loss": likelihood} 246 | if batch_idx < self.cfg.batches_to_save: 247 | out["preds"] = preds.detach() 248 | out["images"] = images.detach() 249 | return out 250 | 251 | def validation_epoch_end(self, outputs): 252 | """ 253 | Every N epochs, get a validation reconstruction sample 254 | :param outputs: list of outputs from the validation steps at batch 0 255 | """ 256 | # Log epoch metrics on saved batches 257 | metrics = self.get_epoch_metrics(self.outputs[:self.cfg.batches_to_save], length=self.cfg.val_length) 258 | for metric in metrics.keys(): 259 | self.log(f"val_{metric}", metrics[metric], prog_bar=True) 260 | 261 | # Get image reconstructions 262 | if self.n_updates % self.cfg.image_interval == 0 and self.n_updates != 0: 263 | show_images(outputs[0]["images"], outputs[0]["preds"], f'{self.logger.log_dir}/images/recon{self.n_updates}val.png', num_out=5) 264 | 265 | def test_step(self, batch, batch_idx): 266 | """ 267 | PyTorch-Lightning testing step. 268 | :param batch: list of dictionary objects representing a single image 269 | :param batch_idx: how far in the epoch this batch is 270 | """ 271 | # Get model outputs from batch 272 | images, states, labels, preds, embeddings = self.get_step_outputs(batch, self.cfg.test_length) 273 | 274 | # Build output dictionary 275 | out = {"states": states.detach().cpu(), "embeddings": embeddings.detach().cpu(), 276 | "preds": preds.detach().cpu(), "images": images.detach().cpu(), "labels": labels.detach().cpu()} 277 | return out 278 | 279 | def test_epoch_end(self, batch_outputs): 280 | """ 281 | For testing end, save the predictions, gt, and MSE to NPY files in the respective experiment folder 282 | :param outputs: list of outputs from the validation steps at batch 0 283 | """ 284 | # Stack all output types and convert to numpy 285 | outputs = dict() 286 | for key in batch_outputs[0].keys(): 287 | outputs[key] = torch.vstack([output[key] for output in batch_outputs]).numpy() 288 | 289 | # Iterate through each metric function and add to a dictionary 290 | out_metrics = {} 291 | for met in self.cfg.test_metrics: 292 | metric_function = getattr(metrics, met) 293 | metric_mean, metric_std = metric_function(outputs["images"], outputs["preds"], cfg=self.cfg, length=self.cfg.test_length) 294 | out_metrics[f"{met}_mean"], out_metrics[f"{met}_std"] = float(metric_mean), float(metric_std) 295 | print(f"=> {met}: {metric_mean:4.5f}+-{metric_std:4.5f}") 296 | 297 | # Set up output path and create dir 298 | output_path = f"{self.logger.log_dir}/test_{self.setting}" 299 | if not os.path.exists(output_path): 300 | os.mkdir(output_path) 301 | 302 | # Save some examples 303 | show_images(outputs["images"][:10], outputs["preds"][:10], f"{output_path}/test_{self.setting}_examples.png", num_out=5) 304 | 305 | # Save trajectory examples 306 | get_embedding_trajectories(outputs["embeddings"][0], outputs["states"][0], f"{output_path}/") 307 | 308 | # Get Z0 TSNE 309 | tsne = TSNE(n_components=2, perplexity=30, learning_rate=200, n_iter=1000, early_exaggeration=12) 310 | fitted = tsne.fit(outputs["embeddings"][:, 0]) 311 | tsne_embedding = fitted.embedding_ 312 | 313 | for i in np.unique(outputs["labels"]): 314 | subset = tsne_embedding[np.where(outputs["labels"] == i)[0], :] 315 | plt.scatter(subset[:, 0], subset[:, 1], c=next(plt.gca()._get_lines.prop_cycler)['color']) 316 | 317 | plt.title("t-SNE Plot of Z0 Embeddings") 318 | plt.legend(np.unique(outputs["labels"]), loc='center left', bbox_to_anchor=(1, 0.5)) 319 | plt.savefig(f"{output_path}/test_{self.setting}_Z0tsne.png", bbox_inches='tight') 320 | plt.close() 321 | 322 | # Save metrics to JSON in checkpoint folder 323 | with open(f"{output_path}/test_{self.setting}_metrics.json", 'w') as f: 324 | json.dump(out_metrics, f) 325 | 326 | # Save metrics to an easy excel conversion style 327 | with open(f"{output_path}/test_{self.setting}_excel.txt", 'w') as f: 328 | for metric in self.cfg.metrics: 329 | f.write(f"{out_metrics[f'{metric}_mean']:0.3f}({out_metrics[f'{metric}_std']:0.3f}),") 330 | -------------------------------------------------------------------------------- /models/CommonVAE.py: -------------------------------------------------------------------------------- 1 | """ 2 | @file CommonVAE.py 3 | 4 | Holds the encoder/decoder architectures that are shared across the NSSM works 5 | """ 6 | import torch 7 | import torch.nn as nn 8 | 9 | from utils.layers import Flatten, Gaussian, UnFlatten 10 | from torch.distributions import Normal, kl_divergence as kl 11 | 12 | 13 | class LatentStateEncoder(nn.Module): 14 | def __init__(self, cfg): 15 | """ 16 | Holds the convolutional encoder that takes in a sequence of images and outputs the 17 | initial state of the latent dynamics 18 | :param cfg: hydra configdict 19 | """ 20 | super(LatentStateEncoder, self).__init__() 21 | self.cfg = cfg 22 | self.z_amort = cfg.z_amort_train 23 | 24 | # Encoder, q(z_0 | x_{0:cfg.z_amort}) 25 | self.initial_encoder = nn.Sequential( 26 | nn.Conv2d(self.z_amort, cfg.num_filters, kernel_size=5, stride=2, padding=(2, 2)), # 14,14 27 | nn.BatchNorm2d(cfg.num_filters), 28 | nn.ReLU(), 29 | nn.Conv2d(cfg.num_filters, cfg.num_filters * 2, kernel_size=5, stride=2, padding=(2, 2)), # 7,7 30 | nn.BatchNorm2d(cfg.num_filters * 2), 31 | nn.ReLU(), 32 | nn.Conv2d(cfg.num_filters * 2, cfg.num_filters * 8, kernel_size=5, stride=2, padding=(2, 2)), 33 | nn.BatchNorm2d(cfg.num_filters * 8), 34 | nn.ReLU(), 35 | nn.AvgPool2d(4), 36 | Flatten() 37 | ) 38 | 39 | self.stochastic_out = Gaussian(cfg.num_filters * 8, cfg.latent_dim) 40 | self.deterministic_out = nn.Linear(cfg.num_filters * 8, cfg.latent_dim) 41 | self.out_act = nn.Tanh() 42 | 43 | # Holds generated z0 means and logvars for use in KL calculations 44 | self.z_means = None 45 | self.z_logvs = None 46 | 47 | def kl_z_term(self): 48 | """ 49 | KL Z term, KL[q(z0|X) || N(0,1)] 50 | :return: mean klz across batch 51 | """ 52 | if self.cfg.stochastic is False: 53 | return 0.0 54 | 55 | batch_size = self.z_means.shape[0] 56 | mus, logvars = self.z_means.view([-1]), self.z_logvs.view([-1]) # N, 2 57 | 58 | q = Normal(mus, torch.exp(0.5 * logvars)) 59 | N = Normal(torch.zeros(len(mus), device=mus.device), 60 | torch.ones(len(mus), device=mus.device)) 61 | 62 | klz = kl(q, N).view([batch_size, -1]).sum([1]).mean() 63 | return klz 64 | 65 | def forward(self, x): 66 | """ 67 | Handles getting the initial state given x and saving the distributional parameters 68 | :param x: input sequences [BatchSize, GenerationLen * NumChannels, H, W] 69 | :return: z0 over the batch [BatchSize, LatentDim] 70 | """ 71 | z0 = self.initial_encoder(x[:, :self.z_amort]) 72 | 73 | # Apply the Gaussian layer if stochastic version 74 | if self.cfg.stochastic is True: 75 | self.z_means, self.z_logvs, z0 = self.stochastic_out(z0) 76 | else: 77 | z0 = self.deterministic_out(z0) 78 | 79 | return self.out_act(z0) 80 | 81 | 82 | class EmissionDecoder(nn.Module): 83 | def __init__(self, cfg): 84 | """ 85 | Holds the convolutional decoder that takes in a batch of individual latent states and 86 | transforms them into their corresponding data space reconstructions 87 | """ 88 | super(EmissionDecoder, self).__init__() 89 | self.cfg = cfg 90 | 91 | # Variable that holds the estimated output for the flattened convolution vector 92 | self.conv_dim = cfg.num_filters * 4 ** 3 93 | 94 | # Emission model handling z_i -> x_i 95 | self.decoder = nn.Sequential( 96 | # Transform latent vector into 4D tensor for deconvolution 97 | nn.Linear(cfg.latent_dim, self.conv_dim), 98 | UnFlatten(4), 99 | 100 | # Perform de-conv to output space 101 | nn.ConvTranspose2d(self.conv_dim // 16, cfg.num_filters * 4, kernel_size=4, stride=1, padding=(0, 0)), 102 | nn.BatchNorm2d(cfg.num_filters * 4), 103 | nn.ReLU(), 104 | nn.ConvTranspose2d(cfg.num_filters * 4, cfg.num_filters * 2, kernel_size=5, stride=2, padding=(1, 1)), 105 | nn.BatchNorm2d(cfg.num_filters * 2), 106 | nn.ReLU(), 107 | nn.ConvTranspose2d(cfg.num_filters * 2, cfg.num_filters, kernel_size=5, stride=2, padding=(1, 1), output_padding=(1, 1)), 108 | nn.BatchNorm2d(cfg.num_filters), 109 | nn.ReLU(), 110 | nn.ConvTranspose2d(cfg.num_filters, cfg.num_channels, kernel_size=5, stride=1, padding=(2, 2)), 111 | nn.Sigmoid(), 112 | ) 113 | 114 | def forward(self, zts): 115 | """ 116 | Handles decoding a batch of individual latent states into their corresponding data space reconstructions 117 | :param zts: latent states [BatchSize * GenerationLen, LatentDim] 118 | :return: data output [BatchSize, GenerationLen, NumChannels, H, W] 119 | """ 120 | # Decode back to image space, after first flattening BatchSize * SeqLen 121 | x_rec = self.decoder(zts.contiguous().view([zts.shape[0] * zts.shape[1], -1])) 122 | 123 | # Reshape to image output 124 | x_rec = x_rec.view([zts.shape[0], x_rec.shape[0] // zts.shape[0], self.cfg.dim, self.cfg.dim]) 125 | return x_rec 126 | -------------------------------------------------------------------------------- /models/group_a/DKF.py: -------------------------------------------------------------------------------- 1 | """ 2 | @file DKF.py 3 | 4 | Holds the model for the Deep Kalman Filter baseline, based off of the code from 5 | @url{https://github.com/john-x-jiang/meta_ssm/blob/main/model/model.py} 6 | """ 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.init as weight_init 10 | 11 | from models.CommonDynamics import LatentDynamicsModel 12 | from torch.distributions import Normal, kl_divergence as kl 13 | 14 | 15 | def reverse_sequence(x, seq_lengths): 16 | """ 17 | Brought from 18 | https://github.com/pyro-ppl/pyro/blob/dev/examples/dmm/polyphonic_data_loader.py 19 | Parameters 20 | ---------- 21 | x: tensor (b, T_max, input_dim) 22 | seq_lengths: tensor (b, ) 23 | Returns 24 | ------- 25 | x_reverse: tensor (b, T_max, input_dim) 26 | The input x in reversed order w.r.t. time-axis 27 | """ 28 | x_reverse = torch.zeros_like(x) 29 | for b in range(x.size(0)): 30 | t = seq_lengths[b] 31 | time_slice = torch.arange(t - 1, -1, -1, device=x.device) 32 | reverse_seq = torch.index_select(x[b, :, :], 0, time_slice) 33 | x_reverse[b, 0:t, :] = reverse_seq 34 | 35 | return x_reverse 36 | 37 | 38 | class RnnEncoder(nn.Module): 39 | """ 40 | RNN encoder that outputs hidden states h_t using x_{t:T} 41 | Parameters 42 | ---------- 43 | input_dim: int 44 | Dim. of inputs 45 | num_hidden: int 46 | Dim. of RNN hidden states 47 | n_layer: int 48 | Number of layers of RNN 49 | drop_rate: float [0.0, 1.0] 50 | RNN dropout rate between layers 51 | bd: bool 52 | Use bi-directional RNN or not 53 | Returns 54 | ------- 55 | h_rnn: tensor (b, T_max, num_hidden * n_direction) 56 | RNN hidden states at every time-step 57 | """ 58 | def __init__(self, cfg, input_dim, num_hidden, n_layer=1, drop_rate=0.0, bd=False, nonlin='relu', 59 | rnn_type='rnn', orthogonal_init=False, reverse_input=True): 60 | super().__init__() 61 | self.n_direction = 1 if not bd else 2 62 | self.cfg = cfg 63 | self.input_dim = input_dim 64 | self.num_hidden = num_hidden 65 | self.n_layer = n_layer 66 | self.drop_rate = drop_rate 67 | self.bd = bd 68 | self.nonlin = nonlin 69 | self.reverse_input = reverse_input 70 | 71 | if not isinstance(rnn_type, str): 72 | raise ValueError("`rnn_type` should be type str.") 73 | self.rnn_type = rnn_type 74 | if rnn_type == 'rnn': 75 | self.rnn = nn.RNN(input_size=input_dim, hidden_size=num_hidden, nonlinearity=nonlin, 76 | batch_first=True, bidirectional=bd, num_layers=n_layer, dropout=drop_rate) 77 | elif rnn_type == 'gru': 78 | self.rnn = nn.GRU(input_size=input_dim, hidden_size=num_hidden, 79 | batch_first=True, bidirectional=bd, num_layers=n_layer, dropout=drop_rate) 80 | elif rnn_type == 'lstm': 81 | self.rnn = nn.LSTM(input_size=input_dim, hidden_size=num_hidden, batch_first=True, 82 | bidirectional=bd, num_layers=n_layer, dropout=drop_rate) 83 | else: 84 | raise ValueError("`rnn_type` must instead be ['rnn', 'gru', 'lstm'] %s" % rnn_type) 85 | 86 | if orthogonal_init: 87 | self.init_weights() 88 | 89 | def init_weights(self): 90 | for w in self.rnn.parameters(): 91 | if w.dim() > 1: 92 | weight_init.orthogonal_(w) 93 | 94 | def calculate_effect_dim(self): 95 | return self.num_hidden * self.n_direction 96 | 97 | def init_hidden(self, trainable=True): 98 | if self.rnn_type == 'lstm': 99 | h0 = nn.Parameter(torch.zeros(self.n_layer * self.n_direction, 1, self.num_hidden), requires_grad=trainable) 100 | c0 = nn.Parameter(torch.zeros(self.n_layer * self.n_direction, 1, self.num_hidden), requires_grad=trainable) 101 | return h0, c0 102 | else: 103 | h0 = nn.Parameter(torch.zeros(self.n_layer * self.n_direction, 1, self.num_hidden), requires_grad=trainable) 104 | return h0 105 | 106 | def kl_z_term(self): 107 | """ 108 | KL Z term, KL[q(z0|X) || N(0,1)] 109 | :return: mean klz across batch 110 | """ 111 | return torch.Tensor([0]).to(self.cfg.devices[0]) 112 | 113 | def forward(self, x): 114 | """ 115 | x: pytorch packed object 116 | input packed data; this can be obtained from 117 | `util.get_mini_batch()` 118 | h0: tensor (n_layer * n_direction, b, num_hidden) 119 | """ 120 | B, T, _ = x.shape 121 | seq_lengths = T * torch.ones(B).int().to(self.cfg.devices[0]) 122 | h_rnn, _ = self.rnn(x) 123 | if self.reverse_input: 124 | h_rnn = reverse_sequence(h_rnn, seq_lengths) 125 | return h_rnn 126 | 127 | 128 | class Transition_Recurrent(nn.Module): 129 | """ 130 | Parameterize the diagonal Gaussian latent transition probability 131 | `p(z_t | z_{t-1})` 132 | Parameters 133 | ---------- 134 | z_dim: int 135 | Dim. of latent variables 136 | transition_dim: int 137 | Dim. of transition hidden units 138 | gated: bool 139 | Use the gated mechanism to consider both linearity and non-linearity 140 | identity_init: bool 141 | Initialize the linearity transform as an identity matrix; 142 | ignored if `gated == False` 143 | clip: bool 144 | clip the value for numerical issues 145 | Returns 146 | ------- 147 | mu: tensor (b, z_dim) 148 | Mean that parameterizes the Gaussian 149 | logvar: tensor (b, z_dim) 150 | Log-variance that parameterizes the Gaussian 151 | """ 152 | 153 | def __init__(self, z_dim, transition_dim, identity_init=True, domain=False, stochastic=True): 154 | super().__init__() 155 | self.z_dim = z_dim 156 | self.transition_dim = transition_dim 157 | self.identity_init = identity_init 158 | self.domain = domain 159 | self.stochastic = stochastic 160 | 161 | if domain: 162 | # compute the gain (gate) of non-linearity 163 | self.lin1 = nn.Linear(z_dim * 2, transition_dim * 2) 164 | self.lin2 = nn.Linear(transition_dim * 2, z_dim) 165 | # compute the proposed mean 166 | self.lin3 = nn.Linear(z_dim * 2, transition_dim * 2) 167 | self.lin4 = nn.Linear(transition_dim * 2, z_dim) 168 | # linearity 169 | self.lin0 = nn.Linear(z_dim * 2, z_dim) 170 | else: 171 | # compute the gain (gate) of non-linearity 172 | self.lin1 = nn.Linear(z_dim, transition_dim) 173 | self.lin2 = nn.Linear(transition_dim, z_dim) 174 | # compute the proposed mean 175 | self.lin3 = nn.Linear(z_dim, transition_dim) 176 | self.lin4 = nn.Linear(transition_dim, z_dim) 177 | self.lin0 = nn.Linear(z_dim, z_dim) 178 | 179 | # compute the linearity part 180 | self.lin_n = nn.Linear(z_dim, z_dim) 181 | 182 | if identity_init: 183 | self.lin_n.weight.data = torch.eye(z_dim) 184 | self.lin_n.bias.data = torch.zeros(z_dim) 185 | 186 | # compute the variation 187 | self.lin_v = nn.Linear(z_dim, z_dim) 188 | # var activation 189 | # self.act_var = nn.Softplus() 190 | self.act_var = nn.Tanh() 191 | 192 | self.act_weight = nn.Sigmoid() 193 | self.act = nn.ELU(inplace=True) 194 | 195 | def init_z_0(self, trainable=True): 196 | return nn.Parameter(torch.zeros(self.z_dim), requires_grad=trainable), \ 197 | nn.Parameter(torch.zeros(self.z_dim), requires_grad=trainable) 198 | 199 | def forward(self, z_t_1, z_c=None): 200 | if self.domain: 201 | z_combine = torch.cat((z_t_1, z_c), dim=1) 202 | _g_t = self.act(self.lin1(z_combine)) 203 | g_t = self.act_weight(self.lin2(_g_t)) 204 | _h_t = self.act(self.lin3(z_combine)) 205 | h_t = self.act(self.lin4(_h_t)) 206 | _mu = self.lin0(z_combine) 207 | mu = (1 - g_t) * self.lin_n(_mu) + g_t * h_t 208 | mu = mu + _mu 209 | else: 210 | _g_t = self.act(self.lin1(z_t_1)) 211 | g_t = self.act_weight(self.lin2(_g_t)) 212 | _h_t = self.act(self.lin3(z_t_1)) 213 | h_t = self.act(self.lin4(_h_t)) 214 | mu = (1 - g_t) * self.lin_n(z_t_1) + g_t * h_t 215 | _mu = self.lin0(z_t_1) 216 | mu = mu + _mu 217 | 218 | if self.stochastic: 219 | _var = self.lin_v(h_t) 220 | _var = torch.clamp(_var, min=-100, max=85) 221 | var = self.act_var(_var) 222 | return mu, var 223 | else: 224 | return mu 225 | 226 | 227 | class Correction(nn.Module): 228 | """ 229 | Parameterize variational distribution `q(z_t | z_{t-1}, x_{t:T})` 230 | a diagonal Gaussian distribution 231 | Parameters 232 | ---------- 233 | z_dim: int 234 | Dim. of latent variables 235 | num_hidden: int 236 | Dim. of RNN hidden states 237 | clip: bool 238 | clip the value for numerical issues 239 | Returns 240 | ------- 241 | mu: tensor (b, z_dim) 242 | Mean that parameterizes the variational Gaussian distribution 243 | logvar: tensor (b, z_dim) 244 | Log-var that parameterizes the variational Gaussian distribution 245 | """ 246 | def __init__(self, z_dim, num_hidden, stochastic=True): 247 | super().__init__() 248 | self.z_dim = z_dim 249 | self.num_hidden = num_hidden 250 | self.stochastic = stochastic 251 | 252 | self.lin1 = nn.Linear(z_dim, num_hidden) 253 | self.act = nn.Tanh() 254 | 255 | self.lin2 = nn.Linear(num_hidden, z_dim) 256 | self.lin_v = nn.Linear(num_hidden, z_dim) 257 | 258 | # self.act_var = nn.Softplus() 259 | self.act_var = nn.Tanh() 260 | 261 | def init_z_q_0(self, trainable=True): 262 | return nn.Parameter(torch.zeros(self.z_dim), requires_grad=trainable) 263 | 264 | def forward(self, h_rnn, z_t_1=None, rnn_bidirection=False): 265 | """ 266 | z_t_1: tensor (b, z_dim) 267 | h_rnn: tensor (b, num_hidden) 268 | """ 269 | assert z_t_1 is not None 270 | h_comb_ = self.act(self.lin1(z_t_1)) 271 | if rnn_bidirection: 272 | h_comb = (1.0 / 3) * (h_comb_ + h_rnn[:, :self.num_hidden] + h_rnn[:, self.num_hidden:]) 273 | else: 274 | h_comb = 0.5 * (h_comb_ + h_rnn) 275 | mu = self.lin2(h_comb) 276 | 277 | if self.stochastic: 278 | _var = self.lin_v(h_comb) 279 | _var = torch.clamp(_var, min=-100, max=85) 280 | var = self.act_var(_var) 281 | return mu, var 282 | else: 283 | return mu 284 | 285 | 286 | class DKF(LatentDynamicsModel): 287 | def __init__(self, cfg): 288 | """ Deep Kalman Filter model """ 289 | super().__init__(cfg) 290 | 291 | # observation 292 | self.embedding = nn.Sequential( 293 | nn.Linear(self.cfg.dim**2, self.cfg.dim**2 // 4), 294 | nn.ReLU(), 295 | # nn.Linear(self.cfg.dim**2 // 4, self.cfg.dim**2 // 8), 296 | # nn.ReLU(), 297 | nn.Linear(self.cfg.dim**2 // 4, self.cfg.num_hidden), 298 | nn.ReLU() 299 | ) 300 | self.encoder = RnnEncoder(cfg, self.cfg.num_hidden, self.cfg.num_hidden, n_layer=self.cfg.rnn_layers, drop_rate=0.0, 301 | bd=self.cfg.bidirectional, nonlin='relu', rnn_type=self.cfg.rnn_type, reverse_input=False) 302 | 303 | # generative model 304 | self.transition = Transition_Recurrent(z_dim=self.cfg.latent_dim, transition_dim=self.cfg.transition_dim) 305 | self.estimation = Correction(z_dim=self.cfg.latent_dim, num_hidden=self.cfg.num_hidden, stochastic=True) 306 | 307 | # initialize hidden states 308 | self.mu_p_0, self.var_p_0 = self.transition.init_z_0(trainable=False) 309 | self.z_q_0 = self.estimation.init_z_q_0(trainable=False) 310 | 311 | # hold p and q distribution parameters for KL term 312 | self.mu_qs = None 313 | self.var_qs = None 314 | self.mu_ps = None 315 | self.var_ps = None 316 | 317 | # placeholder for z_amort 318 | self.z_amort = None 319 | 320 | def reparameterization(self, mu, var): 321 | std = torch.exp(0.5 * var) 322 | eps = torch.randn_like(std) 323 | return mu + eps * std 324 | 325 | def latent_dynamics(self, T, x_rnn): 326 | batch_size = x_rnn.shape[0] 327 | 328 | if T > self.z_amort: 329 | T_final = T 330 | else: 331 | T_final = self.z_amort 332 | 333 | z_ = torch.zeros([batch_size, T_final, self.cfg.latent_dim]).to(self.cfg.devices[0]) 334 | if not self.cfg.use_q_forecast: 335 | mu_ps = torch.zeros([batch_size, self.z_amort, self.cfg.latent_dim]).to(self.cfg.devices[0]) 336 | var_ps = torch.zeros([batch_size, self.z_amort, self.cfg.latent_dim]).to(self.cfg.devices[0]) 337 | mu_qs = torch.zeros([batch_size, self.z_amort, self.cfg.latent_dim]).to(self.cfg.devices[0]) 338 | var_qs = torch.zeros([batch_size, self.z_amort, self.cfg.latent_dim]).to(self.cfg.devices[0]) 339 | else: 340 | mu_ps = torch.zeros([batch_size, T_final, self.cfg.latent_dim]).to(self.cfg.devices[0]) 341 | var_ps = torch.zeros([batch_size, T_final, self.cfg.latent_dim]).to(self.cfg.devices[0]) 342 | mu_qs = torch.zeros([batch_size, T_final, self.cfg.latent_dim]).to(self.cfg.devices[0]) 343 | var_qs = torch.zeros([batch_size, T_final, self.cfg.latent_dim]).to(self.cfg.devices[0]) 344 | x_qs = torch.zeros([batch_size, T_final, self.cfg.dim, self.cfg.dim]).to(self.cfg.devices[0]) 345 | 346 | z_q_0 = self.z_q_0.expand(batch_size, self.cfg.latent_dim) # q(z_0) 347 | mu_p_0 = self.mu_p_0.expand(batch_size, 1, self.cfg.latent_dim) 348 | var_p_0 = self.var_p_0.expand(batch_size, 1, self.cfg.latent_dim) 349 | z_prev = z_q_0 350 | z_[:, 0, :] = z_prev 351 | 352 | for t in range(self.z_amort): 353 | # zt = self.transition(z_prev) 354 | mu_q, var_q = self.estimation(x_rnn[:, t, :], z_prev, rnn_bidirection=self.cfg.bidirectional) 355 | zt_q = self.reparameterization(mu_q, var_q) 356 | z_prev = zt_q 357 | 358 | # p(z_{t+1} | z_t) 359 | mu_p, var_p = self.transition(z_prev) 360 | zt_p = self.reparameterization(mu_p, var_p) 361 | 362 | z_[:, t, :] = zt_q 363 | mu_qs[:, t, :] = mu_q 364 | var_qs[:, t, :] = var_q 365 | mu_ps[:, t, :] = mu_p 366 | var_ps[:, t, :] = var_p 367 | 368 | if T > self.z_amort: 369 | for t in range(self.z_amort, T): 370 | if self.cfg.use_q_forecast: 371 | y = self.decoder(z_prev.unsqueeze(1)) 372 | x_qs[:,t,:,:] = y.squeeze() 373 | y = self.encoder(self.embedding(y.view(y.size(0), 1, -1))) 374 | mu_q, var_q = self.estimation(y[:,0,:], z_prev, rnn_bidirection=self.cfg.bidirectional) 375 | zt_q = self.reparameterization(mu_q, var_q) 376 | z_prev = zt_q 377 | 378 | mu_p, var_p = self.transition(z_prev) 379 | zt_p = self.reparameterization(mu_p, var_p) 380 | 381 | z_[:, t, :] = zt_q 382 | mu_qs[:, t, :] = mu_q 383 | var_qs[:, t, :] = var_q 384 | mu_ps[:, t, :] = mu_p 385 | var_ps[:, t, :] = var_p 386 | else: 387 | # p(z_{t+1} | z_t) 388 | mu_p, var_p = self.transition(z_prev) 389 | zt_p = self.reparameterization(mu_p, var_p) 390 | z_[:, t, :] = zt_p 391 | z_prev = zt_p 392 | 393 | mu_ps = torch.cat([mu_p_0, mu_ps[:, :-1, :]], dim=1) 394 | var_ps = torch.cat([var_p_0, var_ps[:, :-1, :]], dim=1) 395 | 396 | self.mu_qs, self.var_qs = mu_qs, var_qs 397 | self.mu_ps, self.var_ps = mu_ps, var_ps 398 | if self.cfg.use_q_forecast: 399 | return z_, mu_qs, var_qs, mu_ps, var_ps, x_qs 400 | else: 401 | return z_, mu_qs, var_qs, mu_ps, var_ps 402 | 403 | def forward(self, x, generation_len): 404 | if self.trainer.training: 405 | self.z_amort = self.cfg.z_amort_train 406 | else: 407 | self.z_amort = self.cfg.z_amort_test 408 | 409 | batch_size = x.size(0) 410 | 411 | x = x.view(batch_size, generation_len, -1) 412 | x = self.embedding(x) 413 | x_rnn = self.encoder(x) 414 | 415 | if not self.cfg.use_q_forecast: 416 | z_, mu_qs, var_qs, mu_ps, var_ps = self.latent_dynamics(generation_len, x_rnn) 417 | x_ = self.decoder(z_) 418 | else: 419 | z_, mu_qs, var_qs, mu_ps, var_ps, x_ = self.latent_dynamics(generation_len, x_rnn) 420 | x_[:,:self.z_amort,:,:] = self.decoder(z_[:,:self.z_amort,:]) 421 | return x_, z_ 422 | 423 | def model_specific_loss(self, *args, train=True): 424 | """ KL term between the p and q distributions (reconstruction and estimation) """ 425 | q = Normal(self.mu_qs, torch.exp(0.5 * self.var_qs)) 426 | p = Normal(self.mu_ps, torch.exp(0.5 * self.var_ps)) 427 | return kl(q, p).sum([-1]).mean() 428 | -------------------------------------------------------------------------------- /models/group_a/VRNN.py: -------------------------------------------------------------------------------- 1 | """ 2 | @file VRNN.py 3 | 4 | Holds the model for the Variational Recurrent Neural Network baseline, source code modified from 5 | @url{https://github.com/XiaoyuBIE1994/DVAE/blob/master/dvae/model/vrnn.py} 6 | """ 7 | import torch 8 | import torch.nn as nn 9 | 10 | from collections import OrderedDict 11 | from models.CommonDynamics import LatentDynamicsModel 12 | from torch.distributions import Normal, kl_divergence as kl 13 | 14 | 15 | class FakeEncoder(nn.Module): 16 | def __init__(self, cfg): 17 | super().__init__() 18 | self.cfg = cfg 19 | 20 | def kl_z_term(self): 21 | return torch.Tensor([0.]).to(self.cfg.devices[0]) 22 | 23 | 24 | class VRNN(LatentDynamicsModel): 25 | def __init__(self, cfg): 26 | """ Latent dynamics as parameterized by a global deterministic neural ODE """ 27 | super().__init__(cfg) 28 | self.encoder = FakeEncoder(cfg) 29 | self.decoder = None 30 | 31 | ### General parameters 32 | self.x_dim = self.cfg.dim ** 2 33 | self.z_dim = self.cfg.latent_dim 34 | self.dropout_p = 0.2 35 | self.y_dim = self.x_dim 36 | self.activation = nn.LeakyReLU(0.1) 37 | self.sigmoid = nn.Sigmoid() 38 | self.z_amort = None 39 | 40 | ### Feature extractors 41 | self.dense_x = [256] 42 | self.dense_z = [256] 43 | 44 | ### Dense layers 45 | self.dense_hx_z = [128] 46 | self.dense_hz_x = [256] 47 | self.dense_h_z = [128] 48 | 49 | ### RNN 50 | self.dim_RNN = 16 51 | self.num_RNN = 1 52 | 53 | ### Beta-loss 54 | self.beta = 1 55 | 56 | ########################### 57 | #### Feature extractor #### 58 | ########################### 59 | # x 60 | dic_layers = OrderedDict() 61 | if len(self.dense_x) == 0: 62 | dim_feature_x = self.x_dim 63 | dic_layers['Identity'] = nn.Identity() 64 | else: 65 | dim_feature_x = self.dense_x[-1] 66 | for n in range(len(self.dense_x)): 67 | if n == 0: 68 | dic_layers['linear' + str(n)] = nn.Linear(self.x_dim, self.dense_x[n]) 69 | else: 70 | dic_layers['linear' + str(n)] = nn.Linear(self.dense_x[n - 1], self.dense_x[n]) 71 | dic_layers['activation' + str(n)] = self.activation 72 | dic_layers['dropout' + str(n)] = nn.Dropout(p=self.dropout_p) 73 | self.feature_extractor_x = nn.Sequential(dic_layers) 74 | # z 75 | dic_layers = OrderedDict() 76 | if len(self.dense_z) == 0: 77 | dim_feature_z = self.z_dim 78 | dic_layers['Identity'] = nn.Identity() 79 | else: 80 | dim_feature_z = self.dense_z[-1] 81 | for n in range(len(self.dense_z)): 82 | if n == 0: 83 | dic_layers['linear' + str(n)] = nn.Linear(self.z_dim, self.dense_z[n]) 84 | else: 85 | dic_layers['linear' + str(n)] = nn.Linear(self.dense_z[n - 1], self.dense_z[n]) 86 | dic_layers['activation' + str(n)] = self.activation 87 | dic_layers['dropout' + str(n)] = nn.Dropout(p=self.dropout_p) 88 | self.feature_extractor_z = nn.Sequential(dic_layers) 89 | 90 | ###################### 91 | #### Dense layers #### 92 | ###################### 93 | # 1. h_t, x_t to z_t (Inference) 94 | dic_layers = OrderedDict() 95 | if len(self.dense_hx_z) == 0: 96 | dim_hx_z = self.dim_RNN + dim_feature_x 97 | dic_layers['Identity'] = nn.Identity() 98 | else: 99 | dim_hx_z = self.dense_hx_z[-1] 100 | for n in range(len(self.dense_hx_z)): 101 | if n == 0: 102 | dic_layers['linear' + str(n)] = nn.Linear(self.dense_x[-1] + self.dim_RNN, self.dense_hx_z[n]) 103 | else: 104 | dic_layers['linear' + str(n)] = nn.Linear(self.dense_hx_z[n - 1], self.dense_hx_z[n]) 105 | dic_layers['activation' + str(n)] = self.activation 106 | dic_layers['dropout' + str(n)] = nn.Dropout(p=self.dropout_p) 107 | self.mlp_hx_z = nn.Sequential(dic_layers) 108 | self.inf_mean = nn.Linear(dim_hx_z, self.z_dim) 109 | self.inf_logvar = nn.Linear(dim_hx_z, self.z_dim) 110 | 111 | # 2. h_t to z_t (Generation z) 112 | dic_layers = OrderedDict() 113 | if len(self.dense_h_z) == 0: 114 | dim_h_z = self.dim_RNN 115 | dic_layers['Identity'] = nn.Identity() 116 | else: 117 | dim_h_z = self.dense_h_z[-1] 118 | for n in range(len(self.dense_h_z)): 119 | if n == 0: 120 | dic_layers['linear' + str(n)] = nn.Linear(self.dim_RNN, self.dense_h_z[n]) 121 | else: 122 | dic_layers['linear' + str(n)] = nn.Linear(self.dense_h_z[n - 1], self.dense_h_z[n]) 123 | dic_layers['activation' + str(n)] = self.activation 124 | dic_layers['dropout' + str(n)] = nn.Dropout(p=self.dropout_p) 125 | self.mlp_h_z = nn.Sequential(dic_layers) 126 | self.prior_mean = nn.Linear(dim_h_z, self.z_dim) 127 | self.prior_logvar = nn.Linear(dim_h_z, self.z_dim) 128 | 129 | # 3. h_t, z_t to x_t (Generation x) 130 | dic_layers = OrderedDict() 131 | if len(self.dense_hz_x) == 0: 132 | dim_hz_x = self.dim_RNN + dim_feature_z 133 | dic_layers['Identity'] = nn.Identity() 134 | else: 135 | dim_hz_x = self.dense_hz_x[-1] 136 | for n in range(len(self.dense_hz_x)): 137 | if n == 0: 138 | dic_layers['linear' + str(n)] = nn.Linear(self.dim_RNN + dim_feature_z, self.dense_hz_x[n]) 139 | else: 140 | dic_layers['linear' + str(n)] = nn.Linear(self.dense_hz_x[n - 1], self.dense_hz_x[n]) 141 | dic_layers['activation' + str(n)] = self.activation 142 | dic_layers['dropout' + str(n)] = nn.Dropout(p=self.dropout_p) 143 | self.mlp_hz_x = nn.Sequential(dic_layers) 144 | self.gen_out = nn.Linear(dim_hz_x, self.y_dim) 145 | 146 | #################### 147 | #### Recurrence #### 148 | #################### 149 | self.rnn = nn.LSTM(dim_feature_x + dim_feature_z, self.dim_RNN, self.num_RNN) 150 | 151 | def reparameterization(self, mean, logvar): 152 | std = torch.exp(0.5 * logvar) 153 | eps = torch.randn_like(std) 154 | return torch.addcmul(mean, eps, std) 155 | 156 | def generation_x(self, feature_zt, h_t): 157 | dec_input = torch.cat((feature_zt, h_t), 2) 158 | dec_output = self.mlp_hz_x(dec_input) 159 | y_t = self.gen_out(dec_output) 160 | y_t = self.sigmoid(y_t) 161 | return y_t 162 | 163 | def generation_z(self, h): 164 | prior_output = self.mlp_h_z(h) 165 | mean_prior = self.prior_mean(prior_output) 166 | logvar_prior = self.prior_logvar(prior_output) 167 | return mean_prior, logvar_prior 168 | 169 | def inference(self, feature_xt, h_t): 170 | enc_input = torch.cat((feature_xt, h_t), 2) 171 | enc_output = self.mlp_hx_z(enc_input) 172 | mean_zt = self.inf_mean(enc_output) 173 | logvar_zt = self.inf_logvar(enc_output) 174 | return mean_zt, logvar_zt 175 | 176 | def recurrence(self, feature_xt, feature_zt, h_t, c_t): 177 | rnn_input = torch.cat((feature_xt, feature_zt), -1) 178 | _, (h_tp1, c_tp1) = self.rnn(rnn_input, (h_t, c_t)) 179 | return h_tp1, c_tp1 180 | 181 | def forward(self, x, generation_len): 182 | batch_size = x.shape[0] 183 | seq_len = generation_len 184 | 185 | # Input is an image so reduce down to [batch_size, seq_len, flattened_dim] 186 | x = x.reshape(x.shape[0], x.shape[1], -1) 187 | 188 | # Permute it to [seq_len, batch_size, flattened_dim] 189 | x = x.permute(1, 0, 2) 190 | 191 | # Create variable holder and send to GPU if needed 192 | self.z_mean = torch.zeros((seq_len, batch_size, self.z_dim)).to(self.cfg.devices[0]) 193 | self.z_logvar = torch.zeros((seq_len, batch_size, self.z_dim)).to(self.cfg.devices[0]) 194 | y = torch.zeros((seq_len, batch_size, self.y_dim)).to(self.cfg.devices[0]) 195 | self.z = torch.zeros((seq_len, batch_size, self.z_dim)).to(self.cfg.devices[0]) 196 | h = torch.zeros((seq_len, batch_size, self.dim_RNN)).to(self.cfg.devices[0]) 197 | h_t = torch.zeros(self.num_RNN, batch_size, self.dim_RNN).to(self.cfg.devices[0]) 198 | c_t = torch.zeros(self.num_RNN, batch_size, self.dim_RNN).to(self.cfg.devices[0]) 199 | 200 | # For the observed frames, use real input; otherwise use previous generated frame 201 | feature_x_obs = self.feature_extractor_x(x[:self.z_amort]) 202 | for t in range(generation_len): 203 | if t < self.z_amort: 204 | feature_xt = feature_x_obs[t, :, :].unsqueeze(0) 205 | else: 206 | feature_xt = self.feature_extractor_x(y_prev) 207 | 208 | h_t_last = h_t.view(self.num_RNN, 1, batch_size, self.dim_RNN)[-1, :, :, :] 209 | mean_zt, logvar_zt = self.inference(feature_xt, h_t_last) 210 | z_t = self.reparameterization(mean_zt, logvar_zt) 211 | feature_zt = self.feature_extractor_z(z_t) 212 | y_t = self.generation_x(feature_zt, h_t_last) 213 | y_prev = y_t.detach() 214 | self.z_mean[t, :, :] = mean_zt 215 | self.z_logvar[t, :, :] = logvar_zt 216 | self.z[t, :, :] = torch.squeeze(z_t) 217 | y[t, :, :] = torch.squeeze(y_t) 218 | h[t, :, :] = torch.squeeze(h_t_last) 219 | h_t, c_t = self.recurrence(feature_xt, feature_zt, h_t, c_t) # recurrence for t+1 220 | 221 | self.z_mean_p, self.z_logvar_p = self.generation_z(h) 222 | 223 | # Reshape and permute reconstructions + embeddings back to useable shapes 224 | y = y.permute(1, 0, 2).reshape([batch_size, seq_len, self.cfg.dim, self.cfg.dim]) 225 | embeddings = self.z.permute(1, 0, 2) 226 | return y, embeddings 227 | 228 | def model_specific_loss(self, x, x_rec, train=True): 229 | """ KL term between the parameter distribution w and a normal prior""" 230 | # Reshape to [BS, SL, LatentDim] 231 | z_mus = self.z_mean.permute(1, 0, 2).reshape([x.shape[0], -1]) 232 | z_logvar = self.z_logvar.permute(1, 0, 2).reshape([x.shape[0], -1]) 233 | 234 | z_mus_prior = self.z_mean_p.permute(1, 0, 2).reshape([x.shape[0], -1]) 235 | z_logvar_prior = self.z_logvar_p.permute(1, 0, 2).reshape([x.shape[0], -1]) 236 | 237 | q = Normal(z_mus, torch.exp(0.5 * z_logvar)) 238 | N = Normal(z_mus_prior, torch.exp(0.5 * z_logvar_prior)) 239 | return kl(q, N).sum([-1]).mean() -------------------------------------------------------------------------------- /models/group_b1/KVAE.py: -------------------------------------------------------------------------------- 1 | """ 2 | @file KVAE.py 3 | 4 | Holds the model for the Kalman Variational Auto-encoder, source code modified from 5 | @url{https://github.com/XiaoyuBIE1994/DVAE/blob/master/dvae/model/kvae.py} 6 | """ 7 | import torch 8 | import numpy as np 9 | import torch.nn as nn 10 | 11 | from collections import OrderedDict 12 | from models.CommonDynamics import LatentDynamicsModel 13 | from torch.distributions.multivariate_normal import MultivariateNormal 14 | 15 | 16 | class FakeEncoder(nn.Module): 17 | def __init__(self, cfg): 18 | super().__init__() 19 | self.cfg = cfg 20 | 21 | def kl_z_term(self): 22 | return torch.Tensor([0.]).to(self.cfg.devices[0]) 23 | 24 | 25 | class KVAE(LatentDynamicsModel): 26 | def __init__(self, cfg): 27 | """ Latent dynamics as parameterized by a global deterministic neural ODE """ 28 | super().__init__(cfg) 29 | self.encoder = FakeEncoder(cfg) 30 | 31 | ## General parameters 32 | self.x_dim = self.cfg.dim**2 33 | self.y_dim = self.cfg.dim**2 34 | self.a_dim = self.cfg.latent_dim * 2 35 | self.z_dim = self.cfg.latent_dim 36 | self.u_dim = self.cfg.latent_dim * 2 37 | self.dropout_p = 0.2 38 | self.activation = nn.ReLU() 39 | 40 | # VAE 41 | self.dense_x_a = [512, 256] 42 | self.dense_a_x = [512, 256] 43 | 44 | # LGSSM 45 | self.init_kf_mat = 0.05 46 | self.noise_transition = 0.08 47 | self.noise_emission = 0.03 48 | self.init_cov = 20 49 | 50 | # Dynamics params (alpha) 51 | self.K = 4 52 | self.dim_RNN_alpha = 128 53 | self.num_RNN_alpha = 2 54 | 55 | self.build() 56 | 57 | def build(self): 58 | ############# 59 | #### VAE #### 60 | ############# 61 | # 1. Inference of a_t 62 | dic_layers = OrderedDict() 63 | if len(self.dense_x_a) == 0: 64 | dim_x_a = self.x_dim 65 | dic_layers["Identity"] = nn.Identity() 66 | else: 67 | dim_x_a = self.dense_x_a[-1] 68 | for n in range(len(self.dense_x_a)): 69 | if n == 0: 70 | dic_layers["linear" + str(n)] = nn.Linear(self.x_dim, self.dense_x_a[n]) 71 | else: 72 | dic_layers['linear' + str(n)] = nn.Linear(self.dense_x_a[n - 1], self.dense_x_a[n]) 73 | dic_layers['activation' + str(n)] = self.activation 74 | dic_layers['dropout' + str(n)] = nn.Dropout(p=self.dropout_p) 75 | self.mlp_x_a = nn.Sequential(dic_layers) 76 | self.inf_mean = nn.Linear(dim_x_a, self.a_dim) 77 | self.inf_logvar = nn.Linear(dim_x_a, self.a_dim) 78 | 79 | # 2. Generation of x_t 80 | dic_layers = OrderedDict() 81 | if len(self.dense_a_x) == 0: 82 | dim_a_x = self.a_dim 83 | dic_layers["Identity"] = nn.Identity() 84 | else: 85 | dim_a_x = self.dense_a_x[-1] 86 | for n in range(len(self.dense_x_a)): 87 | if n == 0: 88 | dic_layers["linear" + str(n)] = nn.Linear(self.a_dim, self.dense_a_x[n]) 89 | else: 90 | dic_layers["linear" + str(n)] = nn.Linear(self.dense_a_x[n - 1], self.dense_a_x[n]) 91 | dic_layers["activation" + str(n)] = self.activation 92 | dic_layers["dropout" + str(n)] = nn.Dropout(p=self.dropout_p) 93 | self.mlp_a_x = nn.Sequential(dic_layers) 94 | self.gen_logvar = nn.Linear(dim_a_x, self.x_dim) 95 | 96 | ############### 97 | #### LGSSM #### 98 | ############### 99 | # Initializers for LGSSM variables, torch.tensor(), enforce torch.float32 type 100 | # A is an identity matrix 101 | # B and C are randomly sampled from a Gaussian 102 | # Q and R are isotroipic covariance matrices 103 | # z = Az + Bu 104 | # a = Cz 105 | self.A = torch.tensor(np.array([np.eye(self.z_dim) for _ in range(self.K)]), dtype=torch.float32, 106 | requires_grad=True, device=self.cfg.devices[0]) # (K, z_dim. z_dim,) 107 | self.B = torch.tensor( 108 | np.array([self.init_kf_mat * np.random.randn(self.z_dim, self.u_dim) for _ in range(self.K)]), 109 | dtype=torch.float32, requires_grad=True, device=self.cfg.devices[0]) # (K, z_dim, u_dim) 110 | self.C = torch.tensor( 111 | np.array([self.init_kf_mat * np.random.randn(self.a_dim, self.z_dim) for _ in range(self.K)]), 112 | dtype=torch.float32, requires_grad=True, device=self.cfg.devices[0]) # (K, a_dim, z_dim) 113 | self.Q = self.noise_transition * torch.eye(self.z_dim).to(self.cfg.devices[0]) # (z_dim, z_dim) 114 | self.R = self.noise_emission * torch.eye(self.a_dim).to(self.cfg.devices[0]) # (a_dim, a_dim) 115 | self._I = torch.eye(self.z_dim).to(self.cfg.devices[0]) # (z_dim, z_dim) 116 | 117 | ############### 118 | #### Alpha #### 119 | ############### 120 | self.a_init = torch.zeros((1, self.a_dim), requires_grad=True, device=self.cfg.devices[0]) # (bs, a_dim) 121 | self.rnn_alpha = nn.LSTM(self.a_dim, self.dim_RNN_alpha, self.num_RNN_alpha, bidirectional=False) 122 | self.mlp_alpha = nn.Sequential(nn.Linear(self.dim_RNN_alpha, self.K), nn.Softmax(dim=-1)) 123 | 124 | ############################ 125 | #### Scheduler Training #### 126 | ############################ 127 | self.A = nn.Parameter(self.A) 128 | self.B = nn.Parameter(self.B) 129 | self.C = nn.Parameter(self.C) 130 | self.a_init = nn.Parameter(self.a_init) 131 | kf_params = [self.A, self.B, self.C, self.a_init] 132 | 133 | self.iter_kf = (i for i in kf_params) 134 | self.iter_vae = self.concat_iter(self.mlp_x_a.parameters(), 135 | self.inf_mean.parameters(), 136 | self.inf_logvar.parameters(), 137 | self.mlp_a_x.parameters(), 138 | self.gen_logvar.parameters()) 139 | self.iter_alpha = self.concat_iter(self.rnn_alpha.parameters(), 140 | self.mlp_alpha.parameters()) 141 | self.iter_vae_kf = self.concat_iter(self.iter_vae, self.iter_kf) 142 | self.iter_all = self.concat_iter(self.iter_kf, self.iter_vae, self.iter_alpha) 143 | 144 | def concat_iter(self, *iter_list): 145 | for i in iter_list: 146 | yield from i 147 | 148 | def reparameterization(self, mean, logvar): 149 | std = torch.exp(0.5 * logvar) 150 | eps = torch.randn_like(std) 151 | return eps.mul(std).add_(mean) 152 | 153 | def inference(self, x): 154 | x_a = self.mlp_x_a(x) 155 | a_mean = self.inf_mean(x_a) 156 | a_logvar = self.inf_logvar(x_a) 157 | a = self.reparameterization(a_mean, a_logvar) 158 | return a, a_mean, a_logvar 159 | 160 | def get_alpha(self, a_tm1): 161 | """ 162 | Dynamics parameter network alpha for mixing transitions in a SSM 163 | Unlike original code, we only propose RNN here 164 | """ 165 | alpha, _ = self.rnn_alpha(a_tm1) # (seq_len, bs, dim_alpha) 166 | alpha = self.mlp_alpha(alpha) # (seq_len, bs, K), softmax on K dimension 167 | return alpha 168 | 169 | def generation_x(self, a): 170 | a_x = self.mlp_a_x(a) 171 | log_y = self.gen_logvar(a_x) 172 | y = torch.exp(log_y) 173 | return y 174 | 175 | def kf_smoother(self, a, u, K, A, B, C, R, Q, optimal_gain=False, alpha_sq=1): 176 | """" 177 | Kalman Smoother, refer to Murphy's book (MLAPP), section 18.3 178 | Difference from KVAE source code: 179 | - no imputation 180 | - only RNN for the calculation of alpha 181 | - different notations (rather than using same notations as Murphy's book ,we use notation from model KVAE) 182 | # z_t = A_t * z_tm1 + B_t * u_t 183 | #a_t = C_t * z_t 184 | Input: 185 | - a, (seq_len, bs, a_dim) 186 | - u, (seq_len, bs, u_dim) 187 | - alpha, (seq_len, bs, alpha_dim) 188 | - K, real number 189 | - A, (K, z_dim, z_dim) 190 | - B, (K, z_dim, u_dim) 191 | - C, (K, a_dim, z_dim) 192 | - R, (z_dim, z_dim) 193 | - Q , (a_dim, a_dim) 194 | """ 195 | # Initialization 196 | seq_len = a.shape[0] 197 | batch_size = a.shape[1] 198 | self.mu = torch.zeros((batch_size, self.z_dim)).to(self.device) # (bs, z_dim), z_0 199 | self.Sigma = self.init_cov * torch.eye(self.z_dim).unsqueeze(0).repeat(batch_size, 1, 1).to(self.device) # (bs, z_dim, z_dim), Sigma_0 200 | mu_pred = torch.zeros((seq_len, batch_size, self.z_dim)).to(self.device) # (seq_len, bs, z_dim) 201 | mu_filter = torch.zeros((seq_len, batch_size, self.z_dim)).to(self.device) # (seq_len, bs, z_dim) 202 | mu_smooth = torch.zeros((seq_len, batch_size, self.z_dim)).to(self.device) # (seq_len, bs, z_dim) 203 | Sigma_pred = torch.zeros((seq_len, batch_size, self.z_dim, self.z_dim)).to(self.device) # (seq_len, bs, z_dim, z_dim) 204 | Sigma_filter = torch.zeros((seq_len, batch_size, self.z_dim, self.z_dim)).to(self.device) # (seq_len, bs, z_dim, z_dim) 205 | Sigma_smooth = torch.zeros((seq_len, batch_size, self.z_dim, self.z_dim)).to(self.device) # (seq_len, bs, z_dim, z_dim) 206 | 207 | # Calculate alpha, initial observation a_init is assumed to be zero and can be learned 208 | a_init_expand = self.a_init.unsqueeze(1).repeat(1, batch_size, 1) # (1, bs, a_dim) 209 | a_tm1 = torch.cat([a_init_expand, a[:-1, :, :]], 0) # (seq_len, bs, a_dim) 210 | alpha = self.get_alpha(a_tm1) # (seq_len, bs, K) 211 | 212 | # Calculate the mixture of A, B and C 213 | A_flatten = A.view(K, self.z_dim * self.z_dim) # (K, z_dim*z_dim) 214 | B_flatten = B.view(K, self.z_dim * self.u_dim) # (K, z_dim*u_dim) 215 | C_flatten = C.view(K, self.a_dim * self.z_dim) # (K, a_dim*z_dim) 216 | A_mix = alpha.matmul(A_flatten).view(seq_len, batch_size, self.z_dim, self.z_dim) 217 | B_mix = alpha.matmul(B_flatten).view(seq_len, batch_size, self.z_dim, self.u_dim) 218 | C_mix = alpha.matmul(C_flatten).view(seq_len, batch_size, self.a_dim, self.z_dim) 219 | 220 | # Forward filter 221 | for t in range(seq_len): 222 | # Mixture of A, B and C 223 | A_t = A_mix[t] # (bs, z_dim. z_dim) 224 | B_t = B_mix[t] # (bs, z_dim, u_dim) 225 | C_t = C_mix[t] # (bs, a_dim, z_dim) 226 | 227 | if t == 0: 228 | mu_t_pred = self.mu.unsqueeze(-1) # (bs, z_dim, 1) 229 | Sigma_t_pred = self.Sigma 230 | else: 231 | u_t = u[t, :, :] # (bs, u_dim) 232 | mu_t_pred = A_t.bmm(mu_t) + B_t.bmm(u_t.unsqueeze(-1)) # (bs, z_dim, 1), z_{t|t-1} 233 | Sigma_t_pred = alpha_sq * A_t.bmm(Sigma_t).bmm( 234 | A_t.transpose(1, 2)) + self.Q # (bs, z_dim, z_dim), Sigma_{t|t-1} 235 | # alpha_sq (>=1) is fading memory control, which indicates how much you want to forgert past measurements, see more infos in 'FilterPy' library 236 | 237 | # Residual 238 | a_pred = C_t.bmm(mu_t_pred) # (bs, a_dim, z_dim) x (bs, z_dim, 1) 239 | res_t = a[t, :, :].unsqueeze(-1) - a_pred # (bs, a_dim, 1) 240 | 241 | # Kalman gain 242 | S_t = C_t.bmm(Sigma_t_pred).bmm(C_t.transpose(1, 2)) + self.R # (bs, a_dim, a_dim) 243 | S_t_inv = S_t.inverse() 244 | K_t = Sigma_t_pred.bmm(C_t.transpose(1, 2)).bmm(S_t_inv) # (bs, z_dim, a_dim) 245 | 246 | # Update 247 | mu_t = mu_t_pred + K_t.bmm(res_t) # (bs, z_dim, 1) 248 | I_KC = self._I - K_t.bmm(C_t) # (bs, z_dim, z_dim) 249 | if optimal_gain: 250 | Sigma_t = I_KC.bmm(Sigma_t_pred) # (bs, z_dim, z_dim), only valid with optimal Kalman gain 251 | else: 252 | Sigma_t = I_KC.bmm(Sigma_t_pred).bmm(I_KC.transpose(1, 2)) + K_t.matmul(self.R).matmul( 253 | K_t.transpose(1, 2)) # (bs, z_dim, z_dim), general case 254 | 255 | # Save cache 256 | mu_pred[t] = mu_t_pred.view(batch_size, self.z_dim) 257 | Sigma_pred[t] = Sigma_t_pred 258 | Sigma_filter[t] = Sigma_t 259 | 260 | # Add the final state from filter to the smoother as initialization 261 | mu_smooth[-1] = mu_filter[-1] 262 | Sigma_smooth[-1] = Sigma_filter[-1] 263 | 264 | # Backward smooth, reverse loop from pernultimate state 265 | for t in range(seq_len - 2, -1, -1): 266 | # Backward Kalman gain 267 | J_t = Sigma_filter[t].bmm(A_mix[t + 1].transpose(1, 2)).bmm( 268 | Sigma_pred[t + 1].inverse()) # (bs, z_dim, z_dim) 269 | 270 | # Backward smoothing 271 | dif_mu_tp1 = (mu_smooth[t + 1] - mu_filter[t + 1]).unsqueeze(-1) # (bs, z_dim, 1) 272 | mu_smooth[t] = mu_filter[t] + J_t.matmul(dif_mu_tp1).view(batch_size, self.z_dim) # (bs, z_dim) 273 | dif_Sigma_tp1 = Sigma_smooth[t + 1] - Sigma_pred[t + 1] # (bs, z_dim, z_dim) 274 | Sigma_smooth[t] = Sigma_filter[t] + J_t.bmm(dif_Sigma_tp1).bmm(J_t.transpose(1, 2)) # (bs, z_dim, z_dim) 275 | 276 | # Generate a from smoothing z 277 | a_gen = C_mix.matmul(mu_smooth.unsqueeze(-1)).view(seq_len, batch_size, self.a_dim) # (seq_len, bs, a_dim) 278 | return a_gen, mu_smooth, Sigma_smooth, A_mix, B_mix, C_mix 279 | 280 | def forward(self, x, generation_len): 281 | # train input: (batch_size, x_dim, seq_len) 282 | # test input: (x_dim, seq_len) 283 | # need input: (seq_len, batch_size, x_dim) 284 | 285 | # Input is an image so reduce down to [batch_size, flattened_dim, seq_len] 286 | x = x.reshape(x.shape[0], x.shape[1], -1) 287 | x = x.permute(1, 0, 2) 288 | seq_len = x.shape[0] 289 | batch_size = x.shape[1] 290 | 291 | # main part 292 | self.a, self.a_mean, self.a_logvar = self.inference(x) 293 | u_0 = torch.zeros(1, batch_size, self.u_dim).to(self.cfg.devices[0]) 294 | self.u = torch.cat((u_0, self.a[:-1]), 0) 295 | a_gen, self.mu_smooth, self.Sigma_smooth, self.A_mix, self.B_mix, self.C_mix = self.kf_smoother( 296 | self.a, self.u, self.K, self.A, self.B, self.C, self.R, self.Q 297 | ) 298 | self.y = self.generation_x(a_gen) 299 | 300 | # output of NN: (seq_len, batch_size, dim) 301 | # output of model: (batch_size, dim, seq_len) or (dim, seq_len) 302 | self.y = self.y.permute(1, 0, 2) 303 | 304 | # Reshape y to image dimensions 305 | self.y = self.y.reshape([batch_size, seq_len, self.cfg.dim, self.cfg.dim]) 306 | return self.y, torch.zeros([batch_size, seq_len, self.cfg.latent_dim]) 307 | 308 | def model_specific_loss(self, x, x_rec, train=True): 309 | # batch_size, seq_len = x.shape[0], x.shape[1] 310 | # Input is an image so reduce down to [batch_size, flattened_dim, seq_len] 311 | x = x.reshape(x.shape[0], x.shape[1], -1) 312 | x = x.permute(1, 0, 2) 313 | 314 | seq_len = x.shape[0] 315 | batch_size = x.shape[1] 316 | 317 | # log q_{\phi}(a_hat | x), Gaussian 318 | log_qa_given_x = - 0.5 * self.a_logvar - torch.pow(self.a - self.a_mean, 2) / (2 * torch.exp(self.a_logvar)) 319 | 320 | # log p_{\gamma}(a_tilde, z_tilde | u) < in sub-comment, 'tilde' is hidden for simplification > 321 | # >>> log p(z_t | z_tm1, u_t), transition 322 | Sigma_smooth_stable = self.Sigma_smooth 323 | L = torch.linalg.cholesky(Sigma_smooth_stable) 324 | mvn_smooth = MultivariateNormal(self.mu_smooth, scale_tril=L) 325 | 326 | 327 | # mvn_smooth = MultivariateNormal(self.mu_smooth, self.Sigma_smooth) 328 | z_smooth = mvn_smooth.sample() # # (seq_len, bs, z_dim) 329 | Az_tm1 = self.A_mix[:-1].matmul(z_smooth[:-1].unsqueeze(-1)).view(seq_len - 1, batch_size, -1) # (seq_len, bs, z_dim) 330 | Bu_t = self.B_mix[:-1].matmul(self.u[:-1].unsqueeze(-1)).view(seq_len - 1, batch_size, -1) # (seq_len, bs, z_dim) 331 | mu_t_transition = Az_tm1 + Bu_t 332 | z_t_transition = z_smooth[1:] 333 | mvn_transition = MultivariateNormal(z_t_transition, self.Q) 334 | log_prob_transition = mvn_transition.log_prob(mu_t_transition) 335 | 336 | # >>> log p(z_0 | z_init), init state 337 | z_0 = z_smooth[0] 338 | mvn_0 = MultivariateNormal(self.mu, self.Sigma) 339 | log_prob_0 = mvn_0.log_prob(z_0) 340 | 341 | # >>> log p(a_t | z_t), emission 342 | Cz_t = self.C_mix.matmul(z_smooth.unsqueeze(-1)).view(seq_len, batch_size, self.a_dim) 343 | mvn_emission = MultivariateNormal(Cz_t, self.R) 344 | log_prob_emission = mvn_emission.log_prob(self.a) 345 | 346 | # >>> log p_{\gamma}(a_tilde, z_tilde | u) 347 | log_paz_given_u = torch.cat([log_prob_transition, log_prob_0.unsqueeze(0)], 0) + log_prob_emission 348 | 349 | # log p_{\gamma}(z_tilde | a_tilde, u) 350 | # >>> log p(z_t | a, u) 351 | log_pz_given_au = mvn_smooth.log_prob(z_smooth) 352 | 353 | # Normalization 354 | log_qa_given_x = torch.sum(log_qa_given_x) / (batch_size * seq_len) 355 | log_paz_given_u = torch.sum(log_paz_given_u) / (batch_size * seq_len) 356 | log_pz_given_au = torch.sum(log_pz_given_au) / (batch_size * seq_len) 357 | 358 | # Loss 359 | loss_vae = log_qa_given_x 360 | loss_lgssm = - log_paz_given_u + log_pz_given_au 361 | loss_tot = loss_vae + loss_lgssm 362 | return loss_vae + loss_lgssm + loss_tot 363 | -------------------------------------------------------------------------------- /models/group_b2/NeuralODE.py: -------------------------------------------------------------------------------- 1 | """ 2 | @file NeuralODE.py 3 | 4 | Holds the model for the Neural ODE latent dynamics function 5 | """ 6 | import torch 7 | import torch.nn as nn 8 | 9 | from torchdiffeq import odeint 10 | from utils.utils import get_act 11 | from models.CommonDynamics import LatentDynamicsModel 12 | 13 | 14 | class ODEFunction(nn.Module): 15 | def __init__(self, cfg): 16 | """ Standard Neural ODE dynamics function """ 17 | super(ODEFunction, self).__init__() 18 | 19 | # Build the dynamics network 20 | dynamics_network = [] 21 | dynamics_network.extend([ 22 | nn.Linear(cfg.latent_dim, cfg.num_hidden), 23 | get_act(cfg.latent_act) 24 | ]) 25 | 26 | for _ in range(cfg.num_layers - 1): 27 | dynamics_network.extend([ 28 | nn.Linear(cfg.num_hidden, cfg.num_hidden), 29 | get_act(cfg.latent_act) 30 | ]) 31 | 32 | dynamics_network.extend([nn.Linear(cfg.num_hidden, cfg.latent_dim), nn.Tanh()]) 33 | self.dynamics_network = nn.Sequential(*dynamics_network) 34 | 35 | def forward(self, t, z): 36 | """ Wrapper function for the odeint calculation """ 37 | return self.dynamics_network(z) 38 | 39 | 40 | class NeuralODE(LatentDynamicsModel): 41 | def __init__(self, cfg): 42 | """ Latent dynamics as parameterized by a global deterministic neural ODE """ 43 | super().__init__(cfg) 44 | 45 | # ODE-Net which holds mixture logic 46 | self.dynamics_func = ODEFunction(cfg) 47 | 48 | def forward(self, x, generation_len): 49 | """ 50 | Forward function of the ODE network 51 | :param x: data observation, which is a timeseries [BS, Timesteps, N Channels, Dim1, Dim2] 52 | :param generation_len: how many timesteps to generate over 53 | """ 54 | # Sample z_init 55 | z_init = self.encoder(x) 56 | 57 | # Evaluate model forward over T to get L latent reconstructions 58 | t = torch.linspace(0, generation_len - 1, generation_len, device=self.device) 59 | zt = odeint(self.dynamics_func, z_init, t, method=self.cfg.integrator, options={'step_size': self.cfg.integrator_step_size}) 60 | zt = zt.permute([1, 0, 2]) 61 | 62 | # Stack zt and decode zts 63 | x_rec = self.decoder(zt) 64 | return x_rec, zt 65 | -------------------------------------------------------------------------------- /models/group_b2/RGN.py: -------------------------------------------------------------------------------- 1 | """ 2 | @file NeuralODE.py 3 | 4 | Holds the model for the Neural ODE latent dynamics function 5 | """ 6 | import torch 7 | import torch.nn as nn 8 | 9 | from torchdiffeq import odeint 10 | from utils.utils import get_act 11 | from models.CommonDynamics import LatentDynamicsModel 12 | 13 | 14 | class RGNResFunction(nn.Module): 15 | def __init__(self, cfg): 16 | """ Standard Residual Recurrent Generative Network dynamics function """ 17 | super(RGNResFunction, self).__init__() 18 | 19 | # Build the dynamics network 20 | dynamics_network = [] 21 | dynamics_network.extend([ 22 | nn.Linear(cfg.latent_dim, cfg.num_hidden), 23 | get_act(cfg.latent_act) 24 | ]) 25 | 26 | for _ in range(cfg.num_layers - 1): 27 | dynamics_network.extend([ 28 | nn.Linear(cfg.num_hidden, cfg.num_hidden), 29 | get_act(cfg.latent_act) 30 | ]) 31 | 32 | dynamics_network.extend([nn.Linear(cfg.num_hidden, cfg.latent_dim), nn.Tanh()]) 33 | self.dynamics_network = nn.Sequential(*dynamics_network) 34 | 35 | def forward(self, t, z): 36 | """ Wrapper function for the odeint calculation """ 37 | return self.dynamics_network(z) 38 | 39 | 40 | class RGNRes(LatentDynamicsModel): 41 | def __init__(self, cfg): 42 | """ Latent dynamics as parameterized by a global deterministic neural ODE """ 43 | super().__init__(cfg) 44 | 45 | # ODE-Net which holds mixture logic 46 | self.dynamics_func = RGNResFunction(cfg) 47 | 48 | def forward(self, x, generation_len): 49 | # Sample z_init 50 | z_init = self.encoder(x) 51 | 52 | # Evaluate forward over timestep 53 | z_cur = z_init 54 | zts = [z_init] 55 | for _ in range(generation_len - 1): 56 | z_cur = self.dynamics_func(None, z_cur) 57 | zts.append(z_cur) 58 | 59 | zt = torch.stack(zts, dim=1) 60 | 61 | # Stack zt and decode zts 62 | x_rec = self.decoder(zt) 63 | return x_rec, zt 64 | -------------------------------------------------------------------------------- /models/group_b2/RGNRes.py: -------------------------------------------------------------------------------- 1 | """ 2 | @file NeuralODE.py 3 | 4 | Holds the model for the Neural ODE latent dynamics function 5 | """ 6 | import torch 7 | import torch.nn as nn 8 | 9 | from torchdiffeq import odeint 10 | from utils.utils import get_act 11 | from models.CommonDynamics import LatentDynamicsModel 12 | 13 | 14 | class RGNResFunction(nn.Module): 15 | def __init__(self, cfg): 16 | """ Standard Residual Recurrent Generative Network dynamics function """ 17 | super(RGNResFunction, self).__init__() 18 | 19 | # Build the dynamics network 20 | dynamics_network = [] 21 | dynamics_network.extend([ 22 | nn.Linear(cfg.latent_dim, cfg.num_hidden), 23 | get_act(cfg.latent_act) 24 | ]) 25 | 26 | for _ in range(cfg.num_layers - 1): 27 | dynamics_network.extend([ 28 | nn.Linear(cfg.num_hidden, cfg.num_hidden), 29 | get_act(cfg.latent_act) 30 | ]) 31 | 32 | dynamics_network.extend([nn.Linear(cfg.num_hidden, cfg.latent_dim), nn.Tanh()]) 33 | self.dynamics_network = nn.Sequential(*dynamics_network) 34 | 35 | def forward(self, t, z): 36 | """ Wrapper function for the odeint calculation """ 37 | return z + self.dynamics_network(z) 38 | 39 | 40 | class RGNRes(LatentDynamicsModel): 41 | def __init__(self, cfg): 42 | """ Latent dynamics as parameterized by a global deterministic neural ODE """ 43 | super().__init__(cfg) 44 | 45 | # ODE-Net which holds mixture logic 46 | self.dynamics_func = RGNResFunction(cfg) 47 | 48 | def forward(self, x, generation_len): 49 | # Sample z_init 50 | z_init = self.encoder(x) 51 | 52 | # Evaluate forward over timestep 53 | z_cur = z_init 54 | zts = [z_init] 55 | for _ in range(generation_len - 1): 56 | z_cur = self.dynamics_func(None, z_cur) 57 | zts.append(z_cur) 58 | 59 | zt = torch.stack(zts, dim=1) 60 | 61 | # Stack zt and decode zts 62 | x_rec = self.decoder(zt) 63 | return x_rec, zt 64 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # Automatically generated by https://github.com/damnever/pigar. 2 | 3 | hydra-core==1.3.2 4 | matplotlib==3.7.0 5 | numpy==1.24.2 6 | omegaconf==2.3.0 7 | pytorch-lightning==1.9.0 8 | scikit-image==0.25.0 9 | scikit-learn==1.4.2 10 | torch==1.13.1 11 | torchdiffeq==0.2.3 12 | tqdm==4.64.1 13 | -------------------------------------------------------------------------------- /scripts/ablation_generation_length.py: -------------------------------------------------------------------------------- 1 | """ 2 | @file ablation_generation_length.py 3 | 4 | Holds the cmd calls to train models across different generation lengths used in training 5 | """ 6 | import os 7 | 8 | os.system("python3 main.py --generation_len 1 --exptype node_pendulum_1gen --generation_varying False") 9 | os.system("python3 test.py --ckpt_path experiments/node_pendulum_1gen/node/version_1/ --training_len 1") 10 | 11 | os.system("python3 main.py --generation_len 2 --exptype node_pendulum_2gen --generation_varying False") 12 | os.system("python3 test.py --ckpt_path experiments/node_pendulum_2gen/node/version_1/ --training_len 2") 13 | 14 | os.system("python3 main.py --generation_len 3 --exptype node_pendulum_3gen --generation_varying False") 15 | os.system("python3 test.py --ckpt_path experiments/node_pendulum_3gen/node/version_1/ --training_len 3") 16 | 17 | os.system("python3 main.py --generation_len 5 --exptype node_pendulum_5gen --generation_varying False") 18 | os.system("python3 test.py --ckpt_path experiments/node_pendulum_5gen/node/version_1/ --training_len 5") 19 | 20 | os.system("python3 main.py --generation_len 10 --exptype node_pendulum_10gen --generation_varying False") 21 | os.system("python3 test.py --ckpt_path experiments/node_pendulum_10gen/node/version_1/ --training_len 10") 22 | 23 | os.system("python3 main.py --generation_len 20 --exptype node_pendulum_20gen --generation_varying False") 24 | os.system("python3 test.py --ckpt_path experiments/node_pendulum_20gen/node/version_1/ --training_len 20") 25 | -------------------------------------------------------------------------------- /scripts/ablation_odeintegrator.py: -------------------------------------------------------------------------------- 1 | """ 2 | @file ablation_odeintegrator.py 3 | 4 | Holds the cmd calls to train models across different ODE integrators automatically 5 | """ 6 | import os 7 | 8 | dev = 0 9 | num_epochs = 100 10 | 11 | os.system(f"python main.py --exptype ablation_odeint_rk4_1 --integrator rk4 --integrator_params step_size=0.5 --num_epochs {num_epochs} --latent_dim 8 --num_hidden 128 --num_layers 3 --num_filt 8 --dev {dev}") 12 | os.system(f"python main.py --exptype ablation_odeint_rk4_0.5 --integrator rk4 --integrator_params step_size=0.5 --num_epochs {num_epochs} --latent_dim 8 --num_hidden 128 --num_layers 3 --num_filt 8 --dev {dev}") 13 | os.system(f"python main.py --exptype ablation_odeint_rk4_0.25 --integrator rk4 --integrator_params step_size=0.25 --num_epochs {num_epochs} --latent_dim 8 --num_hidden 128 --num_layers 3 --num_filt 8 --dev {dev}") 14 | os.system(f"python main.py --exptype ablation_odeint_rk4_0.125 --integrator rk4 --integrator_params step_size=0.125 --num_epochs {num_epochs} --latent_dim 8 --num_hidden 128 --num_layers 3 --num_filt 8 --dev {dev}") 15 | 16 | os.system(f"python main.py --exptype ablation_odeint_dopri5_500 --integrator dopri5 --integrator_params max_num_steps=500 --num_epochs {num_epochs} --latent_dim 8 --num_hidden 128 --num_layers 3 --num_filt 8 --dev {dev}") 17 | os.system(f"python main.py --exptype ablation_odeint_dopri5_1000 --integrator dopri5 --integrator_params max_num_steps=1000 --num_epochs {num_epochs} --latent_dim 8 --num_hidden 128 --num_layers 3 --num_filt 8 --dev {dev}") 18 | os.system(f"python main.py --exptype ablation_odeint_dopri5_2000 --integrator dopri5 --integrator_params max_num_steps=2000 --num_epochs {num_epochs} --latent_dim 8 --num_hidden 128 --num_layers 3 --num_filt 8 --dev {dev}") 19 | os.system(f"python main.py --exptype ablation_odeint_dopri5_5000 --integrator dopri5 --integrator_params max_num_steps=5000 --num_epochs {num_epochs} --latent_dim 8 --num_hidden 128 --num_layers 3 --num_filt 8 --dev {dev}") 20 | -------------------------------------------------------------------------------- /utils/dataloader.py: -------------------------------------------------------------------------------- 1 | """ 2 | @file dataloader.py 3 | @author Ryan Missel 4 | 5 | Holds the LightningDataModule for the available datasets 6 | """ 7 | import torch 8 | import numpy as np 9 | import pytorch_lightning 10 | from torch.utils.data import Dataset, DataLoader 11 | 12 | 13 | class SSMDataset(Dataset): 14 | """ Basic Dataset object for the SSM """ 15 | def __init__(self, images, labels, states, controls): 16 | self.images = images 17 | self.labels = labels 18 | self.states = states 19 | self.controls = controls 20 | 21 | def __len__(self): 22 | return self.images.shape[0] 23 | 24 | def __getitem__(self, idx): 25 | return torch.Tensor([idx]), self.images[idx], self.states[idx], self.controls[idx], self.labels[idx] 26 | 27 | 28 | class SSMDataModule(pytorch_lightning.LightningDataModule): 29 | """ Custom DataModule object that handles preprocessing all sets of data for a given run """ 30 | def __init__(self, cfg): 31 | super(SSMDataModule, self).__init__() 32 | self.cfg = cfg 33 | 34 | def make_loader(self, mode="train", evaluation=False, shuffle=True): 35 | # Load in NPZ 36 | npzfile = np.load(f"data/{self.cfg.dataset}/{mode}.npz") 37 | 38 | # Load in data sources 39 | images = npzfile['image'].astype(np.float32) 40 | labels = npzfile['label'].astype(np.int16) 41 | states = npzfile['state'].astype(np.float32)[:, :, :2] 42 | 43 | # Load control, if it exists, else make a dummy one 44 | controls = npzfile['control'] if 'control' in npzfile else np.zeros((images.shape[0], images.shape[1], 1), dtype=np.float32) 45 | 46 | # Modify based on dataset percent 47 | rand_idx = np.random.choice(range(images.shape[0]), size=int(images.shape[0] * self.cfg.dataset_percent), replace=False) 48 | images = images[rand_idx] 49 | labels = labels[rand_idx] 50 | states = states[rand_idx] 51 | controls = controls[rand_idx] 52 | 53 | # Convert to Tensors 54 | images = torch.from_numpy(images) 55 | labels = torch.from_numpy(labels) 56 | states = torch.from_numpy(states) 57 | controls = torch.from_numpy(controls) 58 | 59 | # Build dataset and corresponding Dataloader 60 | dataset = SSMDataset(images, labels, states, controls) 61 | 62 | # If it is the training setting, set up the iterative dataloader 63 | if mode == "train" and evaluation is False: 64 | sampler = torch.utils.data.RandomSampler(dataset, replacement=True, num_samples=self.cfg.num_steps * self.cfg.batch_size) 65 | dataloader = DataLoader(dataset, sampler=sampler, batch_size=self.cfg.batch_size, drop_last=True) 66 | 67 | # Otherwise, setup a normal dataloader 68 | else: 69 | dataloader = DataLoader(dataset, batch_size=self.cfg.batch_size, shuffle=shuffle) 70 | return dataloader 71 | 72 | def train_dataloader(self): 73 | """ Getter function that builds and returns the training dataloader """ 74 | return self.make_loader("train") 75 | 76 | def evaluate_train_dataloader(self): 77 | return self.make_loader("train", evaluation=True, shuffle=False) 78 | 79 | def val_dataloader(self): 80 | """ Getter function that builds and returns the validation dataloader """ 81 | return self.make_loader("val", shuffle=False) 82 | 83 | def test_dataloader(self): 84 | """ Getter function that builds and returns the testing dataloader """ 85 | return self.make_loader("test", shuffle=False) 86 | -------------------------------------------------------------------------------- /utils/layers.py: -------------------------------------------------------------------------------- 1 | """ 2 | @file layers.py 3 | 4 | Miscellaneous helper Torch layers 5 | """ 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | class Gaussian(nn.Module): 11 | def __init__(self, in_dim, out_dim, fix_variance=False): 12 | """ 13 | Gaussian sample layer consisting of 2 simple linear layers. 14 | Can choose whether to fix the variance or let it be learned (training instability has been shown when learning). 15 | 16 | :param in_dim: input dimension (often a flattened latent embedding from a CNN) 17 | :param out_dim: output dimension 18 | :param fix_variance: whether to set the log-variance as a constant 0.1 19 | """ 20 | super(Gaussian, self).__init__() 21 | self.fix_variance = fix_variance 22 | 23 | # Mean layer 24 | self.mu = nn.Sequential( 25 | nn.Linear(in_dim, in_dim // 2), 26 | nn.LeakyReLU(0.1), 27 | nn.Linear(in_dim // 2, out_dim) 28 | ) 29 | 30 | # Log-var layer 31 | self.logvar = nn.Sequential( 32 | nn.Linear(in_dim, in_dim // 2), 33 | nn.LeakyReLU(0.1), 34 | nn.Linear(in_dim // 2, out_dim) 35 | ) 36 | 37 | def reparameterize(self, mu, logvar): 38 | """ Reparameterization trick to get a sample from the output distribution """ 39 | std = torch.exp(0.5 * logvar) 40 | noise = torch.randn_like(std) 41 | z = mu + (noise * std) 42 | return z 43 | 44 | def forward(self, x): 45 | """ 46 | Forward function of the Gaussian layer. Handles getting the distributional parameters and sampling a vector 47 | :param x: input vector [BatchSize, InputDim] 48 | """ 49 | # Get mu and logvar 50 | mu = self.mu(x) 51 | 52 | if self.fix_variance: 53 | logvar = torch.full_like(mu, fill_value=0.1) 54 | else: 55 | logvar = self.logvar(x) 56 | 57 | # Check on whether mu/logvar are getting out of normal ranges 58 | if (mu < -100).any() or (mu > 85).any() or (logvar < -100).any() or (logvar > 85).any(): 59 | print("Explosion in mu/logvar. Mu {} Logvar {}".format(torch.mean(mu), torch.mean(logvar))) 60 | 61 | # Reparameterize and sample 62 | z = self.reparameterize(mu, logvar) 63 | return mu, logvar, z 64 | 65 | 66 | class GroupSwish(nn.Module): 67 | def __init__(self, groups): 68 | """ 69 | Swish activation function that works on GroupConvolution by reshaping all input groups back into their 70 | batch shapes, activating them, then reshaping them to 1D GroupConv filters 71 | """ 72 | super().__init__() 73 | self.silu = nn.SiLU() 74 | self.groups = groups 75 | 76 | def forward(self, x): 77 | n_ch_group = x.size(1) // self.groups 78 | t = x.shape[2:] 79 | x = x.reshape(-1, self.groups, n_ch_group, *t) 80 | return self.silu(x).reshape(1, self.groups * n_ch_group, *t) 81 | 82 | 83 | class GroupTanh(nn.Module): 84 | def __init__(self, groups): 85 | """ 86 | Tanh activation function that works on GroupConvolution by reshaping all input groups back into their 87 | batch shapes, activating them, then reshaping them to 1D GroupConv filters 88 | """ 89 | super().__init__() 90 | self.tanh = nn.Tanh() 91 | self.groups = groups 92 | 93 | def forward(self, x): 94 | n_ch_group = x.size(1) // self.groups 95 | t = x.shape[2:] 96 | x = x.reshape(self.groups, n_ch_group, *t) 97 | return self.tanh(x).reshape(1, self.groups * n_ch_group, *t) 98 | 99 | 100 | class Flatten(nn.Module): 101 | def forward(self, input): 102 | """ 103 | Handles flattening a Tensor within a nn.Sequential Block 104 | 105 | :param input: Torch object to flatten 106 | """ 107 | return input.view(input.size(0), -1) 108 | 109 | 110 | class UnFlatten(nn.Module): 111 | def __init__(self, w): 112 | """ 113 | Handles unflattening a vector into a 4D vector in a nn.Sequential Block 114 | 115 | :param w: width of the unflattened image vector 116 | """ 117 | super().__init__() 118 | self.w = w 119 | 120 | def forward(self, input): 121 | nc = input[0].numel() // (self.w ** 2) 122 | return input.view(input.size(0), nc, self.w, self.w) 123 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | """ 2 | @file metrics.py 3 | 4 | Holds a variety of metric computing functions for time-series forecasting models, including 5 | Valid Prediction Time (VPT), Valid Prediction Distance (VPD), etc. 6 | """ 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | 11 | from skimage.filters import threshold_otsu 12 | from sklearn.linear_model import LinearRegression 13 | from sklearn.neural_network import MLPRegressor 14 | 15 | 16 | def vpt(gt, preds, epsilon=0.010, **kwargs): 17 | """ 18 | Computes the Valid Prediction Time metric, as proposed in https://openreview.net/pdf?id=qBl8hnwR0px 19 | VPT = argmin_t [MSE(gt, pred) > epsilon] 20 | :param gt: ground truth sequences 21 | :param preds: model predicted sequences 22 | :param epsilon: threshold for valid prediction 23 | """ 24 | # Ensure on CPU and numpy 25 | if not isinstance(gt, np.ndarray): 26 | gt = gt.cpu().numpy() 27 | preds = preds.cpu().numpy() 28 | 29 | # Get dimensions 30 | _, timesteps, height, width = gt.shape 31 | 32 | # Get pixel_level MSE at each timestep 33 | mse = (gt - preds) ** 2 34 | mse = np.sum(mse, axis=(2, 3)) / (height * width) 35 | 36 | # Get VPT 37 | vpts = [] 38 | for m in mse: 39 | # Get all indices below the given epsilon 40 | indices = np.where(m < epsilon)[0] + 1 41 | 42 | # If there are none below, then add 0 43 | if len(indices) == 0: 44 | vpts.append(0) 45 | continue 46 | 47 | # Append last in list 48 | vpts.append(indices[-1]) 49 | 50 | # Return VPT mean over the total timesteps 51 | return np.mean(vpts) / timesteps, np.std(vpts) / timesteps 52 | 53 | 54 | def thresholding(preds, gt): 55 | """ 56 | Thresholding function that converts gt and preds into binary images 57 | Activated prediction pixels are found via Otsu's thresholding function 58 | """ 59 | N, T = gt.shape[0], gt.shape[1] 60 | res = np.zeros_like(preds) 61 | 62 | # For each sample and timestep, get Otsu's threshold and binarize gt and pred 63 | for n in range(N): 64 | for t in range(T): 65 | img = preds[n, t] 66 | otsu_th = np.max([0.32, threshold_otsu(img)]) 67 | res[n, t] = (img > otsu_th).astype(np.float32) 68 | gt[n, t] = (gt[n, t] > 0.55).astype(np.float32) 69 | return res, gt 70 | 71 | 72 | def dst(gt, preds, **kwargs): 73 | """ 74 | Computes a Euclidean distance metric between the center of the ball in ground truth and prediction 75 | Activated pixels in the predicted are computed via Otsu's thresholding function 76 | :param gt: ground truth sequences 77 | :param preds: model predicted sequences 78 | """ 79 | # Ensure on CPU and numpy 80 | if not isinstance(gt, np.ndarray): 81 | gt = gt.cpu().numpy() 82 | preds = preds.cpu().numpy() 83 | 84 | # Get shapes 85 | num_samples, timesteps, height, width = gt.shape 86 | 87 | # Apply Otsu thresholding function on output 88 | preds, gt = thresholding(preds, gt) 89 | 90 | # Loop over each sample and timestep to get the distance metric 91 | results = np.zeros([num_samples, timesteps]) 92 | for n in range(num_samples): 93 | for t in range(timesteps): 94 | # Get all active predicted pixels 95 | a = preds[n, t] 96 | b = gt[n, t] 97 | pos_a = np.where(a == 1) 98 | pos_b = np.where(b == 1) 99 | 100 | # If there are in gt, add 0 101 | if pos_b[0].shape[0] == 0: 102 | results[n, t] = 0 103 | continue 104 | 105 | # Get gt center 106 | center_b = np.array([pos_b[0].mean(), pos_b[1].mean()]) 107 | 108 | # Get center of predictions 109 | if pos_a[0].shape[0] != 0: 110 | center_a = np.array([pos_a[0].mean(), pos_a[1].mean()]) 111 | # If no pixels above threshold, add the highest possible error in image space 112 | else: 113 | # results[n, t] = np.sqrt(np.sum(np.array([height, width]) ** 2)) 114 | center_a = [0, 0] 115 | # continue 116 | 117 | # Get distance metric 118 | dist = np.sum((center_a - center_b) ** 2) 119 | dist = np.sqrt(dist) 120 | 121 | # Add to result 122 | results[n, t] = dist 123 | 124 | return np.mean(results), np.std(results) 125 | 126 | 127 | def vpd(target, output, epsilon=10, **kwargs): 128 | """ 129 | Computes the Valid Prediction Time metric, as proposed in https://openreview.net/forum?id=7C9aRX2nBf2 130 | VPD = argmin_t [DST(gt, pred) > epsilon] 131 | :param gt: ground truth sequences 132 | :param preds: model predicted sequences 133 | :param epsilon: threshold for valid prediction 134 | """ 135 | # Ensure on CPU and numpy 136 | if not isinstance(output, np.ndarray): 137 | output = output.cpu().numpy() 138 | target = target.cpu().numpy() 139 | 140 | # Get shapes 141 | num_samples, timesteps, height, width = target.shape 142 | 143 | # Apply Otsu thresholding function on output 144 | preds, gt = thresholding(output, target) 145 | 146 | # Loop over each sample and timestep to get the distance metric 147 | dsts = np.zeros([num_samples, timesteps]) 148 | for n in range(num_samples): 149 | for t in range(timesteps): 150 | # Get all active predicted pixels 151 | a = preds[n, t] 152 | b = gt[n, t] 153 | pos_a = np.where(a == 1) 154 | pos_b = np.where(b == 1) 155 | 156 | # If there are in gt, add 0 157 | if pos_b[0].shape[0] == 0: 158 | dsts[n, t] = 0 159 | continue 160 | 161 | # Get gt center 162 | center_b = np.array([pos_b[0].mean(), pos_b[1].mean()]) 163 | 164 | # Get center of predictions 165 | if pos_a[0].shape[0] != 0: 166 | center_a = np.array([pos_a[0].mean(), pos_a[1].mean()]) 167 | # If no pixels above threshold, add the highest possible error in image space 168 | else: 169 | dsts[n, t] = np.sqrt(np.sum(np.array([height, width]) ** 2)) 170 | continue 171 | 172 | # Get distance metric 173 | dist = np.sum((center_a - center_b) ** 2) 174 | dist = np.sqrt(dist) 175 | 176 | # Add to result 177 | dsts[n, t] = dist 178 | 179 | # Get the VPD calculation 180 | B, T = dsts.shape 181 | vpdist = np.zeros(B) 182 | for i in range(B): 183 | idx = np.where(dsts[i, :] >= epsilon)[0] 184 | if idx.shape[0] > 0: 185 | vpdist[i] = np.min(idx) 186 | else: 187 | vpdist[i] = T 188 | 189 | # Return VPT mean over the total timesteps 190 | return np.mean(vpdist) / T, np.std(vpdist) / T 191 | 192 | 193 | def reconstruction_mse(output, target, **kwargs): 194 | """ Gets the mean of the per-pixel MSE for the given length of timesteps used for training """ 195 | full_pixel_mses = (output[:, :kwargs['length']] - target[:, :kwargs['length']]) ** 2 196 | sequence_pixel_mse = np.mean(full_pixel_mses, axis=(1, 2, 3)) 197 | return np.mean(sequence_pixel_mse), np.std(sequence_pixel_mse) 198 | 199 | 200 | def extrapolation_mse(output, target, **kwargs): 201 | """ Gets the mean of the per-pixel MSE for a number of steps past the length used in training """ 202 | full_pixel_mses = (output[:, kwargs['cfg'].train_length:] - target[:, kwargs['cfg'].train_length:]) ** 2 203 | if full_pixel_mses.shape[1] == 0: 204 | return 0.0, 0.0 205 | 206 | sequence_pixel_mse = np.mean(full_pixel_mses, axis=(1, 2, 3)) 207 | return np.mean(sequence_pixel_mse), np.std(sequence_pixel_mse) 208 | 209 | 210 | def r2fit(latents, gt_state, mlp=False): 211 | """ 212 | Computes an R^2 fit value for each ground truth physical state dimension given the latent states at each timestep. 213 | Gets an average per timestep. 214 | :param latents: latent states at each timestep [BatchSize, TimeSteps, LatentSize] 215 | :param gt_state: ground truth physical parameters [BatchSize, TimeSteps, StateSize] 216 | :param mlp: whether to use a non-linear MLP regressor instead of linear regression 217 | """ 218 | # Get first dimension of states for evaluation 219 | sins = np.sin(gt_state[:, :, 0]) 220 | coss = np.cos(gt_state[:, :, 0]) 221 | gt_states = np.stack((sins, coss, gt_state[:, :, 1]), axis=2) 222 | 223 | # Ensure on CPU and numpy 224 | if not isinstance(latents, np.ndarray): 225 | latents = latents.cpu().numpy() 226 | gt_state = gt_state.cpu().numpy() 227 | 228 | # Convert to one large set of latent states 229 | latents = latents.reshape([latents.shape[0] * latents.shape[1], -1]) 230 | gt_state = gt_state.reshape([gt_state.shape[0] * gt_state.shape[1], -1]) 231 | 232 | # For each dimension of gt_state, get the R^2 value 233 | r2s = [] 234 | for sidx in range(gt_state.shape[-1]): 235 | gts = gt_state[:, sidx] 236 | 237 | # Whether to use LinearRegression or an MLP 238 | if mlp: 239 | reg = MLPRegressor().fit(latents, gts) 240 | else: 241 | reg = LinearRegression().fit(latents, gts) 242 | 243 | r2s.append(reg.score(latents, gts)) 244 | 245 | # Return r2s for logging 246 | return r2s 247 | 248 | 249 | def normalized_pixel_mse(gt, preds): 250 | """ 251 | Handles getting the pixel MSE of a trajectory, but normalizes over the average intensity of the ground truth. 252 | This helps to be able to compare pixel MSE effectively over different domains rather than looking at pure intensity. 253 | :param gt: ground truth sequence [BS, TS, Dim1, Dim2] 254 | :param preds: predictions of the model [BS, TS, Dim1, Dim2] 255 | 256 | TODO: Make sure this metric calculation matches WhichPriorsMatter? 257 | """ 258 | mse = nn.MSELoss(reduction='none')(gt, preds) 259 | mse = mse / torch.mean(gt ** 2) 260 | return mse.detach().cpu().numpy(), mse.mean([1, 2, 3]).mean().detach().cpu().numpy() 261 | -------------------------------------------------------------------------------- /utils/plotting.py: -------------------------------------------------------------------------------- 1 | """ 2 | @file plotting.py 3 | 4 | Holds general plotting functions for reconstructions of the bouncing ball dataset 5 | """ 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | 9 | 10 | def show_images(images, preds, out_loc, num_out=None): 11 | """ 12 | Constructs an image of multiple time-series reconstruction samples compared against its relevant ground truth 13 | Saves locally in the given out location 14 | :param images: ground truth images 15 | :param preds: predictions from a given model 16 | :out_loc: where to save the generated image 17 | :param num_out: how many images to stack. If None, stack all 18 | """ 19 | assert len(images.shape) == 4 # Assert both matrices are [Batch, Timesteps, H, W] 20 | assert len(preds.shape) == 4 21 | assert type(num_out) is int or type(num_out) is None 22 | 23 | # Make sure objects are in numpy format 24 | if not isinstance(images, np.ndarray): 25 | images = images.cpu().numpy() 26 | preds = preds.cpu().numpy() 27 | 28 | # Splice to the given num_out 29 | if num_out is not None: 30 | images = images[:num_out] 31 | preds = preds[:num_out] 32 | 33 | # Iterate through each sample, stacking into one image 34 | out_image = None 35 | for idx, (gt, pred) in enumerate(zip(images, preds)): 36 | # Pad between individual timesteps 37 | gt = np.pad(gt, pad_width=( 38 | (0, 0), (5, 5), (0, 1) 39 | ), constant_values=1) 40 | 41 | gt = np.hstack([i for i in gt]) 42 | 43 | # Pad between individual timesteps 44 | pred = np.pad(pred, pad_width=( 45 | (0, 0), (0, 10), (0, 1) 46 | ), constant_values=1) 47 | 48 | # Stack timesteps into one image 49 | pred = np.hstack([i for i in pred]) 50 | 51 | # Stack gt/pred into one image 52 | final = np.vstack((gt, pred)) 53 | 54 | # Stack into out_image 55 | if out_image is None: 56 | out_image = final 57 | else: 58 | out_image = np.vstack((out_image, final)) 59 | 60 | # Save to out location 61 | plt.imsave(out_loc, out_image, cmap='gray') 62 | 63 | 64 | def get_embedding_trajectories(embeddings, states, out_loc): 65 | """ 66 | Handles getting trajectory plots of the embedded states against the true physical states 67 | :param embeddings: vector states over time 68 | :param states: ground truth physical parameter states 69 | """ 70 | # Make sure objects are in numpy format 71 | if not isinstance(embeddings, np.ndarray): 72 | embeddings = embeddings.cpu().numpy() 73 | states = states.cpu().numpy() 74 | 75 | # Get embedding trajectory plots 76 | for idx, embedding in enumerate(np.swapaxes(embeddings, 1, 0)): 77 | plt.plot(embedding, label=f"Dim {idx}") 78 | 79 | plt.title("Vector State Trajectories") 80 | plt.savefig(f"{out_loc}/trajectories_embeddings.png") 81 | plt.close() 82 | 83 | # Get physical state trajectories 84 | for idx, embedding in enumerate(np.swapaxes(states, 1, 0)): 85 | plt.plot(embedding, label=f"Dim {idx}") 86 | 87 | plt.title("GT State Trajectories") 88 | plt.savefig(f"{out_loc}/trajectories_gt.png") 89 | plt.close() 90 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | @file utils.py 3 | 4 | Utility functions across files 5 | """ 6 | import math 7 | import torch.nn as nn 8 | 9 | from omegaconf import DictConfig, OmegaConf 10 | from torch.optim.lr_scheduler import _LRScheduler 11 | 12 | 13 | def flatten_cfg(cfg: DictConfig): 14 | """ Utility function to flatten the primary submodules of a Hydra config """ 15 | # Disable struct flag on the config 16 | OmegaConf.set_struct(cfg, False) 17 | 18 | # Loop through each item, merging with the main cfg if its another DictConfig 19 | for key, value in cfg.copy().items(): 20 | if isinstance(value, DictConfig): 21 | cfg.merge_with(cfg.pop(key)) 22 | 23 | # Do it a second time for nested cfgs 24 | for key, value in cfg.copy().items(): 25 | if isinstance(value, DictConfig): 26 | cfg.merge_with(cfg.pop(key)) 27 | 28 | print(cfg) 29 | return cfg 30 | 31 | 32 | def get_model(name): 33 | """ Import and return the specific latent dynamics function by the given name""" 34 | # Lowercase name in case of misspellings 35 | name = name.lower() 36 | 37 | ## Group A Models 38 | if name == "vrnn": 39 | from models.group_a.VRNN import VRNN 40 | return VRNN 41 | 42 | if name == "dkf": 43 | from models.group_a.DKF import DKF 44 | return DKF 45 | 46 | ## Group B1 Models 47 | if name == "kvae": 48 | from models.group_b1.KVAE import KVAE 49 | return KVAE 50 | 51 | ## Group B2 Models 52 | if name == "node": 53 | from models.group_b2.NeuralODE import NeuralODE 54 | return NeuralODE 55 | 56 | if name == "rgnres": 57 | from models.group_b2.RGNRes import RGNRes 58 | return RGNRes 59 | 60 | # Given no correct model type, raise error 61 | raise NotImplementedError("Model type {} not implemented.".format(name)) 62 | 63 | 64 | def get_act(act="relu"): 65 | """ 66 | Return torch function of a given activation function 67 | :param act: activation function 68 | :return: torch object 69 | """ 70 | if act == "relu": 71 | return nn.ReLU() 72 | elif act == "leaky_relu": 73 | return nn.LeakyReLU(0.1) 74 | elif act == "sigmoid": 75 | return nn.Sigmoid() 76 | elif act == "tanh": 77 | return nn.Tanh() 78 | elif act == "linear": 79 | return nn.Identity() 80 | elif act == 'softplus': 81 | return nn.modules.activation.Softplus() 82 | elif act == 'softmax': 83 | return nn.Softmax() 84 | elif act == "swish": 85 | return nn.SiLU() 86 | else: 87 | return None 88 | 89 | 90 | def determine_annealing_factor(n_updates, min_anneal_factor=0.0, anneal_update=10000): 91 | """ 92 | Handles annealing the KL restriction over a number of update steps to slowly introduce the regularization 93 | to ensure a strong initial fit has been set 94 | :param min_anneal_factor: minimum 95 | :param anneal_update: over how long of updates to apply the annealing factor 96 | :param epoch: current epoch number 97 | :param n_batch: number of total batches within an epoch 98 | :param batch_idx: current batch idx within the epoch 99 | :return: weight of the kl annealing factor for the loss term 100 | """ 101 | if anneal_update > 0 and n_updates < anneal_update: 102 | anneal_factor = min_anneal_factor + \ 103 | (1.0 - min_anneal_factor) * ( 104 | (n_updates / anneal_update) 105 | ) 106 | else: 107 | anneal_factor = 1.0 108 | return anneal_factor 109 | 110 | 111 | class CosineAnnealingWarmRestartsWithDecayAndLinearWarmup(_LRScheduler): 112 | r"""Set the learning rate of each parameter group using a cosine annealing 113 | schedule, where :math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}` 114 | is the number of epochs since the last restart and :math:`T_{i}` is the number 115 | of epochs between two warm restarts in SGDR: 116 | """ 117 | 118 | def __init__(self, optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1, verbose=False, warmup_steps=350, decay=1): 119 | if T_0 <= 0 or not isinstance(T_0, int): 120 | raise ValueError("Expected positive integer T_0, but got {}".format(T_0)) 121 | if T_mult < 1 or not isinstance(T_mult, int): 122 | raise ValueError("Expected integer T_mult >= 1, but got {}".format(T_mult)) 123 | self.T_0 = T_0 124 | self.T_i = T_0 125 | self.T_mult = T_mult 126 | self.eta_min = eta_min 127 | self.T_cur = last_epoch 128 | super(CosineAnnealingWarmRestartsWithDecayAndLinearWarmup, self).__init__(optimizer, last_epoch, verbose) 129 | 130 | # Decay attributes 131 | self.decay = decay 132 | self.initial_lrs = self.base_lrs 133 | 134 | # Warmup attributes 135 | self.warmup_steps = warmup_steps 136 | self.current_steps = 0 137 | 138 | def get_lr(self): 139 | return [ 140 | (self.current_steps / self.warmup_steps) * 141 | (self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * self.T_cur / self.T_i)) / 2) 142 | for base_lr in self.base_lrs 143 | ] 144 | 145 | def step(self, epoch=None): 146 | """Step could be called after every batch update""" 147 | if epoch is None and self.last_epoch < 0: 148 | epoch = 0 149 | 150 | if self.T_cur + 1 == self.T_i: 151 | if self.verbose: 152 | print("multiplying base_lrs by {:.4f}".format(self.decay)) 153 | self.base_lrs = [base_lr * self.decay for base_lr in self.base_lrs] 154 | 155 | if epoch is None: 156 | epoch = self.last_epoch + 1 157 | self.T_cur = self.T_cur + 1 158 | 159 | if self.current_steps < self.warmup_steps: 160 | self.current_steps += 1 161 | 162 | if self.T_cur >= self.T_i: 163 | self.T_cur = self.T_cur - self.T_i 164 | self.T_i = self.T_i * self.T_mult 165 | 166 | self.last_epoch = math.floor(epoch) 167 | 168 | class _enable_get_lr_call: 169 | 170 | def __init__(self, o): 171 | self.o = o 172 | 173 | def __enter__(self): 174 | self.o._get_lr_called_within_step = True 175 | return self 176 | 177 | def __exit__(self, type, value, traceback): 178 | self.o._get_lr_called_within_step = False 179 | return self 180 | 181 | with _enable_get_lr_call(self): 182 | for i, data in enumerate(zip(self.optimizer.param_groups, self.get_lr())): 183 | param_group, lr = data 184 | param_group['lr'] = lr 185 | self.print_lr(self.verbose, i, lr, epoch) 186 | 187 | self._last_lr = [group['lr'] for group in self.optimizer.param_groups] 188 | --------------------------------------------------------------------------------