├── .gitignore ├── .gitmodules ├── README.md ├── VERSION ├── config ├── architecture │ ├── big_concat.yaml │ ├── concat.yaml │ ├── embed.yaml │ ├── mha.yaml │ ├── multi.yaml │ └── small_concat.yaml ├── base │ ├── default.yaml │ ├── normal.yaml │ └── uniform.yaml ├── beta_schedule │ ├── constant.yaml │ ├── linear.yaml │ ├── linear_product2.yaml │ ├── quadratic.yaml │ └── triangle.yaml ├── dataset │ ├── angles.yaml │ ├── angles2.yaml │ ├── chain.yaml │ ├── chain1.yaml │ ├── dirac.yaml │ ├── earthquake.yaml │ ├── fire.yaml │ ├── flood.yaml │ ├── langevin.yaml │ ├── loop-torus.yaml │ ├── mix_vmf.yaml │ ├── polytope-torus.yaml │ ├── polytope.yaml │ ├── uniform.yaml │ ├── us_forest_fires.yaml │ ├── vmf.yaml │ ├── volcanoe.yaml │ ├── wrapnormmix.yaml │ └── wrapped.yaml ├── embedding │ ├── laplacian_eigenfunction.yaml │ └── none.yaml ├── experiment │ ├── earthquake.yaml │ ├── fire.yaml │ ├── flood.yaml │ ├── hessian_dirichlet_d10.yaml │ ├── hessian_dirichlet_d2.yaml │ ├── hessian_dirichlet_d3.yaml │ ├── hessian_hypercube_analytic_d2.yaml │ ├── hessian_hypercube_d10.yaml │ ├── hessian_hypercube_d2.yaml │ ├── hessian_hypercube_d3.yaml │ ├── hessian_loops.yaml │ ├── hessian_smooth_L.yaml │ ├── hyperboloid.yaml │ ├── poincare.yaml │ ├── reflect_hypercube_analytic_d2.yaml │ ├── s2_toy.yaml │ ├── smooth_loops.yaml │ ├── so3.yaml │ ├── spd_d3.yaml │ ├── tn.yaml │ └── volcanoe.yaml ├── flow │ ├── brownian.yaml │ ├── brownian_hessian_product2.yaml │ ├── brownian_product2.yaml │ ├── cnf.yaml │ ├── hessian.yaml │ ├── hessian2k.yaml │ ├── langevin.yaml │ ├── product.yaml │ ├── reflectedbrownian.yaml │ └── vp.yaml ├── generator │ ├── ambient-torus.yaml │ ├── ambient.yaml │ ├── canonical.yaml │ ├── div_free.yaml │ ├── eigen.yaml │ ├── lie_algebra.yaml │ ├── torus.yaml │ └── transport.yaml ├── logger │ ├── all.yaml │ ├── csv.yaml │ └── wandb.yaml ├── loss │ ├── dsm0.yaml │ ├── dsms.yaml │ ├── dsmv.yaml │ ├── euclidean_ism.yaml │ ├── hessian.yaml │ ├── hessian_sqrt_like_w.yaml │ ├── ism.yaml │ ├── logp.yaml │ ├── moser.yaml │ ├── product_ism.yaml │ └── ssm.yaml ├── main.yaml ├── manifold │ ├── bounded-sphere.yaml │ ├── chain.yaml │ ├── chain1.yaml │ ├── hyperbolic.yaml │ ├── hypercube.yaml │ ├── nsphere.yaml │ ├── polytope-and-sphere.yaml │ ├── polytope-torus.yaml │ ├── polytope.yaml │ ├── s1.yaml │ ├── so3.yaml │ └── tn.yaml ├── model │ ├── cnf.yaml │ ├── exp_sgm.yaml │ ├── moser.yaml │ ├── rsgm.yaml │ ├── stereo_sgm.yaml │ └── tanhexp_sgm.yaml ├── optim │ └── adam.yaml ├── pushf │ ├── default.yaml │ ├── moser.yaml │ ├── product.yaml │ └── sde.yaml ├── scheduler │ ├── constant.yaml │ ├── cosine.yaml │ ├── r3cosine.yaml │ ├── ramp.yaml │ ├── rcosine.yaml │ └── rloglinear.yaml ├── server │ ├── base.yaml │ ├── debug_ziz.yaml │ ├── local.yaml │ └── ziz.yaml └── transform │ ├── exp.yaml │ ├── id.yaml │ ├── stereo.yaml │ └── tanhexp.yaml ├── data ├── .gitkeep ├── L.npz ├── T.walk.csv ├── angles.csv ├── b.walk.csv ├── big_hypercube_d=10.npz ├── big_hypercube_d=3.npz ├── birkhoff_d=4.npz ├── dirichlet=15.npz ├── dirichlet_d=10.npz ├── dirichlet_d=15.npz ├── dirichlet_d=2.npz ├── dirichlet_d=3.npz ├── fire.csv ├── flood.csv ├── hypercube_d=10.npz ├── hypercube_d=2.npz ├── hypercube_d=3.npz ├── multimodal_dirichlet_d=2.npz ├── multimodal_hypercube_d=2.npz ├── quakes_all.csv ├── smooth_loops_dist.npz ├── unit_hypercube_d=2.npz └── volerup.csv ├── main.py ├── main_product.py ├── requirements.txt ├── requirements_dev.txt ├── requirements_exps.txt ├── riemannian_score_sde ├── datasets │ ├── __init__.py │ ├── antibody.py │ ├── earth.py │ ├── mixture.py │ └── simple.py ├── losses.py ├── models │ ├── __init__.py │ ├── distribution.py │ ├── embedding.py │ ├── transform.py │ └── vector_field.py ├── sampling.py ├── sde.py └── utils │ ├── __init__.py │ ├── normalization.py │ └── vis.py ├── run.py ├── run_product.py ├── score_sde ├── __init__.py ├── datasets │ ├── __init__.py │ ├── mixture.py │ ├── split.py │ └── tensordataset.py ├── likelihood.py ├── losses.py ├── models │ ├── __init__.py │ ├── architecture.py │ ├── distribution.py │ ├── flow.py │ ├── layers │ │ ├── __init__.py │ │ └── layers.py │ ├── mlp.py │ ├── model.py │ ├── normalization.py │ └── transform.py ├── ode.py ├── optim.py ├── product.py ├── sampling.py ├── schedule.py ├── sde.py └── utils │ ├── __init__.py │ ├── cfg.py │ ├── data.py │ ├── jax.py │ ├── loggers_pl │ ├── __init__.py │ ├── base.py │ ├── csv_log.py │ ├── utilities.py │ └── wandb.py │ ├── random.py │ ├── registry.py │ ├── schedule.py │ ├── switch.py │ ├── training.py │ └── typing.py ├── scripts ├── abdb │ └── download_data.py ├── approximate_forward.py ├── basalisk_calc_angles.py ├── blender │ ├── mesh_utils.py │ └── render_utils.py ├── deploy │ ├── config.sh │ ├── make_venv.sh │ ├── setup_ssh_keys_on_zizgpus.sh │ ├── sync_keys.sh │ └── sync_venv.sh ├── diffusion_sphere │ ├── blender_diffusion_sphere_stills.py │ └── make_diffusion_sphere_plot_data.py ├── examples │ ├── 1D_example.py │ ├── 1D_example2.py │ ├── diffusion.ipynb │ ├── diffusion2.ipynb │ ├── earth_datasets.py │ ├── s2_example.py │ └── so3.ipynb ├── gaussian_pushforward │ └── make_gaussian_pushforward_data.py ├── gaussian_random_walk │ └── make_grw_data.py ├── gaussian_random_walk_step │ └── make_grw_step_data.py ├── horizontal_vs_vertical.py ├── hyperbolic │ └── make_hyperbolic_data.py ├── kent │ ├── agg_results.py │ ├── config.yaml │ ├── kent_model.py │ └── run_kent.py ├── make_fisher_data.py ├── plot_diffusion.py ├── pull_earth_results.py ├── pull_scaling_results.py ├── sabdab │ └── download_data.py ├── self_concordance_barrier.ipynb ├── so3_results.py ├── test_abdb.py ├── test_heat_kernel.py ├── test_hyperbolic.py ├── test_hyperbolic2.py ├── test_log_cholesky.py ├── test_sabdab.py └── utils.py ├── setup.py └── tests ├── test_divergence.py ├── test_haiku.py ├── test_likelihood.py ├── test_likelihood_transform.py ├── test_registry.py ├── test_transform.py └── test_vmf.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | 132 | results/ 133 | images/ 134 | data/abdb/ 135 | data/sabdab/ 136 | scripts/**/*.png 137 | scripts/**/*.csv 138 | scripts/**/*.obj 139 | 140 | .vscode -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "geomstats"] 2 | path = geomstats 3 | url = https://github.com/oxcsml/geomstats 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [Constrained Riemannian Score-Based Generative Modelling](https://arxiv.org/abs/2202.02763) 2 | 3 | This repo requires a modified version of [geomstats](https://github.com/geomstats/geomstats) that adds jax functionality, and a number of other modifications. This can be found [here](https://github.com/oxcsml/geomstats/tree/tmlr). 4 | 5 | This repository contains the code for the paper `Diffusion Models for Constrained Domains`. This paper theoretically and practically extends score-based generative modelling (SGM) from Reimannian manifolds to any convex subsets of connected and complete Riemannian manifolds. 6 | 7 | SGMs are a powerful class of generative models that exhibit remarkable empirical performance. Score-based generative modelling consists of a “noising” stage, whereby a diffusion is used to gradually add Gaussian noise to data, and a generative model, which entails a “denoising” process defined by approximating the time-reversal of the diffusion. 8 | 9 | ## Install 10 | 11 | Simple install instructions are: 12 | ``` 13 | git clone https://github.com/oxcsml/score-sde.git 14 | git clone https://github.com/oxcsml/geomstats.git 15 | virtualenv -p python3.9 venv 16 | source venv/bin/activate 17 | pip install -r requirements.txt 18 | pip install -r requirements_exps.txt 19 | GEOMSTATS_BACKEND=jax pip install -e geomstats 20 | pip install -e . 21 | ``` 22 | 23 | - `requirements.txt` contains the core requirements for running the code in the `score_sde` and `riemmanian_score_sde` packages. NOTE: you may need to alter the jax versions here to match your setup. 24 | - `requirements_exps.txt` contains extra dependencies needed for running our experiments, and using the `run.py` file provided for training / testing models. Also contains extra dependencies for using the job scheduling functionality of hydra. 25 | - `requirements_dev.txt` contains some handy development packages. 26 | 27 | ## Code structure 28 | 29 | The bulk of the code for this project can be found in 3 places 30 | - The `score_sde` package contains code to run SDEs on Euclidean space. This code is modified from the code from the paper [Score-Based Generative Modeling through Stochastic Differential Equations](https://github.com/yang-song/score_sde). 31 | - The `riemannian_score_sde` package contains code needed to extend the code in `score_sde` to Riemannian manifolds. 32 | - An extended version of [geomstats](https://github.com/oxcsml/geomstats) that adds `jax` support, and a number of other extensions. 33 | 34 | ### Different classes of models 35 | Most of the models used in this paper can be though of as a pushforward of a simple density under some continuous-time transformation into a more complex density. In code, this is represented by a `score_sde.models.flow.PushForward`, containing a base distribution, and in the simplest case, a time dependent vector field that defines the flow of the density through time. 36 | 37 | A `Continuous Normalizing Flow (CNF) [score_sde.models.flow.PushForward]` 38 | samples from the pushforward distribution by evolving samples from the base 39 | measure under the action of the vector field. The log-likelihood is computed by 40 | adding the integral of the divergence of the vector field along the sample path 41 | to the log-likelihood of the point under the base measure. Models are trained by 42 | optimising this log-likelihood of the training data. 43 | 44 | `Moser flows [score_sde.models.flow.MoserFlow]` alleviate the expensive 45 | likelihood computation in training using an alternative, cheaper, method of 46 | computing the likelihood. This unfortunately requires a condition on the 47 | pushforward vector field, which is enforced by a regularisation term in the 48 | loss. As a result the cheaper likelihood computation unreliable, and the 49 | sampling must still be done with expensive ODE solutions. 50 | 51 | `Score-based Generative Models (SGMs) [score_sde.models.flow.SDEPushForward]` 52 | instead consider a pushforward defined by the time-reversal of a noising 53 | Stochastic Differential Equation (SDE). Instead of relying on likelihood based 54 | training, these models are trained using score matching. The likelihood is 55 | computed by converting the SDE to the corresponding likelihood ODE. While 56 | identical in nature to the likelihood ODE of CNFs/Moser flows, these are 57 | typically easier to solve computationally due the learned vector fields 58 | being less stiff. 59 | 60 | Other core pieces of code include: 61 | 62 | - `score_sde/models/transform.py` which defines transforms between manifolds and Euclidean space, designed to allow for pushing a Euclidean density onto a manifold. 63 | - `score_sde/models/vector_field.py` which contains various parametrisations of vector fields needed for defining the score functions / vector fields 64 | - `score_sde/sde.py` which defines various SDEs 65 | - `score_sde/losses.py` which contains all the loss functions used 66 | - `score_sde/sampling.py` which provides methods for sampling SDEs 67 | - `score_sde/ode.py` which provides methods for solving ODEs 68 | 69 | and their counterparts in `riemannian_score_sde`. 70 | 71 | ### Model structure 72 | Models are decomposed in three blocks: 73 | - a `base` distribution, with `z ~ base` (a 'prior') 74 | - a learnable diffeomorphic `flow: z -> y` (the flexible component of the model, potentially stochastic as for SGMs) 75 | - a `transform` map `y -> x ∈ M` (if the model is *not* defined on the manifold and needs to be 'projected', else the model is *Riemannian* and `transform=Id`) 76 | Thus, the generative models are defined as `z -> y -> x`. 77 | 78 | ## Reproducing experiments 79 | Experiment configuration is handled by [hydra](https://hydra.cc/docs/intro/), a highly flexible `yaml` based configuration package. Base configs can be found in `config`, and parameters are overridden in the command line. Sweeps over parameters can also be managed with a single command. 80 | 81 | Jobs scheduled on a cluster using a number of different plugins. We use Slurm, and configs for this can be found in `config/server` (note these are reasonably general but have some setup-specific parts). Other systems can easily be substituted by creating a new server configuration. 82 | 83 | The main training and testing script can be found in `run.py`, and is dispatched by running `python main.py [OPTIONs]`. 84 | 85 | ### Logging 86 | By default we log to CSV files and to [Weights and biases](wandb.ai). To use weights and biases, you will need to have an appropriate `WANDB_API_KEY` set in your environment, and to modify the `entity` and `project` entries in the `config/logger/wandb.yaml` file. The top level local logging directory can be set via the `logs_dir` variable. 87 | -------------------------------------------------------------------------------- /VERSION: -------------------------------------------------------------------------------- 1 | 1.0.0 -------------------------------------------------------------------------------- /config/architecture/big_concat.yaml: -------------------------------------------------------------------------------- 1 | _target_: score_sde.models.Concat 2 | # output_shape: ${manifold.embedding_space.dim} 3 | hidden_shapes: [512, 512, 512, 512, 512, 512, 512, 512] 4 | act: sin 5 | -------------------------------------------------------------------------------- /config/architecture/concat.yaml: -------------------------------------------------------------------------------- 1 | _target_: score_sde.models.Concat 2 | # output_shape: ${manifold.embedding_space.dim} 3 | hidden_shapes: [512, 512, 512, 512, 512, 512] 4 | act: sin 5 | -------------------------------------------------------------------------------- /config/architecture/embed.yaml: -------------------------------------------------------------------------------- 1 | _target_: score_sde.models.ConcatEmbed 2 | enc_shapes: [128] 3 | dec_shapes: [512, 512, 512, 512, 512] 4 | t_dim: 64 5 | act: sin -------------------------------------------------------------------------------- /config/architecture/mha.yaml: -------------------------------------------------------------------------------- 1 | _target_: score_sde.models.Attention 2 | # output_shape: ${manifold.embedding_space.dim} 3 | hidden_shapes: [512, 512] 4 | act: sin -------------------------------------------------------------------------------- /config/architecture/multi.yaml: -------------------------------------------------------------------------------- 1 | _target_: score_sde.models.Multi 2 | # output_shape: ${manifold.embedding_space.dim} 3 | hidden_shapes: [512, 512, 512, 512, 512] 4 | act: sin -------------------------------------------------------------------------------- /config/architecture/small_concat.yaml: -------------------------------------------------------------------------------- 1 | _target_: score_sde.models.Concat 2 | # output_shape: ${manifold.embedding_space.dim} 3 | hidden_shapes: [512, 512] 4 | act: sin 5 | -------------------------------------------------------------------------------- /config/base/default.yaml: -------------------------------------------------------------------------------- 1 | _target_: riemannian_score_sde.models.distribution.DefaultDistribution -------------------------------------------------------------------------------- /config/base/normal.yaml: -------------------------------------------------------------------------------- 1 | _target_: score_sde.sde.NormalDistribution -------------------------------------------------------------------------------- /config/base/uniform.yaml: -------------------------------------------------------------------------------- 1 | _target_: riemannian_score_sde.sde.UniformDistribution -------------------------------------------------------------------------------- /config/beta_schedule/constant.yaml: -------------------------------------------------------------------------------- 1 | _target_: score_sde.schedule.ConstantBetaSchedule 2 | tf: 1. 3 | value: 1. 4 | -------------------------------------------------------------------------------- /config/beta_schedule/linear.yaml: -------------------------------------------------------------------------------- 1 | _target_: score_sde.schedule.LinearBetaSchedule 2 | beta_0: ${beta_0} 3 | beta_f: ${beta_f} 4 | t0: 0. 5 | tf: 1. 6 | lambda0: 1. -------------------------------------------------------------------------------- /config/beta_schedule/linear_product2.yaml: -------------------------------------------------------------------------------- 1 | - _target_: score_sde.schedule.LinearBetaSchedule 2 | beta_0: ${beta_0_0} 3 | beta_f: ${beta_f_0} 4 | t0: 0. 5 | tf: 1. 6 | lambda0: 1. 7 | - _target_: score_sde.schedule.LinearBetaSchedule 8 | beta_0: ${beta_0_1} 9 | beta_f: ${beta_f_1} 10 | t0: 0. 11 | tf: 1. 12 | lambda0: 1. -------------------------------------------------------------------------------- /config/beta_schedule/quadratic.yaml: -------------------------------------------------------------------------------- 1 | _target_: score_sde.schedule.QuadraticBetaSchedule 2 | beta_0: 0.01 3 | beta_f: 5. 4 | t0: 0. 5 | tf: 1. -------------------------------------------------------------------------------- /config/beta_schedule/triangle.yaml: -------------------------------------------------------------------------------- 1 | _target_: score_sde.schedule.TriangleBetaSchedule 2 | beta_min: 0.001 3 | beta_max: 2. 4 | t0: 0. 5 | tf: 1. -------------------------------------------------------------------------------- /config/dataset/angles.yaml: -------------------------------------------------------------------------------- 1 | _target_: score_sde.datasets.graph_dataset.GraphLoader 2 | batch_dims: ${batch_size} 3 | seed: ${seed} -------------------------------------------------------------------------------- /config/dataset/angles2.yaml: -------------------------------------------------------------------------------- 1 | _target_: score_sde.datasets.tensordataset.CSVDataset 2 | file: /data/ziz/not-backed-up/fishman/score-sde/data/angles.csv -------------------------------------------------------------------------------- /config/dataset/chain.yaml: -------------------------------------------------------------------------------- 1 | _target_: score_sde.utils.switch.SwitchObject 2 | objects: 3 | - _target_: riemannian_score_sde.datasets.antibody.Antibody 4 | n_links: 9 5 | - _target_: riemannian_score_sde.datasets.antibody.Antibody 6 | n_links: 10 7 | - _target_: riemannian_score_sde.datasets.antibody.Antibody 8 | n_links: 11 9 | - _target_: riemannian_score_sde.datasets.antibody.Antibody 10 | n_links: 14 11 | - _target_: riemannian_score_sde.datasets.antibody.Antibody 12 | n_links: 15 13 | - _target_: riemannian_score_sde.datasets.antibody.Antibody 14 | n_links: 13 15 | - _target_: riemannian_score_sde.datasets.antibody.Antibody 16 | n_links: 16 17 | - _target_: riemannian_score_sde.datasets.antibody.Antibody 18 | n_links: 17 19 | - _target_: riemannian_score_sde.datasets.antibody.Antibody 20 | n_links: 18 21 | - _target_: riemannian_score_sde.datasets.antibody.Antibody 22 | n_links: 19 23 | - _target_: riemannian_score_sde.datasets.antibody.Antibody 24 | n_links: 20 25 | - _target_: riemannian_score_sde.datasets.antibody.Antibody 26 | n_links: 21 27 | - _target_: riemannian_score_sde.datasets.antibody.Antibody 28 | n_links: 24 29 | - _target_: riemannian_score_sde.datasets.antibody.Antibody 30 | n_links: 12 31 | - _target_: riemannian_score_sde.datasets.antibody.Antibody 32 | n_links: 25 33 | - _target_: riemannian_score_sde.datasets.antibody.Antibody 34 | n_links: 22 35 | - _target_: riemannian_score_sde.datasets.antibody.Antibody 36 | n_links: 27 37 | - _target_: riemannian_score_sde.datasets.antibody.Antibody 38 | n_links: 23 39 | - _target_: riemannian_score_sde.datasets.antibody.Antibody 40 | n_links: 28 41 | - _target_: riemannian_score_sde.datasets.antibody.Antibody 42 | n_links: 30 43 | - _target_: riemannian_score_sde.datasets.antibody.Antibody 44 | n_links: 26 45 | - _target_: riemannian_score_sde.datasets.antibody.Antibody 46 | n_links: 29 -------------------------------------------------------------------------------- /config/dataset/chain1.yaml: -------------------------------------------------------------------------------- 1 | _target_: riemannian_score_sde.datasets.antibody.Antibody 2 | n_links: 16 3 | batch_size: ${batch_size} 4 | -------------------------------------------------------------------------------- /config/dataset/dirac.yaml: -------------------------------------------------------------------------------- 1 | _target_: riemannian_score_sde.datasets.DiracDataset 2 | _convert_: partial 3 | mu: [1., 0., 0.] 4 | manifold: ${manifold} -------------------------------------------------------------------------------- /config/dataset/earthquake.yaml: -------------------------------------------------------------------------------- 1 | _target_: riemannian_score_sde.datasets.earth.Earthquake 2 | _convert_: partial 3 | data_dir: ${data_dir} 4 | name: earthquake -------------------------------------------------------------------------------- /config/dataset/fire.yaml: -------------------------------------------------------------------------------- 1 | _target_: riemannian_score_sde.datasets.earth.Fire 2 | _convert_: partial 3 | data_dir: ${data_dir} 4 | name: fire -------------------------------------------------------------------------------- /config/dataset/flood.yaml: -------------------------------------------------------------------------------- 1 | _target_: riemannian_score_sde.datasets.earth.Flood 2 | _convert_: partial 3 | data_dir: ${data_dir} 4 | name: flood -------------------------------------------------------------------------------- /config/dataset/langevin.yaml: -------------------------------------------------------------------------------- 1 | _target_: riemannian_score_sde.datasets.Langevin 2 | _convert_: partial 3 | scale: 10 4 | K: 1 5 | batch_dims: 6 | - ${batch_size} 7 | manifold: ${manifold} 8 | seed: ${seed} 9 | conditional: false -------------------------------------------------------------------------------- /config/dataset/loop-torus.yaml: -------------------------------------------------------------------------------- 1 | _target_: riemannian_score_sde.datasets.antibody.LoopTorus 2 | npz: "/data/ziz/not-backed-up/fishman/score-sde/data/smooth_loops_dist.npz" 3 | batch_size: ${batch_size} 4 | -------------------------------------------------------------------------------- /config/dataset/mix_vmf.yaml: -------------------------------------------------------------------------------- 1 | _target_: riemannian_score_sde.datasets.vMFMixture 2 | _convert_: partial 3 | batch_dims: 4 | - ${batch_size} 5 | mu: 6 | - [1., 0., 0.] 7 | - [0., 1., 0.] 8 | kappa: 9 | - 15. 10 | - 15. 11 | manifold: ${manifold} -------------------------------------------------------------------------------- /config/dataset/polytope-torus.yaml: -------------------------------------------------------------------------------- 1 | _target_: riemannian_score_sde.datasets.antibody.PolytopeTorus 2 | npz: ${npz} 3 | batch_size: ${batch_size} 4 | -------------------------------------------------------------------------------- /config/dataset/polytope.yaml: -------------------------------------------------------------------------------- 1 | _target_: riemannian_score_sde.datasets.antibody.Polytope 2 | npz: ${npz} 3 | batch_size: ${batch_size} 4 | -------------------------------------------------------------------------------- /config/dataset/uniform.yaml: -------------------------------------------------------------------------------- 1 | _target_: riemannian_score_sde.datasets.simple.Uniform 2 | _convert_: partial 3 | batch_dims: 4 | - ${batch_size} 5 | manifold: ${manifold} 6 | seed: ${seed} -------------------------------------------------------------------------------- /config/dataset/us_forest_fires.yaml: -------------------------------------------------------------------------------- 1 | _target_: riemannian_score_sde.datasets.earth.USForestFire 2 | _convert_: partial 3 | data_dir: ${data_dir} 4 | batch_size: ${batch_size} 5 | name: us_forest_fires -------------------------------------------------------------------------------- /config/dataset/vmf.yaml: -------------------------------------------------------------------------------- 1 | _target_: riemannian_score_sde.datasets.vMFDataset 2 | _convert_: partial 3 | batch_dims: 4 | - ${batch_size} 5 | mu: [1., 0., 0.] 6 | kappa: 15. 7 | manifold: ${manifold} -------------------------------------------------------------------------------- /config/dataset/volcanoe.yaml: -------------------------------------------------------------------------------- 1 | _target_: riemannian_score_sde.datasets.earth.VolcanicErruption 2 | _convert_: partial 3 | data_dir: ${data_dir} 4 | name: volcanoe -------------------------------------------------------------------------------- /config/dataset/wrapnormmix.yaml: -------------------------------------------------------------------------------- 1 | _target_: riemannian_score_sde.datasets.WrapNormMixtureDistribution 2 | batch_dims: ${batch_size} 3 | manifold: ${manifold} 4 | scale: 0.15 5 | mean: [[-0.8, 0.0],[0.8, 0.0],[0.0, -0.8],[0.0, 0.8]] -------------------------------------------------------------------------------- /config/dataset/wrapped.yaml: -------------------------------------------------------------------------------- 1 | _target_: riemannian_score_sde.datasets.Wrapped 2 | _convert_: partial 3 | scale_type: random 4 | scale: 100 5 | K: 1 6 | batch_dims: 7 | - ${batch_size} 8 | manifold: ${manifold} 9 | seed: ${seed} 10 | conditional: false -------------------------------------------------------------------------------- /config/embedding/laplacian_eigenfunction.yaml: -------------------------------------------------------------------------------- 1 | _target_: riemannian_score_sde.models.LaplacianEigenfunctionEmbedding 2 | 3 | n_manifold: 11 4 | n_time: 72 5 | max_t: 1. -------------------------------------------------------------------------------- /config/embedding/none.yaml: -------------------------------------------------------------------------------- 1 | _target_: riemannian_score_sde.models.NoneEmbedding -------------------------------------------------------------------------------- /config/experiment/earthquake.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python main.py experiment=s2_toy 5 | 6 | name: earthquake 7 | experiment: earthquake 8 | 9 | defaults: 10 | - earth_data 11 | - /dataset: earthquake -------------------------------------------------------------------------------- /config/experiment/fire.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python main.py experiment=s2_toy 5 | 6 | name: fire 7 | experiment: fire 8 | 9 | defaults: 10 | - earth_data 11 | - /dataset: fire -------------------------------------------------------------------------------- /config/experiment/flood.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python main.py experiment=s2_toy 5 | 6 | name: flood 7 | experiment: flood 8 | 9 | defaults: 10 | - earth_data 11 | - /dataset: flood -------------------------------------------------------------------------------- /config/experiment/hessian_dirichlet_d10.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | name: hessian_dirichlet_d10 4 | 5 | defaults: 6 | - /dataset: polytope 7 | - /manifold: polytope 8 | - /architecture: concat 9 | - /embedding: none 10 | - /model: rsgm 11 | - override /base: uniform 12 | - override /generator: ambient 13 | - override /flow: hessian 14 | - override /loss: hessian 15 | 16 | data_dir: /data/ziz/not-backed-up/fishman/score-sde/data/ 17 | npz: /data/ziz/not-backed-up/fishman/score-sde/data/dirichlet_d=10.npz 18 | 19 | beta_0: 0.001 20 | beta_f: 7 21 | 22 | N: 300 23 | std_trick: true 24 | boundary_enforce: true 25 | 26 | rpb: 1 27 | lr: 2e-4 28 | 29 | metric_type: "Hessian" 30 | p_eps: 1e-3 31 | 32 | splits: [1.0, 0.0, 0.0] 33 | batch_size: 1024 34 | warmup_steps: 1000 35 | steps: 100001 36 | val_freq: 10000 37 | ema_rate: 0.999 38 | eps: 1e-3 39 | eval_batch_size: 256 -------------------------------------------------------------------------------- /config/experiment/hessian_dirichlet_d2.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | name: hessian_dirichlet_d2 4 | 5 | defaults: 6 | - /dataset: polytope 7 | - /manifold: polytope 8 | - /architecture: concat 9 | - /embedding: none 10 | - /model: rsgm 11 | - override /base: uniform 12 | - override /generator: ambient 13 | - override /flow: hessian 14 | - override /loss: hessian 15 | 16 | data_dir: /data/ziz/not-backed-up/fishman/score-sde/data/ 17 | npz: /data/ziz/not-backed-up/fishman/score-sde/data/multimodal_dirichlet_d=2.npz 18 | 19 | beta_0: 0.001 20 | beta_f: 22 21 | 22 | N: 300 23 | std_trick: true 24 | boundary_enforce: true 25 | 26 | rpb: 1 27 | lr: 2e-4 28 | 29 | metric_type: "Hessian" 30 | p_eps: 1e-3 31 | 32 | splits: [1.0, 0.0, 0.0] 33 | batch_size: 1024 34 | warmup_steps: 1000 35 | steps: 100001 36 | val_freq: 10000 37 | ema_rate: 0.999 38 | eps: 1e-3 39 | eval_batch_size: 256 40 | -------------------------------------------------------------------------------- /config/experiment/hessian_dirichlet_d3.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | name: hessian_dirichlet_d3 4 | 5 | defaults: 6 | - /dataset: polytope 7 | - /manifold: polytope 8 | - /architecture: concat 9 | - /embedding: none 10 | - /model: rsgm 11 | - override /base: uniform 12 | - override /generator: ambient 13 | - override /flow: hessian 14 | - override /loss: hessian 15 | 16 | data_dir: /data/ziz/not-backed-up/fishman/score-sde/data/ 17 | npz: /data/ziz/not-backed-up/fishman/score-sde/data/dirichlet_d=3.npz 18 | 19 | beta_0: 0.001 20 | beta_f: 28 21 | 22 | N: 300 23 | std_trick: true 24 | boundary_enforce: true 25 | 26 | rpb: 1 27 | lr: 2e-4 28 | 29 | metric_type: "Hessian" 30 | p_eps: 1e-3 31 | 32 | splits: [1.0, 0.0, 0.0] 33 | batch_size: 1024 34 | warmup_steps: 1000 35 | steps: 100001 36 | val_freq: 10000 37 | ema_rate: 0.999 38 | eps: 1e-3 39 | eval_batch_size: 256 40 | -------------------------------------------------------------------------------- /config/experiment/hessian_hypercube_analytic_d2.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | name: hessian_hypercube_analytic_d2 4 | 5 | defaults: 6 | - /dataset: polytope 7 | - /manifold: hypercube 8 | - /architecture: concat 9 | - /embedding: none 10 | - /model: rsgm 11 | - override /base: uniform 12 | - override /generator: ambient 13 | - override /flow: hessian 14 | - override /loss: hessian 15 | 16 | flow: 17 | full_diffusion_matrix: False 18 | 19 | data_dir: /data/ziz/not-backed-up/fishman/score-sde/data/ 20 | npz: /data/ziz/not-backed-up/fishman/score-sde/data/unit_hypercube_d=2.npz 21 | eps: 1e-3 22 | 23 | beta_0: 1e-3 24 | beta_f: 9 25 | 26 | N: 300 27 | std_trick: true 28 | boundary_enforce: true 29 | 30 | rpb: 1 31 | lr: 2e-4 32 | 33 | metric_type: "Hessian" 34 | p_eps: 1e-3 35 | 36 | splits: [1.0, 0.0, 0.0] 37 | batch_size: 1024 38 | warmup_steps: 1000 39 | steps: 100001 40 | val_freq: 10000 41 | ema_rate: 0.999 42 | -------------------------------------------------------------------------------- /config/experiment/hessian_hypercube_d10.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | name: hessian_hypercube_d10 4 | 5 | defaults: 6 | - /dataset: polytope 7 | - /manifold: polytope 8 | - /architecture: concat 9 | - /embedding: none 10 | - /model: rsgm 11 | - override /base: uniform 12 | - override /generator: ambient 13 | - override /flow: hessian 14 | - override /loss: hessian 15 | 16 | data_dir: /data/ziz/not-backed-up/fishman/score-sde/data/ 17 | npz: /data/ziz/not-backed-up/fishman/score-sde/data/hypercube_d=10.npz 18 | 19 | beta_0: 1e-3 20 | beta_f: 35 21 | 22 | N: 300 23 | std_trick: true 24 | boundary_enforce: true 25 | 26 | rpb: 1 27 | lr: 2e-4 28 | 29 | metric_type: "Hessian" 30 | p_eps: 1e-3 31 | 32 | splits: [1.0, 0.0, 0.0] 33 | batch_size: 1024 34 | warmup_steps: 1000 35 | steps: 100001 36 | val_freq: 10000 37 | ema_rate: 0.999 38 | eps: 1e-3 39 | eval_batch_size: 256 40 | -------------------------------------------------------------------------------- /config/experiment/hessian_hypercube_d2.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | name: hessian_hypercube_d2 4 | 5 | defaults: 6 | - /dataset: polytope 7 | - /manifold: polytope 8 | - /architecture: concat 9 | - /embedding: none 10 | - /model: rsgm 11 | - override /base: uniform 12 | - override /generator: ambient 13 | - override /flow: hessian 14 | - override /loss: hessian 15 | 16 | data_dir: /data/ziz/not-backed-up/fishman/score-sde/data/ 17 | npz: /data/ziz/not-backed-up/fishman/score-sde/data/hypercube_d=2.npz 18 | eps: 1e-3 19 | 20 | beta_0: 1e-3 21 | beta_f: 22 22 | 23 | N: 300 24 | std_trick: true 25 | boundary_enforce: true 26 | 27 | rpb: 1 28 | lr: 2e-4 29 | 30 | metric_type: "Hessian" 31 | p_eps: 1e-3 32 | 33 | splits: [1.0, 0.0, 0.0] 34 | batch_size: 1024 35 | warmup_steps: 1000 36 | steps: 100001 37 | val_freq: 10000 38 | ema_rate: 0.999 39 | -------------------------------------------------------------------------------- /config/experiment/hessian_hypercube_d3.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | name: hessian_hypercube_d3 4 | 5 | defaults: 6 | - /dataset: polytope 7 | - /manifold: polytope 8 | - /architecture: concat 9 | - /embedding: none 10 | - /model: rsgm 11 | - override /base: uniform 12 | - override /generator: ambient 13 | - override /flow: hessian 14 | - override /loss: hessian 15 | 16 | data_dir: /data/ziz/not-backed-up/fishman/score-sde/data/ 17 | npz: /data/ziz/not-backed-up/fishman/score-sde/data/hypercube_d=3.npz 18 | eps: 1e-3 19 | 20 | beta_0: 1e-3 21 | beta_f: 28 22 | 23 | N: 300 24 | std_trick: true 25 | boundary_enforce: true 26 | 27 | rpb: 1 28 | lr: 2e-4 29 | 30 | metric_type: "Hessian" 31 | p_eps: 1e-3 32 | 33 | splits: [1.0, 0.0, 0.0] 34 | batch_size: 1024 35 | warmup_steps: 1000 36 | steps: 100001 37 | val_freq: 10000 38 | ema_rate: 0.999 39 | -------------------------------------------------------------------------------- /config/experiment/hessian_loops.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | name: hessian_loops 4 | 5 | defaults: 6 | - /dataset: polytope-torus 7 | - /manifold: polytope-torus 8 | - /architecture: concat 9 | - /embedding: none 10 | - /model: rsgm 11 | - override /generator: ambient-torus 12 | - override /pushf: product 13 | - override /flow: brownian_product2 14 | - override /loss: product_ism 15 | 16 | data_dir: /data/ziz/not-backed-up/fishman/score-sde/data/ 17 | npz: /data/ziz/not-backed-up/fishman/score-sde/data/smooth_loops_dist.npz 18 | 19 | beta_0_0: 0.001 20 | beta_f_0: 55 21 | beta_0_1: 0.001 22 | beta_f_1: 15 23 | 24 | N: 2000 25 | n_torus: 4 26 | rpb: 1 27 | 28 | std_trick: false 29 | boundary_enforce: true 30 | 31 | metric_type: "Hessian" 32 | p_eps: 1e-3 33 | 34 | lr: 2e-4 35 | splits: [1.0, 0.0, 0.0] 36 | batch_size: 256 37 | warmup_steps: 100 38 | steps: 150001 39 | val_freq: 10000 40 | ema_rate: 0.999 41 | eps: 1e-3 42 | eval_batch_size: 256 43 | -------------------------------------------------------------------------------- /config/experiment/hessian_smooth_L.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | name: hessian_smooth_L 4 | 5 | defaults: 6 | - /dataset: polytope 7 | - /manifold: polytope-and-sphere 8 | - /architecture: concat 9 | - /embedding: none 10 | - /model: rsgm 11 | - override /base: uniform 12 | - override /generator: ambient 13 | - override /flow: hessian 14 | - override /loss: hessian 15 | 16 | data_dir: /data/ziz/not-backed-up/fishman/score-sde/data/ 17 | npz: /data/ziz/not-backed-up/fishman/score-sde/data/SmoothL.npz 18 | 19 | beta_0: 0.001 20 | beta_f: 26 21 | 22 | 23 | N: 2000 24 | std_trick: true 25 | boundary_enforce: true 26 | 27 | rpb: 1 28 | lr: 2e-4 29 | 30 | metric_type: "Hessian" 31 | p_eps: 1e-3 32 | 33 | splits: [1.0, 0.0, 0.0] 34 | batch_size: 1024 35 | warmup_steps: 1000 36 | steps: 100001 37 | val_freq: 10000 38 | ema_rate: 0.999 39 | eps: 1e-3 40 | eval_batch_size: 256 -------------------------------------------------------------------------------- /config/experiment/hyperboloid.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python main.py experiment=hyperboloid 5 | 6 | name: hyperboloid 7 | 8 | defaults: 9 | - /manifold: hyperbolic 10 | - /model: rsgm 11 | - /dataset: wrapnormmix 12 | - /architecture: concat 13 | - /embedding: none 14 | - override /generator: ambient 15 | - override /flow: langevin 16 | - override /loss: ism 17 | 18 | mean: 19 | _target_: geomstats.geometry.hyperbolic.Hyperbolic._ball_to_extrinsic_coordinates 20 | point: 21 | _target_: jax.numpy.array 22 | _convert_: all 23 | object: [[-0.4, 0.0],[0.4, 0.0],[0.0, -0.4],[0.0, 0.4]] 24 | 25 | dataset: 26 | mean: ${mean} 27 | scale: [[0., 0.15, 0.5],[0., 0.15, 0.5],[0., 0.5, 0.15],[0., 0.5, 0.15]] 28 | 29 | manifold: 30 | dim: 2 31 | default_coords_type: extrinsic 32 | 33 | beta_schedule: 34 | beta_0: 0.01 35 | beta_f: 2 36 | 37 | flow: 38 | ref_scale: 0.5 39 | N: 1000 40 | 41 | batch_size: 512 42 | eval_batch_size: 512 43 | warmup_steps: ${min:1000,${eval:${steps}-1}} 44 | steps: 100000 45 | val_freq: 5000 # !!float inf 46 | ema_rate: 0.999 47 | eps: 1e-3 -------------------------------------------------------------------------------- /config/experiment/poincare.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python main.py experiment=poincare 5 | 6 | name: poincare 7 | 8 | defaults: 9 | - hyperboloid 10 | - override /generator: transport 11 | 12 | mean: 13 | _target_: jax.numpy.array 14 | _convert_: all 15 | object: [[-0.4, 0.0],[0.4, 0.0],[0.0, -0.4],[0.0, 0.4]] 16 | dataset: 17 | scale: [[0.15, 0.5],[0.15, 0.5],[0.5, 0.15],[0.5, 0.15]] 18 | 19 | manifold: 20 | default_coords_type: ball -------------------------------------------------------------------------------- /config/experiment/reflect_hypercube_analytic_d2.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | name: reflect_hypercube_analytic_d2 4 | 5 | defaults: 6 | - /dataset: polytope 7 | - /manifold: hypercube 8 | - /architecture: concat 9 | - /embedding: none 10 | - /model: rsgm 11 | - override /base: uniform 12 | - override /generator: ambient 13 | - override /flow: brownian 14 | - override /loss: ism 15 | 16 | 17 | data_dir: /data/ziz/not-backed-up/fishman/score-sde/data/ 18 | npz: /data/ziz/not-backed-up/fishman/score-sde/data/unit_hypercube_d=2.npz 19 | eps: 1e-3 20 | 21 | beta_0: 1e-3 22 | beta_f: 1.5 23 | 24 | N: 300 25 | std_trick: true 26 | boundary_enforce: true 27 | boundary_dis: 1e-3 28 | 29 | loss: 30 | w_floor: 1 31 | 32 | rpb: 1 33 | lr: 2e-4 34 | 35 | metric_type: "Reflected" 36 | p_eps: 1e-3 37 | 38 | splits: [1.0, 0.0, 0.0] 39 | batch_size: 1024 40 | warmup_steps: 1000 41 | steps: 100001 42 | val_freq: 10000 43 | ema_rate: 0.999 44 | -------------------------------------------------------------------------------- /config/experiment/s2_toy.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python main.py experiment=s2_toy 5 | 6 | name: s2_toy 7 | 8 | defaults: 9 | - /manifold: nsphere 10 | - /model: rsgm 11 | - /dataset: vmf 12 | - /architecture: concat 13 | - /embedding: none 14 | 15 | manifold: 16 | dim: 2 17 | 18 | batch_size: 512 19 | eval_batch_size: 512 20 | warmup_steps: 100 21 | steps: 5000 22 | val_freq: 500 23 | ema_rate: 0.999 24 | eps: 1e-3 -------------------------------------------------------------------------------- /config/experiment/smooth_loops.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | name: smooth_loops 4 | 5 | defaults: 6 | - /dataset: polytope 7 | - /manifold: polytope 8 | - /architecture: big_concat 9 | - /embedding: none 10 | - /model: rsgm 11 | - override /generator: ambient 12 | - override /flow: brownian 13 | - override /loss: hessian 14 | 15 | data_dir: /data/ziz/not-backed-up/fishman/score-sde/data/ 16 | npz: /data/ziz/not-backed-up/fishman/score-sde/data/smooth_loops_dist.npz 17 | 18 | 19 | beta_0: 0.001 20 | beta_f: 50 21 | 22 | splits: [1.0, 0.0, 0.0] 23 | batch_size: 1024 24 | warmup_steps: 1000 25 | steps: 150001 26 | val_freq: 5000 27 | ema_rate: 0.999 28 | eps: 1e-3 29 | eval_batch_size: 256 30 | -------------------------------------------------------------------------------- /config/experiment/so3.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | name: so3 4 | 5 | defaults: 6 | - /dataset: wrapped 7 | - /manifold: so3 8 | - /architecture: concat 9 | - /embedding: none 10 | - /model: rsgm 11 | # - override /generator: lie_algebra 12 | 13 | data_dir: /data/ziz/not-backed-up/scratch/score-sde/data/ 14 | 15 | dataset: 16 | K: 32 17 | mean: unif 18 | scale: 100 19 | scale_type: random 20 | 21 | 22 | flow: 23 | beta_schedule: 24 | beta_0: 0.001 25 | beta_f: 6 26 | 27 | splits: [0.8, 0.1, 0.1] 28 | batch_size: 512 29 | eval_batch_size: 2048 30 | warmup_steps: 100 31 | steps: 100000 32 | val_freq: 10000 33 | ema_rate: 0.999 34 | eps: 2e-4 -------------------------------------------------------------------------------- /config/experiment/spd_d3.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | name: spd_d3 4 | 5 | defaults: 6 | - /dataset: polytope 7 | - /manifold: polytope 8 | - /architecture: concat 9 | - /embedding: none 10 | - /model: rsgm 11 | - override /generator: ambient 12 | - override /flow: brownian 13 | - override /loss: hessian 14 | 15 | data_dir: /data/ziz/not-backed-up/fishman/score-sde/data/ 16 | npz: /data/ziz/not-backed-up/fishman/score-sde/data/L.npz 17 | 18 | 19 | beta_0: 0.001 20 | beta_f: 15 21 | 22 | splits: [1.0, 0.0, 0.0] 23 | batch_size: 512 24 | warmup_steps: 1000 25 | steps: 50001 26 | val_freq: 10000 27 | ema_rate: 0.999 28 | eps: 1e-3 29 | eval_batch_size: 256 30 | -------------------------------------------------------------------------------- /config/experiment/tn.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | name: tn 4 | 5 | defaults: 6 | - /dataset: wrapped 7 | - /manifold: tn 8 | - /architecture: concat 9 | - /embedding: none 10 | - /model: rsgm 11 | - override /base: uniform 12 | - override /generator: torus 13 | 14 | data_dir: /data/ziz/not-backed-up/scratch/score-sde/data/ 15 | 16 | n: 5 17 | 18 | dataset: 19 | scale: 0.2 20 | scale_type: fixed 21 | mean: unif 22 | K: 1 23 | 24 | 25 | 26 | beta_0: 0.001 27 | beta_f: 15 28 | 29 | splits: [0.8, 0.1, 0.1] 30 | batch_size: 512 31 | warmup_steps: 100 32 | steps: 50000 33 | val_freq: 1000 34 | ema_rate: 0.999 35 | eps: 1e-3 36 | -------------------------------------------------------------------------------- /config/experiment/volcanoe.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python main.py experiment=s2_toy 5 | 6 | name: volcanoe 7 | experiment: volcanoe 8 | 9 | defaults: 10 | - earth_data 11 | - /dataset: volcanoe -------------------------------------------------------------------------------- /config/flow/brownian.yaml: -------------------------------------------------------------------------------- 1 | _target_: riemannian_score_sde.sde.Brownian 2 | N: ${N} 3 | std_trick: ${std_trick} 4 | boundary_enforce: ${boundary_enforce} 5 | boundary_dis: ${boundary_dis} -------------------------------------------------------------------------------- /config/flow/brownian_hessian_product2.yaml: -------------------------------------------------------------------------------- 1 | _target_: score_sde.product.ProductSDE 2 | sdes: 3 | - _target_: riemannian_score_sde.sde.HessianSDE 4 | - _target_: riemannian_score_sde.sde.Brownian 5 | N: ${N} 6 | boundary_enforce: ${boundary_enforce} 7 | std_trick: ${std_trick} -------------------------------------------------------------------------------- /config/flow/brownian_product2.yaml: -------------------------------------------------------------------------------- 1 | _target_: score_sde.product.ProductSDE 2 | sdes: 3 | - _target_: riemannian_score_sde.sde.Brownian 4 | - _target_: riemannian_score_sde.sde.Brownian 5 | N: ${N} 6 | boundary_enforce: ${boundary_enforce} 7 | std_trick: ${std_trick} -------------------------------------------------------------------------------- /config/flow/cnf.yaml: -------------------------------------------------------------------------------- 1 | _target_: score_sde.models.CNF 2 | get_drift_fn: 3 | _target_: hydra.utils.get_method 4 | path: score_sde.models.get_ode_drift_fn 5 | t0: 0 6 | tf: 1 7 | hutchinson_type: None 8 | rtol: 1e-5 9 | atol: 1e-5 -------------------------------------------------------------------------------- /config/flow/hessian.yaml: -------------------------------------------------------------------------------- 1 | _target_: riemannian_score_sde.sde.HessianSDE 2 | N: ${N} 3 | std_trick: ${std_trick} 4 | boundary_enforce: ${boundary_enforce} -------------------------------------------------------------------------------- /config/flow/hessian2k.yaml: -------------------------------------------------------------------------------- 1 | _target_: riemannian_score_sde.sde.HessianSDE 2 | N: 2000 -------------------------------------------------------------------------------- /config/flow/langevin.yaml: -------------------------------------------------------------------------------- 1 | _target_: riemannian_score_sde.sde.Langevin 2 | ref_scale: 0.5 3 | N: 100 -------------------------------------------------------------------------------- /config/flow/product.yaml: -------------------------------------------------------------------------------- 1 | _target_: score_sde.product.ProductSDE 2 | sdes: 3 | - _target_: riemannian_score_sde.sde.Brownian 4 | - _target_: riemannian_score_sde.sde.Hessian 5 | N: 1000 -------------------------------------------------------------------------------- /config/flow/reflectedbrownian.yaml: -------------------------------------------------------------------------------- 1 | _target_: riemannian_score_sde.sde.ReflectedBrownian 2 | N: ${N} 3 | std_trick: ${std_trick} 4 | boundary_enforce: ${boundary_enforce} 5 | boundary_dis: ${boundary_dis} -------------------------------------------------------------------------------- /config/flow/vp.yaml: -------------------------------------------------------------------------------- 1 | _target_: riemannian_score_sde.sde.VPSDE -------------------------------------------------------------------------------- /config/generator/ambient-torus.yaml: -------------------------------------------------------------------------------- 1 | _target_: riemannian_score_sde.models.AmbientTorusGenerator -------------------------------------------------------------------------------- /config/generator/ambient.yaml: -------------------------------------------------------------------------------- 1 | _target_: riemannian_score_sde.models.AmbientGenerator -------------------------------------------------------------------------------- /config/generator/canonical.yaml: -------------------------------------------------------------------------------- 1 | _target_: riemannian_score_sde.models.CanonicalGenerator -------------------------------------------------------------------------------- /config/generator/div_free.yaml: -------------------------------------------------------------------------------- 1 | _target_: riemannian_score_sde.models.DivFreeGenerator -------------------------------------------------------------------------------- /config/generator/eigen.yaml: -------------------------------------------------------------------------------- 1 | _target_: riemannian_score_sde.models.EigenGenerator -------------------------------------------------------------------------------- /config/generator/lie_algebra.yaml: -------------------------------------------------------------------------------- 1 | _target_: riemannian_score_sde.models.LieAlgebraGenerator -------------------------------------------------------------------------------- /config/generator/torus.yaml: -------------------------------------------------------------------------------- 1 | _target_: riemannian_score_sde.models.TorusGenerator -------------------------------------------------------------------------------- /config/generator/transport.yaml: -------------------------------------------------------------------------------- 1 | _target_: riemannian_score_sde.models.ParallelTransportGenerator -------------------------------------------------------------------------------- /config/logger/all.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - csv 3 | # - tensorboard 4 | - wandb -------------------------------------------------------------------------------- /config/logger/csv.yaml: -------------------------------------------------------------------------------- 1 | csv: 2 | _target_: score_sde.utils.loggers_pl.CSVLogger 3 | save_dir: logs 4 | name: "" 5 | # version: "" 6 | flush_logs_every_n_steps: 1000 -------------------------------------------------------------------------------- /config/logger/wandb.yaml: -------------------------------------------------------------------------------- 1 | wandb: 2 | _target_: score_sde.utils.loggers_pl.WandbLogger 3 | save_dir: null # /data/localhost/not-backed-up/mhutchin/score-sde/logs #${logs_dir} 4 | name: ${str:${seed}} 5 | group: ${name}_${hydra:job.override_dirname} 6 | entity: oxcsml 7 | project: diffusion_manifold 8 | offline: False -------------------------------------------------------------------------------- /config/loss/dsm0.yaml: -------------------------------------------------------------------------------- 1 | _target_: riemannian_score_sde.losses.get_dsm_loss_fn 2 | like_w: false 3 | eps: ${eps} 4 | thresh: 0.5 5 | n_max: 5 6 | -------------------------------------------------------------------------------- /config/loss/dsms.yaml: -------------------------------------------------------------------------------- 1 | _target_: riemannian_score_sde.losses.get_dsm_loss_fn 2 | like_w: false 3 | eps: ${eps} 4 | thresh: 0.1 5 | n_max: 0 6 | s_zero: false -------------------------------------------------------------------------------- /config/loss/dsmv.yaml: -------------------------------------------------------------------------------- 1 | _target_: riemannian_score_sde.losses.get_dsm_loss_fn 2 | like_w: false 3 | eps: ${eps} 4 | thresh: 1. 5 | n_max: -1 6 | -------------------------------------------------------------------------------- /config/loss/euclidean_ism.yaml: -------------------------------------------------------------------------------- 1 | _target_: score_sde.losses.get_ism_loss_fn 2 | like_w: true 3 | hutchinson_type: None 4 | eps: ${eps} 5 | -------------------------------------------------------------------------------- /config/loss/hessian.yaml: -------------------------------------------------------------------------------- 1 | _target_: riemannian_score_sde.losses.get_hessian_loss_fn 2 | like_w: true 3 | sqrt_like_w: false 4 | hutchinson_type: None 5 | eps: ${eps} 6 | repeats_per_batch: ${rpb} 7 | w_floor: 1 8 | -------------------------------------------------------------------------------- /config/loss/hessian_sqrt_like_w.yaml: -------------------------------------------------------------------------------- 1 | _target_: riemannian_score_sde.losses.get_hessian_loss_fn 2 | like_w: false 3 | sqrt_like_w: true 4 | hutchinson_type: None 5 | eps: ${eps} 6 | -------------------------------------------------------------------------------- /config/loss/ism.yaml: -------------------------------------------------------------------------------- 1 | _target_: riemannian_score_sde.losses.get_ism_loss_fn 2 | like_w: true 3 | hutchinson_type: None 4 | eps: ${eps} 5 | repeats_per_batch: ${rpb} 6 | w_floor: 0 -------------------------------------------------------------------------------- /config/loss/logp.yaml: -------------------------------------------------------------------------------- 1 | _target_: score_sde.losses.get_logp_loss_fn -------------------------------------------------------------------------------- /config/loss/moser.yaml: -------------------------------------------------------------------------------- 1 | _target_: riemannian_score_sde.losses.get_moser_loss_fn 2 | alpha_m: 100 3 | alpha_p: 0 4 | K: 10000 5 | hutchinson_type: None 6 | eps: 1e-5 -------------------------------------------------------------------------------- /config/loss/product_ism.yaml: -------------------------------------------------------------------------------- 1 | _target_: riemannian_score_sde.losses.get_product_ism_loss_fn 2 | like_w: true 3 | w_floor: 1 4 | repeats_per_batch: ${rpb} 5 | hutchinson_type: None 6 | -------------------------------------------------------------------------------- /config/loss/ssm.yaml: -------------------------------------------------------------------------------- 1 | _target_: riemannian_score_sde.losses.get_ism_loss_fn 2 | like_w: true 3 | hutchinson_type: Rademacher 4 | eps: ${eps} 5 | -------------------------------------------------------------------------------- /config/main.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | # GENERAL # 3 | - _self_ 4 | 5 | - server: local 6 | - experiment: s2_toy 7 | 8 | - logger: csv 9 | 10 | - optim: adam 11 | - scheduler: rcosine 12 | - beta_schedule: linear 13 | 14 | # enable color logging 15 | - override hydra/hydra_logging: colorlog 16 | - override hydra/job_logging: colorlog 17 | 18 | eval_batch_size: ${batch_size} 19 | now: ${now:%Y-%m-%d}/${now:%H-%M-%S} 20 | 21 | resume: false 22 | mode: all 23 | seed: 0 24 | PROJECT_NAME: score-sde 25 | work_dir: ${hydra:runtime.cwd} 26 | 27 | # path to folder with data 28 | data_dir: ${work_dir}/data/ 29 | ckpt_dir: ckpt 30 | logs_dir: logs 31 | 32 | logdir: ${work_dir}/results 33 | # rundir: ${dataset.name}/${model.name} 34 | 35 | # perform actions in the val loop during training 36 | train_val: false 37 | train_plot: true 38 | 39 | # perform certain metrics in test mode 40 | test_val: true 41 | test_test: true 42 | test_plot: true 43 | 44 | save_freq: 5000 45 | -------------------------------------------------------------------------------- /config/manifold/bounded-sphere.yaml: -------------------------------------------------------------------------------- 1 | _target_: geomstats.geometry.with_boundary.us_border.BoundedHypersphere 2 | dim: 2 -------------------------------------------------------------------------------- /config/manifold/chain1.yaml: -------------------------------------------------------------------------------- 1 | _target_: geomstats.geometry.product_manifold.ProductManifold 2 | manifolds: 3 | - _target_: geomstats.geometry.polytope.Polytope 4 | n_links: 16 5 | - _target_: geomstats.geometry.product_manifold.ProductSameManifold 6 | manifold: 7 | _target_: geomstats.geometry.hypersphere.Hypersphere 8 | dim: 1 9 | mul: 15 -------------------------------------------------------------------------------- /config/manifold/hyperbolic.yaml: -------------------------------------------------------------------------------- 1 | _target_: geomstats.geometry.hyperbolic.Hyperbolic 2 | dim: 2 3 | default_coords_type: ball -------------------------------------------------------------------------------- /config/manifold/hypercube.yaml: -------------------------------------------------------------------------------- 1 | _target_: geomstats.geometry.with_boundary.hypercube.Hypercube 2 | metric_type: ${metric_type} 3 | dim: 2 -------------------------------------------------------------------------------- /config/manifold/nsphere.yaml: -------------------------------------------------------------------------------- 1 | _target_: geomstats.geometry.hypersphere.Hypersphere 2 | dim: 2 -------------------------------------------------------------------------------- /config/manifold/polytope-and-sphere.yaml: -------------------------------------------------------------------------------- 1 | _target_: geomstats.geometry.with_boundary.polytope_and_sphere.PolytopeAndSphere 2 | npz: ${npz} 3 | metric_type: ${metric_type} 4 | eps: ${p_eps} -------------------------------------------------------------------------------- /config/manifold/polytope-torus.yaml: -------------------------------------------------------------------------------- 1 | - _target_: geomstats.geometry.with_boundary.polytope.Polytope 2 | npz: ${npz} 3 | metric_type: ${metric_type} 4 | eps: ${p_eps} 5 | - _target_: geomstats.geometry.product_manifold.ProductSameManifold 6 | manifold: 7 | _target_: geomstats.geometry.hypersphere.Hypersphere 8 | dim: 1 9 | mul: ${n_torus} 10 | default_point_type: vector -------------------------------------------------------------------------------- /config/manifold/polytope.yaml: -------------------------------------------------------------------------------- 1 | _target_: geomstats.geometry.with_boundary.polytope.Polytope 2 | npz: ${npz} 3 | metric_type: ${metric_type} 4 | eps: ${p_eps} -------------------------------------------------------------------------------- /config/manifold/s1.yaml: -------------------------------------------------------------------------------- 1 | _target_: geomstats.geometry.hypersphere.Hypersphere 2 | dim: 1 3 | -------------------------------------------------------------------------------- /config/manifold/so3.yaml: -------------------------------------------------------------------------------- 1 | _target_: geomstats.geometry.special_orthogonal.SpecialOrthogonal 2 | n: 3 3 | # point_type: matrix 4 | point_type: vector 5 | -------------------------------------------------------------------------------- /config/manifold/tn.yaml: -------------------------------------------------------------------------------- 1 | _target_: geomstats.geometry.product_manifold.ProductSameManifold 2 | manifold: 3 | _target_: geomstats.geometry.hypersphere.Hypersphere 4 | dim: 1 5 | mul: ${n_torus} 6 | default_point_type: vector -------------------------------------------------------------------------------- /config/model/cnf.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - /transform: id 5 | - /flow: cnf 6 | - /base: default 7 | - /pushf: default 8 | - /loss: logp 9 | - /generator: div_free -------------------------------------------------------------------------------- /config/model/exp_sgm.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - /transform: exp 5 | - /flow: vp 6 | - /base: default 7 | - /pushf: sde 8 | - /loss: dsm0 9 | - /generator: canonical -------------------------------------------------------------------------------- /config/model/moser.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - /transform: id 5 | - /flow: cnf 6 | - /base: default 7 | - /pushf: moser 8 | - /loss: moser 9 | - /generator: div_free -------------------------------------------------------------------------------- /config/model/rsgm.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - /transform: id 5 | - /flow: brownian 6 | - /base: default 7 | - /pushf: sde 8 | - /loss: dsm0 9 | - /generator: div_free -------------------------------------------------------------------------------- /config/model/stereo_sgm.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - /transform: stereo 5 | - /flow: vp 6 | - /base: default 7 | - /pushf: sde 8 | - /loss: dsm0 9 | - /generator: ambient -------------------------------------------------------------------------------- /config/model/tanhexp_sgm.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - /transform: tanhexp 5 | - /flow: vp 6 | - /base: default 7 | - /pushf: sde 8 | - /loss: dsm0 9 | - /generator: canonical -------------------------------------------------------------------------------- /config/optim/adam.yaml: -------------------------------------------------------------------------------- 1 | # base_lr: # base learning rate, rescaled by batch_size/256 2 | # object: 3 | _target_: optax.adam 4 | learning_rate: ${lr} # ${eval:${optim.base_lr}*${batch_size}/256} 5 | b1: .9 6 | b2: 0.999 7 | eps: 1e-8 8 | -------------------------------------------------------------------------------- /config/pushf/default.yaml: -------------------------------------------------------------------------------- 1 | _target_: score_sde.models.PushForward -------------------------------------------------------------------------------- /config/pushf/moser.yaml: -------------------------------------------------------------------------------- 1 | _target_: score_sde.models.MoserFlow 2 | eps: 1e-5 3 | diffeq: True -------------------------------------------------------------------------------- /config/pushf/product.yaml: -------------------------------------------------------------------------------- 1 | _target_: score_sde.models.flow.ProductSDEPushForward 2 | diffeq: sde -------------------------------------------------------------------------------- /config/pushf/sde.yaml: -------------------------------------------------------------------------------- 1 | _target_: score_sde.models.flow.SDEPushForward 2 | diffeq: sde 3 | -------------------------------------------------------------------------------- /config/scheduler/constant.yaml: -------------------------------------------------------------------------------- 1 | _target_: optax.linear_schedule 2 | init_value: 1.0 3 | end_value: 1.0 4 | transition_steps: ${warmup_steps} -------------------------------------------------------------------------------- /config/scheduler/cosine.yaml: -------------------------------------------------------------------------------- 1 | _target_: optax.cosine_decay_schedule 2 | init_value: 1.0 3 | decay_steps: ${steps} 4 | alpha: 0.0 -------------------------------------------------------------------------------- /config/scheduler/r3cosine.yaml: -------------------------------------------------------------------------------- 1 | _target_: optax.join_schedules 2 | schedules: 3 | - _target_: optax.linear_schedule 4 | init_value: 0.0 5 | end_value: 1.0 6 | transition_steps: ${warmup_steps} 7 | - _target_: optax.cosine_decay_schedule 8 | init_value: 1.0 9 | decay_steps: ${int:${eval:${eval:${steps}-${warmup_steps}}/3}} 10 | alpha: 0.1 11 | - _target_: optax.cosine_decay_schedule 12 | init_value: 0.1 13 | decay_steps: ${int:${eval:${eval:${steps}-${warmup_steps}}/3}} 14 | alpha: 0.1 15 | - _target_: optax.cosine_decay_schedule 16 | init_value: 0.01 17 | decay_steps: ${int:${eval:${eval:${steps}-${warmup_steps}}/3}} 18 | alpha: 0.1 19 | 20 | boundaries: 21 | - ${warmup_steps} 22 | - ${eval:${warmup_steps} + ${int:${eval:${eval:${steps}-${warmup_steps}}/3}}} 23 | - ${eval:${warmup_steps} + ${eval:2*${int:${eval:${eval:${steps}-${warmup_steps}}/3}}}} -------------------------------------------------------------------------------- /config/scheduler/ramp.yaml: -------------------------------------------------------------------------------- 1 | _target_: optax.linear_schedule 2 | init_value: 0.0 3 | end_value: 1.0 4 | transition_steps: ${warmup_steps} -------------------------------------------------------------------------------- /config/scheduler/rcosine.yaml: -------------------------------------------------------------------------------- 1 | _target_: optax.join_schedules 2 | schedules: 3 | - _target_: optax.linear_schedule 4 | init_value: 0.0 5 | end_value: 1.0 6 | transition_steps: ${warmup_steps} 7 | - _target_: optax.cosine_decay_schedule 8 | init_value: 1.0 9 | decay_steps: ${eval:${steps}-${warmup_steps}} 10 | alpha: 0.0 11 | 12 | boundaries: 13 | - ${warmup_steps} -------------------------------------------------------------------------------- /config/scheduler/rloglinear.yaml: -------------------------------------------------------------------------------- 1 | _target_: optax.join_schedules 2 | schedules: 3 | - _target_: optax.linear_schedule 4 | init_value: 0.0 5 | end_value: 1.0 6 | transition_steps: ${warmup_steps} 7 | - _target_: score_sde.utils.schedule.loglinear_schedule 8 | init_value: 1.0 9 | end_value: 1e-5 10 | decay_steps: ${eval:${steps}-${warmup_steps}} 11 | 12 | boundaries: 13 | - ${warmup_steps} -------------------------------------------------------------------------------- /config/server/base.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | num_workers: 8 4 | 5 | paths: 6 | experiments: results 7 | 8 | hydra: 9 | sweep: 10 | dir: ./${paths.experiments}/${name}/${hydra.job.override_dirname} 11 | subdir: ${seed} 12 | # subdir: ${now:%Y-%m-%d_%H-%M-%S}_${hydra.job.num}_${seed} 13 | run: 14 | dir: ./${paths.experiments}/${name}/${hydra.job.override_dirname}/${seed} 15 | # dir: ./${paths.experiments}/${name}/${hydra.job.override_dirname}/${now:%Y-%m-%d_%H-%M-%S}_${seed} 16 | 17 | job_logging: 18 | formatters: 19 | simple: 20 | format: '[%(levelname)s] - %(message)s' 21 | handlers: 22 | file: 23 | filename: run.log 24 | root: 25 | handlers: [console, file] 26 | 27 | job: 28 | config: 29 | # configuration for the ${hydra.job.override_dirname} runtime variable 30 | override_dirname: 31 | exclude_keys: [name, experiment, server, seed, run, resume, num_workers, num_gpus, val_freq, logger, mode, n_jobs, test_val, test_test, test_plot] 32 | 33 | # git: 34 | # commit: ${git_commit:} 35 | # diff: ${bool:${git_diff:}} -------------------------------------------------------------------------------- /config/server/debug_ziz.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | 4 | defaults: 5 | - base_ziz 6 | - override /hydra/launcher: submitit_local -------------------------------------------------------------------------------- /config/server/local.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - base 5 | - override /hydra/launcher: joblib 6 | 7 | n_jobs: 2 8 | 9 | hydra: 10 | launcher: 11 | n_jobs: ${n_jobs} -------------------------------------------------------------------------------- /config/server/ziz.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - base 4 | - override /hydra/launcher: submitit_slurm 5 | 6 | # paths: 7 | # experiments: /data/ziz/not-backed-up/scratch/${oc.env:USER}/${PROJECT_NAME}/results 8 | 9 | n_jobs: 8 10 | num_gpus: 1 11 | 12 | hydra: 13 | # job: 14 | # env_set: 15 | # XLA_FLAGS: --xla_gpu_cuda_data_dir=/opt/cuda11.1 16 | launcher: 17 | submitit_folder: ${hydra.sweep.dir}/.submitit/%j 18 | timeout_min: 10000 19 | cpus_per_task: ${num_gpus} 20 | tasks_per_node: 1 21 | mem_gb: 10 22 | name: ${hydra.job.name} 23 | partition: high-bigbayes-gpu 24 | max_num_timeout: 0 25 | array_parallelism: ${n_jobs} 26 | setup: ["export XLA_FLAGS='--xla_gpu_cuda_data_dir=/opt/cuda'", "export PATH=/opt/cuda/bin/:$PATH", "export LD_LIBRARY_PATH=/opt/cuda/lib64:$LD_LIBRARY_PATH", "export XLA_PYTHON_CLIENT_PREALLOCATE=false", "export GEOMSTATS_BACKEND=jax", "wandb login 49e5475b053130664d1d6455d41081cae0baea0c"] 27 | # setup: ["export XLA_PYTHON_CLIENT_PREALLOCATE=false"] 28 | # executable: /data/ziz/not-backed-up/fishman/miniconda3/envs/score-sde/bin/python3 29 | additional_parameters: { 30 | "clusters": "srf_gpu_01", 31 | "wckey": "wck_${oc.env:USER}", 32 | "gres": "gpu:${num_gpus}", 33 | "exclude": "zizgpu05.cpu.stats.ox.ac.uk" 34 | # "nodelist": ["zizgpu02.cpu.stats.ox.ac.uk", "zizgpu03.cpu.stats.ox.ac.uk"] 35 | # "nodelist": "zizgpu04.cpu.stats.ox.ac.uk" 36 | } 37 | -------------------------------------------------------------------------------- /config/transform/exp.yaml: -------------------------------------------------------------------------------- 1 | _target_: riemannian_score_sde.models.ExpMap -------------------------------------------------------------------------------- /config/transform/id.yaml: -------------------------------------------------------------------------------- 1 | _target_: score_sde.models.Id -------------------------------------------------------------------------------- /config/transform/stereo.yaml: -------------------------------------------------------------------------------- 1 | _target_: riemannian_score_sde.models.InvStereographic -------------------------------------------------------------------------------- /config/transform/tanhexp.yaml: -------------------------------------------------------------------------------- 1 | _target_: riemannian_score_sde.models.TanhExpMap -------------------------------------------------------------------------------- /data/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxcsml/constrained-diffusion/fbeee12df178ed2e08dd0596e791f0d32a3a8c6a/data/.gitkeep -------------------------------------------------------------------------------- /data/L.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxcsml/constrained-diffusion/fbeee12df178ed2e08dd0596e791f0d32a3a8c6a/data/L.npz -------------------------------------------------------------------------------- /data/b.walk.csv: -------------------------------------------------------------------------------- 1 | 7.586351694846405813e+00 2 | -2.234418872170351733e-03 3 | 3.794293056859288082e+00 4 | 3.794293056859288082e+00 5 | -3.794293056859288082e+00 6 | 3.790693136141552788e+00 7 | 3.790693136141552788e+00 8 | -3.790693136141552788e+00 9 | 3.792526958232480094e+00 10 | 3.792526958232480094e+00 11 | -3.792526958232480094e+00 12 | 3.821897839671521524e+00 13 | 3.821897839671521524e+00 14 | -3.821897839671521524e+00 15 | 3.790441868329621045e+00 16 | 3.790441868329621045e+00 17 | -3.790441868329621045e+00 18 | 3.834013022556145511e+00 19 | 3.834013022556145511e+00 20 | -3.834013022556145511e+00 21 | 3.799120025749744389e+00 22 | 3.799120025749744389e+00 23 | -3.799120025749744389e+00 24 | 3.822541087521799774e+00 25 | 3.822541087521799774e+00 26 | -3.822541087521799774e+00 27 | 3.784572192794873668e+00 28 | 3.784572192794873668e+00 29 | -3.784572192794873668e+00 30 | 3.788124062823500005e+00 31 | 3.788124062823500005e+00 32 | -3.788124062823500005e+00 33 | 3.810476004299380826e+00 34 | 3.810476004299380826e+00 35 | -3.810476004299380826e+00 36 | 3.800629681976757546e+00 37 | 3.800629681976757546e+00 38 | -3.800629681976757546e+00 39 | 3.814715804485460549e+00 40 | 3.814715804485460549e+00 41 | -3.814715804485460549e+00 42 | 3.792012336975473463e+00 43 | 3.792012336975473463e+00 44 | -3.792012336975473463e+00 45 | 9.441478624189253210e+00 46 | -1.812047015218331669e+00 47 | -------------------------------------------------------------------------------- /data/big_hypercube_d=10.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxcsml/constrained-diffusion/fbeee12df178ed2e08dd0596e791f0d32a3a8c6a/data/big_hypercube_d=10.npz -------------------------------------------------------------------------------- /data/big_hypercube_d=3.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxcsml/constrained-diffusion/fbeee12df178ed2e08dd0596e791f0d32a3a8c6a/data/big_hypercube_d=3.npz -------------------------------------------------------------------------------- /data/birkhoff_d=4.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxcsml/constrained-diffusion/fbeee12df178ed2e08dd0596e791f0d32a3a8c6a/data/birkhoff_d=4.npz -------------------------------------------------------------------------------- /data/dirichlet=15.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxcsml/constrained-diffusion/fbeee12df178ed2e08dd0596e791f0d32a3a8c6a/data/dirichlet=15.npz -------------------------------------------------------------------------------- /data/dirichlet_d=10.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxcsml/constrained-diffusion/fbeee12df178ed2e08dd0596e791f0d32a3a8c6a/data/dirichlet_d=10.npz -------------------------------------------------------------------------------- /data/dirichlet_d=15.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxcsml/constrained-diffusion/fbeee12df178ed2e08dd0596e791f0d32a3a8c6a/data/dirichlet_d=15.npz -------------------------------------------------------------------------------- /data/dirichlet_d=2.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxcsml/constrained-diffusion/fbeee12df178ed2e08dd0596e791f0d32a3a8c6a/data/dirichlet_d=2.npz -------------------------------------------------------------------------------- /data/dirichlet_d=3.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxcsml/constrained-diffusion/fbeee12df178ed2e08dd0596e791f0d32a3a8c6a/data/dirichlet_d=3.npz -------------------------------------------------------------------------------- /data/hypercube_d=10.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxcsml/constrained-diffusion/fbeee12df178ed2e08dd0596e791f0d32a3a8c6a/data/hypercube_d=10.npz -------------------------------------------------------------------------------- /data/hypercube_d=2.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxcsml/constrained-diffusion/fbeee12df178ed2e08dd0596e791f0d32a3a8c6a/data/hypercube_d=2.npz -------------------------------------------------------------------------------- /data/hypercube_d=3.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxcsml/constrained-diffusion/fbeee12df178ed2e08dd0596e791f0d32a3a8c6a/data/hypercube_d=3.npz -------------------------------------------------------------------------------- /data/multimodal_dirichlet_d=2.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxcsml/constrained-diffusion/fbeee12df178ed2e08dd0596e791f0d32a3a8c6a/data/multimodal_dirichlet_d=2.npz -------------------------------------------------------------------------------- /data/multimodal_hypercube_d=2.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxcsml/constrained-diffusion/fbeee12df178ed2e08dd0596e791f0d32a3a8c6a/data/multimodal_hypercube_d=2.npz -------------------------------------------------------------------------------- /data/smooth_loops_dist.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxcsml/constrained-diffusion/fbeee12df178ed2e08dd0596e791f0d32a3a8c6a/data/smooth_loops_dist.npz -------------------------------------------------------------------------------- /data/unit_hypercube_d=2.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxcsml/constrained-diffusion/fbeee12df178ed2e08dd0596e791f0d32a3a8c6a/data/unit_hypercube_d=2.npz -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import hydra 3 | 4 | # from score_sde.utils.cfg import * 5 | 6 | 7 | @hydra.main(config_path="config", config_name="main") 8 | def main(cfg): 9 | os.environ["GEOMSTATS_BACKEND"] = "jax" 10 | os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" 11 | os.environ["WANDB_START_METHOD"] = "thread" 12 | # os.environ["JAX_ENABLE_X64"] = "True" 13 | 14 | from run import run 15 | 16 | return run(cfg) 17 | 18 | 19 | if __name__ == "__main__": 20 | main() 21 | -------------------------------------------------------------------------------- /main_product.py: -------------------------------------------------------------------------------- 1 | import os 2 | import hydra 3 | 4 | # from score_sde.utils.cfg import * 5 | 6 | 7 | @hydra.main(config_path="config", config_name="main") 8 | def main(cfg): 9 | os.environ["GEOMSTATS_BACKEND"] = "jax" 10 | os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" 11 | os.environ["WANDB_START_METHOD"] = "thread" 12 | # os.environ["JAX_ENABLE_X64"] = "True" 13 | 14 | from run_product import run 15 | 16 | return run(cfg) 17 | 18 | 19 | if __name__ == "__main__": 20 | main() 21 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 2 | dm-haiku 3 | optax 4 | distrax 5 | diffrax 6 | fsspec[http]>=2021.05.0, !=2021.06.0 7 | autograd 8 | tqdm 9 | joblib 10 | cvxpy 11 | -------------------------------------------------------------------------------- /requirements_dev.txt: -------------------------------------------------------------------------------- 1 | black 2 | ipykernel 3 | tensorflow 4 | tensorflow_datasets -------------------------------------------------------------------------------- /requirements_exps.txt: -------------------------------------------------------------------------------- 1 | tqdm 2 | seaborn 3 | wandb 4 | hydra_core 5 | hydra_colorlog 6 | submitit 7 | hydra-submitit 8 | hydra-joblib-launche 9 | -------------------------------------------------------------------------------- /riemannian_score_sde/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .earth import * 2 | from .simple import * 3 | from .mixture import * 4 | -------------------------------------------------------------------------------- /riemannian_score_sde/datasets/antibody.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as np 3 | 4 | 5 | class Polytope: 6 | def __init__(self, npz, scale=None, rng=None, batch_size=64): 7 | if rng is None: 8 | rng = jax.random.PRNGKey(0) 9 | self.rng = rng 10 | self.data = dict(np.load(npz)) 11 | self.data["data"] = self.data["r"][:, 1:-1] if "walk" in npz else self.data["r"] 12 | self.batch_size = batch_size 13 | 14 | def __len__(self): 15 | return self.data["data"].shape[0] 16 | 17 | def __getitem__(self, idx): 18 | return self.data["data"][idx], None # , self.data['seq'][idx] 19 | 20 | def __next__(self): 21 | self.rng, next_rng = jax.random.split(self.rng) 22 | 23 | idx = jax.random.choice(next_rng, len(self), shape=(self.batch_size,)) 24 | 25 | return self[idx] 26 | 27 | def get_all(self): 28 | return self.data["data"], None # , self.data['seq'] 29 | 30 | 31 | class PolytopeTorus: 32 | def __init__(self, npz, scale=None, rng=None, batch_size=64): 33 | if rng is None: 34 | rng = jax.random.PRNGKey(0) 35 | self.rng = rng 36 | self.data = dict(np.load(npz)) 37 | self.data["polytope"] = self.data["r"] 38 | self.data["torus"] = np.stack( 39 | [np.cos(self.data["tau"]), np.sin(self.data["tau"])], axis=-1 40 | ).reshape(-1, 2 * self.data["tau"].shape[1]) 41 | 42 | self.batch_size = batch_size 43 | 44 | def __len__(self): 45 | return self.data["polytope"].shape[0] 46 | 47 | def __getitem__(self, idx): 48 | return [ 49 | self.data["polytope"][idx], 50 | self.data["torus"][idx], 51 | ], None # , self.data['seq'][idx] 52 | 53 | def __next__(self): 54 | self.rng, next_rng = jax.random.split(self.rng) 55 | 56 | idx = jax.random.choice(next_rng, len(self), shape=(self.batch_size,)) 57 | 58 | return self[idx] 59 | 60 | def get_all(self): 61 | return [self.data["polytope"], self.data["torus"]], None # , self.data['seq'] 62 | 63 | 64 | class LoopTorus: 65 | def __init__(self, npz, scale=None, rng=None, batch_size=64): 66 | if rng is None: 67 | rng = jax.random.PRNGKey(0) 68 | self.rng = rng 69 | self.data = dict(np.load(npz)) 70 | self.data["polytope"] = self.data["r"] 71 | self.data["torus"] = np.stack( 72 | [np.cos(self.data["tau"]), np.sin(self.data["tau"])], axis=-1 73 | ).reshape(-1, 2 * self.data["tau"].shape[1]) 74 | 75 | self.batch_size = batch_size 76 | 77 | def __len__(self): 78 | return self.data["polytope"].shape[0] 79 | 80 | def __getitem__(self, idx): 81 | return self.data["torus"][idx], None 82 | 83 | def __next__(self): 84 | self.rng, next_rng = jax.random.split(self.rng) 85 | 86 | idx = jax.random.choice(next_rng, len(self), shape=(self.batch_size,)) 87 | 88 | return self[idx] 89 | 90 | def get_all(self): 91 | return self.data["torus"], None # , self.data['seq'] 92 | 93 | 94 | class Antibody: 95 | def __init__(self, scale=None, n_links=16, rng=None, batch_size=64): 96 | if rng is None: 97 | rng = jax.random.PRNGKey(0) 98 | self.rng = rng 99 | self.n_links = n_links 100 | self.data = dict( 101 | np.load( 102 | f"/data/ziz/not-backed-up/fishman/score-sde/data/walk.0.{n_links}.npz" 103 | ) 104 | ) 105 | self.data["data"] = np.hstack( 106 | [ 107 | self.data["r"][:, 1:-1], 108 | np.stack( 109 | [np.cos(self.data["tau"]), np.sin(self.data["tau"])], axis=-1 110 | ).reshape(-1, 2 * self.data["tau"].shape[1]), 111 | ] 112 | ) 113 | self.data["seq"] = self.data["seq"].astype(float) 114 | self.batch_size = batch_size 115 | 116 | def __len__(self): 117 | return self.data["data"].shape[0] 118 | 119 | def __getitem__(self, idx): 120 | return self.data["data"][idx], self.data["seq"][idx] 121 | 122 | def __next__(self): 123 | self.rng, next_rng = jax.random.split(self.rng) 124 | 125 | idx = jax.random.choice(next_rng, len(self), shape=(self.batch_size,)) 126 | 127 | return self[idx] 128 | 129 | def get_all(self): 130 | return self.data["data"], self.data["seq"] 131 | -------------------------------------------------------------------------------- /riemannian_score_sde/datasets/earth.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from score_sde.utils import register_dataset 4 | from score_sde.datasets import CSVDataset 5 | 6 | import geomstats as gs 7 | import jax.numpy as jnp 8 | 9 | 10 | class SphericalDataset(CSVDataset): 11 | def __init__(self, file, extrinsic=False, delimiter=",", skip_header=1, batch_size=256): 12 | super().__init__(file, delimiter=delimiter, skip_header=skip_header, batch_size=batch_size) 13 | 14 | self.manifold = gs.geometry.hypersphere.Hypersphere(2) 15 | self.intrinsic_data = ( 16 | jnp.pi * (self.data / 180.0) + jnp.array([jnp.pi / 2, jnp.pi])[None, :] 17 | ) 18 | self.data = self.manifold.spherical_to_extrinsic(self.intrinsic_data) 19 | 20 | 21 | 22 | class VolcanicErruption(SphericalDataset): 23 | def __init__(self, data_dir="data", **kwargs): 24 | super().__init__(os.path.join(data_dir, "volerup.csv"), skip_header=2) 25 | 26 | 27 | class Fire(SphericalDataset): 28 | def __init__(self, data_dir="data", **kwargs): 29 | super().__init__(os.path.join(data_dir, "fire.csv")) 30 | 31 | 32 | class Flood(SphericalDataset): 33 | def __init__(self, data_dir="data", **kwargs): 34 | super().__init__(os.path.join(data_dir, "flood.csv"), skip_header=2) 35 | 36 | 37 | class Earthquake(SphericalDataset): 38 | def __init__(self, data_dir="data", **kwargs): 39 | super().__init__(os.path.join(data_dir, "quakes_all.csv"), skip_header=4) 40 | 41 | class USForestFire(SphericalDataset): 42 | def __init__(self, data_dir="data", batch_size=256, **kwargs): 43 | super().__init__(os.path.join(data_dir, "us_forest_fires.csv"), batch_size=batch_size) -------------------------------------------------------------------------------- /riemannian_score_sde/datasets/mixture.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import jax 3 | import jax.numpy as jnp 4 | import numpy as np 5 | from jax.scipy.special import logsumexp 6 | import geomstats.backend as gs 7 | 8 | from riemannian_score_sde.models.distribution import ( 9 | WrapNormDistribution as WrappedNormal, 10 | ) 11 | 12 | 13 | class vMFMixture: 14 | def __init__( 15 | self, batch_dims, rng, manifold, mu, kappa, weights=[0.5, 0.5], **kwargs 16 | ): 17 | self.manifold = manifold 18 | self.mu = jnp.array(mu) 19 | self.kappa = jnp.expand_dims(jnp.array(kappa), -1) 20 | self.weights = jnp.array(weights) 21 | self.batch_dims = batch_dims 22 | self.rng = rng 23 | 24 | def __iter__(self): 25 | return self 26 | 27 | def __next__(self): 28 | rng = jax.random.split(self.rng, num=3) 29 | 30 | self.rng = rng[0] 31 | choice_key = rng[1] 32 | normal_key = rng[2] 33 | 34 | indices = jax.random.choice( 35 | choice_key, a=len(self.weights), shape=self.batch_dims, p=self.weights 36 | ) 37 | random_von_mises_fisher = jax.vmap( 38 | partial( 39 | self.manifold.random_von_mises_fisher, 40 | n_samples=np.prod(self.batch_dims), 41 | ) 42 | ) 43 | samples = random_von_mises_fisher(mu=self.mu[indices], kappa=self.kappa[indices]) 44 | diag = jnp.diag_indices(np.prod(self.batch_dims)) 45 | samples = samples[diag] 46 | return (samples, None) 47 | 48 | 49 | class WrapNormMixtureDistribution: 50 | def __init__( 51 | self, 52 | batch_dims, 53 | manifold, 54 | mean, 55 | scale, 56 | seed=0, 57 | rng=None, 58 | ): 59 | self.mean = jnp.array(mean) 60 | self.K = self.mean.shape[0] 61 | self.scale = jnp.array(scale) 62 | self.batch_dims = batch_dims 63 | self.manifold = manifold 64 | self.rng = rng if rng is not None else jax.random.PRNGKey(seed) 65 | 66 | def __iter__(self): 67 | return self 68 | 69 | def __next__(self): 70 | n_samples = np.prod(self.batch_dims) 71 | ks = jnp.arange(self.K) 72 | self.rng, next_rng = jax.random.split(self.rng) 73 | _, k = gs.random.choice(state=next_rng, a=ks, n=n_samples) 74 | mean = self.mean[k] 75 | scale = self.scale[k] 76 | tangent_vec = self.manifold.random_normal_tangent( 77 | next_rng, self.manifold.identity, n_samples 78 | )[1] 79 | tangent_vec *= scale 80 | tangent_vec = self.manifold.metric.transpfrom0(mean, tangent_vec) 81 | samples = self.manifold.metric.exp(tangent_vec, mean) 82 | return (samples, None) 83 | 84 | def log_prob(self, x): 85 | def component_log_prob(mean, scale): 86 | return WrappedNormal(self.manifold, scale, mean).log_prob(x) 87 | 88 | component_log_like = jax.vmap(component_log_prob)(self.mean, self.scale) 89 | b = 1 / self.K * jnp.ones_like(component_log_like) 90 | return logsumexp(component_log_like, axis=0, b=b) 91 | -------------------------------------------------------------------------------- /riemannian_score_sde/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .embedding import NoneEmbedding, LaplacianEigenfunctionEmbedding 2 | from .transform import * 3 | from .vector_field import * 4 | -------------------------------------------------------------------------------- /riemannian_score_sde/models/distribution.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import jax.numpy as jnp 3 | 4 | from geomstats.geometry.euclidean import Euclidean 5 | from score_sde.sde import SDE 6 | from distrax import MultivariateNormalDiag 7 | 8 | 9 | class UniformDistribution: 10 | """Uniform density on compact manifold""" 11 | 12 | def __init__(self, manifold, **kwargs): 13 | self.manifold = manifold 14 | 15 | def sample(self, rng, shape): 16 | x = self.manifold.random_uniform(state=rng, n_samples=shape[0]) 17 | return x 18 | 19 | def log_prob(self, z): 20 | return -jnp.ones([z.shape[0]]) * self.manifold.log_volume 21 | 22 | def grad_U(self, x): 23 | return jnp.zeros_like(x) 24 | 25 | 26 | class MultivariateNormal(MultivariateNormalDiag): 27 | def __init__(self, dim, mean=None, scale=None, **kwargs): 28 | mean = jnp.zeros((dim)) if mean is None else mean 29 | scale = jnp.ones((dim)) if scale is None else scale 30 | super().__init__(mean, scale) 31 | 32 | def sample(self, rng, shape): 33 | return super().sample(seed=rng, sample_shape=shape) 34 | 35 | def log_prob(self, z): 36 | return super().log_prob(z) 37 | 38 | def grad_U(self, x): 39 | return x / (self.scale_diag**2) 40 | 41 | 42 | class DefaultDistribution: 43 | def __new__(cls, manifold, flow, **kwargs): 44 | if isinstance(flow, SDE): 45 | return flow.limiting 46 | else: 47 | if isinstance(manifold, Euclidean): 48 | zeros = jnp.zeros((manifold.dim)) 49 | ones = jnp.ones((manifold.dim)) 50 | return MultivariateNormalDiag(zeros, ones) 51 | elif hasattr(manifold, "random_uniform"): 52 | return UniformDistribution(manifold) 53 | else: 54 | # TODO: WrappedNormal 55 | raise NotImplementedError(f"No default distribution for {manifold}") 56 | 57 | 58 | class WrapNormDistribution: 59 | def __init__(self, manifold, scale=1.0, mean=None): 60 | self.manifold = manifold 61 | if mean is None: 62 | mean = self.manifold.identity 63 | self.mean = mean 64 | # NOTE: assuming diagonal scale 65 | self.scale = ( 66 | jnp.ones((self.manifold.dim * 2)) * scale 67 | if isinstance(scale, float) 68 | else jnp.array(scale) 69 | ) 70 | 71 | def sample(self, rng, shape): 72 | mean = self.mean[None, ...] 73 | tangent_vec = self.manifold.random_normal_tangent( 74 | rng, self.manifold.identity, np.prod(shape) 75 | )[1] 76 | # tangent_vec = self.manifold.random_normal_tangent(rng, mean, np.prod(shape))[1] 77 | tangent_vec *= self.scale 78 | tangent_vec = self.manifold.metric.transpfrom0(mean, tangent_vec) 79 | return self.manifold.metric.exp(tangent_vec, mean) 80 | 81 | def log_prob(self, z): 82 | tangent_vec = self.manifold.metric.log(z, self.mean) 83 | tangent_vec = self.manifold.metric.transpback0(self.mean, tangent_vec) 84 | zero = jnp.zeros((self.manifold.dim)) 85 | # TODO: to refactor axis contenation / removal 86 | if self.scale.shape[-1] == self.manifold.dim: # poincare 87 | scale = self.scale 88 | else: # hyperboloid 89 | scale = self.scale[..., 1:] 90 | norm_pdf = MultivariateNormalDiag(zero, scale).log_prob(tangent_vec) 91 | logdetexp = self.manifold.metric.logdetexp(self.mean, z) 92 | return norm_pdf - logdetexp 93 | 94 | def grad_U(self, x): 95 | def U(x): 96 | sq_dist = self.manifold.metric.dist(x, self.mean) ** 2 97 | res = 0.5 * sq_dist / (self.scale[0] ** 2) # scale must be isotropic 98 | logdetexp = self.manifold.metric.logdetexp(self.mean, x) 99 | return res + logdetexp 100 | 101 | # U = lambda x: -self.log_prob(x) #NOTE: this does not work 102 | 103 | return self.manifold.to_tangent(self.manifold.metric.grad(U)(x), x) 104 | -------------------------------------------------------------------------------- /riemannian_score_sde/models/embedding.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from dataclasses import dataclass 3 | 4 | import jax 5 | import haiku as hk 6 | import jax.numpy as jnp 7 | 8 | from score_sde.models import MLP 9 | 10 | 11 | class Embedding(hk.Module, abc.ABC): 12 | def __init__(self, manifold): 13 | super().__init__() 14 | self.manifold = manifold 15 | 16 | 17 | class NoneEmbedding(Embedding): 18 | def __call__(self, x, t): 19 | return x, t 20 | 21 | 22 | class LaplacianEigenfunctionEmbedding(Embedding): 23 | def __init__(self, manifold, n_manifold, n_time, max_t): 24 | super().__init__(manifold) 25 | self.n_time = n_time 26 | self.frequencies = 2 * jnp.pi * jnp.arange(n_time) / max_t 27 | self.n_manifold = n_manifold 28 | 29 | def __call__(self, x, t): 30 | t = jnp.array(t) 31 | if len(t.shape) == 0: 32 | t = t * jnp.ones(x.shape[:-1]) 33 | 34 | if len(t.shape) == len(x.shape) - 1: 35 | t = jnp.expand_dims(t, axis=-1) 36 | 37 | x = self.manifold.laplacian_eigenfunctions(x, self.n_time) 38 | t = jnp.concatenate( 39 | (jnp.cos(self.frequencies * t), jnp.sin(self.frequencies * t)), axis=-1 40 | ) 41 | 42 | return x, t 43 | 44 | 45 | @dataclass 46 | class ConcatEigenfunctionEmbed(hk.Module): 47 | def __init__(self, output_shape, hidden_shapes, act): 48 | super().__init__() 49 | self._layer = MLP(hidden_shapes=hidden_shapes, output_shape=output_shape, act=act) 50 | 51 | def __call__(self, x, t): 52 | t = jnp.array(t) 53 | if len(t.shape) == 0: 54 | t = t * jnp.ones(x.shape[:-1]) 55 | 56 | if len(t.shape) == len(x.shape) - 1: 57 | t = jnp.expand_dims(t, axis=-1) 58 | 59 | return self._layer(jnp.concatenate([x, t], axis=-1)) 60 | -------------------------------------------------------------------------------- /riemannian_score_sde/models/transform.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import jax.numpy as jnp 4 | 5 | from geomstats.geometry.lie_group import MatrixLieGroup 6 | from geomstats.geometry.hypersphere import Hypersphere 7 | from geomstats.geometry.euclidean import Euclidean 8 | from geomstats import algebra_utils as utils 9 | 10 | from score_sde.models import Transform, ComposeTransform 11 | 12 | 13 | def get_likelihood_fn_w_transform(likelihood_fn, transform): 14 | def log_prob(x, context=None): 15 | y = transform.inv(x) 16 | logp, nfe = likelihood_fn(y, context=context) 17 | log_abs_det_jacobian = transform.log_abs_det_jacobian(y, x) 18 | logp -= log_abs_det_jacobian 19 | return logp, nfe 20 | 21 | return log_prob 22 | 23 | 24 | class ExpMap(Transform): 25 | def __init__(self, manifold, base_point=None, **kwargs): 26 | super().__init__(Euclidean(manifold.dim), manifold) 27 | self.manifold = manifold 28 | self.base_point = manifold.identity if base_point is None else base_point 29 | if (self.base_point == manifold.identity).all() and isinstance( 30 | manifold, MatrixLieGroup 31 | ): 32 | self.forward = lambda x: manifold.exp_from_identity(x) 33 | self.inverse = lambda y: manifold.log_from_identity(y) 34 | else: 35 | # self.manifold.metric.exp(x, base_point=self.base_point) 36 | self.forward = lambda x: manifold.exp(x, base_point=self.base_point) 37 | self.inverse = lambda y: manifold.log(y, base_point=self.base_point) 38 | 39 | def __call__(self, x): 40 | x = self.manifold.hat(x) 41 | return self.forward(x) 42 | 43 | def inv(self, y): 44 | x = self.inverse(y) 45 | return self.manifold.vee(x) 46 | 47 | def log_abs_det_jacobian(self, x, y): 48 | # TODO: factor 49 | if isinstance(self.manifold, MatrixLieGroup): 50 | return self.manifold.logdetexp(x, y) 51 | else: 52 | return self.manifold.logdetexp(self.base_point, y) 53 | 54 | 55 | class TanhExpMap(ComposeTransform): 56 | def __init__(self, manifold, base_point=None, radius=None, **kwargs): 57 | if radius is None: 58 | radius = manifold.injectivity_radius 59 | if jnp.isposinf(radius): 60 | parts = [] 61 | else: 62 | parts = [RadialTanhTransform(radius, manifold.dim)] 63 | exp_transform = ExpMap(manifold, base_point) 64 | self.base_point = exp_transform.base_point 65 | parts.append(exp_transform) 66 | super().__init__(parts) 67 | 68 | 69 | class InvStereographic(Transform): 70 | def __init__(self, manifold, base_point=None, **kwargs): 71 | assert isinstance(manifold, Hypersphere) 72 | super().__init__(Euclidean(manifold.dim), manifold) 73 | self.manifold = manifold 74 | assert base_point is None or base_point == manifold.identity 75 | self.base_point = manifold.identity 76 | 77 | def __call__(self, x): 78 | return self.manifold.inv_stereographic_projection(x) 79 | 80 | def inv(self, y): 81 | return self.manifold.stereographic_projection(y) 82 | 83 | def log_abs_det_jacobian(self, x, y): 84 | return self.manifold.inv_stereographic_projection_logdet(x) 85 | 86 | 87 | class RadialTanhTransform(Transform): 88 | r""" 89 | from: https://github.com/pimdh/relie/blob/master/relie/flow/radial_tanh_transform.py 90 | Transform R^d of radius (0, inf) to (0, R) 91 | Uses the fact that tanh is linear near 0. 92 | """ 93 | 94 | def __init__(self, radius, dim): 95 | super().__init__(Euclidean(dim), Euclidean(dim)) 96 | self.radius = radius 97 | 98 | def __call__(self, x): 99 | """x -> tanh(||x||) x / ||x|| * R""" 100 | # x_norm = jnp.linalg.norm(x, axis=-1, keepdims=True) 101 | # mask = x_norm > 1e-8 102 | # x_norm = jnp.where(mask, x_norm, jnp.ones_like(x_norm)) 103 | # return jnp.where( 104 | # mask, jnp.tanh(x_norm) * x / x_norm * self.radius, x * self.radius 105 | # ) 106 | x_sq_norm = jnp.sum(jnp.square(x), axis=-1, keepdims=True) 107 | tanh_ratio = utils.taylor_exp_even_func(x_sq_norm, utils.tanh_close_0, order=5) 108 | return tanh_ratio * x * self.radius 109 | 110 | def inv(self, y): 111 | """ 112 | y -> arctanh(||y|| / R) y / ||y|| 113 | y -> arctanh(||y|| / R) y / (||y|| / R) / R 114 | """ 115 | # y_norm = jnp.linalg.norm(y, axis=-1, keepdims=True) 116 | # mask = y_norm > 1e-8 117 | # y_norm = jnp.where(mask, y_norm, jnp.ones_like(y_norm)) 118 | # return jnp.where( 119 | # mask, jnp.arctanh(y_norm / self.radius) * y / y_norm, y / self.radius 120 | # ) 121 | 122 | y_sq_norm = jnp.sum(jnp.square(y), axis=-1, keepdims=True) 123 | y_sq_norm = y_sq_norm / (self.radius**2) 124 | # y_sq_norm = jnp.clip(y_sq_norm, a_max=1) 125 | y_sq_norm = jnp.clip(y_sq_norm, a_max=1 - 1e-7) 126 | arctanh = utils.taylor_exp_even_func( 127 | y_sq_norm, utils.arctanh_card_close_0, order=5 128 | ) 129 | return arctanh * y / self.radius 130 | 131 | def log_abs_det_jacobian(self, x, y): 132 | """ 133 | computation similar to exp map in https://arxiv.org/abs/1902.02992 134 | x -> dim * log R + (dim - 1) * log(tanh(r)/r) + log1p(- tanh(r^2)) 135 | :param x: Tensor 136 | :param y: Tensor 137 | :return: Tensor 138 | """ 139 | x_sq_norm = jnp.sum(jnp.square(x), axis=-1) 140 | x_norm = jnp.sqrt(x_sq_norm) 141 | dim = x.shape[-1] 142 | # tanh = jnp.tanh(x_norm) 143 | # term1 = -jnp.log(x_norm / tanh) 144 | # term1 = -jnp.log( 145 | # utils.taylor_exp_even_func(x_sq_norm, utils.inv_tanh_close_0, order=5) 146 | # ) 147 | # term2 = jnp.log1p(-tanh ** 2) 148 | # return jnp.where(x_norm > 1e-8, out, log_radius) 149 | term1 = utils.taylor_exp_even_func(x_sq_norm, utils.log_tanh_close_0, order=4) 150 | term2 = utils.taylor_exp_even_func( 151 | x_sq_norm, utils.log1p_m_tanh_sq_close_0, order=5 152 | ) 153 | 154 | log_radius = math.log(self.radius) * jnp.ones_like(x_norm) 155 | out = dim * log_radius + (dim - 1) * term1 + term2 156 | return out 157 | -------------------------------------------------------------------------------- /riemannian_score_sde/sampling.py: -------------------------------------------------------------------------------- 1 | """Various sampling methods.""" 2 | from typing import Tuple 3 | from functools import partial 4 | import jax 5 | import jax.numpy as jnp 6 | 7 | from score_sde.utils import batch_mul 8 | from score_sde.sampling import ( 9 | get_pc_sampler, 10 | get_discrete_pc_sampler, 11 | Predictor, 12 | Corrector, 13 | register_predictor, 14 | register_corrector, 15 | ) 16 | 17 | 18 | @partial(register_predictor, name="GRW") 19 | class EulerMaruyamaManifoldPredictor(Predictor): 20 | def __init__(self, sde): 21 | super().__init__(sde) 22 | 23 | def update_fn( 24 | self, 25 | rng: jax.random.KeyArray, 26 | x: jnp.ndarray, 27 | t: float, 28 | dt: float, 29 | temperature: bool = False, 30 | score: jnp.ndarray = None, 31 | ) -> Tuple[jnp.ndarray, jnp.ndarray]: 32 | drift, diffusion = self.sde.coefficients(x, t, score=score) 33 | z = self.sde.manifold.random_normal_tangent( 34 | state=rng, base_point=x, n_samples=x.shape[0] 35 | )[1] 36 | drift = drift * dt[..., None] 37 | if len(diffusion.shape) > 1 and diffusion.shape[-1] == diffusion.shape[-2]: 38 | # if square matrix diffusion coeffs 39 | tangent_vector = drift + jnp.einsum( 40 | "...ij,...j,...->...i", diffusion, z, jnp.sqrt(jnp.abs(dt)) 41 | ) 42 | elif len(diffusion.shape) > 1 and diffusion.shape[-1] == z.shape[-1]: 43 | tangent_vector = drift + jnp.einsum( 44 | "...i,...i,...->...i", diffusion, z, jnp.sqrt(jnp.abs(dt)) 45 | ) 46 | else: 47 | # if scalar diffusion coeffs (i.e. no extra dims on the diffusion) 48 | tangent_vector = drift + jnp.einsum( 49 | "...,...i,...->...i", diffusion, z, jnp.sqrt(jnp.abs(dt)) 50 | ) 51 | 52 | x = self.sde.manifold.exp(tangent_vec=tangent_vector, base_point=x) 53 | return x, x 54 | 55 | 56 | @register_corrector 57 | class LangevinCorrector(Corrector): 58 | """ 59 | dX = c \nabla \log p dt + (2c)^1/2 dBt 60 | c = 1/2 61 | """ 62 | 63 | def __init__(self, sde, snr, n_steps): 64 | raise NotImplementedError("This corrector has not been properly tested") 65 | super().__init__(sde, snr, n_steps) 66 | 67 | def update_fn( 68 | self, 69 | rng: jax.random.KeyArray, 70 | x: jnp.ndarray, 71 | t: float, 72 | dt: float, 73 | temperature: bool = False, 74 | ) -> Tuple[jnp.ndarray, jnp.ndarray]: 75 | sde = self.sde 76 | n_steps = self.n_steps 77 | target_snr = self.snr 78 | """ timestep = (t * (sde.N - 1) / sde.T).astype(jnp.int32) 79 | alpha = sde.alphas[timestep] """ 80 | alpha = jnp.ones_like(t) 81 | 82 | def loop_body(step, val): 83 | rng, x = val 84 | grad = sde.score_fn(x, t) 85 | rng, step_rng = jax.random.split(rng) 86 | 87 | noise = self.sde.manifold.random_normal_tangent( 88 | state=step_rng, base_point=x, n_samples=x.shape[0] 89 | )[1] 90 | 91 | grad_norm = self.sde.manifold.metric.norm(grad, x).mean() 92 | noise_norm = self.sde.manifold.metric.norm(noise, x).mean() 93 | step_size = (target_snr * noise_norm / grad_norm) ** 2 * 2 * alpha 94 | step_size = jnp.expand_dims(step_size, -1) 95 | 96 | tangent_vector = batch_mul((step_size / 2), grad) 97 | tangent_vector += batch_mul(jnp.sqrt(step_size), noise) 98 | 99 | x = self.sde.manifold.exp(tangent_vec=tangent_vector, base_point=x) 100 | return rng, x 101 | 102 | _, x = jax.lax.fori_loop(0, n_steps, loop_body, (rng, x)) 103 | return x, x # x_mean 104 | -------------------------------------------------------------------------------- /riemannian_score_sde/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxcsml/constrained-diffusion/fbeee12df178ed2e08dd0596e791f0d32a3a8c6a/riemannian_score_sde/utils/__init__.py -------------------------------------------------------------------------------- /riemannian_score_sde/utils/normalization.py: -------------------------------------------------------------------------------- 1 | import jax 2 | from jax import numpy as jnp 3 | import numpy as np 4 | 5 | from geomstats.geometry.euclidean import Euclidean 6 | from geomstats.geometry.hypersphere import Hypersphere 7 | from geomstats.geometry.hyperbolic import Hyperbolic, PoincareBall, Hyperboloid 8 | from geomstats.geometry.special_orthogonal import ( 9 | _SpecialOrthogonalMatrices, 10 | _SpecialOrthogonal3Vectors, 11 | ) 12 | from riemannian_score_sde.utils.vis import make_disk_grid 13 | 14 | # def compute_microbatch_split(x, K=1): 15 | # """ Checks if batch needs to be broken down further to fit in memory. """ 16 | # B = x.shape[0] 17 | # S = int(2e5 / (K * np.prod(x.shape[1:]))) # float heuristic for 12Gb cuda memory 18 | # return min(B, S) 19 | 20 | 21 | # def compute_across_microbatch(func, x): 22 | # S = compute_microbatch_split(x) 23 | # split = jnp.split(x, S) 24 | # lw = jnp.concatenate([func(_x) for _x in split], axis=0) # concat on batch 25 | # return lw 26 | 27 | 28 | def get_spherical_grid(N, eps=0.0): 29 | theta = jnp.linspace(eps, jnp.pi - eps, N // 2) 30 | phi = jnp.linspace(eps, 2 * jnp.pi - eps, N) 31 | 32 | theta, phi = jnp.meshgrid(theta, phi) 33 | theta = theta.reshape(-1, 1) 34 | phi = phi.reshape(-1, 1) 35 | xs = jnp.concatenate( 36 | [jnp.sin(theta) * jnp.cos(phi), jnp.sin(theta) * jnp.sin(phi), jnp.cos(theta)], 37 | axis=-1, 38 | ) 39 | volume = (2 * np.pi) * np.pi 40 | lambda_x = jnp.sin(theta).reshape((-1)) 41 | return xs, volume, lambda_x 42 | 43 | 44 | def get_so3_grid(N, eps=0.0): 45 | angle1 = jnp.linspace(-jnp.pi + eps, jnp.pi - eps, N) 46 | angle2 = jnp.linspace(-jnp.pi / 2 + eps, jnp.pi / 2 - eps, N // 2) 47 | angle3 = jnp.linspace(-jnp.pi + eps, jnp.pi - eps, N) 48 | 49 | angle1, angle2, angle3 = jnp.meshgrid(angle1, angle2, angle3) 50 | xs = jnp.concatenate( 51 | [ 52 | angle1.reshape(-1, 1), 53 | angle2.reshape(-1, 1), 54 | angle3.reshape(-1, 1), 55 | ], 56 | axis=-1, 57 | ) 58 | xs = jax.vmap(_SpecialOrthogonal3Vectors().matrix_from_tait_bryan_angles)(xs) 59 | 60 | # remove points too close from the antipole 61 | vs = jax.vmap(_SpecialOrthogonal3Vectors().rotation_vector_from_matrix)(xs) 62 | norm_v = jnp.linalg.norm(vs, axis=-1, keepdims=True) 63 | max_norm = jnp.pi - eps 64 | cond = jnp.expand_dims(norm_v <= max_norm, -1) 65 | rescaled_vs = vs * max_norm / norm_v 66 | rescaled_xs = jax.vmap(_SpecialOrthogonal3Vectors().matrix_from_rotation_vector)( 67 | rescaled_vs 68 | ) 69 | xs = jnp.where(cond, xs, rescaled_xs) 70 | 71 | volume = (2 * np.pi) * (2 * np.pi) * np.pi 72 | lambda_x = (jnp.sin(angle2 + np.pi / 2)).reshape((-1)) 73 | return xs, volume, lambda_x 74 | 75 | 76 | def get_euclidean_grid(N, dim): 77 | dim = int(dim) 78 | bound = 10 79 | x = jnp.linspace(-bound, bound, N) 80 | xs = dim * [x] 81 | 82 | xs = jnp.meshgrid(*xs) 83 | xs = jnp.concatenate([x.reshape(-1, 1) for x in xs], axis=-1) 84 | volume = (2 * bound) ** dim 85 | lambda_x = (jnp.ones((xs.shape[0], 1))).reshape(-1) 86 | return xs, volume, lambda_x 87 | 88 | 89 | def make_disk_grid(N, eps=1e-2, dim=2, radius=1.0): 90 | h = Hyperbolic(dim=dim, default_coords_type="ball") 91 | x = jnp.linspace(-radius, radius, N) 92 | xs = dim * [x] 93 | xs = jnp.meshgrid(*xs) 94 | xs = jnp.concatenate([x.reshape(-1, 1) for x in xs], axis=-1) 95 | mask = jnp.linalg.norm(xs, axis=-1) < 1.0 - eps 96 | idx = jnp.nonzero(mask)[0] 97 | xs = xs[idx] 98 | lambda_x = h.metric.lambda_x(xs) ** 2 99 | # lambda_x = h.metric.lambda_x(xs) ** 2 * mask 100 | volume = (2 * radius) ** dim 101 | 102 | return xs, volume, lambda_x 103 | 104 | 105 | def make_hyp_grid(N, eps=1e-2, dim=2, radius=1.0): 106 | xs, volume, lambda_x = make_disk_grid(N, eps=eps, dim=dim, radius=radius) 107 | ball_to_extr = Hyperbolic._ball_to_extrinsic_coordinates 108 | return ball_to_extr(xs), volume, lambda_x 109 | 110 | 111 | def compute_normalization( 112 | likelihood_fn, manifold, context=None, N=None, eps=0.0, return_all=False 113 | ): 114 | if isinstance(manifold, Euclidean): 115 | N = N if N is not None else int(jnp.power(1e5, 1 / manifold.dim)) 116 | xs, volume, lambda_x = get_euclidean_grid(N, manifold.dim) 117 | elif isinstance(manifold, Hypersphere) and manifold.dim == 2: 118 | N = N if N is not None else 200 119 | xs, volume, lambda_x = get_spherical_grid(N, eps) 120 | elif isinstance(manifold, _SpecialOrthogonalMatrices) and manifold.dim == 3: 121 | N = N if N is not None else 50 122 | xs, volume, lambda_x = get_so3_grid(N, eps=1e-3) 123 | # elif isinstance(manifold, PoincareBall): 124 | # N = N if N is not None else 100 125 | # xs, volume, lambda_x = make_disk_grid(N, dim=manifold.dim) 126 | # elif isinstance(manifold, Hyperboloid): 127 | # N = N if N is not None else 100 128 | # xs, volume, lambda_x = make_hyp_grid(N, dim=manifold.dim) 129 | else: 130 | print("Only integration over R^d, S^2, H2 and SO(3) is implemented.") 131 | return 0.0 132 | context = ( 133 | context 134 | if context is None 135 | else jnp.repeat(jnp.expand_dims(context, 0), xs.shape[0], 0) 136 | ) 137 | logp = likelihood_fn(xs, context) 138 | if isinstance(logp, tuple): 139 | logp, nfe = logp 140 | prob = jnp.exp(logp) 141 | Z = (prob * lambda_x).mean() * volume 142 | if return_all: 143 | return Z.item(), prob, lambda_x * volume, N 144 | else: 145 | return Z.item() 146 | -------------------------------------------------------------------------------- /score_sde/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oxcsml/constrained-diffusion/fbeee12df178ed2e08dd0596e791f0d32a3a8c6a/score_sde/__init__.py -------------------------------------------------------------------------------- /score_sde/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from score_sde.utils import register_dataset 2 | 3 | from .mixture import * 4 | from .tensordataset import * 5 | from .split import * 6 | -------------------------------------------------------------------------------- /score_sde/datasets/mixture.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import jax 3 | import jax.numpy as jnp 4 | import numpy as np 5 | 6 | 7 | class GaussianMixture: 8 | def __init__( 9 | self, 10 | batch_dims, 11 | rng, 12 | means=[-1.0, 1.0], 13 | stds=[1.0, 1.0], 14 | weights=[0.5, 0.5], 15 | **kwargs 16 | ): 17 | self.means = jnp.array(means) 18 | self.stds = jnp.array(stds) 19 | self.weights = jnp.array(weights) 20 | self.batch_dims = batch_dims 21 | self.rng = rng 22 | 23 | def __iter__(self): 24 | return self 25 | 26 | def __next__(self): 27 | rng = jax.random.split(self.rng, num=3) 28 | 29 | self.rng = rng[0] 30 | choice_key = rng[1] 31 | normal_key = rng[2] 32 | 33 | indices = jax.random.choice( 34 | choice_key, a=len(self.weights), shape=self.batch_dims, p=self.weights 35 | ) 36 | samples = self.means[indices] + self.stds[indices] * jax.random.normal( 37 | normal_key, shape=self.batch_dims + self.means.shape[1:] 38 | ) 39 | 40 | return (samples, None) 41 | -------------------------------------------------------------------------------- /score_sde/datasets/split.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import numpy as np 4 | 5 | from score_sde.datasets import SubDataset 6 | 7 | 8 | def validate_shuffle_split(n_samples, test_size, train_size, default_test_size=None): 9 | """ 10 | Validation helper to check if the test/test sizes are meaningful wrt to the 11 | size of the data (n_samples) 12 | """ 13 | if test_size is None and train_size is None: 14 | test_size = default_test_size 15 | 16 | test_size_type = np.asarray(test_size).dtype.kind 17 | train_size_type = np.asarray(train_size).dtype.kind 18 | 19 | if ( 20 | test_size_type == "i" 21 | and (test_size >= n_samples or test_size <= 0) 22 | or test_size_type == "f" 23 | and (test_size <= 0 or test_size >= 1) 24 | ): 25 | raise ValueError( 26 | "test_size={0} should be either positive and smaller" 27 | " than the number of samples {1} or a float in the " 28 | "(0, 1) range".format(test_size, n_samples) 29 | ) 30 | 31 | if ( 32 | train_size_type == "i" 33 | and (train_size >= n_samples or train_size <= 0) 34 | or train_size_type == "f" 35 | and (train_size <= 0 or train_size >= 1) 36 | ): 37 | raise ValueError( 38 | "train_size={0} should be either positive and smaller" 39 | " than the number of samples {1} or a float in the " 40 | "(0, 1) range".format(train_size, n_samples) 41 | ) 42 | 43 | if train_size is not None and train_size_type not in ("i", "f"): 44 | raise ValueError("Invalid value for train_size: {}".format(train_size)) 45 | if test_size is not None and test_size_type not in ("i", "f"): 46 | raise ValueError("Invalid value for test_size: {}".format(test_size)) 47 | 48 | if train_size_type == "f" and test_size_type == "f" and train_size + test_size > 1: 49 | raise ValueError( 50 | "The sum of test_size and train_size = {}, should be in the (0, 1)" 51 | " range. Reduce test_size and/or train_size.".format(train_size + test_size) 52 | ) 53 | 54 | if test_size_type == "f": 55 | n_test = math.ceil(test_size * n_samples) 56 | elif test_size_type == "i": 57 | n_test = float(test_size) 58 | 59 | if train_size_type == "f": 60 | n_train = math.floor(train_size * n_samples) 61 | elif train_size_type == "i": 62 | n_train = float(train_size) 63 | 64 | if train_size is None: 65 | n_train = n_samples - n_test 66 | elif test_size is None: 67 | n_test = n_samples - n_train 68 | 69 | if n_train + n_test > n_samples: 70 | raise ValueError( 71 | "The sum of train_size and test_size = %d, " 72 | "should be smaller than the number of " 73 | "samples %d. Reduce test_size and/or " 74 | "train_size." % (n_train + n_test, n_samples) 75 | ) 76 | 77 | n_train, n_test = int(n_train), int(n_test) 78 | 79 | if n_train == 0: 80 | raise ValueError( 81 | "With n_samples={}, test_size={} and train_size={}, the " 82 | "resulting train set will be empty. Adjust any of the " 83 | "aforementioned parameters.".format(n_samples, test_size, train_size) 84 | ) 85 | 86 | return n_train, n_test 87 | 88 | 89 | def random_split(dataset, lengths, rng): 90 | if lengths is None: 91 | return dataset, dataset, dataset 92 | elif sum(lengths) == len(dataset): 93 | pass 94 | elif sum(lengths) == 1: 95 | lengths = [int(l * len(dataset)) for l in lengths] 96 | lengths[-1] = len(dataset) - int(sum(lengths[:-1])) 97 | else: 98 | raise ValueError( 99 | "Sum of input lengths does not equal the length of the input dataset" 100 | ) 101 | 102 | indicies = jax.random.permutation(rng, len(dataset)) 103 | return [ 104 | SubDataset(dataset, indicies[sum(lengths[:i]) : sum(lengths[: i + 1])]) 105 | for i in range(len(lengths)) 106 | ] 107 | -------------------------------------------------------------------------------- /score_sde/datasets/tensordataset.py: -------------------------------------------------------------------------------- 1 | from math import prod, floor 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | import numpy as np 6 | 7 | 8 | # class TensorDataset: 9 | # def __init__(self, data): 10 | # self.data = jnp.array(data) 11 | # 12 | # def __len__(self): 13 | # return self.data.shape[0] 14 | # 15 | # def __getitem__(self, idx): 16 | # return self.data[idx] 17 | 18 | class TensorDataset: 19 | def __init__(self, data, rng=None, batch_size=256): 20 | self.rng = rng if rng is not None else jax.random.PRNGKey(0) 21 | self.batch_size = batch_size 22 | self.data = jnp.array(data) 23 | 24 | def __len__(self): 25 | return self.data.shape[0] 26 | 27 | def __getitem__(self, idx): 28 | return self.data[idx], None # , self.data['seq'][idx] 29 | 30 | def __next__(self): 31 | self.rng, next_rng = jax.random.split(self.rng) 32 | idx = jax.random.choice(next_rng, len(self), shape=(self.batch_size,)) 33 | return self[idx] 34 | 35 | def get_all(self): 36 | return self.data, None 37 | 38 | 39 | 40 | 41 | class DataLoader: 42 | def __init__(self, dataset, batch_dims, rng, shuffle=False, drop_last=False): 43 | self.dataset = dataset 44 | assert isinstance(batch_dims, int) 45 | self.batch_dims = batch_dims 46 | self.rng = rng 47 | 48 | self.shuffle = shuffle 49 | self.drop_last = drop_last 50 | 51 | def __getitem__(self, idx): 52 | return self.dataset[idx] 53 | 54 | def __len__(self): 55 | bs = self.batch_dims 56 | N = floor(len(self.dataset) / bs)# 57 | return N if self.drop_last else N + 1 58 | 59 | def __iter__(self): 60 | return DatasetIterator(self) 61 | 62 | def __next__(self): 63 | rng, next_rng = jax.random.split(self.rng) 64 | self.rng = rng 65 | 66 | indices = jax.random.choice(next_rng, len(self.dataset), shape=(self.batch_dims,)) 67 | 68 | return self.dataset[indices], None 69 | # return self.data[indices].reshape((self.batch_dims, *self.dataset.shape[1:])) 70 | 71 | def get_all(self): 72 | return self.dataset, None 73 | 74 | class DatasetIterator: 75 | def __init__(self, dataloader: DataLoader): 76 | self.dataloader = dataloader 77 | rng, self.dataloader.rng = jax.random.split(self.dataloader.rng) 78 | if self.dataloader.shuffle: 79 | self.indices = jax.random.permutation(rng, len(self.dataloader.dataset)) 80 | else: 81 | self.indices = jnp.arange(len(self.dataloader.dataset)) 82 | self.bs = self.dataloader.batch_dims 83 | self.N = floor(len(dataloader.dataset) / self.bs) 84 | self.n = 0 85 | 86 | def __next__(self): 87 | if self.n < self.N: 88 | batch = self.dataloader.dataset[ 89 | self.indices[self.bs * self.n : self.bs * (self.n + 1)], ... 90 | ] 91 | self.n = self.n + 1 92 | # batch = batch.reshape( 93 | # (self.dataset.batch_dims, *self.dataset.data.shape[1:]) 94 | # ) 95 | elif (self.n == self.N) and not self.dataloader.drop_last: 96 | batch = self.dataloader.dataset[self.indices[self.bs * self.n :]] 97 | self.n = self.n + 1 98 | # TODO: This only works for 1D batch dims rn 99 | # batch = batch.reshape((-1, *self.dataset.data.shape[1:])) 100 | else: 101 | raise StopIteration 102 | 103 | return batch, None 104 | 105 | 106 | # TODO: assumes 1d batch_dims 107 | class SubDataset: 108 | def __init__(self, dataset, indices): 109 | self.dataset = dataset 110 | self.indices = indices 111 | 112 | def __getitem__(self, idx): 113 | if isinstance(idx, list): 114 | return self.dataset[[self.indices[i] for i in idx]] 115 | return self.dataset[self.indices[idx]] 116 | 117 | def __len__(self): 118 | return self.indices.shape[0] 119 | 120 | 121 | class CSVDataset(TensorDataset): 122 | def __init__(self, file, delimiter=",", skip_header=1, **kwargs): 123 | data = np.genfromtxt(file, delimiter=delimiter, skip_header=skip_header) 124 | super().__init__(data, **kwargs) 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | -------------------------------------------------------------------------------- /score_sde/likelihood.py: -------------------------------------------------------------------------------- 1 | """Modified code from https://github.com/yang-song/score_sde""" 2 | # coding=utf-8 3 | # Copyright 2020 The Google Research Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | # pylint: skip-file 18 | # pytype: skip-file 19 | """Various sampling methods.""" 20 | import jax 21 | import numpy as np 22 | import jax.numpy as jnp 23 | 24 | from score_sde.sde import SDE, ProbabilityFlowODE 25 | from score_sde.utils import ParametrisedScoreFunction 26 | from score_sde.ode import odeint 27 | from score_sde.models import get_div_fn, div_noise 28 | 29 | def get_likelihood_fn( 30 | sde: SDE, 31 | score_fn: ParametrisedScoreFunction, 32 | inverse_scaler=lambda x: x, 33 | hutchinson_type: str = "Rademacher", 34 | rtol: str = 1e-5, 35 | atol: str = 1e-5, 36 | method: str = "RK45", 37 | eps: str = 1e-5, 38 | bits_per_dimension=True, 39 | ): 40 | def likelihood_fn(rng: jax.random.KeyArray, data: jnp.ndarray, tf : float = None): 41 | """Compute an unbiased estimate to the log-likelihood in bits/dim. 42 | 43 | Args: 44 | rng: An array of random states. The list dimension equals the number of devices. 45 | train_state: Replicated training state for running on multiple devices. 46 | data: A JAX array of shape [#devices, batch size, ...]. 47 | 48 | Returns: 49 | bpd: A JAX array of shape [#devices, batch size]. The log-likelihoods on `data` in bits/dim. 50 | z: A JAX array of the same shape as `data`. The latent representation of `data` under the 51 | probability flow ODE. 52 | nfe: An integer. The number of function evaluations used for running the black-box ODE solver. 53 | """ 54 | pode = ProbabilityFlowODE(sde, score_fn) 55 | drift_fn = lambda x, t: pode.coefficients(x, t)[0] 56 | div_fn = get_div_fn(drift_fn, hutchinson_type) 57 | drift_fn, div_fn = jax.jit(drift_fn), jax.jit(div_fn) 58 | 59 | rng, step_rng = jax.random.split(rng) 60 | shape = data.shape 61 | epsilon = div_noise(step_rng, shape, hutchinson_type) 62 | tf = sde.tf if tf is None else tf 63 | 64 | def ode_func(x: jnp.ndarray, t: jnp.ndarray) -> np.array: 65 | sample = x[:, :shape[1]] 66 | vec_t = jnp.ones((sample.shape[0],)) * t 67 | drift = drift_fn(sample, vec_t) 68 | logp_grad = div_fn(sample, vec_t, epsilon).reshape([*shape[:-1], 1]) 69 | return jnp.concatenate([drift, logp_grad], axis=1) 70 | 71 | init = jnp.concatenate([data, np.zeros((shape[0], 1))], axis=1) 72 | ts = jnp.array([eps, tf]) 73 | y, nfe = odeint(ode_func, init, ts, rtol=rtol, atol=atol) 74 | 75 | z = y[-1, ..., :-1] 76 | delta_logp = y[-1, ..., -1] 77 | 78 | prior_logp = sde.limiting_distribution_logp(z) 79 | posterior_logp = prior_logp + delta_logp 80 | bpd = -posterior_logp / np.log(2) 81 | N = np.prod(shape[2:]) 82 | bpd = bpd / N 83 | # A hack to convert log-likelihoods to bits/dim 84 | # based on the gradient of the inverse data normalizer. 85 | offset = jnp.log2(jax.grad(inverse_scaler)(0.0)) + 8.0 86 | bpd += offset 87 | return bpd if bits_per_dimension else posterior_logp, z, nfe 88 | 89 | return likelihood_fn -------------------------------------------------------------------------------- /score_sde/models/__init__.py: -------------------------------------------------------------------------------- 1 | from score_sde.utils import register_category 2 | 3 | from .mlp import MLP 4 | from .model import * 5 | from .architecture import * 6 | from .transform import * 7 | 8 | from .flow import * 9 | -------------------------------------------------------------------------------- /score_sde/models/architecture.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import math 3 | 4 | import jax 5 | import haiku as hk 6 | import numpy as np 7 | import jax.numpy as jnp 8 | 9 | from .mlp import MLP, MHA 10 | 11 | 12 | @dataclass 13 | class Attention(hk.Module): 14 | def __init__(self, output_shape, hidden_shapes, act, heads=512, n_layers=3): 15 | super().__init__() 16 | self.heads = heads 17 | self.output_shape = output_shape 18 | self.hidden_shapes = hidden_shapes 19 | self.act = act 20 | self.n_layers = n_layers 21 | 22 | def __call__(self, X, context): 23 | if len(context) == 2: 24 | T, Z = context 25 | else: 26 | T = context 27 | 28 | if len(T.shape) == 0: 29 | T = T * jnp.ones(X.shape[:-1]) 30 | 31 | if len(T.shape) == len(X.shape) - 1: 32 | T = jnp.expand_dims(T, axis=-1) 33 | 34 | N, S = X.shape[0], X.shape[1] // 3 35 | print(X.shape) 36 | x = x0 = jnp.concatenate([ 37 | X[:, :S].reshape(-1, S, 1), 38 | X[:, S:].reshape(-1, S, 2), 39 | jnp.tile(jnp.arange(S), (N, 1))[:, :, None], 40 | jnp.tile(T[:, None], (S, 1)) 41 | ], axis=2) 42 | print(x.shape) 43 | 44 | for i in range(self.n_layers): 45 | z = jnp.concatenate([x, x0], axis=-1) 46 | x = MHA( 47 | hidden_shapes=self.hidden_shapes[:-1], 48 | output_shape=self.hidden_shapes[-1], 49 | act=self.act, heads=self.heads 50 | )(x, z, z) 51 | x = jnp.concatenate([ 52 | x[..., 0], 53 | x[..., 1:].reshape(N, S * 2) 54 | ], axis=-1) 55 | 56 | return x 57 | 58 | @dataclass 59 | class Concat(hk.Module): 60 | def __init__(self, output_shape, hidden_shapes, act): 61 | super().__init__() 62 | self._layer = MLP(hidden_shapes=hidden_shapes, output_shape=output_shape, act=act) 63 | 64 | def __call__(self, x, context): 65 | if isinstance(x, list): 66 | x = jnp.concatenate(x, axis=-1) 67 | 68 | if len(context) == 2: 69 | t, z = context 70 | z = z.reshape(z.shape[0], -1) 71 | else: 72 | t = context 73 | 74 | if len(t.shape) == 0: 75 | t = t * jnp.ones(x.shape[:-1]) 76 | 77 | if len(t.shape) == len(x.shape) - 1: 78 | t = jnp.expand_dims(t, axis=-1) 79 | 80 | if len(context) == 2: 81 | t = jnp.concatenate([z, t], axis=-1) 82 | 83 | return self._layer(jnp.concatenate([x, t], axis=-1)) 84 | 85 | 86 | @dataclass 87 | class Ignore(hk.Module): 88 | def __init__(self, output_shape, hidden_shapes, act): 89 | super().__init__() 90 | self._layer = MLP(hidden_shapes=hidden_shapes, output_shape=output_shape, act=act) 91 | 92 | def __call__(self, x, t): 93 | return self._layer(x) 94 | 95 | 96 | @dataclass 97 | class Sum(hk.Module): 98 | def __init__(self, output_shape, hidden_shapes, act): 99 | super().__init__() 100 | self._layer = MLP(hidden_shapes=hidden_shapes, output_shape=output_shape, act=act) 101 | self._hyper_bias = MLP( 102 | hidden_shapes=[], output_shape=output_shape, act="", bias=False 103 | ) 104 | 105 | def __call__(self, x, t): 106 | t = jnp.array(t, dtype=float).reshape(-1, 1) 107 | return self._layer(x) + self._hyper_bias(t) 108 | 109 | 110 | @dataclass 111 | class Squash(hk.Module): 112 | def __init__(self, output_shape, hidden_shapes, act): 113 | super().__init__() 114 | self._layer = MLP(hidden_shapes=hidden_shapes, output_shape=output_shape, act=act) 115 | self._hyper = MLP(hidden_shapes=[], output_shape=output_shape, act="") 116 | 117 | def __call__(self, x, t): 118 | t = jnp.array(t, dtype=float).reshape(-1, 1) 119 | return self._layer(x) * jax.nn.sigmoid(self._hyper(t)) 120 | 121 | 122 | @dataclass 123 | class SquashSum(hk.Module): 124 | def __init__(self, output_shape, hidden_shapes, act): 125 | super().__init__() 126 | self._layer = MLP(hidden_shapes=hidden_shapes, output_shape=output_shape, act=act) 127 | self._hyper_bias = MLP( 128 | hidden_shapes=[], output_shape=output_shape, act="", bias=False 129 | ) 130 | self._hyper_gate = MLP(hidden_shapes=[], output_shape=output_shape, act="") 131 | 132 | def __call__(self, x, t): 133 | t = jnp.array(t, dtype=float).reshape(-1, 1) 134 | return self._layer(x) * jax.nn.sigmoid(self._hyper_gate(t)) + self._hyper_bias(t) 135 | 136 | 137 | def get_timestep_embedding(timesteps, embedding_dim=128): 138 | """ 139 | From Fairseq. 140 | Build sinusoidal embeddings. 141 | This matches the implementation in tensor2tensor, but differs slightly 142 | from the description in Section 3.5 of "Attention Is All You Need". 143 | https://github.com/pytorch/fairseq/blob/master/fairseq/modules/sinusoidal_positional_embedding.py 144 | """ 145 | half_dim = embedding_dim // 2 146 | emb = math.log(10000) / (half_dim - 1) 147 | emb = jnp.exp(jnp.arange(half_dim, dtype=float) * -emb) 148 | 149 | emb = timesteps * jnp.expand_dims(emb, 0) 150 | emb = jnp.concatenate([jnp.sin(emb), jnp.cos(emb)], -1) 151 | if embedding_dim % 2 == 1: # zero pad 152 | emb = jnp.pad(emb, [0, 1]) 153 | 154 | return emb 155 | 156 | 157 | @dataclass 158 | class ConcatEmbed(hk.Module): 159 | def __init__( 160 | self, 161 | output_shape, 162 | enc_shapes, 163 | t_dim, 164 | dec_shapes, 165 | act, 166 | ): 167 | super().__init__() 168 | self.temb_dim = t_dim 169 | t_enc_dim = t_dim * 2 170 | 171 | self.net = MLP(hidden_shapes=dec_shapes, output_shape=output_shape, act=act) 172 | 173 | self.t_encoder = MLP(hidden_shapes=enc_shapes, output_shape=t_enc_dim, act=act) 174 | 175 | self.x_encoder = MLP(hidden_shapes=enc_shapes, output_shape=t_enc_dim, act=act) 176 | 177 | def __call__(self, x, t): 178 | t = jnp.array(t, dtype=float).reshape(-1, 1) 179 | if len(x.shape) == 1: 180 | x = x.unsqueeze(0) 181 | 182 | temb = get_timestep_embedding(t, self.temb_dim) 183 | temb = self.t_encoder(temb) 184 | xemb = self.x_encoder(x) 185 | temb = jnp.broadcast_to(temb, [xemb.shape[0], *temb.shape[1:]]) 186 | h = jnp.concatenate([xemb, temb], -1) 187 | out = self.net(h) 188 | return out 189 | -------------------------------------------------------------------------------- /score_sde/models/distribution.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import numpy as np 4 | 5 | 6 | class NormalDistribution: 7 | def __init__(self, **kwargs): 8 | pass 9 | 10 | def sample(self, rng, shape): 11 | return jax.random.normal(rng, shape) 12 | 13 | def log_prob(self, z): 14 | shape = z.shape 15 | N = np.prod(shape[1:]) 16 | logp_fn = lambda z: -N / 2.0 * jnp.log(2 * np.pi) - jnp.sum(z**2) / 2.0 17 | return jax.vmap(logp_fn)(z) 18 | -------------------------------------------------------------------------------- /score_sde/models/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .layers import * 2 | -------------------------------------------------------------------------------- /score_sde/models/layers/layers.py: -------------------------------------------------------------------------------- 1 | """Modified code from https://github.com/yang-song/score_sde""" 2 | # coding=utf-8 3 | # Copyright 2020 The Google Research Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | # pylint: skip-file 18 | """Common layers for defining score networks. 19 | """ 20 | import functools 21 | import math 22 | import string 23 | from typing import Any, Sequence, Optional 24 | from dataclasses import dataclass 25 | 26 | import jax 27 | import haiku as hk 28 | import jax.nn as jnn 29 | import jax.numpy as jnp 30 | 31 | from score_sde.utils import register_category 32 | 33 | get_activation, register_activation = register_category("activation") 34 | 35 | register_activation(jnn.elu, name="elu") 36 | register_activation(jnn.relu, name="relu") 37 | register_activation(functools.partial(jnn.leaky_relu, negative_slope=0.01), name="lrelu") 38 | register_activation(jnn.swish, name="swish") 39 | register_activation(jnp.sin, name='sin') 40 | 41 | 42 | def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000): 43 | assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32 44 | half_dim = embedding_dim // 2 45 | # magic number 10000 is from transformers 46 | emb = math.log(max_positions) / (half_dim - 1) 47 | # emb = math.log(2.) / (half_dim - 1) 48 | emb = jnp.exp(jnp.arange(half_dim, dtype=jnp.float32) * -emb) 49 | # emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :] 50 | # emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :] 51 | emb = timesteps[:, None] * emb[None, :] 52 | emb = jnp.concatenate([jnp.sin(emb), jnp.cos(emb)], axis=1) 53 | if embedding_dim % 2 == 1: # zero pad 54 | emb = jnp.pad(emb, [[0, 0], [0, 1]]) 55 | assert emb.shape == (timesteps.shape[0], embedding_dim) 56 | return emb 57 | -------------------------------------------------------------------------------- /score_sde/models/mlp.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from collections.abc import Iterable 3 | 4 | import jax 5 | import haiku as hk 6 | import jax.numpy as jnp 7 | 8 | from .layers import get_activation 9 | from score_sde.utils import register_model 10 | 11 | 12 | @register_model 13 | @dataclass 14 | class MLP: 15 | hidden_shapes: list 16 | output_shape: list 17 | act: str 18 | bias: bool = True 19 | 20 | def __call__(self, x): 21 | for hs in self.hidden_shapes: 22 | x = hk.Linear(output_size=hs, with_bias=self.bias)(x) 23 | x = get_activation(self.act)(x) 24 | 25 | x = [hk.Linear(output_size=s)(x) for s in self.output_shape] \ 26 | if isinstance(self.output_shape, Iterable) else hk.Linear(output_size=self.output_shape)(x) 27 | 28 | return x 29 | 30 | 31 | @register_model 32 | @dataclass 33 | class MHA(hk.Module): 34 | def __init__(self, output_shape, hidden_shapes, act, heads=10): 35 | super().__init__() 36 | self._len = MLP( 37 | hidden_shapes=hidden_shapes[:-1], output_shape=hidden_shapes[-1], act=act 38 | ) 39 | self._key = MLP( 40 | hidden_shapes=hidden_shapes[:-1], output_shape=hidden_shapes[-1], act=act 41 | ) 42 | self._query = MLP( 43 | hidden_shapes=hidden_shapes[:-1], output_shape=hidden_shapes[-1], act=act 44 | ) 45 | self._value = MLP( 46 | hidden_shapes=hidden_shapes[:-1], output_shape=hidden_shapes[-1], act=act 47 | ) 48 | self._out = MLP(hidden_shapes=hidden_shapes, output_shape=1, act=act) 49 | self._mha = hk.MultiHeadAttention(3, heads, w_init_scale=1) 50 | 51 | def __call__(self, Q, K, V): 52 | 53 | LQ = self._len(jnp.arange(Q.shape[-2])[None, :]) 54 | print(LQ) 55 | EQ = self._query(jnp.concatenate([Q, LQ], axis=-2)) 56 | LK = self._len(jnp.arange(K.shape[-2])[None, :]) 57 | EK = self._key(jnp.concatenate([K, LK], axis=-2)) 58 | LV = self._len(jnp.arange(V.shape[-2])[None, :]) 59 | EV = self._value(jnp.concatenate([V, LV], axis=-2)) 60 | 61 | x = self._mha(EQ, EK, EV) 62 | x = self._out(x.reshape(*x.shape[:-1], 3, -1))[..., 0] 63 | return x 64 | -------------------------------------------------------------------------------- /score_sde/models/model.py: -------------------------------------------------------------------------------- 1 | """Modified code from https://github.com/yang-song/score_sde""" 2 | # coding=utf-8 3 | # Copyright 2020 The Google Research Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """All functions and modules related to model definition. 18 | """ 19 | # from collections import Iterable 20 | 21 | import jax 22 | import jax.numpy as jnp 23 | from score_sde.utils.jax import batch_mul 24 | from score_sde.utils.typing import ParametrisedScoreFunction 25 | 26 | 27 | def get_score_fn( 28 | sde, 29 | model: ParametrisedScoreFunction, 30 | params, 31 | state, 32 | train=False, 33 | return_state=False, 34 | std_trick=False, 35 | residual_trick=True, 36 | boundary_enforce=False, 37 | boundary_dis=0.01, 38 | ): 39 | """Wraps `score_fn` so that the model output corresponds to a real time-dependent score function. 40 | Args: 41 | sde: An `sde.SDE` object that represents the forward SDE. 42 | model: A Haiku transformed function representing the score function model 43 | params: A dictionary that contains all trainable parameters. 44 | states: A dictionary that contains all other mutable parameters. 45 | train: `True` for training and `False` for evaluation. 46 | return_state: If `True`, return the new mutable states alongside the model output. 47 | Returns: 48 | A score function. 49 | """ 50 | 51 | def smoother_fn(bounded_manifold, y, eps=1e-6, maximum=1): 52 | return jnp.minimum( 53 | jax.nn.relu(bounded_manifold.distance_to_boundary(y) - boundary_dis) + eps, 54 | maximum, 55 | ) 56 | 57 | def score_fn(y, t, context, rng=None): 58 | model_out, new_state = model.apply( 59 | params, state, rng, y=y, t=t, context=context 60 | ) 61 | score = model_out 62 | 63 | if std_trick: 64 | # NOTE: scaling the output with 1.0 / std helps cf 'Improved Techniques for Training Score-Based Generative Model' 65 | # stds = [sde.sdes[i].marginal_prob(jnp.zeros_like(score[i]), t)[1] for i in range(len(sde.sdes))] 66 | # score = [batch_mul(score[i], 1.0 / stds[i]) for i in range(len(sde.sdes))] 67 | std = sde.marginal_prob(jnp.zeros_like(y), t)[1] 68 | score = batch_mul(score, 1.0 / std) 69 | if residual_trick: 70 | # NOTE: so that if NN = 0 then time reversal = forward 71 | fwd_drift = sde.drift(y, t) 72 | residual = 2 * fwd_drift / sde.beta_schedule.beta_t(t)[..., None] 73 | score += residual 74 | if boundary_enforce: 75 | if isinstance(score, tuple): 76 | bounded_manifold = ( 77 | sde.sdes[0].manifold if hasattr(sde, "sdes") else sde.manifold 78 | ) 79 | if isinstance(y, list): 80 | smooth_value = smoother_fn(bounded_manifold, y[0]) 81 | if len(smooth_value.shape) == 1: 82 | smooth_value = smooth_value[..., None] 83 | score = ( 84 | score[0] * smooth_value, 85 | score[1], 86 | ) 87 | else: 88 | smooth_value = smoother_fn( 89 | bounded_manifold, y[..., : bounded_manifold.dim] 90 | ) 91 | if len(smooth_value.shape) == 1: 92 | smooth_value = smooth_value[..., None] 93 | score = ( 94 | score[0] * smooth_value, 95 | score[1], 96 | ) 97 | else: 98 | smooth_value = smoother_fn(sde.manifold, y) 99 | if len(smooth_value.shape) == 1: 100 | smooth_value = smooth_value[..., None] 101 | score *= smooth_value 102 | 103 | if return_state: 104 | return score, new_state 105 | else: 106 | return score 107 | 108 | return score_fn 109 | -------------------------------------------------------------------------------- /score_sde/models/transform.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import functools 3 | import operator 4 | 5 | import jax.numpy as jnp 6 | 7 | 8 | def get_likelihood_fn_w_transform(likelihood_fn, transform): 9 | def log_prob(x, context=None): 10 | y = transform.inv(x) 11 | logp, nfe = likelihood_fn(y, context=context) 12 | log_abs_det_jacobian = transform.log_abs_det_jacobian(y, x) 13 | logp -= log_abs_det_jacobian 14 | return logp, nfe 15 | 16 | return log_prob 17 | 18 | 19 | class Transform(abc.ABC): 20 | def __init__(self, domain, codomain): 21 | self.domain = domain 22 | self.codomain = codomain 23 | 24 | @abc.abstractmethod 25 | def __call__(self, x): 26 | """Computes the transform `x => y`.""" 27 | 28 | @abc.abstractmethod 29 | def inv(self, y): 30 | """Inverts the transform `y => x`.""" 31 | 32 | @abc.abstractmethod 33 | def log_abs_det_jacobian(self, x, y): 34 | """Computes the log det jacobian `log |dy/dx|` given input and output.""" 35 | 36 | 37 | class ComposeTransform(Transform): 38 | def __init__(self, parts): 39 | assert len(parts) > 0 40 | # NOTE: Could check constraints on domains and codomains 41 | super().__init__(parts[0].domain, parts[-1].codomain) 42 | self.parts = parts 43 | 44 | def __call__(self, x): 45 | for part in self.parts: 46 | x = part(x) 47 | return x 48 | 49 | def inv(self, y): 50 | for part in self.parts[::-1]: 51 | y = part.inv(y) 52 | return y 53 | 54 | def log_abs_det_jacobian(self, x, y): 55 | xs = [x] 56 | for part in self.parts[:-1]: 57 | xs.append(part(xs[-1])) 58 | xs.append(y) 59 | terms = [] 60 | for part, x, y in zip(self.parts, xs[:-1], xs[1:]): 61 | terms.append(part.log_abs_det_jacobian(x, y)) 62 | return functools.reduce(operator.add, terms) 63 | 64 | 65 | class Id(Transform): 66 | def __init__(self, domain, **kwargs): 67 | super().__init__(domain, domain) 68 | 69 | def __call__(self, x): 70 | return x 71 | 72 | def inv(self, y): 73 | return y 74 | 75 | def log_abs_det_jacobian(self, x, y): 76 | return jnp.zeros((x.shape[0])) 77 | -------------------------------------------------------------------------------- /score_sde/optim.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import optax 3 | import jax.numpy as jnp 4 | from jax.tree_util import tree_map, tree_leaves 5 | 6 | from score_sde.utils import TrainState 7 | 8 | 9 | def build_optimize_fn(warmup: bool, grad_clip: float): 10 | """Returns an optimize_fn based on `config`.""" 11 | 12 | def optimize_fn(state: TrainState, grad: dict, warmup=warmup, grad_clip=grad_clip): 13 | """Optimizes with warmup and gradient clipping (disabled if negative).""" 14 | lr = state.lr 15 | if warmup > 0: 16 | lr = lr * jnp.minimum(state.step / warmup, 1.0) 17 | if grad_clip >= 0: 18 | # Compute global gradient norm 19 | grad_norm = jnp.sqrt(sum([jnp.sum(jnp.square(x)) for x in tree_leaves(grad)])) 20 | # Clip gradient 21 | clipped_grad = tree_map( 22 | lambda x: x * grad_clip / jnp.maximum(grad_norm, grad_clip), grad 23 | ) 24 | else: # disabling gradient clipping if grad_clip < 0 25 | clipped_grad = grad 26 | return state.optimizer.apply_updates(clipped_grad, learning_rate=lr) 27 | 28 | return optimize_fn 29 | -------------------------------------------------------------------------------- /score_sde/schedule.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from jax import numpy as jnp 3 | 4 | 5 | class BetaSchedule(ABC): 6 | @abstractmethod 7 | def beta_t(self, t): 8 | pass 9 | 10 | @abstractmethod 11 | def log_mean_coeff(self, t): 12 | pass 13 | 14 | @abstractmethod 15 | def reverse(self): 16 | pass 17 | 18 | 19 | class LinearBetaSchedule(BetaSchedule): 20 | def __init__( 21 | self, 22 | tf: float = 1, 23 | t0: float = 0, 24 | beta_0: float = 0.1, 25 | beta_f: float = 20, 26 | lambda0: float = 1.0, 27 | psi: float = 0.0, 28 | ): 29 | self.tf = tf 30 | self.t0 = t0 31 | self.beta_0 = beta_0 32 | self.beta_f = beta_f 33 | self.lambda0 = lambda0 34 | self.psi = psi 35 | 36 | def log_mean_coeff(self, t): 37 | normed_t = (t - self.t0) / (self.tf - self.t0) 38 | return -0.5 * ( 39 | 0.5 * normed_t**2 * (self.beta_f - self.beta_0) + normed_t * self.beta_0 40 | ) 41 | 42 | def rescale_t(self, t): 43 | return -2 * self.log_mean_coeff(t) 44 | 45 | def beta_t(self, t): 46 | normed_t = (t - self.t0) / (self.tf - self.t0) 47 | return self.beta_0 + normed_t * (self.beta_f - self.beta_0) 48 | 49 | def B(self, t): 50 | interval = self.tf - self.t0 51 | normed_t = (t - self.t0) / interval 52 | # This is not done in the package. 53 | return interval * ( 54 | self.beta_0 * normed_t + 0.5 * (normed_t**2) * (self.beta_f - self.beta_0) 55 | ) 56 | 57 | def reverse(self): 58 | return LinearBetaSchedule( 59 | tf=self.t0, t0=self.tf, beta_f=self.beta_0, beta_0=self.beta_f 60 | ) 61 | 62 | 63 | class QuadraticBetaSchedule(BetaSchedule): 64 | def __init__(self, tf, t0=0, beta_0=1.0, beta_f=1.0, lambda0=1.0, psi=0.0): 65 | self.tf = tf 66 | self.t0 = t0 67 | self.beta_f = beta_f 68 | self.beta_0 = beta_0 69 | self.lambda0 = lambda0 70 | self.psi = psi 71 | 72 | def beta_t(self, t): 73 | normed_t = (t - self.t0) / (self.tf - self.t0) 74 | 75 | return self.beta_0 + normed_t**2 * (self.beta_f - self.beta_0) 76 | 77 | def rescale_t(self, t): 78 | return -2 * self.log_mean_coeff(t) 79 | 80 | def log_mean_coeff(self, t): 81 | normed_t = (t - self.t0) / (self.tf - self.t0) 82 | return -0.5 * ( 83 | normed_t * self.beta_0 + 1 / 3 * normed_t**3 * (self.beta_f - self.beta_0) 84 | ) 85 | 86 | def reverse(self): 87 | return QuadraticBetaSchedule( 88 | tf=self.t0, t0=self.tf, beta_f=self.beta_0, beta_0=self.beta_f 89 | ) 90 | 91 | 92 | class ConstantBetaSchedule(LinearBetaSchedule): 93 | def __init__( 94 | self, 95 | tf: float = 1, 96 | value: float = 1, 97 | lambda0: float = 1.0, 98 | psi: float = 0.0, 99 | ): 100 | super().__init__( 101 | tf=tf, t0=0.0, beta_0=value, beta_f=value, lambda0=lambda0, psi=psi 102 | ) 103 | 104 | 105 | class TriangleBetaSchedule(BetaSchedule): 106 | def __init__( 107 | self, 108 | tf: float, 109 | t0: float = 0, 110 | beta_min: float = 0.1, 111 | beta_max: float = 20, 112 | peak_t: float = 0.5, 113 | lambda0: float = 1.0, 114 | psi: float = 0.0, 115 | ): 116 | self.tf = tf 117 | self.t0 = t0 118 | self.beta_min = beta_min 119 | self.beta_max = beta_max 120 | self.peak_t = peak_t 121 | self.lambda0 = lambda0 122 | self.psi = psi 123 | 124 | def log_mean_coeff(self, t): 125 | peak_t = (self.tf + self.t0) * self.peak_t 126 | up_t = jnp.minimum(peak_t, t) 127 | up_leg = ( 128 | -0.25 * up_t**2 * (self.beta_max - self.beta_min) 129 | - 0.5 * up_t * self.beta_min 130 | ) 131 | 132 | down_t = jnp.maximum(t - peak_t, 0) 133 | down_leg = ( 134 | 0.25 * down_t**2 * (self.beta_max - self.beta_min) 135 | - 0.5 * down_t * self.beta_max 136 | ) 137 | return down_leg + up_leg 138 | 139 | def beta_t(self, t): 140 | normed_t = (t - self.t0) / (self.tf - self.t0) 141 | up_leg = self.beta_min + normed_t * (self.beta_max - self.beta_min) 142 | down_leg = self.beta_max - (normed_t - 0.5) * (self.beta_max - self.beta_min) 143 | return jnp.where(normed_t < 0.5, up_leg, down_leg) 144 | 145 | def reverse(self): 146 | return TriangleBetaSchedule( 147 | tf=self.tf, 148 | t0=self.t0, 149 | beta_max=self.beta_max, 150 | beta_min=self.beta_min, 151 | peak_t=(1.0 - self.peak_t), 152 | ) 153 | 154 | 155 | class RVESchedule(BetaSchedule): 156 | tf = 1.0 157 | t0 = 0.0 158 | lambda0 = 1.0 159 | psi = 1.0 160 | 161 | def __init__(self, sigma0, sigma1): 162 | super().__init__() 163 | self.sigma0 = sigma0 164 | self.sigma1 = sigma1 165 | 166 | def rescale_t(self, t): 167 | return (self.sigma0 * (self.sigma1 / self.sigma0) ** t) ** 2 168 | 169 | def log_mean_coeff(self, t): 170 | return -0.5 * self.rescale_t(t) 171 | 172 | def beta_t(self, t): 173 | return ( 174 | self.sigma0 175 | * (self.sigma1 / self.sigma0) ** t 176 | * jnp.sqrt(2 * jnp.log(self.sigma1 / self.sigma0)) 177 | ) ** 2 178 | 179 | def reverse(self): 180 | return RVESchedule(self.sigma1, self.sigma0) 181 | 182 | 183 | class ReverseBetaSchedule(BetaSchedule): 184 | def __init__(self, forward_schedule): 185 | self.forward_schedule = forward_schedule 186 | self.tf = forward_schedule.tf 187 | self.t0 = forward_schedule.t0 188 | self.beta_f = forward_schedule.beta_0 189 | self.beta_0 = forward_schedule.beta_f 190 | 191 | def beta_t(self, t): 192 | t = self.tf - t 193 | return self.forward_schedule.beta_t(t) 194 | 195 | def reverse(self): 196 | return self.forward_schedule 197 | -------------------------------------------------------------------------------- /score_sde/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .jax import * 2 | from .training import * 3 | from .registry import * 4 | from .typing import * 5 | from .data import * 6 | from .random import * 7 | from .cfg import * 8 | from .schedule import * 9 | -------------------------------------------------------------------------------- /score_sde/utils/cfg.py: -------------------------------------------------------------------------------- 1 | import math 2 | import functools 3 | 4 | import numpy as np 5 | import hydra 6 | from omegaconf import DictConfig, OmegaConf 7 | 8 | 9 | class NoneHydra: 10 | def __init__(self, *args, **kwargs): 11 | pass 12 | 13 | def __bool__(self): 14 | return False 15 | 16 | 17 | # Define useful resolver for hydra config 18 | # TODO: temp fix using replace due to double import in sweeps 19 | OmegaConf.register_new_resolver("int", lambda x: int(x), replace=True) 20 | OmegaConf.register_new_resolver("eval", lambda x: eval(x), replace=True) 21 | OmegaConf.register_new_resolver("str", lambda x: str(x), replace=True) 22 | OmegaConf.register_new_resolver("prod", lambda x: np.prod(x), replace=True) 23 | OmegaConf.register_new_resolver( 24 | "where", lambda condition, x, y: x if condition else y, replace=True 25 | ) 26 | OmegaConf.register_new_resolver("isequal", lambda x, y: x == y, replace=True) 27 | OmegaConf.register_new_resolver("pi", lambda x: x * math.pi, replace=True) 28 | OmegaConf.register_new_resolver("min", min, replace=True) 29 | 30 | 31 | def partialclass(cls, *args, **kwds): 32 | """Return a class instance with partial __init__ 33 | Input: 34 | cls [str]: class to instantiate 35 | """ 36 | cls = hydra.utils.get_class(cls) 37 | 38 | class NewCls(cls): 39 | __init__ = functools.partialmethod(cls.__init__, *args, **kwds) 40 | 41 | return NewCls 42 | 43 | 44 | def partialfunction(func, *args, **kwargs): 45 | return functools.partial(func, *args, **kwargs) 46 | -------------------------------------------------------------------------------- /score_sde/utils/data.py: -------------------------------------------------------------------------------- 1 | def get_data_scaler(centred): 2 | """Data normalizer. Assume data are always in [0, 1].""" 3 | if centred: 4 | # Rescale to [-1, 1] 5 | return lambda x: x * 2. - 1. 6 | else: 7 | return lambda x: x 8 | 9 | 10 | def get_data_inverse_scaler(centred): 11 | """Inverse data normalizer.""" 12 | if centred: 13 | # Rescale [-1, 1] to [0, 1] 14 | return lambda x: (x + 1.) / 2. 15 | else: 16 | return lambda x: x -------------------------------------------------------------------------------- /score_sde/utils/jax.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | import jax 5 | import numpy as np 6 | import jax.numpy as jnp 7 | import jax.lib.xla_bridge as xb 8 | from jax.tree_util import tree_map 9 | 10 | from .typing import ScoreFunction 11 | 12 | 13 | def batch_add(a, b): 14 | return jax.vmap(lambda a, b: a + b)(a, b) 15 | 16 | 17 | def batch_mul(a, b): 18 | return jax.vmap(lambda a, b: a * b)(a, b) 19 | 20 | 21 | def get_estimate_div_fn(fn: ScoreFunction): 22 | """Create the divergence function of `fn` using the Hutchinson-Skilling trace estimator.""" 23 | 24 | def div_fn(y: jnp.ndarray, t: float, context: jnp.ndarray, eps: jnp.ndarray): 25 | eps = eps.reshape(eps.shape[0], -1) 26 | grad_fn = lambda y: jnp.sum(fn(y, t, context) * eps) 27 | grad_fn_eps = jax.grad(grad_fn)(y).reshape(y.shape[0], -1) 28 | return jnp.sum(grad_fn_eps * eps, axis=tuple(range(1, len(eps.shape)))) 29 | 30 | return div_fn 31 | 32 | 33 | def get_exact_div_fn(fn): 34 | "flatten all but the last axis and compute the true divergence" 35 | 36 | def div_fn(y: jnp.ndarray, t: float, context: jnp.ndarray): 37 | y_shape = y.shape 38 | dim = np.prod(y_shape[1:]) 39 | t = jnp.expand_dims(t.reshape(-1), axis=-1) 40 | y = jnp.expand_dims(y, 1) # NOTE: need leading batch dim after vmap 41 | if context is not None: 42 | context = jnp.expand_dims(context, 1) 43 | t = jnp.expand_dims(t, 1) 44 | jac = jax.vmap(jax.jacrev(fn, argnums=0))(y, t, context) 45 | 46 | jac = jac.reshape([y_shape[0], dim, dim]) 47 | return jnp.trace(jac, axis1=-1, axis2=-2) 48 | 49 | return div_fn 50 | 51 | 52 | def to_flattened_numpy(x: jnp.ndarray) -> np.ndarray: 53 | """Flatten a JAX array `x` and convert it to numpy.""" 54 | return np.asarray(x.reshape((-1,))) 55 | 56 | 57 | def from_flattened_numpy(x: np.ndarray, shape: tuple) -> jnp.ndarray: 58 | """Form a JAX array with the given `shape` from a flattened numpy array `x`.""" 59 | return jnp.asarray(x).reshape(shape) 60 | 61 | 62 | # Borrowed from flax 63 | def replicate(tree, devices=None): 64 | """Replicates arrays to multiple devices. 65 | 66 | Args: 67 | tree: a pytree containing the arrays that should be replicated. 68 | devices: the devices the data is replicated to 69 | (default: `jax.local_devices()`). 70 | Returns: 71 | A new pytree containing the replicated arrays. 72 | """ 73 | if devices is None: 74 | # match the default device assignments used in pmap: 75 | # for single-host, that's the XLA default device assignment 76 | # for multi-host, it's the order of jax.local_devices() 77 | if jax.process_count() == 1: 78 | devices = [ 79 | d 80 | for d in xb.get_backend().get_default_device_assignment( 81 | jax.device_count() 82 | ) 83 | if d.process_index == jax.process_index() 84 | ] 85 | else: 86 | devices = jax.local_devices() 87 | 88 | return jax.device_put_replicated(tree, devices) 89 | 90 | 91 | # Borrowed from flax 92 | def unreplicate(tree): 93 | """Returns a single instance of a replicated array.""" 94 | return tree_map(lambda x: x[0], tree) 95 | 96 | 97 | def save(ckpt_dir: str, state, postfix=None) -> None: 98 | postfix = f".{postfix}" if postfix is not None else '' 99 | with open(os.path.join(ckpt_dir, f"arrays{postfix}.npy"), "wb") as f: 100 | for x in jax.tree_util.tree_leaves(state): 101 | np.save(f, x, allow_pickle=False) 102 | 103 | tree_struct = jax.tree_util.tree_map(lambda t: 0, state) 104 | with open(os.path.join(ckpt_dir, f"tree{postfix}.pkl"), "wb") as f: 105 | pickle.dump(tree_struct, f) 106 | if postfix != '': 107 | save(ckpt_dir, state, postfix=None) 108 | 109 | 110 | def restore(ckpt_dir, postfix=None): 111 | postfix = f"{postfix}" if postfix is not None else '' 112 | with open(os.path.join(ckpt_dir, f"tree{postfix}.pkl"), "rb") as f: 113 | tree_struct = pickle.load(f) 114 | 115 | leaves, treedef = jax.tree_util.tree_flatten(tree_struct) 116 | with open(os.path.join(ckpt_dir, f"arrays{postfix}.npy"), "rb") as f: 117 | flat_state = [np.load(f) for _ in leaves] 118 | 119 | return jax.tree_util.tree_unflatten(treedef, flat_state) 120 | 121 | -------------------------------------------------------------------------------- /score_sde/utils/loggers_pl/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional 2 | 3 | from .base import LoggerCollection, LightningLoggerBase 4 | from .csv_log import CSVLogger 5 | from .wandb import WandbLogger 6 | 7 | from hydra.core.singleton import Singleton 8 | 9 | 10 | class Logger(metaclass=Singleton): 11 | def __init__(self) -> None: 12 | self.logger: Optional[LightningLoggerBase] = None 13 | 14 | def set_logger(self, logger: LightningLoggerBase): 15 | assert logger is not None 16 | assert isinstance(logger, LightningLoggerBase) 17 | self.logger = logger 18 | 19 | @staticmethod 20 | def get() -> LightningLoggerBase: 21 | instance = Logger.instance() 22 | if instance.logger is None: 23 | raise ValueError("Logger not set") 24 | return instance.logger 25 | 26 | @staticmethod 27 | def initialized() -> bool: 28 | instance = Logger.instance() 29 | return instance.logger is not None 30 | 31 | @staticmethod 32 | def instance(*args: Any, **kwargs: Any) -> "Logger": 33 | return Singleton.instance(Logger, *args, **kwargs) # type: ignore 34 | -------------------------------------------------------------------------------- /score_sde/utils/random.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import jax.random as jr 3 | 4 | 5 | class GlobalRNG: 6 | def __init__(self, seed: int = np.random.randint(2147483647)): 7 | self.key = jr.PRNGKey(seed) 8 | 9 | def __iter__(self): 10 | return self 11 | 12 | def __next__(self): 13 | (ret_key, self.key) = jr.split(self.key) 14 | return ret_key 15 | -------------------------------------------------------------------------------- /score_sde/utils/registry.py: -------------------------------------------------------------------------------- 1 | _REGISTRY = {} 2 | 3 | 4 | def get_item(category: str, item: str) -> object: 5 | _category = get_category(category) 6 | if item not in _category: 7 | raise ValueError(f"Item {item} not in category {category}") 8 | 9 | return _category[item] 10 | 11 | 12 | def get_category(category: str) -> dict: 13 | if not category in _REGISTRY: 14 | raise ValueError(f"Category {category} not in registry") 15 | 16 | return _REGISTRY[category] 17 | 18 | 19 | def register_item(category: str, item: object, name: str) -> None: 20 | _category = get_category(category) 21 | if item in _category: 22 | raise ValueError(f"Item {item} already in category {category}") 23 | 24 | _category[name] = item 25 | 26 | return item 27 | 28 | 29 | def register_category(category: str) -> None: 30 | assert isinstance(category, str) 31 | 32 | if category in _REGISTRY: 33 | raise ValueError(f"Category {category} in registry already") 34 | 35 | _REGISTRY[category] = {} 36 | 37 | def get_func(item: str): 38 | return get_item(category, item) 39 | 40 | def register_func(obj: object, *, name: str = None): 41 | name = name if name is not None else obj.__name__.split(".")[-1] 42 | return register_item(category, obj, name) 43 | 44 | return get_func, register_func 45 | 46 | 47 | get_model, register_model = register_category("model") 48 | get_dataset, register_dataset = register_category("dataset") 49 | -------------------------------------------------------------------------------- /score_sde/utils/schedule.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | 3 | 4 | def loglinear_schedule( 5 | init_value, 6 | end_value, 7 | decay_steps, 8 | ): 9 | 10 | log_init = jnp.log(init_value) 11 | log_end = jnp.log(end_value) 12 | 13 | def schedule(count): 14 | t = count / decay_steps 15 | return jnp.exp(log_init + t * (log_end - log_init)) 16 | 17 | return schedule 18 | -------------------------------------------------------------------------------- /score_sde/utils/switch.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | class SwitchObject: 4 | def __init__(self, objects, **kwargs): 5 | self.objects = objects 6 | self.index = 0 7 | self.update() 8 | 9 | def update(self, index=None): 10 | if index is not None: 11 | self.index = index 12 | self.__dict__.update( 13 | {attr: getattr(self.objects[self.index], attr) 14 | for attr in dir(self.objects[self.index])} 15 | ) 16 | 17 | def iter_index(self): 18 | self.update((self.index + 1) % len(self.objects)) 19 | 20 | def randomize_index(self): 21 | self.index, = random.choices( 22 | range(len(self.objects)), 23 | weights=[ 24 | len(obj) if hasattr(obj, "__len__") else 1 25 | for obj in self.objects 26 | ] 27 | ) 28 | self.update() 29 | 30 | def __next__(self): 31 | return self.__next__() 32 | 33 | def __len__(self): 34 | return self.__len__() 35 | -------------------------------------------------------------------------------- /score_sde/utils/training.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | 4 | TrainState = namedtuple( 5 | "TrainState", 6 | [ 7 | "opt_state", 8 | "model_state", 9 | "step", 10 | "params", 11 | "ema_rate", 12 | "params_ema", 13 | "rng", 14 | ], 15 | ) 16 | -------------------------------------------------------------------------------- /score_sde/utils/typing.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Tuple 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | 6 | ParametrisedScoreFunction = Callable[[dict, dict, jnp.ndarray, float], jnp.ndarray] 7 | ScoreFunction = Callable[[jnp.ndarray, float], jnp.ndarray] 8 | 9 | SDEUpdateFunction = Callable[ 10 | [jax.random.KeyArray, jnp.ndarray, float], 11 | Tuple[ 12 | jnp.ndarray, 13 | jnp.ndarray, 14 | ], 15 | ] 16 | -------------------------------------------------------------------------------- /scripts/abdb/download_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import tarfile 4 | import urllib.request 5 | from tqdm import tqdm 6 | 7 | from score_sde.utils import register_dataset 8 | 9 | import geomstats as gs 10 | import jax.numpy as jnp 11 | 12 | data_dir = "/data/localhost/not-backed-up/mhutchin/score-sde/data/abdb" 13 | 14 | _dataset_url = "http://www.abybank.org/abdb/Data/LH_Combined_Martin.tar.bz2" 15 | _redundant_antibody_url = ( 16 | "http://www.abybank.org/abdb/Data/Redundant_files/Redundant_H_Combined_Martin.txt" 17 | ) 18 | _free_antibody_complex_list = ( 19 | "http://www.abybank.org/abdb/Data/Martin_logs/Heavy_HeavyAntigen.list" 20 | ) 21 | 22 | 23 | class DownloadProgressBar(tqdm): 24 | def update_to(self, b=1, bsize=1, tsize=None): 25 | if tsize is not None: 26 | self.total = tsize 27 | self.update(b * bsize - self.n) 28 | 29 | 30 | def download_url(url, output_path): 31 | with DownloadProgressBar( 32 | unit="B", unit_scale=True, miniters=1, desc=url.split("/")[-1] 33 | ) as t: 34 | urllib.request.urlretrieve(url, filename=output_path, reporthook=t.update_to) 35 | 36 | 37 | def clean_files(): 38 | if os.path.isdir(data_dir): 39 | shutil.rmtree(data_dir) 40 | 41 | 42 | def download_files(): 43 | os.makedirs(data_dir, exist_ok=True) 44 | 45 | download_url(_dataset_url, os.path.join(data_dir, "raw_data.tar.bz2")) 46 | download_url( 47 | _redundant_antibody_url, 48 | os.path.join(data_dir, "redundant_antibodies.txt"), 49 | ) 50 | download_url( 51 | _free_antibody_complex_list, 52 | os.path.join(data_dir, "free_antibody_complex_list.txt"), 53 | ) 54 | 55 | 56 | def extract_files(): 57 | tar = tarfile.open(os.path.join(data_dir, "raw_data.tar.bz2"), "r:bz2") 58 | tar.extractall(data_dir) 59 | os.rename( 60 | os.path.join(data_dir, "LH_Combined_Martin"), 61 | os.path.join(data_dir, "raw_antibodies"), 62 | ) 63 | 64 | 65 | clean_files() 66 | download_files() 67 | extract_files() 68 | -------------------------------------------------------------------------------- /scripts/approximate_forward.py: -------------------------------------------------------------------------------- 1 | #%% 2 | import os 3 | 4 | os.environ["GEOMSTATS_BACKEND"] = "jax" 5 | # os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" 6 | 7 | import setGPU 8 | from functools import partial 9 | 10 | import numpy as np 11 | import jax.numpy as jnp 12 | from jax.scipy.stats import norm 13 | from jax import vmap, jit, random, lax 14 | from jax.lax import fori_loop 15 | 16 | from matplotlib import pyplot as plt 17 | import matplotlib.ticker as ticker 18 | import seaborn as sns 19 | 20 | from geomstats.geometry.euclidean import Euclidean 21 | from geomstats.geometry.hypersphere import Hypersphere, gegenbauer_polynomials 22 | from score_sde.utils import batch_mul 23 | from riemannian_score_sde.sde import VPSDE, Brownian 24 | from score_sde.models.flow import PushForward, CNF 25 | from riemannian_score_sde.sampling import get_pc_sampler 26 | 27 | # cmap_name = "plasma_r" 28 | cmap_name = "viridis_r" 29 | #%% 30 | rng = random.PRNGKey(1) 31 | B = 512 * 4 32 | manifold = Hypersphere(2) 33 | # sde = Brownian(manifold, tf=4, t0=0.0, beta_0=1, beta_f=1) 34 | sde = Brownian(manifold, tf=1, t0=0, beta_0=0.001, beta_f=10) 35 | timesteps = jnp.linspace(sde.t0, sde.tf, 20) 36 | timesteps = jnp.expand_dims(timesteps, -1) 37 | x0 = jnp.array([[1.0, 0.0, 0.0]]) 38 | x0b = jnp.repeat(x0, B, 0) 39 | 40 | # def MMD 41 | # def k(x, y, kappa=0.5): 42 | # return jnp.exp(-jnp.linalg.norm(x - y, axis=-1) ** 2 / kappa**2 / 2) 43 | 44 | 45 | def k(x, x0, kappa=0.5, n_max=10): 46 | d = manifold.dim 47 | n = jnp.expand_dims(jnp.arange(0, n_max + 1), axis=-1) 48 | # t = jnp.expand_dims(t, axis=0) 49 | t = jnp.array(kappa**2 / 2) 50 | coeffs = ( 51 | jnp.exp(-n * (n + 1) * t) * (2 * n + d - 1) / (d - 1) / manifold.metric.volume 52 | ) 53 | inner_prod = jnp.sum(x0 * x, axis=-1) 54 | cos_theta = jnp.clip(inner_prod, -1.0, 1.0) 55 | P_n = gegenbauer_polynomials(alpha=(d - 1) / 2, l_max=n_max, x=cos_theta) 56 | prob = jnp.sum(coeffs * P_n, axis=0) 57 | return prob 58 | 59 | 60 | def mmd(xs, ys): 61 | @jit 62 | def k_matrix(xs, ys): 63 | return vmap( 64 | lambda x: vmap(lambda y: k(x, y), in_axes=1, out_axes=1)(xs), 65 | in_axes=1, 66 | out_axes=1, 67 | )(ys) 68 | 69 | m = xs.shape[1] 70 | n = ys.shape[1] 71 | 72 | # biased but positive 73 | k_xx = k_matrix(xs, xs).sum(axis=(-2, -1)) 74 | k_yy = k_matrix(ys, ys).sum(axis=(-2, -1)) 75 | k_xy = k_matrix(xs, ys).sum(axis=(-2, -1)) 76 | sq_mmd = k_xx / m / m + k_yy / n / n - 2 * k_xy / m / n 77 | 78 | # unbiased (but can be negative) 79 | # k_xx = k_matrix(xs, xs) 80 | # k_xx = jnp.sum(k_xx, axis=(-2, -1)) - jnp.trace(k_xx, axis1=-2, axis2=-1) 81 | # k_yy = k_matrix(ys, ys) 82 | # k_yy = jnp.sum(k_yy, axis=(-2, -1)) - jnp.trace(k_yy, axis1=-2, axis2=-1) 83 | # k_xy = k_matrix(xs, ys).sum(axis=(-2, -1)) 84 | # sq_mmd = k_xx / m / (m - 1) + k_yy / n / (n - 1) - 2 * k_xy / m / n 85 | 86 | # return sq_mmd 87 | return jnp.sqrt(sq_mmd) 88 | 89 | 90 | # Get "exact" samples (with N = 1000) 91 | N = 1000 92 | rng, next_rng = random.split(rng) 93 | sampler = get_pc_sampler(sde, N, predictor="GRW") 94 | xt_true = vmap(lambda t: sampler(next_rng, x0b, tf=t))(timesteps) 95 | assert vmap(partial(manifold.belongs, atol=1e-5))(xt_true).all() 96 | 97 | # Get approximate samples (sweeping over N) 98 | Ns = np.array([1, 2, 5, 50, 100, 1000]) 99 | # Ns = np.array([2, 5, 10, 20, 50, 100]) 100 | 101 | fig, ax = plt.subplots(1, 1, figsize=(10, 6), sharey=True, sharex=True) 102 | colors = sns.color_palette(cmap_name, len(Ns)) 103 | fontsize = 30 104 | 105 | for i, N in enumerate(Ns): 106 | xt_approx = vmap( 107 | lambda t: get_pc_sampler(sde, N, predictor="GRW")( 108 | next_rng, x0b, tf=t 109 | ) 110 | )(timesteps) 111 | assert vmap(partial(manifold.belongs, atol=1e-5))(xt_approx).all() 112 | dists = mmd(xt_approx, xt_true) 113 | 114 | ax.plot(timesteps, dists, color=colors[i], label=f"N={N}", lw=5) 115 | 116 | ax.set_xlabel("t", fontsize=fontsize) 117 | ax.set_ylabel( 118 | r"MMD$(\hat{\mathbf{X}}_t|\mathbf{X}_0, \mathbf{X}_t|\mathbf{X}_0)$", 119 | fontsize=fontsize, 120 | ) 121 | ax.set_yscale("log") 122 | ax.legend(fontsize=4 / 5 * fontsize, loc="upper right") 123 | ax.tick_params(axis="both", which="major", labelsize=4 / 5 * fontsize) 124 | 125 | # fig.tight_layout(pad=0.5) 126 | fig_name = f"../doc/images/approximate_forward.pdf" 127 | fig.savefig(fig_name, bbox_inches="tight", transparent=True) 128 | -------------------------------------------------------------------------------- /scripts/deploy/config.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | rsa_key=id_rsa_ziz # setup_ssh_keys_on_zizgpus.sh is required for this key to exist 4 | venv=venv 5 | parent_dir=/data/localhost/not-backed-up/$USER/score-sde -------------------------------------------------------------------------------- /scripts/deploy/make_venv.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # to execute from zizgpu02 4 | 5 | dir_path="`dirname \"$0\"`" 6 | source $dir_path/config.sh 7 | 8 | deactivate 9 | rm -rf $parent_dir 10 | mkdir $parent_dir 11 | 12 | virtualenv -p python3.9 $parent_dir/$venv 13 | source $parent_dir/$venv/bin/activate 14 | pip install -r requirements.txt -------------------------------------------------------------------------------- /scripts/deploy/setup_ssh_keys_on_zizgpus.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # This script will create ssh keys on zizgpu0x and ssh-copy-id them into ziz, 4 | # so that you can easily ssh or rsync from zizgpu0x to ziz within the SLURM jobs. 5 | # This script will prompt you to enter your statistics departmental password for 6 | # each of the 4 nodes zizgpu0x when executing ssh-copy-id. 7 | 8 | # to execute from ziz 9 | dir_path="`dirname \"$0\"`" 10 | source $dir_path/config.sh 11 | 12 | ssh-keygen -q -t rsa -b 2048 -N "" -f ~/.ssh/$rsa_key <<< y 13 | 14 | for NODE_ID in '1' '2' '3' '5' 15 | do 16 | srun --clusters=srf_gpu_01 --partition=zizgpu0$NODE_ID-debug ssh-keygen -q -t rsa -b 2048 -N "" -f ~/.ssh/$rsa_key <<< y 17 | srun --pty --clusters=srf_gpu_01 --partition=zizgpu0$NODE_ID-debug ssh-copy-id -i ~/.ssh/$rsa_key ziz.stats.ox.ac.uk 18 | # srun --pty --clusters=srf_gpu_01 --partition=zizgpu0$NODE_ID-debug scp ziz.stats.ox.ac.uk:~/.ssh/$rsa_key.pub ~/.ssh/ 19 | # srun --clusters=srf_gpu_01 --partition=zizgpu0$NODE_ID-debug cat ~/.ssh/$rsa_key.pub >> ~/.ssh/authorized_keys 20 | done -------------------------------------------------------------------------------- /scripts/deploy/sync_keys.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # to execute from ziz 4 | dir_path="`dirname \"$0\"`" 5 | source $dir_path/config.sh 6 | 7 | for NODE_ID in '1' '2' '3' '5' 8 | do 9 | srun --pty --clusters=srf_gpu_01 --partition=zizgpu0$NODE_ID-debug ssh-copy-id -i ~/.ssh/$rsa_key zizgpu04.cpu.stats.ox.ac.uk 10 | done -------------------------------------------------------------------------------- /scripts/deploy/sync_venv.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # to execute from ziz 4 | dir_path="`dirname \"$0\"`" 5 | source $dir_path/config.sh 6 | 7 | for NODE_ID in '1' '2' '3' '5' 8 | do 9 | srun --clusters=srf_gpu_01 --partition=zizgpu0$NODE_ID-debug mkdir -p $parent_dir/$venv & 10 | srun --clusters=srf_gpu_01 --partition=zizgpu0$NODE_ID-debug rsync -e "ssh -i ~/.ssh/id_rsa_ziz" -a --info=progress2 ziz.stats.ox.ac.uk:$parent_dir/$venv/* $parent_dir/$venv/. & 11 | done 12 | -------------------------------------------------------------------------------- /scripts/diffusion_sphere/blender_diffusion_sphere_stills.py: -------------------------------------------------------------------------------- 1 | import bpy 2 | import bmesh 3 | import math 4 | from functools import partial 5 | import numpy as np 6 | import os 7 | from mathutils import Vector 8 | from mathutils import Euler 9 | 10 | # directory = os.getcwd() 11 | base_dir = os.path.expanduser( 12 | "~/Documents/projects/ExtrinsicGaugeEquivariantVectorGPs/" 13 | ) 14 | scripts_dir = os.path.join(base_dir, "blender_scripts") 15 | data_dir = os.path.join(base_dir, "blender") 16 | texture_path = os.path.join(base_dir, "blender", "textures") 17 | col_dir = os.path.join(base_dir, "blender", "col") 18 | 19 | os.makedirs(os.path.join(data_dir, "blank_to_wrong", "renders"), exist_ok=True) 20 | 21 | with open(os.path.join(scripts_dir, "render.py")) as file: 22 | exec(file.read()) 23 | 24 | reset_scene() 25 | set_renderer_settings(num_samples=16 if bpy.app.background else 128) 26 | (cam_axis, cam_obj) = setup_camera( 27 | distance=10, 28 | angle=(0, 0, 0), 29 | lens=85, 30 | height=1500, 31 | ) 32 | setup_lighting( 33 | shifts=(-10, -10, 10), 34 | sizes=(9, 18, 15), 35 | energies=(1500, 150, 1125), 36 | horizontal_angles=(-np.pi / 6, np.pi / 3, np.pi / 3), 37 | vertical_angles=(-np.pi / 3, -np.pi / 6, np.pi / 4), 38 | ) 39 | set_resolution(1080 / 2, aspect=(16, 9)) 40 | 41 | bd_obj = create_backdrop(location=(0, 0, -2), scale=(10, 5, 5)) 42 | arr_obj = create_vector_arrow(color=(1, 0, 0, 1)) 43 | # set_object_collections(backdrop=[bd_obj], instancing=[arr_obj]) 44 | 45 | bm = import_bmesh(os.path.join(data_dir, "unwrap_sphere", "frame_0.obj")) 46 | # bm = import_bmesh(os.path.join(data_dir, "kernels", "S2.obj")) 47 | # import_color(bm, name='white', color = (1,1,1,1)) 48 | import_color( 49 | bm, 50 | data_file=os.path.join(data_dir, "kernels", "s_wrong.csv"), 51 | palette_file=os.path.join(col_dir, "viridis.csv"), 52 | ) 53 | earth_obj = add_mesh(bm, name="Earth") 54 | earth_mat = add_vertex_colors(earth_obj) 55 | add_texture(earth_mat, os.path.join(texture_path, "mercator_rot.png")) 56 | 57 | # VECTOR FIELD 58 | vf_bm = import_vector_field(os.path.join(data_dir, "unwrap_sphere", f"frame_0.csv")) 59 | vf_obj = add_vector_field(vf_bm, arr_obj, scale=3, name="observations") 60 | 61 | arr_obj = create_vector_arrow(color=(0.3, 0.3, 0.3, 1.0)) # (0.0, 0.75, 1.0, 1) 62 | mean_bm = import_vector_field( 63 | os.path.join(data_dir, "kernels", f"mean_wrong_sphere.csv") 64 | ) 65 | mean_obj = add_vector_field(mean_bm, arr_obj, scale=3, name="means") 66 | 67 | bpy.ops.object.empty_add( 68 | type="PLAIN_AXES", align="WORLD", location=(0, 0, 0), scale=(1, 1, 1) 69 | ) 70 | empty = bpy.context.selected_objects[0] 71 | earth_obj.parent = empty 72 | vf_obj.parent = empty 73 | mean_obj.parent = empty 74 | 75 | bpy.context.scene.render.filepath = os.path.join( 76 | data_dir, "kernels", "renders", "wrong_front.png" 77 | ) 78 | empty.rotation_euler = Euler((0, 0, math.radians(90)), "XYZ") 79 | bpy.ops.render.render(use_viewport=True, write_still=True) 80 | 81 | bpy.context.scene.render.filepath = os.path.join( 82 | data_dir, "kernels", "renders", "wrong_back.png" 83 | ) 84 | empty.rotation_euler = Euler((0, 0, -math.radians(90)), "XYZ") 85 | bpy.ops.render.render(use_viewport=True, write_still=True) 86 | 87 | bpy.context.scene.render.filepath = os.path.join( 88 | data_dir, "kernels", "renders", "wrong_top.png" 89 | ) 90 | empty.rotation_euler = Euler((0, -math.radians(90), math.radians(90)), "XYZ") 91 | bpy.ops.render.render(use_viewport=True, write_still=True) 92 | 93 | bpy.context.scene.render.filepath = os.path.join( 94 | data_dir, "kernels", "renders", "wrong_bottom.png" 95 | ) 96 | empty.rotation_euler = Euler((0, math.radians(90), math.radians(90)), "XYZ") 97 | bpy.ops.render.render(use_viewport=True, write_still=True) 98 | -------------------------------------------------------------------------------- /scripts/examples/1D_example.py: -------------------------------------------------------------------------------- 1 | # %% 2 | %load_ext autoreload 3 | %autoreload 2 4 | 5 | import functools 6 | 7 | import jax 8 | import numpy as np 9 | import jax.numpy as jnp 10 | import haiku as hk 11 | import optax 12 | 13 | import matplotlib.pyplot as plt 14 | 15 | from score_sde.datasets import GaussianMixture 16 | from score_sde.models import MLP 17 | from score_sde.sampling import get_pc_sampler, get_predictor, get_corrector 18 | from score_sde.likelihood import get_likelihood_fn 19 | from score_sde.utils import TrainState, get_data_inverse_scaler, ScoreFunctionWrapper, replicate 20 | from score_sde.sde import subVPSDE, VPSDE 21 | from score_sde.losses import get_pmap_step_fn, get_step_fn 22 | # %% 23 | batch_size = 256 24 | dataset = GaussianMixture([batch_size], jax.random.PRNGKey(0), stds=[0.1,0.1]) 25 | 26 | # score_model = hk.transform_with_state( 27 | # lambda x, t: (1000000000000000.0) * ScoreFunctionWrapper(MLP(hidden_shapes=3*[128], output_shape=1, act='sin'))(x, t) 28 | # ) 29 | score_model = hk.transform_with_state( 30 | lambda x, t: jnp.zeros_like(x) - 0.5 31 | ) 32 | dummy_input = next(dataset) 33 | params, state = score_model.init(rng=jax.random.PRNGKey(0), x=dummy_input, t=0) 34 | 35 | out = score_model.apply(params, state, jax.random.PRNGKey(0), x=dummy_input, t=0) 36 | # %% 37 | 38 | steps = 100000 39 | warmup_steps = 2000 40 | 41 | schedule_fn = optax.join_schedules([ 42 | optax.linear_schedule(init_value=0.0, end_value=1.0, transition_steps=warmup_steps), 43 | # optax.linear_schedule(init_value=1.0, end_value=1.0, transition_steps=steps - warmup_steps), 44 | optax.cosine_decay_schedule(init_value=1.0, decay_steps = steps - warmup_steps, alpha=0.0), 45 | ], [warmup_steps]) 46 | 47 | lr=2e-4 48 | grad_clip=jnp.inf 49 | 50 | optimiser = optax.chain( 51 | optax.clip_by_global_norm(grad_clip), 52 | optax.adam(lr, b1=.9, b2=0.999,eps=1e-8), 53 | optax.scale_by_schedule(schedule_fn) 54 | ) 55 | 56 | opt_state = optimiser.init(params) 57 | 58 | # %% 59 | # lrs = jnp.array([ 60 | # schedule_fn(step) for step in range(steps) 61 | # ]) 62 | # plt.plot(lrs) 63 | 64 | # %% 65 | train_state = TrainState( 66 | opt_state=opt_state, model_state=state, step=0, params=params, ema_rate=0.999, params_ema=params, rng=jax.random.PRNGKey(0) 67 | ) 68 | p_train_state = replicate(train_state) 69 | # %% 70 | sde = VPSDE( 71 | beta_min=0.1, 72 | beta_max=20., 73 | N=1000 74 | ) 75 | # %% 76 | train_step_fn = get_pmap_step_fn(sde, score_model, optimiser, True, reduce_mean=False, continuous=True, like_w=False) 77 | p_train_step = jax.pmap(functools.partial(jax.lax.scan, train_step_fn), axis_name='batch', donate_argnums=1) 78 | # %% 79 | 80 | batch_size = 256 81 | dataset = GaussianMixture([1,1,batch_size], jax.random.PRNGKey(0), stds=[0.1,0.1]) 82 | 83 | # %% 84 | losses = [] 85 | for i in range(steps): 86 | batch = { 87 | 'data': next(dataset), 88 | 'label': None 89 | } 90 | rng = p_train_state.rng[0] 91 | next_rng = jax.random.split(rng, num=jax.local_device_count() + 1) 92 | rng = next_rng[0] 93 | next_rng = next_rng[1:] 94 | (_, p_train_state), loss = p_train_step((next_rng,p_train_state), batch) 95 | losses.append(float(loss)) 96 | if i%100 == 0: 97 | print(i, ': ', loss) 98 | 99 | 100 | print(loss) 101 | losses = jnp.array(losses) 102 | # %% 103 | plt.plot(losses[::10]) 104 | plt.yscale('log') 105 | # %% 106 | sampler = get_pc_sampler(sde, score_model, (2**12,1), get_predictor("EulerMaruyamaPredictor"), get_corrector("NoneCorrector"), lambda x: x, 0.2) 107 | posterior_samples = sampler(replicate(jax.random.PRNGKey(0)), p_train_state) 108 | 109 | target_samples = next(dataset) 110 | import seaborn as sns 111 | sns.kdeplot(posterior_samples[0][0,:,0], color='tab:blue') 112 | # sns.kdeplot(target_samples[0,0,:,0], color='tab:orange') 113 | 114 | # %% 115 | 116 | likelihood_fn = get_likelihood_fn(sde, score_model, lambda x: x, bits_per_dimension=False) 117 | 118 | # %% 119 | 120 | x = jnp.linspace(-5, 5)[np.newaxis, :, np.newaxis] 121 | 122 | prior_likelihood = jnp.exp(sde.prior_logp(x[0])) 123 | pushforward_likelihood = jnp.exp(likelihood_fn(replicate(jax.random.PRNGKey(0)), p_train_state, x)[0]) 124 | 125 | plt.plot(x[0,:,0], prior_likelihood) 126 | plt.plot(x[0,:,0], pushforward_likelihood[0]) 127 | sns.kdeplot(posterior_samples[0][0,:,0], color='tab:green') 128 | 129 | 130 | # %% 131 | from scipy.stats import gaussian_kde 132 | 133 | kernel = gaussian_kde(posterior_samples[0][0,:,:]) 134 | 135 | plt.plot(kernel(x)) 136 | 137 | # %% 138 | import seaborn as sns 139 | samples = jnp.concatenate([next(dataset)[0,:,0] for i in range(10)], axis=0) 140 | 141 | sns.distplot(samples) 142 | # %% 143 | -------------------------------------------------------------------------------- /scripts/examples/1D_example2.py: -------------------------------------------------------------------------------- 1 | # %% 2 | %load_ext autoreload 3 | %autoreload 2 4 | 5 | import functools 6 | 7 | import jax 8 | import numpy as np 9 | import jax.numpy as jnp 10 | import haiku as hk 11 | import optax 12 | 13 | import matplotlib.pyplot as plt 14 | 15 | from score_sde.datasets import GaussianMixture 16 | from score_sde.models import MLP 17 | from score_sde.sampling import get_pc_sampler, get_predictor, get_corrector, get_ode_sampler 18 | from score_sde.likelihood import get_likelihood_fn 19 | from score_sde.utils import TrainState, get_data_inverse_scaler, ScoreFunctionWrapper, replicate 20 | from score_sde.sde import subVPSDE, VPSDE 21 | from score_sde.losses import get_pmap_step_fn, get_step_fn 22 | # %% 23 | batch_size = 256 24 | # dataset = GaussianMixture([batch_size], jax.random.PRNGKey(0), stds=[0.1,0.1]) 25 | a = 5 26 | dataset = GaussianMixture([batch_size], jax.random.PRNGKey(0), means=[a], stds=[1.0], weights=[1.]) 27 | 28 | def score_1d_gaussian(x, t): 29 | t = jnp.array(t) 30 | if len(t.shape) == 0: 31 | t = t * jnp.ones(x.shape[:-1]) 32 | 33 | if len(t.shape) == len(x.shape) - 1: 34 | t = jnp.expand_dims(t, axis=-1) 35 | 36 | return jnp.ones_like(x) * a * jnp.exp(-2 * t) 37 | 38 | score_model = hk.transform_with_state(score_1d_gaussian) 39 | 40 | # score_model = hk.transform_with_state( 41 | # lambda x, t: jnp.zeros_like(x) - 0.5 42 | # lambda x, t: ScoreFunctionWrapper(MLP(hidden_shapes=3*[128], output_shape=1, act='sin'))(x, t) 43 | # ) 44 | 45 | dummy_input = next(dataset) 46 | params, state = score_model.init(rng=jax.random.PRNGKey(0), x=dummy_input, t=0) 47 | 48 | out = score_model.apply(params, state, jax.random.PRNGKey(0), x=dummy_input, t=0) 49 | # %% 50 | 51 | # steps = 100000 52 | # warmup_steps = 2000 53 | steps = 100000 // 100 54 | warmup_steps = 2000 // 100 55 | 56 | schedule_fn = optax.join_schedules([ 57 | optax.linear_schedule(init_value=0.0, end_value=1.0, transition_steps=warmup_steps), 58 | # optax.linear_schedule(init_value=1.0, end_value=1.0, transition_steps=steps - warmup_steps), 59 | optax.cosine_decay_schedule(init_value=1.0, decay_steps = steps - warmup_steps, alpha=0.0), 60 | ], [warmup_steps]) 61 | 62 | lr=2e-4 63 | grad_clip=jnp.inf 64 | 65 | optimiser = optax.chain( 66 | optax.clip_by_global_norm(grad_clip), 67 | optax.adam(lr, b1=.9, b2=0.999,eps=1e-8), 68 | optax.scale_by_schedule(schedule_fn) 69 | ) 70 | 71 | opt_state = optimiser.init(params) 72 | 73 | # %% 74 | # lrs = jnp.array([ 75 | # schedule_fn(step) for step in range(steps) 76 | # ]) 77 | # plt.plot(lrs) 78 | 79 | # %% 80 | train_state = TrainState( 81 | opt_state=opt_state, model_state=state, step=0, params=params, ema_rate=0.999, params_ema=params, rng=jax.random.PRNGKey(0) 82 | ) 83 | p_train_state = replicate(train_state) 84 | # %% 85 | sde = VPSDE( 86 | beta_min=0.1, 87 | beta_max=20., 88 | N=1000 89 | ) 90 | # %% 91 | train_step_fn = get_pmap_step_fn(sde, score_model, optimiser, True, reduce_mean=False, continuous=True, like_w=False) 92 | p_train_step = jax.pmap(functools.partial(jax.lax.scan, train_step_fn), axis_name='batch', donate_argnums=1) 93 | # %% 94 | 95 | batch_size = 256 96 | # dataset = GaussianMixture([1,1,batch_size], jax.random.PRNGKey(0), stds=[0.1,0.1]) 97 | dataset = GaussianMixture([1,1,batch_size], jax.random.PRNGKey(0), means=[a], stds=[1.0], weights=[1.]) 98 | 99 | # # %% 100 | # losses = [] 101 | # for i in range(steps): 102 | # batch = { 103 | # 'data': next(dataset), 104 | # 'label': None 105 | # } 106 | # rng = p_train_state.rng[0] 107 | # next_rng = jax.random.split(rng, num=jax.local_device_count() + 1) 108 | # rng = next_rng[0] 109 | # next_rng = next_rng[1:] 110 | # (_, p_train_state), loss = p_train_step((next_rng,p_train_state), batch) 111 | # losses.append(float(loss)) 112 | # if i%100 == 0: 113 | # print(i, ': ', loss) 114 | 115 | # print(loss) 116 | # losses = jnp.array(losses) 117 | 118 | # plt.plot(losses[::10]) 119 | # plt.yscale('log') 120 | # %% 121 | sampler = get_pc_sampler(sde, score_model, (2**12,1), get_predictor("EulerMaruyamaPredictor"), get_corrector("NoneCorrector"), lambda x: x, 0.2) 122 | # sampler = get_ode_sampler(sde, score_model, (2**12,1), lambda x: x) 123 | posterior_samples = sampler(replicate(jax.random.PRNGKey(0)), p_train_state) 124 | 125 | # %% 126 | target_samples = next(dataset) 127 | import seaborn as sns 128 | # sns.kdeplot(posterior_samples[0][0,:,0], color='tab:orange') 129 | # sns.kdeplot(target_samples[0,0,:,0], color='tab:green') 130 | 131 | likelihood_fn = get_likelihood_fn(sde, score_model, lambda x: x, bits_per_dimension=False) 132 | 133 | x = jnp.linspace(-5, 5)[np.newaxis, :, np.newaxis] 134 | 135 | prior_likelihood = jnp.exp(sde.prior_logp(x[0])) 136 | pushforward_likelihood = jnp.exp(likelihood_fn(replicate(jax.random.PRNGKey(0)), p_train_state, x)[0]) 137 | 138 | plt.plot(x[0,:,0], prior_likelihood, color='tab:blue') 139 | sns.kdeplot(target_samples[0,0,:,0], color='tab:green') 140 | plt.plot(x[0,:,0], pushforward_likelihood[0], color='tab:orange') 141 | sns.kdeplot(posterior_samples[0][0,:,0], color='tab:red') 142 | 143 | # %% 144 | -------------------------------------------------------------------------------- /scripts/examples/earth_datasets.py: -------------------------------------------------------------------------------- 1 | # %% 2 | %load_ext autoreload 3 | %autoreload 2 4 | 5 | import os 6 | os.environ["GEOMSTATS_BACKEND"] = 'jax' 7 | os.chdir('/data/localhost/not-backed-up/mhutchin/score-sde') 8 | 9 | import jax 10 | # %% 11 | 12 | from riemannian_score_sde.datasets import * 13 | from riemannian_score_sde.utils.vis import setup_sphere_plot, scatter_earth 14 | from score_sde.datasets import DataLoader, SubDataset, TensorDataset, random_split 15 | 16 | # %% 17 | 18 | data = VolcanicErruption() 19 | fig, ax = setup_sphere_plot() 20 | scatter_earth(data.data, ax=ax) 21 | 22 | # %% 23 | 24 | data = Fire() 25 | fig, ax = setup_sphere_plot(azim=-45, elev=45) 26 | scatter_earth(data.data, ax=ax) 27 | 28 | # %% 29 | 30 | data = Flood() 31 | fig, ax = setup_sphere_plot() 32 | scatter_earth(data.data, ax=ax) 33 | 34 | # %% 35 | 36 | data = Earthquake() 37 | fig, ax = setup_sphere_plot(azim=90, elev=-0) 38 | scatter_earth(data.data, ax=ax) 39 | # ax.set_aspect('equal') 40 | 41 | # %% 42 | dataloader = DataLoader(Earthquake(), 100, jax.random.PRNGKey(0)) 43 | for batch in dataloader: 44 | print(batch.shape) 45 | 46 | # %% 47 | len(SubDataset(Earthquake(), jnp.arange(100))) 48 | # %% 49 | td = TensorDataset(jnp.arange(100)[:, None]) 50 | subset = SubDataset(td, jnp.arange(50)) 51 | 52 | for b in DataLoader(subset, 10, jax.random.PRNGKey(0), shuffle=False): 53 | print(b) 54 | # %% 55 | 56 | print([len(ds) for ds in random_split(td, [80,10,10], jax.random.PRNGKey(0))]) 57 | # %% 58 | -------------------------------------------------------------------------------- /scripts/gaussian_pushforward/make_gaussian_pushforward_data.py: -------------------------------------------------------------------------------- 1 | # %% 2 | %load_ext autoreload 3 | %autoreload 2 4 | import os 5 | from hydra import compose, initialize 6 | from hydra.utils import instantiate 7 | from omegaconf import OmegaConf 8 | from score_sde.utils import cfg 9 | import jax 10 | import jax.numpy as jnp 11 | import matplotlib.pyplot as plt 12 | 13 | os.environ['GEOMSTATS_BACKEND'] = 'jax' 14 | os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false' 15 | os.environ['CUDA_VISIBLE_DEVICES'] = '1' 16 | # %% 17 | import os 18 | import socket 19 | import logging 20 | from functools import partial 21 | from timeit import default_timer as timer 22 | 23 | import jax 24 | import optax 25 | import numpy as np 26 | import haiku as hk 27 | from tqdm import tqdm 28 | from jax import numpy as jnp 29 | from score_sde.models.flow import SDEPushForward, MoserFlow 30 | from geomstats.geometry.hypersphere import Hypersphere 31 | 32 | # %% 33 | with open("/data/ziz/not-backed-up/mhutchin/score-sde/scripts/blender/mesh_utils.py") as file: 34 | exec(file.read()) 35 | # %% 36 | 37 | def m_to_e(M): 38 | phi = M[..., 0] 39 | theta = M[..., 1] 40 | 41 | return jnp.stack( 42 | [ 43 | jnp.sin(phi) * jnp.cos(theta), 44 | jnp.sin(phi) * jnp.sin(theta), 45 | jnp.cos(phi), 46 | ], 47 | axis=-1, 48 | ) 49 | 50 | 51 | # S2 = EmbeddedS2(1.0) 52 | import numpy as np 53 | num_points = 30 54 | phi = np.linspace(0, np.pi, num_points) 55 | theta = np.linspace(0, 2 * np.pi, num_points + 1)[:-1] 56 | phi, theta = np.meshgrid(phi, theta, indexing="ij") 57 | phi = phi.flatten() 58 | theta = theta.flatten() 59 | m = np.stack( 60 | [phi, theta], axis=-1 61 | ) ### NOTE this ordering, I can change it but it'll be a pain, its latitude, longitude 62 | density = 10 63 | m_dense = np.stack( 64 | np.meshgrid( 65 | np.linspace(0, np.pi, density * num_points), 66 | np.linspace(0, 2 * np.pi, 2 * density * num_points + 1)[:-1], 67 | indexing="ij", 68 | ), 69 | axis=-1, 70 | ) 71 | x = m_to_e(m) 72 | verticies, faces = mesh_to_polyscope( 73 | x.reshape((num_points, num_points, 3)), wrap_y=False 74 | ) 75 | 76 | uv = m / m.max(axis=0, keepdims=True) 77 | mesh_obj = regular_square_mesh_to_obj( 78 | x.reshape((num_points, num_points, 3)), wrap_y=True 79 | ) 80 | 81 | save_obj(mesh_obj, "data/sphere.obj") 82 | 83 | # %% 84 | sphere = Hypersphere(2) 85 | x_dense = m_to_e(m_dense) 86 | # %% 87 | base_point = jnp.array([0,-1,0]) 88 | 89 | def wrap_prob(x, sigma): 90 | tv = sphere.log(x, base_point=base_point[None, :]) 91 | dist = jnp.linalg.norm(tv, axis=-1) ** 2 92 | norm_pdf = 1/jnp.sqrt(2*jnp.pi*sigma**2) * jnp.exp(- 0.5 * dist / sigma**2) 93 | logdetexp = sphere.metric.logdetexp(base_point[None, :], x) 94 | return norm_pdf + logdetexp 95 | 96 | def wrap_logprob(x, sigma): 97 | tv = sphere.log(x, base_point=base_point[None, :]) 98 | dist = jnp.linalg.norm(tv, axis=-1) ** 2 99 | norm_pdf = - 0.5 * dist / sigma**2 100 | logdetexp = sphere.metric.logdetexp(base_point[None, :], x) 101 | return jnp.exp(norm_pdf + logdetexp) 102 | 103 | def rnorm_logprob(x, sigma): 104 | tv = sphere.log(x, base_point=base_point[None, :]) 105 | dist = jnp.linalg.norm(tv, axis=-1) ** 2 106 | norm_pdf = - 0.5 * dist / sigma**2 107 | return jnp.exp(norm_pdf) 108 | 109 | # %% 110 | import geometric_kernels 111 | # Import a space and an appropriate kernel. 112 | from geometric_kernels.spaces.hypersphere import Hypersphere as Hypersphere2 113 | from geometric_kernels.kernels.geometric_kernels import MaternKarhunenLoeveKernel 114 | 115 | # Create a manifold (2-dim sphere). 116 | hypersphere = Hypersphere2(dim=2) 117 | 118 | kernel = MaternKarhunenLoeveKernel(hypersphere, 1000) 119 | params, state = kernel.init_params_and_state() 120 | params["nu"] = np.inf 121 | params["lengthscale"] = np.array([1.]) 122 | 123 | def kernel_prob(x, sigma): 124 | params["lengthscale"] = np.array([sigma]) 125 | return kernel.K(params, state, np.array(x), np.array(base_point[None, :])) 126 | 127 | 128 | # %% 129 | 130 | sigmas = np.power(10,np.linspace(np.log10(0.08), 0.7, 30)) 131 | 132 | for sigma in sigmas: 133 | make_scalar_texture( 134 | partial(wrap_prob, sigma=sigma), 135 | m_to_e(m_dense), 136 | # m_to_e(m_dense).swapaxes(0,1), 137 | f"data/exp_pushforward_{sigma:0.2f}.png", 138 | cmap='viridis', 139 | ) 140 | 141 | make_scalar_texture( 142 | partial(kernel_prob, sigma=sigma), 143 | m_to_e(m_dense), 144 | # m_to_e(m_dense).swapaxes(0,1), 145 | f"data/brownian_{sigma:0.2f}.png", 146 | cmap='viridis', 147 | ) 148 | # %% 149 | 150 | from riemannian_score_sde.models.distribution import WrapNormDistribution 151 | 152 | wrap_norm = WrapNormDistribution(Hypersphere(2)) 153 | 154 | # wrap_norm.log_prob( 155 | # x_dense 156 | # ) 157 | 158 | make_scalar_texture( 159 | partial(wrap_logprob, sigma=0.5), 160 | x_dense, 161 | # m_to_e(m_dense).swapaxes(0,1), 162 | f"data/wrap_norm.png", 163 | cmap='viridis', 164 | ) 165 | 166 | make_scalar_texture( 167 | partial(rnorm_logprob, sigma=0.5), 168 | x_dense, 169 | # m_to_e(m_dense).swapaxes(0,1), 170 | f"data/rnorm.png", 171 | cmap='viridis', 172 | ) 173 | 174 | make_scalar_texture( 175 | lambda x: jnp.zeros_like(x[..., 0]), 176 | x_dense, 177 | # m_to_e(m_dense).swapaxes(0,1), 178 | f"data/unif.png", 179 | cmap='viridis', 180 | ) 181 | # %% 182 | -------------------------------------------------------------------------------- /scripts/gaussian_random_walk/make_grw_data.py: -------------------------------------------------------------------------------- 1 | # %% 2 | %load_ext autoreload 3 | %autoreload 2 4 | import os 5 | from hydra import compose, initialize 6 | from hydra.utils import instantiate 7 | from omegaconf import OmegaConf 8 | from score_sde.utils import cfg 9 | import jax 10 | import jax.numpy as jnp 11 | import matplotlib.pyplot as plt 12 | 13 | os.environ['GEOMSTATS_BACKEND'] = 'jax' 14 | os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false' 15 | 16 | # %% 17 | import os 18 | import socket 19 | import logging 20 | from functools import partial 21 | from timeit import default_timer as timer 22 | 23 | import jax 24 | import optax 25 | import numpy as np 26 | import haiku as hk 27 | from tqdm import tqdm 28 | from jax import numpy as jnp 29 | from score_sde.models.flow import SDEPushForward, MoserFlow 30 | from geomstats.geometry.hypersphere import Hypersphere 31 | 32 | # %% 33 | with open("/data/ziz/not-backed-up/mhutchin/score-sde/scripts/blender/mesh_utils.py") as file: 34 | exec(file.read()) 35 | # %% 36 | 37 | def m_to_e(M): 38 | phi = M[..., 0] 39 | theta = M[..., 1] 40 | 41 | return jnp.stack( 42 | [ 43 | jnp.sin(phi) * jnp.cos(theta), 44 | jnp.sin(phi) * jnp.sin(theta), 45 | jnp.cos(phi), 46 | ], 47 | axis=-1, 48 | ) 49 | 50 | 51 | # S2 = EmbeddedS2(1.0) 52 | import numpy as np 53 | num_points = 30 54 | phi = np.linspace(0, np.pi, num_points) 55 | theta = np.linspace(0, 2 * np.pi, num_points + 1)[1:] 56 | phi, theta = np.meshgrid(phi, theta, indexing="ij") 57 | phi = phi.flatten() 58 | theta = theta.flatten() 59 | m = np.stack( 60 | [phi, theta], axis=-1 61 | ) ### NOTE this ordering, I can change it but it'll be a pain, its latitude, longitude 62 | density = 10 63 | m_dense = np.stack( 64 | np.meshgrid( 65 | np.linspace(0, np.pi, density * num_points), 66 | np.linspace(0, 2 * np.pi, 2 * density * num_points + 1)[1:], 67 | indexing="ij", 68 | ), 69 | axis=-1, 70 | ) 71 | x = m_to_e(m) 72 | verticies, faces = mesh_to_polyscope( 73 | x.reshape((num_points, num_points, 3)), wrap_y=False 74 | ) 75 | 76 | uv = m / m.max(axis=0, keepdims=True) 77 | mesh_obj = regular_square_mesh_to_obj( 78 | x.reshape((num_points, num_points, 3)), wrap_y=True 79 | ) 80 | 81 | # %% 82 | 83 | save_obj(mesh_obj, "data/sphere.obj") 84 | 85 | # %% 86 | 87 | density=10 88 | m_dense = np.stack( 89 | np.meshgrid( 90 | np.linspace(0, np.pi, density * num_points), 91 | (np.linspace(0, 2 * np.pi, 2 * density * num_points + 1)[1:]) % (2*np.pi), 92 | indexing="ij", 93 | ), 94 | axis=-1, 95 | ) 96 | 97 | # %% 98 | sphere = Hypersphere(2) 99 | 100 | 101 | points = [] 102 | tangent_vecs = [] 103 | geodesics = [] 104 | 105 | points.append(jnp.array([0.0, -1.0, 0.0])) 106 | 107 | rng = jax.random.PRNGKey(0) 108 | 109 | length=10 110 | step=0.1 111 | k=10 112 | n=int(length/step) 113 | for i in range(n): 114 | rng, next_rng = jax.random.split(rng) 115 | tangent_vecs.append(step * sphere.random_normal_tangent(rng, points[-1])[1][0]) 116 | geodesics.append(sphere.exp(jnp.linspace(0,1,k)[:, None] * tangent_vecs[-1][None, :], points[-1][None, :])) 117 | points.append(sphere.exp(tangent_vecs[-1], points[-1])) 118 | 119 | 120 | for i in range(n): 121 | np.savetxt(f'data/point_{i}.csv', points[i], delimiter=',') 122 | np.savetxt(f'data/vec_{i}{i+1}.csv', jnp.concatenate([points[i], tangent_vecs[i]])[None, :], delimiter=',') 123 | np.savetxt(f'data/geodesic_{i}{i+1}.csv', geodesics[i], delimiter=',') 124 | 125 | i=n 126 | np.savetxt(f'data/point_{i}.csv', points[i], delimiter=',') 127 | 128 | # %% 129 | -------------------------------------------------------------------------------- /scripts/gaussian_random_walk_step/make_grw_step_data.py: -------------------------------------------------------------------------------- 1 | # %% 2 | %load_ext autoreload 3 | %autoreload 2 4 | import os 5 | from hydra import compose, initialize 6 | from hydra.utils import instantiate 7 | from omegaconf import OmegaConf 8 | from score_sde.utils import cfg 9 | import jax 10 | import jax.numpy as jnp 11 | import matplotlib.pyplot as plt 12 | 13 | os.environ['GEOMSTATS_BACKEND'] = 'jax' 14 | os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false' 15 | 16 | # %% 17 | import os 18 | import socket 19 | import logging 20 | from functools import partial 21 | from timeit import default_timer as timer 22 | 23 | import jax 24 | import optax 25 | import numpy as np 26 | import haiku as hk 27 | from tqdm import tqdm 28 | from jax import numpy as jnp 29 | from score_sde.models.flow import SDEPushForward, MoserFlow 30 | from geomstats.geometry.hypersphere import Hypersphere 31 | 32 | # %% 33 | with open("/data/ziz/not-backed-up/mhutchin/score-sde/scripts/blender/mesh_utils.py") as file: 34 | exec(file.read()) 35 | # %% 36 | 37 | def m_to_e(M): 38 | phi = M[..., 0] 39 | theta = M[..., 1] 40 | 41 | return jnp.stack( 42 | [ 43 | jnp.sin(phi) * jnp.cos(theta), 44 | jnp.sin(phi) * jnp.sin(theta), 45 | jnp.cos(phi), 46 | ], 47 | axis=-1, 48 | ) 49 | 50 | 51 | # S2 = EmbeddedS2(1.0) 52 | import numpy as np 53 | num_points = 30 54 | phi = np.linspace(0, np.pi, num_points) 55 | theta = np.linspace(0, 2 * np.pi, num_points + 1)[:-1] 56 | phi, theta = np.meshgrid(phi, theta, indexing="ij") 57 | phi = phi.flatten() 58 | theta = theta.flatten() 59 | m = np.stack( 60 | [phi, theta], axis=-1 61 | ) ### NOTE this ordering, I can change it but it'll be a pain, its latitude, longitude 62 | density = 10 63 | m_dense = np.stack( 64 | np.meshgrid( 65 | np.linspace(0, np.pi, density * num_points), 66 | np.linspace(0, 2 * np.pi, 2 * density * num_points + 1)[:-1], 67 | indexing="ij", 68 | ), 69 | axis=-1, 70 | ) 71 | x = m_to_e(m) 72 | verticies, faces = mesh_to_polyscope( 73 | x.reshape((num_points, num_points, 3)), wrap_y=False 74 | ) 75 | 76 | uv = m / m.max(axis=0, keepdims=True) 77 | mesh_obj = regular_square_mesh_to_obj( 78 | x.reshape((num_points, num_points, 3)), wrap_y=True 79 | ) 80 | 81 | # %% 82 | 83 | save_obj(mesh_obj, "data/sphere.obj") 84 | 85 | # %% 86 | sphere = Hypersphere(2) 87 | 88 | base_point = jnp.array([0,-1,0]) 89 | ts_base_x = jnp.array([1,0,0]) 90 | ts_base_z = jnp.array([0,0,1]) 91 | sigma = 1.0 92 | 93 | grid = jnp.meshgrid( 94 | jnp.linspace(-1,1,num_points*density),jnp.linspace(-1,1,num_points*density) 95 | ) 96 | 97 | grid = grid[0][..., None] * ts_base_x + grid[1][..., None] * ts_base_z 98 | 99 | def prob(x, sigma): 100 | dist = jnp.linalg.norm(x, axis=-1) ** 2 101 | return 1/jnp.sqrt(2*jnp.pi*sigma**2) * jnp.exp(- dist / sigma**2) 102 | 103 | make_scalar_texture( 104 | partial(prob, sigma=sigma), 105 | grid, 106 | f"data/plane_texture.png", 107 | cmap='viridis', 108 | ) 109 | 110 | # %% 111 | 112 | tv = jnp.array([0.6,0,0]) 113 | geodesic = sphere.exp(jnp.linspace(0,1,100)[:, None] * tv[None, :], base_point[None, :]) 114 | 115 | 116 | np.savetxt(f'data/point.csv', geodesic[-1], delimiter=',') 117 | np.savetxt(f'data/vec.csv', jnp.concatenate([base_point, tv])[None, :], delimiter=',') 118 | np.savetxt(f'data/geodesic.csv', geodesic, delimiter=',') 119 | 120 | # i=n 121 | # np.savetxt(f'data/point_{i}.csv', points[i], delimiter=',') 122 | # %% 123 | -------------------------------------------------------------------------------- /scripts/kent/agg_results.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import pandas as pd 3 | import numpy as np 4 | import pickle 5 | 6 | def unpickle(fp): 7 | with open(fp, 'rb') as f: 8 | return pickle.load(f) 9 | 10 | save_dir = './kent' 11 | fns = os.listdir(save_dir) 12 | 13 | rows = [] 14 | for fn in fns: 15 | fp = os.path.join(save_dir, fn) 16 | out = unpickle(fp) 17 | dataset = "_".join(fp.split('_')[:-4]).split('/')[-1] 18 | row = {**out['inputs'], **{'train_loglik': out['train_loglik'], 19 | 'test_loglik': out['test_loglik'], 20 | 'eval_loglik': out['eval_loglik'], 21 | 'dataset': dataset}} 22 | rows.append(row) 23 | 24 | 25 | 26 | 27 | df = pd.DataFrame.from_records(rows) 28 | 29 | datasets = df.dataset.unique() 30 | 31 | print(datasets) 32 | 33 | df['mean_loglik'] = df.test_loglik 34 | df['std_loglik'] = df.test_loglik 35 | df['count'] = df.seed 36 | 37 | # hack in case eval group too small and has nan - should not do anything 38 | df['eval_loglik'] = np.where(pd.isna(df.eval_loglik), df.train_loglik, df.eval_loglik ) 39 | 40 | 41 | results = [] 42 | for data_tag in datasets: 43 | print(data_tag) 44 | rview = df.dataset == data_tag 45 | res = df[rview].groupby(['dataset', 46 | 'n_components', 'iterations']) \ 47 | .agg({'train_loglik': 'mean', 48 | 'eval_loglik': 'mean', 49 | 'mean_loglik': 'mean', 50 | 'std_loglik': 'std', 51 | 'count': 'count'}) \ 52 | .reset_index() 53 | 54 | res = res[res['count']>1] \ 55 | .sort_values(by='eval_loglik', ascending=False) \ 56 | .iloc[0:1,:] 57 | results.append(res) 58 | 59 | 60 | print(pd.concat(results, axis=0)) -------------------------------------------------------------------------------- /scripts/kent/config.yaml: -------------------------------------------------------------------------------- 1 | 2 | defaults: 3 | - override hydra/launcher: joblib 4 | 5 | data_folder: /Users/jamesthornton/Documents/diffusions/score-sde/data 6 | seed: 0 7 | dataset: volerup 8 | initializations: 20 9 | n_components: 20 10 | iterations: 100 11 | output_dir: '/Users/jamesthornton/Documents/diffusions/score-sde/scripts/kent/output' 12 | 13 | # ../../../miniconda3/envs/diffusions/bin/python scripts/kent/run_kent.py -m seed=1,2,3,4,5 dataset=fire,flood,quakes_all,volerup n_components=5,10,15,20,25 iterations=50,100,150 -------------------------------------------------------------------------------- /scripts/kent/run_kent.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import numpy as np 3 | import os, sys 4 | import hydra 5 | import pickle 6 | 7 | 8 | 9 | from scripts.kent.kent_model import KentMixture, to_cartesian_coords 10 | from score_sde.datasets.split import random_split 11 | 12 | 13 | 14 | def run_kent(X, initializations, n_components, iterations=100, seed=0): 15 | 16 | inputs = {'initializations': initializations, 17 | 'n_components' : n_components, 18 | 'iterations': iterations, 19 | 'seed': seed 20 | } 21 | 22 | random_state = np.random.RandomState(seed=seed) 23 | rng = jax.random.PRNGKey(seed) 24 | 25 | datasets= random_split(X, [0.8,0.1,0.1], rng) 26 | train_X, eval_X, test_X = [d.dataset[d.indices] for d in datasets] 27 | 28 | klf = KentMixture(n_components=n_components, n_iter=iterations, 29 | n_init=initializations, random_state=random_state) 30 | klf.fit(train_X) 31 | 32 | kmm_lpr, kmm_responsibilities = klf.score_samples(train_X) 33 | train_log_lik = np.mean(kmm_lpr) 34 | 35 | kmm_lpr, kmm_responsibilities = klf.score_samples(eval_X) 36 | eval_log_lik = np.mean(kmm_lpr) 37 | 38 | kmm_lpr, kmm_responsibilities = klf.score_samples(test_X) 39 | test_log_lik = np.mean(kmm_lpr) 40 | 41 | 42 | return {'train_loglik' : train_log_lik, 43 | 'test_loglik' : test_log_lik, 44 | 'eval_loglik' : eval_log_lik, 45 | 'params': klf.get_params(), 46 | 'inputs': inputs, 47 | 'converged': klf.converged_} 48 | 49 | def get_loglik(params, X): 50 | klf = KentMixture() 51 | klf.set_params(params) 52 | kmm_lpr, kmm_responsibilities = klf.score_samples(X) 53 | log_lik = np.mean(kmm_lpr) 54 | return log_lik 55 | 56 | @hydra.main(config_path="./", config_name="config") 57 | def main(cfg): 58 | 59 | # load and format data 60 | folder = cfg.data_folder 61 | file = os.path.join(folder, cfg.dataset +'.csv') 62 | 63 | X = np.genfromtxt(file, delimiter=",", skip_header=2) 64 | N = X.shape[0] 65 | intrinsic_data = ( 66 | np.pi * (X / 180.0) + np.array([np.pi / 2, np.pi])[None, :] 67 | ) 68 | X = to_cartesian_coords(intrinsic_data) 69 | 70 | output = run_kent(X, cfg.initializations, cfg.n_components, cfg.iterations, cfg.seed) 71 | 72 | output_dir = cfg.output_dir 73 | file_name = f'{cfg.dataset}_{cfg.seed}_{cfg.n_components}_{cfg.initializations}_{cfg.iterations}.pkl' 74 | fp = os.path.join(output_dir, file_name) 75 | with open(fp, 'wb') as f: 76 | pickle.dump(output, f) 77 | 78 | if __name__ == '__main__': 79 | 80 | main() 81 | 82 | -------------------------------------------------------------------------------- /scripts/make_fisher_data.py: -------------------------------------------------------------------------------- 1 | import setGPU 2 | import numpy as np 3 | from pathlib import Path 4 | import os 5 | 6 | os.environ["GEOMSTATS_BACKEND"] = "jax" 7 | 8 | import jax.numpy as jnp 9 | from riemannian_score_sde.datasets import Langevin 10 | from geomstats.geometry.special_orthogonal import ( 11 | SpecialOrthogonal, 12 | _SpecialOrthogonal3Vectors, 13 | ) 14 | 15 | 16 | if __name__ == "__main__": 17 | manifold = SpecialOrthogonal(n=3, point_type="matrix") 18 | batch_size = 512 19 | K = 4 20 | scale = 2 21 | dir_path = f"{os.getcwd()}/data/so3_langevin/K_{K}_s_{scale}" 22 | 23 | params = dict( 24 | scale=scale, 25 | K=K, 26 | batch_dims=[batch_size], 27 | manifold=manifold, 28 | seed=0, 29 | conditional=True, 30 | ) 31 | dataset = Langevin(**params) 32 | 33 | # N = batch_size * 100 34 | N = batch_size * 10 35 | Xs = [] 36 | ks = [] 37 | for n in range(N // batch_size): 38 | print(n, N // batch_size) 39 | X, k = next(dataset) 40 | Xs.append(X) 41 | ks.append(k) 42 | Xs = jnp.concatenate(Xs, axis=0) 43 | ks = jnp.concatenate(ks, axis=0) 44 | Xs = Xs.reshape(N, -1) 45 | ks = ks.reshape(N, -1) 46 | 47 | Path(dir_path).mkdir(parents=True, exist_ok=True) 48 | np.savetxt(f"{dir_path}/X.csv", Xs, delimiter=",") 49 | np.savetxt(f"{dir_path}/k.csv", ks, delimiter=",") 50 | -------------------------------------------------------------------------------- /scripts/sabdab/download_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import zipfile 4 | import urllib.request 5 | from tqdm import tqdm 6 | 7 | import geomstats as gs 8 | import jax.numpy as jnp 9 | 10 | data_dir = "/data/localhost/not-backed-up/mhutchin/score-sde/data/sabdab" 11 | 12 | _dataset_url = "http://opig.stats.ox.ac.uk/webapps/newsabdab/sabdab/archive/all/" 13 | 14 | 15 | class DownloadProgressBar(tqdm): 16 | def update_to(self, b=1, bsize=1, tsize=None): 17 | if tsize is not None: 18 | self.total = tsize 19 | self.update(b * bsize - self.n) 20 | 21 | 22 | def download_url(url, output_path): 23 | with DownloadProgressBar( 24 | unit="B", unit_scale=True, miniters=1, desc=url.split("/")[-1] 25 | ) as t: 26 | urllib.request.urlretrieve(url, filename=output_path, reporthook=t.update_to) 27 | 28 | 29 | def clean_files(): 30 | if os.path.isdir(data_dir): 31 | shutil.rmtree(data_dir) 32 | 33 | 34 | def download_files(): 35 | os.makedirs(data_dir, exist_ok=True) 36 | 37 | download_url(_dataset_url, os.path.join(data_dir, "all_structures.zip")) 38 | 39 | 40 | def extract_files(): 41 | with zipfile.ZipFile(os.path.join(data_dir, "all_structures.zip"), "r") as zip_ref: 42 | zip_ref.extractall(os.path.join(data_dir)) 43 | 44 | 45 | clean_files() 46 | download_files() 47 | extract_files() 48 | -------------------------------------------------------------------------------- /scripts/test_abdb.py: -------------------------------------------------------------------------------- 1 | # %% 2 | %load_ext autoreload 3 | %autoreload 2 4 | 5 | # %% 6 | import jax.numpy as jnp 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | import seaborn as sns 10 | from tqdm import tqdm 11 | # %% 12 | from Bio import PDB 13 | from Bio.PDB import internal_coords 14 | import nglview as nv 15 | # ic_structure = PDB.internal_coords.IC_Chain(structure) 16 | # ic_structure.internal_to_atom_coordinates() 17 | import glob 18 | datadir = "/data/localhost/not-backed-up/mhutchin/score-sde/data/abdb/raw_antibodies" 19 | files = glob.glob(datadir + '/*.pdb') 20 | # %% 21 | parser = PDB.PDBParser() 22 | structure = parser.get_structure('test_struct', "/data/localhost/not-backed-up/mhutchin/score-sde/data/abdb/raw_antibodies/1A2Y_1.pdb") 23 | structure.atom_to_internal_coordinates() 24 | 25 | chain = next(structure.get_chains()) 26 | angles = np.zeros((len(list(chain.get_residues())), 3)) 27 | lengths = np.zeros((len(list(chain.get_residues())), 3)) 28 | for i, r in enumerate(chain.get_residues()): 29 | if r.internal_coord.rprev: 30 | angles[i-1, 0] = r.internal_coord.get_angle("omega") 31 | angles[i-1, 1] = r.internal_coord.get_angle("phi") 32 | 33 | lengths[i, 0] = r.internal_coord.get_length("N:CA") 34 | lengths[i, 1] = r.internal_coord.get_length("CA:C") 35 | 36 | if r.internal_coord.rnext: 37 | angles[i, 2] = r.internal_coord.get_angle("psi") 38 | lengths[i, 2] = r.internal_coord.get_length("C:1N") 39 | 40 | # %% 41 | sns.histplot(lengths[:, 0]) 42 | sns.histplot(lengths[:, 1]) 43 | sns.histplot(lengths[:, 2]) 44 | 45 | # %% 46 | sns.set_theme(style="dark") 47 | 48 | 49 | x=angles[:, 1] 50 | y=angles[:,2] 51 | sns.scatterplot(x=x,y=y, s=5, color=".15") 52 | # sns.histplot(x=x, y=y, bins=50, pthresh=-1, cmap="mako") 53 | sns.kdeplot(x=x, y=y, levels=5, color="w", linewidths=1) 54 | 55 | plt.gca().set_aspect('equal') 56 | 57 | # %% 58 | sns.set_theme(style="dark") 59 | 60 | 61 | x=angles[:,1] 62 | y=angles[:,0] 63 | sns.scatterplot(x=x,y=y, s=5, color=".15") 64 | # sns.histplot(x=x, y=y, bins=50, pthresh=-1, cmap="mako") 65 | sns.kdeplot(x=x, y=y, levels=5, color="w", linewidths=1) 66 | 67 | plt.gca().set_aspect('equal') 68 | # %% 69 | sns.set_theme(style="dark") 70 | 71 | 72 | x=angles[:,2] 73 | y=angles[:,0] 74 | sns.scatterplot(x=x,y=y, s=5, color=".15") 75 | # sns.histplot(x=x, y=y, bins=50, pthresh=-1, cmap="mako") 76 | sns.kdeplot(x=x, y=y, levels=5, color="w", linewidths=1) 77 | 78 | plt.gca().set_aspect('equal') 79 | # %% 80 | import random 81 | f = random.choice(files) 82 | structure = parser.get_structure('test_struct', f) 83 | chains = list(structure.get_chains()) 84 | nv.show_biopython(chains[2]) 85 | 86 | # %% 87 | new_structure = PDB.StructureBuilder() 88 | 89 | # %% 90 | def get_angles_lengths(chain): 91 | 92 | angles = np.empty((len(list(chain.get_residues())), 3)) 93 | lengths = np.empty((len(list(chain.get_residues())), 4)) 94 | angles[:] = np.nan 95 | lengths[:] = np.nan 96 | 97 | for i, r in enumerate(chain.get_residues()): 98 | try: 99 | if i != 0: 100 | angles[i-1, 0] = r.internal_coord.get_angle("omega") 101 | angles[i-1, 1] = r.internal_coord.get_angle("phi") 102 | 103 | lengths[i, 0] = r.internal_coord.get_length("N:CA") 104 | lengths[i, 1] = r.internal_coord.get_length("CA:C") 105 | lengths[i, 3] = r.internal_coord.get_length("C:O") 106 | 107 | if i != len(list(structure.get_residues()))-1: 108 | angles[i, 2] = r.internal_coord.get_angle("psi") 109 | lengths[i, 2] = r.internal_coord.get_length("C:1N") 110 | except AttributeError: 111 | pass 112 | 113 | return angles, lengths 114 | 115 | def process_pdb_file(file): 116 | 117 | parser = PDB.PDBParser() 118 | structure = parser.get_structure('struct', file) 119 | structure.atom_to_internal_coordinates() 120 | 121 | angles= [] 122 | lengths = [] 123 | 124 | for c in structure.get_chains(): 125 | a, l = get_angles_lengths(c) 126 | angles.append(a) 127 | lengths.append(l) 128 | 129 | return angles, lengths 130 | 131 | 132 | angles = 5*[[]] 133 | lengths = 5*[[]] 134 | for f in tqdm(files[:10]): 135 | a, l = process_pdb_file(f) 136 | # angles += a 137 | # lengths += l 138 | print(len(a), len(l)) 139 | for i in range(len(a)): 140 | angles[i].append(a[i]) 141 | lengths[i].append(l[i]) 142 | 143 | # combined_angles = [np.concatenate(a, axis=0) for a in angles] 144 | # combined_lengths = [np.concatenate(l, axis=0) for l in lengths] 145 | # %% 146 | sns.histplot([a.shape[0] for a in angles]) 147 | 148 | # %% 149 | 150 | x=combined_angles[:, 1] 151 | y=combined_angles[:, 2] 152 | sns.scatterplot(x=x,y=y, s=3, alpha=0.3, color=".15") 153 | plt.gca().set_aspect('equal') 154 | 155 | # %% 156 | sns.histplot(combined_angles[:, 0] % 360) 157 | plt.xlim((90,270)) 158 | 159 | # %% 160 | sns.histplot(combined_lengths[:, 0], color='red') 161 | sns.histplot(combined_lengths[:, 1], color='green') 162 | sns.histplot(combined_lengths[:, 2], color='blue') 163 | sns.histplot(combined_lengths[:, 3], color='yellow') 164 | plt.xlim((np.nanmin(combined_lengths), np.nanmax(combined_lengths))) 165 | # %% 166 | 167 | for f in tqdm(files[:10]): 168 | parser = PDB.PDBParser() 169 | structure = parser.get_structure('struct', f) 170 | structure.atom_to_internal_coordinates() 171 | 172 | print(len(list(structure.get_chains()))) 173 | # %% 174 | -------------------------------------------------------------------------------- /scripts/test_sabdab.py: -------------------------------------------------------------------------------- 1 | # %% 2 | %load_ext autoreload 3 | %autoreload 2 4 | 5 | # %% 6 | import jax.numpy as jnp 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | import seaborn as sns 10 | from tqdm import tqdm 11 | # %% 12 | from Bio import PDB 13 | from Bio.PDB import internal_coords 14 | import nglview as nv 15 | # ic_structure = PDB.internal_coords.IC_Chain(structure) 16 | # ic_structure.internal_to_atom_coordinates() 17 | # %% 18 | parser = PDB.PDBParser() 19 | structure = parser.get_structure('test_struct', "/data/localhost/not-backed-up/mhutchin/score-sde/data/sabdab/all_structures/raw/1a0q.pdb") 20 | structure.atom_to_internal_coordinates() 21 | nv.show_biopython(structure) 22 | 23 | # %% 24 | def get_angles_lengths(chain): 25 | 26 | angles = np.empty((len(list(chain.get_residues())), 3)) 27 | lengths = np.empty((len(list(chain.get_residues())), 4)) 28 | angles[:] = np.nan 29 | lengths[:] = np.nan 30 | 31 | for i, r in enumerate(chain.get_residues()): 32 | try: 33 | if i != 0: 34 | angles[i-1, 0] = r.internal_coord.get_angle("omega") 35 | angles[i-1, 1] = r.internal_coord.get_angle("phi") 36 | 37 | lengths[i, 0] = r.internal_coord.get_length("N:CA") 38 | lengths[i, 1] = r.internal_coord.get_length("CA:C") 39 | lengths[i, 3] = r.internal_coord.get_length("C:O") 40 | 41 | if i != len(list(structure.get_residues()))-1: 42 | angles[i, 2] = r.internal_coord.get_angle("psi") 43 | lengths[i, 2] = r.internal_coord.get_length("C:1N") 44 | except AttributeError: 45 | pass 46 | 47 | return angles, lengths 48 | 49 | def process_pdb_file(file): 50 | 51 | parser = PDB.PDBParser() 52 | structure = parser.get_structure('struct', file) 53 | structure.atom_to_internal_coordinates() 54 | 55 | angles= [] 56 | legnths = [] 57 | 58 | for c in structure.get_chains(): 59 | a, l = get_angles_lengths(c) 60 | angles.append(a) 61 | lengths.append(l) 62 | 63 | return angles, lengths 64 | 65 | 66 | 67 | import glob 68 | datadir = "/data/localhost/not-backed-up/mhutchin/score-sde/data/sabdab/all_structures/imgt" 69 | files = glob.glob(datadir + '/*.pdb') 70 | 71 | angles = [] 72 | lengths = [] 73 | for f in tqdm(files[:10]): 74 | a, l = process_pdb_file(f) 75 | angles += a 76 | lengths += l 77 | 78 | combined_angles = np.concatenate(angles, axis=0) 79 | combined_lengths = np.concatenate(lengths, axis=0) 80 | # %% 81 | sns.histplot([a.shape[0] for a in angles]) 82 | 83 | # %% 84 | 85 | x=combined_angles[:, 1] 86 | y=combined_angles[:, 2] 87 | sns.scatterplot(x=x,y=y, s=3, alpha=0.3, color=".15") 88 | plt.gca().set_aspect('equal') 89 | 90 | # %% 91 | sns.histplot(combined_angles[:, 0] % 360) 92 | plt.xlim((90,270)) 93 | 94 | # %% 95 | sns.histplot(combined_lengths[:, 0], color='red') 96 | sns.histplot(combined_lengths[:, 1], color='green') 97 | sns.histplot(combined_lengths[:, 2], color='blue') 98 | sns.histplot(combined_lengths[:, 3], color='yellow') 99 | plt.xlim((np.nanmin(combined_lengths), np.nanmax(combined_lengths))) 100 | # %% 101 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from setuptools import find_namespace_packages, setup 5 | 6 | with open("requirements.txt", "r") as file: 7 | requirements = [line.strip() for line in file] 8 | 9 | with open("README.md", "r") as file: 10 | long_description = file.read() 11 | 12 | with open("VERSION", "r") as file: 13 | version = file.read().strip() 14 | 15 | setup( 16 | name="score-sde", 17 | version=version, 18 | author="Michael Hutchinson, Emile Mathieu", 19 | author_email="michael.hutchinson@stats.ox.ac.uk", 20 | long_description=long_description, 21 | long_description_content_type="text/markdown", 22 | description="", 23 | # license="Apache License 2.0", 24 | keywords="", 25 | # install_requires=requirements, 26 | packages=find_namespace_packages( 27 | include=["score_sde", "riemannian_score-sde", "stochastic-process-score-sde"] 28 | ), 29 | classifiers=[ 30 | # "License :: OSI Approved :: Apache Software License", 31 | # "Programming Language :: Python :: 3.6", 32 | # "Programming Language :: Python :: 3.7", 33 | "Programming Language :: Python :: 3.8", 34 | "Operating System :: OS Independent", 35 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 36 | ], 37 | ) 38 | -------------------------------------------------------------------------------- /tests/test_haiku.py: -------------------------------------------------------------------------------- 1 | import haiku as hk 2 | 3 | class Pass(hk.Module): 4 | def __call__(self, x): 5 | return x 6 | 7 | print(hk.transform(Pass)) -------------------------------------------------------------------------------- /tests/test_likelihood.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ["GEOMSTATS_BACKEND"] = "jax" 4 | from functools import partial 5 | import hydra 6 | from hydra.utils import instantiate, get_class 7 | 8 | import jax 9 | from jax import numpy as jnp 10 | import haiku as hk 11 | 12 | from riemannian_score_sde.utils.normalization import compute_normalization 13 | 14 | 15 | @hydra.main(config_path="../config", config_name="main") 16 | def main(cfg): 17 | 18 | data_manifold = instantiate(cfg.manifold) 19 | transform = instantiate(cfg.transform, data_manifold) 20 | model_manifold = transform.domain 21 | flow = instantiate(cfg.flow, manifold=model_manifold) 22 | base = instantiate(cfg.base, model_manifold, flow) 23 | pushforward = instantiate(cfg.pushf, flow, base, transform=transform) 24 | 25 | rng = jax.random.PRNGKey(cfg.seed) 26 | rng, next_rng = jax.random.split(rng) 27 | dataset = instantiate(cfg.dataset, rng=next_rng, manifold=data_manifold) 28 | y = transform.inv(next(dataset)[0]) 29 | 30 | def score_model(y, t, context=None): 31 | output_shape = get_class(cfg.generator._target_).output_shape(model_manifold) 32 | score = instantiate( 33 | cfg.generator, 34 | cfg.architecture, 35 | cfg.embedding, 36 | output_shape, 37 | manifold=model_manifold, 38 | ) 39 | return score(y, t) 40 | 41 | score_model = hk.transform_with_state(score_model) 42 | 43 | rng, next_rng = jax.random.split(rng) 44 | params, state = score_model.init(rng=next_rng, y=y, t=jnp.zeros((y.shape[0], 1))) 45 | 46 | model_w_dicts = (score_model, params, state) 47 | likelihood_fn = pushforward.get_log_prob(model_w_dicts, train=False) 48 | 49 | Z = compute_normalization(likelihood_fn, data_manifold, N=200) 50 | print(f"Z = {Z:.2f}") 51 | 52 | 53 | if __name__ == "__main__": 54 | main() 55 | -------------------------------------------------------------------------------- /tests/test_likelihood_transform.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ["GEOMSTATS_BACKEND"] = "jax" 4 | import jax 5 | import jax.numpy as jnp 6 | from riemannian_score_sde.models.transform import * 7 | 8 | rng = jax.random.PRNGKey(0) 9 | 10 | # SO(3) & exponential map 11 | print("SO(3)") 12 | from geomstats.geometry.special_orthogonal import SpecialOrthogonal 13 | from riemannian_score_sde.utils.normalization import compute_normalization 14 | from score_sde.models import get_likelihood_fn_w_transform 15 | import numpy as np 16 | 17 | K = 4 18 | manifold = SpecialOrthogonal(n=3, point_type="matrix") 19 | rng, next_rng = jax.random.split(rng) 20 | base_point = manifold.identity 21 | # base_point = manifold.random_uniform(next_rng, n_samples=1) 22 | transform = TanhExpMap(manifold) 23 | 24 | 25 | class NormalDistribution: 26 | def __init__(self, mean=0.0, scale=1.0): 27 | super().__init__() 28 | self.mean = mean 29 | self.scale = scale 30 | 31 | def sample(self, rng, shape): 32 | return self.mean + self.scale * jax.random.normal(rng, shape) 33 | 34 | def log_prob(self, z): 35 | shape = z.shape 36 | d = np.prod(shape[1:]) 37 | logp_fn = ( 38 | lambda z: -d / 2.0 * jnp.log(2 * np.pi) 39 | - d * jnp.log(self.scale) 40 | - jnp.sum(((z - self.mean) / self.scale) ** 2) / 2.0 41 | ) 42 | return jax.vmap(logp_fn)(z) 43 | 44 | 45 | flow = transform 46 | base = NormalDistribution(mean=0.0, scale=1.0) 47 | likelihood_fn = lambda y, **kwargs: (base.log_prob(y), 0) 48 | 49 | # Normalization constant of base distribition in R^3 50 | N = 100 51 | d = 3 52 | bound = 5 53 | xs = jnp.linspace(-bound, bound, N) 54 | xs = d * [xs] 55 | xs = jnp.meshgrid(*xs) 56 | xs = jnp.concatenate([x.reshape(-1, 1) for x in xs], axis=-1) 57 | logp = likelihood_fn(xs)[0] 58 | prob = jnp.exp(logp) 59 | Z = (prob).mean() * ((2 * bound) ** d) 60 | print(f"Z = {Z:.2f}") 61 | 62 | # Normalization constant of pushforward in SO(3) 63 | likelihood_fn = get_likelihood_fn_w_transform(likelihood_fn, transform) 64 | Z = compute_normalization(likelihood_fn, manifold) 65 | print(f"Z = {Z:.2f}") 66 | -------------------------------------------------------------------------------- /tests/test_registry.py: -------------------------------------------------------------------------------- 1 | # %% 2 | %load_ext autoreload 3 | %autoreload 2 4 | 5 | # %% 6 | from score_sde.utils.registry import register_category, _REGISTRY 7 | import functools 8 | # %% 9 | get_test, register_test = register_category("test") 10 | 11 | # %% 12 | register_test("string", name="test_string") 13 | 14 | @register_test 15 | def test_func(): 16 | return 0 17 | 18 | @register_test 19 | class test_class: 20 | pass 21 | 22 | # %% 23 | -------------------------------------------------------------------------------- /tests/test_transform.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ["GEOMSTATS_BACKEND"] = "jax" 4 | from functools import partial 5 | from geomstats.geometry.hypersphere import Hypersphere 6 | import jax 7 | import jax.numpy as jnp 8 | import geomstats.backend as gs 9 | from geomstats.geometry.special_orthogonal import SpecialOrthogonal 10 | from riemannian_score_sde.models.transform import * 11 | from riemannian_score_sde.datasets import Wrapped 12 | 13 | sq_log_det = lambda jac: jnp.log(jnp.abs(jnp.linalg.det(jac))) 14 | rec_log_det = lambda jac: jnp.log( 15 | jnp.sqrt(jnp.linalg.det(jac.transpose((0, 2, 1)) @ jac)) 16 | ) 17 | 18 | rng = jax.random.PRNGKey(0) 19 | 20 | ## Inverse of stereographic map: R^d -> S^d 21 | print("S^2") 22 | 23 | dim = 2 24 | K = 4 25 | manifold = Hypersphere(dim) 26 | transform = InvStereographic(manifold) 27 | rng, next_rng = jax.random.split(rng) 28 | # rng, x = manifold.random_normal_tangent(state=rng, base_point=transform.base_point, n_samples=K) 29 | x = jax.random.normal(next_rng, shape=[K, transform.domain.dim]) 30 | y = transform(x) 31 | x_prime = transform.inv(y) 32 | assert jnp.isclose(x, x_prime).all() 33 | jac = jax.vmap(jax.jacrev(transform))(x) 34 | logdet_numerical = rec_log_det(jac) 35 | logdet = transform.log_abs_det_jacobian(x, y) 36 | assert jnp.isclose(logdet, logdet_numerical).all() 37 | 38 | # Exponential map 39 | base_point = jnp.array([0.0, 1.0, 0.0]) 40 | transform = ExpMap(manifold, base_point) 41 | rng, next_rng = jax.random.split(rng) 42 | rng, x = manifold.random_normal_tangent( 43 | state=rng, base_point=transform.base_point, n_samples=K 44 | ) 45 | y = transform(x) 46 | x_prime = transform.inv(y) 47 | assert jnp.isclose(x, x_prime).all() 48 | 49 | ## Radial tanh 50 | print("tanh") 51 | # transform = RadialTanhTransform(1., manifold.dim) 52 | transform = RadialTanhTransform(manifold.injectivity_radius, manifold.dim) 53 | rng, next_rng = jax.random.split(rng) 54 | x = jax.random.normal(next_rng, shape=[K, 1]) 55 | x = 2 * x 56 | y = transform(x) 57 | x_prime = transform.inv(y) 58 | assert jnp.isclose(x, x_prime).all() 59 | jac = jax.vmap(jax.jacrev(transform))(x) 60 | logdet_numerical = sq_log_det(jac) 61 | logdet = transform.log_abs_det_jacobian(x, y) 62 | assert jnp.isclose(logdet, logdet_numerical).all() 63 | 64 | # SO(3) & exponential map 65 | print("SO(3)") 66 | 67 | K = 4 68 | # K = 10 69 | manifold = SpecialOrthogonal(n=3, point_type="matrix") 70 | rng, next_rng = jax.random.split(rng) 71 | 72 | dataset = Wrapped( 73 | K=1, 74 | scale=0.1, 75 | scale_type="random", 76 | mean="anti", 77 | conditional=False, 78 | batch_dims=(K,), 79 | manifold=manifold, 80 | seed=0, 81 | ) 82 | P, _ = next(dataset) 83 | A = gs.linalg.logm(P) 84 | A_prime = manifold.log_from_identity(P) 85 | assert jnp.isclose(A, A_prime).all() 86 | 87 | base_point = manifold.identity 88 | base_point = manifold.random_uniform(next_rng, n_samples=1) 89 | radius = manifold.injectivity_radius 90 | transform = TanhExpMap(manifold, radius=radius) 91 | # transform = RadialTanhTransform(manifold.injectivity_radius, manifold.dim) 92 | rng, next_rng = jax.random.split(rng) 93 | _, X = manifold.random_normal_tangent( 94 | state=rng, base_point=transform.base_point, n_samples=K 95 | ) 96 | 97 | # testing that exp and log map works 98 | assert jnp.isclose(X, manifold.hat(manifold.vee(X))).all() 99 | P = gs.linalg.expm(X) 100 | P_prime = manifold.exp_from_identity(X) 101 | assert jnp.isclose(P, P_prime).all() 102 | A = gs.linalg.logm(P) 103 | A_prime = manifold.log_from_identity(P_prime) 104 | assert jnp.isclose(A, A_prime).all() 105 | assert jnp.isclose(A, X).all() 106 | assert jnp.isclose( 107 | manifold.metric.squared_norm(manifold.lie_algebra.basis_normed[..., 0]), 1.0 108 | ) 109 | 110 | delta_ij = manifold.metric.inner_product( 111 | manifold.lie_algebra.basis_normed[..., 0], manifold.lie_algebra.basis_normed[..., 1] 112 | ) 113 | delta_ii = manifold.metric.inner_product( 114 | manifold.lie_algebra.basis_normed[..., 0], manifold.lie_algebra.basis_normed[..., 0] 115 | ) 116 | assert jnp.isclose(delta_ij, 0.0).all() 117 | assert jnp.isclose(delta_ii, 1.0).all() 118 | 119 | x = manifold.vector_from_skew_matrix(X) 120 | 121 | Y = transform(x) 122 | x_prime = transform.inv(Y) 123 | assert jnp.isclose(x, x_prime).all() 124 | func = lambda x: transform(x) 125 | # func = lambda x: jax.scipy.linalg.expm(manifold.hat(x)) 126 | jac = jax.vmap(jax.jacrev(func))(x) 127 | inv_f_x = jnp.repeat(jnp.expand_dims(manifold.inverse(transform(x)), -1), 3, -1) 128 | jac = jax.vmap(manifold.compose, in_axes=-1, out_axes=-1)(inv_f_x, jac) 129 | is_tangent = jax.vmap(partial(manifold.is_tangent, atol=1e-2), in_axes=-1, out_axes=-1)( 130 | jac 131 | ) 132 | assert is_tangent.all() 133 | jac = jax.vmap(manifold.vee, in_axes=-1, out_axes=-1)(jac) 134 | logdet_numerical = sq_log_det(jac) 135 | logdet = transform.log_abs_det_jacobian(x, Y) 136 | assert jnp.isclose(logdet, logdet_numerical).all() 137 | -------------------------------------------------------------------------------- /tests/test_vmf.py: -------------------------------------------------------------------------------- 1 | import math 2 | from jax import numpy as jnp 3 | 4 | from riemannian_score_sde.utils.normalization import compute_normalization 5 | from riemannian_score_sde.datasets import vMFDataset 6 | from geomstats.geometry.hypersphere import Hypersphere 7 | 8 | 9 | class PowerSpherical: 10 | def __init__(self, loc, scale): 11 | self.loc = loc 12 | self.scale = scale 13 | 14 | def log_prob(self, value): 15 | return self.log_normalizer() + self.scale * jnp.log1p((self.loc * value).sum(-1)) 16 | 17 | def log_normalizer(self): 18 | dim = 2 + 1 19 | alpha = (dim - 1) / 2 + self.scale 20 | beta = (dim - 1) / 2 21 | return -( 22 | (alpha + beta) * math.log(2) 23 | + math.lgamma(alpha) 24 | - math.lgamma(alpha + beta) 25 | + beta * math.log(math.pi) 26 | ) 27 | 28 | 29 | print("S^2") 30 | manifold = Hypersphere(2) 31 | mu = [1.0, 0.0, 0.0] 32 | kappa = 100 33 | distribution = vMFDataset(None, None, manifold, mu, kappa) 34 | # distribution = PowerSpherical(mu, kappa) 35 | 36 | likelihood_fn = lambda y, *args, **kwargs: distribution.log_prob(y) 37 | 38 | Z = compute_normalization(likelihood_fn, manifold, N=1000) 39 | print(f"Z = {Z:.2f}") 40 | --------------------------------------------------------------------------------