├── requirements_pip.txt ├── __init__.py ├── code ├── __init__.py ├── data │ ├── __init__.py │ ├── RLC │ │ ├── __init__.py │ │ └── GenerateDataset.py │ ├── DampedPendulum │ │ ├── __init__.py │ │ └── GenerateDataset.py │ ├── DoublePendulum │ │ ├── __init__.py │ │ └── GenerateDataset.py │ └── ReactionDiffusion │ │ ├── __init__.py │ │ └── GenerateDataset.py ├── scripts │ ├── __init__.py │ ├── configs │ │ ├── ReactionDiffusion │ │ │ ├── APHYNITY.yaml │ │ │ ├── APHYNITY_plus.yaml │ │ │ ├── HVAE.yaml │ │ │ └── HVAE_plus.yaml │ │ ├── Pendulum │ │ │ ├── HVAE.yaml │ │ │ ├── APHYNITY.yaml │ │ │ ├── HVAE_plus.yaml │ │ │ └── APHYNITY_plus.yaml │ │ ├── RLC │ │ │ ├── APHYNITY.yaml │ │ │ ├── HVAE.yaml │ │ │ ├── APHYNITY_plus.yaml │ │ │ └── HVAE_plus.yaml │ │ └── DoublePendulum │ │ │ ├── Fp_only.yaml │ │ │ ├── APHYNITY.yaml │ │ │ ├── Fa_only.yaml │ │ │ ├── APHYNITY_plus.yaml │ │ │ ├── APHYNITY_robustness.yaml │ │ │ ├── APHYNITY_robustness_10epochs.yaml │ │ │ ├── HVAE.yaml │ │ │ ├── HVAE_robustness.yaml │ │ │ ├── HVAE_robustness_10epochs.yaml │ │ │ └── HVAE_plus.yaml │ ├── run_experiments_robustness.py │ └── run_experiments.py ├── nn │ ├── __pycache__ │ │ ├── mlp.cpython-310.pyc │ │ ├── unet.cpython-310.pyc │ │ ├── utils.cpython-310.pyc │ │ └── __init__.cpython-310.pyc │ ├── __init__.py │ ├── utils.py │ ├── mlp.py │ └── unet.py ├── utils │ ├── __pycache__ │ │ ├── utils.cpython-310.pyc │ │ ├── __init__.cpython-310.pyc │ │ ├── loaders.cpython-310.pyc │ │ └── plotter.cpython-310.pyc │ ├── __init__.py │ ├── loaders.py │ ├── double_pendulum.py │ ├── utils.py │ └── plotter.py ├── hybrid_models │ ├── __pycache__ │ │ ├── HVAE.cpython-310.pyc │ │ ├── APHYNITY.cpython-310.pyc │ │ ├── __init__.cpython-310.pyc │ │ └── HybridAutoencoder.cpython-310.pyc │ ├── __init__.py │ ├── HybridAutoencoder.py │ ├── APHYNITY.py │ └── HVAE.py └── simulators │ ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── NoSimulator.cpython-310.pyc │ ├── RLCCircuit.cpython-310.pyc │ ├── DampedPendulum.cpython-310.pyc │ ├── DoublePendulum.cpython-310.pyc │ ├── GenericSimulator.cpython-310.pyc │ └── ReactionDiffusion.cpython-310.pyc │ ├── __init__.py │ ├── NoSimulator.py │ ├── GenericSimulator.py │ ├── DoublePendulum.py │ ├── RLCCircuit.py │ ├── ReactionDiffusion.py │ └── DampedPendulum.py ├── figures └── improved_diffusion.png ├── requirements.txt ├── CONTRIBUTING.md ├── README.md ├── License.txt └── CODE_OF_CONDUCT.md /requirements_pip.txt: -------------------------------------------------------------------------------- 1 | torchdiffeq==0.2.3 2 | torchdyn==1.0.3 3 | torchvision==0.14.0 -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 3 | 4 | -------------------------------------------------------------------------------- /code/__init__.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 3 | 4 | -------------------------------------------------------------------------------- /figures/improved_diffusion.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-robust-expert-augmentations/HEAD/figures/improved_diffusion.png -------------------------------------------------------------------------------- /code/data/__init__.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 3 | 4 | -------------------------------------------------------------------------------- /code/data/RLC/__init__.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 3 | 4 | -------------------------------------------------------------------------------- /code/scripts/__init__.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 3 | 4 | -------------------------------------------------------------------------------- /code/data/DampedPendulum/__init__.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 3 | 4 | -------------------------------------------------------------------------------- /code/data/DoublePendulum/__init__.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 3 | 4 | -------------------------------------------------------------------------------- /code/nn/__pycache__/mlp.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-robust-expert-augmentations/HEAD/code/nn/__pycache__/mlp.cpython-310.pyc -------------------------------------------------------------------------------- /code/nn/__pycache__/unet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-robust-expert-augmentations/HEAD/code/nn/__pycache__/unet.cpython-310.pyc -------------------------------------------------------------------------------- /code/data/ReactionDiffusion/__init__.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 3 | 4 | -------------------------------------------------------------------------------- /code/nn/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-robust-expert-augmentations/HEAD/code/nn/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /code/nn/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-robust-expert-augmentations/HEAD/code/nn/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /code/utils/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-robust-expert-augmentations/HEAD/code/utils/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /code/utils/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-robust-expert-augmentations/HEAD/code/utils/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /code/utils/__pycache__/loaders.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-robust-expert-augmentations/HEAD/code/utils/__pycache__/loaders.cpython-310.pyc -------------------------------------------------------------------------------- /code/utils/__pycache__/plotter.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-robust-expert-augmentations/HEAD/code/utils/__pycache__/plotter.cpython-310.pyc -------------------------------------------------------------------------------- /code/hybrid_models/__pycache__/HVAE.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-robust-expert-augmentations/HEAD/code/hybrid_models/__pycache__/HVAE.cpython-310.pyc -------------------------------------------------------------------------------- /code/simulators/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-robust-expert-augmentations/HEAD/code/simulators/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /code/hybrid_models/__pycache__/APHYNITY.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-robust-expert-augmentations/HEAD/code/hybrid_models/__pycache__/APHYNITY.cpython-310.pyc -------------------------------------------------------------------------------- /code/hybrid_models/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-robust-expert-augmentations/HEAD/code/hybrid_models/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /code/simulators/__pycache__/NoSimulator.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-robust-expert-augmentations/HEAD/code/simulators/__pycache__/NoSimulator.cpython-310.pyc -------------------------------------------------------------------------------- /code/simulators/__pycache__/RLCCircuit.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-robust-expert-augmentations/HEAD/code/simulators/__pycache__/RLCCircuit.cpython-310.pyc -------------------------------------------------------------------------------- /code/simulators/__pycache__/DampedPendulum.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-robust-expert-augmentations/HEAD/code/simulators/__pycache__/DampedPendulum.cpython-310.pyc -------------------------------------------------------------------------------- /code/simulators/__pycache__/DoublePendulum.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-robust-expert-augmentations/HEAD/code/simulators/__pycache__/DoublePendulum.cpython-310.pyc -------------------------------------------------------------------------------- /code/simulators/__pycache__/GenericSimulator.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-robust-expert-augmentations/HEAD/code/simulators/__pycache__/GenericSimulator.cpython-310.pyc -------------------------------------------------------------------------------- /code/simulators/__pycache__/ReactionDiffusion.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-robust-expert-augmentations/HEAD/code/simulators/__pycache__/ReactionDiffusion.cpython-310.pyc -------------------------------------------------------------------------------- /code/hybrid_models/__pycache__/HybridAutoencoder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-robust-expert-augmentations/HEAD/code/hybrid_models/__pycache__/HybridAutoencoder.cpython-310.pyc -------------------------------------------------------------------------------- /code/nn/__init__.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 3 | 4 | from code.nn.mlp import MLP, act_dict 5 | from code.nn.unet import UNet, ConditionalUNet, ConditionalUNetReactionDiffusion 6 | from code.nn.utils import Permute, kl_gaussians -------------------------------------------------------------------------------- /code/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 3 | 4 | from code.utils.plotter import plot_curves, plot_curves_partial, plot_diffusion, plot_curves_complete, plot_curves_double_pendulum 5 | from code.utils.loaders import load_data 6 | from code.utils.utils import get_models -------------------------------------------------------------------------------- /code/hybrid_models/__init__.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 3 | 4 | from code.hybrid_models.APHYNITY import APHYNITYAutoencoder, APHYNITYAutoencoderReactionDiffusion 5 | from code.hybrid_models.HVAE import HybridVAE, HybridVAEReactionDiffusion, HybridVAEDoublePendulum 6 | from code.hybrid_models.HybridAutoencoder import HybridAutoencoder 7 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | # platform: osx-arm64 4 | jupyter=1.0.0 5 | matplotlib=3.5.3 6 | #networkx=2.8.8 7 | #normalizingflows=0.1 8 | numpy=1.23.3 9 | pandas=1.4.4 10 | pickleshare=0.7.5 11 | pip=22.2.2 12 | python=3.10.6 13 | pyyaml=6.0 14 | scikit-learn=1.1.3 15 | scipy=1.9.3 16 | seaborn=0.12.1 17 | pytorch -------------------------------------------------------------------------------- /code/simulators/__init__.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 3 | 4 | from code.simulators.GenericSimulator import GenericSimulator, PhysicalModel 5 | from code.simulators.DampedPendulum import DampedPendulum 6 | from code.simulators.RLCCircuit import RLCCircuit 7 | from code.simulators.ReactionDiffusion import ReactionDiffusion 8 | from code.simulators.DoublePendulum import DoublePendulum 9 | from code.simulators.NoSimulator import NoSimulator 10 | -------------------------------------------------------------------------------- /code/scripts/configs/ReactionDiffusion/APHYNITY.yaml: -------------------------------------------------------------------------------- 1 | name: 'Reaction Diffusion - APHYNITY' 2 | parameters: 3 | data_path: "code/data/ReactionDiffusion" 4 | optimization: 5 | model: "APHYNITYReactionDiffusion" 6 | learning_rate_fa: 0.0005 7 | weight_decay_fa: .0 8 | n_epochs: 300 9 | n_iter: 1 10 | lambda_0: 10. 11 | tau_2: 5. 12 | fa: 13 | regularization_scheme: 14 | augmented: False 15 | zp_priors: 16 | a: 17 | b: 18 | simulator: 19 | name: 'ReactionDiffusion' 20 | init_param: 21 | a: -10000000. 22 | b: -10000000. 23 | partial_model_param: 24 | - a 25 | - b 26 | solver: 'APHYNITYReactionDiffusion' -------------------------------------------------------------------------------- /code/scripts/configs/Pendulum/HVAE.yaml: -------------------------------------------------------------------------------- 1 | name: 'Pendulum - HVAE' 2 | parameters: 3 | optimization: 4 | normalize_loss: False 5 | learning_rate_fa: 0.0005 6 | weight_decay_fa: .000001 7 | n_epochs: 1000 8 | b_size: 200 9 | model: HybridVAE 10 | act_mu_p: Softplus 11 | gamma: 1. 12 | alpha: 0.01 13 | beta: 0.01 14 | omicron: 0. 15 | zp_priors: 16 | omega_0: 17 | mu: 2. 18 | sigma: .7 19 | min: .392 20 | max: 3.53 21 | data_path: "code/data/DampedPendulum" 22 | simulator: 23 | name: 'DampedPendulum' 24 | init_param: 25 | omega_0: .5 26 | true_param: 27 | omega_0: 1. 28 | alpha: .5 29 | partial_model_param: 30 | - omega_0 -------------------------------------------------------------------------------- /code/scripts/configs/ReactionDiffusion/APHYNITY_plus.yaml: -------------------------------------------------------------------------------- 1 | name: 'Reaction Diffusion - APHYNITY+' 2 | parameters: 3 | data_path: "code/data/ReactionDiffusion" 4 | optimization: 5 | learning_rate_fa: 0.0005 6 | weight_decay_fa: .0 7 | n_epochs: 300 8 | n_iter: 1 9 | lambda_0: 10. 10 | tau_2: 5. 11 | fa_regularization_scheme: 12 | augmented: True 13 | path_model: "here" 14 | zp_priors: 15 | a: 16 | min: .001 17 | max: .004 18 | b: 19 | min: .001 20 | max: .01 21 | simulator: 22 | name: 'ReactionDiffusion' 23 | init_param: 24 | a: -10000000. 25 | b: -10000000. 26 | partial_model_param: 27 | - a 28 | - b 29 | solver: 'APHYNITYReactionDiffusion' -------------------------------------------------------------------------------- /code/simulators/NoSimulator.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torchdyn.core import NeuralODE 7 | from torchdyn import * 8 | from code.simulators.GenericSimulator import PhysicalModel 9 | import math 10 | 11 | 12 | class NoSimulator(PhysicalModel): 13 | def __init__(self, param_values=None, trainable_param=None): 14 | super(NoSimulator, self).__init__({}, []) 15 | 16 | def forward(self, t, x): 17 | return 0 * x 18 | 19 | def parameterized_forward(self, t, x, **parameters): 20 | return 0 * x 21 | 22 | def get_x_labels(self): 23 | return [] 24 | 25 | def get_name(self): 26 | return "No Fp" -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contribution Guide 2 | 3 | Thanks for your interest in contributing. This project was released to accompany a research paper for purposes of reproducability, and beyond its publication there are limited plans for future development of the repository. 4 | 5 | While we welcome new pull requests and issues please note that our response may be limited. Forks and out-of-tree improvements are strongly encouraged. 6 | 7 | ## Before you get started 8 | 9 | By submitting a pull request, you represent that you have the right to license your contribution to Apple and the community, and agree by submitting the patch that your contributions are licensed under the [LICENSE](LICENSE). 10 | 11 | We ask that all community members read and observe our [Code of Conduct](CODE_OF_CONDUCT.md). -------------------------------------------------------------------------------- /code/scripts/configs/ReactionDiffusion/HVAE.yaml: -------------------------------------------------------------------------------- 1 | name: 'Reaction Diffusion - HVAE' 2 | parameters: 3 | data_path: "code/data/ReactionDiffusion" 4 | optimization: 5 | learning_rate_fa: 0.0005 6 | weight_decay_fa: .00001 7 | n_epochs: 1000 8 | b_size: 100 9 | gamma: 1. 10 | alpha: .01 11 | beta: .01 12 | omicron: 0. 13 | model: HybridVAEReactionDiffusion 14 | zp_priors: 15 | a: 16 | mu: 0.0015 17 | sigma: 0.0004 18 | min: 0.001 19 | max: 0.004 20 | b: 21 | mu: 0.005 22 | sigma: 0.0012 23 | min: 0.001 24 | max: 0.01 25 | simulator: 26 | name: 'ReactionDiffusion' 27 | simulator_init_param: 28 | a: -100000. 29 | b: -100000. 30 | simulator_partial_model_param: 31 | - a 32 | - b -------------------------------------------------------------------------------- /code/scripts/configs/RLC/APHYNITY.yaml: -------------------------------------------------------------------------------- 1 | name: 'RLC - APHYNITY' 2 | parameters: 3 | data_path: "code/data/RLC" 4 | optimization: 5 | learning_rate_fa: 0.0005 6 | weight_decay_fa: .0 7 | n_epochs: 50 8 | b_size: 100 9 | n_iter: 5 10 | lambda_0: 10. 11 | tau_2: 5. 12 | fa_linear: False 13 | fa_n_layers: 3 14 | fa_n_neurons: 150 15 | fa_act: ReLU 16 | fa_final_act: 17 | fa_regularization_scheme: 18 | augmented: False 19 | zp_priors: 20 | L: 21 | C: 22 | model: "APHYNITY" 23 | simulator: 24 | name: 'RLC' 25 | init_param: 26 | omega: 2. 27 | V_a: 2.5 28 | V_c: 1. 29 | L: -10000000. 30 | C: -10000000. 31 | true_param: 32 | omega: 2. 33 | V_a: 2.5 34 | V_c: 1. 35 | partial_model_param: 36 | - L 37 | - C 38 | solver: 'APHYNITY' -------------------------------------------------------------------------------- /code/scripts/configs/Pendulum/APHYNITY.yaml: -------------------------------------------------------------------------------- 1 | name: 'Pendulum - APHYNITY' 2 | parameters: 3 | optimization: 4 | learning_rate_fa: 0.0005 5 | weight_decay_fa: .0 6 | n_epochs: 50 7 | n_iter: 1 8 | lambda_0: 200. 9 | tau_2: 5. 10 | fa: 11 | linear: False 12 | hidden_n: 3 13 | hidden_w: 150 14 | act: ReLU 15 | final_act: 16 | regularization_scheme: 17 | fp_param_converter_hidden_n: 3 18 | fp_param_converter_hidden_w: 150 19 | augmented: False 20 | zp_priors: 21 | omega_0: 22 | min: .5 23 | max: 1.5 24 | model: "APHYNITY" 25 | data_path: "code/data/DampedPendulum" 26 | simulator: 27 | name: 'DampedPendulum' 28 | init_param: 29 | omega_0: .5 30 | true_param: 31 | omega_0: 1. 32 | alpha: .5 33 | partial_model_param: 34 | - omega_0 35 | solver: 'APHYNITY' -------------------------------------------------------------------------------- /code/scripts/configs/Pendulum/HVAE_plus.yaml: -------------------------------------------------------------------------------- 1 | name: 'Pendulum - HVAE' 2 | parameters: 3 | optimization: 4 | normalize_loss: False 5 | learning_rate_fa: 0.0005 6 | weight_decay_fa: .000001 7 | n_epochs: 1000 8 | b_size: 200 9 | model: HybridVAE 10 | act_mu_p: Softplus 11 | gamma: 1. 12 | alpha: 0.01 13 | beta: 0.01 14 | omicron: 0. 15 | zp_priors: 16 | omega_0: 17 | mu: 2. 18 | sigma: .7 19 | min: .5 20 | max: 3.5 21 | augmented: True 22 | loss_params: False 23 | path_model: "code/data/DampedPendulum/runs/HVAE.yaml/02_15_2023_11_31_17/HybridVAE_best_valid_model.pt" 24 | data_path: "code/data/DampedPendulum" 25 | simulator: 26 | name: 'DampedPendulum' 27 | init_param: 28 | omega_0: .5 29 | true_param: 30 | omega_0: 1. 31 | alpha: .5 32 | partial_model_param: 33 | - omega_0 -------------------------------------------------------------------------------- /code/nn/utils.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class Permute(nn.Module): 9 | def __init__(self, order): 10 | super().__init__() 11 | self.order = order 12 | 13 | def forward(self, x): 14 | return x.permute(self.order) 15 | 16 | 17 | class ReactionDiffusionParametersScaler(nn.Module): 18 | def __init__(self, scale=.01): 19 | super(ReactionDiffusionParametersScaler, self).__init__() 20 | self.scale = scale 21 | 22 | def forward(self, x): 23 | return torch.sigmoid(x) * self.scale 24 | 25 | 26 | def kl_gaussians(mu_1, sigma_1, mu_2, sigma_2): 27 | p = torch.distributions.Normal(mu_1, sigma_1) 28 | q = torch.distributions.Normal(mu_2, sigma_2) 29 | kl = torch.distributions.kl_divergence(p, q) 30 | return kl.sum(1) -------------------------------------------------------------------------------- /code/scripts/configs/RLC/HVAE.yaml: -------------------------------------------------------------------------------- 1 | name: 'RLC - HVAE' 2 | parameters: 3 | data_path: "code/data/RLC" 4 | optimization: 5 | normalize_loss: False 6 | learning_rate_fa: 0.0005 7 | weight_decay_fa: .000001 8 | n_epochs: 1000 9 | b_size: 100 10 | act_mu_p: Softplus 11 | model: HybridVAE 12 | gamma: 1. 13 | alpha: .01 14 | beta: .01 15 | posterior_type: positive_gaussian 16 | gp_1_hidden: 17 | - 100 18 | - 100 19 | - 100 20 | zp_priors: 21 | L: 22 | mu: 2.5 23 | sigma: .8 24 | min: 1. 25 | max: 5. 26 | C: 27 | mu: 1. 28 | sigma: .4 29 | min: .5 30 | max: 2.5 31 | simulator: 32 | name: 'RLC' 33 | init_param: 34 | omega: 2. 35 | V_a: 2.5 36 | V_c: 1. 37 | L: -10000000. 38 | C: -10000000. 39 | true_param: 40 | omega: 2. 41 | V_a: 2.5 42 | V_c: 1. 43 | partial_model_param: 44 | - L 45 | - C -------------------------------------------------------------------------------- /code/scripts/configs/Pendulum/APHYNITY_plus.yaml: -------------------------------------------------------------------------------- 1 | name: 'Pendulum - APHYNITY+' 2 | parameters: 3 | optimization: 4 | learning_rate_fa: 0.0005 5 | weight_decay_fa: .0 6 | n_epochs: 50 7 | n_iter: 5 8 | lambda_0: 10. 9 | tau_2: 5. 10 | fa: 11 | linear: False 12 | hidden_n: 3 13 | hidden_w: 50 14 | act: ReLU 15 | final_act: 16 | regularization_scheme: 17 | fp_param_converter_hidden_n: 3 18 | fp_param_converter_hidden_w: 150 19 | augmented: True 20 | combined_augmentation: True 21 | zp_priors: 22 | omega_0: 23 | min: .5 24 | max: 3.5 25 | model: "APHYNITY" 26 | path_model: "code/data/DampedPendulum/runs/APHYNITY.yaml/11_25_2022_20_59_02/APHYNITY_best_valid_model.pt" 27 | data_path: "code/data/DampedPendulum" 28 | simulator: 29 | name: 'DampedPendulum' 30 | init_param: 31 | omega_0: .5 32 | true_param: 33 | omega_0: 1. 34 | alpha: .5 35 | partial_model_param: 36 | - omega_0 37 | solver: 'APHYNITY' -------------------------------------------------------------------------------- /code/hybrid_models/HybridAutoencoder.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 3 | 4 | from abc import abstractmethod 5 | import torch.nn as nn 6 | import torch 7 | 8 | 9 | class HybridAutoencoder(nn.Module): 10 | def __init__(self): 11 | super(HybridAutoencoder, self).__init__() 12 | 13 | @abstractmethod 14 | def forward(self, t_span, x) -> tuple[torch.FloatTensor, torch.FloatTensor]: 15 | pass 16 | 17 | @abstractmethod 18 | def augmented_data(self, t_span, x) -> tuple[torch.FloatTensor, torch.FloatTensor]: 19 | pass 20 | 21 | @abstractmethod 22 | def predicted_parameters(self, t_span, x, zero_param=False) -> torch.FloatTensor: 23 | pass 24 | 25 | @abstractmethod 26 | def predicted_parameters_as_dict(self, t_span, x, zero_param=False) -> dict: 27 | pass 28 | 29 | @abstractmethod 30 | def loss(self, t_span, x) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: 31 | pass 32 | -------------------------------------------------------------------------------- /code/scripts/configs/RLC/APHYNITY_plus.yaml: -------------------------------------------------------------------------------- 1 | name: 'RLC - APHYNITY+' 2 | parameters: 3 | data_path: "code/data/RLC" 4 | optimization: 5 | augmented: True 6 | path_model: "code/data/RLC/runs/APHYNITY.yaml/02_15_2023_09_40_48/APHYNITY_best_valid_model.pt" 7 | learning_rate_fa: 0.0005 8 | weight_decay_fa: .0 9 | n_epochs: 50 10 | b_size: 100 11 | n_iter: 5 12 | lambda_0: 10. 13 | tau_2: 5. 14 | fa_linear: False 15 | fa_n_layers: 3 16 | fa_n_neurons: 150 17 | fa_act: ReLU 18 | fa_final_act: 19 | fa_regularization_scheme: 20 | zp_priors: 21 | L: 22 | mu: 2. 23 | sigma: .8 24 | min: 1. 25 | max: 5. 26 | C: 27 | mu: 1. 28 | sigma: .4 29 | min: .5 30 | max: 2.5 31 | model: "APHYNITY" 32 | simulator: 33 | name: 'RLC' 34 | init_param: 35 | omega: 2. 36 | V_a: 2.5 37 | V_c: 1. 38 | L: -10000000. 39 | C: -10000000. 40 | true_param: 41 | omega: 2. 42 | V_a: 2.5 43 | V_c: 1. 44 | partial_model_param: 45 | - L 46 | - C 47 | solver: 'APHYNITY' -------------------------------------------------------------------------------- /code/scripts/configs/RLC/HVAE_plus.yaml: -------------------------------------------------------------------------------- 1 | name: 'RLC - HVAE+' 2 | parameters: 3 | data_path: "code/data/RLC" 4 | optimization: 5 | normalize_loss: False 6 | learning_rate_fa: 0.0005 7 | weight_decay_fa: .000001 8 | n_epochs: 1000 9 | b_size: 200 10 | act_mu_p: Softplus 11 | model: HybridVAE 12 | gamma: 1. 13 | alpha: .01 14 | beta: .01 15 | posterior_type: positive_gaussian 16 | gp_1_hidden: 17 | - 100 18 | - 100 19 | - 100 20 | zp_priors: 21 | L: 22 | mu: 2.5 23 | sigma: .8 24 | min: 1. 25 | max: 5. 26 | C: 27 | mu: 1. 28 | sigma: .4 29 | min: .5 30 | max: 2.5 31 | augmented: True 32 | loss_params: True 33 | path_model: "code/data/RLC/runs/HVAE.yaml/11_25_2022_10_23_37/HybridVAE_best_valid_model.pt" 34 | simulator: 35 | name: 'RLC' 36 | init_param: 37 | omega: 2. 38 | V_a: 2.5 39 | V_c: 1. 40 | L: -10000000. 41 | C: -10000000. 42 | true_param: 43 | omega: 2. 44 | V_a: 2.5 45 | V_c: 1. 46 | partial_model_param: 47 | - L 48 | - C -------------------------------------------------------------------------------- /code/scripts/configs/ReactionDiffusion/HVAE_plus.yaml: -------------------------------------------------------------------------------- 1 | name: 'Reaction Diffusion - HVAE+' 2 | parameters: 3 | optimization_learning_rate_fa: 0.0005 4 | optimization_weight_decay_fa: .00001 5 | optimization_n_epochs: 1000 6 | optimization_b_size: 100 7 | optimization_gamma: 1. 8 | optimization_alpha: .01 9 | optimization_beta: .01 10 | optimization_omicron: 0. 11 | optimization_model: HybridVAEReactionDiffusion 12 | data_path: "code/data/ReactionDiffusion" 13 | optimization_zp_priors: 14 | a: 15 | mu: 0.0015 16 | sigma: 0.0004 17 | min: 0.001 18 | max: 0.004 19 | b: 20 | mu: 0.005 21 | sigma: 0.0012 22 | min: 0.001 23 | max: 0.01 24 | simulator_name: 'ReactionDiffusion' 25 | optimization_augmented: True 26 | optimization_path_model: "code/data/models/HVAE/vae_amortized_HybridVAEReactionDiffusion_Reaction Diffusion['a', 'b']_HVAE_best_valid_model.pt" 27 | simulator_init_param_a: -100000. 28 | simulator_init_param_b: -100000. 29 | simulator_true_param_a: -100000. 30 | simulator_true_param_b: -100000. 31 | simulator_true_param_k: -100000. 32 | simulator_partial_model_param: 33 | - a 34 | - b -------------------------------------------------------------------------------- /code/utils/loaders.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 3 | 4 | import torch.utils.data as data_utils 5 | import pickle 6 | 7 | 8 | def load_data(data_path, device): 9 | with open(r"%s/train.pkl" % data_path, "rb") as output_file: 10 | t_train, x_train, true_param_train = pickle.load(output_file) 11 | t_train, x_train = t_train.to(device), x_train 12 | 13 | train = data_utils.TensorDataset(x_train) 14 | dl_train = data_utils.DataLoader(train, batch_size=100, shuffle=True) 15 | with open(r"%s/valid.pkl" % data_path, "rb") as output_file: 16 | t_valid, x_valid, true_param_valid = pickle.load(output_file) 17 | t_valid, x_valid = t_valid.to(device), x_valid.to(device) 18 | with open(r"%s/test_shifted.pkl" % data_path, "rb") as output_file: 19 | t_test_shifted, x_test_shifted, true_param_test_shifted = pickle.load(output_file) 20 | t_test_shifted, x_test_shifted = t_test_shifted.to(device), x_test_shifted.to(device) 21 | return (t_train, x_train, true_param_train), dl_train, (t_valid, x_valid, true_param_valid), (t_test_shifted, x_test_shifted, true_param_test_shifted) -------------------------------------------------------------------------------- /code/scripts/configs/DoublePendulum/Fp_only.yaml: -------------------------------------------------------------------------------- 1 | name: 'Double Pendulum - APHYNITY' 2 | parameters: 3 | data_path: "code/data/DoublePendulum" 4 | optimization: 5 | save_all_models: True 6 | learning_rate_fa: 0.0005 7 | weight_decay_fa: .0 8 | n_epochs: 100 9 | n_iter: 1 10 | lambda_0: 1. 11 | tau_2: 5. 12 | no_fa: True 13 | za_dim: 0 14 | reduced_time_frame: True 15 | fa_regularization_scheme: 16 | augmented: False 17 | cos_sin_encoding: True 18 | no_APHYNITY: True 19 | use_complete_signal: True 20 | obtain_init_position: True 21 | nb_observed_theta_0: 10 22 | nb_observed_theta_1: 5 23 | nb_observed: 10 24 | model: 'APHYNITYDoublePendulum' 25 | zp_priors: 26 | \theta_0: 27 | min: .5 28 | max: 1.5 29 | \theta_1: 30 | min: .5 31 | max: 1.5 32 | \dot \theta_0: 33 | min: .5 34 | max: 1.5 35 | \dot \theta_1: 36 | min: .5 37 | max: 1.5 38 | simulator: 39 | name: 'DoublePendulum' 40 | init_param_omega_0: .5 41 | true_param_omega_0: 1. 42 | true_param_alpha: .5 43 | partial_model_param: 44 | - \theta_0 45 | - \theta_1 46 | - \dot \theta_0 47 | - \dot \theta_1 48 | -------------------------------------------------------------------------------- /code/scripts/configs/DoublePendulum/APHYNITY.yaml: -------------------------------------------------------------------------------- 1 | name: 'Double Pendulum - APHYNITY' 2 | parameters: 3 | data_path: "code/data/DoublePendulum" 4 | optimization: 5 | save_all_models: True 6 | learning_rate_fa: 0.0005 7 | weight_decay_fa: .0 8 | n_epochs: 100 9 | n_iter: 1 10 | lambda_0: 1000. 11 | tau_2: 5. 12 | no_fa: False 13 | za_dim: 10 14 | reduced_time_frame: True 15 | fa_regularization_scheme: 16 | augmented: False 17 | cos_sin_encoding: True 18 | no_APHYNITY: False 19 | use_complete_signal: True 20 | obtain_init_position: True 21 | nb_observed_theta_0: 10 22 | nb_observed_theta_1: 5 23 | nb_observed: 10 24 | model: 'APHYNITYDoublePendulum' 25 | zp_priors: 26 | \theta_0: 27 | min: .5 28 | max: 1.5 29 | \theta_1: 30 | min: .5 31 | max: 1.5 32 | \dot \theta_0: 33 | min: .5 34 | max: 1.5 35 | \dot \theta_1: 36 | min: .5 37 | max: 1.5 38 | simulator: 39 | name: 'DoublePendulum' 40 | init_param_omega_0: .5 41 | true_param_omega_0: 1. 42 | true_param_alpha: .5 43 | partial_model_param: 44 | - \theta_0 45 | - \theta_1 46 | - \dot \theta_0 47 | - \dot \theta_1 48 | -------------------------------------------------------------------------------- /code/scripts/configs/DoublePendulum/Fa_only.yaml: -------------------------------------------------------------------------------- 1 | name: 'Double Pendulum - APHYNITY' 2 | parameters: 3 | data_path: "code/data/DoublePendulum" 4 | optimization: 5 | save_all_models: True 6 | learning_rate_fa: 0.0005 7 | weight_decay_fa: .0 8 | n_epochs: 1000 9 | n_iter: 1 10 | lambda_0: 1. 11 | tau_2: 5. 12 | no_fa: False 13 | no_fp: True 14 | za_dim: 10 15 | reduced_time_frame: True 16 | fa_regularization_scheme: 17 | augmented: False 18 | cos_sin_encoding: True 19 | no_APHYNITY: True 20 | use_complete_signal: True 21 | obtain_init_position: True 22 | nb_observed_theta_0: 10 23 | nb_observed_theta_1: 5 24 | nb_observed: 10 25 | model: 'APHYNITYDoublePendulum' 26 | zp_priors: 27 | \theta_0: 28 | min: .5 29 | max: 1.5 30 | \theta_1: 31 | min: .5 32 | max: 1.5 33 | \dot \theta_0: 34 | min: .5 35 | max: 1.5 36 | \dot \theta_1: 37 | min: .5 38 | max: 1.5 39 | simulator: 40 | name: 'DoublePendulum' 41 | init_param_omega_0: .5 42 | true_param_omega_0: 1. 43 | true_param_alpha: .5 44 | partial_model_param: 45 | - \theta_0 46 | - \theta_1 47 | - \dot \theta_0 48 | - \dot \theta_1 49 | -------------------------------------------------------------------------------- /code/scripts/configs/DoublePendulum/APHYNITY_plus.yaml: -------------------------------------------------------------------------------- 1 | name: 'Double Pendulum - APHYNITY' 2 | parameters: 3 | data_path: "code/data/DoublePendulum" 4 | optimization: 5 | learning_rate_fa: 0.0005 6 | weight_decay_fa: .0 7 | n_epochs: 50 8 | n_iter: 1 9 | lambda_0: 5000. 10 | tau_2: 5. 11 | no_fa: False 12 | za_dim: 10 13 | fa_regularization_scheme: 14 | augmented: True 15 | nb_augmentation: 2 16 | cos_sin_encoding: True 17 | no_APHYNITY: False 18 | use_complete_signal: True 19 | obtain_init_position: True 20 | nb_observed_theta_0: 10 21 | nb_observed_theta_1: 5 22 | nb_observed: 10 23 | combined_augmentation: True 24 | reduced_time_frame: True 25 | loss_params: False 26 | path_model: "code/data/DoublePendulum/runs/APHYNITY.yaml/12_14_2022_20_02_29/APHYNITYDoublePendulum_best_valid_model.pt" 27 | model: 'APHYNITYDoublePendulum' 28 | zp_priors: 29 | theta_0: 30 | min: -1.5691 31 | max: 4.7124 32 | theta_1: 33 | min: -1.5691 34 | max: 4.7124 35 | dtheta_0: 36 | min: -15. 37 | max: 15. 38 | dtheta_1: 39 | min: -30. 40 | max: 30. 41 | simulator: 42 | name: 'DoublePendulum' 43 | init_param: 44 | omega_0: .5 45 | true_param: 46 | omega_0: 1. 47 | alpha: .5 48 | partial_model_param: 49 | - \theta_0 50 | - \theta_1 51 | - \dot \theta_0 52 | - \dot \theta_1 53 | -------------------------------------------------------------------------------- /code/scripts/configs/DoublePendulum/APHYNITY_robustness.yaml: -------------------------------------------------------------------------------- 1 | name: 'Double Pendulum - APHYNITY' 2 | parameters: 3 | data_path: "code/data/DoublePendulum" 4 | optimization: 5 | learning_rate_fa: 0.0005 6 | weight_decay_fa: .0 7 | n_epochs: 1 8 | n_iter: 1 9 | lambda_0: 1000. 10 | tau_2: 5. 11 | no_fa: False 12 | za_dim: 10 13 | fa_regularization_scheme: 14 | augmented: True 15 | nb_augmentation: 2 16 | cos_sin_encoding: True 17 | no_APHYNITY: False 18 | use_complete_signal: True 19 | obtain_init_position: True 20 | nb_observed_theta_0: 10 21 | nb_observed_theta_1: 5 22 | nb_observed: 10 23 | combined_augmentation: True 24 | reduced_time_frame: True 25 | loss_params: False 26 | path_model: "code/data/DoublePendulum/runs/APHYNiTYBis_2_given_init_angles.yaml/11_27_2022_19_41_06/APHYNITYDoublePendulum_best_valid_model.pt" 27 | model: 'APHYNITYDoublePendulum' 28 | zp_priors: 29 | theta_0: 30 | min: -1.5691 31 | max: 4.7124 32 | theta_1: 33 | min: -1.5691 34 | max: 4.7124 35 | dtheta_0: 36 | min: -15. 37 | max: 15. 38 | dtheta_1: 39 | min: -30. 40 | max: 30. 41 | simulator: 42 | name: 'DoublePendulum' 43 | init_param: 44 | omega_0: .5 45 | true_param: 46 | omega_0: 1. 47 | alpha: .5 48 | partial_model_param: 49 | - \theta_0 50 | - \theta_1 51 | - \dot \theta_0 52 | - \dot \theta_1 53 | -------------------------------------------------------------------------------- /code/scripts/configs/DoublePendulum/APHYNITY_robustness_10epochs.yaml: -------------------------------------------------------------------------------- 1 | name: 'Double Pendulum - APHYNITY' 2 | parameters: 3 | data_path: "code/data/DoublePendulum" 4 | optimization: 5 | learning_rate_fa: 0.0005 6 | weight_decay_fa: .0 7 | n_epochs: 10 8 | n_iter: 1 9 | lambda_0: 1000. 10 | tau_2: 5. 11 | no_fa: False 12 | za_dim: 10 13 | fa_regularization_scheme: 14 | augmented: True 15 | nb_augmentation: 2 16 | cos_sin_encoding: True 17 | no_APHYNITY: False 18 | use_complete_signal: True 19 | obtain_init_position: True 20 | nb_observed_theta_0: 10 21 | nb_observed_theta_1: 5 22 | nb_observed: 10 23 | combined_augmentation: True 24 | reduced_time_frame: True 25 | loss_params: False 26 | path_model: "code/data/DoublePendulum/runs/APHYNiTYBis_2_given_init_angles.yaml/11_27_2022_19_41_06/APHYNITYDoublePendulum_best_valid_model.pt" 27 | model: 'APHYNITYDoublePendulum' 28 | zp_priors: 29 | theta_0: 30 | min: -1.5691 31 | max: 4.7124 32 | theta_1: 33 | min: -1.5691 34 | max: 4.7124 35 | dtheta_0: 36 | min: -15. 37 | max: 15. 38 | dtheta_1: 39 | min: -30. 40 | max: 30. 41 | simulator: 42 | name: 'DoublePendulum' 43 | init_param: 44 | omega_0: .5 45 | true_param: 46 | omega_0: 1. 47 | alpha: .5 48 | partial_model_param: 49 | - \theta_0 50 | - \theta_1 51 | - \dot \theta_0 52 | - \dot \theta_1 53 | -------------------------------------------------------------------------------- /code/scripts/configs/DoublePendulum/HVAE.yaml: -------------------------------------------------------------------------------- 1 | name: 'Double Pendulum - HVAE' 2 | parameters: 3 | data_path: "code/data/DoublePendulum" 4 | optimization: 5 | save_all_models: True 6 | learning_rate_fa: 0.0005 7 | weight_decay_fa: .0000 8 | b_size: 100 9 | n_epochs: 200 10 | alpha: .0 11 | beta: 0. 12 | gamma: 0. 13 | omicron: 0. 14 | no_fa: False 15 | no_fp: True 16 | za_dim: 2 17 | reduced_time_frame: True 18 | augmented: False 19 | simple_encoder: False 20 | cos_sin_encoding: True 21 | use_complete_signal: True 22 | obtain_init_position: True 23 | nb_observed_theta_0: 10 24 | nb_observed_theta_1: 5 25 | nb_observed: 10 26 | model: 'HybridVAEDoublePendulum' 27 | zp_prior_type: "Uniform" 28 | zp_priors: 29 | \theta_0: 30 | sigma: -1. 31 | mean: -1. 32 | min: -1.5691 33 | max: 4.7124 34 | \theta_1: 35 | sigma: -1. 36 | mean: -1. 37 | min: -1.5691 38 | max: 4.7124 39 | \dot \theta_0: 40 | sigma: -1. 41 | mean: -1. 42 | min: -15. 43 | max: 15. 44 | \dot \theta_1: 45 | sigma: -1. 46 | mean: -1. 47 | min: -30. 48 | max: 30. 49 | simulator: 50 | name: 'DoublePendulum' 51 | init_param_omega_0: .5 52 | true_param_omega_0: 1. 53 | true_param_alpha: .5 54 | partial_model_param: 55 | - \theta_0 56 | - \theta_1 57 | - \dot \theta_0 58 | - \dot \theta_1 -------------------------------------------------------------------------------- /code/nn/mlp.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 3 | 4 | import torch.nn as nn 5 | from code.nn.utils import ReactionDiffusionParametersScaler 6 | 7 | act_dict = {"ReLU": nn.ReLU(), 8 | "Softplus": nn.Softplus(), 9 | "SELU": nn.SELU(), 10 | "ReactionDiffusionParametersScaler": ReactionDiffusionParametersScaler(), 11 | None: nn.Identity()} 12 | 13 | 14 | class AddModule(nn.Module): 15 | def __init__(self, m1, m2): 16 | super(AddModule, self).__init__() 17 | self.m1 = m1 18 | self.m2 = m2 19 | 20 | def forward(self, x): 21 | return self.m1(x) + self.m2(x) 22 | 23 | 24 | class MLP(nn.Module): 25 | def __init__(self, layers=None, linear_dim=0, hidden_act="ReLU", final_act=None): 26 | super(MLP, self).__init__() 27 | if layers is not None and len(layers) > 0: 28 | nn_lay = [] 29 | for l1, l2 in zip(layers[:-1], layers[1:]): 30 | nn_lay += [nn.Linear(l1, l2), act_dict[hidden_act]] 31 | nn_lay.pop() 32 | nn_lay.append(act_dict[final_act]) 33 | self.nn = nn.Sequential(*nn_lay) 34 | else: 35 | self.nn = lambda x: 0. 36 | if linear_dim > 0: 37 | self.linear = nn.Linear(linear_dim, linear_dim) 38 | self.net = AddModule(self.nn, self.linear) 39 | else: 40 | self.linear = lambda x: 0. 41 | self.net = self.nn 42 | 43 | 44 | def forward(self, x): 45 | return self.net(x) -------------------------------------------------------------------------------- /code/scripts/configs/DoublePendulum/HVAE_robustness.yaml: -------------------------------------------------------------------------------- 1 | name: 'Double Pendulum - HVAE' 2 | parameters: 3 | data_path: "code/data/DoublePendulum" 4 | optimization: 5 | save_all_models: True 6 | learning_rate_fa: 0.0005 7 | weight_decay_fa: .00001 8 | b_size: 500 9 | n_epochs: 1 10 | alpha: .0 11 | beta: 0. 12 | gamma: 0. 13 | omicron: 0. 14 | no_fa: True 15 | za_dim: 10 16 | reduced_time_frame: True 17 | augmented: True 18 | nb_augmentation: 1 19 | cos_sin_encoding: True 20 | use_complete_signal: True 21 | obtain_init_position: True 22 | nb_observed_theta_0: 10 23 | nb_observed_theta_1: 5 24 | nb_observed: 10 25 | model: 'HybridVAEDoublePendulum' 26 | path_model: "code/data/DoublePendulum/runs/HVAE.yaml/12_13_2022_20_28_08/HybridVAEDoublePendulum_best_valid_model.pt" 27 | zp_prior_type: "Uniform" 28 | zp_priors: 29 | \theta_0: 30 | sigma: -1. 31 | mean: -1. 32 | min: -1.5691 33 | max: 4.7124 34 | \theta_1: 35 | sigma: -1. 36 | mean: -1. 37 | min: -1.5691 38 | max: 4.7124 39 | \dot \theta_0: 40 | sigma: -1. 41 | mean: -1. 42 | min: -15. 43 | max: 15. 44 | \dot \theta_1: 45 | sigma: -1. 46 | mean: -1. 47 | min: -30. 48 | max: 30. 49 | simulator: 50 | name: 'DoublePendulum' 51 | init_param_omega_0: .5 52 | true_param_omega_0: 1. 53 | true_param_alpha: .5 54 | partial_model_param: 55 | - \theta_0 56 | - \theta_1 57 | - \dot \theta_0 58 | - \dot \theta_1 -------------------------------------------------------------------------------- /code/scripts/configs/DoublePendulum/HVAE_robustness_10epochs.yaml: -------------------------------------------------------------------------------- 1 | name: 'Double Pendulum - HVAE' 2 | parameters: 3 | data_path: "code/data/DoublePendulum" 4 | optimization: 5 | save_all_models: True 6 | learning_rate_fa: 0.0005 7 | weight_decay_fa: .00001 8 | b_size: 100 9 | n_epochs: 10 10 | alpha: .0 11 | beta: 0. 12 | gamma: 0. 13 | omicron: 0. 14 | no_fa: True 15 | za_dim: 10 16 | reduced_time_frame: True 17 | augmented: True 18 | nb_augmentation: 1 19 | cos_sin_encoding: True 20 | use_complete_signal: True 21 | obtain_init_position: True 22 | nb_observed_theta_0: 10 23 | nb_observed_theta_1: 5 24 | nb_observed: 10 25 | model: 'HybridVAEDoublePendulum' 26 | path_model: "code/data/DoublePendulum/runs/HVAE.yaml/12_13_2022_20_28_08/HybridVAEDoublePendulum_best_valid_model.pt" 27 | zp_prior_type: "Uniform" 28 | zp_priors: 29 | \theta_0: 30 | sigma: -1. 31 | mean: -1. 32 | min: -1.5691 33 | max: 4.7124 34 | \theta_1: 35 | sigma: -1. 36 | mean: -1. 37 | min: -1.5691 38 | max: 4.7124 39 | \dot \theta_0: 40 | sigma: -1. 41 | mean: -1. 42 | min: -15. 43 | max: 15. 44 | \dot \theta_1: 45 | sigma: -1. 46 | mean: -1. 47 | min: -30. 48 | max: 30. 49 | simulator: 50 | name: 'DoublePendulum' 51 | init_param_omega_0: .5 52 | true_param_omega_0: 1. 53 | true_param_alpha: .5 54 | partial_model_param: 55 | - \theta_0 56 | - \theta_1 57 | - \dot \theta_0 58 | - \dot \theta_1 -------------------------------------------------------------------------------- /code/scripts/configs/DoublePendulum/HVAE_plus.yaml: -------------------------------------------------------------------------------- 1 | name: 'Double Pendulum - HVAE' 2 | parameters: 3 | data_path: "code/data/DoublePendulum" 4 | optimization: 5 | save_all_models: True 6 | learning_rate_fa: 0.0005 7 | weight_decay_fa: .0000 8 | b_size: 100 9 | n_epochs: 200 10 | alpha: 0. 11 | beta: 0. 12 | gamma: 0. 13 | omicron: 0. 14 | no_fa: False 15 | za_dim: 2 16 | simple_encoder: False 17 | reduced_time_frame: True 18 | augmented: True 19 | nb_augmentation: 1 20 | combined_augmentation: True 21 | cos_sin_encoding: True 22 | use_complete_signal: True 23 | obtain_init_position: True 24 | nb_observed_theta_0: 10 25 | nb_observed_theta_1: 5 26 | nb_observed: 10 27 | loss_params: True 28 | model: 'HybridVAEDoublePendulum' 29 | path_model: "code/data/DoublePendulum/runs/HVAE.yaml/12_16_2022_00_08_08/HybridVAEDoublePendulum_best_valid_model14.pt" 30 | zp_prior_type: "Uniform" 31 | zp_priors: 32 | \theta_0: 33 | sigma: -1. 34 | mean: -1. 35 | min: -1.5691 36 | max: 4.7124 37 | \theta_1: 38 | sigma: -1. 39 | mean: -1. 40 | min: -1.5691 41 | max: 4.7124 42 | \dot \theta_0: 43 | sigma: -1. 44 | mean: -1. 45 | min: -15. 46 | max: 15. 47 | \dot \theta_1: 48 | sigma: -1. 49 | mean: -1. 50 | min: -30. 51 | max: 30. 52 | simulator: 53 | name: 'DoublePendulum' 54 | init_param_omega_0: .5 55 | true_param_omega_0: 1. 56 | true_param_alpha: .5 57 | partial_model_param: 58 | - \theta_0 59 | - \theta_1 60 | - \dot \theta_0 61 | - \dot \theta_1 -------------------------------------------------------------------------------- /code/simulators/GenericSimulator.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class PhysicalModel(nn.Module): 9 | def __init__(self, param_values, trainable_param): 10 | super(PhysicalModel, self).__init__() 11 | self._nb_parameters = len(trainable_param) 12 | self._X_dim = -1 13 | self.incomplete_param_dim_textual = [] 14 | self.full_param_dim_textual = [] 15 | self.missing_param_dim_textual = [] 16 | self.trainable_param = trainable_param 17 | for p in param_values.keys(): 18 | if p in trainable_param: 19 | self.incomplete_param_dim_textual.append(p) 20 | else: 21 | self.missing_param_dim_textual.append(p) 22 | self.full_param_dim_textual.append(p) 23 | 24 | def _nb_parameters(self): 25 | return self._nb_parameters 26 | 27 | def _X_dim(self): 28 | return self._nb_parameters 29 | 30 | def forward(self, t, x): 31 | pass 32 | 33 | def parameterized_forward(self, t, x, **parameters): 34 | if len(set(parameters.keys()) - set(self.trainable_param)) != 0: 35 | raise Exception("Parameterized forward physical arguments does not match the simulator specification. " 36 | "Simulator: {} - kwargs: {}".format(self.trainable_param, parameters.keys())) 37 | pass 38 | 39 | def get_x_labels(self): 40 | return ["$x_%d$" for i in range(self._X_dim)] 41 | 42 | def get_name(self): 43 | return "Generic Simulator" 44 | 45 | 46 | class GenericSimulator: 47 | def __init__(self): 48 | pass 49 | 50 | def sample_theta(self, n=1) -> torch.tensor: 51 | pass 52 | 53 | def sample_sequence(self, theta=None) -> tuple[torch.tensor, torch.tensor]: 54 | pass 55 | 56 | def sample_split_sequence(self, theta=None) -> tuple[torch.tensor, tuple[torch.tensor, torch.tensor]]: 57 | pass 58 | 59 | def get_uncomplete_forward_model(self) -> nn.Module: 60 | pass 61 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Robust Hybrid Learning With Expert Augmentation 2 | ![](figures/improved_diffusion.png) 3 | This repository contains the official implementations of experiments made in the paper ["Robust Hybrid Learning With Expert Augmentation 4 | "](https://openreview.net/forum?id=oe4dl4MCGY) from Antoine Wehenkel, Jens Behrmann, Hsiang Hsu, Guillermo Sapiro, Gilles Louppe, Joern-Henrik Jacobsen. 5 | You can use the following reference to cite our work: 6 | ``` 7 | @article{wehenkel2023robust, 8 | title={Robust Hybrid Learning With Expert Augmentation}, 9 | author={Wehenkel, Antoine and Behrmann, Jens and Hsu, Hsiang and Sapiro, Guillermo and Louppe, Gilles and Jacobsen, J{\"o}rn-Henrik}, 10 | journal={Transactions on Machine Learning Research}, 11 | year={2023} 12 | } 13 | ``` 14 | ## Dependencies 15 | You can install dependencies from the files `requirements.txt` and `requirements_pip.txt` as: 16 | ``` 17 | conda env create -f requirements.txt -n RHL 18 | conda activate RHL 19 | pip install -r requirements_pip.txt 20 | ``` 21 | 22 | To be able to run all commands from the root of the repository you must also execute: 23 | `export PYTHONPATH=.` 24 | 25 | ## Data generation 26 | Before running any experiments you must generate train, validation and test sets by 27 | running the `GenerateDataset.py` python code. 28 | For instance to generate data for the Damped Pendulum experiments you would do: 29 | 30 | ``` 31 | python code/data/DampedPendulum/GenerateDataset.py 32 | ``` 33 | 34 | ## Double pendulum data 35 | 36 | You can get the data of the double pendulum experiments from [this website](https://developer.ibm.com/exchanges/data/all/double-pendulum-chaotic/) and fill in the folder `code/data/DoublePendulum/dpc_dataset_csv` with the data from there. 37 | 38 | ## Training hybrid learning methods: 39 | 40 | You can simply run the `run_experiments.py` file from the `code/scripts` subfolder with the appropriate configuration file. 41 | For instance the configuration file `code/scripts/configs/Pendulum/APHYNITY.yaml` will train an hybrid model with APHYNITY 42 | on the *Damped Pendulum* data: 43 | ``` 44 | python code/scripts/run_experiments.py --config code/scripts/configs/Pendulum/APHYNITY.yaml 45 | ``` 46 | 47 | -------------------------------------------------------------------------------- /License.txt: -------------------------------------------------------------------------------- 1 | Copyright (C) 2022 Apple Inc. All Rights Reserved. 2 | 3 | IMPORTANT: This Apple software is supplied to you by Apple 4 | Inc. ("Apple") in consideration of your agreement to the following 5 | terms, and your use, installation, modification or redistribution of 6 | this Apple software constitutes acceptance of these terms. If you do 7 | not agree with these terms, please do not use, install, modify or 8 | redistribute this Apple software. 9 | 10 | In consideration of your agreement to abide by the following terms, and 11 | subject to these terms, Apple grants you a personal, non-exclusive 12 | license, under Apple's copyrights in this original Apple software (the 13 | "Apple Software"), to use, reproduce, modify and redistribute the Apple 14 | Software, with or without modifications, in source and/or binary forms; 15 | provided that if you redistribute the Apple Software in its entirety and 16 | without modifications, you must retain this notice and the following 17 | text and disclaimers in all such redistributions of the Apple Software. 18 | Neither the name, trademarks, service marks or logos of Apple Inc. may 19 | be used to endorse or promote products derived from the Apple Software 20 | without specific prior written permission from Apple. Except as 21 | expressly stated in this notice, no other rights or licenses, express or 22 | implied, are granted by Apple herein, including but not limited to any 23 | patent rights that may be infringed by your derivative works or by other 24 | works in which the Apple Software may be incorporated. 25 | 26 | The Apple Software is provided by Apple on an "AS IS" basis. APPLE 27 | MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION 28 | THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS 29 | FOR A PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND 30 | OPERATION ALONE OR IN COMBINATION WITH YOUR PRODUCTS. 31 | 32 | IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL 33 | OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 34 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 35 | INTERRUPTION) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION, 36 | MODIFICATION AND/OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED 37 | AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE), 38 | STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE 39 | POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /code/utils/double_pendulum.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 3 | 4 | import math 5 | 6 | import cv2 7 | import numpy as np 8 | import pandas as pd 9 | import torch 10 | import torch.utils.data as data_utils 11 | 12 | 13 | def xy_to_theta(dx, dy): 14 | theta = np.arctan2(dy, dx) + math.pi / 2 15 | # cond_1 = ((dy > 0) & (theta < 0)) 16 | # cond_2 = ((dy < 0) & (theta > 0)) 17 | # else_cond = 1 - (cond_1 & cond_2) 18 | # theta = (theta + math.pi) * cond_1 + cond_2 * (theta - math.pi) + else_cond * theta 19 | # theta = theta + math.pi/2 20 | # theta = theta * (theta < math.pi) + (theta - 2*math.pi) * (theta >= math.pi) 21 | return theta 22 | 23 | 24 | def from_raw_pixels_to_angle(markers): 25 | cos_1 = markers[:, 3] - markers[:, 1] 26 | sin_1 = markers[:, 0] - markers[:, 2] 27 | theta_1 = xy_to_theta(cos_1, sin_1) 28 | 29 | cos_2 = markers[:, 5] - markers[:, 3] 30 | sin_2 = markers[:, 2] - markers[:, 4] 31 | theta_2 = xy_to_theta(cos_2, sin_2) 32 | 33 | return theta_1, theta_2 34 | 35 | 36 | def video_to_frames(video_path, max_amount=np.Inf): 37 | '''Convert a video into its frames.''' 38 | frames = [] 39 | # load the video 40 | vidcap = cv2.VideoCapture(video_path) 41 | while vidcap.isOpened(): 42 | success, frame = vidcap.read() 43 | frames.append(frame) 44 | if len(frames) >= max_amount: 45 | break 46 | return frames 47 | 48 | 49 | def get_dataloaders(path, files=['0'], sub_sampling_rate=1, seq_len=100, seq_gap=10, b_size=100): 50 | train, val, test = [], [], [] 51 | for f in files: 52 | marker_positions = pd.read_csv(path + f +'.csv', 53 | header=None).values 54 | angles = from_raw_pixels_to_angle(marker_positions) 55 | angles_dataset = torch.cat((torch.tensor(angles[0]).unsqueeze(1), torch.tensor(angles[1]).unsqueeze(1)), 1) 56 | angles_dataset = angles_dataset[::sub_sampling_rate] 57 | length_tot = angles_dataset.shape[0] 58 | 59 | samples_ids = torch.arange(seq_len).unsqueeze(0) + torch.arange(0, length_tot-seq_len, seq_gap).unsqueeze(1) 60 | 61 | train_s = int(samples_ids.shape[0] * 0.6) 62 | val_s = int(samples_ids.shape[0] * .2) 63 | train.append(angles_dataset[samples_ids[:train_s]]) 64 | val.append(angles_dataset[samples_ids[train_s:train_s+val_s]]) 65 | test.append(angles_dataset[samples_ids[train_s+val_s:]]) 66 | 67 | train = torch.cat(train, 0) 68 | val = torch.cat(val, 0) 69 | test = torch.cat(test, 0) 70 | 71 | train = data_utils.TensorDataset(train.float()) 72 | dl_train = data_utils.DataLoader(train, batch_size=b_size, shuffle=True) 73 | val = data_utils.TensorDataset(val.float()) 74 | dl_val = data_utils.DataLoader(val, batch_size=b_size, shuffle=False) 75 | test = data_utils.TensorDataset(test.float()) 76 | dl_test = data_utils.DataLoader(test, batch_size=b_size, shuffle=False) 77 | 78 | frequency = int(400/sub_sampling_rate) 79 | time = torch.arange(0, seq_len / frequency, 1 / frequency) 80 | 81 | return dl_train, dl_val, dl_test, time 82 | 83 | 84 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to making participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the open source team at [opensource-conduct@group.apple.com](mailto:opensource-conduct@group.apple.com). All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org), version 1.4, 71 | available at [https://www.contributor-covenant.org/version/1/4/code-of-conduct.html](https://www.contributor-covenant.org/version/1/4/code-of-conduct.html) -------------------------------------------------------------------------------- /code/data/ReactionDiffusion/GenerateDataset.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 3 | 4 | from code.simulators import ReactionDiffusion 5 | import torch 6 | from tqdm import tqdm 7 | import os 8 | import pickle 9 | 10 | 11 | def gen_data(n=500, shifted=""): 12 | if shifted == "small_all": 13 | distributions = { 14 | "a": lambda x: torch.distributions.Uniform(2e-3, 4e-3).sample_n(x), 15 | "b": lambda x: torch.distributions.Uniform(1e-3, 1e-2).sample_n(x), 16 | "k": lambda x: torch.distributions.Uniform(5e-3, 8e-3).sample_n(x) 17 | } 18 | elif shifted == "small_k": 19 | distributions = { 20 | "a": lambda x: torch.distributions.Uniform(1e-3, 2e-3).sample_n(x), 21 | "b": lambda x: torch.distributions.Uniform(3e-3, 7e-3).sample_n(x), 22 | "k": lambda x: torch.distributions.Uniform(5e-3, 8e-3).sample_n(x) 23 | } 24 | elif shifted == "medium_all": 25 | distributions = { 26 | "a": lambda x: torch.distributions.Uniform(2e-3, 4e-3).sample_n(x), 27 | "b": lambda x: torch.distributions.Uniform(1e-3, 1e-2).sample_n(x), 28 | "k": lambda x: torch.distributions.Uniform(8e-3, 2e-2).sample_n(x) 29 | } 30 | elif shifted == "medium_k": 31 | distributions = { 32 | "a": lambda x: torch.distributions.Uniform(1e-3, 2e-3).sample_n(x), 33 | "b": lambda x: torch.distributions.Uniform(3e-3, 7e-3).sample_n(x), 34 | "k": lambda x: torch.distributions.Uniform(8e-3, 2e-2).sample_n(x) 35 | } 36 | elif shifted == "large_all": 37 | distributions = { 38 | "a": lambda x: torch.distributions.Uniform(2e-3, 4e-3).sample_n(x), 39 | "b": lambda x: torch.distributions.Uniform(1e-3, 1e-2).sample_n(x), 40 | "k": lambda x: torch.distributions.Uniform(2e-2, 1e-1).sample_n(x) 41 | } 42 | elif shifted == "large_k": 43 | distributions = { 44 | "a": lambda x: torch.distributions.Uniform(1e-3, 2e-3).sample_n(x), 45 | "b": lambda x: torch.distributions.Uniform(3e-3, 7e-3).sample_n(x), 46 | "k": lambda x: torch.distributions.Uniform(2e-2, 1e-1).sample_n(x) 47 | } 48 | elif shifted == "None": 49 | distributions = { 50 | "a": lambda x: torch.distributions.Uniform(2e-3, 4e-3).sample_n(x), 51 | "b": lambda x: torch.distributions.Uniform(1e-3, 1e-2).sample_n(x), 52 | "k": lambda x: torch.distributions.Uniform(3e-3, 5e-3).sample_n(x) 53 | } 54 | else: 55 | distributions = { 56 | "a": lambda x: torch.distributions.Uniform(1e-3, 2e-3).sample_n(x), 57 | "b": lambda x: torch.distributions.Uniform(3e-3, 7e-3).sample_n(x), 58 | "k": lambda x: torch.distributions.Uniform(3e-3, 5e-3).sample_n(x) 59 | } 60 | s = ReactionDiffusion() 61 | s.n_timesteps = 50 62 | s.T0 = 0. 63 | s.T1 = 5. 64 | dataset = torch.zeros((n, s.n_timesteps + 1) + s._X_dim) 65 | true_param = {"a": torch.zeros(n), 66 | "b": torch.zeros(n), 67 | "k": torch.zeros(n)} 68 | for i in tqdm(range(n)): 69 | true_param['a'][i] = distributions["a"](1)[0] 70 | true_param['b'][i] = distributions["b"](1)[0] 71 | true_param['k'][i] = distributions["k"](1)[0] 72 | t, x = s.sample_sequences({x: t[i] for x, t in true_param.items()}) 73 | dataset[i] = x.squeeze(1) 74 | 75 | return t, dataset, true_param 76 | 77 | 78 | path = 'code/data/ReactionDiffusion' 79 | if not os.path.exists(path): 80 | os.makedirs(path) 81 | 82 | with open(r"%s/train.pkl" % path, "wb") as output_file: 83 | pickle.dump(gen_data(2000), output_file) 84 | 85 | with open(r"%s/valid.pkl" % path, "wb") as output_file: 86 | pickle.dump(gen_data(100), output_file) 87 | 88 | with open(r"%s/test_shifted.pkl" % path, "wb") as output_file: 89 | pickle.dump(gen_data(100, "None"), output_file) 90 | 91 | for hardness in ["small", "medium", "large"]: 92 | for shift in ["all", "k"]: 93 | shift_def = hardness + "_" + shift 94 | with open(r"%s/%s_test.pkl" % (path, shift_def), "wb") as output_file: 95 | pickle.dump(gen_data(100, shifted=shift_def), output_file) 96 | -------------------------------------------------------------------------------- /code/simulators/DoublePendulum.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torchdyn.core import NeuralODE 7 | from torchdyn import * 8 | from code.simulators.GenericSimulator import PhysicalModel 9 | import math 10 | 11 | 12 | class DoublePendulumODE(PhysicalModel): 13 | def __init__(self, param_values=None, trainable_param=None): 14 | super(DoublePendulumODE, self).__init__({}, []) 15 | # Defining constants that we could potentially set as parameters 16 | self.g = 9.81 17 | self.m1 = 1. 18 | self.m2 = 1. 19 | self.l1 = 0.091 20 | self.l2 = 0.070 21 | self._X_dim = 4 22 | 23 | def forward(self, t, x): 24 | return self.parameterized_forward(t, x) 25 | 26 | def parameterized_forward(self, t, x, **parameters): 27 | #super(DoublePendulumODE, self).parameterized_forward(None, x, **parameters) 28 | g = self.g 29 | m1 = self.m1 30 | m2 = self.m2 31 | l1 = self.l1 32 | l2 = self.l2 33 | 34 | theta_1, theta_2, d_theta_1, d_theta_2 = torch.chunk(x, 4, 1) 35 | dd_theta_1 = (-g * (2 * m1 + m2) * torch.sin(theta_1) - 36 | m2 * g * torch.sin(theta_1 - 2 * theta_2) - 37 | 2 * torch.sin(theta_1 - theta_2) * 38 | m2 * (d_theta_2 ** 2 * l2 + d_theta_1 ** 2 * l1 * torch.cos(theta_1 - theta_2))) / \ 39 | (l1 * (2 * m1 + m2 - m2 * torch.cos(2 * theta_1 - 2 * theta_2))) 40 | dd_theta_2 = (2 * torch.sin(theta_1 - theta_2) * (d_theta_1 ** 2 * l1 * (m1 + m2) + 41 | g * (m1 + m2) * torch.cos(theta_1) + 42 | d_theta_2 ** 2 * l2 * m2 * torch.cos(theta_1 - theta_2))) / \ 43 | (l2 * (2 * m1 + m2 - m2 * torch.cos(2 * theta_1 - 2 * theta_2))) 44 | 45 | return torch.cat((d_theta_1, d_theta_2, dd_theta_1, dd_theta_2), 1) 46 | 47 | def get_x_labels(self): 48 | return ["$\\theta_0$", "$\\theta_1$", "$\\dot \\theta_0$", "$\\dot \\theta_1$"] 49 | 50 | def get_name(self): 51 | return "Double Pendulum" 52 | 53 | 54 | class DoublePendulum: 55 | def __init__(self, init_param=None, true_param=None, T0=0., T1=20, n_timesteps=100, partial_model_param=None, 56 | name="DoublePendulum", 57 | **kwargs): 58 | self.full_param_dim_textual = [] 59 | self.incomplete_param_dim_textual = partial_model_param 60 | self.init_param = self.prior_full_parameters() if init_param is None else init_param 61 | self.T0 = float(T0) 62 | self.T1 = float(T1) 63 | self.n_timesteps = int(n_timesteps) 64 | self.name = name 65 | self.true_param = {} if true_param is None else true_param 66 | 67 | def prior_incomplete_parameters(self): 68 | return {} 69 | 70 | def prior_full_parameters(self): 71 | return {} 72 | 73 | def sample_init_state(self, n=1): 74 | theta = torch.rand([n, 2]) * 2.0 * math.pi - math.pi 75 | return torch.cat([theta, torch.zeros_like(theta)], 1) 76 | 77 | def sample_sequences(self, parameters=None, n=1, x0=None): 78 | if parameters is None: 79 | parameters = self.true_param 80 | x0 = self.sample_init_state(n) if x0 is None else x0 81 | t_span = torch.linspace(self.T0, self.T1, self.n_timesteps + 1) 82 | f = self.get_full_physical_model(parameters) 83 | model = NeuralODE(f, sensitivity='adjoint', solver='dopri5', rtol=1e-7, atol=1e-7) 84 | with torch.no_grad(): 85 | t_eval, y_hat = model(x0, t_span) 86 | return t_eval, y_hat + torch.randn_like(y_hat) * .01 87 | 88 | def get_incomplete_physical_model(self, parameters, trainable=True) -> nn.Module: 89 | if trainable: 90 | for p in self.full_param_dim_textual: 91 | if p not in self.incomplete_param_dim_textual: 92 | parameters[p] = 0. 93 | return DoublePendulumODE(param_values=parameters, trainable_param=self.incomplete_param_dim_textual) 94 | else: 95 | return DoublePendulumODE(param_values=parameters, trainable_param=[]) 96 | 97 | def get_full_physical_model(self, parameters, trainable=False) -> nn.Module: 98 | if trainable: 99 | return DoublePendulumODE(param_values=parameters, trainable_param=self.full_param_dim_textual) 100 | else: 101 | return DoublePendulumODE(param_values=parameters, trainable_param=[]) 102 | -------------------------------------------------------------------------------- /code/data/DampedPendulum/GenerateDataset.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 3 | 4 | from code.simulators import DampedPendulum 5 | import torch 6 | from tqdm import tqdm 7 | import os 8 | import pickle 9 | import argparse 10 | 11 | 12 | def gen_data(n=500, shifted=""): 13 | if shifted == "small_all": 14 | distributions = { 15 | "alpha": lambda x: torch.distributions.Uniform(.3, .6).sample_n(x), 16 | "omega_0": lambda x: torch.distributions.Uniform(.5, 1.5).sample_n(x) 17 | } 18 | elif shifted == "small_alpha": 19 | distributions = { 20 | "alpha": lambda x: torch.distributions.Uniform(.3, .6).sample_n(x), 21 | "omega_0": lambda x: torch.distributions.Uniform(1.5, 3.1).sample_n(x) 22 | } 23 | elif shifted == "medium_all": 24 | distributions = { 25 | "alpha": lambda x: torch.distributions.Uniform(.6, 1.).sample_n(x), 26 | "omega_0": lambda x: torch.distributions.Uniform(.5, 1.5).sample_n(x) 27 | } 28 | elif shifted == "medium_alpha": 29 | distributions = { 30 | "alpha": lambda x: torch.distributions.Uniform(.6, 1.).sample_n(x), 31 | "omega_0": lambda x: torch.distributions.Uniform(1.5, 3.1).sample_n(x) 32 | } 33 | elif shifted == "large_all": 34 | distributions = { 35 | "alpha": lambda x: torch.distributions.Uniform(1., 3.).sample_n(x), 36 | "omega_0": lambda x: torch.distributions.Uniform(.5, 1.5).sample_n(x) 37 | } 38 | elif shifted == "large_alpha": 39 | distributions = { 40 | "alpha": lambda x: torch.distributions.Uniform(1., 3.).sample_n(x), 41 | "omega_0": lambda x: torch.distributions.Uniform(1.5, 3.1).sample_n(x) 42 | } 43 | elif shifted == "None": 44 | distributions = { 45 | "alpha": lambda x: torch.distributions.Uniform(0., .6).sample_n(x), 46 | "omega_0": lambda x: torch.distributions.Uniform(.5, 1.5).sample_n(x) 47 | } 48 | else: 49 | distributions = { 50 | "alpha": lambda x: torch.distributions.Uniform(0., 0.6).sample_n(x), 51 | "omega_0": lambda x: torch.distributions.Uniform(1.5, 3.1).sample_n(x) 52 | } 53 | s = DampedPendulum() 54 | s.n_timesteps = 200 55 | s.T0 = 0. 56 | s.T1 = 20 57 | dataset = torch.zeros(n, s.n_timesteps + 1, 1, 2) 58 | true_param = {'omega_0': torch.zeros(n), 'alpha': torch.zeros(n), 'A': torch.zeros(n), 'phi': torch.zeros(n)} 59 | for i in tqdm(range(n)): 60 | true_param['omega_0'][i] = distributions["omega_0"](1)[0] 61 | true_param['alpha'][i] = distributions["alpha"](1)[0] 62 | true_param['phi'][i] = 0. 63 | true_param['A'][i] = 0. 64 | with torch.no_grad(): 65 | t, x = s.sample_sequences({'omega_0': true_param['omega_0'][i], 66 | 'alpha': true_param['alpha'][i], 67 | 'A': true_param['A'][i], 68 | 'phi': true_param['phi'][i]}) 69 | dataset[i, :, :, :] = x 70 | dataset = dataset.permute(0, 1, 3, 2).unsqueeze(3) 71 | return t, dataset, true_param 72 | 73 | if __name__ == "__main__": 74 | try: 75 | nb_data_train = 1000 76 | except: 77 | # Create the parser 78 | parser = argparse.ArgumentParser() 79 | # Add an argument 80 | parser.add_argument('--nb_train', type=int, default=1000) 81 | # Parse the argument 82 | args = parser.parse_args() 83 | 84 | nb_data_train = args.nb_train 85 | path = 'code/data/DampedPendulum' 86 | if not os.path.exists(path): 87 | os.makedirs(path) 88 | with open(r"%s/train.pkl" % path, "wb") as output_file: 89 | pickle.dump(gen_data(nb_data_train), output_file) 90 | 91 | with open(r"%s/valid.pkl" % path, "wb") as output_file: 92 | pickle.dump(gen_data(100), output_file) 93 | 94 | with open(r"%s/test.pkl" % path, "wb") as output_file: 95 | pickle.dump(gen_data(100), output_file) 96 | 97 | with open(r"%s/valid_shifted.pkl" % path, "wb") as output_file: 98 | pickle.dump(gen_data(100, shifted="None"), output_file) 99 | 100 | with open(r"%s/test_shifted.pkl" % path, "wb") as output_file: 101 | pickle.dump(gen_data(100, shifted="None"), output_file) 102 | 103 | for hardness in ["small", "medium", "large"]: 104 | for shift in ["all", "alpha"]: 105 | shift_def = hardness + "_" + shift 106 | with open(r"%s/%s_test.pkl" % (path, shift_def), "wb") as output_file: 107 | pickle.dump(gen_data(100, shifted=shift_def), output_file) -------------------------------------------------------------------------------- /code/simulators/RLCCircuit.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torchdyn.core import NeuralODE 7 | from torchdyn import * 8 | from code.simulators.GenericSimulator import PhysicalModel 9 | import math 10 | 11 | 12 | class RLCODE(PhysicalModel): 13 | def __init__(self, param_values, trainable_param): 14 | super(RLCODE, self).__init__(param_values, trainable_param) 15 | self._X_dim = 2 16 | self.R = nn.Parameter(torch.tensor(param_values["R"])) if "R" in trainable_param else param_values["R"] 17 | self.L = nn.Parameter(torch.tensor(param_values["L"])) if "L" in trainable_param else param_values["L"] 18 | self.C = nn.Parameter(torch.tensor(param_values["C"])) if "C" in trainable_param else param_values["C"] 19 | self.V_c = nn.Parameter(torch.tensor(param_values["V_c"])) if "V_c" in trainable_param else param_values["V_c"] 20 | self.V_a = nn.Parameter(torch.tensor(param_values["V_a"])) if "V_a" in trainable_param else param_values["V_a"] 21 | self.omega = nn.Parameter(torch.tensor(param_values["omega"])) if "omega" in trainable_param \ 22 | else param_values["omega"] 23 | 24 | def forward(self, t, x): 25 | return torch.cat((x[:, [1]]/self.C, (self.V(t) - x[:, [0]] - self.R * x[:, [1]]/self.C)/self.L), 1) 26 | 27 | def V(self, t, V_c=None, V_a=None, omega=None): 28 | if V_a is None: 29 | return self.V_c + self.V_a * torch.sin(t*self.omega) 30 | return V_c + V_a * torch.sin(t*omega) 31 | 32 | def parameterized_forward(self, t, x, **parameters): 33 | super(RLCODE, self).parameterized_forward(t, x, **parameters) 34 | C = self.C if "C" not in parameters else parameters["C"] 35 | R = self.R if "R" not in parameters else parameters["R"] 36 | L = self.L if "L" not in parameters else parameters["L"] 37 | V_a = self.V_a if "V_a" not in parameters else parameters["V_a"] 38 | V_c = self.V_c if "V_c" not in parameters else parameters["V_c"] 39 | omega = self.omega if "omega" not in parameters else parameters["omega"] 40 | return torch.cat((x[:, [1]] / C, (self.V(t, V_c, V_a, omega) - x[:, [0]] - R * x[:, [1]] / C) / L), 1) 41 | 42 | def get_x_labels(self): 43 | return ["$U_C$", "$I_C$"] 44 | 45 | def get_name(self): 46 | return "RLC Circuit" + str(self.trainable_param) 47 | 48 | 49 | class RLCCircuit: 50 | def __init__(self, init_param=None, true_param=None, T0=0., T1=5, n_timesteps=40, partial_model_param=None, 51 | name="RLCCircuit", **kwargs): 52 | if partial_model_param is None: 53 | partial_model_param = ["R", "L", "C"] 54 | self.full_param_dim_textual = ["R", "L", "C", "V_a", "V_c", "omega"] 55 | self.incomplete_param_dim_textual = partial_model_param 56 | self.init_param = self.prior_full_parameters() if init_param is None else init_param 57 | self.T0 = float(T0) 58 | self.T1 = float(T1) 59 | self.n_timesteps = int(n_timesteps) 60 | self.name = "RLC" 61 | self.true_param = {"R": 5, "L": 5., "C": 1., "V_a": 2.5, "V_c": 1., "omega": 2.} if true_param is None else true_param 62 | 63 | def prior_incomplete_parameters(self): 64 | return {"R": 20, "L": .1, "C": 0.1} 65 | 66 | def prior_full_parameters(self): 67 | return {"R": 20, "L": .1, "C": 0.1, "V_a": 2.5, "V_c": 2., "omega": 2.} 68 | 69 | def sample_init_state(self, n=1): 70 | return torch.cat((torch.randn(n, 1), torch.zeros([n, 1])), 1) 71 | 72 | def sample_sequences(self, parameters=None, n=1, x0=None): 73 | if parameters is None: 74 | parameters = self.true_param 75 | x0 = self.sample_init_state(n) if x0 is None else x0 76 | t_span = torch.linspace(self.T0, self.T1, self.n_timesteps + 1) 77 | f = self.get_full_physical_model(parameters) 78 | model = NeuralODE(f, sensitivity='adjoint', solver='dopri5') 79 | with torch.no_grad(): 80 | t_eval, y_hat = model(x0, t_span) 81 | return t_eval, y_hat# + torch.randn_like(y_hat) * .01 82 | 83 | def get_incomplete_physical_model(self, parameters, trainable=True) -> nn.Module: 84 | if trainable: 85 | for p in self.full_param_dim_textual: 86 | if p not in self.incomplete_param_dim_textual and p not in parameters.keys(): 87 | parameters[p] = 0. 88 | return RLCODE(param_values=parameters, trainable_param=self.incomplete_param_dim_textual) 89 | else: 90 | return RLCODE(param_values=parameters, trainable_param=[]) 91 | 92 | def get_full_physical_model(self, parameters, trainable=False) -> nn.Module: 93 | if trainable: 94 | return RLCODE(param_values=parameters, trainable_param=self.full_param_dim_textual) 95 | else: 96 | return RLCODE(param_values=parameters, trainable_param=[]) 97 | 98 | -------------------------------------------------------------------------------- /code/simulators/ReactionDiffusion.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torchdyn.core import NeuralODE 7 | from torchdyn import * 8 | from code.simulators.GenericSimulator import PhysicalModel 9 | import math 10 | import torch.nn.functional as F 11 | 12 | # This code is strongly inspired by https://github.com/yuan-yin/APHYNITY 13 | 14 | 15 | class ReactionDiffusionPDE(PhysicalModel): 16 | def __init__(self, param_values, trainable_param): 17 | super(ReactionDiffusionPDE, self).__init__(param_values, trainable_param) 18 | self._X_dim = (2, 32, 32) 19 | self.a = nn.Parameter(torch.tensor(param_values["a"])) if "a" in trainable_param else param_values["a"] 20 | self.b = nn.Parameter(torch.tensor(param_values["b"])) if "b" in trainable_param else param_values["b"] 21 | self.k = nn.Parameter(torch.tensor(param_values["k"])) if "k" in trainable_param else param_values["k"] 22 | 23 | self._dx = 2./32. 24 | self.register_buffer('_laplacian', torch.tensor( 25 | [ 26 | [0, 1, 0], 27 | [1, -4, 1], 28 | [0, 1, 0], 29 | ], 30 | ).float().view(1, 1, 3, 3) / (self._dx * self._dx)) 31 | ''' 32 | self.params_org = nn.ParameterDict({ 33 | 'a_org': nn.Parameter(torch.tensor(1e-3)), 34 | 'b_org': nn.Parameter(torch.tensor(5e-3)), 35 | 'k_org': nn.Parameter(torch.tensor(5e-3)), 36 | }) 37 | ''' 38 | 39 | def forward(self, t, x): 40 | return self.parameterized_forward(t, x) 41 | 42 | def parameterized_forward(self, t, x, **parameters): 43 | super(ReactionDiffusionPDE, self).parameterized_forward(t, x, **parameters) 44 | a = self.a if "a" not in parameters else parameters["a"].unsqueeze(2).unsqueeze(3) 45 | b = self.b if "b" not in parameters else parameters["b"].unsqueeze(2).unsqueeze(3) 46 | k = self.k if "k" not in parameters else parameters["k"].unsqueeze(2).unsqueeze(3) 47 | 48 | U = x[:, [0]] 49 | V = x[:, [1]] 50 | 51 | # if self.real_params is None: 52 | # self.params['a'] = torch.sigmoid(self.params_org['a_org']) * 1e-2 53 | # self.params['b'] = torch.sigmoid(self.params_org['b_org']) * 1e-2 54 | 55 | U_ = F.pad(U, pad=(1, 1, 1, 1), mode='circular') 56 | Delta_u = F.conv2d(U_, self._laplacian) 57 | 58 | V_ = F.pad(V, pad=(1, 1, 1, 1), mode='circular') 59 | Delta_v = F.conv2d(V_, self._laplacian) 60 | 61 | dUdt = a * Delta_u + U - U.pow(3) - V - k 62 | dVdt = b * Delta_v + U - V 63 | 64 | return torch.cat([dUdt, dVdt], dim=1) # .reshape(-1, 2*32*32) 65 | 66 | def get_x_labels(self): 67 | return ["$U$", "$V$"] 68 | 69 | def get_name(self): 70 | return "Reaction Diffusion" + str(self.trainable_param) 71 | 72 | 73 | class ReactionDiffusion: 74 | def __init__(self, init_param=None, true_param=None, T0=0., T1=5, n_timesteps=40, partial_model_param=None, 75 | name="RLCCircuit", **kwargs): 76 | if partial_model_param is None: 77 | partial_model_param = ["a", "b"] 78 | self.full_param_dim_textual = ["a", "b", "k"] 79 | self.incomplete_param_dim_textual = partial_model_param 80 | self.init_param = self.prior_full_parameters() if init_param is None else init_param 81 | self.T0 = float(T0) 82 | self.T1 = float(T1) 83 | self.n_timesteps = int(n_timesteps) 84 | self.name = "ReactionDiffusion" 85 | self._X_dim = (2, 32, 32) 86 | self.true_param = {"a": 1e-3, "b": 5e-3, "k": 5e-3} if true_param is None else true_param 87 | 88 | def prior_incomplete_parameters(self): 89 | return {"a": .1, "b": .1} 90 | 91 | def prior_full_parameters(self): 92 | return {"a": .1, "b": .1, "k": .1} 93 | 94 | def sample_init_state(self, n=1): 95 | return torch.rand(n, 2, 32, 32) 96 | 97 | def sample_sequences(self, parameters=None, n=1, x0=None): 98 | if parameters is None: 99 | parameters = self.true_param 100 | x0 = self.sample_init_state(n) if x0 is None else x0 101 | t_span = torch.linspace(self.T0, self.T1, self.n_timesteps + 1) 102 | f = self.get_full_physical_model(parameters) 103 | model = NeuralODE(f, sensitivity='adjoint', solver='dopri5') 104 | with torch.no_grad(): 105 | t_eval, y_hat = model(x0, t_span) 106 | return t_eval, y_hat# + torch.randn_like(y_hat) * .01 107 | 108 | def get_incomplete_physical_model(self, parameters, trainable=True) -> nn.Module: 109 | if trainable: 110 | for p in self.full_param_dim_textual: 111 | if p not in self.incomplete_param_dim_textual and p not in parameters.keys(): 112 | parameters[p] = 0. 113 | return ReactionDiffusionPDE(param_values=parameters, trainable_param=self.incomplete_param_dim_textual) 114 | else: 115 | return ReactionDiffusionPDE(param_values=parameters, trainable_param=[]) 116 | 117 | def get_full_physical_model(self, parameters, trainable=False) -> nn.Module: 118 | if trainable: 119 | return ReactionDiffusionPDE(param_values=parameters, trainable_param=self.full_param_dim_textual) 120 | else: 121 | return ReactionDiffusionPDE(param_values=parameters, trainable_param=[]) 122 | 123 | -------------------------------------------------------------------------------- /code/utils/utils.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 3 | 4 | import torch 5 | from matplotlib import pyplot as plt 6 | from torch import nn 7 | 8 | from code.hybrid_models import HybridAutoencoder 9 | from code.hybrid_models.APHYNITY import APHYNITYAutoencoderDoublePendulum 10 | from code.hybrid_models import APHYNITYAutoencoder, APHYNITYAutoencoderReactionDiffusion, HybridVAEDoublePendulum, \ 11 | HybridVAEReactionDiffusion, HybridVAE 12 | 13 | 14 | def compute_zp_metrics(true_zp, est_zp, prefix, device, list_param): 15 | sum_cur_rel_errors = 0 16 | metrics = {} 17 | for p in list_param: 18 | if p in true_zp and p in est_zp: 19 | true_zp_c = true_zp[p].to(device) 20 | est_zp_c = est_zp[p].squeeze(1).to(device) 21 | metrics[prefix + "MAE {}".format(p)] = (true_zp_c - est_zp_c).abs().mean().item() 22 | metrics[prefix + "cur_rel_error_" + p] = ((est_zp_c - true_zp_c).abs() / true_zp_c).mean().item() * 100 23 | sum_cur_rel_errors += metrics[prefix + "cur_rel_error_" + p] 24 | 25 | metrics[prefix + "avg_cur_rel_errors_"] = sum_cur_rel_errors / len(list_param) 26 | 27 | return metrics 28 | 29 | 30 | def get_models(solver, fp, device, config) -> tuple[HybridAutoencoder, HybridAutoencoder, torch.optim.Optimizer]: 31 | if solver in ["APHYNITY"]: 32 | trained_model = APHYNITYAutoencoder(fp.to(device), device=device, **config).to(device) 33 | model = APHYNITYAutoencoder(fp.to(device), device=device, **config).to(device) 34 | elif solver in ["APHYNITYReactionDiffusion"]: 35 | trained_model = APHYNITYAutoencoderReactionDiffusion(fp.to(device), device=device, **config).to(device) 36 | model = APHYNITYAutoencoderReactionDiffusion(fp.to(device), device=device, **config).to(device) 37 | elif solver in ["APHYNITYDoublePendulum"]: 38 | trained_model = APHYNITYAutoencoderDoublePendulum(fp.to(device), device=device, **config).to(device) 39 | model = APHYNITYAutoencoderDoublePendulum(fp.to(device), device=device, **config).to(device) 40 | elif config["model"] == "HybridVAE": 41 | trained_model = HybridVAE(fp.to(device), device=device, **config).to(device) 42 | model = HybridVAE(fp.to(device), device=device, **config).to(device) 43 | elif config["model"] == "HybridVAEReactionDiffusion": 44 | trained_model = HybridVAEReactionDiffusion(fp.to(device), device=device, **config).to(device) 45 | model = HybridVAEReactionDiffusion(fp.to(device), device=device, **config).to(device) 46 | elif config["model"] == "HybridVAEDoublePendulum": 47 | trained_model = HybridVAEDoublePendulum(fp.to(device), device=device, **config).to(device) 48 | model = HybridVAEDoublePendulum(fp.to(device), device=device, **config).to(device) 49 | else: 50 | raise Exception("The model chosen does not exist.") 51 | 52 | if "path_model" in config and config["augmented"]: 53 | if config["model"] == "HybridVAE": 54 | to_train = nn.ModuleList([model.ga, model.gp_1, model.gp_2]) 55 | elif config["model"] == "HybridVAEReactionDiffusion": 56 | to_train = nn.ModuleList([model.enc_za, model.enc_zp, model.gp_1]) 57 | elif config["model"] == "HybridVAEDoublePendulum": 58 | to_train = nn.ModuleList([model]) 59 | else: 60 | to_train = model.enc 61 | 62 | print("Loading models...") 63 | trained_model.load_state_dict(torch.load(config["path_model"], map_location=device)) 64 | model.load_state_dict(torch.load(config["path_model"], map_location=device)) 65 | 66 | if config["model"] in ["HybridVAE"]: 67 | model.sigma_x = torch.nn.Parameter(torch.zeros_like(model.sigma_x)) 68 | elif config["model"] in ["HybridVAEDoublePendulum"]: 69 | model.sigma_x_cos = nn.Parameter(torch.zeros(2, requires_grad=True)).to(device) 70 | model.sigma_x_sin = nn.Parameter(torch.zeros(2, requires_grad=True)).to(device) 71 | 72 | optimizer = torch.optim.Adam(to_train.parameters(), 73 | lr=config["learning_rate_fa"], 74 | weight_decay=config["weight_decay_fa"]) 75 | else: 76 | trained_model = None 77 | optimizer = torch.optim.Adam(model.parameters(), lr=config["learning_rate_fa"], 78 | weight_decay=config["weight_decay_fa"]) 79 | 80 | return trained_model, model, optimizer 81 | 82 | 83 | def plot_augmented_datasets(all_zp_train_augm, save_path): 84 | plt.figure(figsize=(15, 20)) 85 | plt.subplot(2, 2, 1) 86 | plt.xlim(xmin=-4, xmax=4) 87 | plt.title("$ \\theta_0$") 88 | z_p = torch.cat(all_zp_train_augm, 0).permute(1, 0)[[0]] 89 | z_p = z_p[~z_p.isinf()] 90 | plt.hist(z_p, bins=40, histtype="stepfilled", alpha=.3, density=True) 91 | plt.legend() 92 | 93 | plt.subplot(2, 2, 2) 94 | plt.title("$ \\theta_1$") 95 | plt.xlim(xmin=-4, xmax=4) 96 | z_p = torch.cat(all_zp_train_augm, 0).permute(1, 0)[[1]] 97 | z_p = z_p[~z_p.isinf()] 98 | plt.hist(z_p, bins=40, histtype="stepfilled", alpha=.3, 99 | density=True) 100 | plt.subplot(2, 2, 3) 101 | plt.xlim(xmin=-25, xmax=25) 102 | plt.title("$\\dot \\theta_0$") 103 | z_p = torch.cat(all_zp_train_augm, 0).permute(1, 0)[[2]] 104 | z_p = z_p[~z_p.isinf()] 105 | plt.hist(z_p, bins=40, histtype="stepfilled", alpha=.3, 106 | density=True) 107 | plt.legend() 108 | 109 | plt.subplot(2, 2, 4) 110 | plt.title("$\\dot \\theta_1$") 111 | plt.xlim(xmin=-50, xmax=70) 112 | z_p = torch.cat(all_zp_train_augm, 0).permute(1, 0)[[3]] 113 | z_p = z_p[~z_p.isinf()] 114 | plt.hist(z_p, bins=40, histtype="stepfilled", alpha=.3, 115 | density=True) 116 | plt.savefig("%s/init_state_dist.png" % save_path) 117 | 118 | -------------------------------------------------------------------------------- /code/simulators/DampedPendulum.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torchdyn.core import NeuralODE 7 | from torchdyn import * 8 | from code.simulators.GenericSimulator import PhysicalModel 9 | import math 10 | 11 | 12 | class DampedPendulumODE(PhysicalModel): 13 | def __init__(self, param_values, trainable_param): 14 | super(DampedPendulumODE, self).__init__(param_values, trainable_param) 15 | self._X_dim = 2 16 | if "omega_0" in trainable_param: 17 | self.omega_0 = nn.Parameter(torch.tensor(param_values["omega_0"], requires_grad=True)) 18 | else: 19 | self.register_buffer("omega_0", torch.tensor(param_values["omega_0"], requires_grad=False)) 20 | 21 | if "alpha" in trainable_param: 22 | self.alpha = nn.Parameter(torch.tensor(param_values["alpha"], requires_grad=True)) 23 | else: 24 | self.register_buffer("alpha", torch.tensor(param_values["alpha"], requires_grad=False)) 25 | 26 | if "A" in trainable_param: 27 | self.A = nn.Parameter(torch.tensor(param_values["A"], requires_grad=True)) 28 | else: 29 | self.register_buffer("A", torch.tensor(param_values["A"], requires_grad=False)) 30 | 31 | if "phi" in trainable_param: 32 | self.phi = nn.Parameter(torch.tensor(param_values["phi"], requires_grad=True)) 33 | else: 34 | self.register_buffer("phi", torch.tensor(param_values["phi"], requires_grad=False)) 35 | 36 | def forward(self, t, x): 37 | return torch.cat((x[:, [1]], -self.omega_0 ** 2 * torch.sin(x[:, [0]]) - self.alpha * x[:, [1]]), 1) 38 | 39 | def u(self, t, A=None, phi=None, omega_0=None): 40 | return A*(omega_0**2)*torch.cos(2*math.pi*phi) 41 | 42 | def parameterized_forward(self, t, x, **parameters): 43 | super(DampedPendulumODE, self).parameterized_forward(None, x, **parameters) 44 | omega_0 = self.omega_0 if "omega_0" not in parameters else parameters["omega_0"] 45 | alpha = self.alpha if "alpha" not in parameters else parameters["alpha"] 46 | #A = self.A if "A" not in parameters else parameters["A"] 47 | phi = self.phi if "phi" not in parameters else parameters["phi"] 48 | return torch.cat((x[:, [1]], -omega_0 ** 2 * torch.sin(x[:, [0]]) - alpha * x[:, [1]]), 1) 49 | #+ self.u(t, A, phi, omega_0) 50 | 51 | def get_x_labels(self): 52 | return ["$\\theta$", "$\\dot \\theta$"] 53 | 54 | def get_name(self): 55 | return "Damped Pendulum" 56 | 57 | def to(self, device): 58 | super(DampedPendulumODE, self).to(device) 59 | self.omega_0 = self.omega_0.to(device) 60 | self.alpha = self.alpha.to(device) 61 | self.A = self.A.to(device) 62 | self.phi = self.phi.to(device) 63 | return self 64 | 65 | 66 | class DampedPendulum: 67 | def __init__(self, init_param=None, true_param=None, T0=0., T1=20, n_timesteps=40, partial_model_param=None, name="DampedPendulum", 68 | **kwargs): 69 | if partial_model_param is None: 70 | partial_model_param = ["omega_0"] 71 | self.full_param_dim_textual = ["omega_0", "alpha", "A", "phi"] 72 | self.incomplete_param_dim_textual = partial_model_param 73 | self.init_param = self.prior_full_parameters() if init_param is None else init_param 74 | self.T0 = float(T0) 75 | self.T1 = float(T1) 76 | self.n_timesteps = int(n_timesteps) 77 | self.name = name 78 | self.true_param = {"omega_0": 2 * math.pi / 6, "alpha": 0.1, "A": 0., "phi": 1} if true_param is None else true_param 79 | 80 | def prior_incomplete_parameters(self): 81 | T0 = torch.rand(1) * 7 + 3 82 | return {"omega_0": 2 * math.pi / T0} 83 | 84 | def prior_full_parameters(self): 85 | alpha = torch.rand(1) * .5 86 | omega_0 = self.prior_incomplete_parameters()["omega_0"] 87 | A = torch.rand(1) * 40. 88 | phi = torch.rand(1) + 1. 89 | return {"omega_0": omega_0, "alpha": alpha} 90 | 91 | def init_state_APHYNITY_CODE(self, n): 92 | y0 = torch.rand([n, 2]) * 2.0 - 1 93 | radius = (torch.rand([n, 1]) + 1.3).expand(-1, 2) 94 | y0 = y0 / torch.sqrt((y0 ** 2).sum(1).unsqueeze(1).expand(-1, 2)) * radius 95 | return y0 96 | 97 | def sample_init_state(self, n=1): 98 | #return self.init_state_APHYNITY_CODE(n) 99 | min_theta, max_theta = -math.pi / 2, math.pi / 2 100 | return torch.cat((torch.rand([n, 1]) * (max_theta - min_theta) + min_theta, torch.zeros([n, 1])), 1) 101 | 102 | def sample_sequences(self, parameters=None, n=1, x0=None): 103 | if parameters is None: 104 | parameters = self.true_param 105 | x0 = self.sample_init_state(n) if x0 is None else x0 106 | t_span = torch.linspace(self.T0, self.T1, self.n_timesteps + 1) 107 | f = self.get_full_physical_model(parameters) 108 | model = NeuralODE(f, sensitivity='adjoint', solver='dopri5', rtol=1e-7, atol=1e-7) 109 | with torch.no_grad(): 110 | t_eval, y_hat = model(x0, t_span) 111 | return t_eval, y_hat + torch.randn_like(y_hat) * .01 112 | 113 | def get_incomplete_physical_model(self, parameters, trainable=True) -> nn.Module: 114 | if trainable: 115 | for p in self.full_param_dim_textual: 116 | if p not in self.incomplete_param_dim_textual: 117 | parameters[p] = 0. 118 | return DampedPendulumODE(param_values=parameters, trainable_param=self.incomplete_param_dim_textual) 119 | else: 120 | return DampedPendulumODE(param_values=parameters, trainable_param=[]) 121 | 122 | def get_full_physical_model(self, parameters, trainable=False) -> nn.Module: 123 | if trainable: 124 | return DampedPendulumODE(param_values=parameters, trainable_param=self.full_param_dim_textual) 125 | else: 126 | return DampedPendulumODE(param_values=parameters, trainable_param=[]) 127 | 128 | -------------------------------------------------------------------------------- /code/data/RLC/GenerateDataset.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 3 | 4 | from code.simulators import RLCCircuit 5 | import torch 6 | from tqdm import tqdm 7 | import os 8 | import pickle 9 | 10 | 11 | def gen_data(n=500, shifted="False"): 12 | if shifted == "small_all": 13 | distributions = { 14 | "R": lambda x: torch.distributions.Uniform(3., 4.).sample_n(x), 15 | "L": lambda x: torch.distributions.Uniform(3., 5.).sample_n(x), 16 | "C": lambda x: torch.distributions.Uniform(1., 2.5).sample_n(x), 17 | "V_a": lambda x: torch.distributions.Uniform(1.5, 3.5).sample_n(x), 18 | "V_c": lambda x: torch.distributions.Uniform(.5, 2.5).sample_n(x), 19 | "omega": lambda x: torch.distributions.Uniform(1., 2.5).sample_n(x) 20 | } 21 | elif shifted == "small_R": 22 | distributions = { 23 | "R": lambda x: torch.distributions.Uniform(3., 4.).sample_n(x), 24 | "L": lambda x: torch.distributions.Uniform(1., 3.).sample_n(x), 25 | "C": lambda x: torch.distributions.Uniform(.5, 1.5).sample_n(x), 26 | "V_a": lambda x: torch.distributions.Uniform(1.5, 3.5).sample_n(x), 27 | "V_c": lambda x: torch.distributions.Uniform(.5, 2.5).sample_n(x), 28 | "omega": lambda x: torch.distributions.Uniform(1., 2.5).sample_n(x) 29 | } 30 | elif shifted == "medium_all": 31 | distributions = { 32 | "R": lambda x: torch.distributions.Uniform(4., 8.).sample_n(x), 33 | "L": lambda x: torch.distributions.Uniform(3., 5.).sample_n(x), 34 | "C": lambda x: torch.distributions.Uniform(1., 2.5).sample_n(x), 35 | "V_a": lambda x: torch.distributions.Uniform(1.5, 3.5).sample_n(x), 36 | "V_c": lambda x: torch.distributions.Uniform(.5, 2.5).sample_n(x), 37 | "omega": lambda x: torch.distributions.Uniform(1., 2.5).sample_n(x) 38 | } 39 | elif shifted == "medium_R": 40 | distributions = { 41 | "R": lambda x: torch.distributions.Uniform(4., 8.).sample_n(x), 42 | "L": lambda x: torch.distributions.Uniform(1., 3.).sample_n(x), 43 | "C": lambda x: torch.distributions.Uniform(.5, 1.5).sample_n(x), 44 | "V_a": lambda x: torch.distributions.Uniform(1.5, 3.5).sample_n(x), 45 | "V_c": lambda x: torch.distributions.Uniform(.5, 2.5).sample_n(x), 46 | "omega": lambda x: torch.distributions.Uniform(1., 2.5).sample_n(x) 47 | } 48 | elif shifted == "large_all": 49 | distributions = { 50 | "R": lambda x: torch.distributions.Uniform(10., 20.).sample_n(x), 51 | "L": lambda x: torch.distributions.Uniform(3., 5.).sample_n(x), 52 | "C": lambda x: torch.distributions.Uniform(1., 2.5).sample_n(x), 53 | "V_a": lambda x: torch.distributions.Uniform(1.5, 3.5).sample_n(x), 54 | "V_c": lambda x: torch.distributions.Uniform(.5, 2.5).sample_n(x), 55 | "omega": lambda x: torch.distributions.Uniform(1., 2.5).sample_n(x) 56 | } 57 | elif shifted == "large_R": 58 | distributions = { 59 | "R": lambda x: torch.distributions.Uniform(10., 20.).sample_n(x), 60 | "L": lambda x: torch.distributions.Uniform(1., 3.).sample_n(x), 61 | "C": lambda x: torch.distributions.Uniform(.5, 1.5).sample_n(x), 62 | "V_a": lambda x: torch.distributions.Uniform(1.5, 3.5).sample_n(x), 63 | "V_c": lambda x: torch.distributions.Uniform(.5, 2.5).sample_n(x), 64 | "omega": lambda x: torch.distributions.Uniform(1., 2.5).sample_n(x) 65 | } 66 | elif shifted == "None": 67 | distributions = { 68 | "R": lambda x: torch.distributions.Uniform(1., 3.).sample_n(x), 69 | "L": lambda x: torch.distributions.Uniform(3., 5.).sample_n(x), 70 | "C": lambda x: torch.distributions.Uniform(1., 2.5).sample_n(x), 71 | "V_a": lambda x: torch.distributions.Uniform(1.5, 3.5).sample_n(x), 72 | "V_c": lambda x: torch.distributions.Uniform(.5, 2.5).sample_n(x), 73 | "omega": lambda x: torch.distributions.Uniform(5., 1.5).sample_n(x) 74 | } 75 | else: 76 | distributions = { 77 | "R": lambda x: torch.distributions.Uniform(1., 3.).sample_n(x), 78 | "L": lambda x: torch.distributions.Uniform(1., 3.).sample_n(x), 79 | "C": lambda x: torch.distributions.Uniform(.5, 1.5).sample_n(x), 80 | "V_a": lambda x: torch.distributions.Uniform(1.5, 3.5).sample_n(x), 81 | "V_c": lambda x: torch.distributions.Uniform(.5, 2.5).sample_n(x), 82 | "omega": lambda x: torch.distributions.Uniform(1., 2.5).sample_n(x) 83 | } 84 | s = RLCCircuit() 85 | s.n_timesteps = 200 86 | s.T0 = 0. 87 | s.T1 = 20. 88 | dataset = torch.zeros(n, s.n_timesteps + 1, 1, 2) 89 | true_param = {"R": torch.zeros(n), 90 | "L": torch.zeros(n), 91 | "C": torch.zeros(n), 92 | "V_a": torch.zeros(n), 93 | "V_c": torch.zeros(n), 94 | "omega": torch.zeros(n)} 95 | for i in tqdm(range(n)): 96 | true_param['R'][i] = distributions["R"](1)[0] 97 | true_param['L'][i] = distributions["L"](1)[0] 98 | true_param['C'][i] = distributions["C"](1)[0] 99 | true_param['V_a'][i] = 2.5 100 | true_param['V_c'][i] = 1. 101 | true_param['omega'][i] = 2. 102 | t, x = s.sample_sequences({x: t[i] for x, t in true_param.items()}) 103 | dataset[i, :, :, :] = x 104 | dataset = dataset.permute(0, 1, 3, 2).unsqueeze(3) 105 | return t, dataset, true_param 106 | 107 | 108 | path = 'code/data/RLC' 109 | if not os.path.exists(path): 110 | os.makedirs(path) 111 | 112 | with open(r"%s/train.pkl" % path, "wb") as output_file: 113 | pickle.dump(gen_data(3000), output_file) 114 | 115 | with open(r"%s/valid.pkl" % path, "wb") as output_file: 116 | pickle.dump(gen_data(1000), output_file) 117 | 118 | with open(r"%s/test_shifted.pkl" % path, "wb") as output_file: 119 | pickle.dump(gen_data(1000, "None"), output_file) 120 | 121 | for hardness in ["small", "medium", "large"]: 122 | for shift in ["all", "R"]: 123 | shift_def = hardness + "_" + shift 124 | with open(r"%s/%s_test.pkl" % (path, shift_def), "wb") as output_file: 125 | pickle.dump(gen_data(1000, shifted=shift_def), output_file) -------------------------------------------------------------------------------- /code/data/DoublePendulum/GenerateDataset.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 3 | 4 | import math 5 | 6 | import pandas as pd 7 | 8 | import code.simulators.DoublePendulum as DP 9 | import torch 10 | from torchdiffeq import odeint 11 | from tqdm import tqdm 12 | import os 13 | import pickle 14 | import argparse 15 | 16 | from code.utils.double_pendulum import from_raw_pixels_to_angle 17 | 18 | 19 | def get_datatensors(files=['0'], sub_sampling_rate=1, seq_len=100, seq_gap=1): 20 | train, val, test = [], [], [] 21 | for f in files: 22 | marker_positions = pd.read_csv(f'code/data/DoublePendulum/dpc_dataset_csv/{f}.csv', 23 | header=None).values 24 | angles = from_raw_pixels_to_angle(marker_positions) 25 | angles_dataset = torch.cat((torch.tensor(angles[0]).unsqueeze(1), torch.tensor(angles[1]).unsqueeze(1)), 1) 26 | angles_dataset = angles_dataset[::sub_sampling_rate] 27 | length_tot = angles_dataset.shape[0] // 2 28 | 29 | samples_ids = torch.arange(seq_len).unsqueeze(0) + torch.arange(0, length_tot - seq_len, seq_gap).unsqueeze( 30 | 1) + angles_dataset.shape[0] // 2 - 1 31 | 32 | train_s = int(samples_ids.shape[0] * 0.4) 33 | val_s = int(samples_ids.shape[0] * .3) 34 | train.append(angles_dataset[samples_ids[-train_s:]]) 35 | val.append(angles_dataset[samples_ids[-train_s - val_s:-train_s]]) 36 | test.append(angles_dataset[samples_ids[:-train_s - val_s]]) 37 | 38 | train = torch.cat(train, 0).float() 39 | train = train[torch.randperm(train.shape[0])] 40 | 41 | val = torch.cat(val, 0).float() 42 | val = val[torch.randperm(val.shape[0])] 43 | 44 | test = torch.cat(test, 0).float() 45 | test = test[torch.randperm(test.shape[0])] 46 | 47 | frequency = int(400 / sub_sampling_rate) 48 | time = torch.arange(0, seq_len / frequency, 1 / frequency) 49 | 50 | return train, val, test, time 51 | 52 | 53 | if __name__ == "__main__": 54 | 55 | # Create the parser 56 | parser = argparse.ArgumentParser() 57 | # Add an argument 58 | parser.add_argument('--nb_train', type=int, default=1000) 59 | # Parse the argument 60 | args = parser.parse_args() 61 | 62 | nb_data_train = args.nb_train 63 | path = 'code/data/DoublePendulum' 64 | if not os.path.exists(path): 65 | os.makedirs(path) 66 | 67 | train, valid, test, time = get_datatensors(files=['%d' % i for i in range(21)], sub_sampling_rate=4, seq_len=21, 68 | seq_gap=1) 69 | 70 | with open(r"%s/train.pkl" % path, "wb") as output_file: 71 | dx, dy = train[:, 0, :], train[:, 2, :] 72 | diff = (dx - dy).unsqueeze(2) 73 | choices = torch.cat((diff - 2 * math.pi, diff, diff + 2 * math.pi), 2) 74 | _, choice = torch.min(choices ** 2, 2) # np.arctan2(np.sin(x-y), np.cos(x-y)) * 200 75 | omegas = torch.gather(choices, dim=2, index=choice.unsqueeze(2)).squeeze(2) / (2 * time[1]) 76 | init_states = torch.cat((train[:, 1, :], -omegas), 1) 77 | init_states = {"\\theta_0": init_states[:, 0], "\\theta_1": init_states[:, 1], 78 | "\\dot \\theta_0": init_states[:, 2], "\\dot \\theta_1": init_states[:, 3]} 79 | 80 | pickle.dump([time[1:] - time[1], train[:, 1:].unsqueeze(3).unsqueeze(3), init_states], output_file) 81 | 82 | with open(r"%s/valid_shifted.pkl" % path, "wb") as output_file: 83 | dx, dy = valid[:, 0, :], valid[:, 2, :] 84 | diff = (dx - dy).unsqueeze(2) 85 | choices = torch.cat((diff - 2 * math.pi, diff, diff + 2 * math.pi), 2) 86 | _, choice = torch.min(choices ** 2, 2) # np.arctan2(np.sin(x-y), np.cos(x-y)) * 200 87 | omegas = torch.gather(choices, dim=2, index=choice.unsqueeze(2)).squeeze(2) / (2 * time[1]) 88 | init_states = torch.cat((valid[:, 1, :], -omegas), 1) 89 | init_states = {"\\theta_0": init_states[:, 0], "\\theta_1": init_states[:, 1], 90 | "\\dot \\theta_0": init_states[:, 2], "\\dot \\theta_1": init_states[:, 3]} 91 | 92 | pickle.dump([time[1:] - time[1], valid[:, 1:].unsqueeze(3).unsqueeze(3), init_states], output_file) 93 | 94 | with open(r"%s/valid.pkl" % path, "wb") as output_file: 95 | dx, dy = valid[:, 0, :], valid[:, 2, :] 96 | diff = (dx - dy).unsqueeze(2) 97 | choices = torch.cat((diff - 2 * math.pi, diff, diff + 2 * math.pi), 2) 98 | _, choice = torch.min(choices ** 2, 2) # np.arctan2(np.sin(x-y), np.cos(x-y)) * 200 99 | omegas = torch.gather(choices, dim=2, index=choice.unsqueeze(2)).squeeze(2) / (2 * time[1]) 100 | init_states = torch.cat((valid[:, 1, :], -omegas), 1) 101 | init_states = {"\\theta_0": init_states[:, 0], "\\theta_1": init_states[:, 1], 102 | "\\dot \\theta_0": init_states[:, 2], "\\dot \\theta_1": init_states[:, 3]} 103 | pickle.dump([time[1:] - time[1], valid[:, 1:].unsqueeze(3).unsqueeze(3), init_states], output_file) 104 | 105 | with open(r"%s/test.pkl" % path, "wb") as output_file: 106 | dx, dy = test[:, 0, :], test[:, 2, :] 107 | diff = (dx - dy).unsqueeze(2) 108 | choices = torch.cat((diff - 2 * math.pi, diff, diff + 2 * math.pi), 2) 109 | _, choice = torch.min(choices ** 2, 2) # np.arctan2(np.sin(x-y), np.cos(x-y)) * 200 110 | omegas = torch.gather(choices, dim=2, index=choice.unsqueeze(2)).squeeze(2) / (2 * time[1]) 111 | init_states = torch.cat((test[:, 1, :], -omegas), 1) 112 | init_states = {"\\theta_0": init_states[:, 0], "\\theta_1": init_states[:, 1], 113 | "\\dot \\theta_0": init_states[:, 2], "\\dot \\theta_1": init_states[:, 3]} 114 | 115 | pickle.dump([time[1:] - time[1], test[:, 1:].unsqueeze(3).unsqueeze(3), init_states], output_file) 116 | 117 | with open(r"%s/test_shifted.pkl" % path, "wb") as output_file: 118 | dx, dy = test[:, 0, :], test[:, 2, :] 119 | diff = (dx - dy).unsqueeze(2) 120 | choices = torch.cat((diff - 2 * math.pi, diff, diff + 2 * math.pi), 2) 121 | _, choice = torch.min(choices ** 2, 2) # np.arctan2(np.sin(x-y), np.cos(x-y)) * 200 122 | omegas = torch.gather(choices, dim=2, index=choice.unsqueeze(2)).squeeze(2) / (2 * time[1]) 123 | init_states = torch.cat((test[:, 1, :], -omegas), 1) 124 | init_states = {"\\theta_0": init_states[:, 0], "\\theta_1": init_states[:, 1], 125 | "\\dot \\theta_0": init_states[:, 2], "\\dot \\theta_1": init_states[:, 3]} 126 | 127 | pickle.dump([time[1:] - time[1], test[:, 1:].unsqueeze(3).unsqueeze(3), init_states], output_file) 128 | -------------------------------------------------------------------------------- /code/utils/plotter.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 3 | 4 | import matplotlib.pyplot as plt 5 | from textwrap import wrap 6 | 7 | import torch 8 | 9 | 10 | def plot_curves(t, x_pred, x_obs, title, labels, save_name): 11 | fig = plt.figure(figsize=(24, 12)) 12 | fig.suptitle("\n".join(wrap(title, 150)), fontsize=18) 13 | for i in range(9): 14 | plt.subplot(3, 3, i + 1) 15 | plt.plot(t, x_pred[i], '--', linewidth=2.) 16 | for j in range(x_obs.shape[2]): 17 | plt.scatter(t, x_obs[i, :, j], marker='x', linewidth=1.) 18 | if i == 0: 19 | plt.legend(["Predicted %s" % labels[0], "Predicted %s" % labels[1], 20 | "Realized %s" % labels[0], "Realized %s" % labels[1]]) 21 | plt.xlabel("Time - t") 22 | plt.ylabel("State - X(t)") 23 | plt.savefig(save_name + ".pdf") 24 | print("Figure saved in %s" % save_name) 25 | plt.close(fig) 26 | 27 | 28 | def plot_curves_partial(t, x_pred, x_obs, title, labels, save_name, est_param=None, true_param=None): 29 | fig = plt.figure(figsize=(24, 12)) 30 | fig.suptitle("\n".join(wrap(title, 150)), fontsize=18) 31 | for i in range(9): 32 | plt.subplot(3, 3, i + 1) 33 | plt.plot(t[-x_pred.shape[1]:], x_pred[i], '--', linewidth=2.) 34 | for j in range(x_obs.shape[2]): 35 | plt.scatter(t, x_obs[i, :, j], marker='x', linewidth=1.) 36 | if i == 0: 37 | plt.legend(["Predicted %s" % labels[0], "Predicted %s" % labels[1], 38 | "Realized %s" % labels[0], "Realized %s" % labels[1]]) 39 | plt.xlabel("Time - t") 40 | plt.ylabel("State - X(t)") 41 | if est_param is not None: 42 | message = "".join(["{}: True: {:.4f} Est: {:.4f} -- ".format(p, 43 | true_param[p][i], 44 | est_param[p][i, 0]) for p in est_param.keys()]) 45 | else: 46 | message = "Sin" 47 | plt.title("\n".join(wrap(message[:-3], 60)), fontsize=13) 48 | plt.tight_layout() 49 | plt.savefig(save_name + ".pdf") 50 | print("Figure saved in %s" % save_name) 51 | plt.close(fig) 52 | 53 | 54 | def plot_curves_complete(t, x_pred, x_obs, title, labels, save_name, nb_observed=-1, est_param=None, true_param=None): 55 | fig = plt.figure(figsize=(24, 12)) 56 | fig.suptitle("\n".join(wrap(title, 150)), fontsize=18) 57 | for i in range(9): 58 | plt.subplot(3, 3, i + 1) 59 | plt.plot(t[-x_pred.shape[1]:], x_pred[i], '--', linewidth=2.) 60 | if nb_observed > 0: 61 | plt.vlines(t[nb_observed], x_obs.min(), x_obs.max(), label='Nb Observed') 62 | for j in range(x_obs.shape[2]): 63 | plt.scatter(t, x_obs[i, :, j], marker='x', linewidth=1.) 64 | if i == 0: 65 | plt.legend(["Predicted %s" % labels[0], "Predicted %s" % labels[1], 66 | "Realized %s" % labels[0], "Realized %s" % labels[1]]) 67 | plt.xlabel("Time - t") 68 | plt.ylabel("State - X(t)") 69 | if est_param is not None: 70 | message = "".join(["${}$: True: {:.4f} Est: {:.4f} \n".format(p, 71 | true_param[p][i], 72 | est_param[p][i, 0]) for p in est_param.keys()]) 73 | else: 74 | message = "Sin" 75 | plt.title(message[:-3], fontsize=13) 76 | #plt.title("\n".join(wrap(message[:-3], 60)), fontsize=13) 77 | plt.tight_layout() 78 | plt.savefig(save_name + ".pdf") 79 | print("Figure saved in %s" % save_name) 80 | plt.close(fig) 81 | 82 | 83 | def plot_curves_double_pendulum(t, x_pred, x_obs, title, save_name, nb_observed_theta_0, nb_observed_theta_1, 84 | nb_observed, est_param=None, true_param=None): 85 | fig = plt.figure(figsize=(24, 12)) 86 | fig.suptitle("\n".join(wrap(title, 150)), fontsize=18) 87 | for i in range(9): 88 | plt.subplot(3, 3, i + 1) 89 | plt.plot(t, x_pred[i], '--', linewidth=2.) 90 | plt.scatter(t[nb_observed-nb_observed_theta_0:nb_observed], 91 | x_obs[i, nb_observed-nb_observed_theta_0:nb_observed, 0], marker='o', linewidth=1., c='blue') 92 | plt.scatter(t[nb_observed-nb_observed_theta_1:nb_observed], 93 | x_obs[i, nb_observed-nb_observed_theta_1:nb_observed, 1], marker='o', linewidth=1, c='orange') 94 | 95 | if nb_observed - nb_observed_theta_0 > 0: 96 | time_theta0 = torch.cat((t[:nb_observed-nb_observed_theta_0], t[nb_observed:]), 0) 97 | real_theta0 = torch.cat((x_obs[i, :nb_observed-nb_observed_theta_0, 0], x_obs[i, nb_observed:, 0]), 0) 98 | else: 99 | time_theta0 = t[nb_observed:] 100 | real_theta0 = x_obs[i, nb_observed:, 0] 101 | plt.scatter(time_theta0, real_theta0, marker='x', linewidth=1., c='blue') 102 | 103 | if nb_observed - nb_observed_theta_1 > 0: 104 | time_theta1 = torch.cat((t[:nb_observed-nb_observed_theta_1], t[nb_observed:]), 0) 105 | real_theta1 = torch.cat((x_obs[i, :nb_observed-nb_observed_theta_1, 1], x_obs[i, nb_observed:, 1]), 0) 106 | else: 107 | time_theta1 = t[nb_observed:] 108 | real_theta1 = x_obs[i, nb_observed:, 1] 109 | plt.scatter(time_theta1, real_theta1, marker='x', linewidth=1., c='orange') 110 | 111 | if i == 0: 112 | plt.legend(["Predicted $\\theta\_0$", "Predicted $\\theta\_1$", 113 | "Observed $\\theta\_0$", "Observed $\\theta\_1$", 114 | "Realized $\\theta\_0$", "Realized $\\theta\_1$"]) 115 | plt.xlabel("Time - t") 116 | plt.ylabel("State - X(t)") 117 | if est_param is not None: 118 | message = "".join(["${}$: True: {:.4f} Est: {:.4f} \n".format(p, 119 | true_param[p][i], 120 | est_param[p][i, 0]) for p in est_param.keys()]) 121 | else: 122 | message = "Sin" 123 | plt.title("\n".join(wrap(message[:-3], 90)), fontsize=13) 124 | #plt.title("\n".join(wrap("Sin", 60)), fontsize=13) 125 | plt.tight_layout() 126 | plt.savefig(save_name + ".pdf") 127 | print("Figure saved in %s" % save_name) 128 | plt.close(fig) 129 | 130 | 131 | def plot_diffusion(x_pred, x_obs, title, labels, save_name, est_param, true_param): 132 | fig = plt.figure(figsize=(24, 30)) 133 | fig.suptitle("\n".join(wrap(title, 250)), fontsize=24) 134 | nb_sec = 5 135 | for i in range(nb_sec): 136 | plt.subplot(6, nb_sec, i + 1) 137 | plt.imshow(x_pred[0, i * 10, 0], cmap='BrBG') 138 | plt.subplot(6, nb_sec, nb_sec + i + 1) 139 | plt.imshow(x_obs[0, i * 10, 0], cmap='BrBG') 140 | plt.subplot(6, nb_sec, 2*nb_sec + i + 1) 141 | plt.imshow(x_obs[0, i * 10, 0] - x_pred[0, i * 10, 0], cmap='BrBG') 142 | 143 | plt.subplot(6, nb_sec, nb_sec*3 + i + 1) 144 | plt.imshow(x_pred[0, i * 10, 1], cmap='PiYG') 145 | plt.subplot(6, nb_sec, nb_sec*4 + i + 1) 146 | plt.imshow(x_obs[0, i * 10, 1], cmap='PiYG') 147 | plt.subplot(6, nb_sec, nb_sec*5 + i + 1) 148 | plt.imshow(x_obs[0, i * 10, 1] - x_pred[0, i * 10, 1], cmap='PiYG') 149 | plt.tight_layout() 150 | 151 | plt.tight_layout() 152 | plt.savefig(save_name + ".pdf") 153 | print("Figure saved in %s" % save_name) 154 | plt.close(fig) -------------------------------------------------------------------------------- /code/nn/unet.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.optim as optim 8 | from torch.autograd import Variable 9 | from torch.utils.data import Dataset, DataLoader 10 | import torchvision 11 | 12 | 13 | class Block(nn.Module): 14 | def __init__(self, in_ch, out_ch): 15 | super().__init__() 16 | self.conv1 = nn.Conv1d(in_ch, out_ch, 3, padding=1) 17 | self.relu = nn.ReLU() 18 | self.conv2 = nn.Conv1d(out_ch, out_ch, 3, padding=1) 19 | 20 | def forward(self, x): 21 | return self.relu(self.conv2(self.relu(self.conv1(x)))) 22 | 23 | 24 | class Block2d(nn.Module): 25 | def __init__(self, in_ch, out_ch): 26 | super().__init__() 27 | self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1) 28 | self.relu = nn.ReLU() 29 | self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1) 30 | 31 | def forward(self, x): 32 | return self.relu(self.conv2(self.relu(self.conv1(x)))) 33 | 34 | 35 | class Encoder(nn.Module): 36 | def __init__(self, chs=(1, 64, 128, 256, 512, 1024)): 37 | super().__init__() 38 | self.enc_blocks = nn.ModuleList([Block(chs[i], chs[i + 1]) for i in range(len(chs) - 1)]) 39 | self.pool = nn.MaxPool1d(2) 40 | 41 | def forward(self, x): 42 | ftrs = [] 43 | for block in self.enc_blocks: 44 | # print(x.shape) 45 | x = block(x) 46 | # print(x.shape) 47 | ftrs.append(x) 48 | x = self.pool(x) 49 | return ftrs 50 | 51 | 52 | class Encoder2d(nn.Module): 53 | def __init__(self, chs=(1, 64, 128, 256, 512, 1024)): 54 | super().__init__() 55 | self.enc_blocks = nn.ModuleList([Block2d(chs[i], chs[i + 1]) for i in range(len(chs) - 1)]) 56 | self.pool = nn.MaxPool2d(2) 57 | 58 | def forward(self, x): 59 | ftrs = [] 60 | for block in self.enc_blocks: 61 | # print(x.shape) 62 | x = block(x) 63 | # print(x.shape) 64 | ftrs.append(x) 65 | x = self.pool(x) 66 | return ftrs 67 | 68 | 69 | class Decoder(nn.Module): 70 | def __init__(self, chs=(1024, 512, 256, 128, 64), cond_size=0): 71 | super().__init__() 72 | self.chs = chs 73 | cond_sizes = [0] * len(chs) 74 | cond_sizes[0] = cond_size 75 | self.upconvs = nn.ModuleList([nn.ConvTranspose1d(chs[i] + cond_sizes[i], 76 | chs[i + 1], 2, stride=2) for i in range(len(chs) - 1)]) 77 | self.dec_blocks = nn.ModuleList([Block(chs[i], chs[i + 1]) for i in range(len(chs) - 1)]) 78 | 79 | def forward(self, x, encoder_features): 80 | for i in range(len(self.chs) - 1): 81 | x = self.upconvs[i](x) 82 | enc_ftrs = self.crop(encoder_features[i], x) 83 | x = torch.cat([x, enc_ftrs], dim=1) 84 | x = self.dec_blocks[i](x) 85 | return x 86 | 87 | def crop(self, enc_ftrs, x): 88 | _, _, H = x.shape 89 | enc_ftrs = torchvision.transforms.CenterCrop([H, 1])(enc_ftrs.unsqueeze(3)).squeeze(3) 90 | return enc_ftrs 91 | 92 | 93 | class Decoder2d(nn.Module): 94 | def __init__(self, chs=(1024, 512, 256, 128, 64), cond_size=0): 95 | super().__init__() 96 | self.chs = chs 97 | cond_sizes = [0] * len(chs) 98 | cond_sizes[0] = cond_size 99 | self.upconvs = nn.ModuleList([nn.ConvTranspose2d(chs[i] + cond_sizes[i], 100 | chs[i + 1], 2, stride=2) for i in range(len(chs) - 1)]) 101 | self.dec_blocks = nn.ModuleList([Block2d(chs[i], chs[i + 1]) for i in range(len(chs) - 1)]) 102 | 103 | def forward(self, x, encoder_features): 104 | for i in range(len(self.chs) - 1): 105 | x = self.upconvs[i](x) 106 | enc_ftrs = self.crop(encoder_features[i], x) 107 | x = torch.cat([x, enc_ftrs], dim=1) 108 | x = self.dec_blocks[i](x) 109 | return x 110 | 111 | def crop(self, enc_ftrs, x): 112 | _, _, H, W = x.shape 113 | enc_ftrs = torchvision.transforms.CenterCrop([H, W])(enc_ftrs)#.squeeze(3) 114 | return enc_ftrs 115 | 116 | 117 | class UNet(nn.Module): 118 | def __init__(self, enc_chs=(1, 64, 128, 256, 512, 1024), dec_chs=(1024, 512, 256, 128, 64), num_class=1, 119 | retain_dim=True): 120 | super().__init__() 121 | self.encoder = Encoder(enc_chs) 122 | self.decoder = Decoder(dec_chs) 123 | self.head = nn.Conv1d(dec_chs[-1], num_class, 1) 124 | self.final_act = nn.Sigmoid() 125 | self.retain_dim = retain_dim 126 | 127 | def forward(self, x): 128 | enc_ftrs = self.encoder(x) 129 | out = self.decoder(enc_ftrs[::-1][0], enc_ftrs[::-1][1:]) 130 | out = self.head(out) 131 | out = self.final_act(out) 132 | if self.retain_dim: 133 | out = F.interpolate(out, x.shape[-1]) 134 | return out 135 | 136 | 137 | class UNet2d(nn.Module): 138 | def __init__(self, enc_chs=(1, 64, 128, 256, 512, 1024), dec_chs=(1024, 512, 256, 128, 64), num_class=1, 139 | retain_dim=True): 140 | super().__init__() 141 | self.encoder = Encoder2d(enc_chs) 142 | self.decoder = Decoder2d(dec_chs) 143 | self.head = nn.Conv2d(dec_chs[-1], num_class, 1) 144 | self.final_act = nn.Sigmoid() 145 | self.retain_dim = retain_dim 146 | 147 | def forward(self, x): 148 | enc_ftrs = self.encoder(x) 149 | out = self.decoder(enc_ftrs[::-1][0], enc_ftrs[::-1][1:]) 150 | out = self.head(out) 151 | out = self.final_act(out) 152 | if self.retain_dim: 153 | out = F.interpolate(out, x.shape[-1]) 154 | return out 155 | 156 | 157 | class ConditionalUNet(nn.Module): 158 | def __init__(self, enc_chs=(1, 64, 128, 256, 512, 1024), dec_chs=(1024, 512, 256, 128, 64), 159 | num_class=1, retain_dim=True, cond_dim=0, final_act=None): 160 | super().__init__() 161 | self.encoder = Encoder(enc_chs) 162 | self.decoder = Decoder(dec_chs, cond_size=cond_dim) 163 | self.head = nn.Conv1d(dec_chs[-1], num_class, 1) 164 | self.final_act = nn.Sigmoid() if final_act is not None else nn.Identity() 165 | self.retain_dim = retain_dim 166 | 167 | def forward(self, x, cond=None): 168 | enc_ftrs = self.encoder(x) 169 | if cond is not None: 170 | cond = cond.unsqueeze(2).expand(-1, -1, enc_ftrs[-1].shape[2]) 171 | enc_ftrs[-1] = torch.cat((cond, enc_ftrs[-1]), 1) 172 | out = self.decoder(enc_ftrs[::-1][0], enc_ftrs[::-1][1:]) 173 | out = self.head(out) 174 | out = self.final_act(out) 175 | if self.retain_dim: 176 | out = F.interpolate(out, x.shape[-1]) 177 | return out 178 | 179 | 180 | class ConditionalUNetReactionDiffusion(nn.Module): 181 | def __init__(self, z_a_dim=1, enc_chs=(2, 16, 32, 64, 128), dec_chs=(128, 64, 32, 16), 182 | num_class=2, retain_dim=True, final_act=None): 183 | super().__init__() 184 | self.encoder = Encoder2d(enc_chs) 185 | self.decoder = Decoder2d(dec_chs, cond_size=z_a_dim) 186 | self.head = nn.Conv2d(dec_chs[-1], num_class, 1) 187 | self.final_act = nn.Sigmoid() if final_act is not None else nn.Identity() 188 | self.retain_dim = retain_dim 189 | 190 | def forward(self, x, cond=None): 191 | enc_ftrs = self.encoder(x) 192 | if cond is not None: 193 | cond = cond.unsqueeze(2).unsqueeze(2).expand(-1, -1, enc_ftrs[-1].shape[2], enc_ftrs[-1].shape[3]) 194 | enc_ftrs[-1] = torch.cat((cond, enc_ftrs[-1]), 1) 195 | out = self.decoder(enc_ftrs[::-1][0], enc_ftrs[::-1][1:]) 196 | out = self.head(out) 197 | out = self.final_act(out) 198 | if self.retain_dim: 199 | out = F.interpolate(out, x.shape[-1]) 200 | return out -------------------------------------------------------------------------------- /code/scripts/run_experiments_robustness.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 3 | 4 | import shutil 5 | 6 | import torch 7 | from matplotlib import pyplot as plt 8 | from torch import nn 9 | 10 | from code.hybrid_models.APHYNITY import APHYNITYAutoencoderDoublePendulum 11 | from code.simulators import PhysicalModel 12 | from code.simulators import DampedPendulum, RLCCircuit, ReactionDiffusion, DoublePendulum 13 | from code.hybrid_models import APHYNITYAutoencoder, APHYNITYAutoencoderReactionDiffusion, HybridVAEDoublePendulum, \ 14 | HybridVAEReactionDiffusion, HybridVAE 15 | import yaml 16 | from code.utils import plot_curves_partial, load_data, plot_curves_complete, plot_curves_double_pendulum 17 | import torch.utils.data as data_utils 18 | from tqdm import tqdm 19 | import os 20 | import math 21 | import argparse 22 | from datetime import datetime 23 | from code.utils.utils import * 24 | 25 | 26 | 27 | def run_exp(s: PhysicalModel, verbose=False, config=None, solver="None", config_name="", data_path=None, 28 | save_path=None): 29 | if torch.cuda.is_available(): 30 | device = "cuda:0" 31 | elif torch.backends.mps.is_available(): 32 | device = "cpu" #"mps" commented because conv3d not supported yet. 33 | else: 34 | device = "cpu" 35 | 36 | if config is None: 37 | raise Exception("No config provided.") 38 | 39 | param_text = s.incomplete_param_dim_textual 40 | # Loading data 41 | data_path = 'code/data/%s' % s.name if data_path is None else data_path 42 | print("Loading data in %s" % data_path) 43 | (t_train, x_train, true_param_train), dl_train, \ 44 | (t_valid, x_valid, true_param_valid), \ 45 | (t_test_shifted, x_test_shifted, true_param_test_shifted) = load_data(data_path, device) 46 | 47 | # Option to enforce training with reduced timeframe for the double pendulum 48 | if config.get("reduced_time_frame", False): 49 | print("training with fewer steps ( %d ) than valid and test." % config.get("nb_observed", t_train.shape[0])) 50 | t_train = t_train[:config.get("nb_observed", t_train.shape[0])] 51 | print(x_train.shape) 52 | x_train = x_train[:, :config.get("nb_observed", t_train.shape[0])] 53 | train = data_utils.TensorDataset(x_train) 54 | dl_train = data_utils.DataLoader(train, batch_size=config.get("b_size", 100), shuffle=True, drop_last=True) 55 | 56 | # Creating the model 57 | starting_param = s.init_param 58 | fp = s.get_incomplete_physical_model(starting_param, trainable=True) 59 | 60 | config["lambda_p"] = config.get("lambda_0", float("nan")) 61 | 62 | trained_model, model, optimizer = get_models(solver, fp, device, config) 63 | print(model) 64 | 65 | lambda_0 = config.get("lambda_0", float("nan")) 66 | tau_2 = config.get("tau_2", float("nan")) 67 | n_iter = config.get("n_iter", 1) 68 | all_x_train = x_train 69 | 70 | # Expert augmentation if there exist a trained model. 71 | if trained_model is not None: 72 | # Generating an augmented training set: 73 | all_x_train_augm = [] 74 | all_zp_train_augm = [] 75 | print("Generating augmented data...") 76 | for i in range(config.get("nb_augmentation", 1)): 77 | for x_train in tqdm(dl_train): 78 | x_train = x_train[0].to(device) 79 | with torch.no_grad(): 80 | x_train_augm, z_p_augm = trained_model.augmented_data(t_train, x_train) 81 | all_x_train_augm.append(x_train_augm) 82 | all_zp_train_augm.append(z_p_augm) 83 | if config.get("combined_augmentation", True): 84 | all_x_train_augm.append(x_train) 85 | all_zp_train_augm.append(1/torch.zeros_like(z_p_augm)) 86 | train = data_utils.TensorDataset(torch.cat(all_x_train_augm, 0), torch.cat(all_zp_train_augm, 0)) 87 | dl_train = data_utils.DataLoader(train, batch_size=config.get("b_size", 100), shuffle=True, drop_last=True) 88 | 89 | # Variables to keep track of the loss along training and save the best model 90 | best_valid = float("inf") 91 | best_param = {} 92 | counter = 0 93 | valid_increase_nb = 0 94 | last_epochs_res = torch.zeros(5) # Only used for updating the lambda value of APHYNITY 95 | debug = False 96 | if not debug and config.get("normalize_loss", True): 97 | with torch.no_grad(): 98 | baseline, _, _ = model.loss(t_train, x_train.to(device)) 99 | baseline = baseline.item() 100 | else: 101 | print("Not normalizing the loss.") 102 | baseline = 1. 103 | 104 | for epoch in range(config["n_epochs"]): 105 | for i in range(n_iter): 106 | # Update step 107 | sum_loss = 0. 108 | sum_traj = 0. 109 | counter_loss = 0 110 | for x_train in tqdm(dl_train): 111 | 112 | if trained_model is not None: 113 | x_train, z_p = x_train 114 | x_train = x_train.to(device) 115 | z_p = z_p.to(device) 116 | else: 117 | x_train = x_train[0].to(device) 118 | if debug: 119 | counter_loss = 1 120 | fa_norm = torch.tensor(0.) 121 | break 122 | loss, fa_norm, l_trajectory = model.loss(t_train, x_train) 123 | loss = loss / baseline 124 | if trained_model is not None and config.get("loss_params", True): 125 | est_param = model.predicted_parameters(t_train, x_train) 126 | loss_param = ((est_param[~z_p.isinf()] - z_p[~z_p.isinf()]) ** 2).mean() 127 | loss = loss + loss_param 128 | 129 | optimizer.zero_grad() 130 | loss.backward() 131 | optimizer.step() 132 | sum_loss += loss.item() 133 | sum_traj += l_trajectory 134 | counter_loss += 1 135 | 136 | loss_train = sum_loss / counter_loss 137 | mean_traj = sum_traj / counter_loss 138 | last_epochs_res[counter] = loss_train 139 | counter = (counter + 1) % 5 140 | 141 | with torch.no_grad(): 142 | est_param_valid = model.predicted_parameters_as_dict(t_valid, x_valid, True) 143 | est_param_test_shifted = model.predicted_parameters_as_dict(t_test_shifted, x_test_shifted, True) 144 | est_param_train = model.predicted_parameters_as_dict(t_train, all_x_train, True) 145 | 146 | _, pred_x_valid = model.forward(t_valid, x_valid) 147 | _, pred_x_test_shifted = model.forward(t_test_shifted, x_test_shifted) 148 | _, pred_x_train = model.forward(t_train, all_x_train) 149 | 150 | loss_valid, fa_norm_valid, l_trajectory = model.loss(t_valid, x_valid) 151 | loss_test_shifted, fa_norm_test_shifted, l_trajectory = model.loss(t_test_shifted, x_test_shifted) 152 | if solver == "APHYNITYDoublePendulum": 153 | se_valid, _ = model.constraint_traj_from_sol(t_valid, x_valid, pred_x_valid) 154 | se_valid = se_valid.unsqueeze(1) 155 | se_test_shifted, _ = model.constraint_traj_from_sol(t_valid, x_test_shifted, 156 | pred_x_test_shifted) 157 | se_test_shifted = se_test_shifted.unsqueeze(1) 158 | print(se_valid.mean(), se_test_shifted.mean()) 159 | else: 160 | se_valid = ((pred_x_valid[:, model.nb_observed - 1:] - 161 | x_valid[:, model.nb_observed - 1:]) ** 2).sum(4).sum(3).sum(2) 162 | se_test_shifted = ((pred_x_test_shifted[:, model.nb_observed - 1:] - 163 | x_test_shifted[:, model.nb_observed - 1:]) ** 2).sum(4).sum(3).sum(2) 164 | 165 | mu_log_mse_valid = se_valid.mean(1).log().mean() 166 | std_log_mse_valid = se_valid.mean(1).log().std() 167 | mse_valid = se_valid.mean() 168 | mse_test_shifted = se_test_shifted.mean() 169 | if verbose: 170 | message = "Epoch {:d} - Training loss: {:4f} - Training error on trajectory: {:4f}" \ 171 | " - Validation loss: {:4f}" \ 172 | " - Test loss: {:4f}" \ 173 | " - Validation log-mse: {:4f}" \ 174 | " - Test log-mse: {:4f}" \ 175 | " - Validation log-mse mu: {:4f} ± {:4f}" \ 176 | " - train |fa|: {:4f} - valid |fa|: {:4f}".format(epoch, 177 | loss_train, 178 | mean_traj, 179 | loss_valid.item(), 180 | loss_test_shifted.item(), 181 | mse_valid.log().item(), 182 | mse_test_shifted.log().item(), 183 | mu_log_mse_valid.item(), 184 | std_log_mse_valid.item(), 185 | fa_norm.item(), 186 | fa_norm_valid.item()) 187 | 188 | print(message) 189 | 190 | cur_valid = mse_valid.log() 191 | if best_valid > cur_valid: 192 | fp = model.dec.fp if solver in ["APHYNITY"] else model.fp 193 | pth = save_path 194 | if solver in ["APHYNITY", "HybridVAE"]: 195 | plot_curves_partial(t_train.cpu(), 196 | pred_x_train.cpu()[:, model.nb_observed - 1:, :2, 0, 0], 197 | all_x_train.squeeze(2).cpu(), message, fp.get_x_labels(), pth + "train_" 198 | + fp.get_name(), est_param_train, true_param_train) 199 | plot_curves_partial(t_valid.cpu(), 200 | pred_x_valid.cpu()[:, model.nb_observed - 1:, :2, 0, 0], 201 | x_valid.squeeze(2).cpu(), message, fp.get_x_labels(), pth + "valid_" 202 | + fp.get_name(), est_param_valid, true_param_valid) 203 | plot_curves_partial(t_test_shifted.cpu(), 204 | pred_x_test_shifted.cpu()[:, model.nb_observed - 1:, :2, 0, 0], 205 | x_test_shifted.squeeze(2).cpu(), message, fp.get_x_labels(), 206 | pth + "test_" + fp.get_name(), 207 | est_param_test_shifted, true_param_test_shifted) 208 | elif solver == "APHYNITYDoublePendulum": 209 | nb_observed_theta_0 = model.nb_observed_theta_0 210 | nb_observed_theta_1 = model.nb_observed_theta_1 211 | plot_curves_double_pendulum(t_train.cpu(), 212 | torch.sin(pred_x_train).cpu()[:, :, :2, 0, 0], 213 | torch.sin(all_x_train).cpu()[:, :, :2, 0, 0], message, 214 | pth + "train" + fp.get_name(), nb_observed_theta_0, 215 | nb_observed_theta_1, model.nb_observed, 216 | est_param_train, true_param_train) 217 | plot_curves_double_pendulum(t_valid.cpu(), torch.sin(pred_x_valid).cpu()[:, :, :2, 0, 0], 218 | torch.sin(x_valid)[:, :, :2, 0, 0], message, pth + "valid_" + 219 | fp.get_name(), nb_observed_theta_0, nb_observed_theta_1, 220 | model.nb_observed, est_param_valid, true_param_valid) 221 | plot_curves_double_pendulum(t_test_shifted.cpu(), 222 | torch.sin(pred_x_test_shifted).cpu()[:, :, :2, 0, 0], 223 | torch.sin(x_test_shifted).cpu()[:, :, :2, 0, 0], message, 224 | pth + "test_" + "OOD_" + fp.get_name(), nb_observed_theta_0, 225 | nb_observed_theta_1, model.nb_observed, est_param_test_shifted, 226 | true_param_test_shifted) 227 | if trained_model is not None: 228 | save_name = solver + "_plus_best_valid_model.pt" 229 | else: 230 | save_name = solver + "_best_valid_model.pt" 231 | if config.get("save_all_models", False): 232 | torch.save(model.state_dict(), pth + save_name[:-3] + str(valid_increase_nb) + ".pt") 233 | valid_increase_nb += 1 234 | torch.save(model.state_dict(), pth + save_name) 235 | if solver in ["APHYNITY", "APHYNITYReactionDiffusion", "APHYNITYDoublePendulum"]: 236 | for p in param_text: 237 | best_param[p] = est_param_valid[p].squeeze(1) 238 | best_valid = cur_valid.item() 239 | best_test = mse_test_shifted.log().item() 240 | print("New best validation log-mse at epoch: %d" % epoch) 241 | 242 | if "APHYNITY" in solver and (last_epochs_res.max() - last_epochs_res.min()).abs() / last_epochs_res.max().abs() < .2: 243 | print("Increase constraint weight") 244 | lambda_0 += tau_2 * model.constraint_traj(t_train, x_train)[0].mean().item() 245 | 246 | # Logging metrics. 247 | metrics = {'Progress': epoch, 248 | 'Train Loss': loss_train, 249 | 'Log-MSE Validation': mse_valid.log().item(), 250 | 'Validation Loss': loss_valid.item(), 251 | 'Log-MSE OOD': mse_test_shifted.log().item(), 252 | 'OOD Loss': loss_test_shifted.item(), 253 | 'Train |fa|': fa_norm.item(), 254 | 'Validation |fa|': fa_norm_valid.item(), 255 | 'Progress Text': message, 256 | 'best_valid': best_valid} 257 | 258 | 259 | nb_exp = 40 260 | sim_dic = {"DampedPendulum": DampedPendulum, 261 | "RLC": RLCCircuit, 262 | "ReactionDiffusion": ReactionDiffusion, 263 | "DoublePendulum": DoublePendulum} 264 | 265 | # Create the parser 266 | parser = argparse.ArgumentParser() 267 | # Add an argument 268 | parser.add_argument('--config', type=str, default="code/scripts/configs/Pendulum/APHYNITY.yaml") 269 | # Parse the argument 270 | args = parser.parse_args() 271 | config_name = args.config.split('/')[-1] 272 | all_config = yaml.load(open(args.config, 'r'), Loader=yaml.FullLoader) 273 | config = all_config["parameters"] 274 | data_path = config.get("data_path", None) 275 | now = datetime.now() 276 | pth = "%s/runs/%s/%s" % (data_path, config_name, now.strftime("%m_%d_%Y_%H_%M_%S")) 277 | if not os.path.exists(pth): 278 | os.makedirs(pth) 279 | shutil.copyfile(args.config, pth + "/" + config_name) 280 | 281 | s = sim_dic[config["simulator"]["name"]](**config["simulator"]) 282 | path_models = config["optimization"]["path_model"][:-3] 283 | for i in range(nb_exp): 284 | print("------------- Robustness of Augmentation experiment %d -------------" % i) 285 | try: 286 | config["optimization"]["path_model"] = path_models + str(i) + ".pt" 287 | save_pth = pth + "/robustness_%d_" % i 288 | run_exp(s, True, config=config["optimization"], solver=config["optimization"].get("model", "None"), config_name=config_name, 289 | data_path=data_path, save_path=save_pth) 290 | except Exception as e: 291 | print(e) 292 | print("skipping identifier %d" % i) 293 | -------------------------------------------------------------------------------- /code/scripts/run_experiments.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 3 | 4 | import shutil 5 | 6 | import torch 7 | from matplotlib import pyplot as plt 8 | from torch import nn 9 | 10 | from code.hybrid_models.APHYNITY import APHYNITYAutoencoderDoublePendulum 11 | from code.simulators import PhysicalModel 12 | from code.simulators import DampedPendulum, RLCCircuit, ReactionDiffusion, DoublePendulum 13 | from code.hybrid_models import APHYNITYAutoencoder, APHYNITYAutoencoderReactionDiffusion, HybridVAEDoublePendulum, \ 14 | HybridVAEReactionDiffusion, HybridVAE 15 | import yaml 16 | from code.utils import plot_curves_partial, load_data, plot_curves_complete, plot_curves_double_pendulum 17 | import torch.utils.data as data_utils 18 | from tqdm import tqdm 19 | import os 20 | import math 21 | import argparse 22 | from datetime import datetime 23 | from code.utils.utils import * 24 | 25 | 26 | 27 | def run_exp(s: PhysicalModel, verbose=False, config=None, solver="None", config_name="", data_path=None, 28 | save_path=None): 29 | if torch.cuda.is_available(): 30 | device = "cuda:0" 31 | elif torch.backends.mps.is_available(): 32 | device = "cpu" #"mps" commented because conv3d not supported yet. 33 | else: 34 | device = "cpu" 35 | 36 | if config is None: 37 | raise Exception("No config provided.") 38 | 39 | param_text = s.incomplete_param_dim_textual 40 | # Loading data 41 | data_path = 'code/data/%s' % s.name if data_path is None else data_path 42 | print("Loading data in %s" % data_path) 43 | (t_train, x_train, true_param_train), dl_train, \ 44 | (t_valid, x_valid, true_param_valid), \ 45 | (t_test_shifted, x_test_shifted, true_param_test_shifted) = load_data(data_path, device) 46 | 47 | # Option to enforce training with reduced timeframe for the double pendulum 48 | # Training with long sequences is hard because of the chaotic behaviors of the double pendulum equations. 49 | if config.get("reduced_time_frame", False): 50 | print("training with fewer steps ( %d ) than valid and test." % config.get("nb_observed", t_train.shape[0])) 51 | t_train = t_train[:config.get("nb_observed", t_train.shape[0])] 52 | x_train = x_train[:, :config.get("nb_observed", t_train.shape[0])] 53 | train = data_utils.TensorDataset(x_train) 54 | dl_train = data_utils.DataLoader(train, batch_size=config.get("b_size", 100), shuffle=True, drop_last=True) 55 | 56 | # Creating the model 57 | starting_param = s.init_param 58 | fp = s.get_incomplete_physical_model(starting_param, trainable=True) 59 | 60 | config["lambda_p"] = config.get("lambda_0", float("nan")) 61 | 62 | trained_model, model, optimizer = get_models(solver, fp, device, config) 63 | #print(model) 64 | 65 | lambda_0 = config.get("lambda_0", float("nan")) 66 | tau_2 = config.get("tau_2", float("nan")) 67 | n_iter = config.get("n_iter", 1) 68 | all_x_train = x_train 69 | 70 | # Expert augmentation if there exist a trained model. 71 | if trained_model is not None and True: 72 | # Generating an augmented training set: 73 | all_x_train_augm = [] 74 | all_zp_train_augm = [] 75 | print("Generating augmented data...") 76 | for i in range(config.get("nb_augmentation", 1)): 77 | for j, x_train in enumerate(tqdm(dl_train)): 78 | x_train = x_train[0].to(device) 79 | with torch.no_grad(): 80 | x_train_augm, z_p_augm = trained_model.augmented_data(t_train, x_train) 81 | all_x_train_augm.append(x_train_augm) 82 | all_zp_train_augm.append(z_p_augm) 83 | if config.get("combined_augmentation", True): 84 | all_x_train_augm.append(x_train) 85 | all_zp_train_augm.append(1/torch.zeros_like(z_p_augm)) 86 | train = data_utils.TensorDataset(torch.cat(all_x_train_augm, 0), torch.cat(all_zp_train_augm, 0)) 87 | dl_train = data_utils.DataLoader(train, batch_size=config.get("b_size", 100), shuffle=True, drop_last=True) 88 | 89 | # Variables to keep track of the loss along training and save the best model 90 | best_valid = float("inf") 91 | best_param = {} 92 | counter = 0 93 | valid_increase_nb = 0 94 | last_epochs_res = torch.zeros(5) # Only used for updating the lambda value of APHYNITY 95 | debug = False 96 | if not debug and config.get("normalize_loss", True): 97 | with torch.no_grad(): 98 | baseline, _, _ = model.loss(t_train, x_train.to(device)) 99 | baseline = baseline.item() 100 | else: 101 | print("Not normalizing the loss.") 102 | baseline = 1. 103 | 104 | for epoch in range(config["n_epochs"]): 105 | for i in range(n_iter): 106 | # Update step 107 | sum_loss = 0. 108 | sum_traj = 0. 109 | counter_loss = 0 110 | for x_train in tqdm(dl_train): 111 | 112 | if trained_model is not None and True: 113 | x_train, z_p = x_train 114 | x_train = x_train.to(device) 115 | z_p = z_p.to(device) 116 | else: 117 | x_train = x_train[0].to(device) 118 | if debug: 119 | counter_loss = 1 120 | fa_norm = torch.tensor(0.) 121 | break 122 | loss, fa_norm, l_trajectory = model.loss(t_train, x_train) 123 | loss = loss / baseline 124 | if trained_model is not None and config.get("loss_params", True) and True: 125 | est_param = model.predicted_parameters(t_train, x_train) 126 | loss_param = ((est_param[~z_p.isinf()] - z_p[~z_p.isinf()]) ** 2).mean() 127 | loss = loss + loss_param 128 | 129 | optimizer.zero_grad() 130 | loss.backward() 131 | optimizer.step() 132 | sum_loss += loss.item() 133 | sum_traj += l_trajectory 134 | counter_loss += 1 135 | 136 | loss_train = sum_loss / counter_loss 137 | mean_traj = sum_traj / counter_loss 138 | last_epochs_res[counter] = loss_train 139 | counter = (counter + 1) % 5 140 | 141 | with torch.no_grad(): 142 | est_param_valid = model.predicted_parameters_as_dict(t_valid, x_valid, True) 143 | est_param_test_shifted = model.predicted_parameters_as_dict(t_test_shifted, x_test_shifted, True) 144 | est_param_train = model.predicted_parameters_as_dict(t_train, all_x_train, True) 145 | 146 | _, pred_x_valid = model.forward(t_valid, x_valid) 147 | _, pred_x_test_shifted = model.forward(t_test_shifted, x_test_shifted) 148 | _, pred_x_train = model.forward(t_train, all_x_train) 149 | 150 | loss_valid, fa_norm_valid, l_trajectory = model.loss(t_valid, x_valid) 151 | loss_test_shifted, fa_norm_test_shifted, l_trajectory = model.loss(t_test_shifted, x_test_shifted) 152 | if solver in ["APHYNITYDoublePendulum", "HybridVAEDoublePendulum"]: 153 | se_train, _ = model.constraint_traj_from_sol(t_train, all_x_train, pred_x_train) 154 | se_valid, _ = model.constraint_traj_from_sol(t_valid, x_valid, pred_x_valid) 155 | se_valid = se_valid.unsqueeze(1) 156 | se_test_shifted, _ = model.constraint_traj_from_sol(t_valid, x_test_shifted, 157 | pred_x_test_shifted) 158 | se_test_shifted = se_test_shifted.unsqueeze(1) 159 | print(se_valid.mean(), se_test_shifted.mean()) 160 | else: 161 | se_train = ((pred_x_train[:, model.nb_observed - 1:] - 162 | all_x_train[:, model.nb_observed - 1:]) ** 2).sum(4).sum(3).sum(2) 163 | se_valid = ((pred_x_valid[:, model.nb_observed - 1:] - 164 | x_valid[:, model.nb_observed - 1:]) ** 2).sum(4).sum(3).sum(2) 165 | se_test_shifted = ((pred_x_test_shifted[:, model.nb_observed - 1:] - 166 | x_test_shifted[:, model.nb_observed - 1:]) ** 2).sum(4).sum(3).sum(2) 167 | 168 | mu_log_mse_valid = se_valid.mean(1).log().mean() 169 | std_log_mse_valid = se_valid.mean(1).log().std() 170 | mse_train = se_train.mean() 171 | mse_valid = se_valid.mean() 172 | mse_test_shifted = se_test_shifted.mean() 173 | if verbose: 174 | message = "Epoch {:d} - Training loss: {:4f} - Training error on trajectory: {:4f}" \ 175 | " - Validation loss: {:4f}" \ 176 | " - Test loss: {:4f}" \ 177 | " - Training log-mse: {:4f}" \ 178 | " - Validation log-mse: {:4f}" \ 179 | " - Test log-mse: {:4f}" \ 180 | " - Validation log-mse mu: {:4f} ± {:4f}" \ 181 | " - train |fa|: {:4f} - valid |fa|: {:4f}".format(epoch, 182 | loss_train, 183 | mean_traj, 184 | loss_valid.item(), 185 | loss_test_shifted.item(), 186 | mse_train.log().item(), 187 | mse_valid.log().item(), 188 | mse_test_shifted.log().item(), 189 | mu_log_mse_valid.item(), 190 | std_log_mse_valid.item(), 191 | fa_norm.item(), 192 | fa_norm_valid.item()) 193 | 194 | print(message) 195 | 196 | cur_valid = mse_valid.log() 197 | if best_valid > cur_valid: 198 | fp = model.dec.fp if solver in ["APHYNITY"] else model.fp 199 | pth = "%s/" % save_path 200 | if solver in ["APHYNITY", "HybridVAE"]: 201 | plot_curves_partial(t_train.cpu(), 202 | pred_x_train.cpu()[:, model.nb_observed - 1:, :2, 0, 0], 203 | all_x_train.squeeze(2).cpu(), message, fp.get_x_labels(), pth + "train_" 204 | + fp.get_name(), est_param_train, true_param_train) 205 | plot_curves_partial(t_valid.cpu(), 206 | pred_x_valid.cpu()[:, model.nb_observed - 1:, :2, 0, 0], 207 | x_valid.squeeze(2).cpu(), message, fp.get_x_labels(), pth + "valid_" 208 | + fp.get_name(), est_param_valid, true_param_valid) 209 | plot_curves_partial(t_test_shifted.cpu(), 210 | pred_x_test_shifted.cpu()[:, model.nb_observed - 1:, :2, 0, 0], 211 | x_test_shifted.squeeze(2).cpu(), message, fp.get_x_labels(), 212 | pth + "test_" + fp.get_name(), 213 | est_param_test_shifted, true_param_test_shifted) 214 | elif solver in ["APHYNITYDoublePendulum", "HybridVAEDoublePendulum"]: 215 | nb_observed_theta_0 = model.nb_observed_theta_0 216 | nb_observed_theta_1 = model.nb_observed_theta_1 217 | plot_curves_double_pendulum(t_train.cpu(), 218 | torch.sin(pred_x_train).cpu()[:, :, :2, 0, 0], 219 | torch.sin(all_x_train).cpu()[:, :, :2, 0, 0], message, 220 | pth + "train" + fp.get_name(), nb_observed_theta_0, 221 | nb_observed_theta_1, model.nb_observed, 222 | est_param_train, true_param_train) 223 | plot_curves_double_pendulum(t_valid.cpu(), torch.sin(pred_x_valid).cpu()[:, :, :2, 0, 0], 224 | torch.sin(x_valid)[:, :, :2, 0, 0], message, pth + "valid_" + 225 | fp.get_name(), nb_observed_theta_0, nb_observed_theta_1, 226 | model.nb_observed, est_param_valid, true_param_valid) 227 | plot_curves_double_pendulum(t_test_shifted.cpu(), 228 | torch.sin(pred_x_test_shifted).cpu()[:, :, :2, 0, 0], 229 | torch.sin(x_test_shifted).cpu()[:, :, :2, 0, 0], message, 230 | pth + "test_" + "OOD_" + fp.get_name(), nb_observed_theta_0, 231 | nb_observed_theta_1, model.nb_observed, est_param_test_shifted, 232 | true_param_test_shifted) 233 | if trained_model is not None: 234 | save_name = solver + "_plus_best_valid_model.pt" 235 | else: 236 | save_name = solver + "_best_valid_model.pt" 237 | if config.get("save_all_models", False): 238 | torch.save(model.state_dict(), pth + save_name[:-3] + str(valid_increase_nb) + ".pt") 239 | valid_increase_nb += 1 240 | torch.save(model.state_dict(), pth + save_name) 241 | if solver in ["APHYNITY", "APHYNITYReactionDiffusion", "APHYNITYDoublePendulum"]: 242 | for p in param_text: 243 | best_param[p] = est_param_valid[p].squeeze(1) 244 | best_valid = cur_valid.item() 245 | best_test = mse_test_shifted.log().item() 246 | print("New best validation log-mse at epoch: %d" % epoch) 247 | 248 | if "APHYNITY" in solver and (last_epochs_res.max() - last_epochs_res.min()).abs() / last_epochs_res.max().abs() < .2: 249 | print("Increase constraint weight") 250 | lambda_0 += tau_2 * model.constraint_traj(t_train, x_train)[0].mean().item() 251 | 252 | # Logging metrics. 253 | metrics = {'Progress': epoch, 254 | 'Train Loss': loss_train, 255 | 'Log-MSE Validation': mse_valid.log().item(), 256 | 'Validation Loss': loss_valid.item(), 257 | 'Log-MSE OOD': mse_test_shifted.log().item(), 258 | 'OOD Loss': loss_test_shifted.item(), 259 | 'Train |fa|': fa_norm.item(), 260 | 'Validation |fa|': fa_norm_valid.item(), 261 | 'Progress Text': message, 262 | 'best_valid': best_valid} 263 | 264 | 265 | nb_exp = 1 266 | sim_dic = {"DampedPendulum": DampedPendulum, 267 | "RLC": RLCCircuit, 268 | "ReactionDiffusion": ReactionDiffusion, 269 | "DoublePendulum": DoublePendulum} 270 | 271 | # Create the parser 272 | parser = argparse.ArgumentParser() 273 | # Add an argument 274 | parser.add_argument('--config', type=str, default="code/scripts/configs/Pendulum/APHYNITY.yaml") 275 | # Parse the argument 276 | args = parser.parse_args() 277 | config_name = args.config.split('/')[-1] 278 | all_config = yaml.load(open(args.config, 'r'), Loader=yaml.FullLoader) 279 | config = all_config["parameters"] 280 | data_path = config.get("data_path", None) 281 | now = datetime.now() 282 | pth = "%s/runs/%s/%s" % (data_path, config_name, now.strftime("%m_%d_%Y_%H_%M_%S")) 283 | if not os.path.exists(pth): 284 | os.makedirs(pth) 285 | shutil.copyfile(args.config, pth + "/" + config_name) 286 | 287 | s = sim_dic[config["simulator"]["name"]](**config["simulator"]) 288 | 289 | run_exp(s, True, config=config["optimization"], solver=config["optimization"].get("model", "None"), config_name=config_name, 290 | data_path=data_path, save_path=pth) 291 | -------------------------------------------------------------------------------- /code/hybrid_models/APHYNITY.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 3 | 4 | import math 5 | import pickle 6 | 7 | import torch.nn as nn 8 | from code.simulators import PhysicalModel, NoSimulator 9 | import torch 10 | from code.nn import MLP, ConditionalUNet, Permute, act_dict 11 | from torchdiffeq import odeint 12 | from code.hybrid_models.HybridAutoencoder import HybridAutoencoder 13 | 14 | 15 | class Encoder(nn.Module): 16 | def __init__(self, x_dim, hidden_size): 17 | super(Encoder, self).__init__() 18 | self.rnn = nn.GRU(input_size=x_dim, hidden_size=hidden_size, batch_first=True) 19 | self.mu, self.sigma = None, None 20 | 21 | def forward(self, t, x): 22 | _, out = self.rnn(x.squeeze(4).squeeze(3)) 23 | return out[0] 24 | 25 | 26 | class DoublePendulumEncoder(nn.Module): 27 | def __init__(self, layers=[300, 300, 300], za_dim=50, ze_dim=4, initial_guess=True, 28 | nb_observed_theta_0=25, nb_observed_theta_1=25, obtain_init_position=False): 29 | super(DoublePendulumEncoder, self).__init__() 30 | self.za_dim = za_dim 31 | self.ze_dim = ze_dim 32 | in_dim = 2 33 | self.obtain_init_position = obtain_init_position 34 | self.initial_guess = initial_guess 35 | self.total_time = nb_observed_theta_0 + nb_observed_theta_1 36 | self.nb_observed_theta_0 = nb_observed_theta_0 37 | self.nb_observed_theta_1 = nb_observed_theta_1 38 | 39 | lis = [self.total_time*in_dim] + layers 40 | los = layers + [za_dim + ze_dim] 41 | layers = [] 42 | for li, lo in zip(lis, los): 43 | layers += [nn.Linear(li, lo), nn.SELU()] 44 | layers.pop() 45 | self.encoder = nn.Sequential(*layers) 46 | 47 | def forward(self, t, x): 48 | in_s = t.shape[0] 49 | frequency = 1 / t[1] 50 | if self.total_time < 50 or True: 51 | x_masked = torch.cat((x[:, -self.nb_observed_theta_0:, 0], x[:, -self.nb_observed_theta_1:, 1]), 1) 52 | else: 53 | dx, dy = x[:, 0, :], x[:, 2, :] 54 | diff = (dx - dy).unsqueeze(2) 55 | choices = torch.cat((diff - 2 * math.pi, diff, diff + 2 * math.pi), 2) 56 | _, choice = torch.min(choices ** 2, 2) # np.arctan2(np.sin(x-y), np.cos(x-y)) * 200 57 | omegas = torch.gather(choices, dim=2, index=choice.unsqueeze(2)).squeeze(2) * frequency / 2 58 | init_state = torch.cat((x[:, 1, :], -omegas), 1) 59 | x_masked = x.view(x.shape[0], -1) 60 | 61 | sin_cos_encoded = torch.cat((torch.sin(x_masked), torch.cos(x_masked)), 1) 62 | z_all = self.encoder(sin_cos_encoded.reshape(x_masked.shape[0], -1)) 63 | 64 | if self.initial_guess and self.total_time == 50: 65 | z_e = init_state + z_all[:, :self.ze_dim] #* 0.1 66 | elif self.obtain_init_position: 67 | z_e = torch.cat([x[:, 0, :2, 0, 0], z_all[:, 2:self.ze_dim]], 1) 68 | else: 69 | z_e = z_all[:, :self.ze_dim] 70 | z_a = None if self.za_dim == 0 else z_all[:, self.ze_dim:] 71 | return z_e, z_a 72 | 73 | 74 | class HybridDecoder(nn.Module): 75 | def __init__(self, fp: PhysicalModel, fp_param_converter_hidden=None, fp_param_converter_act="SELU", 76 | fp_param_converter_final_act="Softplus", fa_hidden=None, fa_hidden_act="SELU", fa_final_act=None, 77 | encoder_dim=128, **kwargs): 78 | super(HybridDecoder, self).__init__() 79 | 80 | # Create net that maps hidden from encoder to physical parameters. 81 | if fp_param_converter_hidden is None: 82 | fp_param_converter_hidden = 4 * [300] 83 | layers_fp = [encoder_dim] + fp_param_converter_hidden + [len(fp.incomplete_param_dim_textual)] 84 | self.fp_param_converter = MLP(layers_fp, hidden_act=fp_param_converter_act, 85 | final_act=fp_param_converter_final_act) 86 | 87 | # Create net that maps hidden from encoder + state x to fa(x;h) contribution. 88 | if fa_hidden is None: 89 | fa_hidden = [300] * 4 90 | x_dim = fp._X_dim 91 | layers_fa = [encoder_dim + x_dim] + fa_hidden + [x_dim] 92 | self.fa = MLP(layers_fa, hidden_act=fa_hidden_act, final_act=fa_final_act) 93 | 94 | self.fp = fp 95 | 96 | self.h = None 97 | self.fp_params = None 98 | 99 | def to(self, device): 100 | super(HybridDecoder, self).to(device) 101 | self.fp.to(device) 102 | self.fa.to(device) 103 | self.fp_param_converter.to(device) 104 | return self 105 | 106 | def set_hidden(self, h): 107 | # h: B_size x hidden_size 108 | self.h = h 109 | 110 | def forward_fa(self, t, x): 111 | self.fa(torch.cat((x, self.h), 1)) 112 | 113 | return self.fa(torch.cat((x, self.h), 1)) 114 | 115 | def get_physical_parameters(self, x, zero_param=False, zp=None, as_dict=True): 116 | fp_params = self.fp_param_converter(self.h) if zp is None else zp 117 | physical_params = {} 118 | if not as_dict: 119 | return fp_params 120 | for i, p in enumerate(self.fp.incomplete_param_dim_textual): 121 | physical_params[p] = fp_params[:, i].unsqueeze(1) 122 | if zero_param: 123 | for i, p in enumerate(self.fp.missing_param_dim_textual): 124 | physical_params[p] = torch.zeros(x.shape[0], device=x.device).unsqueeze(1) 125 | 126 | return physical_params 127 | 128 | def forward_fp(self, t, x, zp=None): 129 | physical_params = self.get_physical_parameters(x) if zp is None else zp 130 | return self.fp.parameterized_forward(t, x, **physical_params) 131 | 132 | def forward(self, t, x, zp=None): 133 | if x.shape[0] != self.h.shape[0]: 134 | raise Exception( 135 | "Mismatch between hidden state batch size %d and x batch size %d." % (self.h.shape[0], x.shape[0])) 136 | 137 | dx = self.forward_fa(t, x) + self.forward_fp(t, x, zp) 138 | return dx#.unsqueeze(1) 139 | 140 | 141 | class APHYNITYAutoencoder(HybridAutoencoder): 142 | def __init__(self, fp, augmented=False, zp_priors=None, device="cpu", **config): 143 | super(APHYNITYAutoencoder, self).__init__() 144 | encoder_out = 128 145 | self.enc = Encoder(fp._X_dim, encoder_out).to(device) 146 | param_decoder = {"fp": fp.to(device), 147 | "fp_param_converter_hidden_w": 200, 148 | "fp_param_converter_hidden_n": 3, 149 | "fp_param_converter_act": "ReLU", 150 | "fp_param_converter_final_act": "Softplus", 151 | "fa_hidden_w": 200, 152 | "fa_hidden_n": 3, 153 | "fa_hidden_act": "ReLU", 154 | "fa_final_act": None, 155 | "encoder_dim": encoder_out} 156 | 157 | self.lambda_p = config.get("lambda_p", float("nan")) 158 | 159 | for k, v in param_decoder.items(): 160 | if k[:2] == "fa": 161 | if k[3:] in config.get("fa", {}): 162 | param_decoder[k] = config["fa"][k[3:]] 163 | elif k in config: 164 | param_decoder[k] = config[k] 165 | 166 | param_decoder["fa_hidden"] = [param_decoder["fa_hidden_w"]] * param_decoder["fa_hidden_n"] 167 | param_decoder["fp_param_converter_hidden"] = [param_decoder["fp_param_converter_hidden_w"]] * param_decoder["fp_param_converter_hidden_n"] 168 | 169 | self.dec = HybridDecoder(**param_decoder).to(device) 170 | self.nb_observed = 50 171 | self.device = device 172 | self.augmented = augmented 173 | self.min_zp = None 174 | if zp_priors is not None and self.augmented: 175 | min_zp = [] 176 | max_zp = [] 177 | for k, v in zp_priors.items(): 178 | min_zp.append(v["min"]) 179 | max_zp.append(v["max"]) 180 | self.min_zp = torch.tensor(min_zp, device=device) 181 | self.max_zp = torch.tensor(max_zp, device=device) 182 | 183 | def to(self, device): 184 | self.device = device 185 | self.enc.to(device) 186 | self.dec.to(device) 187 | if self.min_zp is not None: 188 | self.min_zp.to(device) 189 | self.max_zp.to(device) 190 | return self 191 | 192 | def forward(self, t_span, x): 193 | h = self.enc(t_span[:self.nb_observed], x[:, :self.nb_observed]) 194 | self.dec.set_hidden(h) 195 | x_pred = odeint(self.dec, x[:, 0, :, 0, 0], t_span, atol=1e-5, rtol=1e-5) 196 | return t_span, x_pred.permute(1, 0, 2).unsqueeze(3).unsqueeze(4) 197 | 198 | def augmented_data(self, t_span, x): 199 | h = self.enc(t_span[:self.nb_observed], x[:, :self.nb_observed]) 200 | self.dec.set_hidden(h) 201 | 202 | resampled_zp = torch.rand(x.shape[0], self.min_zp.shape[0], device=self.device) * (self.max_zp - self.min_zp) \ 203 | + self.min_zp 204 | fp = self.dec.get_physical_parameters(None, None, resampled_zp) 205 | with torch.no_grad(): 206 | dec = lambda t, x: self.dec(t, x, fp) 207 | x_pred = odeint(dec, x[:, 0, :, 0, 0], t_span, atol=1e-5, rtol=1e-5) 208 | return x_pred.permute(1, 0, 2).unsqueeze(3).unsqueeze(4), resampled_zp 209 | 210 | def predicted_parameters(self, t_span, x, zero_param=False): 211 | h = self.enc(t_span[:self.nb_observed], x[:, :self.nb_observed]) 212 | self.dec.set_hidden(h) 213 | return self.dec.get_physical_parameters(x, zero_param, as_dict=False) 214 | 215 | def predicted_parameters_as_dict(self, t_span, x, zero_param=False) -> dict: 216 | h = self.enc(t_span[:self.nb_observed], x[:, :self.nb_observed]) 217 | self.dec.set_hidden(h) 218 | return self.dec.get_physical_parameters(x, zero_param, as_dict=True) 219 | 220 | def penalty(self, t_eval, x_span): 221 | x = x_span[:, :self.nb_observed, :, 0, 0] 222 | concat_x_h = torch.cat((self.dec.h.unsqueeze(1).expand(-1, x.shape[1], -1), x), 2) 223 | l = (self.dec.fa(concat_x_h).norm(2, dim=2)) ** 2 224 | return l.mean(1) 225 | 226 | def constraint_traj(self, t_span, x): 227 | t_eval, x_hat = self.forward(t_span, x[:, :self.nb_observed]) 228 | l_traj = (x[:self.nb_observed] - x_hat[:self.nb_observed]).norm(2, dim=2).mean(0) 229 | return l_traj, t_eval 230 | 231 | def lagrangian(self, lambda_p, t_span, x): 232 | l_trajectory, t_eval = self.constraint_traj(t_span, x) 233 | l_penalty = self.penalty(t_eval, x).mean() ** .5 234 | return l_penalty + lambda_p * l_trajectory.mean(), l_penalty.detach(), l_trajectory.mean().detach() 235 | 236 | def loss(self, t_span, x) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: 237 | return self.lagrangian(self.lambda_p, t_span, x) 238 | 239 | 240 | class APHYNITYAutoencoderDoublePendulum(HybridAutoencoder): 241 | def __init__(self, fp, augmented=False, zp_priors=None, device="cpu", **config): 242 | super(APHYNITYAutoencoderDoublePendulum, self).__init__() 243 | self.ze_dim = 4 244 | self.lambda_p = config.get("lambda_p", float("nan")) 245 | self.nb_observed = config.get("nb_observed", 25) 246 | self.nb_observed_theta_0 = config.get("nb_observed_theta_0", 25) 247 | self.nb_observed_theta_1 = config.get("nb_observed_theta_1", 25) 248 | self.za_dim = config.get("za_dim", 5) 249 | self.cos_sin_encoding = config.get("cos_sin_encoding", False) 250 | self.weight_penalty = config.get("weight_penalty", False) 251 | self.weight_constraint = config.get("weight_constraint", False) 252 | self.use_complete_signal = config.get("use_complete_signal", False) 253 | self.partial_observability = config.get("partial_observability", False) 254 | self.initial_guess = config.get("initial_guess", True) 255 | self.simplified_fa = config.get("simplified_fa", False) 256 | self.obtain_init_position = config.get("obtain_init_position", False) 257 | layers_encoder = [config.get("hidden_size_encoder", 300)] * config.get("nb_layers_encoder", 3) 258 | 259 | self.enc = DoublePendulumEncoder(layers=layers_encoder, za_dim=self.za_dim, ze_dim=self.ze_dim, 260 | nb_observed_theta_0=self.nb_observed_theta_0, 261 | nb_observed_theta_1=self.nb_observed_theta_1, 262 | obtain_init_position=self.obtain_init_position).to(device) 263 | self.fp = fp.to(device) 264 | self.no_fa = config.get("no_fa", False) 265 | self.no_fp = config.get("no_fp", False) 266 | self.no_APHYNITY = config.get("no_APHYNITY", False) 267 | if self.no_fp: 268 | self.fp = NoSimulator() 269 | 270 | if self.no_fa: 271 | self.dec = lambda x: torch.zeros(list(x.shape[:-1]) + [self.ze_dim], device=x.device) 272 | elif self.simplified_fa: 273 | self.linear_layer = nn.Linear(8, self.ze_dim) 274 | self.cos_sin_encoding = False 275 | self.dec = lambda x: self.linear_layer(torch.cat((torch.sin(x[..., :2]), torch.cos(x[..., :2]), 276 | x[..., 2:]/10, (x[..., 2:]**2)/100), -1)) 277 | else: 278 | in_dim = self.ze_dim if not self.cos_sin_encoding else self.ze_dim + 2 279 | 280 | layers_fa = [config.get("hidden_size_fa", 300)] * config.get("nb_layers_fa", 3) 281 | lis = [in_dim + self.za_dim] + layers_fa 282 | los = layers_fa + [self.ze_dim] 283 | layers = [] 284 | for li, lo in zip(lis, los): 285 | layers += [nn.Linear(li, lo), nn.SELU()] 286 | layers.pop() 287 | self.dec = nn.Sequential(*layers) 288 | 289 | self.device = device 290 | self.augmented = augmented 291 | self.min_zp = None 292 | if zp_priors.get("path", None): 293 | with open(zp_priors.get("path", None), "rb") as output_file: 294 | self.init_states = pickle.load(output_file) 295 | else: 296 | self.init_states = None 297 | if zp_priors is not None and self.augmented: 298 | min_zp = [] 299 | max_zp = [] 300 | for k, v in zp_priors.items(): 301 | min_zp.append(v["min"]) 302 | max_zp.append(v["max"]) 303 | self.min_zp = torch.tensor(min_zp, device=device) 304 | self.max_zp = torch.tensor(max_zp, device=device) 305 | 306 | def to(self, device): 307 | self.device = device 308 | self.enc.to(device) 309 | if not self.no_fa and not self.simplified_fa: 310 | self.dec.to(device) 311 | self.fp.to(device) 312 | if self.min_zp is not None: 313 | self.min_zp.to(device) 314 | self.max_zp.to(device) 315 | return self 316 | 317 | def ode_f(self, z_e, z_a): 318 | if self.za_dim > 0: 319 | if self.cos_sin_encoding: 320 | return lambda t, theta: self.fp(t, theta) + self.dec(torch.cat((torch.sin(theta[:, :2]), 321 | torch.cos(theta[:, :2]), 322 | theta[:, 2:], z_a), 1)) 323 | 324 | return lambda t, theta: self.fp(t, theta) + self.dec(torch.cat((theta, z_a), 1)) 325 | 326 | if self.simplified_fa: 327 | return lambda t, theta: self.fp(t, theta) + self.dec(theta) 328 | 329 | if self.cos_sin_encoding: 330 | return lambda t, theta: self.fp(t, theta) + self.dec(torch.cat((torch.sin(theta[:, :2]), 331 | torch.cos(theta[:, :2]), 332 | theta[:, 2:]), 1)) 333 | 334 | return lambda t, theta: self.fp(t, theta) + self.dec(theta) 335 | 336 | def forward(self, t_span, x): 337 | z_e, z_a = self.enc(t_span, x[:, :self.nb_observed]) 338 | x_pred = odeint(self.ode_f(z_e, z_a), z_e, t_span, atol=1e-5, 339 | rtol=1e-5).permute(1, 0, 2)[:, :, :2].unsqueeze(3).unsqueeze(3) 340 | 341 | return t_span, x_pred 342 | 343 | def augmented_data(self, t_span, x): 344 | theta_0, z_a = self.enc(t_span, x[:, :self.nb_observed]) 345 | 346 | if self.init_states is not None: 347 | idx = torch.randperm(self.init_states.shape[0])[:x.shape[0]] 348 | resampled_theta_0 = self.init_states[idx] + torch.randn_like(self.init_states[idx]) * .1 349 | else: 350 | resampled_theta_0 = torch.rand(x.shape[0], self.min_zp.shape[0], device=self.device) * \ 351 | (self.max_zp - self.min_zp) + self.min_zp 352 | with torch.no_grad(): 353 | dec = self.ode_f(resampled_theta_0, z_a) 354 | x_pred = odeint(dec, resampled_theta_0, t_span, atol=1e-5, 355 | rtol=1e-5).permute(1, 0, 2)[:, :, :2].unsqueeze(3).unsqueeze(3) 356 | return x_pred, resampled_theta_0 357 | 358 | def predicted_parameters_as_dict(self, t_span, x, zero_param=False): 359 | z_e, z_a = self.enc(t_span, x[:, :self.nb_observed]) 360 | return {"\\theta_0": z_e[:, [0]], "\\theta_1": z_e[:, [1]], 361 | "\\dot \\theta_0": z_e[:, [2]], "\\dot \\theta_1": z_e[:, [3]]} 362 | 363 | def predicted_parameters(self, t_span, x, zero_param=False): 364 | z_e, z_a = self.enc(t_span, x[:, :self.nb_observed]) 365 | return z_e 366 | 367 | def penalty(self, t_eval, x_span): 368 | raise NotImplementedError 369 | 370 | def penalty_from_sol(self, x_span, z_a): 371 | nb_observed = x_span.shape[1] if self.use_complete_signal else self.nb_observed 372 | 373 | x = x_span[:, :nb_observed] 374 | if self.za_dim > 0: 375 | z_a = z_a.unsqueeze(1).expand(-1, nb_observed, -1) 376 | if self.cos_sin_encoding: 377 | l = ((self.dec(torch.cat((torch.sin(x[:, :, :2, 0, 0]), 378 | torch.cos(x[:, :, :2, 0, 0]), 379 | x[:, :, 2:, 0, 0], z_a), 2))).norm(2, dim=2)) #** 2 380 | else: 381 | l = ((self.dec(torch.cat((x, z_a), 2))).norm(2, dim=2)) #** 2 382 | elif self.cos_sin_encoding: 383 | l = ((self.dec(torch.cat((torch.sin(x[:, :, :2, 0, 0]), 384 | torch.cos(x[:, :, :2, 0, 0]), 385 | x[:, :, 2:, 0, 0]), 2))).norm(2, dim=2)) #** 2 386 | else: 387 | l = ((self.dec(x)).norm(2, dim=2)) #** 2 388 | 389 | if self.weight_penalty: 390 | weights = 1.25 ** -torch.arange(l.shape[1]) 391 | weights = weights.sum() 392 | l = l * weights.unsqueeze(0) 393 | return l.sum(1) 394 | 395 | return l.mean(1) 396 | 397 | def constraint_traj(self, t_span, x): 398 | nb_observed = x.shape[1] if self.use_complete_signal else self.nb_observed 399 | t_eval, x_hat = self.forward(t_span, x[:, :nb_observed]) 400 | x_hat = x_hat[:, :, :2] 401 | diff_sin = (torch.sin(x[:, :nb_observed]) - torch.sin(x_hat[:, :nb_observed]))**2 402 | diff_cos = (torch.cos(x[:, :nb_observed]) - torch.cos(x_hat[:, :nb_observed]))**2 403 | if self.weight_constraint: 404 | l = (diff_sin + diff_cos).mean(2) 405 | weights = 1.25 ** -torch.arange(l.shape[1]) 406 | weights = weights / weights.sum() 407 | l = l * weights.unsqueeze(0) 408 | return l.sum(1) 409 | l_traj = (diff_sin + diff_cos).mean(2).mean(1) 410 | return l_traj, t_eval 411 | 412 | def constraint_traj_from_sol(self, t_eval, x, x_hat): 413 | x_hat = x_hat[:, :, :2] 414 | nb_observed = x.shape[1] if self.use_complete_signal else self.nb_observed 415 | diff_sin = (torch.sin(x[:, :nb_observed]) - torch.sin(x_hat[:, :nb_observed]))**2 416 | diff_cos = (torch.cos(x[:, :nb_observed]) - torch.cos(x_hat[:, :nb_observed]))**2 417 | l_traj = (diff_sin + diff_cos).mean(2).mean(1) 418 | return l_traj, t_eval 419 | 420 | def lagrangian(self, lambda_p, t_span, x): 421 | z_e, z_a = self.enc(t_span, x[:, :self.nb_observed]) 422 | x_pred = odeint(self.ode_f(z_e, z_a), z_e, t_span, atol=1e-5, rtol=1e-5).permute(1, 0, 2).unsqueeze(3).unsqueeze(3) 423 | l_trajectory, t_eval = self.constraint_traj_from_sol(t_span, x, x_pred) 424 | l_trajectory = l_trajectory ** .5 425 | l_penalty = self.penalty_from_sol(x_pred, z_a).mean() 426 | 427 | if self.augmented and False: 428 | x_pred, _ = self.augmented_data(t_span, x) 429 | self.augmented = False 430 | augm_loss, _ = self.lagrangian(lambda_p, t_span, x_pred) 431 | self.augmented = True 432 | return augm_loss, l_penalty.detach() 433 | 434 | if self.no_APHYNITY: 435 | return l_trajectory.mean(), l_penalty.detach(), l_trajectory.mean().detach() 436 | return l_penalty + lambda_p * l_trajectory.mean(), l_penalty.detach(), l_trajectory.mean().detach() 437 | 438 | def loss(self, t_span, x): 439 | return self.lagrangian(self.lambda_p, t_span, x) 440 | 441 | def lagrangian_augm(self, lambda_p, t_span, x, zp=None): 442 | return self.lagrangian(lambda_p, t_span, x) 443 | 444 | 445 | class APHYNITYAutoencoderReactionDiffusion(HybridAutoencoder): 446 | def __init__(self, fp, augmented=False, zp_priors=None, device="cpu", **config): 447 | super(APHYNITYAutoencoderReactionDiffusion, self).__init__() 448 | self.dim_in = 2 449 | self.nb_observed = 10 450 | self.dim_za, self.dim_zp = 10, 2 451 | self.lambda_p = config.get("lambda_p", float("nan")) 452 | self.enc = nn.Sequential(nn.Flatten(0, 1), nn.Conv2d(self.dim_in, 16, 3), nn.ReLU(), 453 | nn.Conv2d(16, 32, 3), nn.AvgPool2d(2), 454 | nn.Conv2d(32, 64, 3), nn.ReLU(), 455 | nn.Conv2d(64, 64, 3), nn.AvgPool2d(2), 456 | nn.Conv2d(64, 32, 3), nn.ReLU(), nn.Unflatten(0, (-1, self.nb_observed)), 457 | Permute((0, 2, 1, 3, 4)), nn.Conv3d(32, 16, 2), nn.ReLU(), 458 | nn.Conv3d(16, 16, 2), nn.Flatten(1, 4), 459 | nn.Linear(128, 256), nn.ReLU(), 460 | nn.Linear(256, 256), nn.ReLU(), 461 | nn.Linear(256, 256), nn.ReLU(), 462 | nn.Linear(256, self.dim_za + self.dim_zp)) 463 | 464 | self.fp_param_converter_final_act = act_dict["ReactionDiffusionParametersScaler"] 465 | self.fa = nn.Sequential(nn.Conv2d(2 + self.dim_za, 16, 3, padding=1), nn.ReLU(), 466 | nn.Conv2d(16, 16, 3, padding=1), nn.ReLU(), 467 | nn.Conv2d(16, 2, 3, padding=1)) 468 | self.fp = fp 469 | 470 | self.device = device 471 | self.augmented = augmented 472 | if augmented: 473 | min_zp = [] 474 | max_zp = [] 475 | for k, v in zp_priors.items(): 476 | min_zp.append(v["min"]) 477 | max_zp.append(v["max"]) 478 | self.min_zp = torch.tensor(min_zp, device=device) 479 | self.max_zp = torch.tensor(max_zp, device=device) 480 | 481 | def forward_step(self, t, x, zp, za): 482 | x_fa = torch.cat((x, za.unsqueeze(2).unsqueeze(3).expand(-1, -1, x.shape[2], x.shape[3])), 1) 483 | return self.fa(x_fa) + self.fp.parameterized_forward(t, x, **zp) 484 | 485 | def to(self, device): 486 | self.device = device 487 | self.enc.to(device) 488 | self.fa.to(device) 489 | self.fp.to(device) 490 | if self.augmented: 491 | self.min_zp.to(device) 492 | self.max_zp.to(device) 493 | return self 494 | 495 | def get_physical_parameters(self, h, zero_param=False, as_dict=False): 496 | za, zp = h[:, :self.dim_za], h[:, self.dim_za:] 497 | zp = self.fp_param_converter_final_act(zp) 498 | 499 | fp_params = zp 500 | if not as_dict: 501 | return fp_params, za 502 | physical_params = {} 503 | for i, p in enumerate(self.fp.incomplete_param_dim_textual): 504 | physical_params[p] = fp_params[:, i].unsqueeze(1) 505 | if zero_param: 506 | for i, p in enumerate(self.fp.missing_param_dim_textual): \ 507 | physical_params[p] = torch.zeros(h.shape[0], device=h.device).unsqueeze(1) 508 | 509 | return physical_params, za 510 | 511 | def forward(self, t_span, x, h=None): 512 | if h is None: 513 | h = self.enc(x[:, :self.nb_observed]) 514 | physical_params, za = self.get_physical_parameters(h, as_dict=True) 515 | x_pred = odeint(lambda t, x: self.forward_step(t, x, physical_params, za), 516 | x[:, 0], t_span, atol=1e-5, rtol=1e-5).permute(1, 0, 2, 3, 4) 517 | return t_span, x_pred 518 | 519 | def augmented_data(self, t_span, x): 520 | h = self.enc(x[:, :self.nb_observed]) 521 | physical_params, za = self.get_physical_parameters(h) 522 | resampled_zp = torch.rand(x.shape[0], self.min_zp.shape[0], device=self.device) * (self.max_zp - self.min_zp) \ 523 | + self.min_zp 524 | fp_params = resampled_zp 525 | physical_params = {} 526 | for i, p in enumerate(self.fp.incomplete_param_dim_textual): 527 | physical_params[p] = fp_params[:, i].unsqueeze(1) 528 | with torch.no_grad(): 529 | x_pred = odeint(lambda t, x: self.forward_step(t, x, physical_params, za), 530 | x[:, 0], t_span, atol=1e-5, rtol=1e-5).permute(1, 0, 2, 3, 4) 531 | return x_pred, resampled_zp 532 | 533 | def predicted_parameters(self, t_span, x, zero_param=False): 534 | h = self.enc(x[:, :self.nb_observed]) 535 | physical_params, za = self.get_physical_parameters(h, zero_param=False) 536 | return physical_params 537 | 538 | def predicted_parameters_as_dict(self, t_span, x, zero_param=False) -> dict: 539 | h = self.enc(x[:, :self.nb_observed]) 540 | physical_params, za = self.get_physical_parameters(h, zero_param=False, as_dict=True) 541 | return physical_params 542 | 543 | def penalty(self, t_eval, x_span, h=None): 544 | if h is None: 545 | h = self.enc(x_span[:, :self.nb_observed]) 546 | x = x_span[:, :self.nb_observed] 547 | physical_params, za = self.get_physical_parameters(h) 548 | za = za.unsqueeze(1).expand(-1, self.nb_observed, -1).reshape(-1, self.dim_za) 549 | x = x.reshape(x.shape[0] * x.shape[1], x.shape[2], x.shape[3], x.shape[4]) 550 | x_fa = torch.cat((x, za.unsqueeze(2).unsqueeze(3).expand(-1, -1, x.shape[2], x.shape[3])), 1) 551 | response = self.fa(x_fa).reshape(x_span.shape[0], self.nb_observed, -1) 552 | l = (response.norm(2, dim=2)) ** 2 553 | return l.mean(1) 554 | 555 | def constraint_traj(self, t_span, x, h=None): 556 | t_eval, x_hat = self.forward(t_span, x[:, :self.nb_observed], h) 557 | l_traj = (x[:, :self.nb_observed] - x_hat[:, :self.nb_observed]).view(x.shape[0], 558 | self.nb_observed, 559 | -1).norm(2, dim=2).mean(1) 560 | return l_traj, t_eval 561 | 562 | def lagrangian(self, lambda_p, t_span, x): 563 | h = self.enc(x[:, :self.nb_observed]) 564 | l_trajectory, t_eval = self.constraint_traj(t_span, x, h) 565 | l_penalty = self.penalty(t_eval, x, h).mean() ** .5 566 | return l_penalty + lambda_p * l_trajectory.mean(), l_penalty.detach(), l_trajectory.mean().detach() 567 | 568 | def loss(self, t_span, x) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: 569 | return self.lagrangian(self.lambda_p, t_span, x) 570 | -------------------------------------------------------------------------------- /code/hybrid_models/HVAE.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 3 | 4 | import math 5 | 6 | import torch.nn as nn 7 | from matplotlib import pyplot as plt 8 | 9 | from code.simulators import PhysicalModel, NoSimulator 10 | import torch 11 | from code.nn import MLP, act_dict, ConditionalUNet, ConditionalUNetReactionDiffusion, Permute, kl_gaussians, UNet 12 | from torchdiffeq import odeint as odeint 13 | from code.hybrid_models.HybridAutoencoder import HybridAutoencoder 14 | 15 | 16 | class DynamicalPhysicalDecoder(nn.Module): 17 | def __init__(self, fp: PhysicalModel): 18 | super(DynamicalPhysicalDecoder, self).__init__() 19 | self.fp = fp 20 | self.z_p = None 21 | 22 | def to(self, device): 23 | super(DynamicalPhysicalDecoder, self).to(device) 24 | self.fp.to(device) 25 | return self 26 | 27 | def set_latent(self, z_p): 28 | # h: B_size x hidden_size 29 | self.z_p = z_p 30 | 31 | def get_physical_parameters(self, x): 32 | fp_params = self.z_p 33 | physical_params = {} 34 | for i, p in enumerate(self.fp.incomplete_param_dim_textual): 35 | physical_params[p] = fp_params[:, i].unsqueeze(1) 36 | 37 | return physical_params 38 | 39 | def forward_fp(self, t, x): 40 | physical_params = self.get_physical_parameters(x) 41 | return self.fp.parameterized_forward(x, **physical_params) 42 | 43 | def forward(self, t, x): 44 | dx = self.forward_fp(t, x) 45 | return dx.unsqueeze(1) 46 | 47 | 48 | class AugmentedHybridDecoder(nn.Module): 49 | def __init__(self, fp: PhysicalModel, fa: nn.Module): 50 | super(AugmentedHybridDecoder, self).__init__() 51 | # Create net that maps hidden from encoder + state x to fa(x;h) contribution. 52 | self.fa = fa 53 | 54 | self.fp = fp 55 | 56 | self.x_dim = fp._X_dim 57 | 58 | self.z_a = None 59 | self.z_p = None 60 | self.fp_params = None 61 | 62 | def to(self, device): 63 | super(AugmentedHybridDecoder, self).to(device) 64 | self.fp.to(device) 65 | self.fa.to(device) 66 | return self 67 | 68 | def set_latent(self, z_a, z_p): 69 | # h: B_size x hidden_size 70 | self.z_a = z_a 71 | self.z_p = z_p 72 | 73 | def forward_fa(self, t, x): 74 | if len(x.shape) == 4: 75 | x_fa = torch.cat((x, self.z_a.unsqueeze(2).unsqueeze(3).expand(-1, -1, x.shape[2], x.shape[3])), 1) 76 | return self.fa(x_fa) 77 | return self.fa(x, self.z_a) 78 | return self.fa(torch.cat((x, self.z_a), 1)) 79 | 80 | def get_physical_parameters(self, x): 81 | fp_params = self.z_p 82 | physical_params = {} 83 | for i, p in enumerate(self.fp.incomplete_param_dim_textual): 84 | physical_params[p] = fp_params[:, i].unsqueeze(1) 85 | 86 | return physical_params 87 | 88 | def forward_fp(self, t, x): 89 | physical_params = self.get_physical_parameters(x) 90 | return self.fp.parameterized_forward(t, x, **physical_params) 91 | 92 | def forward(self, t, x): 93 | if len(x.shape) == 4: 94 | x, x_r = torch.chunk(x, 2, 1) 95 | dx, dx_r = torch.chunk(self.forward_fp(t, torch.cat((x, x_r), 0)), 2, 0) 96 | dx = torch.cat((self.forward_fa(t, x) + dx, dx_r), 2) 97 | return dx 98 | x, x_r = torch.chunk(x, 2, 1) 99 | dx, dx_r = torch.chunk(self.forward_fp(t, torch.cat((x, x_r), 0)), 2, 0) 100 | dx = torch.cat((self.forward_fa(t, x) + dx, dx_r), 1) 101 | return dx#.unsqueeze(1) 102 | 103 | 104 | class HybridVAE(HybridAutoencoder): 105 | def predicted_parameters_as_dict(self, t_span, x, zero_param=False) -> dict: 106 | x_obs = x[:, :self.nb_observed].view(-1, self.nb_observed * self.x_dim) 107 | 108 | mu_a, log_sigma_a = torch.chunk(self.ga(x_obs), 2, 1) 109 | z_a = mu_a 110 | 111 | x_p = x_obs + self.gp_1(torch.cat((x_obs, z_a), 1)) 112 | mu_p, log_sigma_p = torch.chunk(self.gp_2(x_p), 2, 1) 113 | mu_p = self.act_mu_p(mu_p) 114 | 115 | fp_params = mu_p 116 | physical_params = {} 117 | for i, p in enumerate(self.fp.incomplete_param_dim_textual): 118 | physical_params[p] = fp_params[:, i].unsqueeze(1) 119 | for i, p in enumerate(self.fp.missing_param_dim_textual): 120 | physical_params[p] = torch.zeros(x.shape[0], device=x.device).unsqueeze(1) 121 | return physical_params 122 | 123 | def __init__(self, fp: PhysicalModel, device="cpu", **config): 124 | super(HybridVAE, self).__init__() 125 | self.device = device 126 | self.z_a_dim = 1 if "z_a_dim" not in config else config["z_a_dim"] 127 | self.nb_observed = 50 if "nb_observed" not in config else config["nb_observed"] 128 | self.x_dim, self.z_p_dim = fp._X_dim, len(fp.incomplete_param_dim_textual) 129 | factor = 1 130 | self.alpha = .01 * factor if "alpha" not in config else config["alpha"] 131 | self.beta = .001 * factor if "beta" not in config else config["beta"] 132 | self.gamma = .1 * factor if "gamma" not in config else config["gamma"] 133 | self.posterior_type = "dirac" if "posterior_type" not in config else config["posterior_type"] 134 | 135 | seq_dim = self.x_dim * self.nb_observed 136 | 137 | mu_prior_zp = [] 138 | sigma_prior_zp = [] 139 | min_zp = [] 140 | max_zp = [] 141 | for k, v in config["zp_priors"].items(): 142 | mu_prior_zp.append(v["mu"]) 143 | sigma_prior_zp.append(v["sigma"]) 144 | min_zp.append(v["min"]) 145 | max_zp.append(v["max"]) 146 | self.mu_prior_zp, self.sigma_prior_zp = torch.tensor(mu_prior_zp, device=device), torch.tensor(sigma_prior_zp, 147 | device=device) 148 | self.min_zp, self.max_zp = torch.tensor(min_zp, device=device), torch.tensor(max_zp, device=device) 149 | 150 | # Fp encoders 151 | gp_1_hidden = [128, 128] if "gp_1_hidden" not in config else config["gp_1_hidden"] 152 | gp_2_hidden = [128, 128, 256, 64, 32] if "gp_2_hidden" not in config else config["gp_2_hidden"] 153 | self.gp_1 = MLP([seq_dim + self.z_a_dim] + gp_1_hidden + [seq_dim], 0, "SELU", None).to(device) 154 | self.gp_2 = MLP([seq_dim] + gp_2_hidden + [self.z_p_dim * 2], 0, "SELU", None).to(device) 155 | self.act_mu_p = nn.Identity() if "act_mu_p" not in config else act_dict[config["act_mu_p"]] 156 | 157 | # Fa Encoder 158 | ga_hidden = [256, 256, 128, 32] if "ga_hidden" not in config else config["ga_hidden"] 159 | self.ga = MLP([seq_dim] + ga_hidden + [self.z_a_dim * 2], 0, "SELU", None).to(device) 160 | 161 | # Hybrid decoder 162 | fa_hidden = [64, 64] if "fa_hidden" not in config else config["fa_hidden"] 163 | param_decoder = {"fp": fp.to(device), 164 | "layers_fa": [self.z_a_dim + self.x_dim] + fa_hidden + [self.x_dim], 165 | "fa_hidden_act": "SELU", 166 | "fa_final_act": None} 167 | fa = MLP(param_decoder["layers_fa"], hidden_act=param_decoder["fa_hidden_act"], 168 | final_act=param_decoder["fa_final_act"]) 169 | self.dec = AugmentedHybridDecoder(fp=fp, fa=fa).to(device) 170 | self.sigma_x = nn.Parameter(torch.zeros(self.x_dim, requires_grad=True)).to(device) 171 | 172 | # Fp only decoder 173 | self.fp = fp 174 | self.fp_only = DynamicalPhysicalDecoder(fp) 175 | 176 | def augmented_data(self, t_span, x): 177 | resampled_zp = torch.rand(x.shape[0], self.z_p_dim, device=self.device) * (self.max_zp - self.min_zp) \ 178 | + self.min_zp 179 | resampled_za = torch.randn(x.shape[0], self.z_a_dim, device=self.device) 180 | self.dec.set_latent(resampled_za, resampled_zp.repeat(2, 1)) 181 | 182 | mu_x_pred_all = odeint(self.dec, x[:, 0, :, 0, 0].repeat(1, 2), t_span).permute(1, 0, 2) 183 | 184 | mu_x_pred_tot, mu_x_pred_fp = torch.chunk(mu_x_pred_all, 2, 2) 185 | 186 | return mu_x_pred_tot.unsqueeze(3).unsqueeze(3), resampled_zp 187 | 188 | def forward(self, t_span, x): 189 | x_obs = x[:, :self.nb_observed].view(-1, self.nb_observed * self.x_dim) 190 | 191 | mu_a, log_sigma_a = torch.chunk(self.ga(x_obs), 2, 1) 192 | z_a = mu_a 193 | 194 | x_p = x_obs + self.gp_1(torch.cat((x_obs, z_a), 1)) 195 | mu_p, log_sigma_p = torch.chunk(self.gp_2(x_p), 2, 1) 196 | mu_p = self.act_mu_p(mu_p) 197 | z_p = mu_p 198 | 199 | self.dec.set_latent(z_a, z_p.repeat(2, 1)) 200 | mu_x_pred_all = odeint(self.dec, x[:, 0, :, 0, 0].repeat(1, 2), t_span).permute(1, 0, 2) 201 | 202 | mu_x_pred_tot, mu_x_pred_fp = torch.chunk(mu_x_pred_all, 2, 2) 203 | return t_span, mu_x_pred_tot.unsqueeze(3).unsqueeze(3) 204 | 205 | def predicted_parameters(self, t_span, x): 206 | x_obs = x[:, :self.nb_observed].view(-1, self.nb_observed * self.x_dim) 207 | 208 | mu_a, log_sigma_a = torch.chunk(self.ga(x_obs), 2, 1) 209 | z_a = mu_a 210 | 211 | x_p = x_obs + self.gp_1(torch.cat((x_obs, z_a), 1)) 212 | mu_p, log_sigma_p = torch.chunk(self.gp_2(x_p), 2, 1) 213 | mu_p = self.act_mu_p(mu_p) 214 | 215 | fp_params = mu_p 216 | return fp_params 217 | 218 | def loss(self, t_span, x): 219 | return self.loss_augm(t_span, x, None) 220 | 221 | def loss_augm(self, t_span, x, zp): 222 | b_size = x.shape[0] 223 | x = x[:, :, :, 0, 0] 224 | 225 | x_obs = x[:, :self.nb_observed].view(-1, self.nb_observed * self.x_dim) 226 | 227 | mu_a, log_sigma_a = torch.chunk(self.ga(x_obs), 2, 1) 228 | sigma_a = torch.exp(log_sigma_a) 229 | z_a = mu_a + sigma_a * torch.randn_like(mu_a) 230 | 231 | x_p = x_obs + self.gp_1(torch.cat((x_obs, z_a), 1)) 232 | mu_p, log_sigma_p = torch.chunk(self.gp_2(x_p), 2, 1) 233 | mu_p = self.act_mu_p(mu_p) 234 | sigma_p = torch.exp(log_sigma_p) 235 | ll_zp_augm = torch.distributions.Normal(loc=mu_p, scale=sigma_p).log_prob(zp).sum(1) if zp is not None else 0. 236 | 237 | if self.posterior_type == "positive_gaussian": 238 | z_p = mu_p + sigma_p * torch.randn_like(mu_p) 239 | z_p[z_p <= 0.] = 1. 240 | elif self.posterior_type == "dirac": 241 | z_p = mu_p 242 | else: 243 | raise Exception("The posterior type: %s is not implemented" % self.posterior_type) 244 | 245 | self.dec.set_latent(z_a, z_p.repeat(2, 1)) 246 | mu_x_pred_all = odeint(self.dec, x[:, 0].repeat(1, 2), t_span[:self.nb_observed], method='dopri5', 247 | atol=1e-5, rtol=1e-5).permute(1, 0, 2) 248 | 249 | mu_x_pred_tot, mu_x_pred_fp = torch.chunk(mu_x_pred_all, 2, 2) 250 | sigma_x_pred = torch.exp(self.sigma_x.unsqueeze(0).unsqueeze(1).repeat(mu_x_pred_tot.shape[0], 251 | mu_x_pred_tot.shape[1], 1)) 252 | x_obs = x[:, :self.nb_observed] 253 | 254 | mse_traj = (x_obs - mu_x_pred_tot[:, :self.nb_observed]).norm(2, dim=2).mean(0).mean() 255 | #print(torch.exp(self.sigma_x)) 256 | #exit() 257 | ll = torch.distributions.Normal(loc=mu_x_pred_tot, scale=sigma_x_pred).log_prob(x_obs).sum(1).sum(1) 258 | 259 | ELBO = ll - kl_gaussians(mu_a, sigma_a, torch.zeros_like(mu_a), torch.ones_like(sigma_a)) \ 260 | - kl_gaussians(mu_p, sigma_p, self.mu_prior_zp.unsqueeze(0).repeat(b_size, 1), 261 | self.sigma_prior_zp.unsqueeze(0).repeat(b_size, 1)) + ll_zp_augm 262 | 263 | bound_kl_physics_reg = kl_gaussians(mu_x_pred_tot, sigma_x_pred, mu_x_pred_fp, sigma_x_pred).sum(1) \ 264 | + kl_gaussians(mu_a, sigma_a, torch.zeros_like(mu_a), torch.ones_like(sigma_a)) \ 265 | + kl_gaussians(mu_p, sigma_p, torch.ones_like(mu_p) * self.mu_prior_zp, 266 | torch.ones_like(sigma_p) * self.sigma_prior_zp) 267 | 268 | x_r_detached = mu_x_pred_fp.detach().requires_grad_(True) 269 | x_p = x_p.view(b_size, self.nb_observed, self.x_dim) 270 | R_da_1 = ((x_p - x_r_detached) ** 2).sum(1).sum(1) 271 | 272 | resampled_zp = torch.rand(x.shape[0], self.z_p_dim, device=self.device) * (self.max_zp - self.min_zp) \ 273 | + self.min_zp 274 | resampled_za = torch.randn(x.shape[0], self.z_a_dim, device=self.device) 275 | # print(resampled_zp.shape) 276 | self.dec.set_latent(resampled_za, resampled_zp.repeat(2, 1)) 277 | mu_x_pred_all = odeint(self.dec, x[:, 0].repeat(1, 2), t_span[:self.nb_observed], method='dopri5', 278 | atol=1e-5, rtol=1e-5) 279 | 280 | mu_x_pred_all = mu_x_pred_all.squeeze(2).permute(1, 0, 2) 281 | mu_x_pred_tot, mu_x_pred_fp = torch.chunk(mu_x_pred_all, 2, 2) 282 | 283 | x_r_detached = mu_x_pred_fp.detach().squeeze(2).permute(1, 0, 2).requires_grad_(True) 284 | mu_p, log_sigma_p = torch.chunk(self.gp_2(x_r_detached.contiguous().view(b_size, -1)), 2, 1) 285 | R_da_2 = ((mu_p - resampled_zp) ** 2).sum(1) 286 | 287 | loss_tot = -ELBO + self.alpha * bound_kl_physics_reg + self.beta * R_da_1 + self.gamma * R_da_2 288 | return loss_tot.mean(), torch.tensor(-1.), mse_traj.detach() 289 | 290 | 291 | class HybridVAEReactionDiffusion(nn.Module): 292 | def __init__(self, fp: PhysicalModel, device="cpu", **config): 293 | super(HybridVAEReactionDiffusion, self).__init__() 294 | self.device = device 295 | self.z_a_dim = 10 if "z_a_dim" not in config else config["z_a_dim"] 296 | self.nb_observed = 10 if "nb_observed" not in config else config["nb_observed"] 297 | self.x_dim, self.z_p_dim = fp._X_dim, len(fp.incomplete_param_dim_textual) 298 | factor = 1 299 | self.alpha = .01 * factor if "alpha" not in config else config["alpha"] 300 | self.beta = .001 * factor if "beta" not in config else config["beta"] 301 | self.gamma = .1 * factor if "gamma" not in config else config["gamma"] 302 | 303 | mu_prior_zp = [] 304 | sigma_prior_zp = [] 305 | min_zp = [] 306 | max_zp = [] 307 | for k, v in config["zp_priors"].items(): 308 | mu_prior_zp.append(v["mu"]) 309 | sigma_prior_zp.append(v["sigma"]) 310 | min_zp.append(v["min"]) 311 | max_zp.append(v["max"]) 312 | self.mu_prior_zp, self.sigma_prior_zp = torch.tensor(mu_prior_zp, device=device), torch.tensor(sigma_prior_zp, 313 | device=device) 314 | self.min_zp, self.max_zp = torch.tensor(min_zp, device=device), torch.tensor(max_zp, device=device) 315 | 316 | # Fp encoders 317 | self.gp_1 = ConditionalUNetReactionDiffusion(z_a_dim=self.z_a_dim).to(device) 318 | self.act_mu_p = act_dict["ReactionDiffusionParametersScaler"] 319 | 320 | # Hybrid decoder 321 | fa = nn.Sequential(nn.Conv2d(2 + self.z_a_dim, 16, 3, padding=1), nn.ReLU(), 322 | nn.Conv2d(16, 16, 3, padding=1), nn.ReLU(), 323 | nn.Conv2d(16, 2, 3, padding=1)) 324 | self.dec = AugmentedHybridDecoder(fp=fp, fa=fa).to(device) 325 | self.sigma_x = torch.zeros(self.x_dim, requires_grad=False).to(device) # - 3. 326 | 327 | # Fp only decoder 328 | self.fp = fp 329 | self.fp_only = DynamicalPhysicalDecoder(fp) 330 | 331 | self.enc_common = nn.Sequential(nn.Flatten(0, 1), nn.Conv2d(2, 16, 3), nn.ReLU(), 332 | nn.Conv2d(16, 32, 3), nn.AvgPool2d(2), 333 | nn.Conv2d(32, 64, 3), nn.ReLU(), 334 | nn.Conv2d(64, 64, 3), nn.AvgPool2d(2), 335 | nn.Conv2d(64, 32, 3), nn.ReLU(), nn.Unflatten(0, (-1, self.nb_observed)), 336 | Permute((0, 2, 1, 3, 4)), nn.Conv3d(32, 16, 2), nn.ReLU(), 337 | nn.Conv3d(16, 16, 2), nn.Flatten(1, 4)) 338 | 339 | self.enc_za = nn.Sequential(self.enc_common, 340 | nn.Linear(128, 256), nn.ReLU(), 341 | nn.Linear(256, 256), nn.ReLU(), 342 | nn.Linear(256, 256), nn.ReLU(), 343 | nn.Linear(256, 2 * self.z_a_dim)) 344 | 345 | self.enc_zp = nn.Sequential(self.enc_common, 346 | nn.Linear(128, 256), nn.ReLU(), 347 | nn.Linear(256, 256), nn.ReLU(), 348 | nn.Linear(256, 256), nn.ReLU(), 349 | nn.Linear(256, 2 * self.z_p_dim)) 350 | 351 | def to(self, device): 352 | super(HybridVAEReactionDiffusion, self).to(device) 353 | self.sigma_x = self.sigma_x.to(device) 354 | self.device = device 355 | self.min_zp, self.max_zp = self.min_zp.to(device), self.max_zp.to(device) 356 | return self 357 | 358 | def augmented_data(self, t_span, x): 359 | b_size = x.shape[0] 360 | x_obs = x[:, :self.nb_observed] 361 | 362 | resampled_zp = torch.rand(b_size, self.z_p_dim, device=self.device) * (self.max_zp - self.min_zp) + self.min_zp 363 | resampled_za = torch.randn(b_size, self.z_a_dim, device=self.device) 364 | self.dec.set_latent(resampled_za, resampled_zp.repeat(2, 1)) 365 | 366 | mu_x_pred_all = odeint(self.dec, x[:, 0].repeat(1, 1, 2, 1), t_span, method='dopri5', atol=1e-5, rtol=1e-5) 367 | 368 | mu_x_pred_tot, mu_x_pred_fp = torch.chunk(mu_x_pred_all, 2, 3) 369 | mu_x_pred_tot = mu_x_pred_tot.permute(1, 0, 2, 3, 4).reshape(b_size, -1) 370 | return mu_x_pred_tot, resampled_zp 371 | 372 | def forward(self, t_span, x): 373 | b_size, im_size = x.shape[0], x.shape[-1] 374 | x_obs = x[:, :self.nb_observed] 375 | 376 | mu_a, log_sigma_a = torch.chunk(self.enc_za(x_obs), 2, 1) 377 | z_a = mu_a 378 | 379 | x_p = x_obs + self.gp_1(x_obs.reshape(-1, 2, im_size, im_size), 380 | z_a.unsqueeze(1).expand(-1, self.nb_observed, -1).reshape(-1, self.z_a_dim)).reshape( 381 | b_size, self.nb_observed, 2, im_size, im_size) 382 | 383 | mu_p, log_sigma_p = torch.chunk(self.enc_zp(x_p), 2, 1) 384 | mu_p = self.act_mu_p(mu_p) 385 | z_p = mu_p 386 | self.dec.set_latent(z_a, z_p.repeat(2, 1)) 387 | 388 | mu_x_pred_all = odeint(self.dec, x[:, 0].repeat(1, 1, 2, 1), t_span, method='dopri5', 389 | atol=1e-5, rtol=1e-5) 390 | 391 | mu_x_pred_tot, mu_x_pred_fp = torch.chunk(mu_x_pred_all, 2, 3) 392 | mu_x_pred_tot = mu_x_pred_tot.permute(1, 0, 2, 3, 4).reshape(b_size, -1) 393 | mu_x_pred_fp = mu_x_pred_fp.permute(1, 0, 2, 3, 4).reshape(b_size, -1) 394 | 395 | return t_span, mu_x_pred_tot, mu_x_pred_fp 396 | 397 | def predicted_parameters(self, t_span, x): 398 | b_size, im_size = x.shape[0], x.shape[-1] 399 | x_obs = x[:, :self.nb_observed] 400 | 401 | mu_a, log_sigma_a = torch.chunk(self.enc_za(x_obs), 2, 1) 402 | z_a = mu_a 403 | 404 | x_p = x_obs + self.gp_1(x_obs.reshape(-1, 2, im_size, im_size), 405 | z_a.unsqueeze(1).expand(-1, self.nb_observed, -1).reshape(-1, self.z_a_dim)).reshape( 406 | b_size, self.nb_observed, 2, im_size, im_size) 407 | 408 | mu_p, log_sigma_p = torch.chunk(self.enc_zp(x_p), 2, 1) 409 | mu_p = self.act_mu_p(mu_p) 410 | 411 | fp_params = mu_p 412 | physical_params = {} 413 | for i, p in enumerate(self.fp.incomplete_param_dim_textual): 414 | physical_params[p] = fp_params[:, i].unsqueeze(1) 415 | for i, p in enumerate(self.fp.missing_param_dim_textual): 416 | physical_params[p] = torch.zeros(x.shape[0], device=x.device).unsqueeze(1) 417 | return physical_params 418 | 419 | def loss_augm(self, t_span, x, zp): 420 | return self.loss(t_span, x, zp) 421 | 422 | def loss(self, t_span, x, zp=None): 423 | b_size, im_size = x.shape[0], x.shape[-1] 424 | 425 | x_obs = x[:, :self.nb_observed] 426 | 427 | mu_a, log_sigma_a = torch.chunk(self.enc_za(x_obs), 2, 1) 428 | 429 | sigma_a = torch.exp(log_sigma_a) 430 | z_a = mu_a + sigma_a * torch.randn_like(mu_a) 431 | 432 | delta = self.gp_1(x_obs.reshape(-1, 2, im_size, im_size), 433 | z_a.unsqueeze(1).expand(-1, self.nb_observed, -1).reshape(-1, self.z_a_dim)).reshape(b_size, 434 | self.nb_observed, 435 | 2, im_size, 436 | im_size) 437 | x_p = x_obs + delta 438 | mu_p, log_sigma_p = torch.chunk(self.enc_zp(x_p), 2, 1) 439 | mu_p = self.act_mu_p(mu_p) 440 | sigma_p = torch.exp(log_sigma_p) 441 | 442 | ll_zp_augm = torch.distributions.Normal(loc=mu_p, scale=sigma_p).log_prob(zp).sum(1) if zp is not None else 0. 443 | 444 | z_p = mu_p 445 | self.dec.set_latent(z_a, z_p.repeat(2, 1)) 446 | mu_x_pred_all = odeint(self.dec, x[:, 0].repeat(1, 1, 2, 1), t_span[:self.nb_observed], method='dopri5', 447 | atol=1e-5, rtol=1e-5) 448 | 449 | mu_x_pred_tot, mu_x_pred_fp = torch.chunk(mu_x_pred_all, 2, 3) 450 | mu_x_pred_tot = mu_x_pred_tot.permute(1, 0, 2, 3, 4).reshape(b_size, -1) 451 | mu_x_pred_fp = mu_x_pred_fp.permute(1, 0, 2, 3, 4).reshape(b_size, -1) 452 | 453 | sigma_x_pred = torch.exp( 454 | self.sigma_x.unsqueeze(0).unsqueeze(0).repeat(b_size, self.nb_observed, 1, 1, 1).reshape(b_size, -1)) 455 | 456 | x_obs = x[:, :self.nb_observed] 457 | ll = torch.distributions.Normal(loc=mu_x_pred_tot, scale=sigma_x_pred).log_prob(x_obs.view(b_size, -1)).sum(1) 458 | ELBO = ll - kl_gaussians(mu_a, sigma_a, torch.zeros_like(mu_a), torch.ones_like(sigma_a)) + ll_zp_augm\ 459 | - kl_gaussians(mu_p, sigma_p, self.mu_prior_zp.unsqueeze(0).repeat(b_size, 1), 460 | self.sigma_prior_zp.unsqueeze(0).repeat(b_size, 1)) 461 | 462 | bound_kl_physics_reg = kl_gaussians(mu_x_pred_tot, sigma_x_pred, 463 | mu_x_pred_fp, sigma_x_pred) \ 464 | + kl_gaussians(mu_a, sigma_a, torch.zeros_like(mu_a), torch.ones_like(sigma_a)) \ 465 | + kl_gaussians(mu_p, sigma_p, torch.ones_like(mu_p) * self.mu_prior_zp, 466 | torch.ones_like(sigma_p) * self.sigma_prior_zp) 467 | 468 | x_r_detached = mu_x_pred_fp.detach().requires_grad_(True) 469 | x_p = x_p.view(b_size, -1) 470 | R_da_1 = ((x_p - x_r_detached) ** 2).sum(1) 471 | 472 | resampled_zp = torch.rand(b_size, self.z_p_dim, device=self.device) * (self.max_zp - self.min_zp) \ 473 | + self.min_zp 474 | resampled_za = torch.randn(b_size, self.z_a_dim, device=self.device) 475 | self.dec.set_latent(resampled_za, resampled_zp.repeat(2, 1)) 476 | mu_x_pred_all = odeint(self.dec, x[:, 0].repeat(1, 1, 2, 1), t_span[:self.nb_observed], method='dopri5', 477 | atol=1e-5, rtol=1e-5) 478 | 479 | mu_x_pred_all = mu_x_pred_all 480 | mu_x_pred_tot, mu_x_pred_fp = torch.chunk(mu_x_pred_all, 2, 3) 481 | x_r_detached = mu_x_pred_fp.detach().requires_grad_(True) 482 | mu_p, log_sigma_p = torch.chunk(self.enc_zp(x_r_detached), 2, 1) 483 | 484 | R_da_2 = ((mu_p - resampled_zp) ** 2).sum(1) 485 | 486 | loss_tot = -ELBO + self.alpha * bound_kl_physics_reg + self.beta * R_da_1 + self.gamma * R_da_2 487 | return loss_tot 488 | 489 | 490 | class DoublePendulumEncoder(nn.Module): 491 | def __init__(self, layers=[300, 300, 300], za_dim=50, ze_dim=4, initial_guess=True, 492 | nb_observed_theta_0=25, nb_observed_theta_1=25, obtain_init_position=False, 493 | **config): 494 | super(DoublePendulumEncoder, self).__init__() 495 | self.za_dim = za_dim 496 | self.ze_dim = ze_dim 497 | in_dim = 2 # cos and sin 498 | self.obtain_init_position = obtain_init_position 499 | self.initial_guess = initial_guess 500 | self.total_time = nb_observed_theta_0 + nb_observed_theta_1 501 | self.nb_observed_theta_0 = nb_observed_theta_0 502 | self.nb_observed_theta_1 = nb_observed_theta_1 503 | self.simple_encoder = config.get("simple_encoder", False) 504 | 505 | lis = [self.total_time*in_dim] + layers 506 | if self.simple_encoder: 507 | los = layers + [za_dim * 2 + ze_dim*2] 508 | else: 509 | los = layers + [za_dim*2] 510 | layers_nn = [] 511 | for li, lo in zip(lis, los): 512 | layers_nn += [nn.Linear(li, lo), nn.SELU()] 513 | layers_nn.pop() 514 | self.enc_za = nn.Sequential(*layers_nn) 515 | 516 | lis = [self.total_time * in_dim + za_dim] + layers 517 | los = layers + [self.total_time * in_dim] 518 | layers_nn = [] 519 | for li, lo in zip(lis, los): 520 | layers_nn += [nn.Linear(li, lo), nn.SELU()] 521 | layers_nn.pop() 522 | if self.simple_encoder: 523 | self.clean_x = lambda x: x 524 | else: 525 | self.clean_x = nn.Sequential(*layers_nn) 526 | 527 | if not self.simple_encoder: 528 | lis = [self.total_time*in_dim] + layers 529 | los = layers + [ze_dim*2] 530 | layers_nn = [] 531 | for li, lo in zip(lis, los): 532 | layers_nn += [nn.Linear(li, lo), nn.SELU()] 533 | layers_nn.pop() 534 | self.enc_ze = nn.Sequential(*layers_nn) 535 | 536 | 537 | def forward(self, t, x): 538 | in_s = t.shape[0] 539 | frequency = 1 / t[1] 540 | b_size = x.shape[0] 541 | x_masked = torch.cat((x[:, -self.nb_observed_theta_0:, 0], x[:, -self.nb_observed_theta_1:, 1]), 1).reshape(b_size, -1) 542 | 543 | sin_cos_encoded = torch.cat((torch.sin(x_masked), torch.cos(x_masked)), 1) 544 | q_z_a = self.enc_za(sin_cos_encoded) 545 | if self.simple_encoder: 546 | q_z_e = q_z_a[:, :2*self.ze_dim] 547 | q_z_a = q_z_a[:, 2 * self.ze_dim:] 548 | mu_z_a, log_sigma_z_a = torch.chunk(q_z_a, 2, 1) 549 | mu_z_e, log_sigma_z_e = torch.chunk(q_z_e, 2, 1) 550 | z_e = mu_z_e + torch.randn_like(mu_z_e) * torch.exp(log_sigma_z_e) 551 | if self.obtain_init_position: 552 | z_e = torch.cat([x[:, 0, :2, 0, 0], z_e[:, 2:]], 1) 553 | z_a = mu_z_a + torch.randn_like(mu_z_a) * torch.exp(log_sigma_z_a) 554 | return z_e, z_a, q_z_e, q_z_a, sin_cos_encoded 555 | 556 | mu_z_a, log_sigma_z_a = torch.chunk(q_z_a, 2, 1) 557 | z_a = mu_z_a + torch.randn_like(mu_z_a) * torch.exp(log_sigma_z_a) 558 | 559 | x_clean = sin_cos_encoded + self.clean_x(torch.cat([sin_cos_encoded, z_a], 1)) 560 | 561 | q_z_e = self.enc_ze(x_clean) 562 | mu_z_e, log_sigma_z_e = torch.chunk(q_z_e, 2, 1) 563 | z_e = mu_z_e + torch.randn_like(mu_z_e) * torch.exp(log_sigma_z_e) 564 | 565 | if self.obtain_init_position: 566 | z_e = torch.cat([x[:, 0, :2, 0, 0], z_e[:, 2:]], 1) 567 | 568 | z_a = None if self.za_dim == 0 else z_a 569 | return z_e, z_a, q_z_e, q_z_a, x_clean 570 | 571 | 572 | class HybridVAEDoublePendulum(HybridAutoencoder): 573 | def __init__(self, fp: PhysicalModel, device="cpu", **config): 574 | super(HybridVAEDoublePendulum, self).__init__() 575 | self.device = device 576 | self.nb_observed = config.get("nb_observed", 25) 577 | self.nb_observed_theta_0 = config.get("nb_observed_theta_0", 25) 578 | self.nb_observed_theta_1 = config.get("nb_observed_theta_1", 25) 579 | self.za_dim = config.get("za_dim", 5) 580 | self.zp_dim = 4 581 | factor = 1 582 | self.alpha = config.get("alpha", .01 * factor) 583 | self.beta = config.get("beta", .001 * factor) 584 | self.gamma = config.get("gamma", .1 * factor) 585 | self.posterior_type = config.get("posterior_type", "dirac") 586 | self.zp_prior_type = config.get("zp_prior_type", "Normal") 587 | self.obtain_init_position = config.get("obtain_init_position", False) 588 | self.use_complete_signal = config.get("use_complete_signal", False) 589 | self.no_fa = config.get("no_fa", False) 590 | self.no_fp = config.get("no_fp", False) 591 | 592 | seq_dim_cos_sin = 2 * (self.nb_observed_theta_0 + self.nb_observed_theta_1) 593 | self.x_dim = 4 * self.nb_observed 594 | 595 | mu_prior_zp = [] 596 | sigma_prior_zp = [] 597 | min_zp = [] 598 | max_zp = [] 599 | for k, v in config["zp_priors"].items(): 600 | if self.zp_prior_type == "Normal": 601 | mu_prior_zp.append(v["mu"]) 602 | sigma_prior_zp.append(v["sigma"]) 603 | min_zp.append(v["min"]) 604 | max_zp.append(v["max"]) 605 | if self.zp_prior_type == "Normal": 606 | self.mu_prior_zp, self.sigma_prior_zp = torch.tensor(mu_prior_zp, device=device), \ 607 | torch.tensor(sigma_prior_zp, device=device) 608 | self.min_zp, self.max_zp = torch.tensor(min_zp, device=device), torch.tensor(max_zp, device=device) 609 | 610 | # Fp encoders 611 | gp_1_hidden = [200, 200, 200] if "gp_1_hidden" not in config else config["gp_1_hidden"] 612 | gp_2_hidden = [200, 200, 200] if "gp_2_hidden" not in config else config["gp_2_hidden"] 613 | # We use the cardinal coordinates rather than polar ones. 614 | self.gp_1 = MLP([seq_dim_cos_sin + self.za_dim] + gp_1_hidden + [seq_dim_cos_sin], 0, "SELU", None).to(device) 615 | self.gp_2 = MLP([seq_dim_cos_sin] + gp_2_hidden + [self.zp_dim * 2], 0, "SELU", None).to(device) 616 | 617 | def act_double_pendulum_parameters(mu_p): 618 | return torch.cat([math.pi*(2*torch.sigmoid(mu_p[:, :2]) - 1.), mu_p[:, 2:]], 1) 619 | 620 | self.act_mu_p = nn.Identity() #if "act_mu_p" not in config else act_dict[config["act_mu_p"]] 621 | 622 | # Fa Encoder 623 | ga_hidden = [300] * 3 if "ga_hidden" not in config else config["ga_hidden"] 624 | self.ga = MLP([seq_dim_cos_sin] + ga_hidden + [self.za_dim * 2], 0, "SELU", None).to(device) 625 | self.enc = DoublePendulumEncoder(layers=ga_hidden, ze_dim=self.zp_dim, 626 | **config).to(device) 627 | 628 | # Hybrid decoder 629 | fa_hidden = 3 * [300] if "fa_hidden" not in config else config["fa_hidden"] 630 | param_decoder = {"fp": fp.to(device), 631 | "layers_fa": [self.za_dim + 6] + fa_hidden + [4], 632 | "fa_hidden_act": "SELU", 633 | "fa_final_act": None} 634 | fa = MLP(param_decoder["layers_fa"], hidden_act=param_decoder["fa_hidden_act"], 635 | final_act=param_decoder["fa_final_act"]) 636 | self.fp = NoSimulator() if self.no_fp else fp 637 | self.fa = fa 638 | self.sigma_x_cos = nn.Parameter(torch.zeros(2, requires_grad=True)).to(device) 639 | self.sigma_x_sin = nn.Parameter(torch.zeros(2, requires_grad=True)).to(device) 640 | 641 | if True: 642 | self.param_ode_solver = {"method": "dopri5", 643 | "rtol": 1e-8, 644 | "atol": 1e-8 645 | } 646 | else: 647 | self.param_ode_solver = {"method": "rk4", 648 | "rtol": 1e-5, 649 | "atol": 1e-5, 650 | "options": {"step_size": .0001} 651 | } 652 | # Fp only decoder 653 | self.fp = NoSimulator() if self.no_fp else fp 654 | self.fp_only = DynamicalPhysicalDecoder(fp) 655 | 656 | def ode_f(self, t, theta, z_a): 657 | theta_fp_fa, theta_fp_only = torch.chunk(theta, 2, 1) 658 | ode_fp_only = self.fp(t, theta_fp_only) 659 | if self.za_dim > 0 and not self.no_fa: 660 | ode_fp_fa = self.fp(t, theta_fp_fa) + self.fa(torch.cat((torch.sin(theta_fp_fa[:, :2]), 661 | torch.cos(theta_fp_fa[:, :2]), 662 | theta_fp_fa[:, 2:], z_a), 1)) 663 | 664 | else: 665 | ode_fp_fa = self.fp(t, theta_fp_fa) 666 | 667 | return torch.cat((ode_fp_fa, ode_fp_only), 1) 668 | 669 | def augmented_data(self, t_span, x): 670 | with torch.no_grad(): 671 | theta_0, z_a, q_z_e, q_z_a, x_clean = self.enc(t_span, x[:, :self.nb_observed]) 672 | 673 | mu_a, log_sigma_a = torch.chunk(q_z_a, 2, 1) 674 | z_a = mu_a#torch.randn_like(mu_a) 675 | resampled_theta_0 = torch.rand(x.shape[0], self.min_zp.shape[0], device=self.device) * \ 676 | (self.max_zp - self.min_zp) + self.min_zp 677 | 678 | dec = lambda t, theta: self.fp(t, theta) + self.fa(torch.cat((torch.sin(theta[:, :2]), 679 | torch.cos(theta[:, :2]), 680 | theta[:, 2:], z_a), 1)) 681 | x_pred = odeint(dec, resampled_theta_0, t_span, **self.param_ode_solver).permute(1, 0, 2)[:, :, :2].unsqueeze(3).unsqueeze(3) 682 | return x_pred, resampled_theta_0 683 | 684 | def norm_fa_from_sol(self, x_span, z_a): 685 | theta = x_span 686 | nb_observed = x_span.shape[1] if self.use_complete_signal else self.nb_observed 687 | if self.za_dim > 0 and not self.no_fa: 688 | norm_fa = self.fa(torch.cat((torch.sin(theta[:, :, :2, 0, 0]), 689 | torch.cos(theta[:, :, :2, 0, 0]), 690 | theta[:, :, 2:, 0, 0], 691 | z_a.unsqueeze(1).expand(-1, nb_observed, -1)), 2)).norm(2, dim=2).mean().item() 692 | else: 693 | norm_fa = 0. 694 | 695 | return norm_fa 696 | 697 | def constraint_traj_from_sol(self, t_eval, x, x_hat): 698 | x_hat = x_hat[:, :, :2] 699 | nb_observed = x.shape[1] if self.use_complete_signal else self.nb_observed 700 | diff_sin = (torch.sin(x[:, :nb_observed]) - torch.sin(x_hat[:, :nb_observed])) ** 2 701 | diff_cos = (torch.cos(x[:, :nb_observed]) - torch.cos(x_hat[:, :nb_observed])) ** 2 702 | l_traj = (diff_sin + diff_cos).mean(2).mean(1) 703 | return l_traj, t_eval 704 | 705 | def forward(self, t_span, x): 706 | z_e, z_a, _, _, _ = self.enc(t_span, x[:, :self.nb_observed]) 707 | ode_f = lambda t, theta: self.ode_f(t, theta, z_a) 708 | x_pred_hybrid = odeint(ode_f, z_e.repeat(1, 2), t_span, **self.param_ode_solver).permute(1, 0, 2)[:, :, :2].unsqueeze(3).unsqueeze(3) 709 | return t_span, x_pred_hybrid 710 | 711 | def detailed_forward(self, t_span, x): 712 | z_e, z_a, q_z_e, q_z_a, x_clean = self.enc(t_span, x[:, :self.nb_observed]) 713 | ode_f = lambda t, theta: self.ode_f(t, theta, z_a) 714 | mu_x_pred_tot = odeint(ode_f, z_e.repeat(1, 2), t_span, **self.param_ode_solver) 715 | x_pred_hybrid, x_pred_fp = torch.chunk(mu_x_pred_tot, 2, 2) 716 | x_pred_hybrid = x_pred_hybrid.permute(1, 0, 2)[:, :, :].unsqueeze(3).unsqueeze(3) 717 | x_pred_fp = x_pred_fp.permute(1, 0, 2)[:, :, :].unsqueeze(3).unsqueeze(3) 718 | return t_span, x_pred_hybrid, x_pred_fp, z_e, z_a, q_z_e, q_z_a, x_clean 719 | 720 | def predicted_parameters(self, t_span, x, zero_param=False): 721 | z_e, z_a, q_z_e, q_z_a, x_clean = self.enc(t_span, x[:, :self.nb_observed]) 722 | mu_z_e, log_sigma_z_e = torch.chunk(q_z_e, 2, 1) 723 | 724 | return mu_z_e 725 | 726 | def predicted_parameters_as_dict(self, t_span, x, zero_param=False): 727 | mu_p = self.predicted_parameters(t_span, x, zero_param) 728 | return {"\\theta_0": mu_p[:, [0]], "\\theta_1": mu_p[:, [1]], 729 | "\\dot \\theta_0": mu_p[:, [2]], "\\dot \\theta_1": mu_p[:, [3]]} 730 | 731 | def loss(self, t_span, x): 732 | return self.loss_augm(t_span, x, None) 733 | 734 | def loss_augm(self, t_span, x, zp): 735 | b_size = x.shape[0] 736 | nb_steps = t_span.shape[0] 737 | 738 | nb_observed = x.shape[1] if self.use_complete_signal else self.nb_observed 739 | t_span, x_pred_hybrid, x_pred_fp, z_e, z_a, q_z_e, q_z_a, x_clean = self.detailed_forward(t_span, x[:, :nb_observed]) 740 | mu_a, log_sigma_a = torch.chunk(q_z_a, 2, 1) 741 | sigma_a = torch.exp(log_sigma_a) 742 | x_pred_hybrid_all = x_pred_hybrid 743 | x_pred_hybrid = x_pred_hybrid[:, :, :2] 744 | x_pred_fp = x_pred_fp[:, :, :2] 745 | 746 | sigma_x_pred_tot_sin = torch.exp(self.sigma_x_sin).unsqueeze(0).unsqueeze(1).expand(b_size, nb_steps, 747 | -1).unsqueeze( 748 | 3).unsqueeze(3) 749 | sigma_x_pred_tot_cos = torch.exp(self.sigma_x_cos).unsqueeze(0).unsqueeze(1).expand(b_size, nb_steps, 750 | -1).unsqueeze( 751 | 3).unsqueeze(3) 752 | 753 | norm_fa = self.norm_fa_from_sol(x_pred_hybrid_all, mu_a) 754 | 755 | diff_sin = torch.distributions.Normal(loc=torch.sin(x_pred_hybrid[:, :nb_observed]), 756 | scale=sigma_x_pred_tot_sin).log_prob(torch.sin(x[:, :nb_observed])).sum(1).sum(1) 757 | diff_cos = torch.distributions.Normal(loc=torch.cos(x_pred_hybrid[:, :nb_observed]), 758 | scale=sigma_x_pred_tot_cos).log_prob(torch.cos(x[:, :nb_observed])).sum(1).sum(1) 759 | ll_traj = diff_sin + diff_cos 760 | 761 | KL_prior_posterior = kl_gaussians(mu_a, sigma_a, torch.zeros_like(mu_a), torch.ones_like(sigma_a)) 762 | 763 | # No term for z_p as we assume a uniform prior. 764 | ELBO = ll_traj - KL_prior_posterior 765 | 766 | # Regularizer for making the correction small. 767 | if self.alpha > 0.: 768 | bound_kl_physics_reg = (kl_gaussians(torch.flatten(torch.sin(x_pred_hybrid[:, :nb_observed]), 1), 769 | torch.flatten(sigma_x_pred_tot_sin, 1), 770 | torch.flatten(torch.sin(x_pred_fp[:, :nb_observed]), 1), 771 | torch.flatten(sigma_x_pred_tot_sin, 1)) \ 772 | + kl_gaussians(torch.flatten(torch.cos(x_pred_hybrid[:, :nb_observed]), 1), 773 | torch.flatten(sigma_x_pred_tot_cos, 1), 774 | torch.flatten(torch.cos(x_pred_fp[:, :nb_observed]), 1), 775 | torch.flatten(sigma_x_pred_tot_cos, 1)) \ 776 | + kl_gaussians(mu_a, sigma_a, torch.zeros_like(mu_a), torch.ones_like(sigma_a))).mean() 777 | else: 778 | bound_kl_physics_reg = 0. 779 | 780 | 781 | # Trying to make the two-step encoder as much related to the physics as possible 782 | if self.beta > 0.: 783 | x_pred_fp_detached = x_pred_fp.detach() 784 | 785 | x_pred_fp_detached_masked = torch.cat((x_pred_fp_detached[:, -self.nb_observed_theta_0:, 0], 786 | x_pred_fp_detached[:, -self.nb_observed_theta_1:, 1]), 1).reshape(b_size, -1) 787 | 788 | x_pred_fp_detached_masked_sin_cos_encoded = torch.cat((torch.sin(x_pred_fp_detached_masked), torch.cos(x_pred_fp_detached_masked)), 1) 789 | R_da_1 = ((x_clean - x_pred_fp_detached_masked_sin_cos_encoded) ** 2).sum(1).mean() 790 | else: 791 | R_da_1 = 0. 792 | 793 | # Forcing the second step to find the correct latent variables 794 | if self.gamma > 0.: 795 | resampled_zp = torch.rand(b_size, self.zp_dim, device=self.device) * (self.max_zp - self.min_zp) \ 796 | + self.min_zp 797 | 798 | mu_x_pred_fp = odeint(self.fp, resampled_zp, t_span[:self.nb_observed], **self.param_ode_solver) 799 | 800 | x_r_detached = mu_x_pred_fp.detach().requires_grad_(True).permute(1, 0, 2)[:, :, :2].unsqueeze(3).unsqueeze(3) 801 | x_r_detached_masked = torch.cat((x_r_detached[:, -self.nb_observed_theta_0:, 0], 802 | x_r_detached[:, -self.nb_observed_theta_1:, 1]), 1).reshape(b_size, -1) 803 | 804 | x_r_detached_masked_sin_cos_encoded = torch.cat((torch.sin(x_r_detached_masked), 805 | torch.cos(x_r_detached_masked)), 1) 806 | mu_p, log_sigma_p = torch.chunk(self.enc.enc_ze(x_r_detached_masked_sin_cos_encoded), 2, 1) 807 | 808 | R_da_2 = ((mu_p - resampled_zp) ** 2).sum(1).mean() 809 | else: 810 | R_da_2 = 0. 811 | 812 | loss = -ELBO.mean() + self.alpha * bound_kl_physics_reg + self.beta * R_da_1 + self.gamma * R_da_2 813 | return loss, torch.tensor(norm_fa), torch.tensor(-1.) 814 | 815 | --------------------------------------------------------------------------------