├── .github └── workflows │ └── python-package.yml ├── .gitignore ├── LICENSE ├── README.md ├── doc └── fae.png ├── experiments ├── configs │ ├── config_cnn.yaml │ ├── config_dirac_fae.yaml │ ├── config_dirac_vano.yaml │ ├── config_fae.yaml │ ├── config_fae_timing.yaml │ ├── config_sde1d.yaml │ └── config_sde2d.yaml ├── custom_decoders.py ├── custom_encoders.py ├── exp_baseline_comparisons │ ├── main.py │ └── visualizations.ipynb ├── exp_dirac │ ├── main.py │ └── visualizations.ipynb ├── exp_rec_mse_vs_downsample_ratio │ ├── main.py │ └── visualizations.ipynb ├── exp_rec_mse_vs_point_ratio │ ├── main.py │ ├── plots_rec_mse_vs_point_ratio.py │ └── visualizations.ipynb ├── exp_sde1d │ ├── main.py │ ├── plots_sde1d.py │ └── visualizations.ipynb ├── exp_sde2d │ ├── main.py │ ├── plots_sde2d.py │ └── visualizations.ipynb ├── exp_sparse_training │ ├── main.py │ ├── vis_latent_interpolation.ipynb │ ├── vis_patch_superresolution.ipynb │ ├── vis_reconstruction.ipynb │ ├── vis_samples_generation.ipynb │ └── vis_very_high_resolution.ipynb ├── exp_sparse_vs_dense_wall_clock_training │ ├── main.py │ └── visualizations.ipynb ├── exp_train_vs_inference_wall_clock │ ├── main.py │ └── visualizations.ipynb ├── main_run.py ├── plots.py ├── trainer_loader.py └── util_exp.py ├── pyproject.toml ├── quickstart ├── 0_Getting_Started.ipynb ├── 1_FVAE.ipynb ├── 2_FAE.ipynb ├── 3_Custom_Datasets.ipynb ├── 4_Custom_Architectures.ipynb └── images │ ├── FAE_inpainting_superresolution.png │ ├── FAE_self-supervised_training.png │ ├── FVAE_decoder.png │ ├── FVAE_encoder.png │ ├── algorithms_on_fn_space.png │ ├── fae_or_fvae_flowchart.png │ └── fixed-res_algorithms.png ├── requirements.txt ├── src └── functional_autoencoders │ ├── __init__.py │ ├── autoencoder │ └── __init__.py │ ├── datasets │ ├── __init__.py │ ├── darcy_flow.py │ ├── dirac.py │ ├── navier_stokes.py │ ├── sde.py │ └── vano.py │ ├── decoders │ ├── __init__.py │ ├── cnn_decoder.py │ ├── fno_decoder.py │ ├── linear_decoder.py │ └── nonlinear_decoder.py │ ├── domains │ ├── __init__.py │ ├── grid.py │ └── off_grid.py │ ├── encoders │ ├── __init__.py │ ├── cnn_encoder.py │ ├── fno_encoder.py │ ├── lno_encoder.py │ ├── mlp_encoder.py │ └── pooling_encoder.py │ ├── losses │ ├── __init__.py │ ├── fae.py │ ├── fvae_sde.py │ └── vano.py │ ├── positional_encodings │ └── __init__.py │ ├── samplers │ ├── __init__.py │ ├── sampler_gmm.py │ └── sampler_vae.py │ ├── train │ ├── __init__.py │ ├── autoencoder_trainer.py │ └── metrics.py │ └── util │ ├── __init__.py │ ├── anti_aliasing.py │ ├── fft.py │ ├── masks.py │ ├── networks │ ├── __init__.py │ ├── fno.py │ ├── lno.py │ └── pooling.py │ ├── pca.py │ └── random │ ├── __init__.py │ ├── grf.py │ └── sde.py └── tests ├── __init__.py ├── dst.py └── sobolev.py /.github/workflows/python-package.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python 3 | 4 | name: Python package 5 | 6 | on: 7 | push: 8 | branches: [ "main" ] 9 | pull_request: 10 | branches: [ "main" ] 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | strategy: 17 | fail-fast: false 18 | matrix: 19 | python-version: ["3.11"] 20 | 21 | steps: 22 | - uses: actions/checkout@v3 23 | - name: Set up Python ${{ matrix.python-version }} 24 | uses: actions/setup-python@v3 25 | with: 26 | python-version: ${{ matrix.python-version }} 27 | - name: Install dependencies 28 | run: | 29 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 30 | - name: Run unit tests 31 | run: | 32 | cd tests 33 | python dst.py 34 | python sobolev.py 35 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | data 3 | test*.ipynb 4 | *.sqlite3 5 | *.db 6 | *.out 7 | runs 8 | ckpts 9 | tmp/ 10 | figures/ 11 | todo.txt 12 | *.png 13 | !doc/*.png 14 | !quickstart/images/*.png 15 | *.pdf 16 | experiments/scripts/dirac/*results*.yaml 17 | *.pkl 18 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2024 Justin Bunker, Mark Girolami, Hefin Lambley, Andrew M. Stuart, and T. J. Sullivan 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # `functional_autoencoders`: Autoencoders in Function Space 2 | 3 | ![FAE](doc/fae.png) 4 | 5 | This is the official code repository accompanying the paper: 6 | 7 | > **Autoencoders in Function Space** 8 | > 9 | > Justin Bunker, Mark Girolami, Hefin Lambley, Andrew M. Stuart and T. J. Sullivan (2024). 10 | > 11 | > [arXiv:2408.01362](https://arxiv.org/abs/2408.01362). 12 | 13 | The `functional_autoencoders` module contains implementations of 14 | 1. **Functional Variational Autoencoder (FVAE)**, an extension of variational autoencoders (VAEs) to functional data; and 15 | 2. **Functional Autoencoder (FAE)**, a regularised nonprobabilistic autoencoder for functional data. 16 | 17 | 18 | ## Quickstart 19 | 20 | **If you want to install the `functional_autoencoders` package** (e.g., to use in your own projects and notebooks): clone the repository and install the package using `pip` with 21 | 22 | git clone https://github.com/htlambley/functional_autoencoders 23 | cd functional_autoencoders 24 | pip install . 25 | 26 | 27 | You can then import the `functional_autoencoders` package in your own scripts and notebooks. 28 | To get started, why not follow one of our quickstart notebooks: 29 | head to the [getting started notebook](./quickstart/0_Getting_Started.ipynb) (`quickstart/0_Geting_Started.ipynb`) for an introduction to `functional_autoencoders`, and a guide on when to use each model. 30 | Alternatively, you can go straight to 31 | 32 | - [an introduction to FVAE](./quickstart/1_FVAE.ipynb) (`quickstart/1_FVAE.ipynb`); or 33 | - [an introduction to FAE](./quickstart/2_FAE.ipynb) (`quickstart/2_FAE.ipynb`) 34 | 35 | depending on your interests, where you'll learn how to reproduce some of the results in [the paper](https://arxiv.org/abs/2408.01362). 36 | You can also learn how to use FVAE and FAE with [your own data](./quickstart/3_Custom_Datasets.ipynb) (`quickstart/3_Custom_Datasets.ipynb`) and [custom encoder/decoder architectures](./quickstart/4_Custom_Architectures.ipynb) (`quickstart/4_Custom_Architectures.ipynb`). 37 | 38 | ## Reproducing results in the paper 39 | 40 | If you want to reproduce the results from the paper without installing the `functional_autoencoders` package: 41 | clone the repository, install the dependencies, and run the main experimental script with 42 | 43 | git clone https://github.com/htlambley/functional_autoencoders 44 | cd functional_autoencoders 45 | pip install -r requirements.txt 46 | python experiments/main_run.py 47 | 48 | 49 | ## Citation 50 | 51 | You can cite the preprint with the following BibTeX/BibLaTeX entry: 52 | 53 | @misc{BunkerGirolamiLambleyStuartSullivan2024, 54 | author = {Bunker, Justin and Girolami, Mark and Lambley, Hefin and Stuart, Andrew M. and Sullivan, T. J.}, 55 | title = {Autoencoders in Function Space}, 56 | note = {arXiv Preprint arXiv:2408.01362} 57 | } 58 | 59 | Questions, comments, and suggestions for the code repository are welcome through the issue tracker on GitHub or via email to 60 | - [Justin Bunker](https://www.eng.cam.ac.uk/profiles/jb2200) (`jb2200@cantab.ac.uk`); and 61 | - [Hefin Lambley](https://warwick.ac.uk/htlambley) (`hefin.lambley@warwick.ac.uk`). -------------------------------------------------------------------------------- /doc/fae.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htlambley/functional_autoencoders/540179aa43325e9fd96185cd4db738cf1dcd023f/doc/fae.png -------------------------------------------------------------------------------- /experiments/configs/config_cnn.yaml: -------------------------------------------------------------------------------- 1 | encoder: 2 | latent_dim: 64 3 | is_variational: false 4 | type: cnn 5 | options: 6 | pooling: 7 | mlp_dim: 64 8 | mlp_n_hidden_layers: 3 9 | dirac: 10 | features: 11 | - 128 12 | - 128 13 | - 128 14 | cnn: 15 | cnn_features: 16 | - 4 17 | - 4 18 | - 8 19 | - 16 20 | kernel_sizes: 21 | - 2 22 | - 2 23 | - 4 24 | - 4 25 | strides: 26 | - 2 27 | - 2 28 | - 2 29 | - 2 30 | mlp_features: 31 | - 64 32 | 33 | decoder: 34 | type: cnn 35 | options: 36 | nonlinear: 37 | out_dim: 1 38 | features: 39 | - 100 40 | - 100 41 | - 100 42 | - 100 43 | - 100 44 | dirac: 45 | features: 46 | - 128 47 | - 128 48 | - 128 49 | cnn: 50 | trans_cnn_features: 51 | - 16 52 | - 8 53 | - 4 54 | - 4 55 | kernel_sizes: 56 | - 4 57 | - 4 58 | - 2 59 | - 2 60 | strides: 61 | - 2 62 | - 2 63 | - 2 64 | - 2 65 | mlp_features: 66 | - 64 67 | final_cnn_features: 68 | - 8 69 | - 1 70 | final_kernel_sizes: 71 | - 3 72 | - 3 73 | final_strides: 74 | - 1 75 | - 1 76 | c_in: 32 77 | grid_pts_in: 4 78 | 79 | domain: 80 | type: off_grid_randomly_sampled_euclidean 81 | options: 82 | grid_zero_boundary_conditions: 83 | s: 0 84 | off_grid_randomly_sampled_euclidean: 85 | s: 0 86 | off_grid_sde: 87 | ~ 88 | 89 | loss: 90 | type: fae 91 | options: 92 | fae: 93 | beta: 0.001 94 | subtract_data_norm: false 95 | vano: 96 | beta: 0.001 97 | n_monte_carlo_samples: 4 98 | normalised_inner_prod: true 99 | rescale_by_norm: false 100 | fvae_sde: 101 | beta: 1 102 | theta: 0 103 | zero_penalty: 10 104 | n_monte_carlo_samples: 4 105 | 106 | positional_encoding: 107 | is_used: False 108 | dim: 32 109 | 110 | trainer: 111 | max_step: 50_000 112 | lr: 0.001 113 | lr_decay_step: 1000 114 | lr_decay_factor: 0.98 115 | eval_interval: 10 116 | -------------------------------------------------------------------------------- /experiments/configs/config_dirac_fae.yaml: -------------------------------------------------------------------------------- 1 | encoder: 2 | latent_dim: 1 3 | is_variational: false 4 | type: dirac 5 | options: 6 | pooling: 7 | mlp_dim: 64 8 | mlp_n_hidden_layers: 3 9 | dirac: 10 | features: 11 | - 128 12 | - 128 13 | - 128 14 | cnn: 15 | cnn_features: 16 | - 4 17 | - 4 18 | - 8 19 | - 16 20 | kernel_sizes: 21 | - 2 22 | - 2 23 | - 4 24 | - 4 25 | strides: 26 | - 2 27 | - 2 28 | - 2 29 | - 2 30 | mlp_features: 31 | - 64 32 | 33 | decoder: 34 | type: dirac 35 | options: 36 | nonlinear: 37 | out_dim: 1 38 | features: 39 | - 100 40 | - 100 41 | - 100 42 | - 100 43 | - 100 44 | dirac: 45 | features: 46 | - 128 47 | - 128 48 | - 128 49 | cnn: 50 | trans_cnn_features: 51 | - 16 52 | - 8 53 | - 4 54 | - 4 55 | kernel_sizes: 56 | - 4 57 | - 4 58 | - 2 59 | - 2 60 | strides: 61 | - 2 62 | - 2 63 | - 2 64 | - 2 65 | mlp_features: 66 | - 64 67 | final_cnn_features: 68 | - 8 69 | - 1 70 | final_kernel_sizes: 71 | - 3 72 | - 3 73 | final_strides: 74 | - 1 75 | - 1 76 | c_in: 32 77 | grid_pts_in: 4 78 | 79 | domain: 80 | type: grid_zero_boundary_conditions 81 | options: 82 | grid_zero_boundary_conditions: 83 | s: -1 84 | off_grid_randomly_sampled_euclidean: 85 | s: 0 86 | off_grid_sde: 87 | ~ 88 | 89 | loss: 90 | type: fae 91 | options: 92 | fae: 93 | beta: 0.000_000_000_001 94 | subtract_data_norm: true 95 | vano: 96 | beta: 0.0001 97 | n_monte_carlo_samples: 4 98 | normalised_inner_prod: true 99 | rescale_by_norm: false 100 | fvae_sde: 101 | beta: 1 102 | theta: 0 103 | zero_penalty: 10 104 | n_monte_carlo_samples: 16 105 | 106 | positional_encoding: 107 | is_used: false 108 | dim: 32 109 | 110 | trainer: 111 | max_step: 30_000 112 | lr: 0.0001 113 | lr_decay_step: 1000 114 | lr_decay_factor: 0.9 115 | eval_interval: 10 116 | -------------------------------------------------------------------------------- /experiments/configs/config_dirac_vano.yaml: -------------------------------------------------------------------------------- 1 | encoder: 2 | latent_dim: 1 3 | is_variational: True 4 | type: dirac 5 | options: 6 | pooling: 7 | mlp_dim: 64 8 | mlp_n_hidden_layers: 3 9 | dirac: 10 | features: 11 | - 128 12 | - 128 13 | - 128 14 | cnn: 15 | cnn_features: 16 | - 4 17 | - 4 18 | - 8 19 | - 16 20 | kernel_sizes: 21 | - 2 22 | - 2 23 | - 4 24 | - 4 25 | strides: 26 | - 2 27 | - 2 28 | - 2 29 | - 2 30 | mlp_features: 31 | - 64 32 | 33 | decoder: 34 | type: dirac 35 | options: 36 | nonlinear: 37 | out_dim: 1 38 | features: 39 | - 100 40 | - 100 41 | - 100 42 | - 100 43 | - 100 44 | dirac: 45 | features: 46 | - 128 47 | - 128 48 | - 128 49 | cnn: 50 | trans_cnn_features: 51 | - 16 52 | - 8 53 | - 4 54 | - 4 55 | kernel_sizes: 56 | - 4 57 | - 4 58 | - 2 59 | - 2 60 | strides: 61 | - 2 62 | - 2 63 | - 2 64 | - 2 65 | mlp_features: 66 | - 64 67 | final_cnn_features: 68 | - 8 69 | - 1 70 | final_kernel_sizes: 71 | - 3 72 | - 3 73 | final_strides: 74 | - 1 75 | - 1 76 | c_in: 32 77 | grid_pts_in: 4 78 | 79 | domain: 80 | type: grid_zero_boundary_conditions 81 | options: 82 | grid_zero_boundary_conditions: 83 | s: -1 84 | off_grid_randomly_sampled_euclidean: 85 | s: 0 86 | off_grid_sde: 87 | ~ 88 | 89 | loss: 90 | type: vano 91 | options: 92 | fae: 93 | beta: 0.001 94 | subtract_data_norm: true 95 | vano: 96 | beta: 0.0001 97 | n_monte_carlo_samples: 4 98 | normalised_inner_prod: true 99 | rescale_by_norm: false 100 | fvae_sde: 101 | beta: 1 102 | theta: 0 103 | zero_penalty: 10 104 | n_monte_carlo_samples: 16 105 | 106 | positional_encoding: 107 | is_used: false 108 | dim: 32 109 | 110 | trainer: 111 | max_step: 30_000 112 | lr: 0.0001 113 | lr_decay_step: 1000 114 | lr_decay_factor: 0.7 115 | eval_interval: 10 116 | -------------------------------------------------------------------------------- /experiments/configs/config_fae.yaml: -------------------------------------------------------------------------------- 1 | encoder: 2 | latent_dim: 64 3 | is_variational: false 4 | type: pooling 5 | options: 6 | pooling: 7 | mlp_dim: 64 8 | mlp_n_hidden_layers: 3 9 | dirac: 10 | features: 11 | - 128 12 | - 128 13 | - 128 14 | cnn: 15 | cnn_features: 16 | - 4 17 | - 4 18 | - 8 19 | - 16 20 | kernel_sizes: 21 | - 2 22 | - 2 23 | - 4 24 | - 4 25 | strides: 26 | - 2 27 | - 2 28 | - 2 29 | - 2 30 | mlp_features: 31 | - 64 32 | 33 | decoder: 34 | type: nonlinear 35 | options: 36 | nonlinear: 37 | out_dim: 1 38 | features: 39 | - 100 40 | - 100 41 | - 100 42 | - 100 43 | - 100 44 | dirac: 45 | features: 46 | - 128 47 | - 128 48 | - 128 49 | cnn: 50 | trans_cnn_features: 51 | - 16 52 | - 8 53 | - 4 54 | - 4 55 | kernel_sizes: 56 | - 4 57 | - 4 58 | - 2 59 | - 2 60 | strides: 61 | - 2 62 | - 2 63 | - 2 64 | - 2 65 | mlp_features: 66 | - 64 67 | final_cnn_features: 68 | - 8 69 | - 1 70 | final_kernel_sizes: 71 | - 3 72 | - 3 73 | final_strides: 74 | - 1 75 | - 1 76 | c_in: 32 77 | grid_pts_in: 4 78 | 79 | domain: 80 | type: off_grid_randomly_sampled_euclidean 81 | options: 82 | grid_zero_boundary_conditions: 83 | s: 0 84 | off_grid_randomly_sampled_euclidean: 85 | s: 0 86 | off_grid_sde: 87 | ~ 88 | 89 | loss: 90 | type: fae 91 | options: 92 | fae: 93 | beta: 0.001 94 | subtract_data_norm: false 95 | vano: 96 | beta: 0.001 97 | n_monte_carlo_samples: 4 98 | normalised_inner_prod: true 99 | rescale_by_norm: false 100 | fvae_sde: 101 | beta: 1 102 | theta: 0 103 | zero_penalty: 10 104 | n_monte_carlo_samples: 4 105 | 106 | positional_encoding: 107 | is_used: true 108 | dim: 32 109 | 110 | trainer: 111 | max_step: 50_000 112 | lr: 0.001 113 | lr_decay_step: 1000 114 | lr_decay_factor: 0.98 115 | eval_interval: 10 116 | -------------------------------------------------------------------------------- /experiments/configs/config_fae_timing.yaml: -------------------------------------------------------------------------------- 1 | encoder: 2 | latent_dim: 64 3 | is_variational: false 4 | type: pooling 5 | options: 6 | pooling: 7 | mlp_dim: 64 8 | mlp_n_hidden_layers: 3 9 | dirac: 10 | features: 11 | - 128 12 | - 128 13 | - 128 14 | cnn: 15 | cnn_features: 16 | - 4 17 | - 4 18 | - 8 19 | - 16 20 | kernel_sizes: 21 | - 2 22 | - 2 23 | - 4 24 | - 4 25 | strides: 26 | - 2 27 | - 2 28 | - 2 29 | - 2 30 | mlp_features: 31 | - 64 32 | 33 | decoder: 34 | type: nonlinear 35 | options: 36 | nonlinear: 37 | out_dim: 1 38 | features: 39 | - 100 40 | - 100 41 | - 100 42 | - 100 43 | - 100 44 | dirac: 45 | features: 46 | - 128 47 | - 128 48 | - 128 49 | cnn: 50 | trans_cnn_features: 51 | - 16 52 | - 8 53 | - 4 54 | - 4 55 | kernel_sizes: 56 | - 4 57 | - 4 58 | - 2 59 | - 2 60 | strides: 61 | - 2 62 | - 2 63 | - 2 64 | - 2 65 | mlp_features: 66 | - 64 67 | final_cnn_features: 68 | - 8 69 | - 1 70 | final_kernel_sizes: 71 | - 3 72 | - 3 73 | final_strides: 74 | - 1 75 | - 1 76 | c_in: 32 77 | grid_pts_in: 4 78 | 79 | domain: 80 | type: off_grid_randomly_sampled_euclidean 81 | options: 82 | grid_zero_boundary_conditions: 83 | s: 0 84 | off_grid_randomly_sampled_euclidean: 85 | s: 0 86 | off_grid_sde: 87 | ~ 88 | 89 | loss: 90 | type: fae 91 | options: 92 | fae: 93 | subtract_data_norm: false 94 | beta: 0.001 95 | vano: 96 | beta: 0.001 97 | n_monte_carlo_samples: 4 98 | normalised_inner_prod: true 99 | rescale_by_norm: false 100 | fvae_sde: 101 | beta: 1 102 | theta: 0 103 | zero_penalty: 10 104 | n_monte_carlo_samples: 4 105 | 106 | positional_encoding: 107 | is_used: true 108 | dim: 32 109 | 110 | trainer: 111 | max_step: 3_500 112 | lr: 0.001 113 | lr_decay_step: 1000 114 | lr_decay_factor: 0.98 115 | eval_interval: 1 116 | -------------------------------------------------------------------------------- /experiments/configs/config_sde1d.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | x0: 3 | - -1.0 4 | samples: 8192 5 | pts: 512 6 | T: 5 7 | epsilon: 1 8 | sim_dt: 0.0001220703125 # 1/8192 9 | point_ratio_train: 0.5 10 | batch_size: 32 11 | num_workers: 0 12 | 13 | encoder: 14 | latent_dim: 1 15 | is_variational: true 16 | type: pooling 17 | options: 18 | pooling: 19 | mlp_dim: 64 20 | mlp_n_hidden_layers: 3 21 | dirac: 22 | features: 23 | - 128 24 | - 128 25 | - 128 26 | cnn: 27 | cnn_features: 28 | - 4 29 | - 4 30 | - 8 31 | - 16 32 | kernel_sizes: 33 | - 2 34 | - 2 35 | - 4 36 | - 4 37 | strides: 38 | - 2 39 | - 2 40 | - 2 41 | - 2 42 | mlp_features: 43 | - 64 44 | 45 | decoder: 46 | type: nonlinear 47 | options: 48 | nonlinear: 49 | out_dim: 1 50 | features: 51 | - 100 52 | - 100 53 | - 100 54 | - 100 55 | - 100 56 | dirac: 57 | features: 58 | - 128 59 | - 128 60 | - 128 61 | cnn: 62 | trans_cnn_features: 63 | - 16 64 | - 8 65 | - 4 66 | - 4 67 | kernel_sizes: 68 | - 4 69 | - 4 70 | - 2 71 | - 2 72 | strides: 73 | - 2 74 | - 2 75 | - 2 76 | - 2 77 | mlp_features: 78 | - 64 79 | final_cnn_features: 80 | - 8 81 | - 1 82 | final_kernel_sizes: 83 | - 3 84 | - 3 85 | final_strides: 86 | - 1 87 | - 1 88 | c_in: 32 89 | grid_pts_in: 4 90 | 91 | domain: 92 | type: off_grid_sde 93 | options: 94 | grid_zero_boundary_conditions: 95 | s: 0 96 | off_grid_randomly_sampled_euclidean: 97 | s: 0 98 | off_grid_sde: 99 | ~ 100 | 101 | loss: 102 | type: fvae_sde 103 | options: 104 | fae: 105 | beta: 0.001 106 | subtract_data_norm: false 107 | vano: 108 | beta: 0.001 109 | n_monte_carlo_samples: 4 110 | normalised_inner_prod: true 111 | rescale_by_norm: false 112 | fvae_sde: 113 | beta: 1.2 114 | theta: 0 115 | zero_penalty: 10 116 | n_monte_carlo_samples: 4 117 | 118 | positional_encoding: 119 | is_used: false 120 | dim: 32 121 | 122 | trainer: 123 | max_step: 100_000 124 | lr: 0.001 125 | lr_decay_step: 1000 126 | lr_decay_factor: 0.98 127 | eval_interval: 10 128 | -------------------------------------------------------------------------------- /experiments/configs/config_sde2d.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | x0: 3 | - 0.0 4 | - 0.0 5 | samples: 16384 # 2048*8 6 | pts: 512 7 | T: 3 8 | epsilon: 0.1 9 | sim_dt: 0.0001220703125 # 1/8192 10 | point_ratio_train: 0.5 11 | batch_size: 32 12 | num_workers: 0 13 | 14 | encoder: 15 | latent_dim: 16 16 | is_variational: true 17 | type: pooling 18 | options: 19 | pooling: 20 | mlp_dim: 64 21 | mlp_n_hidden_layers: 3 22 | dirac: 23 | features: 24 | - 128 25 | - 128 26 | - 128 27 | cnn: 28 | cnn_features: 29 | - 4 30 | - 4 31 | - 8 32 | - 16 33 | kernel_sizes: 34 | - 2 35 | - 2 36 | - 4 37 | - 4 38 | strides: 39 | - 2 40 | - 2 41 | - 2 42 | - 2 43 | mlp_features: 44 | - 64 45 | 46 | decoder: 47 | type: nonlinear 48 | options: 49 | nonlinear: 50 | out_dim: 2 51 | features: 52 | - 100 53 | - 100 54 | - 100 55 | - 100 56 | - 100 57 | dirac: 58 | features: 59 | - 128 60 | - 128 61 | - 128 62 | cnn: 63 | trans_cnn_features: 64 | - 16 65 | - 8 66 | - 4 67 | - 4 68 | kernel_sizes: 69 | - 4 70 | - 4 71 | - 2 72 | - 2 73 | strides: 74 | - 2 75 | - 2 76 | - 2 77 | - 2 78 | mlp_features: 79 | - 64 80 | final_cnn_features: 81 | - 8 82 | - 1 83 | final_kernel_sizes: 84 | - 3 85 | - 3 86 | final_strides: 87 | - 1 88 | - 1 89 | c_in: 32 90 | grid_pts_in: 4 91 | 92 | domain: 93 | type: off_grid_sde 94 | options: 95 | grid_zero_boundary_conditions: 96 | s: 0 97 | off_grid_randomly_sampled_euclidean: 98 | s: 0 99 | off_grid_sde: 100 | ~ 101 | 102 | loss: 103 | type: fvae_sde 104 | options: 105 | fae: 106 | beta: 0.001 107 | subtract_data_norm: false 108 | vano: 109 | beta: 0.001 110 | n_monte_carlo_samples: 4 111 | normalised_inner_prod: true 112 | rescale_by_norm: false 113 | fvae_sde: 114 | beta: 10 115 | theta: 50 116 | zero_penalty: 50 117 | n_monte_carlo_samples: 4 118 | 119 | positional_encoding: 120 | is_used: false 121 | dim: 32 122 | 123 | trainer: 124 | max_step: 100_000 125 | lr: 0.001 126 | lr_decay_step: 1000 127 | lr_decay_factor: 0.98 128 | eval_interval: 10 129 | -------------------------------------------------------------------------------- /experiments/custom_decoders.py: -------------------------------------------------------------------------------- 1 | import flax.linen as nn 2 | import jax 3 | import jax.numpy as jnp 4 | from functools import partial 5 | from typing import Sequence, Callable 6 | from functional_autoencoders.decoders import Decoder 7 | from functional_autoencoders.util.networks import MLP 8 | 9 | 10 | @partial(jax.vmap, in_axes=(0, 0, 0, None)) 11 | def _gaussian(centre, mass, std, x): 12 | """ 13 | Generates a Gaussian density with mean `centre`, standard deviation `std`, mass `mass`, 14 | evaluated at the mesh points `x`. 15 | 16 | Batched over `centre`, `mass`, `std` (with the 0 axis being the batch) but using an unbatched `x`. 17 | """ 18 | const = (2 * jnp.pi * std**2) ** (-0.5) 19 | return mass * const * jnp.exp(-((x - centre) ** 2) / (2 * std**2)) 20 | 21 | 22 | class DiracDecoder(Decoder): 23 | """ 24 | Decoder that outputs a (smoothed) Dirac delta function with learned centre and mass. 25 | The centre and mass are computed using an MLP. 26 | """ 27 | 28 | fixed_centre: bool 29 | features: Sequence[int] = (128, 128, 128) 30 | min_std: Callable[[float], float] = lambda dx: dx 31 | 32 | def setup(self): 33 | self.mlp = MLP( 34 | [*self.features, 2], 35 | ) 36 | 37 | def _forward(self, z, x, train=False): 38 | if x.shape[2] != 1: 39 | raise NotImplementedError() 40 | 41 | # Implictly assumes x is the same across the batch. 42 | dx = x[0, 1, 0] - x[0, 0, 0] 43 | centre, std, mass = self.get_params(z, dx) 44 | return _gaussian(centre, mass, std, x[0, :, :]) 45 | 46 | def get_params(self, z, dx): 47 | z = self.mlp(z) 48 | mass = jnp.ones_like(z[:, 0]) 49 | if self.fixed_centre: 50 | c = (((1.0 / dx) // 2) + 1) / (1.0 / dx) 51 | centre = c * jnp.ones_like(z[:, 0]) 52 | else: 53 | centre = nn.sigmoid(z[:, 0]) * (1 - 2 * dx) + dx 54 | std = self.min_std(dx) * jnp.ones_like(z[:, 1]) + nn.sigmoid(z[:, 1]) 55 | return centre, std, mass 56 | -------------------------------------------------------------------------------- /experiments/custom_encoders.py: -------------------------------------------------------------------------------- 1 | import flax.linen as nn 2 | import jax.numpy as jnp 3 | from typing import Sequence 4 | from dataclasses import field 5 | from functional_autoencoders.encoders import Encoder 6 | from functional_autoencoders.util.networks import MLP 7 | 8 | 9 | class DiracEncoder(Encoder): 10 | features: Sequence[int] = (128, 128, 128) 11 | latent_dim: int = 64 12 | mlp_args: dict = field(default_factory=dict) 13 | 14 | @nn.compact 15 | def __call__(self, u, x, train=False): 16 | u = jnp.reshape(u, (u.shape[0], -1)) 17 | u = jnp.reshape( 18 | jnp.float32(jnp.argmax(u, axis=1)) / u.shape[1], (u.shape[0], 1) 19 | ) 20 | 21 | d_out = self.latent_dim * 2 if self.is_variational else self.latent_dim 22 | u = MLP([*self.features, d_out], **self.mlp_args)(u) 23 | return u 24 | -------------------------------------------------------------------------------- /experiments/exp_baseline_comparisons/main.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append(".") 4 | import os 5 | import jax 6 | from functional_autoencoders.datasets import get_dataloaders 7 | from functional_autoencoders.datasets.navier_stokes import NavierStokes 8 | from functional_autoencoders.datasets.darcy_flow import DarcyFlow 9 | from functional_autoencoders.util import ( 10 | save_data_results, 11 | save_model_results, 12 | yaml_load, 13 | fit_trainer_using_config, 14 | ) 15 | from experiments.trainer_loader import get_trainer 16 | 17 | 18 | def run_baseline_comparisons( 19 | key, 20 | output_dir, 21 | config_path, 22 | n_runs, 23 | ns_viscosity, 24 | is_darcy=False, 25 | verbose="metrics", 26 | ): 27 | 28 | config = yaml_load(config_path) 29 | 30 | if not is_darcy: 31 | train_dataloader, test_dataloader = get_dataloaders( 32 | NavierStokes, 33 | data_base=".", 34 | viscosity=ns_viscosity, 35 | ) 36 | else: 37 | train_dataloader, test_dataloader = get_dataloaders( 38 | DarcyFlow, data_base=".", downscale=9 39 | ) 40 | 41 | for run_idx in range(n_runs): 42 | key, subkey = jax.random.split(key) 43 | trainer = get_trainer(subkey, config, train_dataloader, test_dataloader) 44 | 45 | key, subkey = jax.random.split(key) 46 | results = fit_trainer_using_config(subkey, trainer, config, verbose=verbose) 47 | 48 | save_model_results( 49 | autoencoder=trainer.autoencoder, 50 | results=results, 51 | model_dir=os.path.join(output_dir, "models", str(run_idx)), 52 | ) 53 | 54 | save_data_results( 55 | autoencoder=trainer.autoencoder, 56 | results=results, 57 | test_dataloader=test_dataloader, 58 | data_dir=os.path.join(output_dir, "data", str(run_idx)), 59 | ) 60 | -------------------------------------------------------------------------------- /experiments/exp_dirac/main.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append(".") 4 | import os 5 | import jax 6 | from functional_autoencoders.datasets import get_dataloaders 7 | from functional_autoencoders.datasets.dirac import RandomDirac 8 | from experiments.trainer_loader import get_trainer 9 | from functional_autoencoders.util import ( 10 | save_data_results, 11 | yaml_load, 12 | fit_trainer_using_config, 13 | ) 14 | 15 | 16 | def run_dirac(key, output_dir, config_path, n_runs, resolutions, verbose="metrics"): 17 | config = yaml_load(config_path) 18 | 19 | for resolution in resolutions: 20 | for run_idx in range(n_runs): 21 | train_dataloader, test_dataloader = get_dataloaders( 22 | RandomDirac, 23 | pts=resolution, 24 | fixed_centre=False, 25 | batch_size=1, 26 | ) 27 | 28 | key, subkey = jax.random.split(key) 29 | trainer = get_trainer(subkey, config, train_dataloader, test_dataloader) 30 | 31 | key, subkey = jax.random.split(key) 32 | results = fit_trainer_using_config(subkey, trainer, config, verbose=verbose) 33 | 34 | save_data_results( 35 | autoencoder=trainer.autoencoder, 36 | results=results, 37 | test_dataloader=test_dataloader, 38 | data_dir=os.path.join( 39 | output_dir, "data", str(resolution), str(run_idx) 40 | ), 41 | ) 42 | -------------------------------------------------------------------------------- /experiments/exp_rec_mse_vs_downsample_ratio/main.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append(".") 4 | import os 5 | import jax 6 | import jax.numpy as jnp 7 | from tqdm.auto import tqdm 8 | from functional_autoencoders.datasets import get_dataloaders, ComplementMasking 9 | from functional_autoencoders.datasets.navier_stokes import NavierStokes 10 | from experiments.trainer_loader import get_trainer 11 | from functional_autoencoders.util import ( 12 | save_data_results, 13 | save_model_results, 14 | get_raw_x, 15 | yaml_load, 16 | fit_trainer_using_config, 17 | ) 18 | 19 | 20 | def run_rec_mse_vs_downsample_ratio( 21 | key, 22 | output_dir, 23 | config_path, 24 | n_runs, 25 | ns_viscosity, 26 | downsample_ratios, 27 | enc_point_ratio_train, 28 | verbose="metrics", 29 | ): 30 | 31 | config = yaml_load(config_path) 32 | 33 | mask_train = ComplementMasking(enc_point_ratio_train) 34 | train_dataloader, test_dataloader = get_dataloaders( 35 | NavierStokes, data_base=".", viscosity=ns_viscosity, transform_train=mask_train 36 | ) 37 | 38 | for run_idx in range(n_runs): 39 | key, subkey = jax.random.split(key) 40 | trainer = get_trainer(subkey, config, train_dataloader, test_dataloader) 41 | 42 | key, subkey = jax.random.split(key) 43 | results = fit_trainer_using_config(subkey, trainer, config, verbose=verbose) 44 | 45 | mse_vs_size = get_mse_vs_size( 46 | autoencoder=trainer.autoencoder, 47 | state=results["state"], 48 | downsample_ratios=downsample_ratios, 49 | dataloader=test_dataloader, 50 | ) 51 | 52 | save_model_results( 53 | autoencoder=trainer.autoencoder, 54 | results=results, 55 | model_dir=os.path.join(output_dir, "models", str(run_idx)), 56 | ) 57 | 58 | save_data_results( 59 | autoencoder=trainer.autoencoder, 60 | results=results, 61 | test_dataloader=test_dataloader, 62 | data_dir=os.path.join(output_dir, "data", str(run_idx)), 63 | additional_data={"mse_vs_size": mse_vs_size}, 64 | ) 65 | 66 | 67 | def get_mse_vs_size(autoencoder, state, downsample_ratios, dataloader): 68 | mse_vs_size = {} 69 | for ratio in tqdm(downsample_ratios): 70 | total_mse = 0 71 | for u, x, _, _ in dataloader: 72 | n_batch = u.shape[0] 73 | n = int(u.shape[1] ** 0.5) 74 | u_down = u.reshape(n_batch, n, n)[:, ::ratio, ::ratio] 75 | n_down = u_down.shape[1] 76 | 77 | x_down_single_batch = get_raw_x(n_down, n_down).reshape(-1, 2) 78 | x_down = jnp.repeat(x_down_single_batch[None, ...], n_batch, axis=0) 79 | u_down = u_down.reshape(n_batch, -1, 1) 80 | 81 | vars = {"params": state.params, "batch_stats": state.batch_stats} 82 | u_hat = autoencoder.apply(vars, u_down, x_down, x) 83 | 84 | sum_batch_mse = jnp.sum(jnp.mean(jnp.sum((u - u_hat) ** 2, axis=2), axis=1)) 85 | total_mse += sum_batch_mse / len(dataloader) 86 | 87 | mse_vs_size[n_down] = total_mse / len(dataloader) 88 | return mse_vs_size 89 | -------------------------------------------------------------------------------- /experiments/exp_rec_mse_vs_point_ratio/main.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append(".") 4 | import os 5 | import jax 6 | import jax.numpy as jnp 7 | from tqdm.auto import tqdm 8 | from functional_autoencoders.datasets import get_dataloaders, ComplementMasking 9 | from functional_autoencoders.datasets.navier_stokes import NavierStokes 10 | from experiments.trainer_loader import get_trainer 11 | from functional_autoencoders.util import ( 12 | save_data_results, 13 | save_model_results, 14 | yaml_load, 15 | fit_trainer_using_config, 16 | ) 17 | 18 | 19 | def run_rec_mse_vs_point_ratio( 20 | key, 21 | output_dir, 22 | config_path, 23 | n_runs, 24 | ns_viscosity, 25 | enc_point_ratio_train_list, 26 | enc_point_ratio_test_list, 27 | verbose="metrics", 28 | ): 29 | 30 | config = yaml_load(config_path) 31 | 32 | for run_idx in range(n_runs): 33 | for enc_point_ratio_train in enc_point_ratio_train_list: 34 | mask_train = ComplementMasking(enc_point_ratio_train) 35 | train_dataloader, test_dataloader = get_dataloaders( 36 | NavierStokes, 37 | data_base=".", 38 | viscosity=ns_viscosity, 39 | transform_train=mask_train, 40 | ) 41 | 42 | key, subkey = jax.random.split(key) 43 | trainer = get_trainer(subkey, config, train_dataloader, test_dataloader) 44 | 45 | key, subkey = jax.random.split(key) 46 | results = fit_trainer_using_config(subkey, trainer, config, verbose=verbose) 47 | 48 | key, subkey = jax.random.split(key) 49 | mse_vs_point_ratio = get_mse_vs_point_ratio( 50 | key=subkey, 51 | autoencoder=trainer.autoencoder, 52 | state=results["state"], 53 | enc_point_ratio_test_list=enc_point_ratio_test_list, 54 | dataloader=test_dataloader, 55 | ) 56 | 57 | save_model_results( 58 | autoencoder=trainer.autoencoder, 59 | results=results, 60 | model_dir=os.path.join( 61 | output_dir, "models", str(run_idx), str(enc_point_ratio_train) 62 | ), 63 | ) 64 | 65 | save_data_results( 66 | autoencoder=trainer.autoencoder, 67 | results=results, 68 | test_dataloader=test_dataloader, 69 | data_dir=os.path.join( 70 | output_dir, "data", str(run_idx), str(enc_point_ratio_train) 71 | ), 72 | additional_data={ 73 | "mse_vs_point_ratio": mse_vs_point_ratio, 74 | "train_point_ratio": enc_point_ratio_train, 75 | }, 76 | ) 77 | 78 | 79 | def get_mse_vs_point_ratio( 80 | key, autoencoder, state, enc_point_ratio_test_list, dataloader 81 | ): 82 | mse_vs_point_ratio = {} 83 | for enc_point_ratio_test in tqdm(enc_point_ratio_test_list): 84 | sum_total_mse = 0 85 | for u, x, _, _ in dataloader: 86 | n_total_pts = u.shape[1] 87 | n_rand_pts = int(enc_point_ratio_test * n_total_pts) 88 | 89 | key, subkey = jax.random.split(key) 90 | indices = jax.random.choice( 91 | subkey, n_total_pts, (n_rand_pts,), replace=False 92 | ) 93 | 94 | u_partial = u[:, indices, :] 95 | x_partial = x[:, indices, :] 96 | 97 | vars = {"params": state.params, "batch_stats": state.batch_stats} 98 | u_hat = autoencoder.apply(vars, u_partial, x_partial, x) 99 | 100 | sum_batch_mse = jnp.sum(jnp.mean(jnp.sum((u - u_hat) ** 2, axis=2), axis=1)) 101 | sum_total_mse += sum_batch_mse 102 | 103 | mse_vs_point_ratio[enc_point_ratio_test] = sum_total_mse / len( 104 | dataloader.dataset 105 | ) 106 | return mse_vs_point_ratio 107 | -------------------------------------------------------------------------------- /experiments/exp_sde1d/main.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append(".") 4 | import os 5 | import jax 6 | from time import time 7 | from typing import Literal 8 | from jax.typing import ArrayLike 9 | from experiments.trainer_loader import get_trainer 10 | from functional_autoencoders.datasets import get_dataloaders, RandomMissingData 11 | from functional_autoencoders.datasets.sde import ( 12 | SDE, 13 | get_brownian_dynamics_diffusion, 14 | get_brownian_dynamics_drift, 15 | ) 16 | from functional_autoencoders.util import ( 17 | save_data_results, 18 | save_model_results, 19 | yaml_load, 20 | fit_trainer_using_config, 21 | ) 22 | 23 | 24 | def potential_1d(x: ArrayLike, c: float = 0.5, alpha: float = 12): 25 | r""" 26 | The potential 27 | 28 | $$U(x) = \alpha\left( \frac{1}{4} x^{4} + \frac{c}{3} x^{3} - \frac{1}{2} x^{2} - c x \right),$$ 29 | 30 | which, with the particular choices $c = 1/2$ and $\alpha = 12$, yields the potential (2.26) in 31 | [the paper](https://arxiv.org/pdf/2408.01362), with minima at $-1$ and $+1$. 32 | 33 | :param x: `ArrayLike` of shape `[1]`. 34 | :param c: `float` parameter for potential $U$ (see above). 35 | :param alpha: `float` parameter for potential $U$ (see above). 36 | """ 37 | return alpha * ((c / 3) * x[0] ** 3 - c * x[0] + 0.25 * x[0] ** 4 - 0.5 * x[0] ** 2) 38 | 39 | 40 | def get_sde_dataloaders( 41 | config_data, verbose, samples=None, which: Literal["train", "test", "both"] = "both" 42 | ): 43 | drift = get_brownian_dynamics_drift(potential_1d) 44 | diffusion = get_brownian_dynamics_diffusion(config_data["epsilon"]) 45 | point_ratio_train = config_data["point_ratio_train"] 46 | random_missing_data = RandomMissingData(point_ratio_train) 47 | 48 | if which == "train" or which == "both": 49 | train_dataloader = get_dataloaders( 50 | SDE, 51 | drift=drift, 52 | diffusion=diffusion, 53 | T=config_data["T"], 54 | samples=config_data["samples"] if samples is None else samples, 55 | pts=config_data["pts"], 56 | sim_dt=config_data["sim_dt"], 57 | batch_size=config_data["batch_size"], 58 | num_workers=config_data["num_workers"], 59 | x0=config_data["x0"], 60 | transform_generated=random_missing_data, 61 | which="train", 62 | verbose=verbose, 63 | ) 64 | else: 65 | train_dataloader = None 66 | 67 | if which == "test" or which == "both": 68 | test_dataloader = get_dataloaders( 69 | SDE, 70 | drift=drift, 71 | diffusion=diffusion, 72 | T=config_data["T"], 73 | samples=config_data["samples"] if samples is None else samples, 74 | pts=config_data["pts"], 75 | sim_dt=config_data["sim_dt"], 76 | batch_size=config_data["batch_size"], 77 | num_workers=config_data["num_workers"], 78 | x0=config_data["x0"], 79 | which="test", 80 | verbose=verbose, 81 | ) 82 | else: 83 | test_dataloader = None 84 | 85 | return train_dataloader, test_dataloader 86 | 87 | 88 | def run_sde1d(key, output_dir, config_path, theta_list, verbose=True): 89 | config_sde1d = yaml_load(config_path) 90 | config_data = config_sde1d["data"] 91 | train_dataloader, test_dataloader = get_sde_dataloaders(config_data, verbose) 92 | 93 | for theta in theta_list: 94 | config_sde1d["loss"]["options"]["fvae_sde"]["theta"] = theta 95 | 96 | key, subkey = jax.random.split(key) 97 | trainer = get_trainer(subkey, config_sde1d, train_dataloader, test_dataloader) 98 | 99 | start_time = time() 100 | 101 | key, subkey = jax.random.split(key) 102 | results = fit_trainer_using_config( 103 | subkey, trainer, config_sde1d, verbose="metrics" if verbose else "none" 104 | ) 105 | 106 | training_time = time() - start_time 107 | 108 | save_model_results( 109 | autoencoder=trainer.autoencoder, 110 | results=results, 111 | model_dir=os.path.join(output_dir, "models", str(theta)), 112 | ) 113 | 114 | save_data_results( 115 | autoencoder=trainer.autoencoder, 116 | results=results, 117 | test_dataloader=test_dataloader, 118 | data_dir=os.path.join( 119 | output_dir, "data", '0', str(theta) 120 | ), 121 | additional_data={ 122 | "theta": theta, 123 | "training_time": training_time, 124 | }, 125 | ) 126 | -------------------------------------------------------------------------------- /experiments/exp_sde1d/plots_sde1d.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append("../") 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | import numpy as np 8 | import matplotlib 9 | import seaborn as sns 10 | from functools import partial 11 | from functional_autoencoders.samplers.sampler_vae import SamplerVAE 12 | from functional_autoencoders.util.random.sde import add_bm_noise 13 | from plots import get_cmap 14 | 15 | 16 | def plot_potential_and_samples(ax1, ax2, potential_1d, train_dataloader, n_samples=5): 17 | x = np.linspace(-1.5, 1.62, 1000) 18 | U = potential_1d(x[None, :, None]) 19 | s = 0 20 | shift = -np.min(U) + s 21 | ax1.plot(x, U + shift, c="r", zorder=3) 22 | ax1.set_xticks([-1, 1]) 23 | ax1.set_xticklabels(["$x_{1} = -1$", "$x_{2} = +1$"]) 24 | ax1.set_yticks([0, shift + potential_1d([-1])]) 25 | ax1.set_ymargin(0) 26 | ax1.yaxis.set_tick_params(right=False) 27 | ax1.xaxis.set_tick_params(top=False) 28 | ax1.set_xlabel("$x$") 29 | ax1.set_title("(a) Potential $U(x)$") 30 | 31 | u, x, _, _ = next(iter(train_dataloader)) 32 | cs = get_cmap("Reds", n_samples) 33 | for i in range(0, n_samples): 34 | ax2.plot(x[i], u[i], c=cs[i]) 35 | ax2.set_xlabel("$t$") 36 | ax2.set_title(r"(b) Sample paths $(u_{t})_{t \in [0, 5]}$") 37 | ax2.set_yticks([-1, 0, 1]) 38 | ax2.yaxis.set_tick_params(right=False) 39 | ax2.xaxis.set_tick_params(top=False) 40 | 41 | 42 | def plot_reconstructions_and_generated_samples( 43 | key, 44 | ax1, 45 | ax2, 46 | ax3, 47 | info, 48 | theta, 49 | config_data, 50 | test_dataloader, 51 | n_samples=5, 52 | title=False, 53 | ): 54 | autoencoder = info["autoencoder"] 55 | state = info["results"]["state"] 56 | 57 | u, x, _, _ = next(iter(test_dataloader)) 58 | 59 | vars = {"params": state.params, "batch_stats": state.batch_stats} 60 | uhat = autoencoder.apply(vars, u, x, x) 61 | 62 | cs_u = get_cmap("Reds", n_samples) 63 | # cs_uhat = get_cmap("GnBu", n_samples) 64 | for i in range(n_samples): 65 | ax1.plot(x[i], u[i], c=cs_u[i]) 66 | ax1.plot(x[i], uhat[i], c="k", zorder=3, linewidth=2, linestyle=(0, (3, 0.5))) 67 | 68 | ax1.yaxis.set_tick_params(right=False) 69 | ax1.xaxis.set_tick_params(top=False) 70 | ax1.set_yticks([-1, 0, 1]) 71 | if title: 72 | ax1.set_title("(a) Reconstructions") 73 | cs_u = get_cmap("Reds", n_samples) 74 | 75 | key, subkey = jax.random.split(key) 76 | sampler = SamplerVAE(autoencoder, state) 77 | samples = sampler.sample(x[:n_samples], subkey) 78 | 79 | s = add_bm_noise( 80 | samples=samples, 81 | epsilon=config_data["epsilon"], 82 | theta=theta, 83 | sim_dt=config_data["sim_dt"], 84 | T=config_data["T"], 85 | ) 86 | 87 | for i in range(n_samples): 88 | ax3.plot(x[i], s[i], c=cs_u[i]) 89 | ax3.yaxis.set_tick_params(right=False) 90 | ax3.xaxis.set_tick_params(top=False) 91 | ax3.set_yticks([-1, 0, 1]) 92 | if title: 93 | ax3.set_title(r"(c) Realizations of $g(z; \psi) + \eta$") 94 | 95 | n_repeats = 1000 96 | for i in range(n_samples): 97 | s = jnp.expand_dims(samples[i, :, :], 0) 98 | s = jnp.repeat(s, n_repeats, axis=0) 99 | s = add_bm_noise( 100 | samples=s, 101 | epsilon=config_data["epsilon"], 102 | theta=theta, 103 | sim_dt=config_data["sim_dt"], 104 | T=config_data["T"], 105 | ) 106 | 107 | std = jnp.std(s[:, :, 0], axis=0) 108 | if i == 0: 109 | ax2.plot(x[i], samples[i], c=cs_u[i], label=r"$g(z; \psi)$") 110 | ax2.fill_between( 111 | x[i, :, 0], 112 | samples[i, :, 0] - std, 113 | samples[i, :, 0] + std, 114 | color=cs_u[i], 115 | alpha=0.4, 116 | label="1 SD", 117 | ) 118 | else: 119 | ax2.plot(x[i], samples[i], c=cs_u[i]) 120 | ax2.fill_between( 121 | x[i, :, 0], 122 | samples[i, :, 0] - std, 123 | samples[i, :, 0] + std, 124 | color=cs_u[i], 125 | alpha=0.4, 126 | ) 127 | 128 | ax2.yaxis.set_tick_params(right=False) 129 | ax2.xaxis.set_tick_params(top=False) 130 | ax2.set_yticks([-1, 0, 1]) 131 | if title: 132 | ax2.set_title(r"(b) Distribution of $g(z; \psi) + \eta$") 133 | 134 | 135 | def plot_latent_variable( 136 | ax, 137 | ax_colorbar, 138 | info, 139 | test_dataloader, 140 | z_min=-2.5, 141 | z_max=2.5, 142 | n_evals=8, 143 | ): 144 | autoencoder = info["autoencoder"] 145 | state = info["results"]["state"] 146 | 147 | z = jnp.linspace(z_min, z_max, n_evals) 148 | z = jnp.expand_dims(z, 1) 149 | _, x, _, _ = next(iter(test_dataloader)) 150 | 151 | # samples = trainer.autoencoder.decode(variables, z, x[: z.shape[0], :, :]) 152 | samples = autoencoder.decode(state, z, x[: z.shape[0], :, :], train=False) 153 | 154 | cs_u = get_cmap("Reds", z.shape[0], start=0.3) 155 | for i in range(z.shape[0]): 156 | ax.plot(x[i, :, 0], samples[i, :, 0], c=cs_u[i]) 157 | ax.set_yticks([-1, 0, 1]) 158 | ax.xaxis.set_tick_params(right=False) 159 | ax.yaxis.set_tick_params(top=False) 160 | ax.set_title(r"(a) $g(z; \psi)(t)$ for $z \in [-2.5, 2.5]$") 161 | 162 | def truncate_colormap(cmap, minval=0.0, maxval=1.0, n=100): 163 | new_cmap = matplotlib.colors.LinearSegmentedColormap.from_list( 164 | "trunc({n},{a:.2f},{b:.2f})".format(n=cmap.name, a=minval, b=maxval), 165 | cmap(np.linspace(minval, maxval, n)), 166 | ) 167 | return new_cmap 168 | 169 | matplotlib.colorbar.ColorbarBase( 170 | ax_colorbar, 171 | cmap=truncate_colormap(matplotlib.cm.Reds, minval=0.3), 172 | orientation="horizontal", 173 | norm=matplotlib.colors.Normalize(vmin=-2.5, vmax=2.5), 174 | ) 175 | ax_colorbar.set_ylabel("$z$", rotation=0) 176 | ax_colorbar.yaxis.set_label_coords(-0.025, -0.03) 177 | 178 | 179 | @partial(jax.vmap, in_axes=(0, 0)) 180 | def get_transition_times(path, x): 181 | return x[jnp.argmax(path[:, 0] > 0), 0] 182 | 183 | 184 | def plot_transition_time_distribution( 185 | key, 186 | ax, 187 | info, 188 | test_dataloader, 189 | ): 190 | autoencoder = info["autoencoder"] 191 | state = info["results"]["state"] 192 | 193 | tts_samples = [] 194 | tts_u = [] 195 | for u, x, _, _ in iter(test_dataloader): 196 | key, subkey = jax.random.split(key) 197 | sampler = SamplerVAE(autoencoder, state) 198 | samples = sampler.sample(x, subkey) 199 | 200 | tt_samples = get_transition_times(samples, x) 201 | tt_u = get_transition_times(u, x) 202 | tts_samples = tts_samples + [*tt_samples.tolist()] 203 | tts_u = tts_u + [*tt_u.tolist()] 204 | 205 | sns.kdeplot(tts_samples, ax=ax, label="FVAE", color="k", fill=True, alpha=0.5) 206 | sns.kdeplot( 207 | tts_u, 208 | ax=ax, 209 | label="Direct numerical simulation", 210 | color="r", 211 | fill=True, 212 | alpha=0.5, 213 | ) 214 | 215 | ax.legend() 216 | ax.yaxis.set_tick_params(right=False) 217 | ax.xaxis.set_tick_params(top=False) 218 | ax.set_title(f"(b) Time $t$ of first crossing above $0$ (N={len(tts_samples)})") 219 | ax.set_ylabel("Density") 220 | -------------------------------------------------------------------------------- /experiments/exp_sde2d/main.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append(".") 4 | import os 5 | import jax 6 | import jax.numpy as jnp 7 | from typing import Literal 8 | from experiments.trainer_loader import get_trainer 9 | from functional_autoencoders.datasets import get_dataloaders, RandomMissingData 10 | from functional_autoencoders.datasets.sde import ( 11 | SDE, 12 | get_brownian_dynamics_diffusion, 13 | get_brownian_dynamics_drift, 14 | ) 15 | from functional_autoencoders.util import ( 16 | save_model_results, 17 | yaml_load, 18 | fit_trainer_using_config, 19 | ) 20 | 21 | 22 | def potential_2d(x): 23 | parabola = lambda x: jnp.square(x).sum(axis=-1) 24 | linear = lambda x: jnp.sum(x, axis=-1) 25 | neg_gaussian = lambda x, C, sigma: -jnp.prod( 26 | jax.scipy.stats.norm.pdf(x.reshape(-1, 2), loc=C, scale=sigma), axis=-1 27 | ).reshape(x.shape[:-1]) 28 | 29 | mu_list = [ 30 | [0, 0], 31 | [0.2, 0.2], 32 | [-0.2, -0.2], 33 | [0.2, -0.2], 34 | [0, 0.2], 35 | [-0.2, 0], 36 | ] 37 | 38 | sigma_list = [ 39 | 0.1, 40 | 0.1, 41 | 0.1, 42 | 0.1, 43 | 0.03, 44 | 0.03, 45 | ] 46 | 47 | coeff_list = [ 48 | 0.1, 49 | 0.1, 50 | 0.1, 51 | 0.1, 52 | 0.01, 53 | 0.01, 54 | ] 55 | 56 | p = parabola(x) 57 | lin = linear(x) 58 | result = p + 0.5 * lin 59 | 60 | for mu, sigma, coeff in zip(mu_list, sigma_list, coeff_list): 61 | result += coeff * neg_gaussian(x, jnp.array(mu), sigma) 62 | 63 | return 0.3 * result 64 | 65 | 66 | def get_sde_dataloaders( 67 | config_data, verbose, samples=None, which: Literal["train", "test", "both"] = "both" 68 | ): 69 | drift = get_brownian_dynamics_drift(potential_2d) 70 | diffusion = get_brownian_dynamics_diffusion(config_data["epsilon"]) 71 | point_ratio_train = config_data["point_ratio_train"] 72 | random_missing_data = RandomMissingData(point_ratio_train) 73 | 74 | if which == "train" or which == "both": 75 | train_dataloader = get_dataloaders( 76 | SDE, 77 | drift=drift, 78 | diffusion=diffusion, 79 | T=config_data["T"], 80 | samples=config_data["samples"] if samples is None else samples, 81 | pts=config_data["pts"], 82 | sim_dt=config_data["sim_dt"], 83 | batch_size=config_data["batch_size"], 84 | num_workers=config_data["num_workers"], 85 | x0=config_data["x0"], 86 | transform_generated=random_missing_data, 87 | which="train", 88 | verbose=verbose, 89 | ) 90 | else: 91 | train_dataloader = None 92 | 93 | if which == "test" or which == "both": 94 | test_dataloader = get_dataloaders( 95 | SDE, 96 | drift=drift, 97 | diffusion=diffusion, 98 | T=config_data["T"], 99 | samples=config_data["samples"] if samples is None else samples, 100 | pts=config_data["pts"], 101 | sim_dt=config_data["sim_dt"], 102 | batch_size=config_data["batch_size"], 103 | num_workers=config_data["num_workers"], 104 | x0=config_data["x0"], 105 | which="test", 106 | verbose=verbose, 107 | ) 108 | else: 109 | test_dataloader = None 110 | 111 | return train_dataloader, test_dataloader 112 | 113 | 114 | def run_sde2d(key, output_dir, config_path, verbose=True): 115 | config_sde2d = yaml_load(config_path) 116 | config_data = config_sde2d["data"] 117 | 118 | train_dataloader, test_dataloader = get_sde_dataloaders(config_data, verbose) 119 | key, subkey = jax.random.split(key) 120 | trainer = get_trainer(subkey, config_sde2d, train_dataloader, test_dataloader) 121 | 122 | key, subkey = jax.random.split(key) 123 | results = fit_trainer_using_config( 124 | subkey, trainer, config_sde2d, verbose="metrics" if verbose else "none" 125 | ) 126 | 127 | save_model_results( 128 | autoencoder=trainer.autoencoder, 129 | results=results, 130 | model_dir=os.path.join(output_dir, "models"), 131 | ) 132 | -------------------------------------------------------------------------------- /experiments/exp_sde2d/plots_sde2d.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append("../") 4 | 5 | import warnings 6 | import jax.numpy as jnp 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | from matplotlib.collections import LineCollection 10 | from functional_autoencoders.util import get_n_params 11 | from plots import plot_train_val_losses 12 | 13 | 14 | def get_X_Y_Z(potential, start=-0.5, end=0.5, n=100): 15 | X = np.linspace(start, end, n) 16 | Y = np.linspace(start, end, n) 17 | X, Y = np.meshgrid(X, Y) 18 | Z = potential(jnp.concat([X[..., None], Y[..., None]], axis=-1)) 19 | return X, Y, Z 20 | 21 | 22 | def plot_colored_line(x, y, **lc_kwargs): 23 | """ 24 | Adapted from: https://matplotlib.org/stable/gallery/lines_bars_and_markers/multicolored_line.html 25 | """ 26 | if "array" in lc_kwargs: 27 | warnings.warn('The provided "array" keyword argument will be overridden') 28 | 29 | # Default the capstyle to butt so that the line segments smoothly line up 30 | default_kwargs = {"capstyle": "butt"} 31 | default_kwargs.update(lc_kwargs) 32 | 33 | # Compute the midpoints of the line segments. Include the first and last points 34 | # twice so we don't need any special syntax later to handle them. 35 | x = np.asarray(x) 36 | y = np.asarray(y) 37 | x_midpts = np.hstack((x[0], 0.5 * (x[1:] + x[:-1]), x[-1])) 38 | y_midpts = np.hstack((y[0], 0.5 * (y[1:] + y[:-1]), y[-1])) 39 | 40 | # Determine the start, middle, and end coordinate pair of each line segment. 41 | # Use the reshape to add an extra dimension so each pair of points is in its 42 | # own list. Then concatenate them to create: 43 | # [ 44 | # [(x1_start, y1_start), (x1_mid, y1_mid), (x1_end, y1_end)], 45 | # [(x2_start, y2_start), (x2_mid, y2_mid), (x2_end, y2_end)], 46 | # ... 47 | # ] 48 | coord_start = np.column_stack((x_midpts[:-1], y_midpts[:-1]))[:, np.newaxis, :] 49 | coord_mid = np.column_stack((x, y))[:, np.newaxis, :] 50 | coord_end = np.column_stack((x_midpts[1:], y_midpts[1:]))[:, np.newaxis, :] 51 | segments = np.concatenate((coord_start, coord_mid, coord_end), axis=1) 52 | 53 | c = np.linspace(0, 1, x.size) # color values for each line segment 54 | lc = LineCollection(segments, **default_kwargs) 55 | lc.set_array(c) # set the colors of each segment 56 | 57 | plt.gca().add_collection(lc) 58 | 59 | 60 | def plot_contour_with_partitions(potential, x_locs, y_locs, cmap_contour="hot"): 61 | U, Y, Z = get_X_Y_Z(potential) 62 | plt.contourf(U, Y, Z, cmap=cmap_contour) 63 | 64 | for i in range(1, len(x_locs) - 1): 65 | plt.axvline(x=x_locs[i], color="black") # Vertical lines 66 | plt.axhline(y=y_locs[i], color="black") # Horizontal lines 67 | 68 | plt.colorbar() 69 | plt.xticks([]) 70 | plt.yticks([]) 71 | 72 | 73 | def plot_contour_with_samples( 74 | u, potential, h, w, cmap_contour="hot", cmap_line="winter" 75 | ): 76 | 77 | U, Y, Z = get_X_Y_Z(potential) 78 | for i in range(h * w): 79 | plt.subplot(h, w, i + 1) 80 | plt.contourf(U, Y, Z, cmap=cmap_contour) 81 | plot_colored_line(u[i, :, 0], u[i, :, 1], cmap=cmap_line) 82 | plt.xticks([]) 83 | plt.yticks([]) 84 | 85 | plt.tight_layout() 86 | 87 | 88 | def plot_contour_with_reconstructions( 89 | u, u_rec, potential, cmap_contour="hot", cmap_line="winter" 90 | ): 91 | n_recs = u.shape[0] 92 | U, Y, Z = get_X_Y_Z(potential) 93 | for i in range(n_recs): 94 | plt.subplot(2, n_recs, i + 1) 95 | plt.contourf(U, Y, Z, cmap=cmap_contour) 96 | plot_colored_line(u[i, :, 0], u[i, :, 1], cmap=cmap_line) 97 | plt.xticks([]) 98 | plt.yticks([]) 99 | if i == 0: 100 | plt.ylabel("True") 101 | 102 | plt.subplot(2, n_recs, n_recs + i + 1) 103 | plt.contourf(U, Y, Z, cmap=cmap_contour) 104 | plot_colored_line(u_rec[i, :, 0], u_rec[i, :, 1], cmap=cmap_line) 105 | plt.xticks([]) 106 | plt.yticks([]) 107 | if i == 0: 108 | plt.ylabel("Reconstructed") 109 | 110 | plt.tight_layout() 111 | 112 | 113 | def plot_training_results(results): 114 | plot_train_val_losses( 115 | results["training_loss_history"], 116 | results["metrics_history"], 117 | start_idx_train=3, 118 | ) 119 | plt.tight_layout() 120 | plt.show() 121 | 122 | n_params = get_n_params(results["state"].params) 123 | metric_names = results["metrics_history"].keys() 124 | for metric_name in reversed(metric_names): 125 | print(f'{metric_name}: {results["metrics_history"][metric_name][-1]:.3e}') 126 | print(f"Number of parameters: {n_params}") 127 | 128 | 129 | def plot_contour_and_heatmap(potential, cmap_contour="hot"): 130 | X, Y, Z = get_X_Y_Z(potential) 131 | 132 | fig = plt.figure(figsize=(10, 4)) 133 | 134 | ax1 = fig.add_subplot(121, projection="3d") 135 | ax1.plot_surface(X, Y, Z, cmap=cmap_contour) 136 | ax1.set_title("Surface of Potential") 137 | ax1.set_xlabel("$x_1$") 138 | ax1.set_ylabel("$x_2$") 139 | 140 | ax2 = fig.add_subplot(122) 141 | CS = ax2.contourf(X, Y, Z, cmap=cmap_contour) 142 | cbar = fig.colorbar(CS, ax=ax2) 143 | cbar.set_label("Potential") 144 | ax2.set_title("Heatmap of Potential") 145 | ax2.set_xlabel("$x_1$") 146 | ax2.set_ylabel("$x_2$") 147 | -------------------------------------------------------------------------------- /experiments/exp_sparse_training/main.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append(".") 4 | import os 5 | import jax 6 | from functional_autoencoders.datasets import get_dataloaders, ComplementMasking 7 | from functional_autoencoders.datasets.navier_stokes import NavierStokes 8 | from functional_autoencoders.datasets.darcy_flow import DarcyFlow 9 | from experiments.trainer_loader import get_trainer 10 | from functional_autoencoders.util import ( 11 | save_model_results, 12 | yaml_load, 13 | fit_trainer_using_config, 14 | ) 15 | 16 | 17 | def run_sparse_training( 18 | key, 19 | output_dir, 20 | config_path, 21 | ratio_rand_pts_enc, 22 | ns_viscosity, 23 | is_darcy=False, 24 | verbose="metrics", 25 | ): 26 | 27 | config = yaml_load(config_path) 28 | 29 | mask_train = ComplementMasking(ratio_rand_pts_enc) 30 | mask_test = ComplementMasking(ratio_rand_pts_enc) 31 | if not is_darcy: 32 | train_dataloader, test_dataloader = get_dataloaders( 33 | NavierStokes, 34 | data_base=".", 35 | viscosity=ns_viscosity, 36 | transform_train=mask_train, 37 | transform_test=mask_test, 38 | ) 39 | else: 40 | train_dataloader, test_dataloader = get_dataloaders( 41 | DarcyFlow, 42 | data_base=".", 43 | transform_train=mask_train, 44 | transform_test=mask_test, 45 | downscale=9, 46 | ) 47 | 48 | key, subkey = jax.random.split(key) 49 | trainer = get_trainer(subkey, config, train_dataloader, test_dataloader) 50 | 51 | key, subkey = jax.random.split(key) 52 | results = fit_trainer_using_config(subkey, trainer, config, verbose=verbose) 53 | 54 | save_model_results( 55 | autoencoder=trainer.autoencoder, 56 | results=results, 57 | model_dir=os.path.join(output_dir, "models"), 58 | ) 59 | -------------------------------------------------------------------------------- /experiments/exp_sparse_vs_dense_wall_clock_training/main.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append(".") 4 | import os 5 | import jax 6 | from time import time 7 | from functional_autoencoders.datasets import get_dataloaders, RandomMasking 8 | from functional_autoencoders.datasets.darcy_flow import DarcyFlow 9 | from experiments.trainer_loader import get_trainer 10 | from functional_autoencoders.util import ( 11 | save_data_results, 12 | save_model_results, 13 | yaml_load, 14 | fit_trainer_using_config, 15 | ) 16 | 17 | 18 | def run_sparse_vs_dense_wall_clock_training( 19 | key, 20 | output_dir, 21 | config_path, 22 | n_runs, 23 | downscale, 24 | ratio_rand_pts_enc_train_list, 25 | verbose="metrics", 26 | ): 27 | 28 | config = yaml_load(config_path) 29 | 30 | _, test_dataloader_full = get_dataloaders( 31 | DarcyFlow, data_base=".", downscale=downscale 32 | ) 33 | 34 | for run_idx in range(n_runs): 35 | for ratio_rand_pts_enc_train in ratio_rand_pts_enc_train_list: 36 | mask_train = RandomMasking( 37 | ratio_rand_pts_enc_train, ratio_rand_pts_enc_train 38 | ) 39 | train_dataloader, _ = get_dataloaders( 40 | DarcyFlow, 41 | data_base=".", 42 | transform_train=mask_train, 43 | downscale=downscale, 44 | ) 45 | 46 | key, subkey = jax.random.split(key) 47 | trainer = get_trainer( 48 | subkey, config, train_dataloader, test_dataloader_full 49 | ) 50 | 51 | start_time = time() 52 | 53 | key, subkey = jax.random.split(key) 54 | results = fit_trainer_using_config(subkey, trainer, config, verbose=verbose) 55 | 56 | training_time = time() - start_time 57 | 58 | save_model_results( 59 | autoencoder=trainer.autoencoder, 60 | results=results, 61 | model_dir=os.path.join( 62 | output_dir, "models", str(run_idx), str(ratio_rand_pts_enc_train) 63 | ), 64 | ) 65 | 66 | save_data_results( 67 | autoencoder=trainer.autoencoder, 68 | results=results, 69 | test_dataloader=test_dataloader_full, 70 | data_dir=os.path.join( 71 | output_dir, "data", str(run_idx), str(ratio_rand_pts_enc_train) 72 | ), 73 | additional_data={ 74 | "train_point_ratio": ratio_rand_pts_enc_train, 75 | "training_time": training_time, 76 | }, 77 | ) 78 | -------------------------------------------------------------------------------- /experiments/exp_train_vs_inference_wall_clock/main.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append(".") 4 | import os 5 | import jax 6 | from time import time 7 | from functional_autoencoders.datasets import get_dataloaders, RandomMasking 8 | from functional_autoencoders.datasets.darcy_flow import DarcyFlow 9 | from experiments.trainer_loader import get_trainer 10 | from functional_autoencoders.util import ( 11 | save_data_results, 12 | save_model_results, 13 | yaml_load, 14 | fit_trainer_using_config, 15 | ) 16 | 17 | 18 | def run_train_vs_inference_wall_clock( 19 | key, 20 | output_dir, 21 | config_path, 22 | n_runs, 23 | downscale, 24 | ratio_rand_pts_enc_train_list, 25 | verbose="metrics", 26 | ): 27 | 28 | config = yaml_load(config_path) 29 | 30 | for run_idx in range(n_runs): 31 | for ratio_rand_pts_enc_train in ratio_rand_pts_enc_train_list: 32 | mask_train = RandomMasking( 33 | ratio_rand_pts_enc_train, ratio_rand_pts_enc_train 34 | ) 35 | train_dataloader, test_dataloader = get_dataloaders( 36 | DarcyFlow, 37 | data_base=".", 38 | transform_train=mask_train, 39 | downscale=downscale, 40 | ) 41 | 42 | key, subkey = jax.random.split(key) 43 | trainer = get_trainer(subkey, config, train_dataloader, test_dataloader) 44 | 45 | # Evaluate training time 46 | start_time = time() 47 | key, subkey = jax.random.split(key) 48 | results = fit_trainer_using_config(subkey, trainer, config, verbose=verbose) 49 | training_time = time() - start_time 50 | 51 | # Evaluate inference time 52 | start_time = time() 53 | perform_inference(trainer.autoencoder, results["state"], test_dataloader) 54 | inference_time = time() - start_time 55 | 56 | save_model_results( 57 | autoencoder=trainer.autoencoder, 58 | results=results, 59 | model_dir=os.path.join( 60 | output_dir, "models", str(run_idx), str(ratio_rand_pts_enc_train) 61 | ), 62 | ) 63 | 64 | save_data_results( 65 | autoencoder=trainer.autoencoder, 66 | results=results, 67 | test_dataloader=test_dataloader, 68 | data_dir=os.path.join( 69 | output_dir, "data", str(run_idx), str(ratio_rand_pts_enc_train) 70 | ), 71 | additional_data={ 72 | "train_point_ratio": ratio_rand_pts_enc_train, 73 | "training_time": training_time, 74 | "inference_time": inference_time, 75 | }, 76 | ) 77 | 78 | 79 | def perform_inference(autoencoder, state, dataloader): 80 | for u, x, _, _ in dataloader: 81 | vars = {"params": state.params, "batch_stats": state.batch_stats} 82 | u_hat = autoencoder.apply(vars, u, x, x) 83 | -------------------------------------------------------------------------------- /experiments/main_run.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append(".") 4 | sys.path.append("src/") 5 | 6 | import jax 7 | from time import time 8 | from experiments.exp_baseline_comparisons.main import ( 9 | run_baseline_comparisons, 10 | ) 11 | from experiments.exp_dirac.main import run_dirac 12 | from experiments.exp_rec_mse_vs_downsample_ratio.main import ( 13 | run_rec_mse_vs_downsample_ratio, 14 | ) 15 | from experiments.exp_rec_mse_vs_point_ratio.main import run_rec_mse_vs_point_ratio 16 | from experiments.exp_sde1d.main import run_sde1d 17 | from experiments.exp_sde2d.main import run_sde2d 18 | from experiments.exp_sparse_training.main import run_sparse_training 19 | from experiments.exp_sparse_vs_dense_wall_clock_training.main import ( 20 | run_sparse_vs_dense_wall_clock_training, 21 | ) 22 | from experiments.exp_train_vs_inference_wall_clock.main import ( 23 | run_train_vs_inference_wall_clock, 24 | ) 25 | 26 | 27 | def wrap_run(func): 28 | def wrapped_func(*args, **kwargs): 29 | try: 30 | print("*" * 40) 31 | print(f'Saving to: {kwargs["output_dir"]}') 32 | 33 | start_time = time() 34 | func(*args, **kwargs) 35 | 36 | print("Done!") 37 | print(f"Time taken: {(time() - start_time) / 60:.2f} minutes") 38 | print("*" * 40 + "\n") 39 | 40 | except Exception as e: 41 | print("Run failed!") 42 | print(e) 43 | 44 | return wrapped_func 45 | 46 | 47 | if __name__ == "__main__": 48 | key = jax.random.PRNGKey(42) 49 | 50 | start_time = time() 51 | 52 | wrap_run(run_baseline_comparisons)( 53 | key=key, 54 | output_dir="tmp/experiments/exp_baseline_comparisons/cnn", 55 | config_path="experiments/configs/config_cnn.yaml", 56 | n_runs=5, 57 | ns_viscosity=1e-4, 58 | is_darcy=False, 59 | ) 60 | 61 | wrap_run(run_baseline_comparisons)( 62 | key=key, 63 | output_dir="tmp/experiments/exp_baseline_comparisons/point", 64 | config_path="experiments/configs/config_fae.yaml", 65 | n_runs=5, 66 | ns_viscosity=1e-4, 67 | is_darcy=False, 68 | ) 69 | 70 | wrap_run(run_dirac)( 71 | key=key, 72 | output_dir="tmp/experiments/exp_dirac/fae", 73 | config_path="experiments/configs/config_dirac_fae.yaml", 74 | n_runs=50, 75 | resolutions=(8, 16, 32, 64, 128), 76 | ) 77 | 78 | wrap_run(run_dirac)( 79 | key=key, 80 | output_dir="tmp/experiments/exp_dirac/vano", 81 | config_path="experiments/configs/config_dirac_vano.yaml", 82 | n_runs=50, 83 | resolutions=(8, 16, 32, 64, 128), 84 | ) 85 | 86 | wrap_run(run_rec_mse_vs_downsample_ratio)( 87 | key=key, 88 | output_dir="tmp/experiments/exp_rec_mse_vs_downsample_ratio", 89 | config_path="experiments/configs/config_fae.yaml", 90 | n_runs=5, 91 | ns_viscosity=1e-4, 92 | downsample_ratios=(1, 2, 4, 8), 93 | enc_point_ratio_train=-1, 94 | ) 95 | 96 | wrap_run(run_rec_mse_vs_point_ratio)( 97 | key=key, 98 | output_dir="tmp/experiments/exp_rec_mse_vs_point_ratio", 99 | config_path="experiments/configs/config_fae.yaml", 100 | n_runs=5, 101 | ns_viscosity=1e-4, 102 | enc_point_ratio_train_list=(0.1, 0.5, 0.9), 103 | enc_point_ratio_test_list=(0.1, 0.3, 0.5, 0.7, 0.9), 104 | ) 105 | 106 | wrap_run(run_sde1d)( 107 | key=key, 108 | output_dir="tmp/experiments/exp_sde1d", 109 | config_path="experiments/configs/config_sde1d.yaml", 110 | theta_list=(0, 25, 10_000), 111 | ) 112 | 113 | wrap_run(run_sde2d)( 114 | key=key, 115 | output_dir="tmp/experiments/exp_sde2d", 116 | config_path="experiments/configs/config_sde2d.yaml", 117 | ) 118 | 119 | wrap_run(run_sparse_training)( 120 | key=key, 121 | output_dir="tmp/experiments/sparse_training", 122 | config_path="experiments/configs/config_fae.yaml", 123 | ratio_rand_pts_enc=0.3, 124 | ns_viscosity=1e-4, 125 | is_darcy=False, 126 | ) 127 | 128 | wrap_run(run_sparse_training)( 129 | key=key, 130 | output_dir="tmp/experiments/sparse_training_darcy", 131 | config_path="experiments/configs/config_fae.yaml", 132 | ratio_rand_pts_enc=0.3, 133 | ns_viscosity=1e-4, 134 | is_darcy=True, 135 | ) 136 | 137 | wrap_run(run_sparse_vs_dense_wall_clock_training)( 138 | key=key, 139 | output_dir="tmp/experiments/exp_sparse_vs_dense_wall_clock_training", 140 | config_path="experiments/configs/config_fae_timing.yaml", 141 | n_runs=5, 142 | downscale=2, 143 | ratio_rand_pts_enc_train_list=(0.1, 0.5, 1), 144 | ) 145 | 146 | wrap_run(run_train_vs_inference_wall_clock)( 147 | key=key, 148 | output_dir="tmp/experiments/exp_train_vs_inference_wall_clock", 149 | config_path="experiments/configs/config_fae.yaml", 150 | n_runs=5, 151 | downscale=2, 152 | ratio_rand_pts_enc_train_list=(0.1, 1), 153 | ) 154 | 155 | print("\n" + "-" * 40 + "\n") 156 | print(f"Total time taken: {(time() - start_time) / 60:.2f} minutes") 157 | print("\n" + "-" * 40 + "\n") 158 | -------------------------------------------------------------------------------- /experiments/trainer_loader.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import numpy as np 3 | from functional_autoencoders.domains import grid 4 | from functional_autoencoders.domains import off_grid 5 | from functional_autoencoders.encoders.cnn_encoder import CNNEncoder 6 | from functional_autoencoders.encoders.pooling_encoder import PoolingEncoder 7 | from experiments.custom_encoders import DiracEncoder 8 | from experiments.custom_decoders import DiracDecoder 9 | from functional_autoencoders.decoders.cnn_decoder import CNNDecoder 10 | from functional_autoencoders.decoders.nonlinear_decoder import NonlinearDecoder 11 | from functional_autoencoders.positional_encodings import ( 12 | RandomFourierEncoding, 13 | IdentityEncoding, 14 | ) 15 | from functional_autoencoders.util.networks.pooling import DeepSetPooling 16 | from functional_autoencoders.autoencoder import Autoencoder 17 | from functional_autoencoders.losses.vano import get_loss_vano_fn 18 | from functional_autoencoders.losses.fvae_sde import get_loss_fvae_sde_fn 19 | from functional_autoencoders.losses.fae import get_loss_fae_fn 20 | from functional_autoencoders.train.autoencoder_trainer import AutoencoderTrainer 21 | from functional_autoencoders.train.metrics import MSEMetric 22 | 23 | 24 | def get_trainer(key, config, train_dataloader, test_dataloader): 25 | key, subkey = jax.random.split(key) 26 | autoencoder = get_autoencoder(subkey, config) 27 | 28 | domain = get_domain(config, train_dataloader) 29 | loss_fn = get_loss_fn(config, autoencoder, domain) 30 | metrics = get_metrics(config, autoencoder, domain) 31 | 32 | trainer = AutoencoderTrainer( 33 | autoencoder=autoencoder, 34 | loss_fn=loss_fn, 35 | metrics=metrics, 36 | train_dataloader=train_dataloader, 37 | test_dataloader=test_dataloader, 38 | ) 39 | 40 | return trainer 41 | 42 | 43 | def get_autoencoder(key, config): 44 | key, subkey = jax.random.split(key) 45 | positional_encoding = get_positional_encoding(subkey, config) 46 | 47 | encoder = get_encoder(config, positional_encoding) 48 | decoder = get_decoder(config, positional_encoding) 49 | 50 | autoencoder = Autoencoder( 51 | encoder=encoder, 52 | decoder=decoder, 53 | ) 54 | 55 | return autoencoder 56 | 57 | 58 | def get_positional_encoding(key, config): 59 | component_config = config["positional_encoding"] 60 | 61 | if component_config["is_used"]: 62 | key, subkey = jax.random.split(key) 63 | b_mat = jax.random.normal(subkey, (component_config["dim"] // 2, 2)) 64 | positional_encoding = RandomFourierEncoding(B=b_mat) 65 | else: 66 | positional_encoding = IdentityEncoding() 67 | 68 | return positional_encoding 69 | 70 | 71 | def get_encoder(config, positional_encoding): 72 | component_config = config["encoder"] 73 | component_type = component_config["type"] 74 | hyperparams = component_config["options"][component_type] 75 | 76 | if component_type == "pooling": 77 | pooling_fn = DeepSetPooling( 78 | mlp_dim=hyperparams["mlp_dim"], 79 | mlp_n_hidden_layers=hyperparams["mlp_n_hidden_layers"], 80 | ) 81 | encoder = PoolingEncoder( 82 | latent_dim=component_config["latent_dim"], 83 | is_variational=component_config["is_variational"], 84 | pooling_fn=pooling_fn, 85 | positional_encoding=positional_encoding, 86 | ) 87 | elif component_type == "dirac": 88 | encoder = DiracEncoder( 89 | latent_dim=component_config["latent_dim"], 90 | is_variational=component_config["is_variational"], 91 | features=hyperparams["features"], 92 | ) 93 | elif component_type == "cnn": 94 | encoder = CNNEncoder( 95 | latent_dim=component_config["latent_dim"], 96 | is_variational=component_config["is_variational"], 97 | cnn_features=hyperparams["cnn_features"], 98 | kernel_sizes=hyperparams["kernel_sizes"], 99 | strides=hyperparams["strides"], 100 | mlp_features=hyperparams["mlp_features"], 101 | ) 102 | else: 103 | raise ValueError(f"Unknown encoder type: {component_type}") 104 | 105 | return encoder 106 | 107 | 108 | def get_decoder(config, positional_encoding): 109 | component_config = config["decoder"] 110 | component_type = component_config["type"] 111 | hyperparams = component_config["options"][component_type] 112 | 113 | if component_type == "nonlinear": 114 | decoder = NonlinearDecoder( 115 | out_dim=hyperparams["out_dim"], 116 | features=hyperparams["features"], 117 | positional_encoding=positional_encoding, 118 | ) 119 | elif component_type == "dirac": 120 | decoder = DiracDecoder( 121 | fixed_centre=False, 122 | features=hyperparams["features"], 123 | min_std=lambda dx: (1 / np.sqrt(2 * np.pi)) * dx, 124 | ) 125 | elif component_type == "cnn": 126 | decoder = CNNDecoder( 127 | trans_cnn_features=hyperparams["trans_cnn_features"], 128 | kernel_sizes=hyperparams["kernel_sizes"], 129 | strides=hyperparams["strides"], 130 | mlp_features=hyperparams["mlp_features"], 131 | final_cnn_features=hyperparams["final_cnn_features"], 132 | final_kernel_sizes=hyperparams["final_kernel_sizes"], 133 | final_strides=hyperparams["final_strides"], 134 | c_in=hyperparams["c_in"], 135 | grid_pts_in=hyperparams["grid_pts_in"], 136 | ) 137 | else: 138 | raise ValueError(f"Unknown decoder type: {component_type}") 139 | 140 | return decoder 141 | 142 | 143 | def get_domain(config, train_dataloader): 144 | component_config = config["domain"] 145 | component_type = component_config["type"] 146 | hyperparams = component_config["options"][component_type] 147 | 148 | if component_type == "grid_zero_boundary_conditions": 149 | domain = grid.ZeroBoundaryConditions( 150 | s=hyperparams["s"], 151 | ) 152 | elif component_type == "off_grid_randomly_sampled_euclidean": 153 | domain = off_grid.RandomlySampledEuclidean( 154 | s=hyperparams["s"], 155 | ) 156 | elif component_type == "off_grid_sde": 157 | domain = off_grid.SDE( 158 | epsilon=config["data"]["epsilon"], x0=train_dataloader.dataset.x0[0] 159 | ) 160 | else: 161 | raise ValueError(f"Unknown domain type: {component_type}") 162 | 163 | return domain 164 | 165 | 166 | def get_loss_fn(config, autoencoder, domain): 167 | component_config = config["loss"] 168 | component_type = component_config["type"] 169 | hyperparams = component_config["options"][component_type] 170 | 171 | if component_type == "fae": 172 | loss_fn = get_loss_fae_fn( 173 | autoencoder=autoencoder, 174 | domain=domain, 175 | beta=hyperparams["beta"], 176 | subtract_data_norm=hyperparams["subtract_data_norm"], 177 | ) 178 | elif component_type == "vano": 179 | loss_fn = get_loss_vano_fn( 180 | autoencoder=autoencoder, 181 | rescale_by_norm=hyperparams["rescale_by_norm"], 182 | normalised_inner_prod=hyperparams["normalised_inner_prod"], 183 | beta=hyperparams["beta"], 184 | n_monte_carlo_samples=hyperparams["n_monte_carlo_samples"], 185 | ) 186 | elif component_type == "fvae_sde": 187 | loss_fn = get_loss_fvae_sde_fn( 188 | autoencoder=autoencoder, 189 | domain=domain, 190 | beta=hyperparams["beta"], 191 | theta=hyperparams["theta"], 192 | zero_penalty=hyperparams["zero_penalty"], 193 | n_monte_carlo_samples=hyperparams["n_monte_carlo_samples"], 194 | ) 195 | else: 196 | raise ValueError(f"Unknown loss type: {component_type}") 197 | 198 | return loss_fn 199 | 200 | 201 | def get_metrics(config, autoencoder, domain): 202 | metrics = [MSEMetric(autoencoder, domain=domain)] 203 | return metrics 204 | -------------------------------------------------------------------------------- /experiments/util_exp.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from functional_autoencoders.util import pickle_load 4 | 5 | 6 | def get_mse_losses_per_quantity_over_time(data_output_dir, quantity): 7 | mse_over_runs = {} 8 | training_times_over_runs = {} 9 | mse_test_losses_by_quantity_over_time = {} 10 | for run_idx_str in os.listdir(data_output_dir): 11 | for quantity_value in os.listdir(os.path.join(data_output_dir, run_idx_str)): 12 | result = pickle_load( 13 | os.path.join( 14 | data_output_dir, run_idx_str, quantity_value, "data.pickle" 15 | ) 16 | ) 17 | 18 | mse = result["training_results"]["metrics_history"]["MSE (in L^{2})"] 19 | training_time = result["additional_data"]["training_time"] 20 | quantity_value = result["additional_data"][quantity] 21 | 22 | if quantity_value not in mse_over_runs: 23 | mse_over_runs[quantity_value] = [] 24 | training_times_over_runs[quantity_value] = [] 25 | 26 | mse_over_runs[quantity_value].append(mse) 27 | training_times_over_runs[quantity_value].append(training_time) 28 | 29 | for quantity_value in mse_over_runs.keys(): 30 | mse = np.array(mse_over_runs[quantity_value]) 31 | training_times = np.array(training_times_over_runs[quantity_value]) 32 | 33 | mse_values_mean = np.mean(mse, axis=0) 34 | training_times_mean = np.mean(training_times, axis=0) 35 | t_range = np.linspace(0, training_times_mean, len(mse_values_mean)) 36 | 37 | mse_test_losses_by_quantity_over_time[quantity_value] = { 38 | t: mse for t, mse in zip(t_range, mse_values_mean) 39 | } 40 | 41 | return mse_test_losses_by_quantity_over_time 42 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | # A list of packages that are needed to build your package: 3 | requires = ["setuptools"] # REQUIRED if [build-system] table is used 4 | # The name of the Python object that frontends will use to perform the build: 5 | build-backend = "setuptools.build_meta" # If not defined, then legacy behavior can happen. 6 | 7 | 8 | [project] 9 | name = "functional_autoencoders" 10 | version = "1.0.0" 11 | description = "Functional autoencoder (FAE) and functional variational autoencoder (FVAE)" 12 | readme = "README.md" 13 | requires-python = ">=3.9" 14 | license = {file = "LICENSE"} 15 | keywords = ["machine learning", "deep learning", "autoencoders"] 16 | 17 | authors = [ 18 | {name = "Justin Bunker", email = "jb2200@cantab.ac.uk" }, 19 | {name = "Hefin Lambley", email = "hefin.lambley@warwick.ac.uk" }, 20 | ] 21 | 22 | maintainers = [ 23 | {name = "Justin Bunker", email = "jb2200@cantab.ac.uk" }, 24 | {name = "Hefin Lambley", email = "hefin.lambley@warwick.ac.uk" }, 25 | ] 26 | 27 | classifiers = [ 28 | "Development Status :: 5 - Production/Stable", 29 | "Programming Language :: Python :: 3 :: Only", 30 | ] 31 | 32 | dynamic = ["dependencies"] 33 | 34 | [project.urls] 35 | "Homepage" = "https://github.com/????" 36 | "Bug Reports" = "https://github.com/????/issues" 37 | 38 | [tool.setuptools.dynamic] 39 | dependencies = {file = ["requirements.txt"]} -------------------------------------------------------------------------------- /quickstart/0_Getting_Started.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Quickstart 0: Getting started\n", 8 | "Welcome to the quickstart series for `functional_autoencoders`, a code package containing implementing\n", 9 | "\n", 10 | "- functional variational autoencoder (FVAE); and\n", 11 | "- functional autoencoder (FAE)\n", 12 | "\n", 13 | "in Python and [JAX](https://github.com/google/jax), based on our paper [*Autoencoders in Function Space*](https://arxiv.org/pdf/2408.01362). This series contains the following notebooks:\n", 14 | "\n", 15 | "1. [An Introduction to FVAE](./1_FVAE.ipynb)\n", 16 | "2. [An Introduction to FAE](./2_FAE.ipynb)\n", 17 | "3. [Custom Datasets](./3_Custom_Datasets.ipynb)\n", 18 | "4. [Custom Architectures](./4_Custom_Architectures.ipynb).\n", 19 | "\n", 20 | "If you want to get started as quickly as possible, you can jump straight into the first notebook. \n", 21 | "But we suggest that you first read the following two-minute introduction to `functional_autoencoders` so you can use the most appropriate model for your problems.\n", 22 | "\n" 23 | ] 24 | }, 25 | { 26 | "cell_type": "markdown", 27 | "metadata": {}, 28 | "source": [ 29 | "### Why use autoencoders in function space?\n", 30 | "\n", 31 | "Machine-learning methods for image data, e.g., convolutional neural networks, are usually formulated for a specific pixellation resolution.\n", 32 | "This means that all data must be provided on the same grid of pixels at training and inference time, even though the underlying images could in principle be stored at arbitrarily fine resolution.\n", 33 | "\n", 34 | "![Fixed-resolution ML algorithm can only accept data at the training resolution](./images/fixed-res_algorithms.png)\n", 35 | "\n", 36 | "Viewing the pixellated images as discrete representations of **functional data**, conceiving algorithms that operate directly on functions, and only then discretising, it is possible to use the same machine-learning model for training and inference across resolutions.\n", 37 | "\n", 38 | "![Diagram of conceiving of algorithms at function-space level and discretising](images/algorithms_on_fn_space.png)\n", 39 | "\n", 40 | "In scientific machine learning, this has led to significant interest in learnable mappings between function spaces such as [DeepONet](https://www.nature.com/articles/s42256-021-00302-5) and [neural operators](https://jmlr.org/papers/v24/21-1524.html).\n", 41 | "The `functional_autoencoders` package adopts this philosophy for autoencoders, which allows for:\n", 42 | "- training with data at any resolution, potentially with missing mesh points;\n", 43 | "- nonlinear dimension reduction for data provided at any resolution; and\n", 44 | "- encoding and decoding on different meshes, allowing for inpainting and superresolution.\n", 45 | "\n", 46 | "![FAE inpainting and superresolution](images/FAE_inpainting_superresolution.png)\n", 47 | "\n", 48 | "\n", 49 | "\n", 50 | "### Should I use FAE or FVAE?\n", 51 | "\n", 52 | "**Summary**: FAE is a good starting point that works \"out of the box\" for most data, whereas FVAE works only for specific types of data.\n", 53 | "\n", 54 | "![Flowchart to aid in deciding whether to use FVAE or FAE](images/fae_or_fvae_flowchart.png)\n", 55 | "\n", 56 | "FVAE uses variational inference to learn a probabilistic encoder and decoder.\n", 57 | "When the FVAE objective is well defined, it gives a very natural extension of VAEs to function space.\n", 58 | "But this training objective is only well defined for specific types of data, such as:\n", 59 | "- path distributions of stochastic differential equations (SDEs); and\n", 60 | "- Bayesian posterior distributions arising from Gaussian priors and \"nice\" forward models.\n", 61 | "\n", 62 | "You can read more about this in sections 2.3 and 3 of [the paper](https://arxiv.org/pdf/2408.01362).\n", 63 | "\n", 64 | "In contrast, FAE is not motivated as a probabilistic model: it is a regularised autoencoder that is well defined in function space. \n", 65 | "This makes FAE much more broadly applicable to many datasets in scientific machine learning (see section 4 of [the paper](https://arxiv.org/pdf/2408.01362)).\n", 66 | "Since FAE is not a probabilistic model, some extra work is needed to use FAE as a generative model for functional data, as we'll explain in the second quickstart notebook." 67 | ] 68 | }, 69 | { 70 | "cell_type": "markdown", 71 | "metadata": {}, 72 | "source": [ 73 | "## What next?\n", 74 | "Start with either:\n", 75 | "1. [An Introduction to FVAE](./1_FVAE.ipynb); or\n", 76 | "2. [An Introduction to FAE](./2_FAE.ipynb).\n", 77 | "\n", 78 | "You can read these in any order, and they'll show you how to get started with a simple FVAE/FAE model." 79 | ] 80 | } 81 | ], 82 | "metadata": { 83 | "language_info": { 84 | "name": "python" 85 | } 86 | }, 87 | "nbformat": 4, 88 | "nbformat_minor": 2 89 | } 90 | -------------------------------------------------------------------------------- /quickstart/images/FAE_inpainting_superresolution.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htlambley/functional_autoencoders/540179aa43325e9fd96185cd4db738cf1dcd023f/quickstart/images/FAE_inpainting_superresolution.png -------------------------------------------------------------------------------- /quickstart/images/FAE_self-supervised_training.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htlambley/functional_autoencoders/540179aa43325e9fd96185cd4db738cf1dcd023f/quickstart/images/FAE_self-supervised_training.png -------------------------------------------------------------------------------- /quickstart/images/FVAE_decoder.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htlambley/functional_autoencoders/540179aa43325e9fd96185cd4db738cf1dcd023f/quickstart/images/FVAE_decoder.png -------------------------------------------------------------------------------- /quickstart/images/FVAE_encoder.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htlambley/functional_autoencoders/540179aa43325e9fd96185cd4db738cf1dcd023f/quickstart/images/FVAE_encoder.png -------------------------------------------------------------------------------- /quickstart/images/algorithms_on_fn_space.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htlambley/functional_autoencoders/540179aa43325e9fd96185cd4db738cf1dcd023f/quickstart/images/algorithms_on_fn_space.png -------------------------------------------------------------------------------- /quickstart/images/fae_or_fvae_flowchart.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htlambley/functional_autoencoders/540179aa43325e9fd96185cd4db738cf1dcd023f/quickstart/images/fae_or_fvae_flowchart.png -------------------------------------------------------------------------------- /quickstart/images/fixed-res_algorithms.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htlambley/functional_autoencoders/540179aa43325e9fd96185cd4db738cf1dcd023f/quickstart/images/fixed-res_algorithms.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | black 2 | jax 3 | # jax[cuda12] 4 | jaxlib 5 | flax 6 | numpy 7 | scipy 8 | h5py 9 | requests 10 | matplotlib 11 | seaborn 12 | dill 13 | ipywidgets 14 | scikit-learn 15 | -------------------------------------------------------------------------------- /src/functional_autoencoders/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Autoencoders in Function Space 3 | 4 | Welcome to the documentation for the official code accompanying the paper [***Autoencoders in Function Space***](https://arxiv.org/abs/2408.01362) by Justin Bunker, Mark Girolami, Hefin Lambley, Andrew M. Stuart, and T. J. Sullivan. 5 | 6 | To get started, head to the README page of the repository or access the API documentation in the sidebar. 7 | 8 | """ 9 | -------------------------------------------------------------------------------- /src/functional_autoencoders/autoencoder/__init__.py: -------------------------------------------------------------------------------- 1 | import flax.linen as nn 2 | from functional_autoencoders.encoders import Encoder 3 | from functional_autoencoders.decoders import Decoder 4 | 5 | 6 | class Autoencoder(nn.Module): 7 | r"""A flexible autoencoder designed for operator encoders and decoders. 8 | 9 | All encoders should take inputs: 10 | `u` : jnp.array of shape [batch, n_evals, out_dim] 11 | Represents `n_evals` evaluations of functions :math:`u \colon [0, 1]^{\text{in\_dim}} \to \mathbb{R}^{\text{out\_dim}}`. 12 | The evaluations do NOT include the boundary of :math:`[0,1]^{\text{in\_dim}}` 13 | 14 | `x` : jnp.array of shape [n_evals, in_dim] 15 | Represents the mesh points upon which :math:`u` is evaluated. 16 | 17 | Encoders should return an array of shape [batch, 2 * latent_dim], with the first latent_dim components representing the encoder 18 | mean and the second latent_dim components representing the log-variances on the diagonal of the encoder covariance. 19 | 20 | All decoders should take inputs: 21 | `z` : jnp.array of shape [batch, latent_dim] 22 | Represents the latent variables 23 | 24 | `x` : jnp.array of shape [n_evals, in_dim] 25 | Represents the mesh grids to evaluate the output function. 26 | 27 | Decoders should return an array of the same shape as the input `u` to the encoder. 28 | 29 | Notes: 30 | - It is assumed that the input mesh and the output mesh are the same, and that the mesh is the same for each example in 31 | the batch, so `x` does *not* have a batch dimension. 32 | 33 | - Calling the `Autoencoder` directly will map (u, x) to the latents z using the encoder, then (without adding the encoder noise) 34 | map straight back using the decoder (without adding any decoder noise). 35 | """ 36 | 37 | encoder: Encoder 38 | decoder: Decoder 39 | 40 | @nn.compact 41 | def __call__(self, u, x_enc, x_dec, train=False): 42 | z = self.encoder(u, x_enc, train) 43 | 44 | if self.encoder.is_variational: 45 | latent_dim = self.get_latent_dim() 46 | mean, _ = ( 47 | z[:, :latent_dim], 48 | z[:, latent_dim:], 49 | ) 50 | return self.decoder(mean, x_dec, train) 51 | else: 52 | return self.decoder(z, x_dec, train) 53 | 54 | def encode(self, state, u, x, train=False): 55 | return self.encoder.apply( 56 | { 57 | "params": state.params["encoder"], 58 | "batch_stats": state.batch_stats["encoder"], 59 | }, 60 | u, 61 | x, 62 | train, 63 | ) 64 | 65 | def decode(self, state, z, x, train=False): 66 | return self.decoder.apply( 67 | { 68 | "params": state.params["decoder"], 69 | "batch_stats": state.batch_stats["decoder"], 70 | }, 71 | z, 72 | x, 73 | train, 74 | ) 75 | 76 | def get_latent_dim(self): 77 | return self.encoder.get_latent_dim() 78 | -------------------------------------------------------------------------------- /src/functional_autoencoders/datasets/darcy_flow.py: -------------------------------------------------------------------------------- 1 | import os 2 | import zipfile 3 | import scipy 4 | import numpy as np 5 | from functional_autoencoders.util.anti_aliasing import AntiAliasingManagerFourier 6 | from functional_autoencoders.util import pickle_save, pickle_load, get_raw_x 7 | from functional_autoencoders.datasets import DownloadableDataset 8 | 9 | 10 | class DarcyFlow(DownloadableDataset): 11 | r""" 12 | Solution pressure fields $p \colon \Omega \to \mathbb{R}$ for the Darcy flow steady-state model of flow in a porous medium, 13 | with domain $\Omega = [0, 1]^{2}$, as described in section 4.3.2 of [the paper](https://arxiv.org/abs/2408.01362), solving the 14 | partial-differential equation 15 | 16 | $$ - \nabla \cdot \bigl(k \nabla p \bigr) = \varphi \text{~on $\Omega$}, $$ 17 | $$ p = 0 \text{~on $\partial \Omega$}, $$ 18 | 19 | with $\varphi = 1$ and $k$ distributed randomly as the pushforward of the measure $N(0, (-\Delta + 9I)^{-2})$ under the map 20 | 21 | $$\psi(x) = 3 + 9 \cdot 1 \bigl[ x \geq 0 \bigr].$$ 22 | 23 | This is based on the dataset of Li et al. (2021) and consists of 1,024 train and 1,024 test samples discretised on a $421 \times 421$ 24 | grid. 25 | 26 | Run once with save_fast=True to save the fast data files. 27 | Subsequent runs can be done with load_fast=True to load the fast data files. 28 | 29 | ## References 30 | Li, Kovachki, Azizzadenesheli, Liu, Bhattacharya, Stuart, and Anandkumar (2021). Fourier neural operator for parametric partial differential equations. ICLR 2021. 31 | arXiv:2010.08895. 32 | """ 33 | 34 | data_url = "https://drive.google.com/u/0/uc?id=1Z1uxG9R8AdAGJprG5STcphysjm56_0Jf&export=download&confirm=t&uuid=9d0c35a0-3979-4852-b8fd-c1d4afec423c&at=AB6BwCA0wHtyid20GZfaIBVJ4aQv:1702379684316" 35 | download_filename = "Darcy_421.zip" 36 | dataset_filename = "" 37 | dataset_name = "fno_darcy" 38 | 39 | def __init__( 40 | self, 41 | downscale=-1, 42 | save_fast=False, 43 | load_fast=False, 44 | transform=None, 45 | *args, 46 | **kwargs, 47 | ): 48 | self.transform = transform 49 | self.downscale = downscale 50 | self.save_fast = save_fast 51 | self.load_fast = load_fast 52 | 53 | super().__init__(*args, **kwargs) 54 | 55 | def _preprocess_data(self): 56 | with zipfile.ZipFile(self.download_path, "r") as f: 57 | f.extractall(self.dataset_dir) 58 | 59 | def _get_slow_data_filename(self, train): 60 | if train: 61 | return "piececonst_r421_N1024_smooth1.mat" 62 | else: 63 | return "piececonst_r421_N1024_smooth2.mat" 64 | 65 | def _get_fast_data_filename(self, train): 66 | slow_data_filename = self._get_slow_data_filename(train) 67 | fast_data_filename = slow_data_filename.replace(".mat", "_fast.pkl") 68 | return fast_data_filename 69 | 70 | def _load_data(self, train): 71 | if self.load_fast: 72 | self._load_data_fast(train) 73 | else: 74 | self._load_data_slow(train) 75 | 76 | def _load_data_slow(self, train): 77 | self.dataset_filename = self._get_slow_data_filename(train) 78 | 79 | mat = scipy.io.loadmat(self.dataset_path, variable_names=["coeff", "sol"]) 80 | u = mat["sol"].astype(float) 81 | 82 | n = u.shape[-1] 83 | x = get_raw_x(n, n).reshape(n, n, 2) 84 | x = np.array(x) 85 | 86 | if self.downscale != -1: 87 | aam = AntiAliasingManagerFourier( 88 | cutoff_nyq=0.99, 89 | mask_blur_kernel_size=7, 90 | gaussian_sigma=0.1, 91 | ) 92 | 93 | u = aam.downsample(u, self.downscale) 94 | x = x[::self.downscale, ::self.downscale, :] 95 | 96 | u = u.reshape(u.shape[0], -1, 1) 97 | x = x.reshape(-1, 2) 98 | 99 | u = (u - u.min()) / (u.max() - u.min()) 100 | self.data = { 101 | "u": u, 102 | "x": x, 103 | } 104 | 105 | if self.save_fast: 106 | print("Saving fast data") 107 | 108 | save_filename = self._get_fast_data_filename(train) 109 | save_path = os.path.join(self.dataset_dir, save_filename) 110 | pickle_save(self.data, save_path) 111 | 112 | print("Done!") 113 | 114 | def _load_data_fast(self, train): 115 | self.dataset_filename = self._get_fast_data_filename(train) 116 | load_path = os.path.join(self.dataset_dir, self.dataset_filename) 117 | self.data = pickle_load(load_path) 118 | 119 | def __len__(self): 120 | return self.data["u"].shape[0] 121 | 122 | def __getitem__(self, idx): 123 | u = self.data["u"][idx].reshape(-1, 1) 124 | x = self.data["x"].reshape(-1, 2) 125 | 126 | if self.transform is not None: 127 | return self.transform(u, x) 128 | else: 129 | return u, x, u, x 130 | -------------------------------------------------------------------------------- /src/functional_autoencoders/datasets/dirac.py: -------------------------------------------------------------------------------- 1 | from functional_autoencoders.datasets import GenerableDataset 2 | import numpy as np 3 | 4 | 5 | class RandomDirac(GenerableDataset): 6 | """ 7 | Dataset representing Dirac masses with random, uniformly chosen centre in :math:`(0, 1)`. 8 | 9 | For numerical purposes, the Dirac mass is represented by a function which is zero except at one mesh point, 10 | with height chosen such that the left/right Riemann sum has constant mass 1 at any resolution. 11 | """ 12 | 13 | def __init__( 14 | self, 15 | fixed_centre, 16 | pts=128, 17 | transform=None, 18 | *args, 19 | **kwargs, 20 | ): 21 | self._fixed_centre = fixed_centre 22 | self._pts = pts 23 | self.transform = transform 24 | super().__init__(*args, **kwargs) 25 | 26 | def generate(self): 27 | x = np.linspace(0, 1, self._pts + 2)[1:-1] 28 | x = np.expand_dims(x, -1) 29 | if self._fixed_centre: 30 | centres = np.ones((2,), dtype=np.int32) * self._pts // 2 31 | else: 32 | centres = np.tile(np.arange(8)[1:-1] * int(self._pts / 8), 2) 33 | masses = np.ones_like(centres) 34 | height = self._pts + 1 35 | u = np.zeros((centres.shape[0], self._pts)) 36 | for i, c in enumerate(centres): 37 | u[i, c] = height * masses[i] 38 | u = np.expand_dims(u, -1) 39 | 40 | self.data = {"u": u, "x": x, "masses": masses, "centres": centres} 41 | 42 | @property 43 | def x(self): 44 | return self.data["x"][:] 45 | 46 | @property 47 | def masses(self): 48 | return self.data["masses"][:] 49 | 50 | @property 51 | def centres(self): 52 | return self.data["centres"][:] 53 | 54 | def __len__(self): 55 | return self.data["u"].shape[0] // 2 56 | 57 | def __getitem__(self, idx): 58 | if not self.train: 59 | idx += self.data["u"].shape[0] // 2 60 | u = self.data["u"][idx] 61 | x = self.data["x"][:] 62 | if self.transform is not None: 63 | return self.transform(u, x) 64 | else: 65 | return u, x, u, x 66 | -------------------------------------------------------------------------------- /src/functional_autoencoders/datasets/navier_stokes.py: -------------------------------------------------------------------------------- 1 | import os 2 | import zipfile 3 | import h5py 4 | import scipy 5 | import numpy as np 6 | from functional_autoencoders.util import pickle_save, pickle_load, get_raw_x 7 | from functional_autoencoders.datasets import DownloadableDataset 8 | 9 | 10 | class NavierStokes(DownloadableDataset): 11 | r"""The Navier--Stokes dataset for a viscous, incompressible fluid in two dimensions as used by Li et al. (2021), 12 | as described in [the paper](https://arxiv.org/abs/2408.01362). 13 | 14 | ## Dataset 15 | 16 | The dataset consists of pairs `(initial_condition, trajectory)`. 17 | The initial condition is generated from a Gaussian random field with mean 18 | zero and covariance operator 19 | 20 | $$ \mathcal{C} = 7^{3/2} \bigl(-\Delta + 49I\bigr)^{-5/2}.$$ 21 | 22 | The trajectory consists of $T$ timesteps evolved from the initial condition 23 | with a step length $\Delta t$. 24 | The forcing term $\varphi$ is given by 25 | 26 | $$ f(x) = 0.1 \Bigl( \sin\bigl( 2\pi (x_{1} + x_{2}) \bigr) + \cos\bigl( 2\pi(x_{1} + x_{2}) \bigr) \Bigr).$$ 27 | 28 | The possible choices of parameters are as follows: 29 | - `viscosity = 1e-3`, `resolution = 64`. This dataset has 5,000 trajectories of $T = 50$ seconds with step $\Delta t = 1$ second. 30 | - `viscosity = 1e-4`, `resolution = 64`. This dataset has 10,000 trajectories of $T = 50$ seconds with step $\Delta t = 1$ second. 31 | - `viscosity = 1e-4`, `resolution = 256`. This dataset has 20 trajectories of $T = 50$ seconds with step $\Delta t = 0.25$ seconds. 32 | - `viscosity = 1e-5`, `resolution = 64`. This dataset has 1,200 trajectories of $T = 20$ seconds with step $\Delta t = 1$ second. 33 | 34 | ## Notes 35 | 36 | In Li et al. (2021), the $\nu = 10^{-4}$ dataset is simulated only up to $T = 30$ seconds, 37 | but the dataset actually includes simulations up to $T = 50$. 38 | """ 39 | 40 | dataset_name = "navier_stokes" 41 | # Data URL and filename are dynamically determined based on resolution and viscosity choice. 42 | data_url = "" 43 | download_filename = "" 44 | dataset_filename = "" 45 | _is_h5 = True 46 | 47 | def __init__( 48 | self, 49 | viscosity=1e-3, 50 | resolution=64, 51 | time_idx=-1, 52 | train_test_split_ratio=0.8, 53 | save_fast=False, 54 | load_fast=False, 55 | transform=None, 56 | *args, 57 | **kwargs, 58 | ): 59 | self.train_test_split_ratio = train_test_split_ratio 60 | self.time_idx = time_idx 61 | self.save_fast = save_fast 62 | self.load_fast = load_fast 63 | 64 | if viscosity not in [1e-3, 1e-4, 1e-5]: 65 | raise ValueError("Viscosity must take value 1e-3, 1e-4 or 1e-5") 66 | 67 | if resolution not in [64, 256]: 68 | raise ValueError( 69 | "Navier--Stokes dataset only available at 64x64 or 256x256 resolution. See documentation for valid combinations." 70 | ) 71 | 72 | if resolution != 64 and viscosity != 1e-4: 73 | raise ValueError( 74 | "Navier--Stokes dataset only available at 256x256 when viscosity is 1e-4." 75 | ) 76 | 77 | self.viscosity = viscosity 78 | self.resolution = resolution 79 | self.transform = transform 80 | 81 | if viscosity == 1e-3: 82 | self.data_url = "https://drive.usercontent.google.com/download?id=1r3idxpsHa21ijhlu3QQ1hVuXcqnBTO7d&export=download&authuser=0&confirm=t&uuid=05b098fa-6a5b-40cd-9b0f-39fa5b7c9261&at=APZUnTWMn104jdp7fiMuS7y5sxL7:1702487907119" 83 | self.download_filename = "NavierStokes_V1e-3_N5000_T50.zip" 84 | self.dataset_filename = "ns_V1e-3_N5000_T50.mat" 85 | 86 | elif viscosity == 1e-4: 87 | if resolution == 64: 88 | self.data_url = "https://drive.usercontent.google.com/download?id=1RmDQQ-lNdAceLXrTGY_5ErvtINIXnpl3&export=download&authuser=0&confirm=t&uuid=a218d5ab-1b75-4b1c-a5da-ed0b71bd3f20&at=APZUnTXLTgsSn6kzgcqTUn2fwDTk:1702490181229" 89 | self.download_filename = "NavierStokes_V1e-4_N10000_T30.zip" 90 | self.dataset_filename = "ns_V1e-4_N10000_T30.mat" 91 | else: 92 | self.data_url = "https://drive.usercontent.google.com/download?id=1pr_Up54tNADCGhF8WLvmyTfKlCD5eEkI&export=download&authuser=0&confirm=t&uuid=d6ce8938-295e-40a4-9f87-d24ebdabf310&at=APZUnTUbM3Hu0GZiNSbQuzuyH7fr:1702639561668" 93 | self.download_filename = "NavierStokes_V1e-4_N20_T50_R256_test.zip" 94 | self.dataset_filename = "ns_data_V1e-4_N20_T50_R256test.mat" 95 | self._is_h5 = False 96 | 97 | elif viscosity == 1e-5: 98 | self.data_url = "https://drive.usercontent.google.com/download?id=1lVgpWMjv9Z6LEv3eZQ_Qgj54lYeqnGl5&export=download&authuser=0&confirm=t&uuid=68addf1d-8b63-4591-b32a-6ee2b4886e66&at=APZUnTU4rRh0Qwj4UUGbBBxetFUR:1702490348128" 99 | self.download_filename = "NavierStokes_V1e-5_N1200_T20.zip" 100 | self.dataset_filename = "NavierStokes_V1e-5_N1200_T20.mat" 101 | self._is_h5 = False 102 | 103 | else: 104 | raise NotImplementedError() 105 | 106 | super().__init__(*args, **kwargs) 107 | 108 | def _load_data(self, train): 109 | if self.load_fast: 110 | self._load_data_fast(train) 111 | else: 112 | self._load_data_slow(train) 113 | 114 | def _load_data_slow(self, train): 115 | if self._is_h5: 116 | data = h5py.File(self.dataset_path, "r") 117 | self.u_data = np.moveaxis(data["u"][self.time_idx, :, :, :], -1, 0) 118 | else: 119 | data = scipy.io.loadmat(self.dataset_path) 120 | self.u_data = data["u"][:, :, :, self.time_idx] 121 | 122 | n_train = int(self.train_test_split_ratio * self.u_data.shape[0]) 123 | 124 | if train: 125 | self.u_data = self.u_data[:n_train] 126 | else: 127 | self.u_data = self.u_data[n_train:] 128 | 129 | self.u_data = (self.u_data - self.u_data.min()) / ( 130 | self.u_data.max() - self.u_data.min() 131 | ) 132 | self.x = get_raw_x(*self.u_data.shape[1:3]) 133 | self.x = np.array(self.x) 134 | 135 | if self.save_fast: 136 | self._save_data_fast(train) 137 | 138 | def _save_data_fast(self, train): 139 | print("Saving fast data") 140 | 141 | save_filename = self._get_fast_data_filename(train) 142 | save_path = os.path.join(self.dataset_dir, save_filename) 143 | pickle_save( 144 | { 145 | "u_data": self.u_data, 146 | "x": self.x, 147 | }, 148 | save_path, 149 | ) 150 | 151 | print("Done!") 152 | 153 | def _load_data_fast(self, train): 154 | self.dataset_filename = self._get_fast_data_filename(train) 155 | load_path = os.path.join(self.dataset_dir, self.dataset_filename) 156 | data = pickle_load(load_path) 157 | self.u_data = data["u_data"] 158 | self.x = data["x"] 159 | 160 | def _get_fast_data_filename(self, train): 161 | filename_suffix = ("_train" if train else "_test") + "_fast.pkl" 162 | fast_data_filename = self.dataset_filename.replace(".mat", filename_suffix) 163 | return fast_data_filename 164 | 165 | def _preprocess_data(self): 166 | with zipfile.ZipFile(self.download_path, "r") as f: 167 | f.extractall(self.dataset_dir) 168 | 169 | def __len__(self): 170 | return self.u_data.shape[0] 171 | 172 | def __getitem__(self, idx): 173 | u = self.u_data[idx].reshape(-1, 1) 174 | x = self.x.reshape(-1, 2) 175 | 176 | if self.transform is not None: 177 | return self.transform(u, x) 178 | else: 179 | return u, x, u, x 180 | -------------------------------------------------------------------------------- /src/functional_autoencoders/datasets/sde.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import jax 3 | from functional_autoencoders.util.random.sde import euler_maruyama 4 | from functional_autoencoders.datasets import GenerableDataset 5 | 6 | 7 | def get_brownian_dynamics_drift(potential, *args, **kwargs): 8 | neg_pot = lambda x: -potential(x, *args, **kwargs) 9 | net_pot_grad = jax.grad(neg_pot) 10 | neg_pot_grad_vmap = jax.vmap(net_pot_grad, in_axes=(0,)) 11 | neg_pot_grad_vmap_jit = jax.jit(neg_pot_grad_vmap) 12 | return lambda X, t: neg_pot_grad_vmap_jit(X) 13 | 14 | 15 | def get_brownian_dynamics_diffusion(epsilon): 16 | return lambda x, dwt, t: (epsilon ** (0.5)) * dwt 17 | 18 | 19 | class SDE(GenerableDataset): 20 | def __init__( 21 | self, 22 | drift, 23 | diffusion, 24 | x0, 25 | samples=200, 26 | pts=100, 27 | T=1, 28 | sim_dt=1e-3, 29 | verbose=False, 30 | transform=None, 31 | transform_generated=None, 32 | *args, 33 | **kwargs, 34 | ): 35 | self._samples = samples 36 | self._pts = pts 37 | self._sim_dt = sim_dt 38 | self._dt = T / pts 39 | self._subsample_rate = int(self._dt / sim_dt) 40 | self._n_steps = pts * self._subsample_rate 41 | self.drift = drift 42 | self.diffusion = diffusion 43 | self.verbose = verbose 44 | self._x0 = x0 45 | self.transform = transform 46 | self.transform_generated = transform_generated 47 | super().__init__(*args, **kwargs) 48 | 49 | def generate(self): 50 | if self._samples <= 0 or self._pts <= 0: 51 | raise ValueError( 52 | "To generate dataset, need number of realisations `samples > 0` and grid points `pts > 0`" 53 | ) 54 | 55 | x0 = np.repeat(np.expand_dims(self._x0, 0), self._samples, axis=0) 56 | u = euler_maruyama( 57 | x0, 58 | self.drift, 59 | self.diffusion, 60 | self._sim_dt, 61 | self._n_steps, 62 | self._subsample_rate, 63 | self.verbose, 64 | ) 65 | 66 | x = np.arange(0, self._pts + 1) * self._dt 67 | x = np.expand_dims(x, -1) 68 | 69 | if self.transform_generated is not None: 70 | u, x = self.transform_generated(u, x) 71 | 72 | self.data = { 73 | "u": u, 74 | "x": x, 75 | "x0": x0, 76 | "sim_dt": self._sim_dt, 77 | "samples": self._samples, 78 | } 79 | 80 | def __len__(self): 81 | return self.data["samples"] 82 | 83 | def __getitem__(self, idx): 84 | u = self.data["u"][idx] 85 | x = self.data["x"][:] 86 | if self.transform is not None: 87 | return self.transform(u, x) 88 | else: 89 | return u, x, u, x 90 | 91 | @property 92 | def x0(self): 93 | return self.data["x0"][:] 94 | 95 | @property 96 | def x(self): 97 | return self.data["x"][:] 98 | 99 | @property 100 | def sim_dt(self) -> float: 101 | return self.data["sim_dt"] 102 | -------------------------------------------------------------------------------- /src/functional_autoencoders/datasets/vano.py: -------------------------------------------------------------------------------- 1 | from functional_autoencoders.datasets import GenerableDataset 2 | from functional_autoencoders.util.random import grf 3 | from functools import partial 4 | import jax 5 | import jax.numpy as jnp 6 | import numpy as np 7 | import scipy 8 | import itertools 9 | 10 | 11 | class GRF(GenerableDataset): 12 | r"""Realisations from a Gaussian process on [0, 1] with zero boundary conditions and Matérn-like covariance operator. 13 | 14 | The data are realisations from a Gaussian measure $N(0, C)$ on the space $L^{2}([0, 1])$ with the covariance operator given 15 | as a Matérn-like inverse power of the Dirichlet Laplacian: 16 | 17 | $$ C = (\tau^{2} I - \Delta)^{d}. $$ 18 | 19 | The realisations are generated using the `dirichlet_grf` function in the `random` module, which uses a discrete sine transformation 20 | to efficiently generate realisations of the Gaussian process. 21 | 22 | Notes: 23 | This dataset supports transformations using transform (e.g. for rescaling). The VANO implementation of Seidman et al. (2023) 24 | instead applies a rescaling *in the ELBO loss* and this is preferable when trying to reproduce their results. 25 | 26 | References: 27 | Seidman et al. (2023). Variational autoencoding neural operator. ICML. 28 | """ 29 | 30 | def __init__( 31 | self, 32 | samples=200, 33 | pts=2000, 34 | tau=3.0, 35 | d=2.0, 36 | even_powers_only=False, 37 | dim=1, 38 | out_dim=1, 39 | transform=None, 40 | *args, 41 | **kwargs, 42 | ): 43 | self._samples = samples 44 | self._pts = pts 45 | self._tau = tau 46 | self._d = d 47 | self._even_powers_only = even_powers_only 48 | self._dim = dim 49 | self._out_dim = out_dim 50 | self.transform = transform 51 | super().__init__(*args, **kwargs) 52 | 53 | def generate(self): 54 | if not self.train: 55 | raise NotImplementedError( 56 | "Test/train split not yet implemented; only train data available" 57 | ) 58 | 59 | if self._samples <= 0 or self._pts <= 0: 60 | raise ValueError( 61 | "To generate `GRF` dataset, need number of realisations `samples > 0` and grid points `pts > 0`" 62 | ) 63 | 64 | self._samples *= 2 65 | grid = [self._pts] * self._dim 66 | u = grf.dirichlet_grf( 67 | None, 68 | self._samples, 69 | grid, 70 | tau=self._tau, 71 | d=self._d, 72 | even_powers_only=self._even_powers_only, 73 | out_dim=self._out_dim, 74 | ) 75 | x1 = np.linspace(0, 1, self._pts + 2)[1:-1] 76 | xs = np.meshgrid(*([x1] * self._dim), indexing="ij") 77 | xs = [np.expand_dims(v, -1) for v in xs] 78 | x = np.concatenate(xs, axis=-1) 79 | 80 | u = np.reshape(u, (u.shape[0], -1, self._out_dim)) 81 | x = np.reshape(x, (-1, x.shape[-1])) 82 | 83 | self.data = { 84 | "u": u, 85 | "x": x, 86 | "d": self._d, 87 | "tau": self._tau, 88 | "even_powers_only": self._even_powers_only, 89 | } 90 | 91 | def __len__(self): 92 | return self.data["u"].shape[0] // 2 93 | 94 | def __getitem__(self, idx): 95 | if not self.train: 96 | idx += len(self) 97 | 98 | u = self.data["u"][idx] 99 | x = self.data["x"][:] 100 | if self.transform is not None: 101 | return self.transform(u, x) 102 | else: 103 | return u, x, u, x 104 | 105 | @property 106 | def pts(self) -> float: 107 | return self.data["u"].shape[1] 108 | 109 | @property 110 | def tau(self) -> float: 111 | return self.data["tau"] 112 | 113 | @property 114 | def d(self) -> float: 115 | return self.data["d"] 116 | 117 | @property 118 | def x(self) -> np.array: 119 | return self.data["x"][:] 120 | 121 | @property 122 | def even_powers_only(self) -> bool: 123 | return self.data["even_powers_only"] 124 | 125 | @partial(jax.vmap, in_axes=(None, 0)) 126 | def scaled_basis(self, x): 127 | if x.shape[-1] != 1 or self.data["u"].shape[-1] != 1: 128 | raise NotImplementedError() 129 | n = jnp.reshape(jnp.arange(1, x.shape[0] + 1), (1, -1)) 130 | basis = 2 * jnp.sin(jnp.pi * n * x) 131 | sqrt_eigs = np.reshape( 132 | grf._compute_dirichlet_covariance_operator_sqrt_eigenvalues( 133 | (x.shape[0],), 134 | tau=self.tau, 135 | d=self.d, 136 | even_powers_only=self.even_powers_only, 137 | ), 138 | (1, -1), 139 | ) 140 | return sqrt_eigs * basis 141 | 142 | 143 | class GaussianDensities(GenerableDataset): 144 | def __init__( 145 | self, 146 | samples=2048, 147 | pts=48, 148 | n_gaussians=1, 149 | std_min=0.01, 150 | std_max=0.11, 151 | transform=None, 152 | *args, 153 | **kwargs, 154 | ): 155 | self._samples = samples 156 | self._pts = pts 157 | self._n_gaussians = n_gaussians 158 | self._std_min = std_min 159 | self._std_max = std_max 160 | self.transform = transform 161 | super().__init__(*args, **kwargs) 162 | 163 | def generate(self): 164 | if self._samples <= 0 or self._pts <= 0: 165 | raise ValueError( 166 | "To generate `GaussianDensities` dataset, need number of realisations `samples > 0` and grid points `pts > 0`" 167 | ) 168 | 169 | # Generate equal test--train split (i.e. _samples of train and _samples of test) 170 | self._samples *= 2 171 | means = np.random.uniform(size=(self._samples, self._n_gaussians, 2)) 172 | stds = np.random.uniform( 173 | low=self._std_min, 174 | high=self._std_max, 175 | size=(self._samples, self._n_gaussians), 176 | ) 177 | x = np.linspace(0, 1, self._pts + 1)[:-1] 178 | xs = np.array(list(itertools.product(x, x))) 179 | u = np.zeros((self._samples, self._pts, self._pts)) 180 | for sample in range(0, self._samples): 181 | for i in range(0, self._n_gaussians): 182 | out = np.reshape( 183 | scipy.stats.multivariate_normal.pdf( 184 | xs, 185 | mean=means[sample, i, :], 186 | cov=((stds[sample, i]) ** 2) * np.eye(2), 187 | ), 188 | (self._pts, self._pts), 189 | ) 190 | u[sample, :, :] += out 191 | u = u.reshape(self._samples, -1, 1) 192 | 193 | self.data = {"u": u, "x": x} 194 | 195 | def __len__(self): 196 | return self.data["u"].shape[0] // 2 197 | 198 | def __getitem__(self, idx): 199 | if not self.train: 200 | idx += len(self) 201 | u = self.data["u"][idx] 202 | x = self.data["x"][:] 203 | if self.transform is not None: 204 | return self.transform(u, x) 205 | else: 206 | return u, x, u, x 207 | 208 | @property 209 | def x(self) -> np.array: 210 | return self.data["x"][:] 211 | -------------------------------------------------------------------------------- /src/functional_autoencoders/decoders/__init__.py: -------------------------------------------------------------------------------- 1 | import flax.linen as nn 2 | import jax.numpy as jnp 3 | 4 | 5 | class Decoder(nn.Module): 6 | """ 7 | For general comments, see `Autoencoder` documentation. 8 | """ 9 | 10 | @nn.compact 11 | def __call__(self, z, x, train=False): 12 | u = self._forward(z, x, train) 13 | return u 14 | 15 | def _forward(self, z, x): 16 | raise NotImplementedError() 17 | 18 | 19 | def _apply_grid_decoder_operator(z, x, operator): 20 | """ 21 | Helper function to reshape appropriately to a grid in order to apply grid-based operators like 22 | the Fourier neural operator, used in `FNODecoder`. 23 | """ 24 | 25 | # Reshape x to be a grid for use with the FNO 26 | n_batch = z.shape[0] 27 | input_dimension = x.shape[-1] 28 | n = round(x.shape[1] ** (1 / input_dimension)) 29 | x_shape = [n_batch] + [n] * input_dimension + [x.shape[-1]] 30 | x = jnp.reshape(x, x_shape) 31 | 32 | # Lift z to be a constant function 33 | new_dims = x.ndim - 2 34 | z = jnp.reshape(z, [z.shape[0]] + [1] * new_dims + [z.shape[1]]) 35 | tiling_shape = [1] + list(x.shape[1:-1]) + [1] 36 | z = jnp.tile(z, tiling_shape) 37 | 38 | # Concatenate the two functions on the channel axis 39 | zx = jnp.concatenate((z, x), axis=-1) 40 | 41 | # Apply the FNO 42 | u = operator(zx, x) 43 | 44 | # Reshape to the "sparse" convention 45 | u = jnp.reshape(u, (u.shape[0], -1, u.shape[-1])) 46 | return u 47 | -------------------------------------------------------------------------------- /src/functional_autoencoders/decoders/cnn_decoder.py: -------------------------------------------------------------------------------- 1 | import flax.linen as nn 2 | import jax.numpy as jnp 3 | from typing import Sequence 4 | from functional_autoencoders.decoders import Decoder 5 | from functional_autoencoders.util.networks import MLP, CNN 6 | 7 | 8 | class CNNDecoder(Decoder): 9 | c_in: int 10 | grid_pts_in: int 11 | trans_cnn_features: Sequence[int] = (32, 16, 8) 12 | kernel_sizes: Sequence[int] = (2, 2, 2) 13 | strides: Sequence[int] = (2, 2, 2) 14 | final_cnn_features: Sequence[int] = (16, 1) 15 | final_kernel_sizes: Sequence[int] = (3,) 16 | final_strides: Sequence[int] = (1,) 17 | mlp_features: Sequence[int] = (128, 128, 128) 18 | 19 | def _forward(self, z, x, train=False): 20 | u = MLP([*self.mlp_features, self.grid_pts_in**2 * self.c_in])(z) 21 | u = jnp.reshape(u, (-1, self.grid_pts_in, self.grid_pts_in, self.c_in)) 22 | 23 | u = CNN( 24 | self.trans_cnn_features, self.kernel_sizes, self.strides, is_transpose=True 25 | )(u) 26 | u = nn.relu(u) 27 | 28 | u = CNN(self.final_cnn_features, self.final_kernel_sizes, self.final_strides)(u) 29 | u = jnp.reshape(u, (u.shape[0], -1, 1)) 30 | return u 31 | -------------------------------------------------------------------------------- /src/functional_autoencoders/decoders/fno_decoder.py: -------------------------------------------------------------------------------- 1 | from dataclasses import field 2 | from functional_autoencoders.decoders import Decoder, _apply_grid_decoder_operator 3 | from functional_autoencoders.domains import Domain 4 | from functional_autoencoders.util.networks.fno import FNO 5 | 6 | 7 | class FNODecoder(Decoder): 8 | r"""A nonlinear decoder mapping from a finite-dimensional latent variable to a function with an FNO. 9 | 10 | IMPORTANT: the output mesh `x` must be a regular grid on :math:`[0, 1]^{d}`, excluding the boundary. 11 | If `x` is incorrectly specified then internal reshapes will fail or the specified values of `x` may be silently ignored. 12 | 13 | First, the latent vector :math:`z \in \mathbb{R}^{\text{latent\_dim}}` is lifted to a function 14 | :math:`u \colon [0, 1]^{\text{in\_dim}} \to \mathbb{R}^{\text{latent\_dim} + \text{out\_dim}}` that 15 | takes the constant value :math:`z_{i}`, :math:`i = 1, \dots, \text{latent\_dim}` in the first `latent_dim` 16 | components and is the identity function in the remaining component. 17 | """ 18 | 19 | out_dim: int 20 | domain: Domain 21 | hidden_dim: int = 64 22 | n_layers: int = 1 23 | n_modes_per_dim: int = 12 24 | fno_args: dict = field(default_factory=dict) 25 | 26 | def _forward(self, z, x, train=False): 27 | n_modes = [[self.n_modes_per_dim] * x.shape[-1]] * self.n_layers 28 | lifting_features = [self.hidden_dim] 29 | projection_features = [self.hidden_dim] 30 | 31 | operator = FNO( 32 | n_modes, 33 | lifting_features, 34 | [*projection_features, self.out_dim], 35 | self.domain, 36 | **self.fno_args, 37 | ) 38 | u = _apply_grid_decoder_operator(z, x, operator) 39 | return u 40 | -------------------------------------------------------------------------------- /src/functional_autoencoders/decoders/linear_decoder.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import flax.linen as nn 3 | from typing import Sequence 4 | from dataclasses import field 5 | from functional_autoencoders.decoders import Decoder 6 | from functional_autoencoders.positional_encodings import ( 7 | PositionalEncoding, 8 | IdentityEncoding, 9 | ) 10 | from functional_autoencoders.util.networks import MLP 11 | 12 | 13 | class LinearDecoder(Decoder): 14 | """ 15 | Essentially the same as a "stacked" DeepONet. 16 | 17 | Inputs: 18 | 19 | `z` : [batch, basis] 20 | tensor of basis coefficients, e.g. [64, 10] for 10 basis coefficients 21 | 22 | `x` : [batch, n_evals, in_dim] 23 | tensor of query points 24 | """ 25 | 26 | out_dim: int 27 | n_basis: int = 64 28 | features: Sequence[int] = (128, 128, 128) 29 | positional_encoding: PositionalEncoding = IdentityEncoding() 30 | mlp_args: dict = field(default_factory=dict) 31 | 32 | def setup(self): 33 | self.net = MLP([*self.features, self.n_basis * self.out_dim], **self.mlp_args) 34 | 35 | def _forward(self, z, x, train=False): 36 | basis = self.basis(x) 37 | return jnp.einsum("ij,...ikjl->ikl", z, basis) 38 | 39 | def basis(self, x): 40 | x = self.positional_encoding(x) 41 | basis = self.net(x) 42 | return jnp.reshape(basis, (x.shape[0], x.shape[1], self.n_basis, self.out_dim)) 43 | -------------------------------------------------------------------------------- /src/functional_autoencoders/decoders/nonlinear_decoder.py: -------------------------------------------------------------------------------- 1 | import jax 2 | from jax.typing import ArrayLike 3 | import jax.numpy as jnp 4 | from typing import Sequence, Callable 5 | from functional_autoencoders.decoders import Decoder 6 | from functional_autoencoders.positional_encodings import ( 7 | PositionalEncoding, 8 | IdentityEncoding, 9 | ) 10 | from functional_autoencoders.util.networks import MLP 11 | from dataclasses import field 12 | 13 | 14 | class NonlinearDecoder(Decoder): 15 | r"""A nonlinear decoder :math:`g(z)(x) = f(z, \gamma(x))` learned using an MLP, where :math:`\gamma` 16 | is a positional encoding. 17 | 18 | The positional information :math:`\gamma(x)` of shape [batch, queries, n] is combined with the latent data 19 | [batch, m] by tiling the latent to shape [batch, queries, m] and concatenating to the final axis of 20 | :math:`\gamma(x)`, which is then fed to an MLP at the start. 21 | """ 22 | 23 | out_dim: int 24 | features: Sequence[int] = (128, 128, 128) 25 | positional_encoding: PositionalEncoding = IdentityEncoding() 26 | mlp_args: dict = field(default_factory=dict) 27 | post_activation: Callable[[ArrayLike], jax.Array] = lambda x: x 28 | concat_method: str = "initial" 29 | 30 | def _forward(self, z, x, train=False): 31 | x = self.positional_encoding(x) 32 | y = self._mlp_forward(z, x) 33 | y = self.post_activation(y) 34 | return y 35 | 36 | def _mlp_forward(self, z, x): 37 | if self.concat_method == "initial": 38 | return self._mlp_initial_concat(z, x) 39 | else: 40 | raise ValueError(f"Unknown method {self.method}") 41 | 42 | def _mlp_initial_concat(self, z, x): 43 | zx = self._concat(z, x) 44 | return MLP([*self.features, self.out_dim], **self.mlp_args)(zx) 45 | 46 | def _concat(self, z, x): 47 | n_evals = x.shape[1] 48 | z = jnp.repeat(jnp.expand_dims(z, 1), n_evals, axis=1) 49 | r = jnp.concatenate((z, x), axis=-1) 50 | return r 51 | -------------------------------------------------------------------------------- /src/functional_autoencoders/domains/__init__.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Domain objects specify which function space the data are assumed to come from, and encapsulate the appropriate function-space 3 | norms and inner products needed in the FVAE and FAE losses. 4 | 5 | - For the `fvae_sde` loss, the correct domain to use is `functional_autoencoders.domains.off_grid.SDE`. 6 | - For the `fae` loss, we provide domain objects for: 7 | - functional data defined on a square domain $[0, 1]^{d}$ with zero boundary conditions, discretised on a grid (`functional_autoencoders.domains.grid.ZeroBoundaryConditions`) 8 | - functional data defined on a periodic domain $\mathbb{T}^{d}$, discretised on a grid (`functional_autoencoders.domains.grid.PeriodicBoundaryConditions`) 9 | - functional data on a square domain (with arbitrary boundary conditions), discretised on possibly non-grid meshes (`functional_autoencoders.domains.off_grid.RandomlySampledEuclidean`). 10 | 11 | The `grid` domains allow for the assumption that the data lies in a Sobolev space of nonzero order, and the commensurate use 12 | of Sobolev norms in the loss, whereas the `non_grid` do not permit this. 13 | """ 14 | 15 | import jax 16 | from jax.typing import ArrayLike 17 | from typing import Tuple, Callable 18 | 19 | NonlocalTransform = Tuple[ 20 | Callable[[ArrayLike], jax.Array], Callable[[ArrayLike], jax.Array] 21 | ] 22 | 23 | 24 | class Domain: 25 | """ 26 | Base class representing the domain on which the data functions are defined, and the appropriate 27 | norms, inner products, and boundary conditions. 28 | 29 | Users should instantiate the appropriate derived class for their use case (see top-level `functional_autoencoders.domains` 30 | documentation). 31 | """ 32 | 33 | name: str 34 | 35 | def __init__(self, name: str): 36 | self.name = name 37 | 38 | def squared_norm(self, u: ArrayLike, x: ArrayLike) -> jax.Array: 39 | raise NotImplementedError() 40 | 41 | def inner_product(self, u: ArrayLike, v: ArrayLike, x: ArrayLike) -> jax.Array: 42 | raise NotImplementedError() 43 | 44 | def nonlocal_transform(self) -> NonlocalTransform: 45 | raise NotImplementedError() 46 | -------------------------------------------------------------------------------- /src/functional_autoencoders/domains/grid.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from jax.typing import ArrayLike 4 | from functional_autoencoders.domains import ( 5 | Domain, 6 | NonlocalTransform, 7 | ) 8 | from functional_autoencoders.util.fft import dstn, idstn 9 | 10 | 11 | def _dstn_transform(u): 12 | return dstn(u, type=1, axes=range(1, u.ndim - 1), norm="forward") 13 | 14 | 15 | def _dstn_inverse_transform(u): 16 | return idstn(u, type=1, axes=range(1, u.ndim - 1), norm="forward") 17 | 18 | 19 | def _dftn_transform(u): 20 | return jnp.fft.fftn(u, axes=range(1, u.ndim - 1), norm="forward") 21 | 22 | 23 | def _dftn_inverse_transform(u): 24 | return jnp.fft.ifftn(u, axes=range(1, u.ndim - 1), norm="forward").real 25 | 26 | 27 | class ZeroBoundaryConditions(Domain): 28 | s: float 29 | 30 | def __init__(self, s: float): 31 | self.s = s 32 | name = f"H^{{{s}}}_{{0}}" if s != 0 else "L^{2}_{0}" 33 | super().__init__(name) 34 | 35 | def squared_norm(self, u: ArrayLike, x: ArrayLike) -> jax.Array: 36 | r"""Computes the squared Sobolev :math:`H^{s}` norm of a function defined on the domain :math:`[0, 1]^{d}` with zero boundary conditions. 37 | 38 | Given a function :math:`u \colon [0, 1]^{d} \to \Reals^{m}` of the form 39 | 40 | .. math 41 | u = \sum_{n \in \mathbb{N}^{d}} \alpha_{n} \varphi_{n}, \qquad \varphi_{n} = 2^{d/2} \prod_{i = 1}^{d} \sin(\pi n_{i} x_{i}), 42 | 43 | the following norm equivalent to the $H^{s}$-norm is computed: 44 | 45 | .. math 46 | \|u\|_{H^{s}([0, 1]^{d})}^{2} = \sum_{n \in \mathbb{N}^{d}} (1 + \norm{n}^{2})^{s} \norm{\alpha_{n}}^{2}. 47 | 48 | This will need to be noted in any write-up. 49 | 50 | Arguments: 51 | u : jnp.array of shape [batch, n_evals, out_dim] 52 | 53 | x: jnp.array of shape [n_evals, in_dim] 54 | 55 | s : float 56 | Sobolev exponent. The exponent $s = 0$ corresponds to the $L^{2}$ norm. 57 | 58 | Notes: 59 | If a large value of $s$ or a large number of grid points are used, this method may yield unreliable results because 60 | of numerical instability arising from the application of the discrete sine transform. 61 | This can be mitigated somewhat by passing an array `u` with double-precision floats, but for very large $s$ or grid sizes 62 | the results will still be incorrect even with the use of double-precision floats. 63 | """ 64 | input_dimension = x.shape[-1] 65 | n = round(x.shape[1] ** (1 / input_dimension)) 66 | u_shape = [u.shape[0]] + [n] * input_dimension + [u.shape[-1]] 67 | u = jnp.reshape(u, u_shape) 68 | 69 | d = u.ndim - 2 70 | axes = list(range(1, u.ndim - 1)) 71 | # Scale the DST to get the coefficients in the orthonormal basis 72 | uhat = dstn(u, type=1, axes=axes, norm="forward") * (2 ** (d / 2)) 73 | l2_norm_squared = jnp.sum(uhat**2, axis=-1) 74 | 75 | if self.s != 0.0: 76 | # Compute the weights $(1 + \|n\|^{2})^{s}$; when $s < 0$, nans are sometimes produced 77 | # as $1 + \|n\|^{2}$ is very large but in that case we replace the weights by an appropriate value. 78 | ax = (slice(1, sz + 1) for sz in u.shape[1:-1]) 79 | weights = (1.0 + jnp.prod(jnp.mgrid[ax] ** 2, axis=0)) ** self.s 80 | weights = jnp.nan_to_num(weights, nan=jnp.inf if self.s >= 0 else 0) 81 | weights = jnp.expand_dims(weights, 0) 82 | else: 83 | weights = jnp.ones_like(l2_norm_squared) 84 | 85 | return jnp.sum(weights * l2_norm_squared, axis=range(1, weights.ndim)) 86 | 87 | def inner_product(self, u: ArrayLike, v: ArrayLike, x: ArrayLike) -> jax.Array: 88 | input_dimension = x.shape[-1] 89 | n = round(x.shape[1] ** (1 / input_dimension)) 90 | u_shape = [u.shape[0]] + [n] * input_dimension + [u.shape[-1]] 91 | u = jnp.reshape(u, u_shape) 92 | v = jnp.reshape(v, u_shape) 93 | 94 | d = u.ndim - 2 95 | axes = list(range(1, u.ndim - 1)) 96 | uhat = dstn(u, type=1, axes=axes, norm="forward") * (2 ** (d / 2)) 97 | vhat = dstn(v, type=1, axes=axes, norm="forward") * (2 ** (d / 2)) 98 | 99 | if self.s != 0.0: 100 | ax = (slice(1, sz + 1) for sz in u.shape[1:-1]) 101 | weights = (1.0 + jnp.prod(jnp.mgrid[ax] ** 2, axis=0)) ** (self.s / 2) 102 | weights = jnp.nan_to_num(weights, nan=jnp.inf if self.s >= 0 else 0) 103 | weights = jnp.expand_dims(weights, 0) 104 | weights = jnp.expand_dims(weights, -1) 105 | else: 106 | weights = jnp.ones_like(uhat) 107 | 108 | uhat = weights * uhat 109 | vhat = weights * vhat 110 | 111 | return jnp.sum(jnp.sum(uhat * vhat, axis=-1), axis=range(1, uhat.ndim - 1)) 112 | 113 | @property 114 | def nonlocal_transform(self) -> NonlocalTransform: 115 | return (_dstn_transform, _dstn_inverse_transform) 116 | 117 | 118 | class PeriodicBoundaryConditions(Domain): 119 | s: float 120 | 121 | def __init__(self, s: float): 122 | if abs(s) > 1e-6: 123 | # Implementing this would just be a case of swapping out `dstn` to `dftn` in `ZeroBoundaryConditions` 124 | raise NotImplementedError() 125 | self.s = s 126 | name = f"H^{{{s}}}_{{per}}" if s != 0 else "L^{2}_{per}" 127 | super().__init__(name) 128 | 129 | def squared_norm(self, u: ArrayLike, x: ArrayLike) -> jax.Array: 130 | return jnp.mean(jnp.sum(u**2, axis=2), axis=1) 131 | 132 | def inner_product(self, u: ArrayLike, v: ArrayLike, x: ArrayLike) -> jax.Array: 133 | return jnp.mean(jnp.sum(u * v, axis=2), axis=1) 134 | 135 | @property 136 | def nonlocal_transform(self) -> NonlocalTransform: 137 | return (_dftn_transform, _dftn_inverse_transform) 138 | -------------------------------------------------------------------------------- /src/functional_autoencoders/domains/off_grid.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import jax 3 | import jax.numpy as jnp 4 | from jax.typing import ArrayLike 5 | 6 | 7 | from functional_autoencoders.domains import Domain, NonlocalTransform 8 | 9 | 10 | @partial(jax.vmap, in_axes=(0, 0)) 11 | def stochastic_integral( 12 | u: ArrayLike, 13 | v: ArrayLike, 14 | ) -> jax.Array: 15 | r""" 16 | Computes an approximation of the Itô stochastic integral 17 | 18 | $$ \int_{0}^{T} u_{t}^{T} \,\mathrm{d} v_{t} \approx \sum_{i} u_{t_{i}}^{T} \bigl( v_{t_{i+1}} - v_{t_{i}} \bigr)$$ 19 | 20 | for $\mathbb{R}^{d}$-valued processes $(u_{t})_{t \in [0, T]}$ and $(v_{t})_{t \in [0, T]}$ 21 | discretised on the same mesh. See page 45 of Särkkä and Solin (2019) for details. 22 | 23 | ## References 24 | Särkkä and Solin (2019). Applied Stochastic Differential Equations. Cambridge University Press, 25 | DOI: 10.1017/9781108186735. 26 | """ 27 | 28 | # The sum here is both computing the outer sum over timesteps and the sum in the dot product 29 | # u_{t_{i}}^{T} \bigl( v_{t_{i+1}} - v_{t_{i}} \bigr) 30 | return jnp.sum(u[:-1, :] * jnp.diff(v, axis=0)) 31 | 32 | 33 | class RandomlySampledEuclidean(Domain): 34 | s: float 35 | 36 | def __init__(self, s: float): 37 | if s != 0.0: 38 | raise NotImplementedError() 39 | self.s = s 40 | self.name = "L^{2}" 41 | 42 | def squared_norm(self, u: ArrayLike, x: ArrayLike) -> jax.Array: 43 | return jnp.mean(jnp.sum(u**2, axis=2), axis=1) 44 | 45 | def inner_product(self, u: ArrayLike, v: ArrayLike, x: ArrayLike) -> jax.Array: 46 | return jnp.mean(jnp.sum(u * v, axis=2), axis=1) 47 | 48 | @property 49 | def nonlocal_transform(self) -> NonlocalTransform: 50 | raise NotImplementedError("") 51 | 52 | 53 | class SDE(Domain): 54 | epsilon: float 55 | 56 | def __init__(self, epsilon: float, x0: float): 57 | self.epsilon = epsilon 58 | # x0 is not used directly here, but it will be accessed by the SDE loss 59 | self.x0 = x0 60 | name = "L^{2}" 61 | super().__init__(name) 62 | 63 | def squared_norm(self, u: ArrayLike, x: ArrayLike) -> jax.Array: 64 | dx = x[:, 1:, 0] - x[:, 0:-1, 0] 65 | squared_l2_norm = jnp.sum( 66 | jnp.sum(u[:, :-1, :] * u[:, :-1, :], axis=2) * dx, axis=1 67 | ) 68 | return self.epsilon ** (-1) * squared_l2_norm 69 | 70 | def inner_product(self, u: ArrayLike, v: ArrayLike, x: ArrayLike) -> jax.Array: 71 | return self.epsilon ** (-1) * stochastic_integral(u, v) 72 | 73 | @property 74 | def nonlocal_transform(self) -> NonlocalTransform: 75 | raise NotImplementedError("") 76 | -------------------------------------------------------------------------------- /src/functional_autoencoders/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import flax.linen as nn 3 | 4 | 5 | class Encoder(nn.Module): 6 | """ 7 | For general comments, see `Autoencoder` documentation. 8 | """ 9 | 10 | is_variational: bool 11 | 12 | def __call__(self, u, x): 13 | raise NotImplementedError() 14 | 15 | def get_latent_dim(self): 16 | return self.latent_dim 17 | 18 | 19 | def _apply_grid_encoder_operator( 20 | u, x, x_pos, operator, latent_dim, is_variational, pooling_fn, is_concat, is_grid 21 | ): 22 | if is_grid: 23 | input_dimension = x.shape[-1] 24 | n = round(x.shape[1] ** (1 / input_dimension)) 25 | x_shape = [u.shape[0]] + [n] * input_dimension + [x.shape[-1]] 26 | u_shape = [u.shape[0]] + [n] * input_dimension + [u.shape[-1]] 27 | x = jnp.reshape(x, x_shape) 28 | u = jnp.reshape(u, u_shape) 29 | 30 | u = operator(u, x) 31 | 32 | if is_concat: 33 | u = jnp.concatenate([u, x], axis=-1) 34 | 35 | u = pooling_fn(u, x_pos) 36 | d_out = latent_dim * 2 if is_variational else latent_dim 37 | u = nn.Dense(d_out, use_bias=False)(u) 38 | return u 39 | -------------------------------------------------------------------------------- /src/functional_autoencoders/encoders/cnn_encoder.py: -------------------------------------------------------------------------------- 1 | import flax.linen as nn 2 | import jax.numpy as jnp 3 | from typing import Sequence 4 | from functional_autoencoders.encoders import Encoder 5 | from functional_autoencoders.util.networks import CNN, MLP 6 | 7 | 8 | class CNNEncoder(Encoder): 9 | """A CNN-based encoder which assumes a fixed input mesh. 10 | 11 | Inputs: 12 | u : jnp.array of shape [batch, n_evals, out_dim] 13 | The input functions. 14 | 15 | x : jnp.array of shape [n_evals, in_dim] 16 | The input mesh, which is not actually used by the MLP encoder and is just assumed to be fixed for all 17 | data realisations. 18 | """ 19 | 20 | cnn_features: Sequence[int] = (8, 16, 32) 21 | mlp_features: Sequence[int] = (128, 128, 128) 22 | kernel_sizes: Sequence[int] = (2, 2, 2) 23 | strides: Sequence[int] = (2, 2, 2) 24 | latent_dim: int = 64 25 | 26 | @nn.compact 27 | def __call__(self, u, x, train=False): 28 | n = int(u.shape[1] ** 0.5) 29 | u = jnp.reshape(u, (-1, n, n, 1)) 30 | 31 | u = CNN(self.cnn_features, self.kernel_sizes, self.strides)(u) 32 | u = jnp.reshape(u, (u.shape[0], -1)) 33 | 34 | d_out = self.latent_dim * 2 if self.is_variational else self.latent_dim 35 | u = MLP([*self.mlp_features, d_out])(u) 36 | return u 37 | -------------------------------------------------------------------------------- /src/functional_autoencoders/encoders/fno_encoder.py: -------------------------------------------------------------------------------- 1 | import flax.linen as nn 2 | from dataclasses import field 3 | from functional_autoencoders.domains import Domain 4 | from functional_autoencoders.util.networks.fno import FNO 5 | from functional_autoencoders.util.networks.pooling import MLPKernelPooling 6 | from functional_autoencoders.encoders import Encoder, _apply_grid_encoder_operator 7 | from functional_autoencoders.positional_encodings import ( 8 | PositionalEncoding, 9 | IdentityEncoding, 10 | ) 11 | 12 | 13 | class FNOEncoder(Encoder): 14 | """ 15 | An FNO-based encoder. 16 | 17 | IMPORTANT: inputs must be provided on a regular grid for use with the FNO. 18 | See notes below. 19 | 20 | Inputs: 21 | u : jnp.array of shape [batch, n_evals, out_dim] 22 | Will be reshaped internally to shape [batch, ax_1, ..., ax_d, out_dim], 23 | where `d` is `in_dim` from the shape of `x`, and `ax_1 = ax_2 = ... = ax_d = n_evals ** (1/d)`. 24 | If `u` is not evaluated on a regular grid, this will either fail to reshape or incorrectly treat the 25 | data as if it were on a grid. 26 | 27 | x : jnp.array of shape [batch, n_evals, in_dim] 28 | The mesh on which the inputs are evaluated. Must be a regular grid in the domain $[0, 1]^{d}$, 29 | excluding the boundary. 30 | """ 31 | 32 | latent_dim: int 33 | domain: Domain 34 | fno_lifting_dim: int = 32 35 | fno_projection_dim: int = 4 36 | n_modes_per_dim: int = 12 37 | kernel_hidden_dim: int = 32 38 | kernel_n_layers: int = 2 39 | post_kernel_hidden_dim: int = 16 40 | post_kernel_n_layers: int = 3 41 | n_layers: int = 1 42 | fno_args: dict = field(default_factory=dict) 43 | positional_encoding: PositionalEncoding = IdentityEncoding() 44 | pooling_fn: nn.Module = MLPKernelPooling() 45 | pooling_concat_x: bool = True 46 | 47 | @nn.compact 48 | def __call__(self, u, x, train=False): 49 | n_modes = [[self.n_modes_per_dim] * x.shape[-1]] * self.n_layers 50 | lifting_features = [self.fno_lifting_dim] 51 | projection_features = [self.fno_projection_dim] 52 | 53 | operator = FNO( 54 | n_modes, 55 | lifting_features, 56 | projection_features, 57 | self.domain, 58 | **self.fno_args, 59 | ) 60 | 61 | u = _apply_grid_encoder_operator( 62 | u=u, 63 | x=x, 64 | x_pos=self.positional_encoding(x), 65 | operator=operator, 66 | pooling_fn=self.pooling_fn, 67 | latent_dim=self.latent_dim, 68 | is_variational=self.is_variational, 69 | is_concat=self.pooling_concat_x, 70 | is_grid=True, 71 | ) 72 | return u 73 | -------------------------------------------------------------------------------- /src/functional_autoencoders/encoders/lno_encoder.py: -------------------------------------------------------------------------------- 1 | import flax.linen as nn 2 | from dataclasses import field 3 | from functional_autoencoders.domains import Domain 4 | from functional_autoencoders.util.networks.lno import LNO 5 | from functional_autoencoders.util.networks.pooling import MLPKernelPooling 6 | from functional_autoencoders.encoders import Encoder, _apply_grid_encoder_operator 7 | from functional_autoencoders.positional_encodings import ( 8 | PositionalEncoding, 9 | IdentityEncoding, 10 | ) 11 | 12 | 13 | class LNOEncoder(Encoder): 14 | r""" 15 | Encoder based on the low-rank neural operator (LNO; Kovachki et al., 2023), where each layer consists of the 16 | low-rank linear update 17 | 18 | $$v(x) = \sum_{j = 1}^{r} \langle u, \psi_{j} \rangle_{L^{2}} \phi_{j}(x),$$ 19 | 20 | and the nonlinear update 21 | 22 | $$\sigma(Wu(x) + v(x)).$$ 23 | 24 | The rank $r$ is determined by `lno_n_rank`. 25 | Implicitly it is assumed that the data $u$ are evaluated on a grid or a random mesh, such that the Monte Carlo approximation 26 | to the $L^{2}$-inner product is accurate. 27 | 28 | ## References 29 | Kovachki et al. (2023). Neural operator: learning maps between function spaces with applications to PDEs. JMLR. 30 | """ 31 | 32 | latent_dim: int 33 | domain: Domain 34 | hidden_dim: int = 32 35 | n_layers: int = 1 36 | lno_n_rank: int = 12 37 | lno_mlp_n_layers: int = 2 38 | lno_mlp_n_dims: int = 32 39 | lno_args: dict = field(default_factory=dict) 40 | positional_encoding: PositionalEncoding = IdentityEncoding() 41 | pooling_fn: nn.Module = MLPKernelPooling() 42 | pooling_concat_x: bool = True 43 | 44 | @nn.compact 45 | def __call__(self, u, x, train=False): 46 | lifting_features = [self.hidden_dim] 47 | projection_features = [self.hidden_dim] 48 | 49 | lno_n_ranks = [self.lno_n_rank] * self.n_layers 50 | lno_mlp_hidden_features = [self.lno_mlp_n_dims] * self.lno_mlp_n_layers 51 | 52 | operator = LNO( 53 | domain=self.domain, 54 | n_ranks=lno_n_ranks, 55 | lifting_features=lifting_features, 56 | projection_features=projection_features, 57 | lno_mlp_hidden_features=lno_mlp_hidden_features, 58 | **self.lno_args, 59 | ) 60 | 61 | u = _apply_grid_encoder_operator( 62 | u=u, 63 | x=x, 64 | x_pos=self.positional_encoding(x), 65 | operator=operator, 66 | pooling_fn=self.pooling_fn, 67 | latent_dim=self.latent_dim, 68 | is_variational=self.is_variational, 69 | is_concat=self.pooling_concat_x, 70 | is_grid=False, 71 | ) 72 | return u 73 | -------------------------------------------------------------------------------- /src/functional_autoencoders/encoders/mlp_encoder.py: -------------------------------------------------------------------------------- 1 | import flax.linen as nn 2 | import jax.numpy as jnp 3 | from typing import Sequence 4 | from dataclasses import field 5 | from functional_autoencoders.encoders import Encoder 6 | from functional_autoencoders.util.networks import MLP 7 | 8 | 9 | class MLPEncoder(Encoder): 10 | """A MLP-based encoder which assumes a fixed input mesh. 11 | 12 | Inputs: 13 | u : jnp.array of shape [batch, n_evals, out_dim] 14 | The input functions. 15 | 16 | x : jnp.array of shape [n_evals, in_dim] 17 | The input mesh, which is not actually used by the MLP encoder and is just assumed to be fixed for all 18 | data realisations. 19 | """ 20 | 21 | features: Sequence[int] = (128, 128, 128) 22 | latent_dim: int = 64 23 | mlp_args: dict = field(default_factory=dict) 24 | 25 | @nn.compact 26 | def __call__(self, u, x, train=False): 27 | u = jnp.reshape(u, (u.shape[0], -1)) 28 | 29 | d_out = self.latent_dim * 2 if self.is_variational else self.latent_dim 30 | u = MLP([*self.features, d_out], **self.mlp_args)(u) 31 | return u 32 | -------------------------------------------------------------------------------- /src/functional_autoencoders/encoders/pooling_encoder.py: -------------------------------------------------------------------------------- 1 | import flax.linen as nn 2 | import jax.numpy as jnp 3 | from functional_autoencoders.util.networks.pooling import DeepSetPooling 4 | from functional_autoencoders.encoders import Encoder 5 | from functional_autoencoders.positional_encodings import ( 6 | PositionalEncoding, 7 | IdentityEncoding, 8 | ) 9 | 10 | 11 | class PoolingEncoder(Encoder): 12 | latent_dim: int 13 | pooling_fn: nn.Module = DeepSetPooling() 14 | positional_encoding: PositionalEncoding = IdentityEncoding() 15 | 16 | @nn.compact 17 | def __call__(self, u, x, train=False): 18 | x_pos = self.positional_encoding(x) 19 | 20 | u = jnp.concatenate([x_pos, u], axis=-1) 21 | z = self.pooling_fn(u, x_pos) 22 | 23 | d_out = self.latent_dim * 2 if self.is_variational else self.latent_dim 24 | z = nn.Dense(d_out)(z) 25 | return z 26 | -------------------------------------------------------------------------------- /src/functional_autoencoders/losses/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Loss functions. 3 | """ 4 | 5 | import jax 6 | from jax.typing import ArrayLike 7 | import jax.numpy as jnp 8 | 9 | 10 | def _diag_normal_unbatched( 11 | key: jax.random.PRNGKey, means: ArrayLike, log_variances: ArrayLike 12 | ) -> jax.Array: 13 | """Generates a realisation of $N(\mu, \Sigma)$ where $\Sigma$ is a diagonal matrix of variances. 14 | 15 | This version is unbatched and generally the batched version `diag_normal` will be more useful. 16 | """ 17 | cov_sqrt = jnp.sqrt(jnp.diag(jnp.exp(log_variances))) 18 | return cov_sqrt @ jax.random.normal(key, means.shape) + means 19 | 20 | 21 | _diag_normal = jax.vmap(_diag_normal_unbatched, (0, 0, 0)) 22 | 23 | 24 | def _kl_gaussian(means, log_variances): 25 | """KL divergence from $N(\mu, \Sigma)$ to $N(0, I)$, when $\Sigma$ is a diagonal matrix of variances. 26 | 27 | The matrix $\Sigma$ is represented by an array of *log-variances* representing the diagonal of the covariance matrix. 28 | """ 29 | n = means.shape[-1] 30 | logdets = jnp.sum(log_variances, axis=-1) 31 | traces = jnp.sum(jnp.exp(log_variances), axis=-1) 32 | return 0.5 * (-logdets - n + traces + jnp.sum(means * means, axis=-1)) 33 | 34 | 35 | def _call_autoencoder_fn(params, batch_stats, fn, u, x, name, dropout_key): 36 | variables = { 37 | "params": params[name], 38 | "batch_stats": (batch_stats if batch_stats else {}).get(name, {}), 39 | } 40 | result = fn( 41 | variables, 42 | u, 43 | x, 44 | train=True, 45 | mutable=["batch_stats"], 46 | rngs={"dropout": dropout_key}, 47 | ) 48 | return result 49 | -------------------------------------------------------------------------------- /src/functional_autoencoders/losses/fae.py: -------------------------------------------------------------------------------- 1 | import jax 2 | from jax.typing import ArrayLike 3 | import jax.numpy as jnp 4 | from functools import partial 5 | from functional_autoencoders.losses import _call_autoencoder_fn 6 | from functional_autoencoders.autoencoder import Autoencoder 7 | from functional_autoencoders.domains import Domain 8 | 9 | 10 | def get_loss_fae_fn( 11 | autoencoder: Autoencoder, 12 | domain: Domain, 13 | beta: float, 14 | subtract_data_norm: bool = False, 15 | ): 16 | if autoencoder.encoder.is_variational: 17 | raise NotImplementedError( 18 | "The FAE loss requires `is_variational` to be `False`." 19 | ) 20 | 21 | return partial( 22 | _get_loss_fae, 23 | encode_fn=autoencoder.encoder.apply, 24 | decode_fn=autoencoder.decoder.apply, 25 | domain=domain, 26 | beta=beta, 27 | subtract_data_norm=subtract_data_norm, 28 | ) 29 | 30 | 31 | def _get_loss_fae( 32 | params, 33 | key: jax.random.PRNGKey, 34 | batch_stats, 35 | u_enc: ArrayLike, 36 | x_enc: ArrayLike, 37 | u_dec: ArrayLike, 38 | x_dec: ArrayLike, 39 | encode_fn, 40 | decode_fn, 41 | domain: Domain, 42 | beta: float, 43 | subtract_data_norm: bool, 44 | ) -> jax.Array: 45 | 46 | # Encode input functions u 47 | key, dropout_key = jax.random.split(key) 48 | latents, encoder_updates = _call_autoencoder_fn( 49 | params=params, 50 | batch_stats=batch_stats, 51 | fn=encode_fn, 52 | u=u_enc, 53 | x=x_enc, 54 | name="encoder", 55 | dropout_key=dropout_key, 56 | ) 57 | 58 | # Decode latent variables 59 | key, dropout_key = jax.random.split(key) 60 | decoded, decoder_updates = _call_autoencoder_fn( 61 | params=params, 62 | batch_stats=batch_stats, 63 | fn=decode_fn, 64 | u=latents, 65 | x=x_dec, 66 | name="decoder", 67 | dropout_key=dropout_key, 68 | ) 69 | 70 | if subtract_data_norm: 71 | norms = 0.5 * domain.squared_norm(decoded, x_dec) 72 | inner_prods = domain.inner_product(decoded, u_dec, x_dec) 73 | reconstruction_terms = norms - inner_prods 74 | else: 75 | reconstruction_terms = 0.5 * domain.squared_norm(decoded - u_dec, x_dec) 76 | regularisation_terms = beta * jnp.sum(latents**2, axis=-1) 77 | 78 | batch_stats = { 79 | "encoder": encoder_updates["batch_stats"], 80 | "decoder": decoder_updates["batch_stats"], 81 | } 82 | 83 | loss_value = jnp.mean(reconstruction_terms) + jnp.mean(regularisation_terms) 84 | return loss_value, batch_stats 85 | -------------------------------------------------------------------------------- /src/functional_autoencoders/losses/fvae_sde.py: -------------------------------------------------------------------------------- 1 | import jax 2 | from jax.typing import ArrayLike 3 | import jax.numpy as jnp 4 | from functools import partial 5 | from functional_autoencoders.autoencoder import Autoencoder 6 | from functional_autoencoders.domains import Domain 7 | from functional_autoencoders.losses import ( 8 | _diag_normal, 9 | _kl_gaussian, 10 | _call_autoencoder_fn, 11 | ) 12 | 13 | 14 | def get_loss_fvae_sde_fn( 15 | autoencoder: Autoencoder, 16 | domain: Domain, 17 | n_monte_carlo_samples: int = 4, 18 | beta: float = 1, 19 | theta: float = 0.0, 20 | zero_penalty: float = 0.0, 21 | ): 22 | if not autoencoder.encoder.is_variational: 23 | raise NotImplementedError( 24 | "The FVAE SDE loss requires `is_variational` to be `True`" 25 | ) 26 | 27 | return partial( 28 | _get_loss_fvae_sde, 29 | encode_fn=autoencoder.encoder.apply, 30 | decode_fn=autoencoder.decoder.apply, 31 | domain=domain, 32 | n_monte_carlo_samples=n_monte_carlo_samples, 33 | beta=beta, 34 | theta=theta, 35 | zero_penalty=zero_penalty, 36 | ) 37 | 38 | 39 | def _get_loss_fvae_sde( 40 | params, 41 | key: jax.random.PRNGKey, 42 | batch_stats, 43 | u_enc: ArrayLike, 44 | x_enc: ArrayLike, 45 | u_dec: ArrayLike, 46 | x_dec: ArrayLike, 47 | encode_fn, 48 | decode_fn, 49 | domain: Domain, 50 | n_monte_carlo_samples: int, 51 | beta: float, 52 | theta: float, 53 | zero_penalty: float, 54 | ) -> jax.Array: 55 | if x_enc.shape[-1] != 1: 56 | raise NotImplementedError() 57 | 58 | # Encode input functions u 59 | key, dropout_key = jax.random.split(key) 60 | encoder_params, encoder_updates = _call_autoencoder_fn( 61 | params=params, 62 | batch_stats=batch_stats, 63 | fn=encode_fn, 64 | u=u_enc, 65 | x=x_enc, 66 | name="encoder", 67 | dropout_key=dropout_key, 68 | ) 69 | latent_dimension = encoder_params.shape[-1] // 2 70 | 71 | # Generate S Monte Carlo realisations from $\mathbb{Q}_{z \mid u}^{\phi}$ 72 | encoder_params = jnp.tile(encoder_params, (n_monte_carlo_samples, 1)) 73 | keys = jax.random.split(key, encoder_params.shape[0]) 74 | means = encoder_params[:, :latent_dimension] 75 | log_variances = encoder_params[:, latent_dimension:] 76 | latents = _diag_normal(keys, means, log_variances) 77 | 78 | # Decode the S Monte Carlo realisations 79 | tiling_shape = [n_monte_carlo_samples] + [1] * (x_dec.ndim - 1) 80 | x_tile = jnp.tile(x_dec, tiling_shape) 81 | 82 | key, dropout_key = jax.random.split(key) 83 | decoded, decoder_updates = _call_autoencoder_fn( 84 | params=params, 85 | batch_stats=batch_stats, 86 | fn=decode_fn, 87 | u=latents, 88 | x=x_tile, 89 | name="decoder", 90 | dropout_key=dropout_key, 91 | ) 92 | decoded_grads = _decoder_grad( 93 | decoder_apply=decode_fn, 94 | out_dim=u_dec.shape[-1], 95 | train=True, 96 | )(params, batch_stats, latents, x_tile) 97 | 98 | u_shape = [1] * u_dec.ndim 99 | u_shape[0] = n_monte_carlo_samples 100 | u_dec = jnp.tile(u_dec, u_shape) 101 | norms = 0.5 * domain.squared_norm( 102 | decoded_grads - theta * (u_dec - decoded), x_tile 103 | ) - 0.5 * domain.squared_norm(theta * u_dec, x_tile) 104 | inner_prods = domain.inner_product( 105 | decoded_grads, u_dec, x_tile 106 | ) + domain.inner_product(theta * decoded, u_dec, x_tile) 107 | reconstruction_terms = norms - inner_prods 108 | kl_divs = beta * _kl_gaussian(means, log_variances) 109 | 110 | batch_stats = { 111 | "encoder": encoder_updates["batch_stats"], 112 | "decoder": decoder_updates["batch_stats"], 113 | } 114 | 115 | x0 = jnp.expand_dims(domain.x0, 0) 116 | x0 = jnp.repeat(x0, decoded.shape[0], axis=0) 117 | loss_value = ( 118 | jnp.mean(reconstruction_terms) 119 | + jnp.mean(kl_divs) 120 | + zero_penalty * jnp.mean(jnp.sum((decoded[:, 0, :] - x0) ** 2, axis=-1)) 121 | ) 122 | 123 | return loss_value, batch_stats 124 | 125 | 126 | def _decoder_grad(decoder_apply, out_dim, train): 127 | def inner(variables, batch_stats, z, x): 128 | gs = [] 129 | for ax in range(out_dim): 130 | 131 | def decode(variables, batch_stats, z, x): 132 | batch_stats = {} if batch_stats is None else batch_stats 133 | return decoder_apply( 134 | { 135 | "params": variables["decoder"], 136 | "batch_stats": batch_stats.get("decoder", {}), 137 | }, 138 | z, 139 | x, 140 | train, 141 | )[0, 0, ax] 142 | 143 | g = jax.grad(decode, argnums=3)( 144 | variables, 145 | batch_stats, 146 | jnp.reshape(z, (1, -1)), 147 | jnp.reshape(x, (1, 1, 1)), 148 | )[0, 0, 0] 149 | 150 | gs.append(g) 151 | 152 | return jnp.array(gs) 153 | 154 | return jax.vmap( 155 | jax.vmap(inner, in_axes=(None, None, None, 0)), in_axes=(None, None, 0, 0) 156 | ) 157 | -------------------------------------------------------------------------------- /src/functional_autoencoders/losses/vano.py: -------------------------------------------------------------------------------- 1 | import jax 2 | from jax.typing import ArrayLike 3 | import jax.numpy as jnp 4 | from functools import partial 5 | from functional_autoencoders.autoencoder import Autoencoder 6 | from functional_autoencoders.losses import ( 7 | _diag_normal, 8 | _kl_gaussian, 9 | _call_autoencoder_fn, 10 | ) 11 | 12 | 13 | def get_loss_vano_fn( 14 | autoencoder: Autoencoder, 15 | n_monte_carlo_samples: int = 4, 16 | beta: float = 1, 17 | normalised_inner_prod: bool = True, 18 | rescale_by_norm: bool = True, 19 | ): 20 | r""" 21 | Computes the VANO loss of Seidman et al. (2023), which corresponds to the FVAE loss with white noise $\eta \sim N(0, I)$ 22 | in the VAE decoder model. 23 | 24 | Notes: 25 | - To follow the loss described by Seidman et al. (2023), set `normalised_inner_prod` to `False`, which corresponds to not normalising by :math:`1/N_{\text{points}}` in :math:`(\dagger)`. 26 | This choice can lead to instability across resolutions as the inner product is no longer correctly normalised but it may be useful for comparison. 27 | """ 28 | if not autoencoder.encoder.is_variational: 29 | raise NotImplementedError( 30 | "The VANO loss requires `is_variational` to be `True`" 31 | ) 32 | 33 | return partial( 34 | _get_loss_vano, 35 | encode_fn=autoencoder.encoder.apply, 36 | decode_fn=autoencoder.decoder.apply, 37 | n_monte_carlo_samples=n_monte_carlo_samples, 38 | beta=beta, 39 | normalised_inner_prod=normalised_inner_prod, 40 | rescale_by_norm=rescale_by_norm, 41 | ) 42 | 43 | 44 | def _get_loss_vano( 45 | params, 46 | key: jax.random.PRNGKey, 47 | batch_stats, 48 | u_enc: ArrayLike, 49 | x_enc: ArrayLike, 50 | u_dec: ArrayLike, 51 | x_dec: ArrayLike, 52 | encode_fn, 53 | decode_fn, 54 | n_monte_carlo_samples: int, 55 | beta: float, 56 | normalised_inner_prod: bool, 57 | rescale_by_norm: bool, 58 | ) -> jax.Array: 59 | 60 | scales = jnp.ones((u_dec.shape[0],)) 61 | if rescale_by_norm: 62 | scales = ( 63 | jnp.mean(jnp.sum(u_dec**2, axis=-1), axis=range(1, u_dec.ndim - 1)) 64 | ) ** (-0.5) 65 | scales = jnp.tile(scales, (n_monte_carlo_samples,)) 66 | 67 | # Encode input functions u 68 | key, dropout_key = jax.random.split(key) 69 | encoder_params, encoder_updates = _call_autoencoder_fn( 70 | params=params, 71 | batch_stats=batch_stats, 72 | fn=encode_fn, 73 | u=u_enc, 74 | x=x_enc, 75 | name="encoder", 76 | dropout_key=dropout_key, 77 | ) 78 | latent_dim = encoder_params.shape[-1] // 2 79 | 80 | # Generate S Monte Carlo realisations from $\mathbb{Q}_{z \mid u}^{\phi}$ 81 | encoder_params = jnp.tile(encoder_params, (n_monte_carlo_samples, 1)) 82 | keys = jax.random.split(key, encoder_params.shape[0]) 83 | means = encoder_params[:, :latent_dim] 84 | log_variances = encoder_params[:, latent_dim:] 85 | latents = _diag_normal(keys, means, log_variances) 86 | 87 | # Decode the S Monte Carlo realisations 88 | tiling_shape = [n_monte_carlo_samples] + [1] * (x_dec.ndim - 1) 89 | x_tile = jnp.tile(x_dec, tiling_shape) 90 | 91 | key, dropout_key = jax.random.split(key) 92 | decoded, decoder_updates = _call_autoencoder_fn( 93 | params=params, 94 | batch_stats=batch_stats, 95 | fn=decode_fn, 96 | u=latents, 97 | x=x_tile, 98 | name="decoder", 99 | dropout_key=dropout_key, 100 | ) 101 | 102 | # Estimate half squared L^{2} norm of each D_{\theta}(z) of shape [batch * S, points, out_dim] 103 | norms = 0.5 * jnp.mean(jnp.sum(decoded**2, axis=2), axis=1) 104 | 105 | # Tile the true data $u$ to have the same shape 106 | u_dec = jnp.tile(u_dec, (n_monte_carlo_samples, 1, 1)) 107 | 108 | # Estimate inner product term, optionally normalising (see documentation of this function for rationale) 109 | inner_prods = jnp.sum(decoded * u_dec, axis=(1, 2)) 110 | if normalised_inner_prod: 111 | inner_prods /= u_dec.shape[1] 112 | 113 | # Explicit formula for KL divergence for multivariate Gaussians with identity covariance 114 | # and different means 115 | reconstruction_terms = scales * (norms - inner_prods) 116 | kl_divs = beta * _kl_gaussian(means, log_variances) 117 | 118 | batch_stats = { 119 | "encoder": encoder_updates["batch_stats"], 120 | "decoder": decoder_updates["batch_stats"], 121 | } 122 | 123 | loss_value = jnp.mean(reconstruction_terms) + jnp.mean(kl_divs) 124 | return loss_value, batch_stats 125 | -------------------------------------------------------------------------------- /src/functional_autoencoders/positional_encodings/__init__.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import jax 3 | from jax.typing import ArrayLike 4 | import jax.numpy as jnp 5 | from functools import partial 6 | 7 | 8 | class PositionalEncoding: 9 | """Maps co-ordinates :math:`x` to a function :math:`\gamma(x)`. 10 | 11 | A positional encoding is a map :math:`\gamma(x)` mapping an `in_dim`-dimensional co-ordinate :math:`x` to an `encoding_dim`-dimensional 12 | positional encoding of that co-ordinate. 13 | Examples of positional encodings include Fourier features. 14 | 15 | All positional encodings should take as input: 16 | 17 | x : jnp.array of shape [n_evals, in_dim] 18 | 19 | and should return an another array of shape [n_evals, encoding_dim]. 20 | """ 21 | 22 | def __call__(self, x): 23 | raise NotImplementedError("`PositionalEncoding` must implement `__call__`.") 24 | 25 | 26 | @dataclass 27 | class FourierEncoding1D: 28 | k: int 29 | L: float = 1 30 | 31 | def __call__(self, x): 32 | return _fourier_features(x, self.k, self.L) 33 | 34 | 35 | @dataclass 36 | class RandomFourierEncoding: 37 | B: ArrayLike 38 | 39 | def __call__(self, x): 40 | return _random_fourier_features(x, self.B) 41 | 42 | 43 | class IdentityEncoding(PositionalEncoding): 44 | def __call__(self, x): 45 | return x 46 | 47 | 48 | @partial(jax.vmap, in_axes=(0, None, None)) 49 | def _fourier_features(x, k, L=1): 50 | r"""Computes Fourier features $[\cos(\omega x), \sin(\omega x), \dots, \cos(k\omega x), \sin(k\omega x)]$ for tensor of shape [queries, 1]. 51 | 52 | The factor $\omega$ is given by $2\pi / L$. 53 | 54 | Arguments: 55 | 56 | x : jnp.array of shape [queries, 1] 57 | 58 | k : int 59 | number of frequencies to include in Fourier features 60 | 61 | L : float 62 | length of the one-dimensional domain 63 | 64 | See Seidman et al. (2023), section C.3 for implementation details; this implementation differs slightly because it omits the constant `1` in the feature array. 65 | """ 66 | omega = 2 * jnp.pi / L 67 | a = jnp.arange(1, k + 1) * omega 68 | a = jnp.reshape(a, (1, -1)) 69 | return jnp.concatenate((jnp.cos(a * x), jnp.sin(a * x)), axis=-1) 70 | 71 | 72 | @partial(jax.vmap, in_axes=(0, None)) 73 | def _random_fourier_features(x, B): 74 | x = jnp.einsum("ij,qj->qi", 2 * jnp.pi * B, x) 75 | return jnp.concatenate((jnp.cos(x), jnp.sin(x)), axis=-1) 76 | -------------------------------------------------------------------------------- /src/functional_autoencoders/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | class SamplerBase: 2 | def __init__(self, autoencoder, state): 3 | self.autoencoder = autoencoder 4 | self.state = state 5 | 6 | def sample(self, x): 7 | raise NotImplementedError() 8 | -------------------------------------------------------------------------------- /src/functional_autoencoders/samplers/sampler_gmm.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | from sklearn import mixture 3 | from functional_autoencoders.samplers import SamplerBase 4 | 5 | 6 | class SamplerGMM(SamplerBase): 7 | def __init__(self, autoencoder, state, n_components): 8 | super().__init__(autoencoder, state) 9 | self.gmm = None 10 | self.n_components = n_components 11 | 12 | def fit(self, train_dataloader): 13 | gmm = mixture.GaussianMixture( 14 | n_components=self.n_components, 15 | covariance_type="full", 16 | max_iter=2000, 17 | verbose=0, 18 | tol=1e-3, 19 | ) 20 | 21 | z_dataset = self._get_z_dataset(train_dataloader) 22 | gmm.fit(z_dataset) 23 | 24 | self.gmm = gmm 25 | 26 | def sample(self, x): 27 | z_samples, _ = self.gmm.sample(x.shape[0]) 28 | u_samples = self.autoencoder.decode(self.state, z_samples, x, train=False) 29 | return u_samples 30 | 31 | def _get_z_dataset(self, train_dataloader): 32 | z_dataset = [] 33 | for u, x, _, _ in train_dataloader: 34 | z = self.autoencoder.encode(self.state, u, x, train=False) 35 | z_dataset.append(z) 36 | z_dataset = jnp.concatenate(z_dataset, axis=0) 37 | return z_dataset 38 | -------------------------------------------------------------------------------- /src/functional_autoencoders/samplers/sampler_vae.py: -------------------------------------------------------------------------------- 1 | import jax 2 | from functional_autoencoders.samplers import SamplerBase 3 | 4 | 5 | class SamplerVAE(SamplerBase): 6 | def __init__(self, autoencoder, state): 7 | super().__init__(autoencoder, state) 8 | 9 | def sample(self, x, key): 10 | latent_dim = self.autoencoder.get_latent_dim() 11 | latents = jax.random.normal(key, [x.shape[0], latent_dim]) 12 | decoded = self.autoencoder.decode(self.state, latents, x, train=False) 13 | return decoded 14 | -------------------------------------------------------------------------------- /src/functional_autoencoders/train/__init__.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import flax 3 | from flax.training import train_state 4 | 5 | 6 | class TrainNanError(Exception): 7 | pass 8 | 9 | 10 | class TrainState(train_state.TrainState): 11 | batch_stats: flax.core.FrozenDict 12 | key: jax.Array 13 | -------------------------------------------------------------------------------- /src/functional_autoencoders/train/autoencoder_trainer.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import jax 3 | import jax.numpy as jnp 4 | import optax 5 | from tqdm import tqdm 6 | from typing import Sequence, Literal, Union 7 | from functional_autoencoders.util import get_n_params 8 | from functional_autoencoders.train import ( 9 | TrainState, 10 | TrainNanError, 11 | ) 12 | from functional_autoencoders.autoencoder import Autoencoder 13 | from functional_autoencoders.train.metrics import Metric 14 | 15 | 16 | class AutoencoderTrainer: 17 | autoencoder: Autoencoder 18 | metrics: Sequence[Metric] 19 | 20 | def __init__( 21 | self, 22 | autoencoder: Autoencoder, 23 | loss_fn, 24 | metrics: Sequence[Metric], 25 | train_dataloader, 26 | test_dataloader, 27 | ): 28 | super().__init__() 29 | self.autoencoder = autoencoder 30 | self.loss_fn = loss_fn 31 | self.train_dataloader = train_dataloader 32 | self.test_dataloader = test_dataloader 33 | self.metrics = metrics 34 | self.metrics_history = {} 35 | self.training_loss_history = [] 36 | 37 | def _get_verbosity_level(self, verbose) -> Literal["full", "metrics", "none"]: 38 | if isinstance(verbose, bool): 39 | return "full" if verbose else "none" 40 | else: 41 | return verbose 42 | 43 | def fit( 44 | self, 45 | key, 46 | lr, 47 | lr_decay_step, 48 | lr_decay_factor, 49 | max_step, 50 | eval_interval=10, 51 | verbose: Union[bool, Literal["full", "metrics", "none"]] = False, 52 | ): 53 | """ 54 | Fits the `AutoencoderTrainer` to the training data provided by `train_dataloader` in the constructor, 55 | using the validation data from `test_dataloader` in the constructor to compute evaluation metrics. 56 | 57 | :param key: JAX pseudorandom number generator key (`jax.random.PRNGKey`) 58 | :param lr: learning rate 59 | :param lr_decay_step: along with `lr_decay_factor`, parameters for the [`optax.exponential_decay`](https://optax.readthedocs.io/en/latest/api/optimizer_schedules.html#optax.exponential_decay) learning-rate scheduler 60 | :param lr_decay_factor: see `lr_decay_step` 61 | :param max_step: training will finish when an epoch is complete and the total number of steps exceeds `max_step` 62 | :param eval_interval: number of epochs to train before a validation step 63 | :param verbose: verbosity of the `AutoencoderTrainer`. `"full"` shows a progress bar per epoch (for use in, e.g., interactive sessions). `"metrics"` prints only evaluation metrics every `eval_interval` epochs. `"none"` prints nothing. 64 | """ 65 | self._init_history() 66 | verbose = self._get_verbosity_level(verbose) 67 | 68 | key, subkey = jax.random.split(key) 69 | state = self._get_init_state(subkey, lr, lr_decay_step, lr_decay_factor) 70 | 71 | if verbose != "none": 72 | print(f"Parameter count: {get_n_params(state.params)}") 73 | 74 | train_step_fn = self._get_train_step_fn() 75 | 76 | epoch = 0 77 | step = 0 78 | while True: 79 | key, subkey = jax.random.split(key) 80 | state, step = self._train_one_epoch( 81 | subkey, state, step, train_step_fn, epoch, verbose 82 | ) 83 | 84 | if epoch % eval_interval == 0: 85 | key, subkey = jax.random.split(key) 86 | self._evaluate(subkey, state) 87 | self._print_metrics(epoch, verbose) 88 | 89 | if step >= max_step: 90 | return { 91 | "state": state, 92 | "training_loss_history": self.training_loss_history, 93 | "metrics_history": self.metrics_history, 94 | } 95 | 96 | epoch += 1 97 | 98 | def _train_one_epoch(self, key, state, step, train_step_fn, epoch, verbose): 99 | epoch_loss = 0.0 100 | for i, batch in enumerate( 101 | pbar := tqdm( 102 | self.train_dataloader, 103 | disable=(verbose != "full"), 104 | desc=f"epoch {epoch}", 105 | ) 106 | ): 107 | key, subkey = jax.random.split(key) 108 | loss_value, state = train_step_fn( 109 | subkey, 110 | state, 111 | batch, 112 | ) 113 | 114 | epoch_loss += loss_value 115 | step += 1 116 | 117 | if verbose == "full": 118 | pbar.set_description(f"epoch {epoch} (loss {loss_value:.3E})") 119 | if jnp.any(jnp.isnan(epoch_loss)): 120 | raise TrainNanError() 121 | 122 | epoch_loss /= i + 1 123 | self.training_loss_history.append(epoch_loss) 124 | 125 | return state, step 126 | 127 | def _evaluate(self, key, state): 128 | for metric in self.metrics: 129 | key, subkey = jax.random.split(key) 130 | self.metrics_history[metric.name].append( 131 | metric(state, subkey, self.test_dataloader) 132 | ) 133 | 134 | def _print_metrics(self, epoch, verbose): 135 | if verbose != "none": 136 | metric_string = " | ".join( 137 | [ 138 | f"{metric_name}: {self.metrics_history[metric_name][-1]:.3E}" 139 | for metric_name in self.metrics_history 140 | ] 141 | ) 142 | print(f"epoch {epoch:6} || {metric_string}") 143 | sys.stdout.flush() 144 | 145 | def _get_optimizer(self, lr, lr_decay_step, lr_decay_factor): 146 | schedule = optax.exponential_decay( 147 | init_value=lr, 148 | transition_steps=lr_decay_step, 149 | decay_rate=lr_decay_factor, 150 | ) 151 | optimizer = optax.adam(learning_rate=schedule) 152 | return optimizer 153 | 154 | def _get_init_variables(self, key): 155 | key, subkey = jax.random.split(key) 156 | init_u, init_x, _, _ = next(iter(self.train_dataloader)) 157 | variables = self.autoencoder.init(subkey, init_u, init_x, init_x) 158 | return variables 159 | 160 | def _get_init_state(self, key, lr, lr_decay_step, lr_decay_factor): 161 | optimizer = self._get_optimizer(lr, lr_decay_step, lr_decay_factor) 162 | 163 | key, subkey = jax.random.split(key) 164 | init_variables = self._get_init_variables(subkey) 165 | 166 | key, subkey = jax.random.split(key) 167 | state = TrainState.create( 168 | apply_fn=self.autoencoder.apply, 169 | params=init_variables["params"], 170 | tx=optimizer, 171 | batch_stats=( 172 | init_variables["batch_stats"] 173 | if "batch_stats" in init_variables 174 | else None 175 | ), 176 | key=subkey, 177 | ) 178 | return state 179 | 180 | def _init_history(self): 181 | self.metrics_history = {metric.name: [] for metric in self.metrics} 182 | self.training_loss_history = [] 183 | 184 | def _get_train_step_fn(self): 185 | @jax.jit 186 | def step_func(k, state, batch): 187 | u_dec, x_dec, u_enc, x_enc = batch 188 | grad_fn = jax.value_and_grad(self.loss_fn, has_aux=True) 189 | (loss_value, batch_stats), grads = grad_fn( 190 | state.params, 191 | key=k, 192 | batch_stats=state.batch_stats, 193 | u_enc=u_enc, 194 | x_enc=x_enc, 195 | u_dec=u_dec, 196 | x_dec=x_dec, 197 | ) 198 | state = state.apply_gradients(grads=grads) 199 | state = state.replace(batch_stats=batch_stats) 200 | return loss_value, state 201 | 202 | return step_func 203 | -------------------------------------------------------------------------------- /src/functional_autoencoders/util/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility functions. 3 | """ 4 | 5 | import os 6 | import jax 7 | import jax.numpy as jnp 8 | import yaml 9 | import dill as pickle 10 | from functools import partial 11 | 12 | 13 | def get_n_params(variables): 14 | """ 15 | Computes the number of trainable parameters for the specified `variables` object. 16 | """ 17 | return sum(x.size for x in jax.tree_util.tree_leaves(variables)) 18 | 19 | 20 | def get_raw_x(h, w): 21 | x_mesh_list = jnp.meshgrid( 22 | jnp.linspace(0, 1, h + 2)[1:-1], 23 | jnp.linspace(0, 1, w + 2)[1:-1], 24 | indexing="ij", 25 | ) 26 | xx = jnp.concatenate([jnp.expand_dims(v, -1) for v in x_mesh_list], axis=-1) 27 | xx = xx.reshape(-1, 2) 28 | return xx 29 | 30 | 31 | def pickle_save(obj, save_path): 32 | os.makedirs(os.path.dirname(save_path), exist_ok=True) 33 | with open(save_path, "wb") as file: 34 | pickle.dump(obj, file) 35 | 36 | 37 | def pickle_load(save_path): 38 | with open(save_path, "rb") as file: 39 | return pickle.load(file) 40 | 41 | 42 | def yaml_load(save_path): 43 | with open(save_path, "r") as file: 44 | return yaml.safe_load(file) 45 | 46 | 47 | def fit_trainer_using_config(key, trainer, config, verbose=False): 48 | results = trainer.fit( 49 | key=key, 50 | lr=config["trainer"]["lr"], 51 | lr_decay_step=config["trainer"]["lr_decay_step"], 52 | lr_decay_factor=config["trainer"]["lr_decay_factor"], 53 | max_step=config["trainer"]["max_step"], 54 | eval_interval=config["trainer"]["eval_interval"], 55 | verbose=verbose, 56 | ) 57 | return results 58 | 59 | 60 | def save_data_results( 61 | autoencoder, results, test_dataloader, data_dir, additional_data={} 62 | ): 63 | 64 | state = results["state"] 65 | u, x, _, _ = next(iter(test_dataloader)) 66 | u_hat = autoencoder.apply( 67 | {"params": state.params, "batch_stats": state.batch_stats}, u, x, x 68 | ) 69 | reconstructions = {"u": u, "u_hat": u_hat, "x": x} 70 | 71 | light_results = { 72 | "training_results": results, 73 | "reconstructions": reconstructions, 74 | "additional_data": additional_data, 75 | } 76 | 77 | path_results = os.path.join(data_dir, "data.pickle") 78 | pickle_save(light_results, path_results) 79 | 80 | 81 | def save_model_results(autoencoder, results, model_dir): 82 | pickle_save( 83 | { 84 | "autoencoder": autoencoder, 85 | "results": results, 86 | }, 87 | os.path.join(model_dir, "model.pkl"), 88 | ) # Use .pkl to ignore in git 89 | 90 | 91 | @partial(jax.vmap, in_axes=(0, None)) 92 | def get_transition_matrix(u_bucket, n): 93 | P = jnp.zeros((n, n)) 94 | for i in range(n): 95 | for j in range(n): 96 | P = P.at[i, j].set(jnp.sum((u_bucket[1:] == j) & (u_bucket[:-1] == i))) 97 | row_sums = jnp.sum(P, axis=1) 98 | P = jnp.where(row_sums[:, None] == 0, jnp.ones_like(P) / n, P / row_sums[:, None]) 99 | return P 100 | 101 | 102 | @partial(jax.vmap, in_axes=(0, None, None)) 103 | def bucket_data(u, x_locs, y_locs): 104 | u_bucket = -jnp.ones(u.shape[0]) 105 | for i in range(len(y_locs) - 1): 106 | for j in range(len(x_locs) - 1): 107 | mask = ( 108 | (u[:, 0] > x_locs[j]) 109 | & (u[:, 0] <= x_locs[j + 1]) 110 | & (u[:, 1] > y_locs[i]) 111 | & (u[:, 1] <= y_locs[i + 1]) 112 | ) 113 | u_bucket = jnp.where(mask, i * (len(x_locs) - 1) + j, u_bucket) 114 | return u_bucket 115 | -------------------------------------------------------------------------------- /src/functional_autoencoders/util/anti_aliasing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from abc import ABC, abstractmethod 3 | from scipy.ndimage import gaussian_filter 4 | 5 | 6 | class AntiAliasingManagerBase(ABC): 7 | """ 8 | Base class that allows for antialiased downsampling and upsampling. 9 | 10 | Users should not instantiate this class directly, and instead use `AntiAliasingManagerFourier`. 11 | """ 12 | 13 | @abstractmethod 14 | def lowpass(self, z, sample_rate_h, sample_rate_w): 15 | pass 16 | 17 | def upsample(self, z, scale_factor=2): 18 | z_up = self._upsample_with_zero_insertion(z, scale_factor) 19 | z_up_low = self.lowpass(z_up, z.shape[-2], z.shape[-1]) * scale_factor**2 20 | return z_up_low 21 | 22 | def downsample(self, z, scale_factor=2): 23 | z_low = self.lowpass(z, z.shape[-2] // scale_factor, z.shape[-1] // scale_factor) 24 | z_low_down = z_low[:, ::scale_factor, ::scale_factor] 25 | return z_low_down 26 | 27 | def _upsample_with_zero_insertion(self, x, stride=2): 28 | *cdims, Hin, Win = x.shape 29 | Hout = stride * Hin 30 | Wout = stride * Win 31 | out = x.new_zeros(*cdims, Hout, Wout) 32 | out[..., ::stride, ::stride] = x 33 | return out 34 | 35 | 36 | class AntiAliasingManagerFourier(AntiAliasingManagerBase): 37 | """ 38 | Allows for downsampling and upsampling using a smoothed filter applied in Fourier space. 39 | 40 | The low-pass filter is a mollification of an ideal $\mathrm{sinc}$ filter with bandwidth 41 | selected to eliminate frequencies beyond the Nyquist frequency of the target resolution, 42 | computed in practice by convolving the ideal filter in Fourier space with a Gaussian kernel 43 | of with standard deviation `gaussian_sigma`, truncated to a `mask_blur_kernel_size` convolutional filter. 44 | """ 45 | 46 | def __init__(self, cutoff_nyq, mask_blur_kernel_size, gaussian_sigma): 47 | self.cutoff_nyq = cutoff_nyq 48 | self.mask_blur_kernel_size = mask_blur_kernel_size 49 | self.gaussian_sigma = gaussian_sigma 50 | 51 | def lowpass(self, z, sample_rate_h, sample_rate_w): 52 | dft = np.fft.fft2(z, norm="ortho") 53 | dft_shift = np.fft.fftshift(dft) 54 | 55 | mask = self._get_blurred_mask(z, sample_rate_h, sample_rate_w) 56 | dft_shift_masked = np.multiply(dft_shift, mask) 57 | 58 | back_ishift_masked = np.fft.ifftshift(dft_shift_masked) 59 | 60 | z_filtered = np.fft.ifft2(back_ishift_masked, norm="ortho").real 61 | return z_filtered 62 | 63 | def _get_blurred_mask(self, z, sample_rate_h, sample_rate_w): 64 | radius = self.mask_blur_kernel_size - 1 // 2 65 | mask = self._get_ideal_square_mask(z, sample_rate_h, sample_rate_w) 66 | mask = gaussian_filter(mask, self.gaussian_sigma, radius=radius) 67 | return mask 68 | 69 | def _get_ideal_square_mask(self, z, sample_rate_h, sample_rate_w): 70 | h, w = z.shape[-2:] 71 | mask = np.zeros((h, w)) 72 | modes_h = int(self.cutoff_nyq * sample_rate_h) 73 | modes_w = int(self.cutoff_nyq * sample_rate_w) 74 | 75 | if (h - modes_h) % 2 != 0: 76 | modes_h -= 1 77 | 78 | if (w - modes_w) % 2 != 0: 79 | modes_w -= 1 80 | 81 | start_idx_h = (h - modes_h) // 2 82 | start_idx_w = (w - modes_w) // 2 83 | 84 | end_idx_h = start_idx_h + modes_h 85 | end_idx_w = start_idx_w + modes_w 86 | 87 | mask[start_idx_h:end_idx_h, start_idx_w:end_idx_w] = 1 88 | 89 | return mask 90 | -------------------------------------------------------------------------------- /src/functional_autoencoders/util/fft.py: -------------------------------------------------------------------------------- 1 | import jax 2 | from jax.typing import ArrayLike 3 | import jax.numpy as jnp 4 | from typing import Literal 5 | 6 | 7 | def _invert_norm( 8 | norm: Literal["backward", "ortho", "forward"] 9 | ) -> Literal["backward", "ortho", "forward"]: 10 | if norm == "backward": 11 | return "forward" 12 | elif norm == "forward": 13 | return "backward" 14 | return norm 15 | 16 | 17 | def dst( 18 | x: ArrayLike, 19 | type: Literal[1, 2, 3, 4] = 2, 20 | axis: int = -1, 21 | norm: Literal["backward", "ortho", "forward"] = "backward", 22 | ) -> jax.Array: 23 | """ 24 | LAX-backed implementation of the discrete sine transform (DST), following the `scipy.fft.dst` API and its defaults. 25 | 26 | Notes: 27 | - Currently only implements DST-I, emulating `scipy.fft.dst` with `type = 1`. 28 | - The algorithm used is a simple implementation of the DST based on the discrete Fourier transform (DFT), which 29 | forms an intermediate tensor of double the input size. 30 | 31 | ## References 32 | Press, Teukolsky, Vetterling & Flannery (2007). Numerical recipes in C: the art of scientific computing. 33 | 3rd ed. Cambridge University Press. ISBN: 9780521880688. 34 | """ 35 | if type != 1: 36 | raise NotImplementedError() 37 | 38 | norm = _invert_norm(norm) 39 | shape = list(x.shape) 40 | shape[axis] = 1 41 | xaug = jnp.concatenate( 42 | [ 43 | jnp.zeros(shape), 44 | x, 45 | jnp.zeros(shape), 46 | -jnp.flip(x, axis=axis), 47 | ], 48 | axis=axis, 49 | ) 50 | xhat = jnp.fft.ifft(xaug, norm=norm, axis=axis) 51 | idx = [slice(0, dim) for dim in x.shape] 52 | idx[axis] = slice(1, x.shape[axis] + 1) 53 | return (xhat.imag)[tuple(idx)] 54 | 55 | 56 | def dstn( 57 | x: ArrayLike, 58 | type: Literal[1, 2, 3, 4] = 2, 59 | axes=None, 60 | norm: Literal["backward", "ortho", "forward"] = "backward", 61 | ) -> jax.Array: 62 | """ 63 | LAX-backed implementation of the $n$-dimensional discrete sine transform (DST), emulating the `scipy.fft.dstn` API. 64 | 65 | The implementation is based on repeated application of `dst` in each relevant axis (all axes by default, unless `axes` is 66 | specified). 67 | 68 | ## References 69 | Press, Teukolsky, Vetterling & Flannery (2007). Numerical recipes in C: the art of scientific computing. 70 | 3rd ed. Cambridge University Press. ISBN: 9780521880688. 71 | """ 72 | if type != 1: 73 | raise NotImplementedError() 74 | 75 | if axes is None: 76 | axes = range(0, x.ndim) 77 | 78 | for axis in axes: 79 | x = dst(x, type=type, axis=axis, norm=norm) 80 | return x 81 | 82 | 83 | def idst( 84 | x: ArrayLike, 85 | type: Literal[1, 2, 3, 4] = 2, 86 | axis: int = -1, 87 | norm: Literal["backward", "ortho", "forward"] = "backward", 88 | ) -> jax.Array: 89 | """ 90 | LAX-backed implementation of the inverse discrete sine transform (DST) emulating the `scipy.fft.idst` API. 91 | 92 | ## References 93 | Press, Teukolsky, Vetterling & Flannery (2007). Numerical recipes in C: the art of scientific computing. 94 | 3rd ed. Cambridge University Press. ISBN: 9780521880688. 95 | """ 96 | norm = _invert_norm(norm) 97 | return dst(x, type, axis, norm) 98 | 99 | 100 | def idstn( 101 | x: ArrayLike, 102 | type: Literal[1, 2, 3, 4] = 2, 103 | axes=None, 104 | norm: Literal["backward", "ortho", "forward"] = "backward", 105 | ) -> jax.Array: 106 | """ 107 | LAX-backed implementation of the $n$-dimensional inverse discrete sine transform (DST), 108 | emulating the `scipy.fft.idstn` API. 109 | 110 | ## References 111 | Press, Teukolsky, Vetterling & Flannery (2007). Numerical recipes in C: the art of scientific computing. 112 | 3rd ed. Cambridge University Press. ISBN: 9780521880688. 113 | """ 114 | norm = _invert_norm(norm) 115 | return dstn(x, type, axes, norm) 116 | -------------------------------------------------------------------------------- /src/functional_autoencoders/util/masks.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import numpy as np 3 | from jax import lax 4 | import jax.numpy as jnp 5 | from functional_autoencoders.util import get_raw_x 6 | from functional_autoencoders.util.random.grf import torus_grf 7 | 8 | 9 | def get_mask_grf_torus(key, u, threshold, tau=3, d=2): 10 | key, subkey = jax.random.split(key) 11 | grid_pts = int(u.shape[1] ** 0.5) 12 | mask = ( 13 | torus_grf(subkey, n=1, shape=(grid_pts, grid_pts), tau=tau, d=d).real.flatten() 14 | > threshold 15 | ) 16 | return mask 17 | 18 | 19 | def get_mask_uniform(key, u, mask_ratio): 20 | key, subkey = jax.random.split(key) 21 | mask = jax.random.bernoulli(subkey, mask_ratio, shape=(u.shape[1],)) 22 | return mask 23 | 24 | 25 | def get_mask_random_circle(key, u, radius): 26 | key, subkey = jax.random.split(key) 27 | random_mean = jax.random.uniform(subkey, shape=(2,)) 28 | random_mean = random_mean * (1 - 2 * radius) + radius 29 | grid_pts = int(u.shape[1] ** 0.5) 30 | xx = get_raw_x(grid_pts, grid_pts) 31 | mask = jnp.linalg.norm(xx - jnp.array(random_mean), axis=1) < radius 32 | return mask 33 | 34 | 35 | def get_mask_rect(key, u, h, w): 36 | grid_pts = int(u.shape[1] ** 0.5) 37 | 38 | key, k1, k2 = jax.random.split(key, 3) 39 | a = jax.random.randint(k1, minval=0, maxval=grid_pts - h + 1, shape=()) 40 | b = jax.random.randint(k2, minval=0, maxval=grid_pts - w + 1, shape=()) 41 | 42 | mask = jnp.zeros((grid_pts, grid_pts), dtype=bool) 43 | rect = jnp.ones((h, w), dtype=bool) 44 | 45 | mask = lax.dynamic_update_slice(mask, rect, (a, b)).flatten() 46 | return mask 47 | 48 | 49 | def get_mask_rect_np(u, h, w): 50 | grid_pts = int(u.shape[1] ** 0.5) 51 | 52 | a = np.random.randint(0, grid_pts - h + 1, size=()) 53 | b = np.random.randint(0, grid_pts - w + 1, size=()) 54 | 55 | mask = np.zeros((grid_pts, grid_pts), dtype=bool) 56 | mask[a : a + h, b : b + w] = True 57 | mask = mask.flatten() 58 | 59 | return mask 60 | -------------------------------------------------------------------------------- /src/functional_autoencoders/util/networks/__init__.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import flax.linen as nn 4 | from typing import Sequence, Any, Callable 5 | 6 | 7 | Initializer = Callable[[jax.random.PRNGKey, Sequence[int], Any], Any] 8 | 9 | 10 | class MLP(nn.Module): 11 | """ 12 | Multilayer perceptron (MLP) neural network mapping between vectors. 13 | """ 14 | 15 | features: Sequence[int] 16 | """ 17 | Widths for each layer of the MLP. The final value is the width of the output layer. 18 | """ 19 | 20 | act: Callable = lambda x: nn.gelu(x) 21 | """ 22 | The activation function used between hidden layers (no activation is used on the output layer). 23 | Default is the GELU activation (Hendrycks and Gimpel, 2016); others are available through the 24 | [`flax.linen.activation`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/activation_functions.html) 25 | module. 26 | """ 27 | 28 | kernel_init: Initializer = None 29 | bias_init: Initializer = None 30 | 31 | use_bias: bool = True 32 | """ 33 | Enables the use of a bias for each layer, including the output layer. 34 | """ 35 | 36 | @nn.compact 37 | def __call__(self, x): 38 | kwargs = {} 39 | if self.kernel_init is not None: 40 | kwargs["kernel_init"] = self.kernel_init 41 | if self.bias_init is not None: 42 | kwargs["bias_init"] = self.bias_init 43 | 44 | for feat in self.features[:-1]: 45 | x = self.act(nn.Dense(feat, use_bias=self.use_bias, **kwargs)(x)) 46 | x = nn.Dense(self.features[-1], use_bias=self.use_bias)(x) 47 | return x 48 | 49 | 50 | class CNN(nn.Module): 51 | features: Sequence[int] 52 | kernel_sizes: Sequence[int] 53 | strides: Sequence[int] 54 | act: Callable = lambda x: nn.gelu(x) 55 | is_transpose: bool = False 56 | 57 | @nn.compact 58 | def __call__(self, x): 59 | n_layers = len(self.features) 60 | conv_fn = nn.Conv if not self.is_transpose else nn.ConvTranspose 61 | for i in range(n_layers - 1): 62 | x = self.act( 63 | conv_fn( 64 | features=self.features[i], 65 | kernel_size=(self.kernel_sizes[i], self.kernel_sizes[i]), 66 | strides=(self.strides[i], self.strides[i]), 67 | )(x) 68 | ) 69 | x = conv_fn( 70 | features=self.features[n_layers - 1], 71 | kernel_size=( 72 | self.kernel_sizes[n_layers - 1], 73 | self.kernel_sizes[n_layers - 1], 74 | ), 75 | strides=(self.strides[n_layers - 1], self.strides[n_layers - 1]), 76 | )(x) 77 | return x 78 | 79 | 80 | class MultiheadAttentionBlock(nn.Module): 81 | n_heads: int 82 | mlp_dim: int = 128 83 | mlp_n_hidden_layers: int = 2 84 | 85 | @nn.compact 86 | def __call__(self, X, Y): 87 | MH = MultiheadLinearAttentionLayer(n_heads=self.n_heads)(X, Y, Y) 88 | H = nn.LayerNorm()(X + MH) 89 | rFF = MLP([*[self.mlp_dim] * self.mlp_n_hidden_layers, H.shape[-1]])(H) 90 | X_out = nn.LayerNorm()(H + rFF) 91 | return X_out 92 | 93 | 94 | class MultiheadLinearAttentionLayer(nn.Module): 95 | n_heads: int 96 | 97 | @nn.compact 98 | def __call__(self, X_q, X_k, X_v): 99 | heads = [] 100 | for _ in range(self.n_heads): 101 | h = AttentionLayer(dim_attn=X_v.shape[-1])(X_q, X_k, X_v) 102 | heads.append(h) 103 | 104 | heads_concat = jnp.concatenate(heads, axis=-1) 105 | 106 | X_out = nn.Dense(X_v.shape[-1], use_bias=False)(heads_concat) 107 | return X_out 108 | 109 | 110 | class AttentionLayer(nn.Module): 111 | dim_attn: int 112 | 113 | @nn.compact 114 | def __call__(self, X_q, X_k, X_v): 115 | Q = nn.Dense(self.dim_attn, use_bias=False)(X_q) 116 | K = nn.Dense(self.dim_attn, use_bias=False)(X_k) 117 | V = nn.Dense(self.dim_attn, use_bias=False)(X_v) 118 | 119 | P = jnp.einsum("bkd,bnd->bkn", Q, K) 120 | P = P / jnp.sqrt(self.dim_attn) 121 | P = jax.nn.softmax(P, axis=-1) 122 | 123 | X_out = jnp.einsum("bkn,bnd->bkd", P, V) 124 | return X_out 125 | -------------------------------------------------------------------------------- /src/functional_autoencoders/util/networks/fno.py: -------------------------------------------------------------------------------- 1 | import flax.linen as nn 2 | from typing import Sequence 3 | import string 4 | import jax.numpy as jnp 5 | 6 | from functional_autoencoders.domains import Domain 7 | from functional_autoencoders.util.networks import MLP, Initializer 8 | from functional_autoencoders.domains import Domain 9 | 10 | 11 | class FNOLayer(nn.Module): 12 | n_modes: Sequence[int] 13 | domain: Domain 14 | R_init: Initializer = nn.initializers.glorot_normal() 15 | act = nn.gelu 16 | 17 | @nn.compact 18 | def __call__(self, u): 19 | if u.ndim >= 25: 20 | raise NotImplementedError( 21 | "Einsum string unable to handle >= 23-dimensional domain." 22 | ) 23 | 24 | transform, inverse_transform = self.domain.nonlocal_transform 25 | 26 | # Compute pointwise matrix multiplication 27 | Wu = nn.Dense(u.shape[-1], use_bias=False)(u) 28 | 29 | # Compute spectral convolution 30 | uhat = transform(u) 31 | kmax_shape = tuple( 32 | [slice(None)] + [slice(kmax) for kmax in self.n_modes] + [slice(None)] 33 | ) 34 | uhat = uhat[kmax_shape] 35 | R_shape = [*self.n_modes, u.shape[-1], u.shape[-1]] 36 | R_real = self.param("R_real", self.R_init, tuple(R_shape)) 37 | R_cplx = self.param("R_cplx", self.R_init, tuple(R_shape)) 38 | R = R_real + 1j * R_cplx 39 | 40 | # Build the einsum dynamically so it works for any input dimension 41 | weight_shape = string.ascii_lowercase[: uhat.ndim - 2] + "xy" 42 | uhat_shape = "z" + string.ascii_lowercase[: uhat.ndim - 2] + "y" 43 | out_shape = "z" + string.ascii_lowercase[: uhat.ndim - 2] + "x" 44 | prod = jnp.einsum(f"{weight_shape},{uhat_shape}->{out_shape}", R, uhat) 45 | 46 | # Create a larger zero-padded array and transform back 47 | uhat_ext = jnp.zeros_like(u, dtype=jnp.complex64) 48 | uhat_ext = uhat_ext.at[kmax_shape].set(prod) 49 | Ku = inverse_transform(uhat_ext) 50 | return self.act(Wu + Ku) 51 | 52 | 53 | class FNO(nn.Module): 54 | n_modes: Sequence[Sequence[int]] 55 | lifting_features: Sequence[int] 56 | projection_features: Sequence[int] 57 | domain: Domain 58 | act = None 59 | R_init: Initializer = None 60 | 61 | mlp_init: Initializer = None 62 | mlp_bias: bool = True 63 | 64 | @nn.compact 65 | def __call__(self, u, x): 66 | fno_kwargs = {} 67 | mlp_kwargs = {} 68 | 69 | if self.mlp_init is not None: 70 | mlp_kwargs["initializer"] = self.mlp_init 71 | 72 | if self.act is not None: 73 | fno_kwargs["act"] = self.act 74 | mlp_kwargs["act"] = self.act 75 | 76 | if self.R_init is not None: 77 | fno_kwargs["R_init"] = self.R_init 78 | 79 | u = MLP(self.lifting_features, **mlp_kwargs)(u) 80 | for layer_modes in self.n_modes: 81 | u = FNOLayer(layer_modes, self.domain, **fno_kwargs)(u) 82 | u = MLP(self.projection_features, **mlp_kwargs)(u) 83 | return u 84 | 85 | 86 | class FNO1D(nn.Module): 87 | n_modes: Sequence[int] 88 | lifting_features: Sequence[int] 89 | projection_features: Sequence[int] 90 | domain: Domain 91 | act = None 92 | 93 | R_init: Initializer = None 94 | 95 | mlp_init: Initializer = None 96 | mlp_bias: bool = True 97 | 98 | @nn.compact 99 | def __call__(self, u, x): 100 | return FNO( 101 | n_modes=[(mode,) for mode in self.n_modes], 102 | lifting_features=self.lifting_features, 103 | projection_features=self.projection_features, 104 | domain=self.domain, 105 | act=self.act, 106 | R_init=self.R_init, 107 | mlp_init=self.mlp_init, 108 | mlp_bias=self.mlp_bias, 109 | )(u) 110 | -------------------------------------------------------------------------------- /src/functional_autoencoders/util/networks/lno.py: -------------------------------------------------------------------------------- 1 | import flax.linen as nn 2 | import jax.numpy as jnp 3 | from typing import Sequence 4 | from functional_autoencoders.util.networks import MLP, Initializer 5 | from functional_autoencoders.domains import Domain 6 | 7 | 8 | class LNOLayer(nn.Module): 9 | n_rank: int 10 | R_init: Initializer = nn.initializers.glorot_normal() 11 | act = nn.gelu 12 | mlp_hidden_features: Sequence[int] = (128, 128) 13 | 14 | @nn.compact 15 | def __call__(self, u, x): 16 | x = x / jnp.max(jnp.abs(x)) 17 | Wu = nn.Dense(u.shape[-1], use_bias=False)(u) 18 | 19 | f = MLP([*self.mlp_hidden_features, 2 * self.n_rank * u.shape[-1]])(x) 20 | f = f.reshape(*x.shape[:-1], 2 * self.n_rank, u.shape[-1]) 21 | phi, psi = jnp.split(f, 2, axis=-2) 22 | l2_inner_prods = (psi * jnp.expand_dims(u, axis=-2)).mean(axis=-3) 23 | l2_inner_prods = jnp.expand_dims(l2_inner_prods, axis=-3) 24 | Ku = jnp.sum(l2_inner_prods * phi, axis=-2) 25 | 26 | return self.act(Wu + Ku) 27 | 28 | 29 | class LNO(nn.Module): 30 | domain: Domain 31 | n_ranks: Sequence[int] 32 | lifting_features: Sequence[int] 33 | projection_features: Sequence[int] 34 | act = None 35 | R_init: Initializer = None 36 | mlp_init: Initializer = None 37 | lno_mlp_hidden_features: Sequence[int] = (128, 128) 38 | 39 | @nn.compact 40 | def __call__(self, u, x): 41 | fno_kwargs = {} 42 | mlp_kwargs = {} 43 | 44 | if self.mlp_init is not None: 45 | mlp_kwargs["initializer"] = self.mlp_init 46 | 47 | if self.act is not None: 48 | fno_kwargs["act"] = self.act 49 | mlp_kwargs["act"] = self.act 50 | 51 | if self.R_init is not None: 52 | fno_kwargs["R_init"] = self.R_init 53 | 54 | u = MLP(self.lifting_features, **mlp_kwargs)(u) 55 | for n_rank in self.n_ranks: 56 | u = LNOLayer( 57 | n_rank=n_rank, 58 | mlp_hidden_features=self.lno_mlp_hidden_features, 59 | **fno_kwargs, 60 | )(u, x) 61 | u = MLP(self.projection_features, **mlp_kwargs)(u) 62 | return u 63 | -------------------------------------------------------------------------------- /src/functional_autoencoders/util/networks/pooling.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import flax.linen as nn 3 | from functional_autoencoders.util.networks import MLP, MultiheadLinearAttentionLayer 4 | 5 | 6 | class MLPKernelPooling(nn.Module): 7 | mlp_dim: int = 128 8 | mlp_n_hidden_layers: int = 2 9 | 10 | @nn.compact 11 | def __call__(self, u, x): 12 | u_dim = u.shape[-1] 13 | hidden_features = [self.mlp_dim] * self.mlp_n_hidden_layers 14 | mlp_features = [*hidden_features, self.mlp_dim * u_dim] 15 | 16 | kernel_eval_shape = [*x.shape[:-1], self.mlp_dim, u_dim] 17 | kernel_evals = MLP(mlp_features)(x).reshape(kernel_eval_shape) 18 | 19 | z = jnp.einsum("...xy,...y->...x", kernel_evals, u) 20 | z = z.mean(axis=range(1, z.ndim - 1)) 21 | return z 22 | 23 | 24 | class MultiheadAttentionPooling(nn.Module): 25 | n_heads: int = 2 26 | mlp_dim: int = 128 27 | mlp_n_hidden_layers: int = 2 28 | 29 | @nn.compact 30 | def __call__(self, u, x): 31 | indices = jnp.arange(1, dtype=jnp.int32) 32 | z = MLP([self.mlp_dim] * self.mlp_n_hidden_layers)(u) 33 | s = nn.Embed(1, z.shape[-1])(indices)[None, :] 34 | s = jnp.repeat(s, z.shape[0], axis=0) 35 | 36 | z = MultiheadLinearAttentionLayer(n_heads=self.n_heads)(s, z, z) 37 | z = z[:, 0, :] 38 | return z 39 | 40 | 41 | class DeepSetPooling(nn.Module): 42 | mlp_dim: int = 128 43 | mlp_n_hidden_layers: int = 2 44 | 45 | @nn.compact 46 | def __call__(self, u, x): 47 | z = MLP([self.mlp_dim] * self.mlp_n_hidden_layers)(u) 48 | z = z.mean(axis=1) 49 | return z 50 | -------------------------------------------------------------------------------- /src/functional_autoencoders/util/pca.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | 3 | 4 | def pca(u): 5 | r""" 6 | Performs principal component analysis for discretisations of functional data $(u_{j})_{j = 1}^{N} \subset L^{2}$ 7 | by computing eigenfunctions and eigenvaleus of the empirical covariance operator 8 | 9 | $$C_{N} = \frac{1}{N} \sum_{j = 1}^{N} u_{j} \otimes u_{j}.$$ 10 | 11 | Returns :math:`L^{2}`-orthonormal eigenvectors and associated eigenvalues of the empirical covariance operator. 12 | 13 | Note: this currently only works for functions taking values in :math:`\mathbb{R}^{d}`. 14 | 15 | :param u: `jnp.array` of shape `[samples, grid_pts, out_dim]` 16 | 17 | Returns tuple of: 18 | - **eigenvalues**: `jnp.array` of shape `[n_eigs]` 19 | sorted from smallest to largest 20 | - **eigenvectors**: `jnp.array` of shape `[grid_points, out_dim, n_eigs]`. 21 | 22 | 23 | ## References 24 | Bhattacharya, Hosseini, Kovachki, Stuart (2021). Model reduction and neural networks for parametric PDEs. 25 | SMAI J. Comp. Math 7:121--157, doi:[10.5802/smai-jcm.74](https://dx.doi.org/10.5802/smai-jcm.74). 26 | """ 27 | 28 | grid_pts = u.shape[1] 29 | out_dim = u.shape[2] 30 | u = jnp.reshape(u, (u.shape[0], -1, 1)) 31 | 32 | # Form empirical covariance operator 33 | cov = (1.0 / u.shape[0]) * jnp.sum(u @ jnp.conj(jnp.swapaxes(u, -1, -2)), axis=0) 34 | eigenvalues, eigenvectors = jnp.linalg.eigh(cov) 35 | 36 | # Normalise eigenvalues and eigenvectors to account for L^2 inner product instead of Euclidean 37 | eigenvectors = eigenvectors.real 38 | l2_norms = jnp.expand_dims( 39 | jnp.mean(jnp.abs(eigenvectors) ** 2, axis=0) ** (-0.5), 0 40 | ) 41 | normalised_eigenvectors = eigenvectors * l2_norms 42 | normalised_eigenvalues = eigenvalues.real / u.shape[1] 43 | normalised_eigenvectors = jnp.reshape( 44 | normalised_eigenvectors, (grid_pts, out_dim, -1) 45 | ) 46 | return normalised_eigenvalues, normalised_eigenvectors 47 | -------------------------------------------------------------------------------- /src/functional_autoencoders/util/random/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htlambley/functional_autoencoders/540179aa43325e9fd96185cd4db738cf1dcd023f/src/functional_autoencoders/util/random/__init__.py -------------------------------------------------------------------------------- /src/functional_autoencoders/util/random/grf.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import numpy as np 4 | 5 | from functional_autoencoders.util.fft import idstn 6 | 7 | 8 | def _index_to_frequency(idx, max): 9 | ret = np.zeros_like(idx) 10 | for pos, (i, m) in enumerate(zip(idx, max)): 11 | if i <= m // 2: 12 | ret[pos] = i 13 | else: 14 | ret[pos] = -(m - i) 15 | return ret 16 | 17 | 18 | def _compute_torus_covariance_operator_sqrt_eigenvalues(shape, tau=3, d=2): 19 | eigs = np.zeros(shape) 20 | for index, _ in np.ndenumerate(eigs): 21 | idx = _index_to_frequency(index, shape) 22 | sum_square_idx = np.sum(np.square(idx)) 23 | if sum_square_idx != 0: 24 | eigs[index] = (4 * np.pi**2 * sum_square_idx + tau**2) ** (-d / 2) 25 | return eigs 26 | 27 | 28 | def torus_grf(key: jax.random.PRNGKey, n, shape, out_dim=1, tau=3, d=2, method="fft"): 29 | r"""Returns realisations of a mean-zero Gaussian random field on an $n$-dimensional torus with Matérn-type covariance operator and periodic boundary conditions. 30 | 31 | This function generates realisations of a mean-zero Gaussian random field on $X = L^{2}(\Omega; \mathbb{C})$, with 32 | $\Omega = \mathbb{T}^{n}$ defined as the domain $[0, 1]^{n}$ with periodic boundary. 33 | The covariance operator used is a perturbed inverse power of the periodic Laplacian on the torus, 34 | 35 | $$ C = (\tau^{2} I - \Delta)^{-d},$$ 36 | 37 | which is trace class when $d > n/2$. 38 | The dimension $n$ is inferred from the dimensions of `shape`. 39 | 40 | ## Arguments 41 | `n` : int 42 | number of realisations to generate 43 | `shape` : tuple of int 44 | shape of each output realisation; output must be square if `method='fft'`, i.e. all values of `shape` equal. 45 | `tau` : float 46 | length-scale parameter 47 | `d` : float 48 | smoothness parameter 49 | `method`: 'fft' 50 | the method used to generate the Gaussian random field; currently only supports 'fft', which computes the eigenvalues 51 | of the covariance operator analytically and generates a truncated Karhunen--Lo\`eve expansion via the fast Fourier transform. 52 | 53 | ## Returns 54 | Array of shape `(n, *shape)` containing `n` realisations of the Gaussian random field with the required shape. 55 | Note that the returned random field is *complex-valued*, so users may wish to take only the real part, which gives a real-valued 56 | Gaussian random function with mean zero. 57 | """ 58 | if method == "fft": 59 | if key is not None: 60 | key, subkey = jax.random.split(key) 61 | zhat = jax.random.normal( 62 | key, (n, *shape, out_dim) 63 | ) + 1j * jax.random.normal(subkey, (n, *shape, out_dim)) 64 | else: 65 | zhat = np.random.randn(n, *shape, out_dim) + 1j * np.random.randn( 66 | n, *shape, out_dim 67 | ) 68 | 69 | eigs = _compute_torus_covariance_operator_sqrt_eigenvalues(shape, tau=tau, d=d) 70 | eigs = jnp.expand_dims(eigs, -1) 71 | return jnp.fft.ifftn(eigs * zhat, norm="forward", axes=range(1, zhat.ndim - 1)) 72 | else: 73 | raise NotImplementedError("Only supported generation method is method='fft'") 74 | 75 | 76 | def _compute_dirichlet_covariance_operator_sqrt_eigenvalues( 77 | shape, tau=3, d=2, even_powers_only=False 78 | ): 79 | if len(shape) != 1 and even_powers_only: 80 | raise NotImplementedError("even_powers_only implemented only in 1D") 81 | eigs = np.zeros(shape) 82 | for index, _ in np.ndenumerate(eigs): 83 | sum_square_idx = np.sum(np.square(index + np.ones_like(index))) 84 | if sum_square_idx != 0: 85 | # Note here the 0th index is frequency 1, so the "even" frequencies are the odd indices 86 | if len(shape) == 1 and index[0] % 2 == 0 and even_powers_only: 87 | continue 88 | eigs[index] = (np.pi**2 * sum_square_idx + tau**2) ** (-d / 2) * 2 ** ( 89 | -len(shape) / 2 90 | ) 91 | return eigs 92 | 93 | 94 | def dirichlet_grf( 95 | key: jax.random.PRNGKey, 96 | n, 97 | shape, 98 | out_dim=1, 99 | tau=3, 100 | d=2, 101 | method="dst", 102 | even_powers_only=False, 103 | ): 104 | """Returns realisations of a mean-zero Gaussian random field on the $n$-dimensional square $[0, 1]^{n}$ with Matérn-type covariance operator and zero Dirichlet boundary conditions. 105 | 106 | The arguments are similar to `torus_grf`, with the exception of `even_powers_only`, which assigns coefficient zero to all 107 | odd frequencies in the Karhunen–Loeve expansion for compatibility with the work of Seidman et al. (2023), as discussed below. 108 | 109 | ## Notes 110 | The variational autoencoding neural operator of Seidman et al. (2023) uses a dataset based on GRFs with Dirichlet Laplacian covariance. 111 | Their implementation subtly differs because they only allow even frequencies, as noticeable in Fig. 6 of their paper where all the generated 112 | functions pass through 0.5. 113 | """ 114 | 115 | if method == "dst": 116 | if key is not None: 117 | zhat = jax.random.normal(key, (n, *shape, out_dim)) 118 | else: 119 | zhat = np.random.randn(n, *shape, out_dim) 120 | 121 | eigs = _compute_dirichlet_covariance_operator_sqrt_eigenvalues( 122 | shape, tau=tau, d=d, even_powers_only=even_powers_only 123 | ) 124 | eigs = jnp.expand_dims(eigs, -1) 125 | return idstn(eigs * zhat, norm="forward", axes=range(1, zhat.ndim - 1), type=1) 126 | else: 127 | raise NotImplementedError("Only supported generation method is method='kl'") 128 | -------------------------------------------------------------------------------- /src/functional_autoencoders/util/random/sde.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from typing import Callable 3 | from tqdm.auto import tqdm 4 | 5 | 6 | Drift = Callable[[np.array, float], np.array] 7 | Diffusion = Callable[[np.array, np.array, float], np.array] 8 | 9 | 10 | def euler_maruyama( 11 | x0: np.array, 12 | drift: Drift, 13 | diffusion: Diffusion, 14 | simulation_dt: float, 15 | n_steps: int, 16 | subsample_rate: int = 1, 17 | verbose: bool = False, 18 | ): 19 | r""" 20 | Simulates a stochastic differential equation of the form 21 | .. math:: 22 | \mathrm{d} X_{t} = a(X_{t}, t}) \mathrm{d} t + b(X_{t}, t) \mathrm{d} W_{t} 23 | 24 | using the Euler--Maruyama scheme. 25 | 26 | Arguments: 27 | x0 : np.array of shape [n_realisations, dimension] 28 | The initial conditions for the SDE. The size of the leading axis determines the number of 29 | realisations to simulate. 30 | drift : function (x: [n_realisations, dimension], t: float) -> [n_realisations, dimension] 31 | Drift function :math:`a(X_{t}, t)` for the SDE, which should be batched over the leading axis. 32 | diffusion : function (x: [n_realisations, dimension], dwt: [n_realisations, dimension], t: float) -> [n_realisations, dimension] 33 | Diffusion function :math:`b(X_{t}, t)` for the SDE, which should be batched over the leading axis. 34 | simulation_dt : float 35 | Timestep for internal Euler--Maruyama solver 36 | n_steps : int 37 | Number of time steps to take. This determines the final time by T = dt * n_steps. 38 | subsample_rate : int 39 | Number of simulation steps per step reported in the output array. The dt in the output array is thus 40 | simulation_dt * subsample_rate. 41 | n_steps must be divisible by subsample_rate. 42 | 43 | Returns an array of shape [n_realisations, dimension, n_steps + 1], including the initial condition. 44 | 45 | The Euler--Maruyama scheme converges at rate :math:`\sqrt{t}` to the true solution. 46 | """ 47 | n_realisations = x0.shape[0] 48 | d = x0.shape[1] 49 | x = x0 50 | result = np.zeros((n_realisations, n_steps // subsample_rate + 1, d)) 51 | result[:, 0, :] = x0 52 | t = 0.0 53 | for i in tqdm(range(0, n_steps), disable=not verbose): 54 | dwt = np.sqrt(simulation_dt) * np.random.randn(n_realisations, d) 55 | x = x + drift(x, t) * simulation_dt + diffusion(x, dwt, t) 56 | if i % subsample_rate == 0: 57 | result[:, i // subsample_rate + 1, :] = x 58 | t = t + simulation_dt 59 | return result 60 | 61 | 62 | def add_bm_noise(samples, epsilon, theta, sim_dt, T): 63 | n_batch, n_pts, n_dim = samples.shape 64 | x0 = np.zeros((n_batch, n_dim)) 65 | bm = euler_maruyama( 66 | x0, 67 | lambda xt, t: -theta * (xt), 68 | lambda xt, dwt, t: (epsilon ** (0.5)) * dwt, 69 | sim_dt, 70 | int(T / sim_dt), 71 | int((T / (n_pts - 1)) / sim_dt), 72 | ) 73 | return samples + bm 74 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htlambley/functional_autoencoders/540179aa43325e9fd96185cd4db738cf1dcd023f/tests/__init__.py -------------------------------------------------------------------------------- /tests/dst.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append("../src") 4 | 5 | import jax 6 | import unittest 7 | import numpy as np 8 | import scipy 9 | 10 | from functional_autoencoders.util.fft import dst, dstn, idst, idstn 11 | 12 | 13 | def relative_error(a, b): 14 | return np.sum((a - b) ** 2) / np.sum(a**2) 15 | 16 | 17 | eps = 1e-2 18 | 19 | 20 | dst = jax.jit(dst, static_argnums=(1, 2, 3)) 21 | dstn = jax.jit(dstn, static_argnums=(1, 2, 3)) 22 | 23 | 24 | class DST(unittest.TestCase): 25 | def test_1d_sins(self): 26 | x = np.linspace(0, 1, 101)[1:-1] 27 | y = 2 * np.sin(np.pi * x) + 2 * np.sin(np.pi * 3 * x) 28 | 29 | yhat_ours = dst(y, type=1, norm="forward") 30 | yhat_true = scipy.fft.dst(y, type=1, norm="forward") 31 | self.assertTrue(relative_error(yhat_true, yhat_ours) < eps) 32 | self.assertAlmostEqual(yhat_ours[0], 1.0) 33 | self.assertAlmostEqual(yhat_ours[1], 0.0) 34 | self.assertAlmostEqual(yhat_ours[2], 1.0) 35 | self.assertAlmostEqual(yhat_ours[3], 0.0) 36 | 37 | def test_1d_rand(self): 38 | y = np.random.randn(5001) 39 | yhat_ours = dst(y, type=1, norm="forward") 40 | yhat_true = scipy.fft.dst(y, type=1, norm="forward") 41 | self.assertTrue(relative_error(yhat_true, yhat_ours) < eps) 42 | 43 | def test_2d_rand(self): 44 | y = np.random.randn(33, 5001) 45 | yhat_ours = dst(y, type=1, norm="forward") 46 | yhat_true = scipy.fft.dst(y, type=1, norm="forward") 47 | self.assertTrue(relative_error(yhat_true, yhat_ours) < eps) 48 | 49 | def test_2d_rand_ortho(self): 50 | y = np.random.randn(33, 5001) 51 | yhat_ours = dst(y, type=1, norm="ortho") 52 | yhat_true = scipy.fft.dst(y, type=1, norm="ortho") 53 | self.assertTrue(relative_error(yhat_true, yhat_ours) < eps) 54 | 55 | def test_2d_rand_backward(self): 56 | y = np.random.randn(33, 5001) 57 | yhat_ours = dst(y, type=1, norm="backward") 58 | yhat_true = scipy.fft.dst(y, type=1, norm="backward") 59 | self.assertTrue(relative_error(yhat_true, yhat_ours) < eps) 60 | 61 | def test_dst_self_inverse(self): 62 | n = 101 63 | y = np.random.randn(101) 64 | yhat = dst(y, type=1) 65 | yhathat = dst(yhat, type=1) 66 | true = 2 * (n + 1) 67 | self.assertTrue((np.mean(yhathat / y) - true) / true < eps) 68 | 69 | def test_inverse(self): 70 | y = np.random.randn(101) 71 | self.assertTrue(relative_error(y, idst(dst(y, type=1), type=1)) < eps) 72 | self.assertTrue( 73 | relative_error( 74 | y, idst(dst(y, type=1, norm="forward"), type=1, norm="forward") 75 | ) 76 | < eps 77 | ) 78 | 79 | 80 | class DSTN(unittest.TestCase): 81 | def test_dstn_corresponds_with_multiple_dst(self): 82 | y = np.random.randn(33, 33, 33) 83 | yhat_true = scipy.fft.dstn(y, type=1, norm="forward") 84 | yhat_dstn = dstn(y, type=1, norm="forward", axes=(0, 1, 2)) 85 | for axis in [0, 1, 2]: 86 | y = dst(y, type=1, norm="forward", axis=axis) 87 | self.assertTrue(relative_error(yhat_true, yhat_dstn) < 1) 88 | self.assertTrue(relative_error(yhat_true, y) < 1) 89 | 90 | def test_dstn_partial(self): 91 | y = np.random.randn(32, 33, 33) 92 | yhat_true = scipy.fft.dstn(y, type=1, norm="forward", axes=(1, 2)) 93 | yhat_ours = dstn(y, type=1, norm="forward", axes=(1, 2)) 94 | self.assertTrue(relative_error(yhat_true, yhat_ours) < 1) 95 | 96 | def test_inverse(self): 97 | y = np.random.randn(9, 9, 7) 98 | self.assertTrue(relative_error(y, idstn(dstn(y, type=1), type=1)) < 1) 99 | 100 | 101 | if __name__ == "__main__": 102 | unittest.main() 103 | -------------------------------------------------------------------------------- /tests/sobolev.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append("../src") 4 | 5 | import unittest 6 | import jax.numpy as jnp 7 | 8 | from functional_autoencoders.domains.grid import ZeroBoundaryConditions 9 | 10 | 11 | def relative_error(a, b): 12 | return jnp.sum((a - b) ** 2) / jnp.sum(a**2) 13 | 14 | 15 | class SobolevNorm(unittest.TestCase): 16 | def test_sin(self): 17 | n_pts = 100 18 | pts = jnp.linspace(0, 1, n_pts + 2)[1:-1] 19 | x = jnp.sqrt(2) * jnp.sin(jnp.pi * pts) 20 | x = jnp.reshape(x, (1, -1, 1)) 21 | pts = jnp.reshape(pts, (1, -1, 1)) 22 | self.assertAlmostEqual( 23 | ZeroBoundaryConditions(0).squared_norm(x, pts), 1.0, places=3 24 | ) 25 | self.assertAlmostEqual( 26 | ZeroBoundaryConditions(1).squared_norm(x, pts), 2.0, places=3 27 | ) 28 | self.assertAlmostEqual( 29 | ZeroBoundaryConditions(2).squared_norm(x, pts), 4.0, places=3 30 | ) 31 | 32 | def test_batched(self): 33 | n_pts = 100 34 | pts = jnp.linspace(0, 1, n_pts + 2)[1:-1] 35 | x = jnp.sqrt(2) * jnp.sin(jnp.pi * pts) 36 | x = jnp.reshape(x, (1, -1, 1)) 37 | x = jnp.tile(x, (32, 1, 1)) 38 | pts = jnp.reshape(pts, (1, -1, 1)) 39 | pts = jnp.tile(pts, (32, 1, 1)) 40 | self.assertAlmostEqual( 41 | jnp.sum( 42 | ZeroBoundaryConditions(0).squared_norm(x, pts) - 1.0 * jnp.ones((32,)) 43 | ), 44 | 0.0, 45 | places=3, 46 | ) 47 | self.assertAlmostEqual( 48 | jnp.sum( 49 | ZeroBoundaryConditions(1).squared_norm(x, pts) - 2.0 * jnp.ones((32,)) 50 | ), 51 | 0.0, 52 | places=3, 53 | ) 54 | self.assertAlmostEqual( 55 | jnp.sum( 56 | ZeroBoundaryConditions(2).squared_norm(x, pts) - 4.0 * jnp.ones((32,)) 57 | ), 58 | 0.0, 59 | places=3, 60 | ) 61 | 62 | def test_2d_out_dim(self): 63 | n_pts = 100 64 | pts = jnp.linspace(0, 1, n_pts + 2)[1:-1] 65 | x = jnp.sqrt(2) * jnp.sin(jnp.pi * pts) 66 | x = jnp.reshape(x, (1, -1, 1)) 67 | x = jnp.tile(x, (32, 1, 2)) 68 | pts = jnp.reshape(pts, (1, -1, 1)) 69 | pts = jnp.tile(pts, (32, 1, 1)) 70 | self.assertAlmostEqual( 71 | jnp.sum( 72 | ZeroBoundaryConditions(0).squared_norm(x, pts) - 2.0 * jnp.ones((32,)) 73 | ), 74 | 0.0, 75 | places=3, 76 | ) 77 | self.assertAlmostEqual( 78 | jnp.sum( 79 | ZeroBoundaryConditions(1).squared_norm(x, pts) - 4.0 * jnp.ones((32,)) 80 | ), 81 | 0.0, 82 | places=3, 83 | ) 84 | 85 | 86 | class SobolevInnerProd(unittest.TestCase): 87 | def test_sin(self): 88 | n_pts = 100 89 | pts = jnp.linspace(0, 1, n_pts + 2)[1:-1] 90 | x = jnp.sqrt(2) * jnp.sin(jnp.pi * pts) 91 | x = jnp.reshape(x, (1, -1, 1)) 92 | pts = jnp.reshape(pts, (1, -1, 1)) 93 | self.assertAlmostEqual( 94 | ZeroBoundaryConditions(0).inner_product(x, x, pts), 1.0, places=3 95 | ) 96 | self.assertAlmostEqual( 97 | ZeroBoundaryConditions(1).inner_product(x, x, pts), 2.0, places=3 98 | ) 99 | self.assertAlmostEqual( 100 | ZeroBoundaryConditions(2).inner_product(x, x, pts), 4.0, places=3 101 | ) 102 | 103 | def test_2d_out_dim(self): 104 | n_pts = 100 105 | pts = jnp.linspace(0, 1, n_pts + 2)[1:-1] 106 | x = jnp.sqrt(2) * jnp.sin(jnp.pi * pts) 107 | x = jnp.reshape(x, (1, -1, 1)) 108 | x = jnp.tile(x, (32, 1, 2)) 109 | pts = jnp.reshape(pts, (1, -1, 1)) 110 | pts = jnp.tile(pts, (32, 1, 1)) 111 | self.assertAlmostEqual( 112 | jnp.sum( 113 | ZeroBoundaryConditions(0).inner_product(x, x, pts) 114 | - 2.0 * jnp.ones((32,)) 115 | ), 116 | 0.0, 117 | places=3, 118 | ) 119 | self.assertAlmostEqual( 120 | jnp.sum( 121 | ZeroBoundaryConditions(1).inner_product(x, x, pts) 122 | - 4.0 * jnp.ones((32,)) 123 | ), 124 | 0.0, 125 | places=3, 126 | ) 127 | 128 | def test_zero_inner_prod(self): 129 | n_pts = 101 130 | pts = jnp.linspace(0, 1, n_pts + 2)[1:-1] 131 | x = jnp.sqrt(2) * jnp.sin(jnp.pi * pts) 132 | x = jnp.reshape(x, (1, -1, 1)) 133 | y = jnp.zeros_like(x) 134 | pts = jnp.reshape(pts, (1, -1, 1)) 135 | self.assertAlmostEqual( 136 | ZeroBoundaryConditions(1.1).inner_product(x, y, pts), 137 | 0.0, 138 | places=3, 139 | ) 140 | 141 | def test_inner_prod_is_norm(self): 142 | n_pts = 101 143 | pts = jnp.linspace(0, 1, n_pts + 2)[1:-1] 144 | x = jnp.sqrt(2) * jnp.sin(jnp.pi * pts) 145 | x = jnp.reshape(x, (1, -1, 1)) 146 | pts = jnp.reshape(pts, (1, -1, 1)) 147 | domain = ZeroBoundaryConditions(-0.9) 148 | self.assertAlmostEqual( 149 | domain.inner_product(x, x, pts), 150 | domain.squared_norm(x, pts), 151 | places=3, 152 | ) 153 | 154 | 155 | if __name__ == "__main__": 156 | unittest.main() 157 | --------------------------------------------------------------------------------