, x<i) is weakened, and without
134 | near-term observations to guide the dynamics function, its long-horizon forecasting is limited[1,3].
135 |
136 | System parameters as latent variables (System Identification): Rather than system states, one can instead choose
137 | to select the parameters (denoted as θz in Equation [1](#ssmEQ)). With this change, the resulting PGM
138 | is represented in Fig. [1B](#pgmSchematic) and its marginal likelihood over x0:T is represented now by:
139 |
140 | 
141 |
142 | where the resulting output observations are derived from an initial latent state z0 and the dynamics
143 | parameters θz. As before, a variational approximation is considered for inference in place of an
144 | intractable posterior but now for the density q(θz, z0) instead. Given
145 | prior density assumptions of p(θz) and p(z0) in a similar vein as
146 | above, the ELBO function for this PGM is constructed as:
147 |
148 | 
149 | where again the first term is a reconstruction likelihood and the terms following represent KL-Divergence losses over
150 | the inferred variables.
151 |
152 |
153 | The given formulation here is the most general form for this line of models and other works can be covered under
154 | special assumptions or restrictions of how q(θz) and p(θz) are
155 | modeled. Original Neural SSM parameter works consider Linear-Gaussian SSMs as the transition function and
156 | introduce non-linearity by varying the transition parameters over time as θz0:T[1,2,3].
157 | However, as shown in Fig. [2B1](#latentSchematic), the result of this results in convoluted temporal
158 | modeling and devolves into the same state estimation problem as now the time-varying parameters rely on near-term
159 | observations for correctness[8,20]. Rather than time-varying, the system parameters could be considered
160 | an optimized global variable, in which the underlying dynamics function becomes a Bayesian neural network in a VAE's
161 | latent space[5] and is shown in Fig. [2B2](#latentSchematic). Restricting these parameters to
162 | be deterministic results in a model of the form presented in Latent ODE[10]. The furthest restriction in
163 | forgoing stochasticity in the inference of z0 results in the suite of models as presented in [4].
164 |
165 |
166 | Regardless of the precise assumptions, this framework builds a strong latent dynamics function that enables long-term
167 | forecasting and, in some settings, even full-scale system identification[1,4] of physical systems. This is
168 | done at the cost of a harder inference task given no access to dynamics correction during generation and for full
169 | identification tasks, often requires a large number of training samples over the potential system state space[4,5].
170 |
171 |
172 |
173 | ## Latent Initial State Z0 / Zinit
174 |
175 | As the transition dynamics and the observation space are intentionally disconnected in this framework,
176 | the problem of inferring a strong initial latent state from which to forecast is an important consideration
177 | when designing a neural state-space model[30]. This is primarily a task- and data-dependent choice,
178 | in which the architecture follows the data structure. Thankfully, much work has been done in other research
179 | directions on designing good latent encoding models. As such, works in this area often draw from them.
180 | This section is split into three parts - one on the usual architecture for high-dimensional image tasks,
181 | one on lower-dimensional and/or miscellaneous encoders, and one on the different forms of inference for the initial
182 | state depending on which sequence portions are observed.
183 |
184 | Image-based Encoders: Unsurprisingly, the common architecture used in latent image encoding is a convolutional
185 | neural network (CNN) given its inherent bias toward spatial feature extraction[1,3,4,5]. Works are mixed
186 | between either having the sequential input reshaped as frames stacked over the channel dimension or simply running
187 | the CNN over each observed frame separately and passing the concatenated embedding into an output layer. Regardless
188 | of methodology, a few frames are assumed as observations for initialization, as multiple timesteps are required to
189 | infer the initial system movement. A subset of works considers second-order latent vector spaces, in which the
190 | encoder is explicitly split into two individual position and momenta functions, taking single and multiple frames
191 | respectively[5].
192 |
193 |
194 | 
195 | Fig N. Visualization of the stacked initial state encoder, modified from [23].
196 |
197 |
198 | Alternate Encoders: In settings with non-image-based inputs, the initial latent encoder can take on a large variety of forms, ranging anywhere from simple linear/MLP networks in physical systems[5] to graph convolution networks for latent medical image forecasting[20]. Multi-modal and dynamics conditioning inputs can be leveraged via combinations of encoders whose embeddings go through a shared linear function.
199 |
200 |
201 | 
202 | Fig N. Visualization of the stacked graph convolutional encoder, modified from [24].
203 |
204 |
205 | Variables z0, zk, and zinit:
206 | Beyond just the inference of this latent variable, there is one more variation that can be seen throughout literature -
207 | that of which portions of the input sequence are observed and used in the initial state inference.
208 |
209 | Generally, there are 3 forms seen:
210 |
211 | - z0 - which uses a sequence x0:k to get z0.
212 | - zk - which uses previous frames x-k:k to get zk to go forward past observed frames.
213 | - zinit - which uses the sequence x0:k to get an abstract initial vector state that the dynamics function starts from.
214 |
215 |
216 | Throughout literature, these variable names as shown here aren't used (as most works just call it z0
217 | and describe its inference) but we differentiate it specifically to highlight the distinctions.
218 | For training purspoes, it is a subtle distinction but potentially has implications for the resulting l
219 | ikelihood optimization and learned vector space.
220 |
221 |
222 | 
223 | Fig N. Schematic of the difference between z0 and zinit formulations.
224 |
225 | Saying that, generally there is a lack of work exploring the considerations for each approach, besides ad-hoc solutions to bridge
226 | the gap between the latent encoder and dynamics function distributions[5]. This gap can stem from
227 | optimization problems caused by imbalanced reconstruction terms between dynamics and initial states or in cases
228 | where the initial state distribution is far enough away from the data distribution of downstream frames.
229 |
230 | However, a recent work "Learning Neural State-Space Models: Do we need a state estimator?" [30] is the first detailed
231 | study into the considerations of initial state inference, providing ablations across increasing difficulties of
232 | datasets and inference forms. In their work, they found that to get competitive performance of neural SSMs on some
233 | dynamical systems, more advanced architectures were required (feed-forward or LSTM networks). Notably, they only evaluate
234 | on the zk form, varying architectural choices.
235 |
236 | A variety of empirical techniques have been proposed to tackle this gap, much in the same spirit of empirical
237 | VAE stability 'tricks.' These include separated x0 and x1:T terms
238 | (where x0 has a positive weighting coefficient), VAE pre-training for x0,
239 | and KL-regularization terms between the output distributions of the encoder and the dynamics flow[1,5].
240 | One personal intuition regarding these two variable approaches and the tricks applied is that there exists
241 | a theoretical trade-off between the two formulations and the tricks applied help to empirically alleviate the
242 | shortcomings of either approach. This, however, requires experimentation and validation before any claims can be made.
243 |
244 |
245 |
246 | ## Reconstruction vs. Extrapolation
247 |
248 | There are three important phases during the forecasting for a neural SSM, that of initial state inference, reconstruction,
249 | and extrapolation.
250 |
251 |

252 | Fig N. Breakdown of the three forecasting phases - initial state inference, reconstruction, and extrapolation.
253 |
254 | Initial State Inference: Inferring the initial state and how many frames are required to get a good
255 | initialization is fairly domain/problem specific, as each problem may require more or less time to highlight
256 | distinctive patterns that enable effective dynamics separation.
257 |
258 | Reconstruction: The former refers to the number of timesteps that are used in training, from which the
259 | likelihood term is calculated. So far in works, there is no generally agreed upon standard on how many steps to use
260 | in this and works can be seen using anywhere from 1 (i.e. next-step prediction) to 60 frames in this portion[4].
261 | Some works frame this as a hyper-parameter to tune in experiments and there is a consideration of computational cost
262 | when scaling up to longer sequences. In our experiments, we've noticed a linear scaling in training time w.r.t.
263 | this sequence length. (TO-DO) In the Experiments section, we perform an ablation study on how fixed lengths of reconstruction affects the
264 | extrapolation ability of models on Hamiltonian systems.
265 |
266 | Extrapolation: This phase refers to the arbitrarily long forecasting of frames that goes beyond the length used
267 | in the likelihood term during training. It represents whether a model has captured the system dynamics sufficiently to
268 | enable long-term forecasting or model energy decay in non-conserving systems. For specific
269 | dynamical systems, this can be a difficult task as, at base, there is no training signal to inform the model to learn
270 | good extrapolation. Works often highlight metrics independently on reconstruction and extrapolation phases to highlight
271 | a model's strength of identification[4].
272 |
273 | Training Considerations: It is important to note that the exact structure of how the likelihood loss is formulated
274 | plays a role in how this sequence length may affect extrapolation. Having your likelihood incorporate temporal information
275 | (e.g. summation over the sequence, trajectory mean, etc.) can have a detrimental effect on extrapolation as the model
276 | optimizes with respect to the fixed reconstruction length. Figure N highlights an example of using temporal information
277 | in a likelihood term, where there is near flawless reconstruction but immediate forecasting failure when going towards
278 | extrapolation.
279 |
280 | 
281 | Fig N. Example of failed extrapolation given an incorrect likelihood term. Red highlights beginning of extrapoolation.
282 |
283 | As well, it is often the case where the reconstruction training metrics (e.g. likelihood/pixel MSE) and visualizations
284 | will often show strong convergence despite still poor extrapolation. It can sometimes be the case, especially in Neural ODE
285 | latent dynamics, that more training than expected is required to enable strong extrapolation. It is an intuition that
286 | the vector field may require a longer optimization than just the reconstruction convergence to be
287 | robust against error accumulation that impacts long-horizon forecasting.
288 |
289 | 
290 | Fig N. Training vs. Validation pixel MSE metrics, highlight the continued extrapolation learning past training "convergence."
291 |
292 |
293 | Tips for training good extrapolation in these models include:
294 |
295 | - Perform extrapolation in your validation steps such that there is a metric to highlight extrapolation learning over training.
296 | - Use per-frame averages in the likelihood function rather than any form with temporal information.
297 | - Use variable lengths of reconstruction during training, sampling 1-T frames to reconstruct in a given batch.
298 | - If you have long sequences, especially in non-conserving systems, sample a random starting point per batch.
299 | - Train for longer than you might expect, even when training metrics have converged for "reconstruction."
300 | - The integrator choice can affect this, as non-symplectic integrators have known error accumulation which affects the vector state over long horizons[4]
301 |
302 |
303 |
304 |
305 | ## Use of System Controls, ut
306 |
307 | Insofar we have ignored another common and important component of state-space modeling, the incorporation of external
308 | controls u that affect the transition function of the state. Controls represent factors that influence the
309 | trajectory of a system but are not direct features of the object/system being modeled. For example, an external
310 | force such as friction acting on a moving ball or medications given to a patient could be considered
311 | controls[8,14]. These allow an additional layer of interpretability in SSMs and even enable counterfactual
312 | reasoning; i.e., given the current state, what does its trajectory look like under varying control inputs going
313 | forwards? This has myriad uses in medical modeling with counterfactual medicine[14] or physical system
314 | simulations[8].
315 |
316 |
317 | For Neural SSMs, a variety of approaches have been taken thus far depending on the type of latent transition function used.
318 |
319 |
320 | Linear Dynamics: In latent dynamics still parameterized by traditional linear gaussian transition functions,
321 | control incorporation is as easy as the addition of another transition matrix Bt that modifies a
322 | control input ut at each timestep[1,2,4,7].
323 |
324 | 
325 | Fig N. Example of control input in a linear transition function[1].
326 |
327 |
328 | Non-Linear Dynamics: In discrete non-linear transition matrices using either multi-layer perceptrons or
329 | recurrent cells, these can be leveraged by either concatenating it to the input vector before the network forward
330 | pass or as a data transformation in the form of element-wise addition and a weighted combination[10].
331 |
332 | 
333 | Fig N. Example of control input in a non-linear transition function[1].
334 |
335 |
336 | Continuous Dynamics: For incorporation into continuous latent dynamics functions, finding the best approaches
337 | is an ongoing topic of interest. Thus far, the reigning approaches are:
338 |
339 | 1. Directly jumping the vector field state with recurrent cells[18]
340 | 
341 |
342 | 2. Influencing the vector field gradient (e.g. neural controlled differential equations)[17]
343 | 
344 |
345 | 3. Introducing another dynamics mechanism, continuous or otherwise (e.g. neural ODE or attention blocks), that is combined with the latent trajectory z1:T into an auxiliary state h1:T[8,14,25].
346 | 
347 | Fig N. Visualization of the IMODE architecture, taken from [8].
348 |
349 |
350 |
352 |
353 |
354 |
358 |
359 |
360 |
361 |
362 | # Implementation
363 |
364 | In this section, specifics on model implementation and the datasets/metrics used are detailed. Specific data generation details are available in the URLs provided for each dataset. The models and datasets used throughout this repo are solely grayscale physics datasets with underlying Hamiltonian laws, such as pendulum and mass-spring sets. Extensions to color images and non-pixel-based tasks (or even graph-based data!) are easily done in this framework, as the only architecture change needed is the structure of the encoder and decoder networks as the state propagation happens solely in a latent space.
365 |
366 | The project's folder structure is as follows:
367 |
368 | ```
369 | torchssm/
370 | │
371 | ├── main.py # Training entry point that takes in user args and handles boilerplate
372 | ├── tune.py # Performs a hyperparameter search for a given dataset using Ray[Tune]
373 | ├── README.md # What you're reading right now :^)
374 | ├── requirements.txt # Anaconda requirements file to enable easy setup
375 | ├── configs/
376 | │ ├── dataset/ # Dataset config files
377 | │ ├── model/ # Model-specific hyperparameters
378 | │ ├── training/ # Training-setting parameters
379 | | └── config.yaml # Base config file
380 | ├── data/
381 | | ├── # Name of the stored dynamics dataset (e.g. pendulum)
382 | | ├── generate_bouncingball.py # Dataset generation script for Bouncing Ball
383 | | ├── generate_hamiltonian.py # Dataset generation script for Hamiltonian Dynamics
384 | | └── visualize_data.py # Visualize trajectories of a dataset
385 | ├── experiments/
386 | | └── # Name of the dynamics model run
387 | | └── # Given name for the ran experiment
388 | | └── / # Each experiment type has its sequential lightning logs saved
389 | ├── models/
390 | │ ├── CommonDynamics.py # Abstract PyTorch-Lightning Module to handle train/test loops
391 | │ ├── CommonVAE.py # Shared encoder/decoder Modules for the VAE portion
392 | │ ├── group_a/
393 | │ └── ... # State-Estimation Examples
394 | │ └── group_b1/
395 | │ └── ... # Time-Varying System-Identification Examples
396 | │ └── group_b2/
397 | │ └── ... # Time-Invariant System-Identification Examples
398 | ├── utils/
399 | │ ├── dataloader.py # Dataloaders used in train/val/testing
400 | │ ├── layers.py # PyTorch Modules that represent general network layers
401 | │ ├── metrics.py # Metric functions for evaluation
402 | │ ├── plotting.py # Plotting functions for visualizatin
403 | | └── utils.py # General utility functions (e.g. argparsing, etc)
404 | └──
405 | ```
406 |
407 |
408 |
409 | ## Config Management with Hydra
410 |
411 | This repository uses [Hydra](https://hydra.cc/) to manage configurations for running experiments, ensuring modularity, scalability, and ease of experimentation. Hydra allows for a hierarchical organization of configuration files, making it straightforward to adjust parameters for datasets, models, and training setups without editing the core code.
412 |
413 | The configuration files are organized under the `configs/` directory and follow a modular design:
414 |
415 | - **`dataset/`**: Contains configuration files specific to datasets, such as paths, preprocessing steps, and dataset-specific parameters. For example, `pendulum.yaml` might specify the length of the pendulum and time step resolution.
416 | - **`model/`**: Defines model-specific parameters like architecture, latent dimension size, and dynamics function hyperparameters.
417 | - **`training/`**: Manages training-specific settings, such as learning rate, batch size, number of epochs, and checkpointing options.
418 | - **`config.yaml`**: The base configuration file that aggregates default values and sets up common parameters shared across experiments.
419 |
420 | Usage: Hydra facilitates experiment customization and parameter sweeping with ease. By specifying the desired configuration components, users can dynamically compose configurations at runtime. For instance:
421 |
422 | ```bash
423 | python main.py dataset=pendulum model=node training=default
424 | ```
425 |
426 |
427 |
428 |
429 | ## Data
430 |
431 | All data used throughout these experiments are available for download here on Google Drive, in which they already come in their .npz forms. Feel free to generate your own sets using the provided data scripts!
432 |
433 | Hamiltonian Dynamics: Provided are a dataloader and generation scripts for DeepMind's Hamiltonian Dynamics
434 | suite[4], a simulation library for 17 different physics datasets that have known underlying Hamiltonian dynamics.
435 | It comes in the form of color image sequences of arbitrary length, coupled with the system's ground truth states (e.g., for pendulum, angular velocity and angle). It is well-benchmarked and customizable, making it a perfect testbed for latent dynamics function evaluation. For each setting, the physical parameters are tweakable alongside an optional friction coefficient to construct non-energy conserving systems. The location of focal points and the color of the objects are all individually tuneable, enabling mixed and complex visual datasets of varying latent dynamics.
436 |
437 | 
438 | Fig N. Pendulum-Colors Examples.
439 |
440 | For the base presented experiments of this dataset, we consider and evaluate grayscale versions of pendulum and
441 | mass-spring - which conveniently are just the sliced red channel of the original sets. Each set has 10000
442 | training and 1000
testing trajectories sampled at Δt = 1
time intervals. Energy conservation
443 | is preserved without friction and we assume constant placement of focal points for simplicity. Note that the
444 | modification to color outputs in this framework is as simple as modifying the number of channels in the
445 | encoder and decoder.
446 |
447 |
448 | Bouncing Balls: Additionally, we provide a dataloader and generation scripts for the standard latent dynamics
449 | dataset of bouncing balls[1,2,5,7,8], modified from the implementation in
450 | [1]. It consists of a ball or multiple
451 | balls moving within a bounding box while being affected by potential external effects, e.g. gravitational
452 | forces[1,2,5], pong[2], and interventions[8]. The starting position, angle,
453 | and velocity of the ball(s) are sampled uniformly between a set range. It is generated with the
454 | PyMunk and PyGame libraries.
455 | In this repository, we consider two sets - a simple set of one gravitational force and a mixed set of 4 gravitational
456 | forces in the cardinal directions with varying strengths. We similarly generate 10000
training and
457 | 1000
testing trajectories, however sample them at Δt = 0.1
intervals.
458 |
459 | 
460 | Fig N. Single Gravity Bouncing Ball Example.
461 |
462 |
463 | Notably, this system is surprisingly difficult to successfully perform long-term generation on, especially in cases
464 | of mixed gravities or multiple objects. Most works only report on generation within 5-15 timesteps following a
465 | period of 3-5 observation timesteps[1,2,7] and farther timesteps show lost trajectories and/or
466 | incoherent reconstructions.
467 |
468 |
469 | Meta-Learning Datasets: One of the latest research directions for neural SSMs is evaluating the potential of
470 | meta-learning to build domain-adaptable latent dynamics functions[26,27,29]. A representative dataset
471 | example for this task is the Turbulent Flow dataset that is affected by various buoyancy forces, highlighting a
472 | task with partially shared yet heterogeneous dynamics[27].
473 |
474 | 
475 | Fig N. Turbulent Flow Example, sourced from [27].
476 |
477 |
478 | Multi-System Dynamics: So far in the literature, the majority of works only consider training Neural SSMs on
479 | one system of dynamics at a time - with the most variety lying in that of differing trajectory hyper-parameters.
480 | The ability to infer multiple dynamical systems under one model (or learn to output dynamical functions given
481 | system observations) and leverage similarities between the sets is an ongoing research pursuit - with applications
482 | of neural unit hypernetworks[27] and dynamics functions conditioned on sequences via
483 | meta-learning[26,29] being the first dives into this.
484 |
485 |
486 | Other Sets in Literature: Outside of the previous sets, there are a plethora of other datasets that have been
487 | explored in relevant work. The popular task of human motion prediction in both the pose estimation and video
488 | generation setting has been considered via datasets
489 | Human3.6Mil,
490 | CMU Mocap, and
491 | Weizzman-Action[5,19],
492 | though proper experimentation into this area would require problem-specific architectures given the depth of the
493 | existing field. Past high-dimensionality and image-space, standard benchmark datasets in time-series forecasting
494 | have also been considered, including the M4,
495 | Electricity Transformer Temperature (ETT), and
496 | the NASA Turbofan Degradation set.
497 | Recent works have begun looking at medical applications in inverse image reconstruction and the incorporation of
498 | physics-inspired priors[20,29y ]. Regardless of the success of Neural SSMs on these tasks, the task-agnostic
499 | factor and principled structure of this framework make it a versatile and appealing option for generative time-series modeling.
500 |
501 |
502 |
503 |
504 | ## Models
505 |
506 | Here, details on how the model implementation is structured and running experiments locally are given. As well,
507 | an overview of the abstract class implementation for a general Neural SSM and its types are explained.
508 |
509 | ### Implementation Structure
510 | Provided within this repository is a PyTorch class structure in which an abstract PyTorch-Lightning Module is shared
511 | across all the given models, from which the specific VAE and dynamics functions inherit and override the relevant
512 | forward functions for training and evaluation. Swapping between dynamics functions and PGM type is as easy as passing
513 | in the model's name for arguments, e.g. `python3 main.py model=node`. As the implementation is provided in
514 | PyTorch-Lightning, an optimization and boilerplate
515 | library for PyTorch, it is recommended to be familiar at face-level.
516 |
517 |
518 | For every model run, a new experiment version under `experiments/` related to the passed in naming arguments. Hyperparameters passed in for this run are both stored in the Tensorboard instance created as well as in the local files hparams.yaml, config.json
. During training and validation sequences, all of the metrics below are automatically tracked and saved into a Tensorboard instance
519 | which can be used to compare different model runs following. Every 500 batches, reconstruction sequences against the
520 | ground truth for a set of samples are saved to the `experiments/` folder. Currently, only one checkpoint is saved based
521 | on the last epoch ran rather than checkpoints based on the best validation score or at set epochs. Restarting training
522 | from a checkpoint or loading in a model for testing is done currently by specifying the ckpt_path
to the
523 | base experiment folder and the checkpt
filename.
524 |
525 |
526 | The implemented dynamics functions are each separated into their respective PGM groups, however, they can still share
527 | the same general classes. Each dynamics subclass has a model_specific_loss
function that allows it to
528 | return additional loss values without interrupting the abstract flow. For example, this could be used in a flow-based
529 | prior that has additional KL terms over ODE flow density without needing to override the training_step
530 | function with a duplicate copy. As well, there is additionally model_specific_plotting
to enable custom
531 | plots every training epoch end.
532 |
533 | ### Implemented Dynamics
534 |
535 | Group A, 'State-Estimation': For the Group A PGM category, we provide a reimplementation of the classic Neural SSM work Deep Kalman Filter[7] alongside state estimation versions of the above, provided in Fig. N below. The DKF model modifies the standard Kalman Filter Gaussian transition function to incorporate non-linearity and expressivity by parameterizing the distribution parameters with neural networks zt∼N(G(zt−1,∆t), S(zt−1,∆t))
[7].
536 | Additionally, we provide a reimplementation of the Variational Recurrent Neural
537 | Network (VRNN), one of the starting state estimation works in Neural SSMs[22]. For the latent correction
538 | step, we leverage a standard Gated Recurrent Unit (GRU) cell and the corrected latent state is what is passed to the
539 | decoder and likelihood function. Notably, there are two settings these models can be run under: reconstruction
540 | and generation. Reconstruction is used for training and incorporates ground truth observations to correct
541 | the latent state while generation is used to test the forecasting abilities of the model, both short- and long-term.
542 |
543 |
544 | 
545 | Fig N. Model schematics for Group A's implemented dynamics functions.
546 |
547 |
548 | Group B1, 'Time-Varying System-Identification': For the Group B1 PGM category, we present a reimplementation of the Kalman Variational Autoencoder (KVAE)[2], a hybrid model combining Kalman Filter dynamics with a deep recognition network. The KVAE disentangles latent states into dynamics and representations by using a structured transition model paired with a learned recognition function. Specifically, the latent dynamics are modeled as xt ∼ N(A xt−1 + But, Q)
, where A
, B
, and Q
parameterize a linear dynamical system, while the recognition network maps observations into latent space.
549 | We provide both state estimation and generation capabilities. State estimation incorporates observed data into the inference process at every timestep. For generation, the model tests its forecasting ability in both short- and long-term scenarios by propagating dynamics without direct observation, instead using prior reconstructions as input.
550 |
551 |
552 | 
553 | Fig N. Model schematics for Group A's implemented dynamics functions.
554 |
555 |
556 | Group B2, 'Time-Invariant System-Identification': For the system identification models, we provide a variety of dynamics functions that resemble the general and special
557 | cases detailed above, which are provided in Fig N. below. The most general version is that of the Bayesian Neural ODE,
558 | in which a neural ordinary differential equation[21] is sampled from a set of optimized distributional
559 | parameters and used as the latent dynamics function
560 | z't = fp(θ)(zs)
[5]. A deterministic version
561 | of a standard Neural ODE is similarly provided, e.g.
562 | z't = fθ(zs)
[10,21]. Following that, two forms of a
563 | Recurrent Generative Network are provided, a residual version (RGN-Res) and a full-step version (RGN), that represent
564 | deterministic and discrete non-linear transition functions. RGN-Res is the equivalent of a Neural ODE using a fixed
565 | step Euler integrator while RGN is just a recurrent forward step function.
566 |
567 |
568 | Training for these models has one mode, that of taking in several observational frames to infer z0
569 | and then outputting a full sequence autonomously without access to subsequent observations. A likelihood function is
570 | compared over the full reconstructed sequence and optimized over. Testing and generation in this setting can be done
571 | out to any horizon easily.
572 |
573 |
574 | 
575 | Fig N. Model schematics for system identification's implemented dynamics functions.
576 |
577 |
578 |
579 |
580 |
581 |
582 | ## Metrics
583 |
584 | Mean Squared Error (MSE): A common metric used in video and image tasks where its use is in per-frame average over individual pixel error. While a multitude of papers solely uses plots of frame MSE over time as an evaluation metric, it is insufficient for comparison between models - especially in cases where the dataset contains a small object for reconstruction[4]. This is especially prominent in tasks of system identification where a model that fails to predict long-term may end up with a lower average MSE than a model that has better generation but is slightly off in its object placement.
585 |
586 | 
587 | Fig N. Per-Frame MSE Equation.
588 |
589 | Valid Prediction Time (VPT): Introduced in [4], the VPT metric is an advance on latent dynamics evaluation over pure pixel-based MSE metrics. For each prediction sequence, the per-pixel MSE is taken over the frames individually, and the minimum timestep in which the MSE surpasses a pre-defined epsilon is considered the 'valid prediction time.' The resulting mean number over the samples is often normalized over the total prediction timesteps to get a percentage of valid predictions.
590 |
591 | 
592 | Fig N. Per-Sequence VPT Equation.
593 |
594 | Object Distance (DST): Another potential metric to support evaluation (useful in image-based physics forecasting tasks) is using the Euclidean distance between the estimated center of the predicted object and its ground truth center. Otsu's Thresholding method can be applied to grayscale output images to get binary predictions of each pixel and then the average pixel location of all the "active" pixels can be calculated. This approach can help alleviate the prior MSE issues of metric imbalance as the maximum Euclidean error of a given image space can be applied to model predictions that fail to have any pixels over Otsu's threshold.
595 |
596 | 
597 | Fig N. Per-Frame DST Equation.
598 | where RN is the dimension of the output (e.g. number of image channels) and s, shat are the subsets of "active" predicted pixels.
599 |
600 |
601 | Valid Prediction Distance (VPD): Similar in spirit to how VPT leverages MSE, VPD is the minimum timestep in which the DST metric surpasses a pre-defined epsilon[29]. This is useful in tracking how long a model can generate an object in a physical system before either incorrect trajectories and/or error accumulation cause significant divergence.
602 |
603 | 
604 | Fig N. Per-Sequence VPD Equation.
605 |
606 |
607 | R2 Score: For evaluating systems where the full underlying latent system is available and known (e.g. image translations of Hamiltonian systems), the goodness-of-fit score R2 can be used per dimension to show how well the latent system of the Neural SSM captures the dynamics in an interpretable way[1,3]. This is easiest to leverage in linear transition dynamics.
608 |
609 | Ref. [1], while containing linear transition dynamics, mentioned the possibility of non-linear regression via vanilla neural networks, though this may run into concerns of regressor capacity and data sizes. Additionally, incorporating metrics derived from latent disentanglement learning may provide stronger evaluation capabilities.
610 |
611 | 
612 | Fig N. DVBF Latent Space Visualization for R2 scores, sourced from [1,3].
613 |
614 |
615 |
616 |
617 | # Experiments
618 |
619 | This section details some experiments that evaluate the fundamental aspects of Neural SSMs and the effects of the
620 | framework decisions one can take. Trained model checkpoints and hyperparameter files are provided for each experiment
621 | under experiments/model
. Evaluations are done with the metrics discussed above, as well as visualizations of
622 | animated trajectories over time and latent walk visualizations.
623 |
624 |
625 |
626 |
627 | ## Pendulum Dynamics
628 |
629 | Here we report the results of tuning each of the models on the Hamiltonian physics dataset Pendulum. For each model,
630 | we highlight their best-performing hyperparameters with respect to the validation extrapolation MSE. For experiment
631 | going forwards, these hyperparameters will be used in experiments of similar complexity.
632 |
633 | We test two environments for the Pendulum dataset, a fixed-point one-color pendulum and a multi-point multi-color
634 | pendulum set of increased complexity. As described in [4], each individual sequence is sampled from a uniform
635 | distribution over physical parameters like mass, gravity, and pivot length.
636 | We describe data generation above in the Data section.
637 |
638 | [TO-DO: Click to show the results for Fixed-Point Pendulum]
639 | Coming soon.
640 |
641 |
642 |
643 |
644 | [TO-DO: Click to show the results for Multi-Point Pendulum tuning]
645 | Coming soon.
646 |
647 |
648 |
649 |
650 |
651 | ## Bouncing Ball Dynamics
652 |
653 | Similar to above, we highlight the results and hyperparameters of each model for the Bouncing Ball dataset.
654 |
655 | [TO-DO: Click to show the results for Bouncing Ball tuning]
656 | Coming soon.
657 |
658 |
659 |
660 |
661 | ## Ablation Studies
662 |
663 | ODE Solvers: To measure the impact that ODE Solvers have on the optimized dynamics models, we performed an ablation on the available
664 | solvers that exist within the torchdiffeq library, including both fixed and adaptive solvers. We make note of
665 | their respective training times due to increased solution complexity and train each ODE solver over a variety of parameters
666 | depending on their type (e.g. step size or solution tolerances).
667 |
668 | [Click to show the results for the ODE Solver ablation]
669 | Coming soon.
670 |
671 |
672 |
673 |
674 | # Miscellaneous
675 |
676 | This section just consists of to-do's within the repo, contribution guidelines,
677 | and a section on how to find the references used throughout the repo.
678 |
679 |
680 |
681 | ## To-Do
682 |
683 | Ablations-to-do
684 |
685 | - Generation lengths used in training (e.g. 1/2/3/5/10 frames)
686 | - Fixed vs variable generation lengths
687 | - z0 inference scheme (no overlap, overlap-by-one, full overlap)
688 | - Use of ODE solvers (fixed, adaptive, tolerances)
689 | - Different forms of learning rate schedulers
690 | - Linear versus CNN decoder
691 | - Activation functions in the latent dynamics function
692 |
693 | Repository-wise
694 |
695 | Model-wise
696 |
697 | Evaluation-wise
698 |
699 | - Implement latent walk visualizations against data-space observations (like in DVBF)
700 |
701 | README-wise
702 |
703 | - Add guidelines for an ```Experiment``` section highlighting experiments performed in validating the models
704 | - Add a section explaining for ```Meta-Learning``` works in Neural SSMs
705 | - Add a section explaining for ```ODE Integrator``` considerations in Neural SSMs
706 |
707 |
708 |
709 | ## Contributions
710 | Contributions are welcome and encouraged! If you have an implementation of a latent dynamics function you think
711 | would be relevant and add to the conversation, feel free to submit an Issue or PR and we can discuss its
712 | incorporation. Similarly, if you feel an area of the README is lacking or contains errors, please put up a
713 | README editing PR with your suggested updates. Even tackling items on the To-Do would be massively helpful!
714 |
715 |
716 |
717 | ## References
718 | 1. Maximilian Karl, Maximilian Soelch, Justin Bayer, and Patrick van der Smagt. Deep variational bayes filters: Unsupervised learning of state space models from raw data. In International Conference on Learning Representations, 2017.
719 | 2. Marco Fraccaro, Simon Kamronn, Ulrich Paquetz, and OleWinthery. A disentangled recognition and nonlinear dynamics model for unsupervised learning. In Advances in Neural Information Processing Systems, 2017.
720 | 3. Alexej Klushyn, Richard Kurle, Maximilian Soelch, Botond Cseke, and Patrick van der Smagt. Latent matters: Learning deep state-space models. Advances in Neural Information Processing Systems, 34, 2021.
721 | 4. Aleksandar Botev, Andrew Jaegle, Peter Wirnsberger, Daniel Hennes, and Irina Higgins. Which priors matter? benchmarking models for learning latent dynamics. In Advances in Neural Information Processing Systems, 2021.
722 | 5. C. Yildiz, M. Heinonen, and H. Lahdesmaki. ODE2VAE: Deep generative second order odes with bayesian neural networks. In Neural Information Processing Systems, 2020.
723 | 6. Batuhan Koyuncu. Analysis of ode2vae with examples. arXiv preprint arXiv:2108.04899, 2021.
724 | 7. Rahul G. Krishnan, Uri Shalit, and David Sontag. Structured inference networks for nonlinear state space models. In Association for the Advancement of Artificial Intelligence, 2017.
725 | 8. Daehoon Gwak, Gyuhyeon Sim, Michael Poli, Stefano Massaroli, Jaegul Choo, and Edward Choi. Neural ordinary differential equations for intervention modeling. arXiv preprint arXiv:2010.08304, 2020.
726 | 9. Junyoung Chung, Kyle Kastner, Laurent Dinh, Kratarth Goel, Aaron Courville, and Yoshua Bengio. A recurrent latent variable model for sequential data. In Advances in Neural Information Processing Systems, 2015.
727 | 10. Yulia Rubanova, Ricky T. Q. Chen, and David Duvenaud. Latent odes for irregularly-sampled time series. In Neural Information Processing Systems, 2019.
728 | 11. Tsuyoshi Ishizone, Tomoyuki Higuchi, and Kazuyuki Nakamura. Ensemble kalman variational objectives: Nonlinear latent trajectory inference with a hybrid of variational inference and ensemble kalman filter. arXiv preprint arXiv:2010.08729, 2020.
729 | 12. Justin Bayer, Maximilian Soelch, Atanas Mirchev, Baris Kayalibay, and Patrick van der Smagt. Mind the gap when conditioning amortised inference in sequential latent-variable models. arXiv preprint arXiv:2101.07046, 2021.
730 | 13. Ðor ̄de Miladinovi ́c, Muhammad Waleed Gondal, Bernhard Schölkopf, Joachim M Buhmann, and Stefan Bauer. Disentangled state space representations. arXiv preprint arXiv:1906.03255, 2019.
731 | 14. Zeshan Hussain, Rahul G. Krishnan, and David Sontag. Neural pharmacodynamic state space modeling, 2021.
732 | 15. Francesco Paolo Casale, Adrian Dalca, Luca Saglietti, Jennifer Listgarten, and Nicolo Fusi.Gaussian process prior variational autoencoders. Advances in neural information processing systems, 31, 2018.
733 | 16. Yingzhen Li and Stephan Mandt. Disentangled sequential autoencoder. arXiv preprint arXiv:1803.02991, 2018.
734 | 17. Patrick Kidger, James Morrill, James Foster, and Terry Lyons. Neural controlled differential equations for irregular time series. Advances in Neural Information Processing Systems, 33:6696-6707, 2020.
735 | 18. Edward De Brouwer, Jaak Simm, Adam Arany, and Yves Moreau. Gru-ode-bayes: Continuous modeling of sporadically-observed time series. Advances in neural information processing systems, 32, 2019.
736 | 19. Ruben Villegas, Jimei Yang, Yuliang Zou, Sungryull Sohn, Xunyu Lin, and Honglak Lee. Learning to generate long-term future via hierarchical prediction. In international conference on machine learning, pages 3560–3569. PMLR, 2017
737 | 20. Xiajun Jiang, Ryan Missel, Maryam Toloubidokhti, Zhiyuan Li, Omar Gharbia, John L Sapp, and Linwei Wang. Label-free physics-informed image sequence reconstruction with disentangled spatial-temporal modeling. In International Conference on Medical Image Computing and Computer-Assisted Intervention, pages 361–371. Springer, 2021.
738 | 21. Ricky TQ Chen, Yulia Rubanova, Jesse Bettencourt, and David K Duvenaud. Neural ordinary differential equations. Advances in neural information processing systems, 31, 2018.
739 | 22. Junyoung Chung, Kyle Kastner, Laurent Dinh, Kratarth Goel, Aaron C Courville, and Yoshua Bengio. A recurrent latent variable model for sequential data. Advances in neural information processing systems, 28, 2015.
740 | 23. Junbo Zhang, Yu Zheng, and Dekang Qi. Deep spatio-temporal residual networks for citywide crowd flows prediction. In Thirty-first AAAI conference on artificial intelligence, 2017.
741 | 24. Yong Han, Shukang Wang, Yibin Ren, Cheng Wang, Peng Gao, and Ge Chen. Predicting station-level short-term passenger flow in a citywide metro network using spatiotemporal graph convolutional neural networks. ISPRS International Journal of Geo-Information, 8(6):243, 2019
742 | 25. Maryam Toloubidokhti, Ryan Missel, Xiajun Jiang, Niels Otani, and Linwei Wang. Neural state-space modeling with latent causal-effect disentanglement. In International Workshop on Machine Learning in Medical Imaging, 2022.
743 | 26. Matthieu Kirchmeyer, Yuan Yin, J ́er ́emie Don`a, Nicolas Baskiotis, Alain Rakotomamonjy, and Patrick Gallinari. Generalizing to new physical systems via context-informed dynamics model. arXiv preprint arXiv:2202.01889, 2022.
744 | 27. Rui Wang, Robin Walters, and Rose Yu. Meta-learning dynamics forecasting using task inference. arXiv preprint arXiv:2102.10271, 2021.
745 | 28. Kingma Diederik P, Welling Max. Auto-encoding variational bayes // arXiv preprint arXiv:1312.6114.2013.
746 | 29. Xiajun Jiang, Zhiyuan Li, Ryan Missel, Md Shakil Zaman, Brian Zenger, Wilson W Good, Rob S MacLeod, John L Sapp, and Linwei Wang. Few-shot generation of personalized neural surrogates for cardiac simulation via bayesian meta-learning. In International Conference on Medical Image Computing and Computer-Assisted Intervention, pages 46–56. Springer, 2022.
747 | 30. Marco Forgione, Manas Mejari, and Dario Piga. Learning neural state-space models: do we need a state estimator? arXiv preprint arXiv:2206.12928, 2022.
748 |
--------------------------------------------------------------------------------
/configs/config.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - _self_
3 | - model: node
4 | - dataset: bouncingball
5 | - training: forecast_forecast
6 |
7 | hydra:
8 | output_subdir: null
9 | run:
10 | dir: .
11 |
12 | # Random seed of the run
13 | seed: 125125125
14 | devices: [0]
15 |
16 | # Experiment folder naming
17 | expname: ${dataset.dataset}_${model.model_type}_${training.z_amort_train}ztrain_${training.z_amort_test}ztest_${training.num_steps}iterations_${seed}seed
18 | model_path: ""
19 | checkpt: ""
20 |
21 | # For training, overrideable by cmd
22 | train: true
23 | resume: false
24 |
--------------------------------------------------------------------------------
/configs/dataset/bouncingball.yaml:
--------------------------------------------------------------------------------
1 | dataset: bouncingball
2 | dataset_percent: 1.0
3 | batches_to_save: 50
4 |
--------------------------------------------------------------------------------
/configs/model/architecture/default.yaml:
--------------------------------------------------------------------------------
1 | # Dimension and channels of the input image
2 | dim: 32
3 | num_channels: 1
4 |
5 | # Dynamics MLP function parameters
6 | num_layers: 2
7 | num_hidden: 128
8 | latent_dim: 8
9 | latent_act: swish
10 |
11 | # VAE parameters
12 | num_filters: 24
13 |
--------------------------------------------------------------------------------
/configs/model/dkf.yaml:
--------------------------------------------------------------------------------
1 | model_type: dkf
2 | stochastic: true
3 |
4 | defaults:
5 | - architecture: default
--------------------------------------------------------------------------------
/configs/model/kvae.yaml:
--------------------------------------------------------------------------------
1 | model_type: kvae
2 | stochastic: true
3 |
4 | defaults:
5 | - architecture: default
--------------------------------------------------------------------------------
/configs/model/node.yaml:
--------------------------------------------------------------------------------
1 | model_type: node
2 | stochastic: true
3 |
4 | integrator: rk4
5 | integrator_step_size: 0.5
6 |
7 | defaults:
8 | - architecture: default
--------------------------------------------------------------------------------
/configs/model/rgn.yaml:
--------------------------------------------------------------------------------
1 | model_type: rgn
2 | stochastic: true
3 |
4 | defaults:
5 | - architecture: default
--------------------------------------------------------------------------------
/configs/model/rgnres.yaml:
--------------------------------------------------------------------------------
1 | model_type: rgnres
2 | stochastic: true
3 |
4 | defaults:
5 | - architecture: default
--------------------------------------------------------------------------------
/configs/model/vrnn.yaml:
--------------------------------------------------------------------------------
1 | model_type: vrnn
2 | stochastic: true
3 |
4 | defaults:
5 | - architecture: default
--------------------------------------------------------------------------------
/configs/training/forecast_forecast.yaml:
--------------------------------------------------------------------------------
1 | # PyTorch-Lightning hardware params
2 | accelerator: gpu
3 | num_workers: 0
4 |
5 | # Which GPU ID to run on
6 | devices: [0]
7 |
8 | # Number of steps per task
9 | num_steps: 50001
10 |
11 | # How often to log metrics and how often to save image reconstructions
12 | metric_interval: 50
13 | image_interval: 500
14 |
15 | # What metrics to evaluate on
16 | metrics:
17 | - vpt
18 | - reconstruction_mse
19 |
20 | test_metrics:
21 | - vpt
22 | - dst
23 | - vpd
24 | - reconstruction_mse
25 | - extrapolation_mse
26 |
27 | # Batch size
28 | batch_size: 64
29 |
30 | # Learning rate and cosine annealing scheduler
31 | # We use CosineAnnealing with WarmRestarts
32 | # More information here: https://github.com/qu-gg/pytorch-cosine-annealing-with-decay-and-initial-warmup
33 | learning_rate: 1e-3
34 | scheduler_use: true
35 | scheduler_restart_interval: 5000
36 | scheduler_warmup_steps: 200
37 | scheduler_decay: 0.90
38 |
39 | # KL loss betas
40 | beta_z0: 1e-2
41 | beta_kl: 1e-3
42 |
43 | # How many steps are given as observed data
44 | # For forecasting, this will be small (e.g., 3 to 5)
45 | # For reconstruction, this will be the train_length
46 | z_amort_train: 5
47 | z_amort_test: 5
48 |
49 | # Total steps to either reconstruct or forecast for
50 | train_length: 20
51 | val_length: 20
52 | test_length: 20
--------------------------------------------------------------------------------
/configs/training/forecast_recon.yaml:
--------------------------------------------------------------------------------
1 | # PyTorch-Lightning hardware params
2 | accelerator: gpu
3 | num_workers: 0
4 |
5 | # Which GPU ID to run on
6 | devices: [0]
7 |
8 | # Number of steps per task
9 | num_steps: 50001
10 |
11 | # How often to log metrics and how often to save image reconstructions
12 | metric_interval: 50
13 | image_interval: 500
14 |
15 | # What metrics to evaluate on
16 | metrics:
17 | - vpt
18 | - reconstruction_mse
19 |
20 | test_metrics:
21 | - vpt
22 | - dst
23 | - vpd
24 | - reconstruction_mse
25 | - extrapolation_mse
26 |
27 | # Batch size
28 | batch_size: 64
29 |
30 | # Learning rate and cosine annealing scheduler
31 | # We use CosineAnnealing with WarmRestarts
32 | # More information here: https://github.com/qu-gg/pytorch-cosine-annealing-with-decay-and-initial-warmup
33 | learning_rate: 1e-3
34 | scheduler_use: true
35 | scheduler_restart_interval: 5000
36 | scheduler_warmup_steps: 200
37 | scheduler_decay: 0.90
38 |
39 | # KL loss betas
40 | beta_z0: 1e-2
41 | beta_kl: 1e-3
42 |
43 | # How many steps are given as observed data
44 | # For forecasting, this will be small (e.g., 3 to 5)
45 | # For reconstruction, this will be the train_length
46 | z_amort_train: 5
47 | z_amort_test: 20
48 |
49 | # Total steps to either reconstruct or forecast for
50 | train_length: 20
51 | val_length: 20
52 | test_length: 20
--------------------------------------------------------------------------------
/configs/training/recon_forecast.yaml:
--------------------------------------------------------------------------------
1 | # PyTorch-Lightning hardware params
2 | accelerator: gpu
3 | num_workers: 0
4 |
5 | # Which GPU ID to run on
6 | devices: [0]
7 |
8 | # Number of steps per task
9 | num_steps: 50001
10 |
11 | # How often to log metrics and how often to save image reconstructions
12 | metric_interval: 50
13 | image_interval: 500
14 |
15 | # What metrics to evaluate on
16 | metrics:
17 | - vpt
18 | - reconstruction_mse
19 |
20 | test_metrics:
21 | - vpt
22 | - dst
23 | - vpd
24 | - reconstruction_mse
25 | - extrapolation_mse
26 |
27 | # Batch size
28 | batch_size: 64
29 |
30 | # Learning rate and cosine annealing scheduler
31 | # We use CosineAnnealing with WarmRestarts
32 | # More information here: https://github.com/qu-gg/pytorch-cosine-annealing-with-decay-and-initial-warmup
33 | learning_rate: 1e-3
34 | scheduler_use: true
35 | scheduler_restart_interval: 5000
36 | scheduler_warmup_steps: 200
37 | scheduler_decay: 0.90
38 |
39 | # KL loss betas
40 | beta_z0: 1e-2
41 | beta_kl: 1e-3
42 |
43 | # How many steps are given as observed data
44 | # For forecasting, this will be small (e.g., 3 to 5)
45 | # For reconstruction, this will be the train_length
46 | z_amort_train: 20
47 | z_amort_test: 5
48 |
49 | # Total steps to either reconstruct or forecast for
50 | train_length: 20
51 | val_length: 20
52 | test_length: 20
--------------------------------------------------------------------------------
/configs/training/recon_recon.yaml:
--------------------------------------------------------------------------------
1 | # PyTorch-Lightning hardware params
2 | accelerator: gpu
3 | num_workers: 0
4 |
5 | # Which GPU ID to run on
6 | devices: [0]
7 |
8 | # Number of steps per task
9 | num_steps: 50001
10 |
11 | # How often to log metrics and how often to save image reconstructions
12 | metric_interval: 50
13 | image_interval: 500
14 |
15 | # What metrics to evaluate on
16 | metrics:
17 | - vpt
18 | - reconstruction_mse
19 |
20 | test_metrics:
21 | - vpt
22 | - dst
23 | - vpd
24 | - reconstruction_mse
25 | - extrapolation_mse
26 |
27 | # Batch size
28 | batch_size: 64
29 |
30 | # Learning rate and cosine annealing scheduler
31 | # We use CosineAnnealing with WarmRestarts
32 | # More information here: https://github.com/qu-gg/pytorch-cosine-annealing-with-decay-and-initial-warmup
33 | learning_rate: 1e-3
34 | scheduler_use: true
35 | scheduler_restart_interval: 5000
36 | scheduler_warmup_steps: 200
37 | scheduler_decay: 0.90
38 |
39 | # KL loss betas
40 | beta_z0: 1e-2
41 | beta_kl: 1e-3
42 |
43 | # How many steps are given as observed data
44 | # For forecasting, this will be small (e.g., 3 to 5)
45 | # For reconstruction, this will be the train_length
46 | z_amort_train: 20
47 | z_amort_test: 20
48 |
49 | # Total steps to either reconstruct or forecast for
50 | train_length: 20
51 | val_length: 20
52 | test_length: 20
--------------------------------------------------------------------------------
/data/README.md:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/qu-gg/torch-neural-ssm/13813f75e18b9efd5cedc40bd7d707b0a1e7870e/data/README.md
--------------------------------------------------------------------------------
/data/generate_bouncingball.py:
--------------------------------------------------------------------------------
1 | """
2 | @file twomix_gravity.py
3 |
4 | Handles generating a mixed set of a simple static velocity set (left to right, same initial velocities) alongside
5 | a single gravity
6 | """
7 | import os
8 | import pygame
9 | import random
10 | import tarfile
11 | import numpy as np
12 | import pymunk.pygame_util
13 | import matplotlib.pyplot as plt
14 |
15 | from tqdm import tqdm
16 |
17 |
18 | class BallBox:
19 | def __init__(self, dt=0.2, res=(32, 32), init_pos=(3, 3), init_std=0, wall=None, gravity=(0.0, 0.0), ball_color="white"):
20 | pygame.init()
21 |
22 | self.ball_color = ball_color
23 |
24 | self.dt = dt
25 | self.res = res
26 | if os.environ.get('SDL_VIDEODRIVER', '') == 'dummy':
27 | pygame.display.set_mode(res, 0, 24)
28 | self.screen = pygame.Surface(res, pygame.SRCCOLORKEY, 24)
29 | pygame.draw.rect(self.screen, (0, 0, 0), (0, 0, res[0], res[1]), 0)
30 | else:
31 | self.screen = pygame.display.set_mode(res, 0, 24)
32 | self.gravity = gravity
33 | self.initial_position = init_pos
34 | self.initial_std = init_std
35 | self.space = pymunk.Space()
36 | self.space.gravity = self.gravity
37 | self.draw_options = pymunk.pygame_util.DrawOptions(self.screen)
38 | self.clock = pygame.time.Clock()
39 | self.wall = wall
40 | self.static_lines = None
41 |
42 | self.dd = 2
43 |
44 | def _clear(self):
45 | self.screen.fill((0.3176, 0.3451, 0.3647))
46 |
47 | def create_ball(self, radius=3):
48 | inertia = pymunk.moment_for_circle(1, 0, radius, (0, 0))
49 | body = pymunk.Body(1, inertia)
50 | position = np.array(self.initial_position) + self.initial_std * np.random.normal(size=(2,))
51 | position = np.clip(position, self.dd + radius + 1, self.res[0]-self.dd-radius-1)
52 | position = position.tolist()
53 | body.position = position
54 |
55 | shape = pymunk.Circle(body, radius, (0, 0))
56 | shape.elasticity = 1.0
57 |
58 | shape.color = pygame.color.THECOLORS[self.ball_color]
59 | return shape
60 |
61 | def fire(self, angle=50, velocity=20, radius=3):
62 | speedX = velocity * np.cos(angle * np.pi / 180)
63 | speedY = velocity * np.sin(angle * np.pi / 180)
64 |
65 | ball = self.create_ball(radius)
66 | ball.body.velocity = (speedX, speedY)
67 |
68 | self.space.add(ball, ball.body)
69 | return ball
70 |
71 | def run(self, iterations=20, sequences=500, angle_limits=(0, 360), velocity_limits=(10, 25), radius=3,
72 | flip_gravity=None, save=None, filepath='../../data/balls.npz', delay=None):
73 | if save:
74 | images = np.empty((sequences, iterations, self.res[0], self.res[1]), dtype=np.float32)
75 | state = np.empty((sequences, iterations, 2), dtype=np.float32)
76 |
77 | dd = 0
78 | self.static_lines = [pymunk.Segment(self.space.static_body, (-1, -1), (-1, self.res[1]-dd), 0.0),
79 | pymunk.Segment(self.space.static_body, (-1, -1), (self.res[0]-dd, -1), 0.0),
80 | pymunk.Segment(self.space.static_body, (self.res[0] - dd, self.res[1] - dd),
81 | (-1, self.res[1]-dd), 0.0),
82 | pymunk.Segment(self.space.static_body, (self.res[0] - dd, self.res[1] - dd),
83 | (self.res[0]-dd, -1), 0.0)]
84 | for line in self.static_lines:
85 | line.elasticity = 1.0
86 |
87 | if self.ball_color == "white2":
88 | line.color = pygame.color.THECOLORS["white"]
89 | else:
90 | line.color = pygame.color.THECOLORS[self.ball_color]
91 | # self.space.add(self.static_lines)
92 |
93 | for sl in self.static_lines:
94 | self.space.add(sl)
95 |
96 | for s in range(sequences):
97 | if s % 100 == 0:
98 | print(s)
99 |
100 | angle = np.random.uniform(*angle_limits)
101 | velocity = np.random.uniform(*velocity_limits)
102 | # controls[:, s] = np.array([angle, velocity])
103 | ball = self.fire(angle, velocity, radius)
104 | for i in range(iterations):
105 | self._clear()
106 | self.space.debug_draw(self.draw_options)
107 | self.space.step(self.dt)
108 | pygame.display.flip()
109 |
110 | if delay:
111 | self.clock.tick(delay)
112 |
113 | if save == 'png':
114 | pygame.image.save(self.screen, os.path.join(filepath, "bouncing_balls_%02d_%02d.png" % (s, i)))
115 | elif save == 'npz':
116 | images[s, i] = pygame.surfarray.array2d(self.screen).swapaxes(1, 0).astype(np.float32) / (2**24 - 1)
117 | state[s, i] = list(ball.body.velocity) # list(ball.body.position) + # Note that this is done for compatibility with the combined dataset
118 |
119 | # Remove the ball and the wall from the space
120 | self.space.remove(ball, ball.body)
121 |
122 | return images, state
123 |
124 |
125 | if __name__ == '__main__':
126 | os.environ['SDL_VIDEODRIVER'] = 'dummy'
127 |
128 | # Parameters of generation, resolution and number of samples
129 | scale = 1
130 | timesteps = 40
131 | training_size = 10000
132 | validation_size = 1500
133 | testing_size = 2500
134 | total_size = training_size + validation_size + testing_size
135 |
136 | base_dir = f"bouncingball_{training_size}samples_{timesteps}steps/"
137 |
138 | # Arrays to hold sets
139 | train_images, train_states = [], []
140 | test_images, test_states = [], []
141 | np.random.seed(1234)
142 |
143 | # Generate the data sequences
144 | cannon = BallBox(dt=0.25, res=(32*scale, 32*scale), init_pos=(16*scale, 16*scale), init_std=8,
145 | wall=None, gravity=(0.0, -5.0), ball_color="white")
146 | i, s = cannon.run(delay=None, iterations=timesteps, sequences=total_size + 5000,
147 | radius=3, angle_limits=(0, 360), velocity_limits=(5.0, 10.0), save='npz')
148 |
149 | # Setting the pixels to a uniform background and foreground (for simplicity of training)
150 | i = (i > 0).astype(np.float32)
151 |
152 | # Brute force check for any bad trajectories where the ball leaves
153 | bad_indices = []
154 | for seq_idx, sequence in enumerate(i):
155 | sums = np.sum(sequence, axis=(1, 2))
156 | if np.where(sums == 0.0)[0].shape[0] > 0:
157 | bad_indices.append(seq_idx)
158 |
159 | print(f"Bad indices: {len(bad_indices)}")
160 |
161 | i = np.delete(i, bad_indices, 0)
162 | s = np.delete(s, bad_indices, 0)
163 | if i.shape[0] > total_size:
164 | i = i[:total_size, :]
165 | s = s[:total_size, :]
166 |
167 | # Break into train and test sets, adding in generic labels
168 | train_images = i[:training_size]
169 | train_states = s[:training_size]
170 | train_classes = np.full([train_images.shape[0], 1], fill_value=2)
171 |
172 | val_images = i[training_size:training_size + validation_size]
173 | val_states = s[training_size:training_size + validation_size]
174 | val_classes = np.full([val_images.shape[0], 1], fill_value=2)
175 |
176 | test_images = i[training_size + validation_size:]
177 | test_states = s[training_size + validation_size:]
178 | test_classes = np.full([test_images.shape[0], 1], fill_value=2)
179 |
180 | print(f"Train - Images: {train_images.shape} | States: {train_states.shape} | Classes: {train_classes.shape}")
181 | print(f"Val - Images: {val_images.shape} | States: {val_states.shape} | Classes: {val_classes.shape}")
182 | print(f"Test - Images: {test_images.shape} | States: {test_states.shape} | Classes: {test_classes.shape}")
183 |
184 | # Make sure all directories are made beforehand
185 | if not os.path.exists(f"{base_dir}/"):
186 | os.mkdir(f"{base_dir}/")
187 |
188 | # Permute the sets and states together
189 | p = np.random.permutation(train_images.shape[0])
190 | train_images, train_classes, train_states = train_images[p], train_classes[p], train_states[p]
191 |
192 | p = np.random.permutation(val_images.shape[0])
193 | val_images, val_classes, val_states = val_images[p], val_classes[p], val_states[p]
194 |
195 | p = np.random.permutation(test_images.shape[0])
196 | test_images, test_classes, test_states = test_images[p],test_classes[p], test_states[p]
197 |
198 | # Save sets
199 | np.savez(os.path.abspath(f"{base_dir}/train.npz"), image=train_images, state=train_states, label=train_classes)
200 | np.savez(os.path.abspath(f"{base_dir}/val.npz"), image=val_images, state=val_states, label=val_classes)
201 | np.savez(os.path.abspath(f"{base_dir}/test.npz"), image=test_images, state=test_states, label=test_classes)
202 |
--------------------------------------------------------------------------------
/data/generate_hamiltonian.py:
--------------------------------------------------------------------------------
1 | """
2 | @file generate_date.py
3 | @url https://github.com/deepmind/dm_hamiltonian_dynamics_suite/tree/master/dm_hamiltonian_dynamics_suite
4 |
5 | Adapted from the DeepMind Hamiltonian Dynamics Suite for usage with WebDataset
6 | Requires Linux and JAX to run, as well as a local folder structure of dm_hamiltonian_dynamics_suite.
7 | As such, it won't run straight in this repository, however is here for demonstration purposes.
8 | """
9 | from subprocess import getstatusoutput
10 | from matplotlib import pyplot as plt
11 | from matplotlib import animation as plt_animation
12 | import numpy as np
13 | from jax import config as jax_config
14 |
15 | import os
16 | import random
17 | import tarfile
18 | from tqdm import tqdm
19 |
20 | jax_config.update("jax_enable_x64", True)
21 |
22 | from dm_hamiltonian_dynamics_suite import load_datasets
23 | from dm_hamiltonian_dynamics_suite import datasets
24 |
25 | # @title Helper functions
26 | DATASETS_URL = "gs://dm-hamiltonian-dynamics-suite"
27 | DATASETS_FOLDER = "./datasets" # @param {type: "string"}
28 | os.makedirs(DATASETS_FOLDER, exist_ok=True)
29 |
30 |
31 | def download_file(file_url, destination_file):
32 | print("Downloading", file_url, "to", destination_file)
33 | command = f"gsutil cp {file_url} {destination_file}"
34 | status_code, output = getstatusoutput(command)
35 | if status_code != 0:
36 | raise ValueError(output)
37 |
38 |
39 | def download_dataset(dataset_name: str):
40 | """Downloads the provided dataset from the DM Hamiltonian Dataset Suite"""
41 | destination_folder = os.path.join(DATASETS_FOLDER, dataset_name)
42 | dataset_url = os.path.join(DATASETS_URL, dataset_name)
43 | os.makedirs(destination_folder, exist_ok=True)
44 | if "long_trajectory" in dataset_name:
45 | files = ("features.txt", "test.tfrecord")
46 | else:
47 | files = ("features.txt", "train.tfrecord", "test.tfrecord")
48 | for file_name in files:
49 | file_url = os.path.join(dataset_url, file_name)
50 | destination_file = os.path.join(destination_folder, file_name)
51 | if os.path.exists(destination_file):
52 | print("File", file_url, "already present.")
53 | continue
54 | download_file(file_url, destination_file)
55 |
56 |
57 | def unstack(value: np.ndarray, axis: int = 0):
58 | """Unstacks an array along an axis into a list"""
59 | split = np.split(value, value.shape[axis], axis=axis)
60 | return [np.squeeze(v, axis=axis) for v in split]
61 |
62 |
63 | def make_batch_grid(
64 | batch: np.ndarray,
65 | grid_height: int,
66 | grid_width: int,
67 | with_padding: bool = True):
68 | """Makes a single grid image from a batch of multiple images."""
69 | assert batch.ndim == 5
70 | assert grid_height * grid_width >= batch.shape[0]
71 | batch = batch[:grid_height * grid_width]
72 | batch = batch.reshape((grid_height, grid_width) + batch.shape[1:])
73 | if with_padding:
74 | batch = np.pad(batch, pad_width=[[0, 0], [0, 0], [0, 0],
75 | [1, 0], [1, 0], [0, 0]],
76 | mode="constant", constant_values=1.0)
77 | batch = np.concatenate(unstack(batch), axis=-3)
78 | batch = np.concatenate(unstack(batch), axis=-2)
79 | if with_padding:
80 | batch = batch[:, 1:, 1:]
81 | return batch
82 |
83 |
84 | def plot_animattion_from_batch(
85 | batch: np.ndarray,
86 | grid_height,
87 | grid_width,
88 | with_padding=True,
89 | figsize=None):
90 | """Plots an animation of the batch of sequences."""
91 | if figsize is None:
92 | figsize = (grid_width, grid_height)
93 | batch = make_batch_grid(batch, grid_height, grid_width, with_padding)
94 | batch = batch[:, ::-1]
95 | fig = plt.figure(figsize=figsize)
96 | plt.close()
97 | ax = fig.add_subplot(1, 1, 1)
98 | ax.axis('off')
99 | img = ax.imshow(batch[0])
100 |
101 | def frame_update(i):
102 | i = int(np.floor(i).astype("int64"))
103 | img.set_data(batch[i])
104 | return [img]
105 |
106 | anim = plt_animation.FuncAnimation(
107 | fig=fig,
108 | func=frame_update,
109 | frames=np.linspace(0.0, len(batch), len(batch) * 5 + 1)[:-1],
110 | save_count=len(batch),
111 | interval=10,
112 | blit=True
113 | )
114 | return anim
115 |
116 |
117 | def plot_sequence_from_batch(
118 | batch: np.ndarray,
119 | t_start: int = 0,
120 | with_padding: bool = True,
121 | fontsize: int = 20):
122 | """Plots all of the sequences in the batch."""
123 | n, t, dx, dy = batch.shape[:-1]
124 | xticks = np.linspace(dx // 2, t * (dx + 1) - 1 - dx // 2, t)
125 | xtick_labels = np.arange(t) + t_start
126 | yticks = np.linspace(dy // 2, n * (dy + 1) - 1 - dy // 2, n)
127 | ytick_labels = np.arange(n)
128 | batch = batch.reshape((n * t, 1) + batch.shape[2:])
129 | batch = make_batch_grid(batch, n, t, with_padding)[0]
130 | plt.imshow(batch.squeeze())
131 | plt.xticks(ticks=xticks, labels=xtick_labels, fontsize=fontsize)
132 | plt.yticks(ticks=yticks, labels=ytick_labels, fontsize=fontsize)
133 |
134 |
135 | def visalize_dataset(
136 | dataset_path: str,
137 | sequence_lengths: int = 60,
138 | grid_height: int = 2,
139 | grid_width: int = 5):
140 | """Visualizes a dataset loaded from the path provided."""
141 | split = "test"
142 | batch_size = grid_height * grid_width
143 | dataset = load_datasets.load_dataset(
144 | path=dataset_path,
145 | tfrecord_prefix=split,
146 | sub_sample_length=sequence_lengths,
147 | per_device_batch_size=batch_size,
148 | num_epochs=None,
149 | drop_remainder=True,
150 | shuffle=False,
151 | shuffle_buffer=100
152 | )
153 | sample = next(iter(dataset))
154 | batch_x = sample['x'].numpy()
155 | batch_image = sample['image'].numpy()
156 | # Plot real system dimensions
157 | plt.figure(figsize=(24, 8))
158 | for i in range(batch_x.shape[-1]):
159 | plt.subplot(1, batch_x.shape[-1], i + 1)
160 | plt.title(f"Samples from dimension {i + 1}")
161 | plt.plot(batch_x[:, :, i].T)
162 | plt.show()
163 | # Plot a sequence of 50 images
164 | plt.figure(figsize=(30, 10))
165 | plt.title("Samples from 50 steps sub sequences.")
166 | plot_sequence_from_batch(batch_image[:, :50])
167 | plt.show()
168 | # Plot animation
169 | return plot_animattion_from_batch(batch_image, grid_height, grid_width)
170 |
171 |
172 | # Generate dataset
173 | print("Generating datasets....")
174 | folder_to_store = "./generated_datasets"
175 | dataset = "pendulum"
176 | class_id = np.array([1])
177 | dt = 0.1
178 | num_steps = 1000
179 | steps_per_dt = 1
180 | num_train = 1000
181 | num_test = 2000
182 | overwrite = True
183 | datasets.generate_full_dataset(
184 | folder=folder_to_store,
185 | dataset=dataset,
186 | dt=dt,
187 | num_steps=num_steps,
188 | steps_per_dt=steps_per_dt,
189 | num_train=num_train,
190 | num_test=num_test,
191 | overwrite=overwrite,
192 | )
193 | dataset_full_name = dataset + "_dt_" + str(dt).replace(".", "_")
194 | dataset_output_path = dataset + "_{}samples_{}steps".format(num_train, num_steps) + "_dt" + str(dt).replace(".", "")
195 | dataset_path = os.path.join(folder_to_store, dataset_full_name)
196 | visalize_dataset(dataset_path)
197 |
198 | if not os.path.exists("data_out/{}".format(dataset_output_path)):
199 | os.mkdir("data_out/{}".format(dataset_output_path))
200 |
201 | """
202 | Training Generation
203 | """
204 | print("Converting training files...")
205 | if not os.path.exists("data_out/{}/train/".format(dataset_output_path)):
206 | os.mkdir("data_out/{}/train/".format(dataset_output_path))
207 |
208 | loaded_dataset = load_datasets.load_dataset(
209 | path=dataset_path,
210 | tfrecord_prefix="train",
211 | sub_sample_length=num_steps,
212 | per_device_batch_size=1,
213 | num_epochs=1,
214 | drop_remainder=True,
215 | shuffle=True,
216 | shuffle_buffer=100
217 | )
218 |
219 | images = None
220 | xs = None
221 | for idx, sample in enumerate(loaded_dataset):
222 | image = sample['image'][0].numpy()
223 |
224 | # (32, 32, 3) -> (3, 32, 32)
225 | image = np.swapaxes(image, 2, 3)
226 | image = np.swapaxes(image, 1, 2)
227 |
228 | # Just grab the R channel
229 | image = image[:, 0, :, :]
230 |
231 | x = sample['x'].numpy()
232 |
233 | np.savez("data_out/{}/train/{}.npz".format(dataset_output_path, idx), image=image, x=x, class_id=class_id)
234 |
235 | # Get file list and then shuffle it
236 | file_list = os.listdir("data_out/" + dataset_output_path + "/train/")
237 | random.shuffle(file_list)
238 |
239 | if not os.path.exists("data_out/" + dataset_output_path + '/train_tars/'):
240 | os.mkdir("data_out/" + dataset_output_path + '/train_tars/')
241 |
242 | n_shards = 200
243 | elements_per_shard = len(file_list) // n_shards
244 |
245 | for n in tqdm(range(n_shards)):
246 | with tarfile.open("data_out/" + dataset_output_path + "/train_tars/train{0:03}.tar".format(n), "w:gz") as tar:
247 | for file in file_list[n * elements_per_shard: (n + 1) * elements_per_shard]:
248 | tar.add("data_out/" + dataset_output_path + "/train/{}".format(file))
249 |
250 | """
251 | Testing Generation
252 | """
253 | print("Converting testing files...")
254 | if not os.path.exists("data_out/{}/test/".format(dataset_output_path)):
255 | os.mkdir("data_out/{}/test/".format(dataset_output_path))
256 | loaded_dataset = load_datasets.load_dataset(
257 | path=dataset_path,
258 | tfrecord_prefix="test",
259 | sub_sample_length=num_steps,
260 | per_device_batch_size=1,
261 | num_epochs=1,
262 | drop_remainder=True,
263 | shuffle=True,
264 | shuffle_buffer=100
265 | )
266 |
267 | images = None
268 | xs = None
269 | for idx, sample in enumerate(loaded_dataset):
270 | image = sample['image'][0].numpy()
271 |
272 | # (32, 32, 3) -> (3, 32, 32)
273 | image = np.swapaxes(image, 2, 3)
274 | image = np.swapaxes(image, 1, 2)
275 |
276 | # Just grab the R channel
277 | image = image[:, 0, :, :]
278 |
279 | x = sample['x'].numpy()
280 | np.savez("data_out/{}/test/{}.npz".format(dataset_output_path, idx), image=image, x=x, class_id=class_id)
281 |
282 | # Get file list and then shuffle it
283 | file_list = os.listdir("data_out/" + dataset_output_path + "/test/")
284 | random.shuffle(file_list)
285 |
286 | if not os.path.exists("data_out/" + dataset_output_path + '/test_tars/'):
287 | os.mkdir("data_out/" + dataset_output_path + '/test_tars/')
288 |
289 | n_shards = 200
290 | elements_per_shard = len(file_list) // n_shards
291 |
292 | for n in tqdm(range(n_shards)):
293 | with tarfile.open("data_out/" + dataset_output_path + "/test_tars/test{0:03}.tar".format(n), "w:gz") as tar:
294 | for file in file_list[n * elements_per_shard: (n + 1) * elements_per_shard]:
295 | tar.add("data_out/" + dataset_output_path + "/test/{}".format(file))
296 |
--------------------------------------------------------------------------------
/data/visualize_data.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | from matplotlib.colors import Normalize
4 | from matplotlib.cm import ScalarMappable
5 |
6 | def plot_ball_trajectory(images):
7 | # Define number of timesteps and image dimensions
8 | timesteps, dim1, dim2 = images.shape
9 |
10 | # Use 'turbo' colormap for a bright and high-contrast gradient
11 | cmap = plt.get_cmap("turbo")
12 | norm = Normalize(vmin=2, vmax=timesteps - 1)
13 | sm = ScalarMappable(cmap=cmap, norm=norm)
14 |
15 | # Create a base figure for the trajectory plot
16 | plt.figure(figsize=(6, 6))
17 | plt.axis("off")
18 |
19 | # Plot each timestep's image on the plot with a color gradient
20 | for t in range(timesteps):
21 | # Get color for the current timestep
22 | color = cmap(norm(t))[:3] # RGB color for the current timestep
23 |
24 | # Create an RGB mask for the current frame
25 | mask = images[t] > 0 # Assuming ball is represented by non-zero values
26 | colored_image = np.zeros((dim1, dim2, 3)) # Create an RGB image
27 | for c in range(3): # Assign the color to masked areas only
28 | colored_image[:, :, c] = mask * color[c]
29 |
30 | # Plot the colored mask with some transparency
31 | plt.imshow(1 - colored_image, alpha=0.3)
32 |
33 | # Add color bar to indicate time progression
34 | plt.colorbar(sm, label="Time Steps")
35 | plt.title("Ball Trajectory Over Time", color='white')
36 | plt.gca().set_facecolor('black') # Set the background to black for contrast
37 | plt.show()
38 |
39 | images = np.load("bouncingball_10000samples_40steps/train.npz", allow_pickle=True)["image"]
40 | plot_ball_trajectory(images[0])
41 | plot_ball_trajectory(images[1])
42 | plot_ball_trajectory(images[2])
--------------------------------------------------------------------------------
/experiments/README.md:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/qu-gg/torch-neural-ssm/13813f75e18b9efd5cedc40bd7d707b0a1e7870e/experiments/README.md
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | """
2 | @file main.py
3 |
4 | Main entrypoint for the training and testing environments. Takes in a configuration file
5 | of arguments and either trains a model or tests a given model and checkpoint.
6 | """
7 | import torch
8 | import hydra
9 | import pytorch_lightning
10 | import pytorch_lightning.loggers as pl_loggers
11 |
12 | from omegaconf import DictConfig
13 | from utils.dataloader import SSMDataModule
14 | from utils.utils import get_model, flatten_cfg
15 | from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, LearningRateMonitor
16 |
17 |
18 | @hydra.main(version_base="1.3", config_path="configs", config_name="config")
19 | def main(cfg: DictConfig):
20 | # Flatten the cfg down a level
21 | cfg.expname = cfg.expname
22 | cfg = flatten_cfg(cfg)
23 |
24 | # Set a consistent seed over the full set for consistent analysis
25 | pytorch_lightning.seed_everything(cfg.seed, workers=True)
26 |
27 | # Enable fp16 training
28 | torch.backends.cudnn.benchmark = True
29 | torch.backends.cudnn.allow_tf32 = True
30 | torch.set_float32_matmul_precision('medium')
31 |
32 | # Limit number of CPU workers
33 | torch.set_num_threads(8)
34 |
35 | # Building the PL-DataModule for all splits
36 | datamodule = SSMDataModule(cfg=cfg)
37 |
38 | # Initialize model type and initialize
39 | model = get_model(cfg.model_type)(cfg)
40 |
41 | # Tensorboard Logger
42 | tb_logger = pl_loggers.TensorBoardLogger(save_dir=f"experiments/{cfg.expname}/", name=f"{cfg.model_type}")
43 |
44 | # Callbacks for checkpointing and early stopping
45 | checkpoint_callback = ModelCheckpoint(monitor='val_reconstruction_mse',
46 | filename='step{step:02d}-val_reconstruction_mse{val_reconstruction_mse:.4f}',
47 | auto_insert_metric_name=False, save_last=True)
48 | early_stop_callback = EarlyStopping(monitor="val_reconstruction_mse", min_delta=0.000001, patience=15, mode="min")
49 | lr_monitor = LearningRateMonitor(logging_interval='step')
50 |
51 | # Initialize trainer
52 | trainer = pytorch_lightning.Trainer(
53 | callbacks=[
54 | # early_stop_callback,
55 | lr_monitor,
56 | checkpoint_callback
57 | ],
58 | accelerator=cfg.accelerator,
59 | devices=cfg.devices,
60 | deterministic=True,
61 | max_steps=cfg.num_steps * cfg.batch_size,
62 | max_epochs=1,
63 | gradient_clip_val=5.0,
64 | val_check_interval=cfg.metric_interval,
65 | num_sanity_val_steps=0,
66 | logger=tb_logger
67 | )
68 |
69 | # Training from scratch
70 | if cfg.train is True and cfg.resume is False:
71 | trainer.fit(model, datamodule)
72 |
73 | # Training from the last epoch
74 | elif cfg.train is True and cfg.resume is True:
75 | ckpt_path = tb_logger.log_dir + "/checkpoints/" + cfg.checkpt if cfg.checkpt != "" \
76 | else f"{tb_logger.log_dir}/checkpoints/last.ckpt"
77 |
78 | trainer.fit(model, datamodule, ckpt_path=ckpt_path)
79 |
80 | # Testing the model on each split
81 | ckpt_path = tb_logger.log_dir + "/checkpoints/" + cfg.checkpt if cfg.checkpt != "" \
82 | else f"{tb_logger.log_dir}/checkpoints/last.ckpt"
83 |
84 | model.setting = 'train'
85 | trainer.test(model, datamodule.evaluate_train_dataloader(), ckpt_path=ckpt_path)
86 |
87 | model.setting = 'val'
88 | trainer.test(model, datamodule.val_dataloader(), ckpt_path=ckpt_path)
89 |
90 | model.setting = 'test'
91 | trainer.test(model, datamodule.test_dataloader(), ckpt_path=ckpt_path)
92 |
93 |
94 | if __name__ == '__main__':
95 | main()
96 |
--------------------------------------------------------------------------------
/models/CommonDynamics.py:
--------------------------------------------------------------------------------
1 | """
2 | @file CommonDynamics.py
3 |
4 | A common class that each latent dynamics function inherits.
5 | Holds the training + validation step logic and the VAE components for reconstructions.
6 | """
7 | import os
8 | import json
9 | import torch
10 | import numpy as np
11 | import torch.nn as nn
12 | import pytorch_lightning
13 | import utils.metrics as metrics
14 | import matplotlib.pyplot as plt
15 |
16 | from sklearn.manifold import TSNE
17 | from models.CommonVAE import LatentStateEncoder, EmissionDecoder
18 | from utils.plotting import show_images, get_embedding_trajectories
19 | from utils.utils import determine_annealing_factor, CosineAnnealingWarmRestartsWithDecayAndLinearWarmup
20 |
21 |
22 | class LatentDynamicsModel(pytorch_lightning.LightningModule):
23 | def __init__(self, cfg):
24 | """
25 | Generic implementation of a Latent Dynamics Model
26 | Holds the training and testing boilerplate, as well as experiment tracking
27 | :param cfg: passed in hydra configdict
28 | """
29 | super().__init__()
30 | # Config
31 | self.cfg = cfg
32 | self.setting = 'train'
33 |
34 | # Encoder + Decoder
35 | self.encoder = LatentStateEncoder(cfg)
36 | self.decoder = EmissionDecoder(cfg)
37 |
38 | # Recurrent dynamics function
39 | self.dynamics_func = None
40 |
41 | # Number of steps for training
42 | self.n_updates = 0
43 |
44 | # Loss function
45 | self.reconstruction_loss = nn.MSELoss(reduction='none')
46 |
47 | # Variable to hold batch outputs to manually log
48 | self.outputs = list()
49 |
50 | def forward(self, x, generation_len):
51 | """ Placeholder function for the dynamics forward pass """
52 | raise NotImplementedError("In forward function: Latent Dynamics function not specified.")
53 |
54 | def model_specific_loss(self, x, x_rec, train=True):
55 | """ Placeholder function for any additional loss terms a dynamics function may have """
56 | return 0.0
57 |
58 | def model_specific_plotting(self, version_path, outputs):
59 | """ Placeholder function for any additional plots a dynamics function may have """
60 | return None
61 |
62 | def configure_optimizers(self):
63 | """
64 | Most standard NSSM models have a joint optimization step under one ELBO, however there is room
65 | for EM-optimization procedures based on the PGM.
66 |
67 | By default, we assume a joint optim with the Adam Optimizer. We additionally include LR Warmup and
68 | CosineAnnealing with decay for standard learning rate care during training.
69 |
70 | For CosineAnnealing, we set the LR bounds to be [LR * 1e-2, LR]
71 | """
72 | optim = torch.optim.AdamW(self.parameters(), lr=self.cfg.learning_rate)
73 |
74 | # Build the scheduler if using it
75 | if self.cfg.scheduler_use is True:
76 | scheduler = CosineAnnealingWarmRestartsWithDecayAndLinearWarmup(
77 | optim,
78 | T_0=self.cfg.scheduler_restart_interval, T_mult=1,
79 | eta_min=self.cfg.learning_rate * 1e-2,
80 | warmup_steps=self.cfg.scheduler_warmup_steps,
81 | decay=self.cfg.scheduler_decay
82 | )
83 |
84 | # Explicit dictionary to state how often to ping the scheduler
85 | scheduler = {
86 | 'scheduler': scheduler,
87 | 'frequency': 1,
88 | 'interval': 'step'
89 | }
90 |
91 | return [optim], [scheduler]
92 |
93 | # Otherwise just return the optimizer
94 | return optim
95 |
96 | def on_train_start(self):
97 | """
98 | Before a training session starts, we set some model variables and save a JSON configuration of the
99 | used hyperparameters to allow for easy load-in at test-time.
100 | """
101 | # Get total number of parameters for the model and save
102 | self.log("total_num_parameters", float(sum(p.numel() for p in self.parameters() if p.requires_grad)), prog_bar=False)
103 |
104 | # Make image dir in lightning experiment folder if it doesn't exist
105 | if not os.path.exists(f"{self.logger.log_dir}/images/"):
106 | os.mkdir(f"{self.logger.log_dir}/images/")
107 |
108 | def get_step_outputs(self, batch, generation_len):
109 | """
110 | Handles the process of pre-processing and subsequence sampling a batch,
111 | as well as getting the outputs from the models regardless of step
112 | :param batch: list of dictionary objects representing a single image
113 | :param generation_len: how far out to generate for, dependent on the step (train/val)
114 | :return: processed model outputs
115 | """
116 | # Deconstruct batch
117 | _, images, states, _, labels = batch
118 |
119 | # Set the length of z_amort depending on training/testing
120 | if self.trainer.training:
121 | self.z_amort = self.cfg.z_amort_train
122 | else:
123 | self.z_amort = self.cfg.z_amort_test
124 |
125 | # Same random portion of the sequence over generation_len, saving room for backwards solving
126 | if max(images.shape[1] - self.z_amort - generation_len, 0) > 0:
127 | random_start = np.random.randint(0, images.shape[1] - self.z_amort - generation_len)
128 | images = images[:, random_start:random_start + generation_len + self.z_amort]
129 | states = states[:, random_start:random_start + generation_len + self.z_amort]
130 |
131 | # Get predictions
132 | preds, embeddings = self(images, generation_len)
133 |
134 | # Restrict images to start from after inference, for metrics and likelihood
135 | images = images[:, self.z_amort:]
136 | states = states[:, self.z_amort:]
137 | return images, states, labels, preds, embeddings
138 |
139 | def get_step_losses(self, images, preds):
140 | """
141 | Handles getting the ELBO terms for the given step
142 | :param images: ground truth images
143 | :param images_rev: grouth truth images, reversed for some models' secondary TRS loss
144 | :param preds: forward predictions from the model
145 | :return: likelihood, kl on z0, model-specific dynamics loss
146 | """
147 | # Reconstruction loss for the sequence and z0
148 | likelihood = self.reconstruction_loss(preds, images)
149 | likelihood = likelihood.reshape([likelihood.shape[0] * likelihood.shape[1], -1]).sum([-1]).mean()
150 |
151 | # Initial encoder loss, KL[q(z_K|x_0:K) || p(z_K)]
152 | klz = self.encoder.kl_z_term()
153 |
154 | # Get the loss terms from the specific latent dynamics loss
155 | dynamics_loss = self.model_specific_loss(images, preds)
156 | return likelihood, klz, dynamics_loss
157 |
158 | def get_epoch_metrics(self, outputs, length=20):
159 | """
160 | Takes the dictionary of saved batch metrics, stacks them, and gets outputs to log in the Tensorboard.
161 | :param outputs: list of dictionaries with outputs from each back
162 | :return: dictionary of metrics aggregated over the epoch
163 | """
164 | # Convert outputs to Tensors and then Numpy arrays
165 | images = torch.vstack([out["images"] for out in outputs]).cpu().numpy()
166 | preds = torch.vstack([out["preds"] for out in outputs]).cpu().numpy()
167 |
168 | # Iterate through each metric function and add to a dictionary
169 | out_metrics = {}
170 | for met in self.cfg.metrics:
171 | metric_function = getattr(metrics, met)
172 | out_metrics[met] = metric_function(images, preds, cfg=self.cfg, length=length)[0]
173 |
174 | # Return a dictionary of metrics
175 | return out_metrics
176 |
177 | def training_step(self, batch, batch_idx):
178 | """
179 | PyTorch-Lightning training step where the network is propagated and returns a single loss value,
180 | which is automatically handled for the backward update
181 | :param batch: list of dictionary objects representing a single image
182 | :param batch_idx: how far in the epoch this batch is
183 | """
184 | # Get model outputs from batch
185 | images, _, labels, preds, _ = self.get_step_outputs(batch, self.cfg.train_length)
186 |
187 | # Get model loss terms for the step
188 | likelihood, klz, dynamics_loss = self.get_step_losses(images, preds)
189 |
190 | # Determine KL annealing factor for the current step
191 | kl_factor = determine_annealing_factor(self.n_updates, anneal_update=1000)
192 |
193 | # Build the full loss
194 | loss = likelihood + kl_factor * ((self.cfg.beta_z0 * klz) + (self.cfg.beta_kl * dynamics_loss))
195 |
196 | # Log ELBO loss terms
197 | self.log_dict({
198 | "likelihood": likelihood,
199 | "kl_z": self.cfg.beta_z0 * klz,
200 | "dynamics_loss": self.cfg.beta_kl * dynamics_loss,
201 | "kl_factor": kl_factor
202 | })
203 |
204 | # Log metrics every N batches
205 | if len(self.outputs) < self.cfg.batches_to_save:
206 | self.outputs.append({"loss": loss, "labels": labels.detach(), "preds": preds.detach(), "images": images.detach()})
207 |
208 | # Return the loss for updating and track the iteration number
209 | self.n_updates += 1
210 | return {"loss": loss}
211 |
212 | def on_train_batch_end(self, outputs, batch, batch_idx):
213 | """ Given the iterative training, check on every batch's end whether it is evaluation time or not """
214 | # Show side-by-side reconstructions
215 | if batch_idx % self.cfg.image_interval == 0 and batch_idx != 0:
216 | show_images(self.outputs[0]["images"], self.outputs[0]["preds"], f'{self.logger.log_dir}/images/recon{batch_idx}train.png', num_out=5)
217 |
218 | # Get per-dynamics plots
219 | self.model_specific_plotting(self.logger.log_dir, self.outputs)
220 |
221 | # Get metrics over the window of batches and clear output buffer
222 | if batch_idx % self.cfg.metric_interval == 0 and batch_idx != 0:
223 | metrics = self.get_epoch_metrics(self.outputs[:self.cfg.batches_to_save], length=self.cfg.train_length)
224 | for metric in metrics.keys():
225 | self.log(f"train_{metric}", metrics[metric], prog_bar=True)
226 |
227 | self.outputs = list()
228 |
229 | def validation_step(self, batch, batch_idx):
230 | """
231 | PyTorch-Lightning validation step. Similar to the training step but on the given val set under torch.no_grad()
232 | :param batch: list of dictionary objects representing a single image
233 | :param batch_idx: how far in the epoch this batch is
234 | """
235 | # Get model outputs from batch
236 | images, _, _, preds, _ = self.get_step_outputs(batch, self.cfg.val_length)
237 |
238 | # Get model loss terms for the step
239 | likelihood, _, _ = self.get_step_losses(images, preds)
240 |
241 | # Log validation likelihood and metrics
242 | self.log("val_likelihood", likelihood, prog_bar=True)
243 |
244 | # Return outputs as dict
245 | out = {"loss": likelihood}
246 | if batch_idx < self.cfg.batches_to_save:
247 | out["preds"] = preds.detach()
248 | out["images"] = images.detach()
249 | return out
250 |
251 | def validation_epoch_end(self, outputs):
252 | """
253 | Every N epochs, get a validation reconstruction sample
254 | :param outputs: list of outputs from the validation steps at batch 0
255 | """
256 | # Log epoch metrics on saved batches
257 | metrics = self.get_epoch_metrics(self.outputs[:self.cfg.batches_to_save], length=self.cfg.val_length)
258 | for metric in metrics.keys():
259 | self.log(f"val_{metric}", metrics[metric], prog_bar=True)
260 |
261 | # Get image reconstructions
262 | if self.n_updates % self.cfg.image_interval == 0 and self.n_updates != 0:
263 | show_images(outputs[0]["images"], outputs[0]["preds"], f'{self.logger.log_dir}/images/recon{self.n_updates}val.png', num_out=5)
264 |
265 | def test_step(self, batch, batch_idx):
266 | """
267 | PyTorch-Lightning testing step.
268 | :param batch: list of dictionary objects representing a single image
269 | :param batch_idx: how far in the epoch this batch is
270 | """
271 | # Get model outputs from batch
272 | images, states, labels, preds, embeddings = self.get_step_outputs(batch, self.cfg.test_length)
273 |
274 | # Build output dictionary
275 | out = {"states": states.detach().cpu(), "embeddings": embeddings.detach().cpu(),
276 | "preds": preds.detach().cpu(), "images": images.detach().cpu(), "labels": labels.detach().cpu()}
277 | return out
278 |
279 | def test_epoch_end(self, batch_outputs):
280 | """
281 | For testing end, save the predictions, gt, and MSE to NPY files in the respective experiment folder
282 | :param outputs: list of outputs from the validation steps at batch 0
283 | """
284 | # Stack all output types and convert to numpy
285 | outputs = dict()
286 | for key in batch_outputs[0].keys():
287 | outputs[key] = torch.vstack([output[key] for output in batch_outputs]).numpy()
288 |
289 | # Iterate through each metric function and add to a dictionary
290 | out_metrics = {}
291 | for met in self.cfg.test_metrics:
292 | metric_function = getattr(metrics, met)
293 | metric_mean, metric_std = metric_function(outputs["images"], outputs["preds"], cfg=self.cfg, length=self.cfg.test_length)
294 | out_metrics[f"{met}_mean"], out_metrics[f"{met}_std"] = float(metric_mean), float(metric_std)
295 | print(f"=> {met}: {metric_mean:4.5f}+-{metric_std:4.5f}")
296 |
297 | # Set up output path and create dir
298 | output_path = f"{self.logger.log_dir}/test_{self.setting}"
299 | if not os.path.exists(output_path):
300 | os.mkdir(output_path)
301 |
302 | # Save some examples
303 | show_images(outputs["images"][:10], outputs["preds"][:10], f"{output_path}/test_{self.setting}_examples.png", num_out=5)
304 |
305 | # Save trajectory examples
306 | get_embedding_trajectories(outputs["embeddings"][0], outputs["states"][0], f"{output_path}/")
307 |
308 | # Get Z0 TSNE
309 | tsne = TSNE(n_components=2, perplexity=30, learning_rate=200, n_iter=1000, early_exaggeration=12)
310 | fitted = tsne.fit(outputs["embeddings"][:, 0])
311 | tsne_embedding = fitted.embedding_
312 |
313 | for i in np.unique(outputs["labels"]):
314 | subset = tsne_embedding[np.where(outputs["labels"] == i)[0], :]
315 | plt.scatter(subset[:, 0], subset[:, 1], c=next(plt.gca()._get_lines.prop_cycler)['color'])
316 |
317 | plt.title("t-SNE Plot of Z0 Embeddings")
318 | plt.legend(np.unique(outputs["labels"]), loc='center left', bbox_to_anchor=(1, 0.5))
319 | plt.savefig(f"{output_path}/test_{self.setting}_Z0tsne.png", bbox_inches='tight')
320 | plt.close()
321 |
322 | # Save metrics to JSON in checkpoint folder
323 | with open(f"{output_path}/test_{self.setting}_metrics.json", 'w') as f:
324 | json.dump(out_metrics, f)
325 |
326 | # Save metrics to an easy excel conversion style
327 | with open(f"{output_path}/test_{self.setting}_excel.txt", 'w') as f:
328 | for metric in self.cfg.metrics:
329 | f.write(f"{out_metrics[f'{metric}_mean']:0.3f}({out_metrics[f'{metric}_std']:0.3f}),")
330 |
--------------------------------------------------------------------------------
/models/CommonVAE.py:
--------------------------------------------------------------------------------
1 | """
2 | @file CommonVAE.py
3 |
4 | Holds the encoder/decoder architectures that are shared across the NSSM works
5 | """
6 | import torch
7 | import torch.nn as nn
8 |
9 | from utils.layers import Flatten, Gaussian, UnFlatten
10 | from torch.distributions import Normal, kl_divergence as kl
11 |
12 |
13 | class LatentStateEncoder(nn.Module):
14 | def __init__(self, cfg):
15 | """
16 | Holds the convolutional encoder that takes in a sequence of images and outputs the
17 | initial state of the latent dynamics
18 | :param cfg: hydra configdict
19 | """
20 | super(LatentStateEncoder, self).__init__()
21 | self.cfg = cfg
22 | self.z_amort = cfg.z_amort_train
23 |
24 | # Encoder, q(z_0 | x_{0:cfg.z_amort})
25 | self.initial_encoder = nn.Sequential(
26 | nn.Conv2d(self.z_amort, cfg.num_filters, kernel_size=5, stride=2, padding=(2, 2)), # 14,14
27 | nn.BatchNorm2d(cfg.num_filters),
28 | nn.ReLU(),
29 | nn.Conv2d(cfg.num_filters, cfg.num_filters * 2, kernel_size=5, stride=2, padding=(2, 2)), # 7,7
30 | nn.BatchNorm2d(cfg.num_filters * 2),
31 | nn.ReLU(),
32 | nn.Conv2d(cfg.num_filters * 2, cfg.num_filters * 8, kernel_size=5, stride=2, padding=(2, 2)),
33 | nn.BatchNorm2d(cfg.num_filters * 8),
34 | nn.ReLU(),
35 | nn.AvgPool2d(4),
36 | Flatten()
37 | )
38 |
39 | self.stochastic_out = Gaussian(cfg.num_filters * 8, cfg.latent_dim)
40 | self.deterministic_out = nn.Linear(cfg.num_filters * 8, cfg.latent_dim)
41 | self.out_act = nn.Tanh()
42 |
43 | # Holds generated z0 means and logvars for use in KL calculations
44 | self.z_means = None
45 | self.z_logvs = None
46 |
47 | def kl_z_term(self):
48 | """
49 | KL Z term, KL[q(z0|X) || N(0,1)]
50 | :return: mean klz across batch
51 | """
52 | if self.cfg.stochastic is False:
53 | return 0.0
54 |
55 | batch_size = self.z_means.shape[0]
56 | mus, logvars = self.z_means.view([-1]), self.z_logvs.view([-1]) # N, 2
57 |
58 | q = Normal(mus, torch.exp(0.5 * logvars))
59 | N = Normal(torch.zeros(len(mus), device=mus.device),
60 | torch.ones(len(mus), device=mus.device))
61 |
62 | klz = kl(q, N).view([batch_size, -1]).sum([1]).mean()
63 | return klz
64 |
65 | def forward(self, x):
66 | """
67 | Handles getting the initial state given x and saving the distributional parameters
68 | :param x: input sequences [BatchSize, GenerationLen * NumChannels, H, W]
69 | :return: z0 over the batch [BatchSize, LatentDim]
70 | """
71 | z0 = self.initial_encoder(x[:, :self.z_amort])
72 |
73 | # Apply the Gaussian layer if stochastic version
74 | if self.cfg.stochastic is True:
75 | self.z_means, self.z_logvs, z0 = self.stochastic_out(z0)
76 | else:
77 | z0 = self.deterministic_out(z0)
78 |
79 | return self.out_act(z0)
80 |
81 |
82 | class EmissionDecoder(nn.Module):
83 | def __init__(self, cfg):
84 | """
85 | Holds the convolutional decoder that takes in a batch of individual latent states and
86 | transforms them into their corresponding data space reconstructions
87 | """
88 | super(EmissionDecoder, self).__init__()
89 | self.cfg = cfg
90 |
91 | # Variable that holds the estimated output for the flattened convolution vector
92 | self.conv_dim = cfg.num_filters * 4 ** 3
93 |
94 | # Emission model handling z_i -> x_i
95 | self.decoder = nn.Sequential(
96 | # Transform latent vector into 4D tensor for deconvolution
97 | nn.Linear(cfg.latent_dim, self.conv_dim),
98 | UnFlatten(4),
99 |
100 | # Perform de-conv to output space
101 | nn.ConvTranspose2d(self.conv_dim // 16, cfg.num_filters * 4, kernel_size=4, stride=1, padding=(0, 0)),
102 | nn.BatchNorm2d(cfg.num_filters * 4),
103 | nn.ReLU(),
104 | nn.ConvTranspose2d(cfg.num_filters * 4, cfg.num_filters * 2, kernel_size=5, stride=2, padding=(1, 1)),
105 | nn.BatchNorm2d(cfg.num_filters * 2),
106 | nn.ReLU(),
107 | nn.ConvTranspose2d(cfg.num_filters * 2, cfg.num_filters, kernel_size=5, stride=2, padding=(1, 1), output_padding=(1, 1)),
108 | nn.BatchNorm2d(cfg.num_filters),
109 | nn.ReLU(),
110 | nn.ConvTranspose2d(cfg.num_filters, cfg.num_channels, kernel_size=5, stride=1, padding=(2, 2)),
111 | nn.Sigmoid(),
112 | )
113 |
114 | def forward(self, zts):
115 | """
116 | Handles decoding a batch of individual latent states into their corresponding data space reconstructions
117 | :param zts: latent states [BatchSize * GenerationLen, LatentDim]
118 | :return: data output [BatchSize, GenerationLen, NumChannels, H, W]
119 | """
120 | # Decode back to image space, after first flattening BatchSize * SeqLen
121 | x_rec = self.decoder(zts.contiguous().view([zts.shape[0] * zts.shape[1], -1]))
122 |
123 | # Reshape to image output
124 | x_rec = x_rec.view([zts.shape[0], x_rec.shape[0] // zts.shape[0], self.cfg.dim, self.cfg.dim])
125 | return x_rec
126 |
--------------------------------------------------------------------------------
/models/group_a/DKF.py:
--------------------------------------------------------------------------------
1 | """
2 | @file DKF.py
3 |
4 | Holds the model for the Deep Kalman Filter baseline, based off of the code from
5 | @url{https://github.com/john-x-jiang/meta_ssm/blob/main/model/model.py}
6 | """
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.init as weight_init
10 |
11 | from models.CommonDynamics import LatentDynamicsModel
12 | from torch.distributions import Normal, kl_divergence as kl
13 |
14 |
15 | def reverse_sequence(x, seq_lengths):
16 | """
17 | Brought from
18 | https://github.com/pyro-ppl/pyro/blob/dev/examples/dmm/polyphonic_data_loader.py
19 | Parameters
20 | ----------
21 | x: tensor (b, T_max, input_dim)
22 | seq_lengths: tensor (b, )
23 | Returns
24 | -------
25 | x_reverse: tensor (b, T_max, input_dim)
26 | The input x in reversed order w.r.t. time-axis
27 | """
28 | x_reverse = torch.zeros_like(x)
29 | for b in range(x.size(0)):
30 | t = seq_lengths[b]
31 | time_slice = torch.arange(t - 1, -1, -1, device=x.device)
32 | reverse_seq = torch.index_select(x[b, :, :], 0, time_slice)
33 | x_reverse[b, 0:t, :] = reverse_seq
34 |
35 | return x_reverse
36 |
37 |
38 | class RnnEncoder(nn.Module):
39 | """
40 | RNN encoder that outputs hidden states h_t using x_{t:T}
41 | Parameters
42 | ----------
43 | input_dim: int
44 | Dim. of inputs
45 | num_hidden: int
46 | Dim. of RNN hidden states
47 | n_layer: int
48 | Number of layers of RNN
49 | drop_rate: float [0.0, 1.0]
50 | RNN dropout rate between layers
51 | bd: bool
52 | Use bi-directional RNN or not
53 | Returns
54 | -------
55 | h_rnn: tensor (b, T_max, num_hidden * n_direction)
56 | RNN hidden states at every time-step
57 | """
58 | def __init__(self, cfg, input_dim, num_hidden, n_layer=1, drop_rate=0.0, bd=False, nonlin='relu',
59 | rnn_type='rnn', orthogonal_init=False, reverse_input=True):
60 | super().__init__()
61 | self.n_direction = 1 if not bd else 2
62 | self.cfg = cfg
63 | self.input_dim = input_dim
64 | self.num_hidden = num_hidden
65 | self.n_layer = n_layer
66 | self.drop_rate = drop_rate
67 | self.bd = bd
68 | self.nonlin = nonlin
69 | self.reverse_input = reverse_input
70 |
71 | if not isinstance(rnn_type, str):
72 | raise ValueError("`rnn_type` should be type str.")
73 | self.rnn_type = rnn_type
74 | if rnn_type == 'rnn':
75 | self.rnn = nn.RNN(input_size=input_dim, hidden_size=num_hidden, nonlinearity=nonlin,
76 | batch_first=True, bidirectional=bd, num_layers=n_layer, dropout=drop_rate)
77 | elif rnn_type == 'gru':
78 | self.rnn = nn.GRU(input_size=input_dim, hidden_size=num_hidden,
79 | batch_first=True, bidirectional=bd, num_layers=n_layer, dropout=drop_rate)
80 | elif rnn_type == 'lstm':
81 | self.rnn = nn.LSTM(input_size=input_dim, hidden_size=num_hidden, batch_first=True,
82 | bidirectional=bd, num_layers=n_layer, dropout=drop_rate)
83 | else:
84 | raise ValueError("`rnn_type` must instead be ['rnn', 'gru', 'lstm'] %s" % rnn_type)
85 |
86 | if orthogonal_init:
87 | self.init_weights()
88 |
89 | def init_weights(self):
90 | for w in self.rnn.parameters():
91 | if w.dim() > 1:
92 | weight_init.orthogonal_(w)
93 |
94 | def calculate_effect_dim(self):
95 | return self.num_hidden * self.n_direction
96 |
97 | def init_hidden(self, trainable=True):
98 | if self.rnn_type == 'lstm':
99 | h0 = nn.Parameter(torch.zeros(self.n_layer * self.n_direction, 1, self.num_hidden), requires_grad=trainable)
100 | c0 = nn.Parameter(torch.zeros(self.n_layer * self.n_direction, 1, self.num_hidden), requires_grad=trainable)
101 | return h0, c0
102 | else:
103 | h0 = nn.Parameter(torch.zeros(self.n_layer * self.n_direction, 1, self.num_hidden), requires_grad=trainable)
104 | return h0
105 |
106 | def kl_z_term(self):
107 | """
108 | KL Z term, KL[q(z0|X) || N(0,1)]
109 | :return: mean klz across batch
110 | """
111 | return torch.Tensor([0]).to(self.cfg.devices[0])
112 |
113 | def forward(self, x):
114 | """
115 | x: pytorch packed object
116 | input packed data; this can be obtained from
117 | `util.get_mini_batch()`
118 | h0: tensor (n_layer * n_direction, b, num_hidden)
119 | """
120 | B, T, _ = x.shape
121 | seq_lengths = T * torch.ones(B).int().to(self.cfg.devices[0])
122 | h_rnn, _ = self.rnn(x)
123 | if self.reverse_input:
124 | h_rnn = reverse_sequence(h_rnn, seq_lengths)
125 | return h_rnn
126 |
127 |
128 | class Transition_Recurrent(nn.Module):
129 | """
130 | Parameterize the diagonal Gaussian latent transition probability
131 | `p(z_t | z_{t-1})`
132 | Parameters
133 | ----------
134 | z_dim: int
135 | Dim. of latent variables
136 | transition_dim: int
137 | Dim. of transition hidden units
138 | gated: bool
139 | Use the gated mechanism to consider both linearity and non-linearity
140 | identity_init: bool
141 | Initialize the linearity transform as an identity matrix;
142 | ignored if `gated == False`
143 | clip: bool
144 | clip the value for numerical issues
145 | Returns
146 | -------
147 | mu: tensor (b, z_dim)
148 | Mean that parameterizes the Gaussian
149 | logvar: tensor (b, z_dim)
150 | Log-variance that parameterizes the Gaussian
151 | """
152 |
153 | def __init__(self, z_dim, transition_dim, identity_init=True, domain=False, stochastic=True):
154 | super().__init__()
155 | self.z_dim = z_dim
156 | self.transition_dim = transition_dim
157 | self.identity_init = identity_init
158 | self.domain = domain
159 | self.stochastic = stochastic
160 |
161 | if domain:
162 | # compute the gain (gate) of non-linearity
163 | self.lin1 = nn.Linear(z_dim * 2, transition_dim * 2)
164 | self.lin2 = nn.Linear(transition_dim * 2, z_dim)
165 | # compute the proposed mean
166 | self.lin3 = nn.Linear(z_dim * 2, transition_dim * 2)
167 | self.lin4 = nn.Linear(transition_dim * 2, z_dim)
168 | # linearity
169 | self.lin0 = nn.Linear(z_dim * 2, z_dim)
170 | else:
171 | # compute the gain (gate) of non-linearity
172 | self.lin1 = nn.Linear(z_dim, transition_dim)
173 | self.lin2 = nn.Linear(transition_dim, z_dim)
174 | # compute the proposed mean
175 | self.lin3 = nn.Linear(z_dim, transition_dim)
176 | self.lin4 = nn.Linear(transition_dim, z_dim)
177 | self.lin0 = nn.Linear(z_dim, z_dim)
178 |
179 | # compute the linearity part
180 | self.lin_n = nn.Linear(z_dim, z_dim)
181 |
182 | if identity_init:
183 | self.lin_n.weight.data = torch.eye(z_dim)
184 | self.lin_n.bias.data = torch.zeros(z_dim)
185 |
186 | # compute the variation
187 | self.lin_v = nn.Linear(z_dim, z_dim)
188 | # var activation
189 | # self.act_var = nn.Softplus()
190 | self.act_var = nn.Tanh()
191 |
192 | self.act_weight = nn.Sigmoid()
193 | self.act = nn.ELU(inplace=True)
194 |
195 | def init_z_0(self, trainable=True):
196 | return nn.Parameter(torch.zeros(self.z_dim), requires_grad=trainable), \
197 | nn.Parameter(torch.zeros(self.z_dim), requires_grad=trainable)
198 |
199 | def forward(self, z_t_1, z_c=None):
200 | if self.domain:
201 | z_combine = torch.cat((z_t_1, z_c), dim=1)
202 | _g_t = self.act(self.lin1(z_combine))
203 | g_t = self.act_weight(self.lin2(_g_t))
204 | _h_t = self.act(self.lin3(z_combine))
205 | h_t = self.act(self.lin4(_h_t))
206 | _mu = self.lin0(z_combine)
207 | mu = (1 - g_t) * self.lin_n(_mu) + g_t * h_t
208 | mu = mu + _mu
209 | else:
210 | _g_t = self.act(self.lin1(z_t_1))
211 | g_t = self.act_weight(self.lin2(_g_t))
212 | _h_t = self.act(self.lin3(z_t_1))
213 | h_t = self.act(self.lin4(_h_t))
214 | mu = (1 - g_t) * self.lin_n(z_t_1) + g_t * h_t
215 | _mu = self.lin0(z_t_1)
216 | mu = mu + _mu
217 |
218 | if self.stochastic:
219 | _var = self.lin_v(h_t)
220 | _var = torch.clamp(_var, min=-100, max=85)
221 | var = self.act_var(_var)
222 | return mu, var
223 | else:
224 | return mu
225 |
226 |
227 | class Correction(nn.Module):
228 | """
229 | Parameterize variational distribution `q(z_t | z_{t-1}, x_{t:T})`
230 | a diagonal Gaussian distribution
231 | Parameters
232 | ----------
233 | z_dim: int
234 | Dim. of latent variables
235 | num_hidden: int
236 | Dim. of RNN hidden states
237 | clip: bool
238 | clip the value for numerical issues
239 | Returns
240 | -------
241 | mu: tensor (b, z_dim)
242 | Mean that parameterizes the variational Gaussian distribution
243 | logvar: tensor (b, z_dim)
244 | Log-var that parameterizes the variational Gaussian distribution
245 | """
246 | def __init__(self, z_dim, num_hidden, stochastic=True):
247 | super().__init__()
248 | self.z_dim = z_dim
249 | self.num_hidden = num_hidden
250 | self.stochastic = stochastic
251 |
252 | self.lin1 = nn.Linear(z_dim, num_hidden)
253 | self.act = nn.Tanh()
254 |
255 | self.lin2 = nn.Linear(num_hidden, z_dim)
256 | self.lin_v = nn.Linear(num_hidden, z_dim)
257 |
258 | # self.act_var = nn.Softplus()
259 | self.act_var = nn.Tanh()
260 |
261 | def init_z_q_0(self, trainable=True):
262 | return nn.Parameter(torch.zeros(self.z_dim), requires_grad=trainable)
263 |
264 | def forward(self, h_rnn, z_t_1=None, rnn_bidirection=False):
265 | """
266 | z_t_1: tensor (b, z_dim)
267 | h_rnn: tensor (b, num_hidden)
268 | """
269 | assert z_t_1 is not None
270 | h_comb_ = self.act(self.lin1(z_t_1))
271 | if rnn_bidirection:
272 | h_comb = (1.0 / 3) * (h_comb_ + h_rnn[:, :self.num_hidden] + h_rnn[:, self.num_hidden:])
273 | else:
274 | h_comb = 0.5 * (h_comb_ + h_rnn)
275 | mu = self.lin2(h_comb)
276 |
277 | if self.stochastic:
278 | _var = self.lin_v(h_comb)
279 | _var = torch.clamp(_var, min=-100, max=85)
280 | var = self.act_var(_var)
281 | return mu, var
282 | else:
283 | return mu
284 |
285 |
286 | class DKF(LatentDynamicsModel):
287 | def __init__(self, cfg):
288 | """ Deep Kalman Filter model """
289 | super().__init__(cfg)
290 |
291 | # observation
292 | self.embedding = nn.Sequential(
293 | nn.Linear(self.cfg.dim**2, self.cfg.dim**2 // 4),
294 | nn.ReLU(),
295 | # nn.Linear(self.cfg.dim**2 // 4, self.cfg.dim**2 // 8),
296 | # nn.ReLU(),
297 | nn.Linear(self.cfg.dim**2 // 4, self.cfg.num_hidden),
298 | nn.ReLU()
299 | )
300 | self.encoder = RnnEncoder(cfg, self.cfg.num_hidden, self.cfg.num_hidden, n_layer=self.cfg.rnn_layers, drop_rate=0.0,
301 | bd=self.cfg.bidirectional, nonlin='relu', rnn_type=self.cfg.rnn_type, reverse_input=False)
302 |
303 | # generative model
304 | self.transition = Transition_Recurrent(z_dim=self.cfg.latent_dim, transition_dim=self.cfg.transition_dim)
305 | self.estimation = Correction(z_dim=self.cfg.latent_dim, num_hidden=self.cfg.num_hidden, stochastic=True)
306 |
307 | # initialize hidden states
308 | self.mu_p_0, self.var_p_0 = self.transition.init_z_0(trainable=False)
309 | self.z_q_0 = self.estimation.init_z_q_0(trainable=False)
310 |
311 | # hold p and q distribution parameters for KL term
312 | self.mu_qs = None
313 | self.var_qs = None
314 | self.mu_ps = None
315 | self.var_ps = None
316 |
317 | # placeholder for z_amort
318 | self.z_amort = None
319 |
320 | def reparameterization(self, mu, var):
321 | std = torch.exp(0.5 * var)
322 | eps = torch.randn_like(std)
323 | return mu + eps * std
324 |
325 | def latent_dynamics(self, T, x_rnn):
326 | batch_size = x_rnn.shape[0]
327 |
328 | if T > self.z_amort:
329 | T_final = T
330 | else:
331 | T_final = self.z_amort
332 |
333 | z_ = torch.zeros([batch_size, T_final, self.cfg.latent_dim]).to(self.cfg.devices[0])
334 | if not self.cfg.use_q_forecast:
335 | mu_ps = torch.zeros([batch_size, self.z_amort, self.cfg.latent_dim]).to(self.cfg.devices[0])
336 | var_ps = torch.zeros([batch_size, self.z_amort, self.cfg.latent_dim]).to(self.cfg.devices[0])
337 | mu_qs = torch.zeros([batch_size, self.z_amort, self.cfg.latent_dim]).to(self.cfg.devices[0])
338 | var_qs = torch.zeros([batch_size, self.z_amort, self.cfg.latent_dim]).to(self.cfg.devices[0])
339 | else:
340 | mu_ps = torch.zeros([batch_size, T_final, self.cfg.latent_dim]).to(self.cfg.devices[0])
341 | var_ps = torch.zeros([batch_size, T_final, self.cfg.latent_dim]).to(self.cfg.devices[0])
342 | mu_qs = torch.zeros([batch_size, T_final, self.cfg.latent_dim]).to(self.cfg.devices[0])
343 | var_qs = torch.zeros([batch_size, T_final, self.cfg.latent_dim]).to(self.cfg.devices[0])
344 | x_qs = torch.zeros([batch_size, T_final, self.cfg.dim, self.cfg.dim]).to(self.cfg.devices[0])
345 |
346 | z_q_0 = self.z_q_0.expand(batch_size, self.cfg.latent_dim) # q(z_0)
347 | mu_p_0 = self.mu_p_0.expand(batch_size, 1, self.cfg.latent_dim)
348 | var_p_0 = self.var_p_0.expand(batch_size, 1, self.cfg.latent_dim)
349 | z_prev = z_q_0
350 | z_[:, 0, :] = z_prev
351 |
352 | for t in range(self.z_amort):
353 | # zt = self.transition(z_prev)
354 | mu_q, var_q = self.estimation(x_rnn[:, t, :], z_prev, rnn_bidirection=self.cfg.bidirectional)
355 | zt_q = self.reparameterization(mu_q, var_q)
356 | z_prev = zt_q
357 |
358 | # p(z_{t+1} | z_t)
359 | mu_p, var_p = self.transition(z_prev)
360 | zt_p = self.reparameterization(mu_p, var_p)
361 |
362 | z_[:, t, :] = zt_q
363 | mu_qs[:, t, :] = mu_q
364 | var_qs[:, t, :] = var_q
365 | mu_ps[:, t, :] = mu_p
366 | var_ps[:, t, :] = var_p
367 |
368 | if T > self.z_amort:
369 | for t in range(self.z_amort, T):
370 | if self.cfg.use_q_forecast:
371 | y = self.decoder(z_prev.unsqueeze(1))
372 | x_qs[:,t,:,:] = y.squeeze()
373 | y = self.encoder(self.embedding(y.view(y.size(0), 1, -1)))
374 | mu_q, var_q = self.estimation(y[:,0,:], z_prev, rnn_bidirection=self.cfg.bidirectional)
375 | zt_q = self.reparameterization(mu_q, var_q)
376 | z_prev = zt_q
377 |
378 | mu_p, var_p = self.transition(z_prev)
379 | zt_p = self.reparameterization(mu_p, var_p)
380 |
381 | z_[:, t, :] = zt_q
382 | mu_qs[:, t, :] = mu_q
383 | var_qs[:, t, :] = var_q
384 | mu_ps[:, t, :] = mu_p
385 | var_ps[:, t, :] = var_p
386 | else:
387 | # p(z_{t+1} | z_t)
388 | mu_p, var_p = self.transition(z_prev)
389 | zt_p = self.reparameterization(mu_p, var_p)
390 | z_[:, t, :] = zt_p
391 | z_prev = zt_p
392 |
393 | mu_ps = torch.cat([mu_p_0, mu_ps[:, :-1, :]], dim=1)
394 | var_ps = torch.cat([var_p_0, var_ps[:, :-1, :]], dim=1)
395 |
396 | self.mu_qs, self.var_qs = mu_qs, var_qs
397 | self.mu_ps, self.var_ps = mu_ps, var_ps
398 | if self.cfg.use_q_forecast:
399 | return z_, mu_qs, var_qs, mu_ps, var_ps, x_qs
400 | else:
401 | return z_, mu_qs, var_qs, mu_ps, var_ps
402 |
403 | def forward(self, x, generation_len):
404 | if self.trainer.training:
405 | self.z_amort = self.cfg.z_amort_train
406 | else:
407 | self.z_amort = self.cfg.z_amort_test
408 |
409 | batch_size = x.size(0)
410 |
411 | x = x.view(batch_size, generation_len, -1)
412 | x = self.embedding(x)
413 | x_rnn = self.encoder(x)
414 |
415 | if not self.cfg.use_q_forecast:
416 | z_, mu_qs, var_qs, mu_ps, var_ps = self.latent_dynamics(generation_len, x_rnn)
417 | x_ = self.decoder(z_)
418 | else:
419 | z_, mu_qs, var_qs, mu_ps, var_ps, x_ = self.latent_dynamics(generation_len, x_rnn)
420 | x_[:,:self.z_amort,:,:] = self.decoder(z_[:,:self.z_amort,:])
421 | return x_, z_
422 |
423 | def model_specific_loss(self, *args, train=True):
424 | """ KL term between the p and q distributions (reconstruction and estimation) """
425 | q = Normal(self.mu_qs, torch.exp(0.5 * self.var_qs))
426 | p = Normal(self.mu_ps, torch.exp(0.5 * self.var_ps))
427 | return kl(q, p).sum([-1]).mean()
428 |
--------------------------------------------------------------------------------
/models/group_a/VRNN.py:
--------------------------------------------------------------------------------
1 | """
2 | @file VRNN.py
3 |
4 | Holds the model for the Variational Recurrent Neural Network baseline, source code modified from
5 | @url{https://github.com/XiaoyuBIE1994/DVAE/blob/master/dvae/model/vrnn.py}
6 | """
7 | import torch
8 | import torch.nn as nn
9 |
10 | from collections import OrderedDict
11 | from models.CommonDynamics import LatentDynamicsModel
12 | from torch.distributions import Normal, kl_divergence as kl
13 |
14 |
15 | class FakeEncoder(nn.Module):
16 | def __init__(self, cfg):
17 | super().__init__()
18 | self.cfg = cfg
19 |
20 | def kl_z_term(self):
21 | return torch.Tensor([0.]).to(self.cfg.devices[0])
22 |
23 |
24 | class VRNN(LatentDynamicsModel):
25 | def __init__(self, cfg):
26 | """ Latent dynamics as parameterized by a global deterministic neural ODE """
27 | super().__init__(cfg)
28 | self.encoder = FakeEncoder(cfg)
29 | self.decoder = None
30 |
31 | ### General parameters
32 | self.x_dim = self.cfg.dim ** 2
33 | self.z_dim = self.cfg.latent_dim
34 | self.dropout_p = 0.2
35 | self.y_dim = self.x_dim
36 | self.activation = nn.LeakyReLU(0.1)
37 | self.sigmoid = nn.Sigmoid()
38 | self.z_amort = None
39 |
40 | ### Feature extractors
41 | self.dense_x = [256]
42 | self.dense_z = [256]
43 |
44 | ### Dense layers
45 | self.dense_hx_z = [128]
46 | self.dense_hz_x = [256]
47 | self.dense_h_z = [128]
48 |
49 | ### RNN
50 | self.dim_RNN = 16
51 | self.num_RNN = 1
52 |
53 | ### Beta-loss
54 | self.beta = 1
55 |
56 | ###########################
57 | #### Feature extractor ####
58 | ###########################
59 | # x
60 | dic_layers = OrderedDict()
61 | if len(self.dense_x) == 0:
62 | dim_feature_x = self.x_dim
63 | dic_layers['Identity'] = nn.Identity()
64 | else:
65 | dim_feature_x = self.dense_x[-1]
66 | for n in range(len(self.dense_x)):
67 | if n == 0:
68 | dic_layers['linear' + str(n)] = nn.Linear(self.x_dim, self.dense_x[n])
69 | else:
70 | dic_layers['linear' + str(n)] = nn.Linear(self.dense_x[n - 1], self.dense_x[n])
71 | dic_layers['activation' + str(n)] = self.activation
72 | dic_layers['dropout' + str(n)] = nn.Dropout(p=self.dropout_p)
73 | self.feature_extractor_x = nn.Sequential(dic_layers)
74 | # z
75 | dic_layers = OrderedDict()
76 | if len(self.dense_z) == 0:
77 | dim_feature_z = self.z_dim
78 | dic_layers['Identity'] = nn.Identity()
79 | else:
80 | dim_feature_z = self.dense_z[-1]
81 | for n in range(len(self.dense_z)):
82 | if n == 0:
83 | dic_layers['linear' + str(n)] = nn.Linear(self.z_dim, self.dense_z[n])
84 | else:
85 | dic_layers['linear' + str(n)] = nn.Linear(self.dense_z[n - 1], self.dense_z[n])
86 | dic_layers['activation' + str(n)] = self.activation
87 | dic_layers['dropout' + str(n)] = nn.Dropout(p=self.dropout_p)
88 | self.feature_extractor_z = nn.Sequential(dic_layers)
89 |
90 | ######################
91 | #### Dense layers ####
92 | ######################
93 | # 1. h_t, x_t to z_t (Inference)
94 | dic_layers = OrderedDict()
95 | if len(self.dense_hx_z) == 0:
96 | dim_hx_z = self.dim_RNN + dim_feature_x
97 | dic_layers['Identity'] = nn.Identity()
98 | else:
99 | dim_hx_z = self.dense_hx_z[-1]
100 | for n in range(len(self.dense_hx_z)):
101 | if n == 0:
102 | dic_layers['linear' + str(n)] = nn.Linear(self.dense_x[-1] + self.dim_RNN, self.dense_hx_z[n])
103 | else:
104 | dic_layers['linear' + str(n)] = nn.Linear(self.dense_hx_z[n - 1], self.dense_hx_z[n])
105 | dic_layers['activation' + str(n)] = self.activation
106 | dic_layers['dropout' + str(n)] = nn.Dropout(p=self.dropout_p)
107 | self.mlp_hx_z = nn.Sequential(dic_layers)
108 | self.inf_mean = nn.Linear(dim_hx_z, self.z_dim)
109 | self.inf_logvar = nn.Linear(dim_hx_z, self.z_dim)
110 |
111 | # 2. h_t to z_t (Generation z)
112 | dic_layers = OrderedDict()
113 | if len(self.dense_h_z) == 0:
114 | dim_h_z = self.dim_RNN
115 | dic_layers['Identity'] = nn.Identity()
116 | else:
117 | dim_h_z = self.dense_h_z[-1]
118 | for n in range(len(self.dense_h_z)):
119 | if n == 0:
120 | dic_layers['linear' + str(n)] = nn.Linear(self.dim_RNN, self.dense_h_z[n])
121 | else:
122 | dic_layers['linear' + str(n)] = nn.Linear(self.dense_h_z[n - 1], self.dense_h_z[n])
123 | dic_layers['activation' + str(n)] = self.activation
124 | dic_layers['dropout' + str(n)] = nn.Dropout(p=self.dropout_p)
125 | self.mlp_h_z = nn.Sequential(dic_layers)
126 | self.prior_mean = nn.Linear(dim_h_z, self.z_dim)
127 | self.prior_logvar = nn.Linear(dim_h_z, self.z_dim)
128 |
129 | # 3. h_t, z_t to x_t (Generation x)
130 | dic_layers = OrderedDict()
131 | if len(self.dense_hz_x) == 0:
132 | dim_hz_x = self.dim_RNN + dim_feature_z
133 | dic_layers['Identity'] = nn.Identity()
134 | else:
135 | dim_hz_x = self.dense_hz_x[-1]
136 | for n in range(len(self.dense_hz_x)):
137 | if n == 0:
138 | dic_layers['linear' + str(n)] = nn.Linear(self.dim_RNN + dim_feature_z, self.dense_hz_x[n])
139 | else:
140 | dic_layers['linear' + str(n)] = nn.Linear(self.dense_hz_x[n - 1], self.dense_hz_x[n])
141 | dic_layers['activation' + str(n)] = self.activation
142 | dic_layers['dropout' + str(n)] = nn.Dropout(p=self.dropout_p)
143 | self.mlp_hz_x = nn.Sequential(dic_layers)
144 | self.gen_out = nn.Linear(dim_hz_x, self.y_dim)
145 |
146 | ####################
147 | #### Recurrence ####
148 | ####################
149 | self.rnn = nn.LSTM(dim_feature_x + dim_feature_z, self.dim_RNN, self.num_RNN)
150 |
151 | def reparameterization(self, mean, logvar):
152 | std = torch.exp(0.5 * logvar)
153 | eps = torch.randn_like(std)
154 | return torch.addcmul(mean, eps, std)
155 |
156 | def generation_x(self, feature_zt, h_t):
157 | dec_input = torch.cat((feature_zt, h_t), 2)
158 | dec_output = self.mlp_hz_x(dec_input)
159 | y_t = self.gen_out(dec_output)
160 | y_t = self.sigmoid(y_t)
161 | return y_t
162 |
163 | def generation_z(self, h):
164 | prior_output = self.mlp_h_z(h)
165 | mean_prior = self.prior_mean(prior_output)
166 | logvar_prior = self.prior_logvar(prior_output)
167 | return mean_prior, logvar_prior
168 |
169 | def inference(self, feature_xt, h_t):
170 | enc_input = torch.cat((feature_xt, h_t), 2)
171 | enc_output = self.mlp_hx_z(enc_input)
172 | mean_zt = self.inf_mean(enc_output)
173 | logvar_zt = self.inf_logvar(enc_output)
174 | return mean_zt, logvar_zt
175 |
176 | def recurrence(self, feature_xt, feature_zt, h_t, c_t):
177 | rnn_input = torch.cat((feature_xt, feature_zt), -1)
178 | _, (h_tp1, c_tp1) = self.rnn(rnn_input, (h_t, c_t))
179 | return h_tp1, c_tp1
180 |
181 | def forward(self, x, generation_len):
182 | batch_size = x.shape[0]
183 | seq_len = generation_len
184 |
185 | # Input is an image so reduce down to [batch_size, seq_len, flattened_dim]
186 | x = x.reshape(x.shape[0], x.shape[1], -1)
187 |
188 | # Permute it to [seq_len, batch_size, flattened_dim]
189 | x = x.permute(1, 0, 2)
190 |
191 | # Create variable holder and send to GPU if needed
192 | self.z_mean = torch.zeros((seq_len, batch_size, self.z_dim)).to(self.cfg.devices[0])
193 | self.z_logvar = torch.zeros((seq_len, batch_size, self.z_dim)).to(self.cfg.devices[0])
194 | y = torch.zeros((seq_len, batch_size, self.y_dim)).to(self.cfg.devices[0])
195 | self.z = torch.zeros((seq_len, batch_size, self.z_dim)).to(self.cfg.devices[0])
196 | h = torch.zeros((seq_len, batch_size, self.dim_RNN)).to(self.cfg.devices[0])
197 | h_t = torch.zeros(self.num_RNN, batch_size, self.dim_RNN).to(self.cfg.devices[0])
198 | c_t = torch.zeros(self.num_RNN, batch_size, self.dim_RNN).to(self.cfg.devices[0])
199 |
200 | # For the observed frames, use real input; otherwise use previous generated frame
201 | feature_x_obs = self.feature_extractor_x(x[:self.z_amort])
202 | for t in range(generation_len):
203 | if t < self.z_amort:
204 | feature_xt = feature_x_obs[t, :, :].unsqueeze(0)
205 | else:
206 | feature_xt = self.feature_extractor_x(y_prev)
207 |
208 | h_t_last = h_t.view(self.num_RNN, 1, batch_size, self.dim_RNN)[-1, :, :, :]
209 | mean_zt, logvar_zt = self.inference(feature_xt, h_t_last)
210 | z_t = self.reparameterization(mean_zt, logvar_zt)
211 | feature_zt = self.feature_extractor_z(z_t)
212 | y_t = self.generation_x(feature_zt, h_t_last)
213 | y_prev = y_t.detach()
214 | self.z_mean[t, :, :] = mean_zt
215 | self.z_logvar[t, :, :] = logvar_zt
216 | self.z[t, :, :] = torch.squeeze(z_t)
217 | y[t, :, :] = torch.squeeze(y_t)
218 | h[t, :, :] = torch.squeeze(h_t_last)
219 | h_t, c_t = self.recurrence(feature_xt, feature_zt, h_t, c_t) # recurrence for t+1
220 |
221 | self.z_mean_p, self.z_logvar_p = self.generation_z(h)
222 |
223 | # Reshape and permute reconstructions + embeddings back to useable shapes
224 | y = y.permute(1, 0, 2).reshape([batch_size, seq_len, self.cfg.dim, self.cfg.dim])
225 | embeddings = self.z.permute(1, 0, 2)
226 | return y, embeddings
227 |
228 | def model_specific_loss(self, x, x_rec, train=True):
229 | """ KL term between the parameter distribution w and a normal prior"""
230 | # Reshape to [BS, SL, LatentDim]
231 | z_mus = self.z_mean.permute(1, 0, 2).reshape([x.shape[0], -1])
232 | z_logvar = self.z_logvar.permute(1, 0, 2).reshape([x.shape[0], -1])
233 |
234 | z_mus_prior = self.z_mean_p.permute(1, 0, 2).reshape([x.shape[0], -1])
235 | z_logvar_prior = self.z_logvar_p.permute(1, 0, 2).reshape([x.shape[0], -1])
236 |
237 | q = Normal(z_mus, torch.exp(0.5 * z_logvar))
238 | N = Normal(z_mus_prior, torch.exp(0.5 * z_logvar_prior))
239 | return kl(q, N).sum([-1]).mean()
--------------------------------------------------------------------------------
/models/group_b1/KVAE.py:
--------------------------------------------------------------------------------
1 | """
2 | @file KVAE.py
3 |
4 | Holds the model for the Kalman Variational Auto-encoder, source code modified from
5 | @url{https://github.com/XiaoyuBIE1994/DVAE/blob/master/dvae/model/kvae.py}
6 | """
7 | import torch
8 | import numpy as np
9 | import torch.nn as nn
10 |
11 | from collections import OrderedDict
12 | from models.CommonDynamics import LatentDynamicsModel
13 | from torch.distributions.multivariate_normal import MultivariateNormal
14 |
15 |
16 | class FakeEncoder(nn.Module):
17 | def __init__(self, cfg):
18 | super().__init__()
19 | self.cfg = cfg
20 |
21 | def kl_z_term(self):
22 | return torch.Tensor([0.]).to(self.cfg.devices[0])
23 |
24 |
25 | class KVAE(LatentDynamicsModel):
26 | def __init__(self, cfg):
27 | """ Latent dynamics as parameterized by a global deterministic neural ODE """
28 | super().__init__(cfg)
29 | self.encoder = FakeEncoder(cfg)
30 |
31 | ## General parameters
32 | self.x_dim = self.cfg.dim**2
33 | self.y_dim = self.cfg.dim**2
34 | self.a_dim = self.cfg.latent_dim * 2
35 | self.z_dim = self.cfg.latent_dim
36 | self.u_dim = self.cfg.latent_dim * 2
37 | self.dropout_p = 0.2
38 | self.activation = nn.ReLU()
39 |
40 | # VAE
41 | self.dense_x_a = [512, 256]
42 | self.dense_a_x = [512, 256]
43 |
44 | # LGSSM
45 | self.init_kf_mat = 0.05
46 | self.noise_transition = 0.08
47 | self.noise_emission = 0.03
48 | self.init_cov = 20
49 |
50 | # Dynamics params (alpha)
51 | self.K = 4
52 | self.dim_RNN_alpha = 128
53 | self.num_RNN_alpha = 2
54 |
55 | self.build()
56 |
57 | def build(self):
58 | #############
59 | #### VAE ####
60 | #############
61 | # 1. Inference of a_t
62 | dic_layers = OrderedDict()
63 | if len(self.dense_x_a) == 0:
64 | dim_x_a = self.x_dim
65 | dic_layers["Identity"] = nn.Identity()
66 | else:
67 | dim_x_a = self.dense_x_a[-1]
68 | for n in range(len(self.dense_x_a)):
69 | if n == 0:
70 | dic_layers["linear" + str(n)] = nn.Linear(self.x_dim, self.dense_x_a[n])
71 | else:
72 | dic_layers['linear' + str(n)] = nn.Linear(self.dense_x_a[n - 1], self.dense_x_a[n])
73 | dic_layers['activation' + str(n)] = self.activation
74 | dic_layers['dropout' + str(n)] = nn.Dropout(p=self.dropout_p)
75 | self.mlp_x_a = nn.Sequential(dic_layers)
76 | self.inf_mean = nn.Linear(dim_x_a, self.a_dim)
77 | self.inf_logvar = nn.Linear(dim_x_a, self.a_dim)
78 |
79 | # 2. Generation of x_t
80 | dic_layers = OrderedDict()
81 | if len(self.dense_a_x) == 0:
82 | dim_a_x = self.a_dim
83 | dic_layers["Identity"] = nn.Identity()
84 | else:
85 | dim_a_x = self.dense_a_x[-1]
86 | for n in range(len(self.dense_x_a)):
87 | if n == 0:
88 | dic_layers["linear" + str(n)] = nn.Linear(self.a_dim, self.dense_a_x[n])
89 | else:
90 | dic_layers["linear" + str(n)] = nn.Linear(self.dense_a_x[n - 1], self.dense_a_x[n])
91 | dic_layers["activation" + str(n)] = self.activation
92 | dic_layers["dropout" + str(n)] = nn.Dropout(p=self.dropout_p)
93 | self.mlp_a_x = nn.Sequential(dic_layers)
94 | self.gen_logvar = nn.Linear(dim_a_x, self.x_dim)
95 |
96 | ###############
97 | #### LGSSM ####
98 | ###############
99 | # Initializers for LGSSM variables, torch.tensor(), enforce torch.float32 type
100 | # A is an identity matrix
101 | # B and C are randomly sampled from a Gaussian
102 | # Q and R are isotroipic covariance matrices
103 | # z = Az + Bu
104 | # a = Cz
105 | self.A = torch.tensor(np.array([np.eye(self.z_dim) for _ in range(self.K)]), dtype=torch.float32,
106 | requires_grad=True, device=self.cfg.devices[0]) # (K, z_dim. z_dim,)
107 | self.B = torch.tensor(
108 | np.array([self.init_kf_mat * np.random.randn(self.z_dim, self.u_dim) for _ in range(self.K)]),
109 | dtype=torch.float32, requires_grad=True, device=self.cfg.devices[0]) # (K, z_dim, u_dim)
110 | self.C = torch.tensor(
111 | np.array([self.init_kf_mat * np.random.randn(self.a_dim, self.z_dim) for _ in range(self.K)]),
112 | dtype=torch.float32, requires_grad=True, device=self.cfg.devices[0]) # (K, a_dim, z_dim)
113 | self.Q = self.noise_transition * torch.eye(self.z_dim).to(self.cfg.devices[0]) # (z_dim, z_dim)
114 | self.R = self.noise_emission * torch.eye(self.a_dim).to(self.cfg.devices[0]) # (a_dim, a_dim)
115 | self._I = torch.eye(self.z_dim).to(self.cfg.devices[0]) # (z_dim, z_dim)
116 |
117 | ###############
118 | #### Alpha ####
119 | ###############
120 | self.a_init = torch.zeros((1, self.a_dim), requires_grad=True, device=self.cfg.devices[0]) # (bs, a_dim)
121 | self.rnn_alpha = nn.LSTM(self.a_dim, self.dim_RNN_alpha, self.num_RNN_alpha, bidirectional=False)
122 | self.mlp_alpha = nn.Sequential(nn.Linear(self.dim_RNN_alpha, self.K), nn.Softmax(dim=-1))
123 |
124 | ############################
125 | #### Scheduler Training ####
126 | ############################
127 | self.A = nn.Parameter(self.A)
128 | self.B = nn.Parameter(self.B)
129 | self.C = nn.Parameter(self.C)
130 | self.a_init = nn.Parameter(self.a_init)
131 | kf_params = [self.A, self.B, self.C, self.a_init]
132 |
133 | self.iter_kf = (i for i in kf_params)
134 | self.iter_vae = self.concat_iter(self.mlp_x_a.parameters(),
135 | self.inf_mean.parameters(),
136 | self.inf_logvar.parameters(),
137 | self.mlp_a_x.parameters(),
138 | self.gen_logvar.parameters())
139 | self.iter_alpha = self.concat_iter(self.rnn_alpha.parameters(),
140 | self.mlp_alpha.parameters())
141 | self.iter_vae_kf = self.concat_iter(self.iter_vae, self.iter_kf)
142 | self.iter_all = self.concat_iter(self.iter_kf, self.iter_vae, self.iter_alpha)
143 |
144 | def concat_iter(self, *iter_list):
145 | for i in iter_list:
146 | yield from i
147 |
148 | def reparameterization(self, mean, logvar):
149 | std = torch.exp(0.5 * logvar)
150 | eps = torch.randn_like(std)
151 | return eps.mul(std).add_(mean)
152 |
153 | def inference(self, x):
154 | x_a = self.mlp_x_a(x)
155 | a_mean = self.inf_mean(x_a)
156 | a_logvar = self.inf_logvar(x_a)
157 | a = self.reparameterization(a_mean, a_logvar)
158 | return a, a_mean, a_logvar
159 |
160 | def get_alpha(self, a_tm1):
161 | """
162 | Dynamics parameter network alpha for mixing transitions in a SSM
163 | Unlike original code, we only propose RNN here
164 | """
165 | alpha, _ = self.rnn_alpha(a_tm1) # (seq_len, bs, dim_alpha)
166 | alpha = self.mlp_alpha(alpha) # (seq_len, bs, K), softmax on K dimension
167 | return alpha
168 |
169 | def generation_x(self, a):
170 | a_x = self.mlp_a_x(a)
171 | log_y = self.gen_logvar(a_x)
172 | y = torch.exp(log_y)
173 | return y
174 |
175 | def kf_smoother(self, a, u, K, A, B, C, R, Q, optimal_gain=False, alpha_sq=1):
176 | """"
177 | Kalman Smoother, refer to Murphy's book (MLAPP), section 18.3
178 | Difference from KVAE source code:
179 | - no imputation
180 | - only RNN for the calculation of alpha
181 | - different notations (rather than using same notations as Murphy's book ,we use notation from model KVAE)
182 | # z_t = A_t * z_tm1 + B_t * u_t
183 | #a_t = C_t * z_t
184 | Input:
185 | - a, (seq_len, bs, a_dim)
186 | - u, (seq_len, bs, u_dim)
187 | - alpha, (seq_len, bs, alpha_dim)
188 | - K, real number
189 | - A, (K, z_dim, z_dim)
190 | - B, (K, z_dim, u_dim)
191 | - C, (K, a_dim, z_dim)
192 | - R, (z_dim, z_dim)
193 | - Q , (a_dim, a_dim)
194 | """
195 | # Initialization
196 | seq_len = a.shape[0]
197 | batch_size = a.shape[1]
198 | self.mu = torch.zeros((batch_size, self.z_dim)).to(self.device) # (bs, z_dim), z_0
199 | self.Sigma = self.init_cov * torch.eye(self.z_dim).unsqueeze(0).repeat(batch_size, 1, 1).to(self.device) # (bs, z_dim, z_dim), Sigma_0
200 | mu_pred = torch.zeros((seq_len, batch_size, self.z_dim)).to(self.device) # (seq_len, bs, z_dim)
201 | mu_filter = torch.zeros((seq_len, batch_size, self.z_dim)).to(self.device) # (seq_len, bs, z_dim)
202 | mu_smooth = torch.zeros((seq_len, batch_size, self.z_dim)).to(self.device) # (seq_len, bs, z_dim)
203 | Sigma_pred = torch.zeros((seq_len, batch_size, self.z_dim, self.z_dim)).to(self.device) # (seq_len, bs, z_dim, z_dim)
204 | Sigma_filter = torch.zeros((seq_len, batch_size, self.z_dim, self.z_dim)).to(self.device) # (seq_len, bs, z_dim, z_dim)
205 | Sigma_smooth = torch.zeros((seq_len, batch_size, self.z_dim, self.z_dim)).to(self.device) # (seq_len, bs, z_dim, z_dim)
206 |
207 | # Calculate alpha, initial observation a_init is assumed to be zero and can be learned
208 | a_init_expand = self.a_init.unsqueeze(1).repeat(1, batch_size, 1) # (1, bs, a_dim)
209 | a_tm1 = torch.cat([a_init_expand, a[:-1, :, :]], 0) # (seq_len, bs, a_dim)
210 | alpha = self.get_alpha(a_tm1) # (seq_len, bs, K)
211 |
212 | # Calculate the mixture of A, B and C
213 | A_flatten = A.view(K, self.z_dim * self.z_dim) # (K, z_dim*z_dim)
214 | B_flatten = B.view(K, self.z_dim * self.u_dim) # (K, z_dim*u_dim)
215 | C_flatten = C.view(K, self.a_dim * self.z_dim) # (K, a_dim*z_dim)
216 | A_mix = alpha.matmul(A_flatten).view(seq_len, batch_size, self.z_dim, self.z_dim)
217 | B_mix = alpha.matmul(B_flatten).view(seq_len, batch_size, self.z_dim, self.u_dim)
218 | C_mix = alpha.matmul(C_flatten).view(seq_len, batch_size, self.a_dim, self.z_dim)
219 |
220 | # Forward filter
221 | for t in range(seq_len):
222 | # Mixture of A, B and C
223 | A_t = A_mix[t] # (bs, z_dim. z_dim)
224 | B_t = B_mix[t] # (bs, z_dim, u_dim)
225 | C_t = C_mix[t] # (bs, a_dim, z_dim)
226 |
227 | if t == 0:
228 | mu_t_pred = self.mu.unsqueeze(-1) # (bs, z_dim, 1)
229 | Sigma_t_pred = self.Sigma
230 | else:
231 | u_t = u[t, :, :] # (bs, u_dim)
232 | mu_t_pred = A_t.bmm(mu_t) + B_t.bmm(u_t.unsqueeze(-1)) # (bs, z_dim, 1), z_{t|t-1}
233 | Sigma_t_pred = alpha_sq * A_t.bmm(Sigma_t).bmm(
234 | A_t.transpose(1, 2)) + self.Q # (bs, z_dim, z_dim), Sigma_{t|t-1}
235 | # alpha_sq (>=1) is fading memory control, which indicates how much you want to forgert past measurements, see more infos in 'FilterPy' library
236 |
237 | # Residual
238 | a_pred = C_t.bmm(mu_t_pred) # (bs, a_dim, z_dim) x (bs, z_dim, 1)
239 | res_t = a[t, :, :].unsqueeze(-1) - a_pred # (bs, a_dim, 1)
240 |
241 | # Kalman gain
242 | S_t = C_t.bmm(Sigma_t_pred).bmm(C_t.transpose(1, 2)) + self.R # (bs, a_dim, a_dim)
243 | S_t_inv = S_t.inverse()
244 | K_t = Sigma_t_pred.bmm(C_t.transpose(1, 2)).bmm(S_t_inv) # (bs, z_dim, a_dim)
245 |
246 | # Update
247 | mu_t = mu_t_pred + K_t.bmm(res_t) # (bs, z_dim, 1)
248 | I_KC = self._I - K_t.bmm(C_t) # (bs, z_dim, z_dim)
249 | if optimal_gain:
250 | Sigma_t = I_KC.bmm(Sigma_t_pred) # (bs, z_dim, z_dim), only valid with optimal Kalman gain
251 | else:
252 | Sigma_t = I_KC.bmm(Sigma_t_pred).bmm(I_KC.transpose(1, 2)) + K_t.matmul(self.R).matmul(
253 | K_t.transpose(1, 2)) # (bs, z_dim, z_dim), general case
254 |
255 | # Save cache
256 | mu_pred[t] = mu_t_pred.view(batch_size, self.z_dim)
257 | Sigma_pred[t] = Sigma_t_pred
258 | Sigma_filter[t] = Sigma_t
259 |
260 | # Add the final state from filter to the smoother as initialization
261 | mu_smooth[-1] = mu_filter[-1]
262 | Sigma_smooth[-1] = Sigma_filter[-1]
263 |
264 | # Backward smooth, reverse loop from pernultimate state
265 | for t in range(seq_len - 2, -1, -1):
266 | # Backward Kalman gain
267 | J_t = Sigma_filter[t].bmm(A_mix[t + 1].transpose(1, 2)).bmm(
268 | Sigma_pred[t + 1].inverse()) # (bs, z_dim, z_dim)
269 |
270 | # Backward smoothing
271 | dif_mu_tp1 = (mu_smooth[t + 1] - mu_filter[t + 1]).unsqueeze(-1) # (bs, z_dim, 1)
272 | mu_smooth[t] = mu_filter[t] + J_t.matmul(dif_mu_tp1).view(batch_size, self.z_dim) # (bs, z_dim)
273 | dif_Sigma_tp1 = Sigma_smooth[t + 1] - Sigma_pred[t + 1] # (bs, z_dim, z_dim)
274 | Sigma_smooth[t] = Sigma_filter[t] + J_t.bmm(dif_Sigma_tp1).bmm(J_t.transpose(1, 2)) # (bs, z_dim, z_dim)
275 |
276 | # Generate a from smoothing z
277 | a_gen = C_mix.matmul(mu_smooth.unsqueeze(-1)).view(seq_len, batch_size, self.a_dim) # (seq_len, bs, a_dim)
278 | return a_gen, mu_smooth, Sigma_smooth, A_mix, B_mix, C_mix
279 |
280 | def forward(self, x, generation_len):
281 | # train input: (batch_size, x_dim, seq_len)
282 | # test input: (x_dim, seq_len)
283 | # need input: (seq_len, batch_size, x_dim)
284 |
285 | # Input is an image so reduce down to [batch_size, flattened_dim, seq_len]
286 | x = x.reshape(x.shape[0], x.shape[1], -1)
287 | x = x.permute(1, 0, 2)
288 | seq_len = x.shape[0]
289 | batch_size = x.shape[1]
290 |
291 | # main part
292 | self.a, self.a_mean, self.a_logvar = self.inference(x)
293 | u_0 = torch.zeros(1, batch_size, self.u_dim).to(self.cfg.devices[0])
294 | self.u = torch.cat((u_0, self.a[:-1]), 0)
295 | a_gen, self.mu_smooth, self.Sigma_smooth, self.A_mix, self.B_mix, self.C_mix = self.kf_smoother(
296 | self.a, self.u, self.K, self.A, self.B, self.C, self.R, self.Q
297 | )
298 | self.y = self.generation_x(a_gen)
299 |
300 | # output of NN: (seq_len, batch_size, dim)
301 | # output of model: (batch_size, dim, seq_len) or (dim, seq_len)
302 | self.y = self.y.permute(1, 0, 2)
303 |
304 | # Reshape y to image dimensions
305 | self.y = self.y.reshape([batch_size, seq_len, self.cfg.dim, self.cfg.dim])
306 | return self.y, torch.zeros([batch_size, seq_len, self.cfg.latent_dim])
307 |
308 | def model_specific_loss(self, x, x_rec, train=True):
309 | # batch_size, seq_len = x.shape[0], x.shape[1]
310 | # Input is an image so reduce down to [batch_size, flattened_dim, seq_len]
311 | x = x.reshape(x.shape[0], x.shape[1], -1)
312 | x = x.permute(1, 0, 2)
313 |
314 | seq_len = x.shape[0]
315 | batch_size = x.shape[1]
316 |
317 | # log q_{\phi}(a_hat | x), Gaussian
318 | log_qa_given_x = - 0.5 * self.a_logvar - torch.pow(self.a - self.a_mean, 2) / (2 * torch.exp(self.a_logvar))
319 |
320 | # log p_{\gamma}(a_tilde, z_tilde | u) < in sub-comment, 'tilde' is hidden for simplification >
321 | # >>> log p(z_t | z_tm1, u_t), transition
322 | Sigma_smooth_stable = self.Sigma_smooth
323 | L = torch.linalg.cholesky(Sigma_smooth_stable)
324 | mvn_smooth = MultivariateNormal(self.mu_smooth, scale_tril=L)
325 |
326 |
327 | # mvn_smooth = MultivariateNormal(self.mu_smooth, self.Sigma_smooth)
328 | z_smooth = mvn_smooth.sample() # # (seq_len, bs, z_dim)
329 | Az_tm1 = self.A_mix[:-1].matmul(z_smooth[:-1].unsqueeze(-1)).view(seq_len - 1, batch_size, -1) # (seq_len, bs, z_dim)
330 | Bu_t = self.B_mix[:-1].matmul(self.u[:-1].unsqueeze(-1)).view(seq_len - 1, batch_size, -1) # (seq_len, bs, z_dim)
331 | mu_t_transition = Az_tm1 + Bu_t
332 | z_t_transition = z_smooth[1:]
333 | mvn_transition = MultivariateNormal(z_t_transition, self.Q)
334 | log_prob_transition = mvn_transition.log_prob(mu_t_transition)
335 |
336 | # >>> log p(z_0 | z_init), init state
337 | z_0 = z_smooth[0]
338 | mvn_0 = MultivariateNormal(self.mu, self.Sigma)
339 | log_prob_0 = mvn_0.log_prob(z_0)
340 |
341 | # >>> log p(a_t | z_t), emission
342 | Cz_t = self.C_mix.matmul(z_smooth.unsqueeze(-1)).view(seq_len, batch_size, self.a_dim)
343 | mvn_emission = MultivariateNormal(Cz_t, self.R)
344 | log_prob_emission = mvn_emission.log_prob(self.a)
345 |
346 | # >>> log p_{\gamma}(a_tilde, z_tilde | u)
347 | log_paz_given_u = torch.cat([log_prob_transition, log_prob_0.unsqueeze(0)], 0) + log_prob_emission
348 |
349 | # log p_{\gamma}(z_tilde | a_tilde, u)
350 | # >>> log p(z_t | a, u)
351 | log_pz_given_au = mvn_smooth.log_prob(z_smooth)
352 |
353 | # Normalization
354 | log_qa_given_x = torch.sum(log_qa_given_x) / (batch_size * seq_len)
355 | log_paz_given_u = torch.sum(log_paz_given_u) / (batch_size * seq_len)
356 | log_pz_given_au = torch.sum(log_pz_given_au) / (batch_size * seq_len)
357 |
358 | # Loss
359 | loss_vae = log_qa_given_x
360 | loss_lgssm = - log_paz_given_u + log_pz_given_au
361 | loss_tot = loss_vae + loss_lgssm
362 | return loss_vae + loss_lgssm + loss_tot
363 |
--------------------------------------------------------------------------------
/models/group_b2/NeuralODE.py:
--------------------------------------------------------------------------------
1 | """
2 | @file NeuralODE.py
3 |
4 | Holds the model for the Neural ODE latent dynamics function
5 | """
6 | import torch
7 | import torch.nn as nn
8 |
9 | from torchdiffeq import odeint
10 | from utils.utils import get_act
11 | from models.CommonDynamics import LatentDynamicsModel
12 |
13 |
14 | class ODEFunction(nn.Module):
15 | def __init__(self, cfg):
16 | """ Standard Neural ODE dynamics function """
17 | super(ODEFunction, self).__init__()
18 |
19 | # Build the dynamics network
20 | dynamics_network = []
21 | dynamics_network.extend([
22 | nn.Linear(cfg.latent_dim, cfg.num_hidden),
23 | get_act(cfg.latent_act)
24 | ])
25 |
26 | for _ in range(cfg.num_layers - 1):
27 | dynamics_network.extend([
28 | nn.Linear(cfg.num_hidden, cfg.num_hidden),
29 | get_act(cfg.latent_act)
30 | ])
31 |
32 | dynamics_network.extend([nn.Linear(cfg.num_hidden, cfg.latent_dim), nn.Tanh()])
33 | self.dynamics_network = nn.Sequential(*dynamics_network)
34 |
35 | def forward(self, t, z):
36 | """ Wrapper function for the odeint calculation """
37 | return self.dynamics_network(z)
38 |
39 |
40 | class NeuralODE(LatentDynamicsModel):
41 | def __init__(self, cfg):
42 | """ Latent dynamics as parameterized by a global deterministic neural ODE """
43 | super().__init__(cfg)
44 |
45 | # ODE-Net which holds mixture logic
46 | self.dynamics_func = ODEFunction(cfg)
47 |
48 | def forward(self, x, generation_len):
49 | """
50 | Forward function of the ODE network
51 | :param x: data observation, which is a timeseries [BS, Timesteps, N Channels, Dim1, Dim2]
52 | :param generation_len: how many timesteps to generate over
53 | """
54 | # Sample z_init
55 | z_init = self.encoder(x)
56 |
57 | # Evaluate model forward over T to get L latent reconstructions
58 | t = torch.linspace(0, generation_len - 1, generation_len, device=self.device)
59 | zt = odeint(self.dynamics_func, z_init, t, method=self.cfg.integrator, options={'step_size': self.cfg.integrator_step_size})
60 | zt = zt.permute([1, 0, 2])
61 |
62 | # Stack zt and decode zts
63 | x_rec = self.decoder(zt)
64 | return x_rec, zt
65 |
--------------------------------------------------------------------------------
/models/group_b2/RGN.py:
--------------------------------------------------------------------------------
1 | """
2 | @file NeuralODE.py
3 |
4 | Holds the model for the Neural ODE latent dynamics function
5 | """
6 | import torch
7 | import torch.nn as nn
8 |
9 | from torchdiffeq import odeint
10 | from utils.utils import get_act
11 | from models.CommonDynamics import LatentDynamicsModel
12 |
13 |
14 | class RGNResFunction(nn.Module):
15 | def __init__(self, cfg):
16 | """ Standard Residual Recurrent Generative Network dynamics function """
17 | super(RGNResFunction, self).__init__()
18 |
19 | # Build the dynamics network
20 | dynamics_network = []
21 | dynamics_network.extend([
22 | nn.Linear(cfg.latent_dim, cfg.num_hidden),
23 | get_act(cfg.latent_act)
24 | ])
25 |
26 | for _ in range(cfg.num_layers - 1):
27 | dynamics_network.extend([
28 | nn.Linear(cfg.num_hidden, cfg.num_hidden),
29 | get_act(cfg.latent_act)
30 | ])
31 |
32 | dynamics_network.extend([nn.Linear(cfg.num_hidden, cfg.latent_dim), nn.Tanh()])
33 | self.dynamics_network = nn.Sequential(*dynamics_network)
34 |
35 | def forward(self, t, z):
36 | """ Wrapper function for the odeint calculation """
37 | return self.dynamics_network(z)
38 |
39 |
40 | class RGNRes(LatentDynamicsModel):
41 | def __init__(self, cfg):
42 | """ Latent dynamics as parameterized by a global deterministic neural ODE """
43 | super().__init__(cfg)
44 |
45 | # ODE-Net which holds mixture logic
46 | self.dynamics_func = RGNResFunction(cfg)
47 |
48 | def forward(self, x, generation_len):
49 | # Sample z_init
50 | z_init = self.encoder(x)
51 |
52 | # Evaluate forward over timestep
53 | z_cur = z_init
54 | zts = [z_init]
55 | for _ in range(generation_len - 1):
56 | z_cur = self.dynamics_func(None, z_cur)
57 | zts.append(z_cur)
58 |
59 | zt = torch.stack(zts, dim=1)
60 |
61 | # Stack zt and decode zts
62 | x_rec = self.decoder(zt)
63 | return x_rec, zt
64 |
--------------------------------------------------------------------------------
/models/group_b2/RGNRes.py:
--------------------------------------------------------------------------------
1 | """
2 | @file NeuralODE.py
3 |
4 | Holds the model for the Neural ODE latent dynamics function
5 | """
6 | import torch
7 | import torch.nn as nn
8 |
9 | from torchdiffeq import odeint
10 | from utils.utils import get_act
11 | from models.CommonDynamics import LatentDynamicsModel
12 |
13 |
14 | class RGNResFunction(nn.Module):
15 | def __init__(self, cfg):
16 | """ Standard Residual Recurrent Generative Network dynamics function """
17 | super(RGNResFunction, self).__init__()
18 |
19 | # Build the dynamics network
20 | dynamics_network = []
21 | dynamics_network.extend([
22 | nn.Linear(cfg.latent_dim, cfg.num_hidden),
23 | get_act(cfg.latent_act)
24 | ])
25 |
26 | for _ in range(cfg.num_layers - 1):
27 | dynamics_network.extend([
28 | nn.Linear(cfg.num_hidden, cfg.num_hidden),
29 | get_act(cfg.latent_act)
30 | ])
31 |
32 | dynamics_network.extend([nn.Linear(cfg.num_hidden, cfg.latent_dim), nn.Tanh()])
33 | self.dynamics_network = nn.Sequential(*dynamics_network)
34 |
35 | def forward(self, t, z):
36 | """ Wrapper function for the odeint calculation """
37 | return z + self.dynamics_network(z)
38 |
39 |
40 | class RGNRes(LatentDynamicsModel):
41 | def __init__(self, cfg):
42 | """ Latent dynamics as parameterized by a global deterministic neural ODE """
43 | super().__init__(cfg)
44 |
45 | # ODE-Net which holds mixture logic
46 | self.dynamics_func = RGNResFunction(cfg)
47 |
48 | def forward(self, x, generation_len):
49 | # Sample z_init
50 | z_init = self.encoder(x)
51 |
52 | # Evaluate forward over timestep
53 | z_cur = z_init
54 | zts = [z_init]
55 | for _ in range(generation_len - 1):
56 | z_cur = self.dynamics_func(None, z_cur)
57 | zts.append(z_cur)
58 |
59 | zt = torch.stack(zts, dim=1)
60 |
61 | # Stack zt and decode zts
62 | x_rec = self.decoder(zt)
63 | return x_rec, zt
64 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # Automatically generated by https://github.com/damnever/pigar.
2 |
3 | hydra-core==1.3.2
4 | matplotlib==3.7.0
5 | numpy==1.24.2
6 | omegaconf==2.3.0
7 | pytorch-lightning==1.9.0
8 | scikit-image==0.25.0
9 | scikit-learn==1.4.2
10 | torch==1.13.1
11 | torchdiffeq==0.2.3
12 | tqdm==4.64.1
13 |
--------------------------------------------------------------------------------
/scripts/ablation_generation_length.py:
--------------------------------------------------------------------------------
1 | """
2 | @file ablation_generation_length.py
3 |
4 | Holds the cmd calls to train models across different generation lengths used in training
5 | """
6 | import os
7 |
8 | os.system("python3 main.py --generation_len 1 --exptype node_pendulum_1gen --generation_varying False")
9 | os.system("python3 test.py --ckpt_path experiments/node_pendulum_1gen/node/version_1/ --training_len 1")
10 |
11 | os.system("python3 main.py --generation_len 2 --exptype node_pendulum_2gen --generation_varying False")
12 | os.system("python3 test.py --ckpt_path experiments/node_pendulum_2gen/node/version_1/ --training_len 2")
13 |
14 | os.system("python3 main.py --generation_len 3 --exptype node_pendulum_3gen --generation_varying False")
15 | os.system("python3 test.py --ckpt_path experiments/node_pendulum_3gen/node/version_1/ --training_len 3")
16 |
17 | os.system("python3 main.py --generation_len 5 --exptype node_pendulum_5gen --generation_varying False")
18 | os.system("python3 test.py --ckpt_path experiments/node_pendulum_5gen/node/version_1/ --training_len 5")
19 |
20 | os.system("python3 main.py --generation_len 10 --exptype node_pendulum_10gen --generation_varying False")
21 | os.system("python3 test.py --ckpt_path experiments/node_pendulum_10gen/node/version_1/ --training_len 10")
22 |
23 | os.system("python3 main.py --generation_len 20 --exptype node_pendulum_20gen --generation_varying False")
24 | os.system("python3 test.py --ckpt_path experiments/node_pendulum_20gen/node/version_1/ --training_len 20")
25 |
--------------------------------------------------------------------------------
/scripts/ablation_odeintegrator.py:
--------------------------------------------------------------------------------
1 | """
2 | @file ablation_odeintegrator.py
3 |
4 | Holds the cmd calls to train models across different ODE integrators automatically
5 | """
6 | import os
7 |
8 | dev = 0
9 | num_epochs = 100
10 |
11 | os.system(f"python main.py --exptype ablation_odeint_rk4_1 --integrator rk4 --integrator_params step_size=0.5 --num_epochs {num_epochs} --latent_dim 8 --num_hidden 128 --num_layers 3 --num_filt 8 --dev {dev}")
12 | os.system(f"python main.py --exptype ablation_odeint_rk4_0.5 --integrator rk4 --integrator_params step_size=0.5 --num_epochs {num_epochs} --latent_dim 8 --num_hidden 128 --num_layers 3 --num_filt 8 --dev {dev}")
13 | os.system(f"python main.py --exptype ablation_odeint_rk4_0.25 --integrator rk4 --integrator_params step_size=0.25 --num_epochs {num_epochs} --latent_dim 8 --num_hidden 128 --num_layers 3 --num_filt 8 --dev {dev}")
14 | os.system(f"python main.py --exptype ablation_odeint_rk4_0.125 --integrator rk4 --integrator_params step_size=0.125 --num_epochs {num_epochs} --latent_dim 8 --num_hidden 128 --num_layers 3 --num_filt 8 --dev {dev}")
15 |
16 | os.system(f"python main.py --exptype ablation_odeint_dopri5_500 --integrator dopri5 --integrator_params max_num_steps=500 --num_epochs {num_epochs} --latent_dim 8 --num_hidden 128 --num_layers 3 --num_filt 8 --dev {dev}")
17 | os.system(f"python main.py --exptype ablation_odeint_dopri5_1000 --integrator dopri5 --integrator_params max_num_steps=1000 --num_epochs {num_epochs} --latent_dim 8 --num_hidden 128 --num_layers 3 --num_filt 8 --dev {dev}")
18 | os.system(f"python main.py --exptype ablation_odeint_dopri5_2000 --integrator dopri5 --integrator_params max_num_steps=2000 --num_epochs {num_epochs} --latent_dim 8 --num_hidden 128 --num_layers 3 --num_filt 8 --dev {dev}")
19 | os.system(f"python main.py --exptype ablation_odeint_dopri5_5000 --integrator dopri5 --integrator_params max_num_steps=5000 --num_epochs {num_epochs} --latent_dim 8 --num_hidden 128 --num_layers 3 --num_filt 8 --dev {dev}")
20 |
--------------------------------------------------------------------------------
/utils/dataloader.py:
--------------------------------------------------------------------------------
1 | """
2 | @file dataloader.py
3 | @author Ryan Missel
4 |
5 | Holds the LightningDataModule for the available datasets
6 | """
7 | import torch
8 | import numpy as np
9 | import pytorch_lightning
10 | from torch.utils.data import Dataset, DataLoader
11 |
12 |
13 | class SSMDataset(Dataset):
14 | """ Basic Dataset object for the SSM """
15 | def __init__(self, images, labels, states, controls):
16 | self.images = images
17 | self.labels = labels
18 | self.states = states
19 | self.controls = controls
20 |
21 | def __len__(self):
22 | return self.images.shape[0]
23 |
24 | def __getitem__(self, idx):
25 | return torch.Tensor([idx]), self.images[idx], self.states[idx], self.controls[idx], self.labels[idx]
26 |
27 |
28 | class SSMDataModule(pytorch_lightning.LightningDataModule):
29 | """ Custom DataModule object that handles preprocessing all sets of data for a given run """
30 | def __init__(self, cfg):
31 | super(SSMDataModule, self).__init__()
32 | self.cfg = cfg
33 |
34 | def make_loader(self, mode="train", evaluation=False, shuffle=True):
35 | # Load in NPZ
36 | npzfile = np.load(f"data/{self.cfg.dataset}/{mode}.npz")
37 |
38 | # Load in data sources
39 | images = npzfile['image'].astype(np.float32)
40 | labels = npzfile['label'].astype(np.int16)
41 | states = npzfile['state'].astype(np.float32)[:, :, :2]
42 |
43 | # Load control, if it exists, else make a dummy one
44 | controls = npzfile['control'] if 'control' in npzfile else np.zeros((images.shape[0], images.shape[1], 1), dtype=np.float32)
45 |
46 | # Modify based on dataset percent
47 | rand_idx = np.random.choice(range(images.shape[0]), size=int(images.shape[0] * self.cfg.dataset_percent), replace=False)
48 | images = images[rand_idx]
49 | labels = labels[rand_idx]
50 | states = states[rand_idx]
51 | controls = controls[rand_idx]
52 |
53 | # Convert to Tensors
54 | images = torch.from_numpy(images)
55 | labels = torch.from_numpy(labels)
56 | states = torch.from_numpy(states)
57 | controls = torch.from_numpy(controls)
58 |
59 | # Build dataset and corresponding Dataloader
60 | dataset = SSMDataset(images, labels, states, controls)
61 |
62 | # If it is the training setting, set up the iterative dataloader
63 | if mode == "train" and evaluation is False:
64 | sampler = torch.utils.data.RandomSampler(dataset, replacement=True, num_samples=self.cfg.num_steps * self.cfg.batch_size)
65 | dataloader = DataLoader(dataset, sampler=sampler, batch_size=self.cfg.batch_size, drop_last=True)
66 |
67 | # Otherwise, setup a normal dataloader
68 | else:
69 | dataloader = DataLoader(dataset, batch_size=self.cfg.batch_size, shuffle=shuffle)
70 | return dataloader
71 |
72 | def train_dataloader(self):
73 | """ Getter function that builds and returns the training dataloader """
74 | return self.make_loader("train")
75 |
76 | def evaluate_train_dataloader(self):
77 | return self.make_loader("train", evaluation=True, shuffle=False)
78 |
79 | def val_dataloader(self):
80 | """ Getter function that builds and returns the validation dataloader """
81 | return self.make_loader("val", shuffle=False)
82 |
83 | def test_dataloader(self):
84 | """ Getter function that builds and returns the testing dataloader """
85 | return self.make_loader("test", shuffle=False)
86 |
--------------------------------------------------------------------------------
/utils/layers.py:
--------------------------------------------------------------------------------
1 | """
2 | @file layers.py
3 |
4 | Miscellaneous helper Torch layers
5 | """
6 | import torch
7 | import torch.nn as nn
8 |
9 |
10 | class Gaussian(nn.Module):
11 | def __init__(self, in_dim, out_dim, fix_variance=False):
12 | """
13 | Gaussian sample layer consisting of 2 simple linear layers.
14 | Can choose whether to fix the variance or let it be learned (training instability has been shown when learning).
15 |
16 | :param in_dim: input dimension (often a flattened latent embedding from a CNN)
17 | :param out_dim: output dimension
18 | :param fix_variance: whether to set the log-variance as a constant 0.1
19 | """
20 | super(Gaussian, self).__init__()
21 | self.fix_variance = fix_variance
22 |
23 | # Mean layer
24 | self.mu = nn.Sequential(
25 | nn.Linear(in_dim, in_dim // 2),
26 | nn.LeakyReLU(0.1),
27 | nn.Linear(in_dim // 2, out_dim)
28 | )
29 |
30 | # Log-var layer
31 | self.logvar = nn.Sequential(
32 | nn.Linear(in_dim, in_dim // 2),
33 | nn.LeakyReLU(0.1),
34 | nn.Linear(in_dim // 2, out_dim)
35 | )
36 |
37 | def reparameterize(self, mu, logvar):
38 | """ Reparameterization trick to get a sample from the output distribution """
39 | std = torch.exp(0.5 * logvar)
40 | noise = torch.randn_like(std)
41 | z = mu + (noise * std)
42 | return z
43 |
44 | def forward(self, x):
45 | """
46 | Forward function of the Gaussian layer. Handles getting the distributional parameters and sampling a vector
47 | :param x: input vector [BatchSize, InputDim]
48 | """
49 | # Get mu and logvar
50 | mu = self.mu(x)
51 |
52 | if self.fix_variance:
53 | logvar = torch.full_like(mu, fill_value=0.1)
54 | else:
55 | logvar = self.logvar(x)
56 |
57 | # Check on whether mu/logvar are getting out of normal ranges
58 | if (mu < -100).any() or (mu > 85).any() or (logvar < -100).any() or (logvar > 85).any():
59 | print("Explosion in mu/logvar. Mu {} Logvar {}".format(torch.mean(mu), torch.mean(logvar)))
60 |
61 | # Reparameterize and sample
62 | z = self.reparameterize(mu, logvar)
63 | return mu, logvar, z
64 |
65 |
66 | class GroupSwish(nn.Module):
67 | def __init__(self, groups):
68 | """
69 | Swish activation function that works on GroupConvolution by reshaping all input groups back into their
70 | batch shapes, activating them, then reshaping them to 1D GroupConv filters
71 | """
72 | super().__init__()
73 | self.silu = nn.SiLU()
74 | self.groups = groups
75 |
76 | def forward(self, x):
77 | n_ch_group = x.size(1) // self.groups
78 | t = x.shape[2:]
79 | x = x.reshape(-1, self.groups, n_ch_group, *t)
80 | return self.silu(x).reshape(1, self.groups * n_ch_group, *t)
81 |
82 |
83 | class GroupTanh(nn.Module):
84 | def __init__(self, groups):
85 | """
86 | Tanh activation function that works on GroupConvolution by reshaping all input groups back into their
87 | batch shapes, activating them, then reshaping them to 1D GroupConv filters
88 | """
89 | super().__init__()
90 | self.tanh = nn.Tanh()
91 | self.groups = groups
92 |
93 | def forward(self, x):
94 | n_ch_group = x.size(1) // self.groups
95 | t = x.shape[2:]
96 | x = x.reshape(self.groups, n_ch_group, *t)
97 | return self.tanh(x).reshape(1, self.groups * n_ch_group, *t)
98 |
99 |
100 | class Flatten(nn.Module):
101 | def forward(self, input):
102 | """
103 | Handles flattening a Tensor within a nn.Sequential Block
104 |
105 | :param input: Torch object to flatten
106 | """
107 | return input.view(input.size(0), -1)
108 |
109 |
110 | class UnFlatten(nn.Module):
111 | def __init__(self, w):
112 | """
113 | Handles unflattening a vector into a 4D vector in a nn.Sequential Block
114 |
115 | :param w: width of the unflattened image vector
116 | """
117 | super().__init__()
118 | self.w = w
119 |
120 | def forward(self, input):
121 | nc = input[0].numel() // (self.w ** 2)
122 | return input.view(input.size(0), nc, self.w, self.w)
123 |
--------------------------------------------------------------------------------
/utils/metrics.py:
--------------------------------------------------------------------------------
1 | """
2 | @file metrics.py
3 |
4 | Holds a variety of metric computing functions for time-series forecasting models, including
5 | Valid Prediction Time (VPT), Valid Prediction Distance (VPD), etc.
6 | """
7 | import numpy as np
8 | import torch
9 | import torch.nn as nn
10 |
11 | from skimage.filters import threshold_otsu
12 | from sklearn.linear_model import LinearRegression
13 | from sklearn.neural_network import MLPRegressor
14 |
15 |
16 | def vpt(gt, preds, epsilon=0.010, **kwargs):
17 | """
18 | Computes the Valid Prediction Time metric, as proposed in https://openreview.net/pdf?id=qBl8hnwR0px
19 | VPT = argmin_t [MSE(gt, pred) > epsilon]
20 | :param gt: ground truth sequences
21 | :param preds: model predicted sequences
22 | :param epsilon: threshold for valid prediction
23 | """
24 | # Ensure on CPU and numpy
25 | if not isinstance(gt, np.ndarray):
26 | gt = gt.cpu().numpy()
27 | preds = preds.cpu().numpy()
28 |
29 | # Get dimensions
30 | _, timesteps, height, width = gt.shape
31 |
32 | # Get pixel_level MSE at each timestep
33 | mse = (gt - preds) ** 2
34 | mse = np.sum(mse, axis=(2, 3)) / (height * width)
35 |
36 | # Get VPT
37 | vpts = []
38 | for m in mse:
39 | # Get all indices below the given epsilon
40 | indices = np.where(m < epsilon)[0] + 1
41 |
42 | # If there are none below, then add 0
43 | if len(indices) == 0:
44 | vpts.append(0)
45 | continue
46 |
47 | # Append last in list
48 | vpts.append(indices[-1])
49 |
50 | # Return VPT mean over the total timesteps
51 | return np.mean(vpts) / timesteps, np.std(vpts) / timesteps
52 |
53 |
54 | def thresholding(preds, gt):
55 | """
56 | Thresholding function that converts gt and preds into binary images
57 | Activated prediction pixels are found via Otsu's thresholding function
58 | """
59 | N, T = gt.shape[0], gt.shape[1]
60 | res = np.zeros_like(preds)
61 |
62 | # For each sample and timestep, get Otsu's threshold and binarize gt and pred
63 | for n in range(N):
64 | for t in range(T):
65 | img = preds[n, t]
66 | otsu_th = np.max([0.32, threshold_otsu(img)])
67 | res[n, t] = (img > otsu_th).astype(np.float32)
68 | gt[n, t] = (gt[n, t] > 0.55).astype(np.float32)
69 | return res, gt
70 |
71 |
72 | def dst(gt, preds, **kwargs):
73 | """
74 | Computes a Euclidean distance metric between the center of the ball in ground truth and prediction
75 | Activated pixels in the predicted are computed via Otsu's thresholding function
76 | :param gt: ground truth sequences
77 | :param preds: model predicted sequences
78 | """
79 | # Ensure on CPU and numpy
80 | if not isinstance(gt, np.ndarray):
81 | gt = gt.cpu().numpy()
82 | preds = preds.cpu().numpy()
83 |
84 | # Get shapes
85 | num_samples, timesteps, height, width = gt.shape
86 |
87 | # Apply Otsu thresholding function on output
88 | preds, gt = thresholding(preds, gt)
89 |
90 | # Loop over each sample and timestep to get the distance metric
91 | results = np.zeros([num_samples, timesteps])
92 | for n in range(num_samples):
93 | for t in range(timesteps):
94 | # Get all active predicted pixels
95 | a = preds[n, t]
96 | b = gt[n, t]
97 | pos_a = np.where(a == 1)
98 | pos_b = np.where(b == 1)
99 |
100 | # If there are in gt, add 0
101 | if pos_b[0].shape[0] == 0:
102 | results[n, t] = 0
103 | continue
104 |
105 | # Get gt center
106 | center_b = np.array([pos_b[0].mean(), pos_b[1].mean()])
107 |
108 | # Get center of predictions
109 | if pos_a[0].shape[0] != 0:
110 | center_a = np.array([pos_a[0].mean(), pos_a[1].mean()])
111 | # If no pixels above threshold, add the highest possible error in image space
112 | else:
113 | # results[n, t] = np.sqrt(np.sum(np.array([height, width]) ** 2))
114 | center_a = [0, 0]
115 | # continue
116 |
117 | # Get distance metric
118 | dist = np.sum((center_a - center_b) ** 2)
119 | dist = np.sqrt(dist)
120 |
121 | # Add to result
122 | results[n, t] = dist
123 |
124 | return np.mean(results), np.std(results)
125 |
126 |
127 | def vpd(target, output, epsilon=10, **kwargs):
128 | """
129 | Computes the Valid Prediction Time metric, as proposed in https://openreview.net/forum?id=7C9aRX2nBf2
130 | VPD = argmin_t [DST(gt, pred) > epsilon]
131 | :param gt: ground truth sequences
132 | :param preds: model predicted sequences
133 | :param epsilon: threshold for valid prediction
134 | """
135 | # Ensure on CPU and numpy
136 | if not isinstance(output, np.ndarray):
137 | output = output.cpu().numpy()
138 | target = target.cpu().numpy()
139 |
140 | # Get shapes
141 | num_samples, timesteps, height, width = target.shape
142 |
143 | # Apply Otsu thresholding function on output
144 | preds, gt = thresholding(output, target)
145 |
146 | # Loop over each sample and timestep to get the distance metric
147 | dsts = np.zeros([num_samples, timesteps])
148 | for n in range(num_samples):
149 | for t in range(timesteps):
150 | # Get all active predicted pixels
151 | a = preds[n, t]
152 | b = gt[n, t]
153 | pos_a = np.where(a == 1)
154 | pos_b = np.where(b == 1)
155 |
156 | # If there are in gt, add 0
157 | if pos_b[0].shape[0] == 0:
158 | dsts[n, t] = 0
159 | continue
160 |
161 | # Get gt center
162 | center_b = np.array([pos_b[0].mean(), pos_b[1].mean()])
163 |
164 | # Get center of predictions
165 | if pos_a[0].shape[0] != 0:
166 | center_a = np.array([pos_a[0].mean(), pos_a[1].mean()])
167 | # If no pixels above threshold, add the highest possible error in image space
168 | else:
169 | dsts[n, t] = np.sqrt(np.sum(np.array([height, width]) ** 2))
170 | continue
171 |
172 | # Get distance metric
173 | dist = np.sum((center_a - center_b) ** 2)
174 | dist = np.sqrt(dist)
175 |
176 | # Add to result
177 | dsts[n, t] = dist
178 |
179 | # Get the VPD calculation
180 | B, T = dsts.shape
181 | vpdist = np.zeros(B)
182 | for i in range(B):
183 | idx = np.where(dsts[i, :] >= epsilon)[0]
184 | if idx.shape[0] > 0:
185 | vpdist[i] = np.min(idx)
186 | else:
187 | vpdist[i] = T
188 |
189 | # Return VPT mean over the total timesteps
190 | return np.mean(vpdist) / T, np.std(vpdist) / T
191 |
192 |
193 | def reconstruction_mse(output, target, **kwargs):
194 | """ Gets the mean of the per-pixel MSE for the given length of timesteps used for training """
195 | full_pixel_mses = (output[:, :kwargs['length']] - target[:, :kwargs['length']]) ** 2
196 | sequence_pixel_mse = np.mean(full_pixel_mses, axis=(1, 2, 3))
197 | return np.mean(sequence_pixel_mse), np.std(sequence_pixel_mse)
198 |
199 |
200 | def extrapolation_mse(output, target, **kwargs):
201 | """ Gets the mean of the per-pixel MSE for a number of steps past the length used in training """
202 | full_pixel_mses = (output[:, kwargs['cfg'].train_length:] - target[:, kwargs['cfg'].train_length:]) ** 2
203 | if full_pixel_mses.shape[1] == 0:
204 | return 0.0, 0.0
205 |
206 | sequence_pixel_mse = np.mean(full_pixel_mses, axis=(1, 2, 3))
207 | return np.mean(sequence_pixel_mse), np.std(sequence_pixel_mse)
208 |
209 |
210 | def r2fit(latents, gt_state, mlp=False):
211 | """
212 | Computes an R^2 fit value for each ground truth physical state dimension given the latent states at each timestep.
213 | Gets an average per timestep.
214 | :param latents: latent states at each timestep [BatchSize, TimeSteps, LatentSize]
215 | :param gt_state: ground truth physical parameters [BatchSize, TimeSteps, StateSize]
216 | :param mlp: whether to use a non-linear MLP regressor instead of linear regression
217 | """
218 | # Get first dimension of states for evaluation
219 | sins = np.sin(gt_state[:, :, 0])
220 | coss = np.cos(gt_state[:, :, 0])
221 | gt_states = np.stack((sins, coss, gt_state[:, :, 1]), axis=2)
222 |
223 | # Ensure on CPU and numpy
224 | if not isinstance(latents, np.ndarray):
225 | latents = latents.cpu().numpy()
226 | gt_state = gt_state.cpu().numpy()
227 |
228 | # Convert to one large set of latent states
229 | latents = latents.reshape([latents.shape[0] * latents.shape[1], -1])
230 | gt_state = gt_state.reshape([gt_state.shape[0] * gt_state.shape[1], -1])
231 |
232 | # For each dimension of gt_state, get the R^2 value
233 | r2s = []
234 | for sidx in range(gt_state.shape[-1]):
235 | gts = gt_state[:, sidx]
236 |
237 | # Whether to use LinearRegression or an MLP
238 | if mlp:
239 | reg = MLPRegressor().fit(latents, gts)
240 | else:
241 | reg = LinearRegression().fit(latents, gts)
242 |
243 | r2s.append(reg.score(latents, gts))
244 |
245 | # Return r2s for logging
246 | return r2s
247 |
248 |
249 | def normalized_pixel_mse(gt, preds):
250 | """
251 | Handles getting the pixel MSE of a trajectory, but normalizes over the average intensity of the ground truth.
252 | This helps to be able to compare pixel MSE effectively over different domains rather than looking at pure intensity.
253 | :param gt: ground truth sequence [BS, TS, Dim1, Dim2]
254 | :param preds: predictions of the model [BS, TS, Dim1, Dim2]
255 |
256 | TODO: Make sure this metric calculation matches WhichPriorsMatter?
257 | """
258 | mse = nn.MSELoss(reduction='none')(gt, preds)
259 | mse = mse / torch.mean(gt ** 2)
260 | return mse.detach().cpu().numpy(), mse.mean([1, 2, 3]).mean().detach().cpu().numpy()
261 |
--------------------------------------------------------------------------------
/utils/plotting.py:
--------------------------------------------------------------------------------
1 | """
2 | @file plotting.py
3 |
4 | Holds general plotting functions for reconstructions of the bouncing ball dataset
5 | """
6 | import numpy as np
7 | import matplotlib.pyplot as plt
8 |
9 |
10 | def show_images(images, preds, out_loc, num_out=None):
11 | """
12 | Constructs an image of multiple time-series reconstruction samples compared against its relevant ground truth
13 | Saves locally in the given out location
14 | :param images: ground truth images
15 | :param preds: predictions from a given model
16 | :out_loc: where to save the generated image
17 | :param num_out: how many images to stack. If None, stack all
18 | """
19 | assert len(images.shape) == 4 # Assert both matrices are [Batch, Timesteps, H, W]
20 | assert len(preds.shape) == 4
21 | assert type(num_out) is int or type(num_out) is None
22 |
23 | # Make sure objects are in numpy format
24 | if not isinstance(images, np.ndarray):
25 | images = images.cpu().numpy()
26 | preds = preds.cpu().numpy()
27 |
28 | # Splice to the given num_out
29 | if num_out is not None:
30 | images = images[:num_out]
31 | preds = preds[:num_out]
32 |
33 | # Iterate through each sample, stacking into one image
34 | out_image = None
35 | for idx, (gt, pred) in enumerate(zip(images, preds)):
36 | # Pad between individual timesteps
37 | gt = np.pad(gt, pad_width=(
38 | (0, 0), (5, 5), (0, 1)
39 | ), constant_values=1)
40 |
41 | gt = np.hstack([i for i in gt])
42 |
43 | # Pad between individual timesteps
44 | pred = np.pad(pred, pad_width=(
45 | (0, 0), (0, 10), (0, 1)
46 | ), constant_values=1)
47 |
48 | # Stack timesteps into one image
49 | pred = np.hstack([i for i in pred])
50 |
51 | # Stack gt/pred into one image
52 | final = np.vstack((gt, pred))
53 |
54 | # Stack into out_image
55 | if out_image is None:
56 | out_image = final
57 | else:
58 | out_image = np.vstack((out_image, final))
59 |
60 | # Save to out location
61 | plt.imsave(out_loc, out_image, cmap='gray')
62 |
63 |
64 | def get_embedding_trajectories(embeddings, states, out_loc):
65 | """
66 | Handles getting trajectory plots of the embedded states against the true physical states
67 | :param embeddings: vector states over time
68 | :param states: ground truth physical parameter states
69 | """
70 | # Make sure objects are in numpy format
71 | if not isinstance(embeddings, np.ndarray):
72 | embeddings = embeddings.cpu().numpy()
73 | states = states.cpu().numpy()
74 |
75 | # Get embedding trajectory plots
76 | for idx, embedding in enumerate(np.swapaxes(embeddings, 1, 0)):
77 | plt.plot(embedding, label=f"Dim {idx}")
78 |
79 | plt.title("Vector State Trajectories")
80 | plt.savefig(f"{out_loc}/trajectories_embeddings.png")
81 | plt.close()
82 |
83 | # Get physical state trajectories
84 | for idx, embedding in enumerate(np.swapaxes(states, 1, 0)):
85 | plt.plot(embedding, label=f"Dim {idx}")
86 |
87 | plt.title("GT State Trajectories")
88 | plt.savefig(f"{out_loc}/trajectories_gt.png")
89 | plt.close()
90 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | @file utils.py
3 |
4 | Utility functions across files
5 | """
6 | import math
7 | import torch.nn as nn
8 |
9 | from omegaconf import DictConfig, OmegaConf
10 | from torch.optim.lr_scheduler import _LRScheduler
11 |
12 |
13 | def flatten_cfg(cfg: DictConfig):
14 | """ Utility function to flatten the primary submodules of a Hydra config """
15 | # Disable struct flag on the config
16 | OmegaConf.set_struct(cfg, False)
17 |
18 | # Loop through each item, merging with the main cfg if its another DictConfig
19 | for key, value in cfg.copy().items():
20 | if isinstance(value, DictConfig):
21 | cfg.merge_with(cfg.pop(key))
22 |
23 | # Do it a second time for nested cfgs
24 | for key, value in cfg.copy().items():
25 | if isinstance(value, DictConfig):
26 | cfg.merge_with(cfg.pop(key))
27 |
28 | print(cfg)
29 | return cfg
30 |
31 |
32 | def get_model(name):
33 | """ Import and return the specific latent dynamics function by the given name"""
34 | # Lowercase name in case of misspellings
35 | name = name.lower()
36 |
37 | ## Group A Models
38 | if name == "vrnn":
39 | from models.group_a.VRNN import VRNN
40 | return VRNN
41 |
42 | if name == "dkf":
43 | from models.group_a.DKF import DKF
44 | return DKF
45 |
46 | ## Group B1 Models
47 | if name == "kvae":
48 | from models.group_b1.KVAE import KVAE
49 | return KVAE
50 |
51 | ## Group B2 Models
52 | if name == "node":
53 | from models.group_b2.NeuralODE import NeuralODE
54 | return NeuralODE
55 |
56 | if name == "rgnres":
57 | from models.group_b2.RGNRes import RGNRes
58 | return RGNRes
59 |
60 | # Given no correct model type, raise error
61 | raise NotImplementedError("Model type {} not implemented.".format(name))
62 |
63 |
64 | def get_act(act="relu"):
65 | """
66 | Return torch function of a given activation function
67 | :param act: activation function
68 | :return: torch object
69 | """
70 | if act == "relu":
71 | return nn.ReLU()
72 | elif act == "leaky_relu":
73 | return nn.LeakyReLU(0.1)
74 | elif act == "sigmoid":
75 | return nn.Sigmoid()
76 | elif act == "tanh":
77 | return nn.Tanh()
78 | elif act == "linear":
79 | return nn.Identity()
80 | elif act == 'softplus':
81 | return nn.modules.activation.Softplus()
82 | elif act == 'softmax':
83 | return nn.Softmax()
84 | elif act == "swish":
85 | return nn.SiLU()
86 | else:
87 | return None
88 |
89 |
90 | def determine_annealing_factor(n_updates, min_anneal_factor=0.0, anneal_update=10000):
91 | """
92 | Handles annealing the KL restriction over a number of update steps to slowly introduce the regularization
93 | to ensure a strong initial fit has been set
94 | :param min_anneal_factor: minimum
95 | :param anneal_update: over how long of updates to apply the annealing factor
96 | :param epoch: current epoch number
97 | :param n_batch: number of total batches within an epoch
98 | :param batch_idx: current batch idx within the epoch
99 | :return: weight of the kl annealing factor for the loss term
100 | """
101 | if anneal_update > 0 and n_updates < anneal_update:
102 | anneal_factor = min_anneal_factor + \
103 | (1.0 - min_anneal_factor) * (
104 | (n_updates / anneal_update)
105 | )
106 | else:
107 | anneal_factor = 1.0
108 | return anneal_factor
109 |
110 |
111 | class CosineAnnealingWarmRestartsWithDecayAndLinearWarmup(_LRScheduler):
112 | r"""Set the learning rate of each parameter group using a cosine annealing
113 | schedule, where :math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}`
114 | is the number of epochs since the last restart and :math:`T_{i}` is the number
115 | of epochs between two warm restarts in SGDR:
116 | """
117 |
118 | def __init__(self, optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1, verbose=False, warmup_steps=350, decay=1):
119 | if T_0 <= 0 or not isinstance(T_0, int):
120 | raise ValueError("Expected positive integer T_0, but got {}".format(T_0))
121 | if T_mult < 1 or not isinstance(T_mult, int):
122 | raise ValueError("Expected integer T_mult >= 1, but got {}".format(T_mult))
123 | self.T_0 = T_0
124 | self.T_i = T_0
125 | self.T_mult = T_mult
126 | self.eta_min = eta_min
127 | self.T_cur = last_epoch
128 | super(CosineAnnealingWarmRestartsWithDecayAndLinearWarmup, self).__init__(optimizer, last_epoch, verbose)
129 |
130 | # Decay attributes
131 | self.decay = decay
132 | self.initial_lrs = self.base_lrs
133 |
134 | # Warmup attributes
135 | self.warmup_steps = warmup_steps
136 | self.current_steps = 0
137 |
138 | def get_lr(self):
139 | return [
140 | (self.current_steps / self.warmup_steps) *
141 | (self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * self.T_cur / self.T_i)) / 2)
142 | for base_lr in self.base_lrs
143 | ]
144 |
145 | def step(self, epoch=None):
146 | """Step could be called after every batch update"""
147 | if epoch is None and self.last_epoch < 0:
148 | epoch = 0
149 |
150 | if self.T_cur + 1 == self.T_i:
151 | if self.verbose:
152 | print("multiplying base_lrs by {:.4f}".format(self.decay))
153 | self.base_lrs = [base_lr * self.decay for base_lr in self.base_lrs]
154 |
155 | if epoch is None:
156 | epoch = self.last_epoch + 1
157 | self.T_cur = self.T_cur + 1
158 |
159 | if self.current_steps < self.warmup_steps:
160 | self.current_steps += 1
161 |
162 | if self.T_cur >= self.T_i:
163 | self.T_cur = self.T_cur - self.T_i
164 | self.T_i = self.T_i * self.T_mult
165 |
166 | self.last_epoch = math.floor(epoch)
167 |
168 | class _enable_get_lr_call:
169 |
170 | def __init__(self, o):
171 | self.o = o
172 |
173 | def __enter__(self):
174 | self.o._get_lr_called_within_step = True
175 | return self
176 |
177 | def __exit__(self, type, value, traceback):
178 | self.o._get_lr_called_within_step = False
179 | return self
180 |
181 | with _enable_get_lr_call(self):
182 | for i, data in enumerate(zip(self.optimizer.param_groups, self.get_lr())):
183 | param_group, lr = data
184 | param_group['lr'] = lr
185 | self.print_lr(self.verbose, i, lr, epoch)
186 |
187 | self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
188 |
--------------------------------------------------------------------------------