├── .coveragerc ├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md └── workflows │ ├── coverage.yml │ └── tests_bench.yml ├── .gitignore ├── .readthedocs.yml ├── CITATION.cff ├── CONTRIBUTING.md ├── LICENSE ├── MANIFEST.in ├── README.md ├── codecov.yml ├── docs ├── Makefile ├── README.md ├── make.bat ├── requirements.txt └── source │ ├── _static │ └── basic.css │ ├── _templates │ ├── _variables.scss │ └── layout.html │ ├── conf.py │ ├── imgs │ ├── CNNs.jpeg │ ├── Case_study_1.jpg │ ├── DA_diagram(1).pdf │ ├── DA_diagram.png │ ├── baseline_results.png │ ├── logo_pyraug_2.jpeg │ ├── nine_digits-rot.pdf │ ├── nine_digits-rot.png │ ├── nine_digits.pdf │ ├── nine_digits.png │ ├── optimized_results.png │ ├── pyraug_diagram.jpg │ └── pyraug_diagram_simplified.jpg │ ├── index.rst │ ├── models │ ├── autoencoders │ │ ├── aae.rst │ │ ├── ae.rst │ │ ├── auto_model.rst │ │ ├── baseAE.rst │ │ ├── betatcvae.rst │ │ ├── betavae.rst │ │ ├── ciwae.rst │ │ ├── disentangled_betavae.rst │ │ ├── factorvae.rst │ │ ├── hvae.rst │ │ ├── infovae.rst │ │ ├── iwae.rst │ │ ├── miwae.rst │ │ ├── models.rst │ │ ├── msssimvae.rst │ │ ├── piwae.rst │ │ ├── pvae.rst │ │ ├── rae_gp.rst │ │ ├── rae_l2.rst │ │ ├── rhvae.rst │ │ ├── svae.rst │ │ ├── vae.rst │ │ ├── vae_iaf.rst │ │ ├── vae_lin_nf.rst │ │ ├── vaegan.rst │ │ ├── vamp.rst │ │ ├── vqvae.rst │ │ └── wae.rst │ ├── nn │ │ ├── celeba │ │ │ ├── convnets.rst │ │ │ ├── pythae_benchmarks_nn_celeba.rst │ │ │ └── resnets.rst │ │ ├── cifar │ │ │ ├── convnets.rst │ │ │ ├── pythae_benchmarks_nn_cifar.rst │ │ │ └── resnets.rst │ │ ├── mnist │ │ │ ├── convnets.rst │ │ │ ├── pythae_benchmarks_nn_mnist.rst │ │ │ └── resnets.rst │ │ ├── nn.rst │ │ ├── pythae_base_nn.rst │ │ └── pythae_benchmarks_nn.rst │ ├── normalizing_flows │ │ ├── basenf.rst │ │ ├── iaf.rst │ │ ├── made.rst │ │ ├── maf.rst │ │ ├── normalizing_flows.rst │ │ ├── pixelcnn.rst │ │ ├── planar_flow.rst │ │ └── radial_flow.rst │ └── pythae.models.rst │ ├── pipelines │ ├── generation.rst │ ├── pythae.pipelines.rst │ └── training.rst │ ├── references.bib │ ├── samplers │ ├── basesampler.rst │ ├── gmm_sampler.rst │ ├── iaf_sampler.rst │ ├── maf_sampler.rst │ ├── normal_sampler.rst │ ├── pixelcnn_sampler.rst │ ├── poincare_disk_sampler.rst │ ├── pythae.samplers.rst │ ├── rhvae_sampler.rst │ ├── twostage_sampler.rst │ ├── unit_sphere_unif_sampler.rst │ └── vamp_sampler.rst │ └── trainers │ ├── pythae.trainer.rst │ ├── pythae.trainers.adversarial_trainer.rst │ ├── pythae.trainers.base_trainer.rst │ ├── pythae.trainers.coupled_optimizer_adversarial_trainer.rst │ ├── pythae.trainers.coupled_trainer.rst │ └── pythae.training_callbacks.rst ├── examples ├── notebooks │ ├── README.md │ ├── comet_experiment_monitoring.ipynb │ ├── custom_dataset.ipynb │ ├── hf_hub_models_sharing.ipynb │ ├── making_your_own_autoencoder.ipynb │ ├── mlflow_experiment_monitoring.ipynb │ ├── models_training │ │ ├── adversarial_ae_training.ipynb │ │ ├── ae_training.ipynb │ │ ├── beta_tc_vae_training.ipynb │ │ ├── beta_vae_training.ipynb │ │ ├── ciwae_training.ipynb │ │ ├── disentangled_beta_vae_training.ipynb │ │ ├── factor_vae_training.ipynb │ │ ├── hrqvae_training.ipynb │ │ ├── hvae_training.ipynb │ │ ├── info_vae_training.ipynb │ │ ├── iwae_training.ipynb │ │ ├── miwae_training.ipynb │ │ ├── ms_ssim_vae_training.ipynb │ │ ├── normalizing_flows_training.ipynb │ │ ├── piwae_training.ipynb │ │ ├── pvae_training.ipynb │ │ ├── rae_gp_training.ipynb │ │ ├── rae_l2_training.ipynb │ │ ├── rhvae_training.ipynb │ │ ├── svae_training.ipynb │ │ ├── vae_iaf_training.ipynb │ │ ├── vae_lin_nf_training.ipynb │ │ ├── vae_lstm.ipynb │ │ ├── vae_training.ipynb │ │ ├── vaegan_training.ipynb │ │ ├── vamp_training.ipynb │ │ ├── vqvae_training.ipynb │ │ └── wae_training.ipynb │ ├── overview_notebook.ipynb │ ├── requirements.txt │ └── wandb_experiment_monitoring.ipynb ├── scripts │ ├── README.md │ ├── configs │ │ ├── binary_mnist │ │ │ ├── base_training_config.json │ │ │ ├── hvae_config.json │ │ │ ├── iwae_config.json │ │ │ ├── rhvae_config.json │ │ │ ├── svae_config.json │ │ │ ├── vae_config.json │ │ │ ├── vae_iaf_config.json │ │ │ ├── vae_lin_nf_config.json │ │ │ └── vamp_config.json │ │ ├── celeba │ │ │ ├── aae_config.json │ │ │ ├── ae_config.json │ │ │ ├── base_training_config.json │ │ │ ├── beta_tc_vae_config.json │ │ │ ├── beta_vae_config.json │ │ │ ├── disentangled_beta_vae_config.json │ │ │ ├── factor_vae_config.json │ │ │ ├── hvae_config.json │ │ │ ├── info_vae_config.json │ │ │ ├── msssim_vae_config.json │ │ │ ├── rae_gp_config.json │ │ │ ├── rae_l2_config.json │ │ │ ├── rhvae_config.json │ │ │ ├── svae_config.json │ │ │ ├── vae_config.json │ │ │ ├── vae_iaf_config.json │ │ │ ├── vae_lin_nf_config.json │ │ │ ├── vaegan_config.json │ │ │ ├── vamp_config.json │ │ │ ├── vqvae_config.json │ │ │ └── wae_config.json │ │ ├── cifar10 │ │ │ ├── ae_config.json │ │ │ ├── base_training_config.json │ │ │ ├── beta_vae_config.json │ │ │ ├── hvae_config.json │ │ │ ├── info_vae_config.json │ │ │ ├── rae_gp_config.json │ │ │ ├── rae_l2_config.json │ │ │ ├── rhvae_config.json │ │ │ ├── vae_config.json │ │ │ ├── vamp_config.json │ │ │ ├── vqvae_config.json │ │ │ └── wae_config.json │ │ ├── dsprites │ │ │ ├── beta_tc_vae │ │ │ │ ├── base_training_config.json │ │ │ │ └── beta_tc_vae_config.json │ │ │ └── factorvae │ │ │ │ ├── base_training_config.json │ │ │ │ └── factorvae_config.json │ │ └── mnist │ │ │ ├── aae_config.json │ │ │ ├── ae_config.json │ │ │ ├── base_training_config.json │ │ │ ├── beta_tc_vae_config.json │ │ │ ├── beta_vae_config.json │ │ │ ├── disentangled_beta_vae_config.json │ │ │ ├── factor_vae_config.json │ │ │ ├── hvae_config.json │ │ │ ├── info_vae_config.json │ │ │ ├── iwae_config.json │ │ │ ├── msssim_vae_config.json │ │ │ ├── rae_gp_config.json │ │ │ ├── rae_l2_config.json │ │ │ ├── rhvae_config.json │ │ │ ├── svae_config.json │ │ │ ├── vae_config.json │ │ │ ├── vae_iaf_config.json │ │ │ ├── vae_lin_nf_config.json │ │ │ ├── vaegan_config.json │ │ │ ├── vamp_config.json │ │ │ ├── vqvae_config.json │ │ │ └── wae_config.json │ ├── custom_nn.py │ ├── data-download.py │ ├── distributed_training_ffhq.py │ ├── distributed_training_imagenet.py │ ├── distributed_training_mnist.py │ ├── reproducibility │ │ ├── README.md │ │ ├── aae.py │ │ ├── betatcvae.py │ │ ├── ciwae.py │ │ ├── hvae.py │ │ ├── iwae.py │ │ ├── miwae.py │ │ ├── piwae.py │ │ ├── pvae.py │ │ ├── rae_gp.py │ │ ├── rae_l2.py │ │ ├── svae.py │ │ ├── vae.py │ │ ├── vamp.py │ │ └── wae.py │ └── training.py └── showcases │ ├── aae_normal_sampling_celeba.png │ ├── aae_normal_sampling_mnist.png │ ├── aae_reconstruction_celeba.png │ ├── aae_reconstruction_mnist.png │ ├── ae_gmm_sampling_celeba.png │ ├── ae_gmm_sampling_mnist.png │ ├── ae_normal_sampling_celeba.png │ ├── ae_normal_sampling_mnist.png │ ├── ae_reconstruction_celeba.png │ ├── ae_reconstruction_mnist.png │ ├── beta_tc_vae_normal_sampling_celeba.png │ ├── beta_tc_vae_normal_sampling_mnist.png │ ├── beta_tc_vae_reconstruction_celeba.png │ ├── beta_tc_vae_reconstruction_mnist.png │ ├── beta_vae_normal_sampling_celeba.png │ ├── beta_vae_normal_sampling_mnist.png │ ├── beta_vae_reconstruction_celeba.png │ ├── beta_vae_reconstruction_mnist.png │ ├── disentangled_beta_vae_normal_sampling_celeba.png │ ├── disentangled_beta_vae_normal_sampling_mnist.png │ ├── disentangled_beta_vae_reconstruction_celeba.png │ ├── disentangled_beta_vae_reconstruction_mnist.png │ ├── eval_reconstruction_celeba.png │ ├── eval_reconstruction_mnist.png │ ├── factor_vae_normal_sampling_celeba.png │ ├── factor_vae_normal_sampling_mnist.png │ ├── factor_vae_reconstruction_celeba.png │ ├── factor_vae_reconstruction_mnist.png │ ├── hvae_gmm_sampling_mnist.png │ ├── hvae_normal_sampling_celeba.png │ ├── hvae_normal_sampling_mnist.png │ ├── hvae_reconstruction_celeba.png │ ├── hvae_reconstruction_mnist.png │ ├── hvvae_gmm_sampling_mnist.png │ ├── hvvae_normal_sampling_mnist.png │ ├── infovae_normal_sampling_celeba.png │ ├── infovae_normal_sampling_mnist.png │ ├── infovae_reconstruction_celeba.png │ ├── infovae_reconstruction_mnist.png │ ├── iwae_gmm_sampling_mnist.png │ ├── iwae_normal_sampling_celeba.png │ ├── iwae_normal_sampling_mnist.png │ ├── iwae_reconstruction_celeba.png │ ├── iwae_reconstruction_mnist.png │ ├── msssim_vae_normal_sampling_celeba.png │ ├── msssim_vae_normal_sampling_mnist.png │ ├── msssim_vae_reconstruction_celeba.png │ ├── msssim_vae_reconstruction_mnist.png │ ├── rae_gp_gmm_sampling_celeba.png │ ├── rae_gp_gmm_sampling_mnist.png │ ├── rae_gp_normal_sampling_celeba.png │ ├── rae_gp_normal_sampling_mnist.png │ ├── rae_gp_reconstruction_celeba.png │ ├── rae_gp_reconstruction_mnist.png │ ├── rae_l2_gmm_sampling_celeba.png │ ├── rae_l2_gmm_sampling_mnist.png │ ├── rae_l2_normal_sampling_celeba.png │ ├── rae_l2_normal_sampling_mnist.png │ ├── rae_l2_reconstruction_celeba.png │ ├── rae_l2_reconstruction_mnist.png │ ├── rhvae_reconstruction_celeba.png │ ├── rhvae_reconstruction_mnist.png │ ├── rhvae_rhvae_sampling_celeba.png │ ├── rhvae_rhvae_sampling_mnist.png │ ├── svae_hypersphere_uniform_sampling_celeba.png │ ├── svae_hypersphere_uniform_sampling_mnist.png │ ├── svae_reconstruction_celeba.png │ ├── svae_reconstruction_mnist.png │ ├── vae_gmm_sampling_celeba.png │ ├── vae_gmm_sampling_mnist.png │ ├── vae_iaf_normal_sampling_celeba.png │ ├── vae_iaf_normal_sampling_mnist.png │ ├── vae_iaf_reconstruction_celeba.png │ ├── vae_iaf_reconstruction_mnist.png │ ├── vae_lin_nf_normal_sampling_celeba.png │ ├── vae_lin_nf_normal_sampling_mnist.png │ ├── vae_lin_nf_reconstruction_celeba.png │ ├── vae_lin_nf_reconstruction_mnist.png │ ├── vae_maf_sampling_celeba.png │ ├── vae_maf_sampling_mnist.png │ ├── vae_normal_sampling_celeba.png │ ├── vae_normal_sampling_mnist.png │ ├── vae_reconstruction_celeba.png │ ├── vae_reconstruction_mnist.png │ ├── vae_second_stage_sampling_celeba.png │ ├── vae_second_stage_sampling_mnist.png │ ├── vaegan_normal_sampling_celeba.png │ ├── vaegan_normal_sampling_mnist.png │ ├── vaegan_reconstruction_celeba.png │ ├── vaegan_reconstruction_mnist.png │ ├── vamp_reconstruction_celeba.png │ ├── vamp_reconstruction_mnist.png │ ├── vamp_vamp_sampling_celeba.png │ ├── vamp_vamp_sampling_mnist.png │ ├── vqvae_maf_sampling_celeba.png │ ├── vqvae_maf_sampling_mnist.png │ ├── vqvae_pixelcnn_sampling_mnist.png │ ├── vqvae_reconstruction_celeba.png │ ├── vqvae_reconstruction_mnist.png │ ├── wae_gmm_sampling_celeba.png │ ├── wae_gmm_sampling_mnist.png │ ├── wae_normal_sampling_celeba.png │ ├── wae_normal_sampling_mnist.png │ ├── wae_reconstruction_celeba.png │ └── wae_reconstruction_mnist.png ├── pyproject.toml ├── requirements.txt ├── setup.cfg ├── setup.py ├── src └── pythae │ ├── __init__.py │ ├── config.py │ ├── customexception.py │ ├── data │ ├── __init__.py │ ├── datasets.py │ └── preprocessors.py │ ├── models │ ├── __init__.py │ ├── adversarial_ae │ │ ├── __init__.py │ │ ├── adversarial_ae_config.py │ │ └── adversarial_ae_model.py │ ├── ae │ │ ├── __init__.py │ │ ├── ae_config.py │ │ └── ae_model.py │ ├── auto_model │ │ ├── __init__.py │ │ ├── auto_config.py │ │ └── auto_model.py │ ├── base │ │ ├── __init__.py │ │ ├── base_config.py │ │ ├── base_model.py │ │ └── base_utils.py │ ├── beta_tc_vae │ │ ├── __init__.py │ │ ├── beta_tc_vae_config.py │ │ └── beta_tc_vae_model.py │ ├── beta_vae │ │ ├── __init__.py │ │ ├── beta_vae_config.py │ │ └── beta_vae_model.py │ ├── ciwae │ │ ├── __init__.py │ │ ├── ciwae_config.py │ │ └── ciwae_model.py │ ├── disentangled_beta_vae │ │ ├── __init__.py │ │ ├── disentangled_beta_vae_config.py │ │ └── disentangled_beta_vae_model.py │ ├── factor_vae │ │ ├── __init__.py │ │ ├── factor_vae_config.py │ │ ├── factor_vae_model.py │ │ └── factor_vae_utils.py │ ├── hrq_vae │ │ ├── __init__.py │ │ ├── hrq_vae_config.py │ │ ├── hrq_vae_model.py │ │ └── hrq_vae_utils.py │ ├── hvae │ │ ├── __init__.py │ │ ├── hvae_config.py │ │ └── hvae_model.py │ ├── info_vae │ │ ├── __init__.py │ │ ├── info_vae_config.py │ │ └── info_vae_model.py │ ├── iwae │ │ ├── __init__.py │ │ ├── iwae_config.py │ │ └── iwae_model.py │ ├── miwae │ │ ├── __init__.py │ │ ├── miwae_config.py │ │ └── miwae_model.py │ ├── msssim_vae │ │ ├── __init__.py │ │ ├── msssim_vae_config.py │ │ ├── msssim_vae_model.py │ │ └── msssim_vae_utils.py │ ├── nn │ │ ├── __init__.py │ │ ├── base_architectures.py │ │ ├── benchmarks │ │ │ ├── __init__.py │ │ │ ├── celeba │ │ │ │ ├── __init__.py │ │ │ │ ├── convnets.py │ │ │ │ └── resnets.py │ │ │ ├── cifar │ │ │ │ ├── __init__.py │ │ │ │ ├── convnets.py │ │ │ │ └── resnets.py │ │ │ ├── mnist │ │ │ │ ├── __init__.py │ │ │ │ ├── convnets.py │ │ │ │ └── resnets.py │ │ │ └── utils.py │ │ └── default_architectures.py │ ├── normalizing_flows │ │ ├── __init__.py │ │ ├── base │ │ │ ├── __init__.py │ │ │ ├── base_nf_config.py │ │ │ └── base_nf_model.py │ │ ├── iaf │ │ │ ├── __init__.py │ │ │ ├── iaf_config.py │ │ │ └── iaf_model.py │ │ ├── layers.py │ │ ├── made │ │ │ ├── __init__.py │ │ │ ├── made_config.py │ │ │ └── made_model.py │ │ ├── maf │ │ │ ├── __init__.py │ │ │ ├── maf_config.py │ │ │ └── maf_model.py │ │ ├── pixelcnn │ │ │ ├── __init__.py │ │ │ ├── pixelcnn_config.py │ │ │ ├── pixelcnn_model.py │ │ │ └── utils.py │ │ ├── planar_flow │ │ │ ├── __init__.py │ │ │ ├── planar_flow_config.py │ │ │ └── planar_flow_model.py │ │ └── radial_flow │ │ │ ├── __init__.py │ │ │ ├── radial_flow_config.py │ │ │ └── radial_flow_model.py │ ├── piwae │ │ ├── __init__.py │ │ ├── piwae_config.py │ │ └── piwae_model.py │ ├── pvae │ │ ├── __init__.py │ │ ├── pvae_config.py │ │ ├── pvae_model.py │ │ └── pvae_utils.py │ ├── rae_gp │ │ ├── __init__.py │ │ ├── rae_gp_config.py │ │ └── rae_gp_model.py │ ├── rae_l2 │ │ ├── __init__.py │ │ ├── rae_l2_config.py │ │ └── rae_l2_model.py │ ├── rhvae │ │ ├── __init__.py │ │ ├── rhvae_config.py │ │ ├── rhvae_model.py │ │ └── rhvae_utils.py │ ├── svae │ │ ├── __init__.py │ │ ├── svae_config.py │ │ ├── svae_model.py │ │ └── svae_utils.py │ ├── vae │ │ ├── __init__.py │ │ ├── vae_config.py │ │ └── vae_model.py │ ├── vae_gan │ │ ├── __init__.py │ │ ├── vae_gan_config.py │ │ └── vae_gan_model.py │ ├── vae_iaf │ │ ├── __init__.py │ │ ├── vae_iaf_config.py │ │ └── vae_iaf_model.py │ ├── vae_lin_nf │ │ ├── __init__.py │ │ ├── vae_lin_nf_config.py │ │ └── vae_lin_nf_model.py │ ├── vamp │ │ ├── __init__.py │ │ ├── vamp_config.py │ │ └── vamp_model.py │ ├── vq_vae │ │ ├── __init__.py │ │ ├── vq_vae_config.py │ │ ├── vq_vae_model.py │ │ └── vq_vae_utils.py │ └── wae_mmd │ │ ├── __init__.py │ │ ├── wae_mmd_config.py │ │ └── wae_mmd_model.py │ ├── pipelines │ ├── __init__.py │ ├── base_pipeline.py │ ├── generation.py │ ├── pipeline_utils.py │ └── training.py │ ├── py.typed │ ├── samplers │ ├── __init__.py │ ├── base │ │ ├── __init__.py │ │ ├── base_sampler.py │ │ └── base_sampler_config.py │ ├── gaussian_mixture │ │ ├── __init__.py │ │ ├── gaussian_mixture_config.py │ │ └── gaussian_mixture_sampler.py │ ├── hypersphere_uniform_sampler │ │ ├── __init__.py │ │ ├── hypersphere_uniform_config.py │ │ └── hypersphere_uniform_sampler.py │ ├── iaf_sampler │ │ ├── __init__.py │ │ ├── iaf_sampler.py │ │ └── iaf_sampler_config.py │ ├── maf_sampler │ │ ├── __init__.py │ │ ├── maf_sampler.py │ │ └── maf_sampler_config.py │ ├── manifold_sampler │ │ ├── __init__.py │ │ ├── rhvae_sampler.py │ │ └── rhvae_sampler_config.py │ ├── normal_sampling │ │ ├── __init__.py │ │ ├── normal_config.py │ │ └── normal_sampler.py │ ├── pixelcnn_sampler │ │ ├── __init__.py │ │ ├── pixelcnn_sampler.py │ │ └── pixelcnn_sampler_config.py │ ├── pvae_sampler │ │ ├── __init__.py │ │ ├── pvae_sampler.py │ │ └── pvae_sampler_config.py │ ├── two_stage_vae_sampler │ │ ├── __init__.py │ │ ├── two_stage_sampler.py │ │ └── two_stage_sampler_config.py │ └── vamp_sampler │ │ ├── __init__.py │ │ ├── vamp_sampler.py │ │ └── vamp_sampler_config.py │ └── trainers │ ├── __init__.py │ ├── adversarial_trainer │ ├── __init__.py │ ├── adversarial_trainer.py │ └── adversarial_trainer_config.py │ ├── base_trainer │ ├── __init__.py │ ├── base_trainer.py │ └── base_training_config.py │ ├── coupled_optimizer_adversarial_trainer │ ├── __init__.py │ ├── coupled_optimizer_adversarial_trainer.py │ └── coupled_optimizer_adversarial_trainer_config.py │ ├── coupled_optimizer_trainer │ ├── __init__.py │ ├── coupled_optimizer_trainer.py │ └── coupled_optimizer_trainer_config.py │ ├── trainer_utils.py │ └── training_callbacks.py └── tests ├── README.md ├── __init__.py ├── conftest.py ├── data ├── baseAE │ └── configs │ │ ├── corrupted_model_config.json │ │ ├── generation_config00.json │ │ ├── model_config00.json │ │ ├── not_json_file.md │ │ └── training_config00.json ├── corrupted_config │ └── model_config.json ├── custom_architectures.py ├── loading │ └── dummy_data_folder │ │ ├── example0.bmp │ │ ├── example0.jpeg │ │ ├── example0.jpg │ │ ├── example0.png │ │ └── example0_downsampled_12_12.jpg ├── mnist_clean_train_dataset_sample ├── rhvae │ └── configs │ │ ├── model_config00.json │ │ └── trained_model_folder │ │ ├── model.pt │ │ └── model_config.json ├── unnormalized_mnist_data_array └── unnormalized_mnist_data_list_of_array ├── pytest.ini ├── test_AE.py ├── test_Adversarial_AE.py ├── test_BetaTCVAE.py ├── test_BetaVAE.py ├── test_CIWAE.py ├── test_DisentangledBetaVAE.py ├── test_FactorVAE.py ├── test_HRQVAE.py ├── test_HVAE.py ├── test_IAF.py ├── test_IWAE.py ├── test_MADE.py ├── test_MAF.py ├── test_MIWAE.py ├── test_MSSSIMVAE.py ├── test_PIWAE.py ├── test_PixelCNN.py ├── test_PoincareVAE.py ├── test_RHVAE.py ├── test_SVAE.py ├── test_VAE.py ├── test_VAEGAN.py ├── test_VAE_IAF.py ├── test_VAE_LinFlow.py ├── test_VAMP.py ├── test_VQVAE.py ├── test_WAE_MMD.py ├── test_adversarial_trainer.py ├── test_amp.py ├── test_auto_model.py ├── test_baseAE.py ├── test_baseSampler.py ├── test_base_trainer.py ├── test_config.py ├── test_coupled_optimizers_adversarial_trainer.py ├── test_coupled_optimizers_trainer.py ├── test_datasets.py ├── test_gaussian_mixture_sampler.py ├── test_hypersphere_uniform_sampler.py ├── test_iaf_sampler.py ├── test_info_vae_mmd.py ├── test_maf_sampler.py ├── test_nn_benchmark.py ├── test_normal_sampler.py ├── test_pipeline_standalone.py ├── test_pixelcnn_sampler.py ├── test_planar_flow.py ├── test_preprocessing.py ├── test_pvae_sampler.py ├── test_radial_flow.py ├── test_rae_gp.py ├── test_rae_l2.py ├── test_rhvae_sampler.py ├── test_training_callbacks.py ├── test_two_stage_sampler.py ├── test_vamp_sampler.py └── your_file.jpeg /.coveragerc: -------------------------------------------------------------------------------- 1 | [report] 2 | exclude_lines = 3 | pragma: no cover -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Go to '...' 16 | 2. Click on '....' 17 | 3. Scroll down to '....' 18 | 4. See error 19 | **If possible please also provide code snippets or errors** 20 | 21 | 22 | **Expected behavior** 23 | A clear and concise description of what you expected to happen. 24 | 25 | **Screenshots** 26 | If applicable, add screenshots to help explain your problem. 27 | 28 | 29 | 30 | 31 | **Desktop (please complete the following information):** 32 | - OS: [e.g. iOS] 33 | - Browser [e.g. chrome, safari] 34 | - Version [e.g. 22] 35 | 36 | **Smartphone (please complete the following information):** 37 | - Device: [e.g. iPhone6] 38 | - OS: [e.g. iOS8.1] 39 | - Browser [e.g. stock browser, safari] 40 | - Version [e.g. 22] 41 | 42 | **Additional context** 43 | Add any other context about the problem here. 44 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.github/workflows/coverage.yml: -------------------------------------------------------------------------------- 1 | name: Codecov coverage computation 2 | on: 3 | push: 4 | branches: 5 | - main 6 | jobs: 7 | run: 8 | runs-on: ${{ matrix.os }} 9 | strategy: 10 | matrix: 11 | os: [ubuntu-latest] 12 | env: 13 | OS: ${{ matrix.os }} 14 | PYTHON: '3.8' 15 | steps: 16 | - uses: actions/checkout@main 17 | - name: Setup Python 18 | uses: actions/setup-python@main 19 | with: 20 | python-version: 3.8 21 | - name: Generate coverage report 22 | run: | 23 | pip install . 24 | pip install pytest 25 | pip install pytest-cov 26 | pytest --cov=pythae ./ --cov-report=xml --runslow 27 | - name: Upload coverage to Codecov 28 | uses: codecov/codecov-action@v2 29 | with: 30 | token: ${{ secrets.CODECOV_TOKEN }} 31 | directory: ./coverage/reports/ 32 | env_vars: OS,PYTHON 33 | fail_ci_if_error: true 34 | files: ./coverage.xml 35 | flags: unittests 36 | name: codecov-umbrella 37 | path_to_write_report: ./coverage/codecov_report.txt 38 | verbose: true 39 | -------------------------------------------------------------------------------- /.github/workflows/tests_bench.yml: -------------------------------------------------------------------------------- 1 | name: Running tests (avoid slow) 2 | on: 3 | push: 4 | jobs: 5 | run: 6 | runs-on: ${{ matrix.os }} 7 | strategy: 8 | matrix: 9 | os: [ubuntu-latest, macos-latest, windows-latest] 10 | env: 11 | OS: ${{ matrix.os }} 12 | PYTHON: '3.8' 13 | steps: 14 | - uses: actions/checkout@main 15 | - name: Setup Python 16 | uses: actions/setup-python@main 17 | with: 18 | python-version: 3.8 19 | - name: Run tests 20 | run: | 21 | pip install . 22 | pip install pytest 23 | pip install pytest-cov 24 | pytest . 25 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | # File: .readthedocs.yaml 2 | 3 | version: 2 4 | 5 | build: 6 | os: ubuntu-22.04 7 | tools: 8 | python: "3.9" 9 | 10 | # Build from the docs/ directory with Sphinx 11 | sphinx: 12 | configuration: docs/source/conf.py 13 | 14 | # Explicitly set the version of Python and its requirements 15 | python: 16 | install: 17 | - requirements: docs/requirements.txt 18 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | date-released: 2022-06 3 | message: "If you use this software, please cite it as below." 4 | title: "Pythae: Unifying Generative Autoencoders in Python -- A Benchmarking Use Case" 5 | url: "https://github.com/clementchadebec/benchmark_VAE" 6 | authors: 7 | - family-names: Chadebec 8 | given-names: Clément 9 | - family-names: Vincent 10 | given-names: Louis J. 11 | - family-names: Allassonnière 12 | given-names: Stéphanie 13 | preferred-citation: 14 | type: conference-paper 15 | title: "Pythae: Unifying Generative Autoencoders in Python -- A Benchmarking Use Case" 16 | authors: 17 | - family-names: Chadebec 18 | given-names: Clément 19 | - family-names: Vincent 20 | given-names: Louis J. 21 | - family-names: Allassonnière 22 | given-names: Stéphanie 23 | collection-title: Advances in Neural Information Processing Systems 35 24 | collection-type: proceedings 25 | editors: 26 | - family-names: Koyejo 27 | given-names: S. 28 | - family-names: Mohamed 29 | given-names: S. 30 | - family-names: Agarwal 31 | given-names: A. 32 | - family-names: Belgrave 33 | given-names: D. 34 | - family-names: Cho 35 | given-names: K. 36 | - family-names: Oh 37 | given-names: A. 38 | start: 21575 39 | end: 21589 40 | year: 2022 41 | publisher: 42 | name: Curran Associates, Inc. 43 | url: "https://arxiv.org/abs/2206.08309" 44 | address: "Online" -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | # Declare all files and folders to include/exclude 2 | 3 | include README.md 4 | include src/pythae/py.typed # marker file for PEP 561 5 | 6 | include LICENSE 7 | # Include all citation (*.cff) files 8 | include *.cff # citation info 9 | 10 | include MANIFEST.in 11 | include pyproject.toml 12 | include setup.py 13 | include setup.cfg 14 | 15 | # include requirements.txt 16 | 17 | recursive-exclude tests * 18 | recursive-exclude docs * 19 | recursive-exclude examples * 20 | 21 | recursive-exclude .github * 22 | recursive-exclude dist * 23 | 24 | exclude .gitignore 25 | exclude .readthedocs.yml 26 | exclude codecov.yml 27 | 28 | exclude CONTRIBUTING.md 29 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | coverage: 2 | status: 3 | project: 4 | default: 5 | # Commits pushed to main should not make the overall 6 | # project coverage decrease by more than 1%: 7 | target: 0% 8 | threshold: 1% 9 | patch: 10 | default: 11 | enabled: no 12 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | **The latest version of the documentation is available on [readthedocs](https://pythae.readthedocs.io/en/latest/)** 2 | 3 | To generate the documentation locally, do the following: 4 | 5 | - Make sur you installed the requirement at the root of this repo 6 | ``` 7 | pip install -r ../requirements.txt 8 | ``` 9 | 10 | - Install [sphinx](https://www.sphinx-doc.org/en/master/usage/quickstart.html) with pip and dependencies 11 | 12 | ``` 13 | pip install sphinx 14 | pip install sphinx_rtd_theme 15 | pip install sphinxcontrib.bibtex 16 | ``` 17 | 18 | - Then build the documentation: 19 | 20 | ``` 21 | sphinx-build -b html source build 22 | ``` 23 | 24 | **note** do not pay attention the logs warnings 25 | 26 | Then go to `build` folder and open the `intex.html` file (in a browser for example) 27 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | Sphinx==4.1.2 2 | sphinx-rtd-theme==0.5.2 3 | sphinxcontrib-applehelp==1.0.2 4 | sphinxcontrib-bibtex==2.3.0 5 | sphinxcontrib-devhelp==1.0.2 6 | sphinxcontrib-htmlhelp==2.0.0 7 | sphinxcontrib-jsmath==1.0.1 8 | sphinxcontrib-qthelp==1.0.3 9 | sphinxcontrib-serializinghtml==1.1.5 10 | -e . 11 | -------------------------------------------------------------------------------- /docs/source/_templates/_variables.scss: -------------------------------------------------------------------------------- 1 | /******************************************************************************* 2 | * Variables used throughout the theme. 3 | * To adjust anything, simply edit the variables below and rebuild the theme. 4 | ******************************************************************************/ 5 | 6 | 7 | // Colors 8 | $red-color: #FF3636 !default; 9 | $red-color-dark: #B71C1C !default; 10 | $orange-color: #F29105 !default; 11 | $blue-color: #0076df !default; 12 | $blue-color-dark: #00369f !default; 13 | $cyan-color: #2698BA !default; 14 | $light-cyan-color: lighten($cyan-color, 25%); 15 | $green-color: #00ab37 !default; 16 | $green-color-lime: #B7D12A !default; 17 | $green-color-dark: #009f06 !default; 18 | $green-color-light: #ddffdd !default; 19 | $green-color-bright: #11D68B !default; 20 | $purple-color: #B509AC !default; 21 | $light-purple-color: lighten($purple-color, 25%); 22 | $pink-color: #f92080 !default; 23 | $pink-color-light: #ffdddd !default; 24 | $yellow-color: #efcc00 !default; 25 | 26 | $grey-color: #828282 !default; 27 | $grey-color-light: lighten($grey-color, 40%); 28 | $grey-color-dark: darken($grey-color, 25%); 29 | 30 | $white-color: #ffffff !default; 31 | $black-color: #000000 !default; 32 | 33 | 34 | // Theme colors 35 | 36 | $code-bg-color-light: rgba($purple-color, 0.05); 37 | $code-bg-color-dark: #2c3237 !default; 38 | 39 | -------------------------------------------------------------------------------- /docs/source/_templates/layout.html: -------------------------------------------------------------------------------- 1 | {% extends "!layout.html" %} 2 | {% block footer %} {{ super() }} 3 | 4 | 7 | {% endblock %} -------------------------------------------------------------------------------- /docs/source/imgs/CNNs.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/docs/source/imgs/CNNs.jpeg -------------------------------------------------------------------------------- /docs/source/imgs/Case_study_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/docs/source/imgs/Case_study_1.jpg -------------------------------------------------------------------------------- /docs/source/imgs/DA_diagram(1).pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/docs/source/imgs/DA_diagram(1).pdf -------------------------------------------------------------------------------- /docs/source/imgs/DA_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/docs/source/imgs/DA_diagram.png -------------------------------------------------------------------------------- /docs/source/imgs/baseline_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/docs/source/imgs/baseline_results.png -------------------------------------------------------------------------------- /docs/source/imgs/logo_pyraug_2.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/docs/source/imgs/logo_pyraug_2.jpeg -------------------------------------------------------------------------------- /docs/source/imgs/nine_digits-rot.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/docs/source/imgs/nine_digits-rot.pdf -------------------------------------------------------------------------------- /docs/source/imgs/nine_digits-rot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/docs/source/imgs/nine_digits-rot.png -------------------------------------------------------------------------------- /docs/source/imgs/nine_digits.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/docs/source/imgs/nine_digits.pdf -------------------------------------------------------------------------------- /docs/source/imgs/nine_digits.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/docs/source/imgs/nine_digits.png -------------------------------------------------------------------------------- /docs/source/imgs/optimized_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/docs/source/imgs/optimized_results.png -------------------------------------------------------------------------------- /docs/source/imgs/pyraug_diagram.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/docs/source/imgs/pyraug_diagram.jpg -------------------------------------------------------------------------------- /docs/source/imgs/pyraug_diagram_simplified.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/docs/source/imgs/pyraug_diagram_simplified.jpg -------------------------------------------------------------------------------- /docs/source/models/autoencoders/aae.rst: -------------------------------------------------------------------------------- 1 | ********************************** 2 | Adversarial Autoencoder 3 | ********************************** 4 | 5 | 6 | .. automodule:: 7 | pythae.models.adversarial_ae 8 | 9 | .. autoclass:: pythae.models.Adversarial_AE_Config 10 | :members: 11 | 12 | .. autoclass:: pythae.models.Adversarial_AE 13 | :members: 14 | -------------------------------------------------------------------------------- /docs/source/models/autoencoders/ae.rst: -------------------------------------------------------------------------------- 1 | ********************************** 2 | Autoencoder 3 | ********************************** 4 | 5 | 6 | .. automodule:: 7 | pythae.models.ae 8 | 9 | .. autoclass:: pythae.models.AEConfig 10 | :members: 11 | 12 | .. autoclass:: pythae.models.AE 13 | :members: 14 | -------------------------------------------------------------------------------- /docs/source/models/autoencoders/auto_model.rst: -------------------------------------------------------------------------------- 1 | ********************************** 2 | AutoModel 3 | ********************************** 4 | 5 | 6 | .. automodule:: 7 | pythae.models.auto_model 8 | 9 | .. autoclass:: pythae.models.AutoModel 10 | :members: 11 | -------------------------------------------------------------------------------- /docs/source/models/autoencoders/baseAE.rst: -------------------------------------------------------------------------------- 1 | ********************************** 2 | BaseAE 3 | ********************************** 4 | 5 | 6 | .. automodule:: 7 | pythae.models.base 8 | 9 | .. autoclass:: pythae.models.BaseAEConfig 10 | :members: 11 | 12 | .. autoclass:: pythae.models.BaseAE 13 | :members: 14 | 15 | .. autoclass:: pythae.models.base.base_utils.ModelOutput 16 | :members: -------------------------------------------------------------------------------- /docs/source/models/autoencoders/betatcvae.rst: -------------------------------------------------------------------------------- 1 | ********************************** 2 | BetaTCVAE 3 | ********************************** 4 | 5 | .. automodule:: 6 | pythae.models.beta_tc_vae 7 | 8 | .. autoclass:: pythae.models.BetaTCVAEConfig 9 | :members: 10 | 11 | .. autoclass:: pythae.models.BetaTCVAE 12 | :members: -------------------------------------------------------------------------------- /docs/source/models/autoencoders/betavae.rst: -------------------------------------------------------------------------------- 1 | ********************************** 2 | BetaVAE 3 | ********************************** 4 | 5 | .. automodule:: 6 | pythae.models.beta_vae 7 | 8 | .. autoclass:: pythae.models.BetaVAEConfig 9 | :members: 10 | 11 | .. autoclass:: pythae.models.BetaVAE 12 | :members: -------------------------------------------------------------------------------- /docs/source/models/autoencoders/ciwae.rst: -------------------------------------------------------------------------------- 1 | ********************************** 2 | CIWAE 3 | ********************************** 4 | 5 | .. automodule:: 6 | pythae.models.ciwae 7 | 8 | .. autoclass:: pythae.models.CIWAEConfig 9 | :members: 10 | 11 | .. autoclass:: pythae.models.CIWAE 12 | :members: -------------------------------------------------------------------------------- /docs/source/models/autoencoders/disentangled_betavae.rst: -------------------------------------------------------------------------------- 1 | ********************************** 2 | DisentangledBetaVAE 3 | ********************************** 4 | 5 | .. automodule:: 6 | pythae.models.disentangled_beta_vae 7 | 8 | .. autoclass:: pythae.models.DisentangledBetaVAEConfig 9 | :members: 10 | 11 | .. autoclass:: pythae.models.DisentangledBetaVAE 12 | :members: -------------------------------------------------------------------------------- /docs/source/models/autoencoders/factorvae.rst: -------------------------------------------------------------------------------- 1 | ********************************** 2 | FactorVAE 3 | ********************************** 4 | 5 | .. automodule:: 6 | pythae.models.factor_vae 7 | 8 | .. autoclass:: pythae.models.FactorVAEConfig 9 | :members: 10 | 11 | .. autoclass:: pythae.models.FactorVAE 12 | :members: -------------------------------------------------------------------------------- /docs/source/models/autoencoders/hvae.rst: -------------------------------------------------------------------------------- 1 | ********************************** 2 | Hamiltonian VAE 3 | ********************************** 4 | 5 | .. automodule:: 6 | pythae.models.hvae 7 | 8 | 9 | .. autoclass:: pythae.models.HVAEConfig 10 | :members: 11 | 12 | .. autoclass:: pythae.models.HVAE 13 | :members: -------------------------------------------------------------------------------- /docs/source/models/autoencoders/infovae.rst: -------------------------------------------------------------------------------- 1 | ********************************** 2 | INFO-VAE 3 | ********************************** 4 | 5 | .. automodule:: 6 | pythae.models.info_vae 7 | 8 | .. autoclass:: pythae.models.INFOVAE_MMD_Config 9 | :members: 10 | 11 | .. autoclass:: pythae.models.INFOVAE_MMD 12 | :members: -------------------------------------------------------------------------------- /docs/source/models/autoencoders/iwae.rst: -------------------------------------------------------------------------------- 1 | ********************************** 2 | IWAE 3 | ********************************** 4 | 5 | .. automodule:: 6 | pythae.models.iwae 7 | 8 | .. autoclass:: pythae.models.IWAEConfig 9 | :members: 10 | 11 | .. autoclass:: pythae.models.IWAE 12 | :members: -------------------------------------------------------------------------------- /docs/source/models/autoencoders/miwae.rst: -------------------------------------------------------------------------------- 1 | ********************************** 2 | MIWAE 3 | ********************************** 4 | 5 | .. automodule:: 6 | pythae.models.miwae 7 | 8 | .. autoclass:: pythae.models.MIWAEConfig 9 | :members: 10 | 11 | .. autoclass:: pythae.models.MIWAE 12 | :members: -------------------------------------------------------------------------------- /docs/source/models/autoencoders/models.rst: -------------------------------------------------------------------------------- 1 | ********************************** 2 | Autoencoders 3 | ********************************** 4 | 5 | .. toctree:: 6 | :hidden: 7 | :maxdepth: 1 8 | 9 | baseAE 10 | auto_model 11 | ae 12 | vae 13 | betavae 14 | vae_lin_nf 15 | vae_iaf 16 | disentangled_betavae 17 | factorvae 18 | betatcvae 19 | iwae 20 | ciwae 21 | miwae 22 | piwae 23 | msssimvae 24 | wae 25 | infovae 26 | vamp 27 | svae 28 | pvae 29 | aae 30 | vaegan 31 | vqvae 32 | hvae 33 | rae_gp 34 | rae_l2 35 | rhvae 36 | 37 | .. automodule:: 38 | pythae.models 39 | 40 | Available Models 41 | ----------------- 42 | 43 | .. autosummary:: 44 | ~pythae.models.BaseAE 45 | ~pythae.models.AutoModel 46 | ~pythae.models.AE 47 | ~pythae.models.VAE 48 | ~pythae.models.BetaVAE 49 | ~pythae.models.VAE_LinNF 50 | ~pythae.models.VAE_IAF 51 | ~pythae.models.DisentangledBetaVAE 52 | ~pythae.models.FactorVAE 53 | ~pythae.models.BetaTCVAE 54 | ~pythae.models.IWAE 55 | ~pythae.models.CIWAE 56 | ~pythae.models.MIWAE 57 | ~pythae.models.PIWAE 58 | ~pythae.models.MSSSIM_VAE 59 | ~pythae.models.WAE_MMD 60 | ~pythae.models.INFOVAE_MMD 61 | ~pythae.models.VAMP 62 | ~pythae.models.SVAE 63 | ~pythae.models.PoincareVAE 64 | ~pythae.models.Adversarial_AE 65 | ~pythae.models.VAEGAN 66 | ~pythae.models.VQVAE 67 | ~pythae.models.HVAE 68 | ~pythae.models.RAE_GP 69 | ~pythae.models.RAE_L2 70 | ~pythae.models.RHVAE 71 | :nosignatures: -------------------------------------------------------------------------------- /docs/source/models/autoencoders/msssimvae.rst: -------------------------------------------------------------------------------- 1 | ********************************** 2 | MSSSIM_VAE 3 | ********************************** 4 | 5 | .. automodule:: 6 | pythae.models.msssim_vae 7 | 8 | .. autoclass:: pythae.models.MSSSIM_VAEConfig 9 | :members: 10 | 11 | .. autoclass:: pythae.models.MSSSIM_VAE 12 | :members: -------------------------------------------------------------------------------- /docs/source/models/autoencoders/piwae.rst: -------------------------------------------------------------------------------- 1 | ********************************** 2 | PIWAE 3 | ********************************** 4 | 5 | .. automodule:: 6 | pythae.models.piwae 7 | 8 | .. autoclass:: pythae.models.PIWAEConfig 9 | :members: 10 | 11 | .. autoclass:: pythae.models.PIWAE 12 | :members: -------------------------------------------------------------------------------- /docs/source/models/autoencoders/pvae.rst: -------------------------------------------------------------------------------- 1 | ********************************** 2 | PoincareVAE 3 | ********************************** 4 | 5 | 6 | .. automodule:: 7 | pythae.models.pvae 8 | 9 | .. autoclass:: pythae.models.PoincareVAEConfig 10 | :members: 11 | 12 | .. autoclass:: pythae.models.PoincareVAE 13 | :members: -------------------------------------------------------------------------------- /docs/source/models/autoencoders/rae_gp.rst: -------------------------------------------------------------------------------- 1 | ********************************** 2 | RAE-GP 3 | ********************************** 4 | 5 | .. automodule:: 6 | pythae.models.rae_gp 7 | 8 | .. autoclass:: pythae.models.RAE_GP_Config 9 | :members: 10 | 11 | .. autoclass:: pythae.models.RAE_GP 12 | :members: -------------------------------------------------------------------------------- /docs/source/models/autoencoders/rae_l2.rst: -------------------------------------------------------------------------------- 1 | ********************************** 2 | RAE-L2 3 | ********************************** 4 | 5 | .. automodule:: 6 | pythae.models.rae_l2 7 | 8 | .. autoclass:: pythae.models.RAE_L2_Config 9 | :members: 10 | 11 | .. autoclass:: pythae.models.RAE_L2 12 | :members: -------------------------------------------------------------------------------- /docs/source/models/autoencoders/rhvae.rst: -------------------------------------------------------------------------------- 1 | ********************************** 2 | Riemannian Hamiltonian VAE 3 | ********************************** 4 | 5 | .. automodule:: 6 | pythae.models.rhvae 7 | 8 | 9 | .. autoclass:: pythae.models.RHVAEConfig 10 | :members: 11 | 12 | .. autoclass:: pythae.models.RHVAE 13 | :members: -------------------------------------------------------------------------------- /docs/source/models/autoencoders/svae.rst: -------------------------------------------------------------------------------- 1 | ********************************** 2 | SVAE 3 | ********************************** 4 | 5 | 6 | .. automodule:: 7 | pythae.models.svae 8 | 9 | .. autoclass:: pythae.models.SVAEConfig 10 | :members: 11 | 12 | .. autoclass:: pythae.models.SVAE 13 | :members: -------------------------------------------------------------------------------- /docs/source/models/autoencoders/vae.rst: -------------------------------------------------------------------------------- 1 | ********************************** 2 | VAE 3 | ********************************** 4 | 5 | 6 | .. automodule:: 7 | pythae.models.vae 8 | 9 | .. autoclass:: pythae.models.VAEConfig 10 | :members: 11 | 12 | .. autoclass:: pythae.models.VAE 13 | :members: -------------------------------------------------------------------------------- /docs/source/models/autoencoders/vae_iaf.rst: -------------------------------------------------------------------------------- 1 | ********************************** 2 | VAE_IAF 3 | ********************************** 4 | 5 | 6 | .. automodule:: 7 | pythae.models.vae_iaf 8 | 9 | .. autoclass:: pythae.models.VAE_IAF_Config 10 | :members: 11 | 12 | .. autoclass:: pythae.models.VAE_IAF 13 | :members: -------------------------------------------------------------------------------- /docs/source/models/autoencoders/vae_lin_nf.rst: -------------------------------------------------------------------------------- 1 | ********************************** 2 | VAE_LinNF 3 | ********************************** 4 | 5 | 6 | .. automodule:: 7 | pythae.models.vae_lin_nf 8 | 9 | .. autoclass:: pythae.models.VAE_LinNF_Config 10 | :members: 11 | 12 | .. autoclass:: pythae.models.VAE_LinNF 13 | :members: -------------------------------------------------------------------------------- /docs/source/models/autoencoders/vaegan.rst: -------------------------------------------------------------------------------- 1 | ********************************** 2 | VAEGAN 3 | ********************************** 4 | 5 | 6 | .. automodule:: 7 | pythae.models.vae_gan 8 | 9 | .. autoclass:: pythae.models.VAEGANConfig 10 | :members: 11 | 12 | .. autoclass:: pythae.models.VAEGAN 13 | :members: 14 | -------------------------------------------------------------------------------- /docs/source/models/autoencoders/vamp.rst: -------------------------------------------------------------------------------- 1 | ********************************** 2 | VAMP 3 | ********************************** 4 | 5 | 6 | .. automodule:: 7 | pythae.models.vamp 8 | 9 | .. autoclass:: pythae.models.VAMPConfig 10 | :members: 11 | 12 | .. autoclass:: pythae.models.VAMP 13 | :members: -------------------------------------------------------------------------------- /docs/source/models/autoencoders/vqvae.rst: -------------------------------------------------------------------------------- 1 | ********************************** 2 | Vector Quantized VAE 3 | ********************************** 4 | 5 | .. automodule:: 6 | pythae.models.vq_vae 7 | 8 | .. autoclass:: pythae.models.VQVAEConfig 9 | :members: 10 | 11 | .. autoclass:: pythae.models.VQVAE 12 | :members: 13 | -------------------------------------------------------------------------------- /docs/source/models/autoencoders/wae.rst: -------------------------------------------------------------------------------- 1 | ********************************** 2 | WAE 3 | ********************************** 4 | 5 | .. automodule:: 6 | pythae.models.wae_mmd 7 | 8 | 9 | .. autoclass:: pythae.models.WAE_MMD_Config 10 | :members: 11 | 12 | .. autoclass:: pythae.models.WAE_MMD 13 | :members: -------------------------------------------------------------------------------- /docs/source/models/nn/celeba/convnets.rst: -------------------------------------------------------------------------------- 1 | ******************************** 2 | ConvNets 3 | ******************************** 4 | 5 | .. autoclass:: pythae.models.nn.benchmarks.celeba.Encoder_Conv_AE_CELEBA 6 | :members: 7 | 8 | .. autoclass:: pythae.models.nn.benchmarks.celeba.Encoder_Conv_VAE_CELEBA 9 | :members: 10 | 11 | .. autoclass:: pythae.models.nn.benchmarks.celeba.Encoder_Conv_SVAE_CELEBA 12 | :members: 13 | 14 | .. autoclass:: pythae.models.nn.benchmarks.celeba.Decoder_Conv_AE_CELEBA 15 | :members: 16 | 17 | .. autoclass:: pythae.models.nn.benchmarks.celeba.Discriminator_Conv_CELEBA 18 | :members: -------------------------------------------------------------------------------- /docs/source/models/nn/celeba/pythae_benchmarks_nn_celeba.rst: -------------------------------------------------------------------------------- 1 | ********************************** 2 | CELEBA 3 | ********************************** 4 | 5 | .. automodule:: 6 | pythae.models.nn.benchmarks.celeba 7 | 8 | .. toctree:: 9 | :maxdepth: 1 10 | 11 | convnets 12 | resnets 13 | 14 | ConvNets 15 | ~~~~~~~~~~~~~~~ 16 | 17 | .. autosummary:: 18 | ~pythae.models.nn.benchmarks.celeba.Encoder_Conv_AE_CELEBA 19 | ~pythae.models.nn.benchmarks.celeba.Encoder_Conv_VAE_CELEBA 20 | ~pythae.models.nn.benchmarks.celeba.Encoder_Conv_SVAE_CELEBA 21 | ~pythae.models.nn.benchmarks.celeba.Decoder_Conv_AE_CELEBA 22 | ~pythae.models.nn.benchmarks.celeba.Discriminator_Conv_CELEBA 23 | :nosignatures: 24 | 25 | ResNets 26 | ~~~~~~~~~~~~~~~ 27 | 28 | .. autosummary:: 29 | ~pythae.models.nn.benchmarks.celeba.Encoder_ResNet_AE_CELEBA 30 | ~pythae.models.nn.benchmarks.celeba.Encoder_ResNet_VAE_CELEBA 31 | ~pythae.models.nn.benchmarks.celeba.Encoder_ResNet_SVAE_CELEBA 32 | ~pythae.models.nn.benchmarks.celeba.Encoder_ResNet_VQVAE_CELEBA 33 | ~pythae.models.nn.benchmarks.celeba.Decoder_ResNet_AE_CELEBA 34 | ~pythae.models.nn.benchmarks.celeba.Decoder_ResNet_VQVAE_CELEBA 35 | :nosignatures: 36 | -------------------------------------------------------------------------------- /docs/source/models/nn/celeba/resnets.rst: -------------------------------------------------------------------------------- 1 | ******************************** 2 | ResNets 3 | ******************************** 4 | 5 | .. autoclass:: pythae.models.nn.benchmarks.celeba.Encoder_ResNet_AE_CELEBA 6 | :members: 7 | 8 | .. autoclass:: pythae.models.nn.benchmarks.celeba.Encoder_ResNet_VAE_CELEBA 9 | :members: 10 | 11 | .. autoclass:: pythae.models.nn.benchmarks.celeba.Encoder_ResNet_SVAE_CELEBA 12 | :members: 13 | 14 | .. autoclass:: pythae.models.nn.benchmarks.celeba.Encoder_ResNet_VQVAE_CELEBA 15 | :members: 16 | 17 | .. autoclass:: pythae.models.nn.benchmarks.celeba.Decoder_ResNet_AE_CELEBA 18 | :members: 19 | 20 | .. autoclass:: pythae.models.nn.benchmarks.celeba.Decoder_ResNet_VQVAE_CELEBA 21 | :members: -------------------------------------------------------------------------------- /docs/source/models/nn/cifar/convnets.rst: -------------------------------------------------------------------------------- 1 | ******************************** 2 | ConvNets 3 | ******************************** 4 | 5 | .. autoclass:: pythae.models.nn.benchmarks.cifar.Encoder_Conv_AE_CIFAR 6 | :members: 7 | 8 | .. autoclass:: pythae.models.nn.benchmarks.cifar.Encoder_Conv_VAE_CIFAR 9 | :members: 10 | 11 | .. autoclass:: pythae.models.nn.benchmarks.cifar.Encoder_Conv_SVAE_CIFAR 12 | :members: 13 | 14 | .. autoclass:: pythae.models.nn.benchmarks.cifar.Decoder_Conv_AE_CIFAR 15 | :members: 16 | 17 | .. autoclass:: pythae.models.nn.benchmarks.cifar.Discriminator_Conv_CIFAR 18 | :members: -------------------------------------------------------------------------------- /docs/source/models/nn/cifar/pythae_benchmarks_nn_cifar.rst: -------------------------------------------------------------------------------- 1 | ********************************** 2 | CIFAR 3 | ********************************** 4 | 5 | .. automodule:: 6 | pythae.models.nn.benchmarks.cifar 7 | 8 | .. toctree:: 9 | :maxdepth: 1 10 | 11 | convnets 12 | resnets 13 | 14 | ConvNets 15 | ~~~~~~~~~~~~~~~ 16 | 17 | .. autosummary:: 18 | ~pythae.models.nn.benchmarks.cifar.Encoder_Conv_AE_CIFAR 19 | ~pythae.models.nn.benchmarks.cifar.Encoder_Conv_VAE_CIFAR 20 | ~pythae.models.nn.benchmarks.cifar.Encoder_Conv_SVAE_CIFAR 21 | ~pythae.models.nn.benchmarks.cifar.Decoder_Conv_AE_CIFAR 22 | ~pythae.models.nn.benchmarks.cifar.Discriminator_Conv_CIFAR 23 | :nosignatures: 24 | 25 | ResNets 26 | ~~~~~~~~~~~~~~~ 27 | 28 | .. autosummary:: 29 | ~pythae.models.nn.benchmarks.cifar.Encoder_ResNet_AE_CIFAR 30 | ~pythae.models.nn.benchmarks.cifar.Encoder_ResNet_VAE_CIFAR 31 | ~pythae.models.nn.benchmarks.cifar.Encoder_ResNet_SVAE_CIFAR 32 | ~pythae.models.nn.benchmarks.cifar.Encoder_ResNet_VQVAE_CIFAR 33 | ~pythae.models.nn.benchmarks.cifar.Decoder_ResNet_AE_CIFAR 34 | ~pythae.models.nn.benchmarks.cifar.Decoder_ResNet_VQVAE_CIFAR 35 | :nosignatures: -------------------------------------------------------------------------------- /docs/source/models/nn/cifar/resnets.rst: -------------------------------------------------------------------------------- 1 | ******************************** 2 | ResNets 3 | ******************************** 4 | 5 | .. autoclass:: pythae.models.nn.benchmarks.cifar.Encoder_ResNet_AE_CIFAR 6 | :members: 7 | 8 | .. autoclass:: pythae.models.nn.benchmarks.cifar.Encoder_ResNet_VAE_CIFAR 9 | :members: 10 | 11 | .. autoclass:: pythae.models.nn.benchmarks.cifar.Encoder_ResNet_SVAE_CIFAR 12 | :members: 13 | 14 | .. autoclass:: pythae.models.nn.benchmarks.cifar.Encoder_ResNet_VQVAE_CIFAR 15 | :members: 16 | 17 | .. autoclass:: pythae.models.nn.benchmarks.cifar.Decoder_ResNet_AE_CIFAR 18 | :members: 19 | 20 | .. autoclass:: pythae.models.nn.benchmarks.cifar.Decoder_ResNet_VQVAE_CIFAR 21 | :members: -------------------------------------------------------------------------------- /docs/source/models/nn/mnist/convnets.rst: -------------------------------------------------------------------------------- 1 | ******************************** 2 | ConvNets 3 | ******************************** 4 | 5 | .. autoclass:: pythae.models.nn.benchmarks.mnist.Encoder_Conv_AE_MNIST 6 | :members: 7 | 8 | .. autoclass:: pythae.models.nn.benchmarks.mnist.Encoder_Conv_VAE_MNIST 9 | :members: 10 | 11 | .. autoclass:: pythae.models.nn.benchmarks.mnist.Encoder_Conv_SVAE_MNIST 12 | :members: 13 | 14 | .. autoclass:: pythae.models.nn.benchmarks.mnist.Decoder_Conv_AE_MNIST 15 | :members: 16 | 17 | .. autoclass:: pythae.models.nn.benchmarks.mnist.Discriminator_Conv_MNIST 18 | :members: -------------------------------------------------------------------------------- /docs/source/models/nn/mnist/pythae_benchmarks_nn_mnist.rst: -------------------------------------------------------------------------------- 1 | ********************************** 2 | MNIST 3 | ********************************** 4 | 5 | .. automodule:: 6 | pythae.models.nn.benchmarks.mnist 7 | 8 | .. toctree:: 9 | :maxdepth: 1 10 | 11 | convnets 12 | resnets 13 | 14 | ConvNets 15 | ~~~~~~~~~~~~~~~ 16 | 17 | .. autosummary:: 18 | ~pythae.models.nn.benchmarks.mnist.Encoder_Conv_AE_MNIST 19 | ~pythae.models.nn.benchmarks.mnist.Encoder_Conv_VAE_MNIST 20 | ~pythae.models.nn.benchmarks.mnist.Encoder_Conv_SVAE_MNIST 21 | ~pythae.models.nn.benchmarks.mnist.Decoder_Conv_AE_MNIST 22 | ~pythae.models.nn.benchmarks.mnist.Discriminator_Conv_MNIST 23 | :nosignatures: 24 | 25 | ResNets 26 | ~~~~~~~~~~~~~~~ 27 | 28 | .. autosummary:: 29 | ~pythae.models.nn.benchmarks.mnist.Encoder_ResNet_AE_MNIST 30 | ~pythae.models.nn.benchmarks.mnist.Encoder_ResNet_VAE_MNIST 31 | ~pythae.models.nn.benchmarks.mnist.Encoder_ResNet_SVAE_MNIST 32 | ~pythae.models.nn.benchmarks.mnist.Encoder_ResNet_VQVAE_MNIST 33 | ~pythae.models.nn.benchmarks.mnist.Decoder_ResNet_AE_MNIST 34 | ~pythae.models.nn.benchmarks.mnist.Decoder_ResNet_VQVAE_MNIST 35 | :nosignatures: -------------------------------------------------------------------------------- /docs/source/models/nn/mnist/resnets.rst: -------------------------------------------------------------------------------- 1 | ******************************** 2 | ResNets 3 | ******************************** 4 | 5 | .. autoclass:: pythae.models.nn.benchmarks.mnist.Encoder_ResNet_AE_MNIST 6 | :members: 7 | 8 | .. autoclass:: pythae.models.nn.benchmarks.mnist.Encoder_ResNet_VAE_MNIST 9 | :members: 10 | 11 | .. autoclass:: pythae.models.nn.benchmarks.mnist.Encoder_ResNet_SVAE_MNIST 12 | :members: 13 | 14 | .. autoclass:: pythae.models.nn.benchmarks.mnist.Encoder_ResNet_VQVAE_MNIST 15 | :members: 16 | 17 | .. autoclass:: pythae.models.nn.benchmarks.mnist.Decoder_ResNet_AE_MNIST 18 | :members: 19 | 20 | .. autoclass:: pythae.models.nn.benchmarks.mnist.Decoder_ResNet_VQVAE_MNIST 21 | :members: -------------------------------------------------------------------------------- /docs/source/models/nn/pythae_base_nn.rst: -------------------------------------------------------------------------------- 1 | ********************************** 2 | Base Neural Nets 3 | ********************************** 4 | 5 | The base neural nets from which all further neural nets must inherit. In particular, if you decide 6 | to provide your own ones make them inherit from them. See tutorials for further 7 | details. 8 | 9 | .. autoclass:: pythae.models.nn.BaseEncoder 10 | :members: 11 | 12 | .. autoclass:: pythae.models.nn.BaseDecoder 13 | :members: 14 | 15 | .. autoclass:: pythae.models.nn.BaseMetric 16 | :members: 17 | 18 | .. autoclass:: pythae.models.nn.BaseDiscriminator 19 | :members: 20 | -------------------------------------------------------------------------------- /docs/source/models/normalizing_flows/basenf.rst: -------------------------------------------------------------------------------- 1 | ********************************** 2 | BaseNF 3 | ********************************** 4 | 5 | .. automodule:: 6 | pythae.models.normalizing_flows.base 7 | 8 | .. autoclass:: pythae.models.normalizing_flows.BaseNF 9 | :members: 10 | 11 | .. autoclass:: pythae.models.normalizing_flows.BaseNFConfig 12 | :members: 13 | -------------------------------------------------------------------------------- /docs/source/models/normalizing_flows/iaf.rst: -------------------------------------------------------------------------------- 1 | ********************************** 2 | IAF 3 | ********************************** 4 | 5 | 6 | .. automodule:: 7 | pythae.models.normalizing_flows.iaf 8 | 9 | .. autoclass:: pythae.models.normalizing_flows.IAFConfig 10 | :members: 11 | 12 | .. autoclass:: pythae.models.normalizing_flows.IAF 13 | :members: 14 | -------------------------------------------------------------------------------- /docs/source/models/normalizing_flows/made.rst: -------------------------------------------------------------------------------- 1 | ********************************** 2 | MADE 3 | ********************************** 4 | 5 | 6 | .. automodule:: 7 | pythae.models.normalizing_flows.made 8 | 9 | .. autoclass:: pythae.models.normalizing_flows.MADEConfig 10 | :members: 11 | 12 | .. autoclass:: pythae.models.normalizing_flows.MADE 13 | :members: 14 | -------------------------------------------------------------------------------- /docs/source/models/normalizing_flows/maf.rst: -------------------------------------------------------------------------------- 1 | ********************************** 2 | MAF 3 | ********************************** 4 | 5 | 6 | .. automodule:: 7 | pythae.models.normalizing_flows.maf 8 | 9 | .. autoclass:: pythae.models.normalizing_flows.MAFConfig 10 | :members: 11 | 12 | .. autoclass:: pythae.models.normalizing_flows.MAF 13 | :members: 14 | -------------------------------------------------------------------------------- /docs/source/models/normalizing_flows/normalizing_flows.rst: -------------------------------------------------------------------------------- 1 | ********************************** 2 | Normalizing Flows 3 | ********************************** 4 | 5 | 6 | .. toctree:: 7 | :hidden: 8 | :maxdepth: 1 9 | 10 | basenf 11 | planar_flow 12 | radial_flow 13 | made 14 | maf 15 | iaf 16 | pixelcnn 17 | 18 | .. automodule:: 19 | pythae.models.normalizing_flows 20 | 21 | Available Flows 22 | ----------------- 23 | 24 | .. autosummary:: 25 | ~pythae.models.normalizing_flows.BaseNF 26 | ~pythae.models.normalizing_flows.PlanarFlow 27 | ~pythae.models.normalizing_flows.RadialFlow 28 | ~pythae.models.normalizing_flows.MADE 29 | ~pythae.models.normalizing_flows.MAF 30 | ~pythae.models.normalizing_flows.IAF 31 | ~pythae.models.normalizing_flows.PixelCNN 32 | :nosignatures: 33 | -------------------------------------------------------------------------------- /docs/source/models/normalizing_flows/pixelcnn.rst: -------------------------------------------------------------------------------- 1 | ********************************** 2 | PixelCNN 3 | ********************************** 4 | 5 | 6 | .. automodule:: 7 | pythae.models.normalizing_flows.pixelcnn 8 | 9 | .. autoclass:: pythae.models.normalizing_flows.PixelCNNConfig 10 | :members: 11 | 12 | .. autoclass:: pythae.models.normalizing_flows.PixelCNN 13 | :members: 14 | -------------------------------------------------------------------------------- /docs/source/models/normalizing_flows/planar_flow.rst: -------------------------------------------------------------------------------- 1 | ********************************** 2 | Planar Flow 3 | ********************************** 4 | 5 | 6 | .. automodule:: 7 | pythae.models.normalizing_flows.planar_flow 8 | 9 | .. autoclass:: pythae.models.normalizing_flows.PlanarFlowConfig 10 | :members: 11 | 12 | .. autoclass:: pythae.models.normalizing_flows.PlanarFlow 13 | :members: 14 | -------------------------------------------------------------------------------- /docs/source/models/normalizing_flows/radial_flow.rst: -------------------------------------------------------------------------------- 1 | ********************************** 2 | Radial Flow 3 | ********************************** 4 | 5 | 6 | .. automodule:: 7 | pythae.models.normalizing_flows.radial_flow 8 | 9 | .. autoclass:: pythae.models.normalizing_flows.RadialFlowConfig 10 | :members: 11 | 12 | .. autoclass:: pythae.models.normalizing_flows.RadialFlow 13 | :members: 14 | -------------------------------------------------------------------------------- /docs/source/pipelines/generation.rst: -------------------------------------------------------------------------------- 1 | ********************************** 2 | GenerationPipeline 3 | ********************************** 4 | 5 | .. autoclass:: 6 | pythae.pipelines.GenerationPipeline 7 | :members: __call__ 8 | -------------------------------------------------------------------------------- /docs/source/pipelines/training.rst: -------------------------------------------------------------------------------- 1 | ********************************** 2 | TrainingPipeline 3 | ********************************** 4 | 5 | .. autoclass:: 6 | pythae.pipelines.TrainingPipeline 7 | :members: __call__ 8 | -------------------------------------------------------------------------------- /docs/source/samplers/basesampler.rst: -------------------------------------------------------------------------------- 1 | ********************************** 2 | BaseSampler 3 | ********************************** 4 | 5 | .. automodule:: 6 | pythae.samplers.base 7 | 8 | .. autoclass:: pythae.samplers.BaseSamplerConfig 9 | :members: 10 | 11 | .. autoclass:: pythae.samplers.BaseSampler 12 | :members: -------------------------------------------------------------------------------- /docs/source/samplers/gmm_sampler.rst: -------------------------------------------------------------------------------- 1 | ********************************** 2 | GaussianMixtureSampler 3 | ********************************** 4 | 5 | .. automodule:: 6 | pythae.samplers.gaussian_mixture 7 | 8 | .. autoclass:: pythae.samplers.GaussianMixtureSamplerConfig 9 | :members: 10 | 11 | .. autoclass:: pythae.samplers.GaussianMixtureSampler 12 | :members: -------------------------------------------------------------------------------- /docs/source/samplers/iaf_sampler.rst: -------------------------------------------------------------------------------- 1 | ********************************** 2 | IAFSampler 3 | ********************************** 4 | 5 | .. automodule:: 6 | pythae.samplers.iaf_sampler 7 | 8 | .. autoclass:: pythae.samplers.IAFSamplerConfig 9 | :members: 10 | 11 | .. autoclass:: pythae.samplers.IAFSampler 12 | :members: -------------------------------------------------------------------------------- /docs/source/samplers/maf_sampler.rst: -------------------------------------------------------------------------------- 1 | ********************************** 2 | MAFSampler 3 | ********************************** 4 | 5 | .. automodule:: 6 | pythae.samplers.maf_sampler 7 | 8 | .. autoclass:: pythae.samplers.MAFSamplerConfig 9 | :members: 10 | 11 | .. autoclass:: pythae.samplers.MAFSampler 12 | :members: -------------------------------------------------------------------------------- /docs/source/samplers/normal_sampler.rst: -------------------------------------------------------------------------------- 1 | ********************************** 2 | NormalSampler 3 | ********************************** 4 | 5 | .. automodule:: 6 | pythae.samplers.normal_sampling 7 | 8 | .. autoclass:: pythae.samplers.NormalSampler 9 | :members: -------------------------------------------------------------------------------- /docs/source/samplers/pixelcnn_sampler.rst: -------------------------------------------------------------------------------- 1 | ********************************** 2 | PixelCNNSampler 3 | ********************************** 4 | 5 | .. automodule:: 6 | pythae.samplers.pixelcnn_sampler 7 | 8 | .. autoclass:: pythae.samplers.PixelCNNSamplerConfig 9 | :members: 10 | 11 | .. autoclass:: pythae.samplers.PixelCNNSampler 12 | :members: -------------------------------------------------------------------------------- /docs/source/samplers/poincare_disk_sampler.rst: -------------------------------------------------------------------------------- 1 | ********************************** 2 | PoincareDiskSampler 3 | ********************************** 4 | 5 | .. automodule:: 6 | pythae.samplers.pvae_sampler 7 | 8 | .. autoclass:: pythae.samplers.PoincareDiskSampler 9 | :members: -------------------------------------------------------------------------------- /docs/source/samplers/rhvae_sampler.rst: -------------------------------------------------------------------------------- 1 | ********************************** 2 | RHVAESampler 3 | ********************************** 4 | 5 | .. automodule:: 6 | pythae.samplers.manifold_sampler 7 | 8 | .. autoclass:: pythae.samplers.RHVAESamplerConfig 9 | :members: 10 | 11 | .. autoclass:: pythae.samplers.RHVAESampler 12 | :members: -------------------------------------------------------------------------------- /docs/source/samplers/twostage_sampler.rst: -------------------------------------------------------------------------------- 1 | ********************************** 2 | TwoStageVAESampler 3 | ********************************** 4 | 5 | .. automodule:: 6 | pythae.samplers.two_stage_vae_sampler 7 | 8 | .. autoclass:: pythae.samplers.TwoStageVAESamplerConfig 9 | :members: 10 | 11 | .. autoclass:: pythae.samplers.TwoStageVAESampler 12 | :members: -------------------------------------------------------------------------------- /docs/source/samplers/unit_sphere_unif_sampler.rst: -------------------------------------------------------------------------------- 1 | ********************************** 2 | HypersphereUniformSampler 3 | ********************************** 4 | 5 | .. automodule:: 6 | pythae.samplers.hypersphere_uniform_sampler 7 | 8 | .. autoclass:: pythae.samplers.HypersphereUniformSampler 9 | :members: -------------------------------------------------------------------------------- /docs/source/samplers/vamp_sampler.rst: -------------------------------------------------------------------------------- 1 | ********************************** 2 | VAMPSampler 3 | ********************************** 4 | 5 | .. automodule:: 6 | pythae.samplers.vamp_sampler 7 | 8 | .. autoclass:: pythae.samplers.VAMPSampler 9 | :members: -------------------------------------------------------------------------------- /docs/source/trainers/pythae.trainer.rst: -------------------------------------------------------------------------------- 1 | ********************************** 2 | Trainers 3 | ********************************** 4 | 5 | .. toctree:: 6 | :hidden: 7 | :maxdepth: 1 8 | 9 | pythae.trainers.base_trainer 10 | pythae.trainers.coupled_trainer 11 | pythae.trainers.adversarial_trainer 12 | pythae.trainers.coupled_optimizer_adversarial_trainer 13 | pythae.training_callbacks 14 | 15 | 16 | .. automodule:: 17 | pythae.trainers 18 | 19 | 20 | .. autosummary:: 21 | ~pythae.trainers.BaseTrainer 22 | ~pythae.trainers.CoupledOptimizerTrainer 23 | ~pythae.trainers.AdversarialTrainer 24 | ~pythae.trainers.CoupledOptimizerAdversarialTrainer 25 | :nosignatures: 26 | 27 | ---------------------------------- 28 | Training Callbacks 29 | ---------------------------------- 30 | 31 | .. automodule:: 32 | pythae.trainers.training_callbacks 33 | 34 | .. autosummary:: 35 | ~pythae.trainers.training_callbacks.TrainingCallback 36 | ~pythae.trainers.training_callbacks.CallbackHandler 37 | ~pythae.trainers.training_callbacks.MetricConsolePrinterCallback 38 | ~pythae.trainers.training_callbacks.ProgressBarCallback 39 | ~pythae.trainers.training_callbacks.WandbCallback 40 | ~pythae.trainers.training_callbacks.MLFlowCallback 41 | ~pythae.trainers.training_callbacks.CometCallback 42 | :nosignatures: -------------------------------------------------------------------------------- /docs/source/trainers/pythae.trainers.adversarial_trainer.rst: -------------------------------------------------------------------------------- 1 | ********************************** 2 | AdversarialTrainer 3 | ********************************** 4 | 5 | .. automodule:: 6 | pythae.trainers.adversarial_trainer 7 | 8 | .. autoclass:: pythae.trainers.AdversarialTrainerConfig 9 | :members: 10 | 11 | .. autoclass:: pythae.trainers.AdversarialTrainer 12 | :members: -------------------------------------------------------------------------------- /docs/source/trainers/pythae.trainers.base_trainer.rst: -------------------------------------------------------------------------------- 1 | ********************************** 2 | BaseTrainer 3 | ********************************** 4 | 5 | .. automodule:: 6 | pythae.trainers.base_trainer 7 | 8 | .. autoclass:: pythae.trainers.BaseTrainerConfig 9 | :members: 10 | 11 | .. autoclass:: pythae.trainers.BaseTrainer 12 | :members: -------------------------------------------------------------------------------- /docs/source/trainers/pythae.trainers.coupled_optimizer_adversarial_trainer.rst: -------------------------------------------------------------------------------- 1 | ********************************** 2 | CoupledOptimizerAdversarialTrainer 3 | ********************************** 4 | 5 | .. automodule:: 6 | pythae.trainers.coupled_optimizer_adversarial_trainer 7 | 8 | .. autoclass:: pythae.trainers.CoupledOptimizerAdversarialTrainerConfig 9 | :members: 10 | 11 | .. autoclass:: pythae.trainers.CoupledOptimizerAdversarialTrainer 12 | :members: -------------------------------------------------------------------------------- /docs/source/trainers/pythae.trainers.coupled_trainer.rst: -------------------------------------------------------------------------------- 1 | ********************************** 2 | CoupledOptimizerTrainer 3 | ********************************** 4 | 5 | .. automodule:: 6 | pythae.trainers.coupled_optimizer_trainer 7 | 8 | .. autoclass:: pythae.trainers.CoupledOptimizerTrainerConfig 9 | :members: 10 | 11 | .. autoclass:: pythae.trainers.CoupledOptimizerTrainer 12 | :members: -------------------------------------------------------------------------------- /docs/source/trainers/pythae.training_callbacks.rst: -------------------------------------------------------------------------------- 1 | ********************************** 2 | TrainingCallbacks 3 | ********************************** 4 | 5 | .. automodule:: 6 | pythae.trainers.training_callbacks 7 | 8 | .. autoclass:: pythae.trainers.training_callbacks.TrainingCallback 9 | :members: 10 | 11 | .. autoclass:: pythae.trainers.training_callbacks.CallbackHandler 12 | :members: 13 | 14 | .. autoclass:: pythae.trainers.training_callbacks.MetricConsolePrinterCallback 15 | :members: 16 | 17 | .. autoclass:: pythae.trainers.training_callbacks.ProgressBarCallback 18 | :members: 19 | 20 | .. autoclass:: pythae.trainers.training_callbacks.WandbCallback 21 | :members: 22 | 23 | .. autoclass:: pythae.trainers.training_callbacks.MLFlowCallback 24 | :members: 25 | 26 | .. autoclass:: pythae.trainers.training_callbacks.CometCallback 27 | :members: -------------------------------------------------------------------------------- /examples/notebooks/README.md: -------------------------------------------------------------------------------- 1 | ## Tutorials 2 | 3 | To run notebook tutorials, you will need to install the dependencies used by running 4 | ```bash 5 | $ pip install -r requirements.txt 6 | -------------------------------------------------------------------------------- /examples/notebooks/requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib>=3.3.2 2 | torchvision>=0.9.1 -------------------------------------------------------------------------------- /examples/scripts/README.md: -------------------------------------------------------------------------------- 1 | ## Scripts 2 | 3 | We also provided a training script as example allowing you to train VAE models on well known benchmark data set (mnist, cifar10, celeba ...). 4 | The script can be launched with the following commandline 5 | 6 | ```bash 7 | python training.py --dataset mnist --model_name ae --model_config 'configs/ae_config.json' --training_config 'configs/base_training_config.json' 8 | ``` 9 | 10 | The folder structure should be as follows: 11 | ```bash 12 | . 13 | ├── configs # the model & training config files (you can amend these files as desired or specify the location of yours in '--model_config' ) 14 | │   ├── ae_config.json 15 | │   ├── base_training_config.json 16 | │   ├── beta_vae_config.json 17 | │   ├── hvae_config.json 18 | │   ├── rhvae_config.json 19 | │   ├── vae_config.json 20 | │   └── vamp_config.json 21 | ├── data # the dataset with train_data.npz and eval_data.npz files 22 | │   ├── celeba 23 | │   │   ├── eval_data.npz 24 | │   │   └── train_data.npz 25 | │   ├── cifar10 26 | │   │   ├── eval_data.npz 27 | │   │   └── train_data.npz 28 | │   └── mnist 29 | │   ├── eval_data.npz 30 | │   └── train_data.npz 31 | ├── my_models # trained models are saved here 32 | │   ├── AE_training_2021-10-15_16-07-04 33 | │   └── RHVAE_training_2021-10-15_15-54-27 34 | ├── README.md 35 | └── training.py 36 | ``` 37 | 38 | **Note** The data in the `train_data.npz` and `eval_data.npz` files must be loadable as follows 39 | 40 | ```python 41 | train_data = np.load(os.path.join(PATH, f'data/{args.dataset}', 'train_data.npz'))['data'] 42 | eval_data = np.load(os.path.join(PATH, f'data/{args.dataset}', 'eval_data.npz'))['data'] 43 | ``` 44 | where `train_data` and `eval_data` have now the shape (n_img x im_channel x height x width) -------------------------------------------------------------------------------- /examples/scripts/configs/binary_mnist/base_training_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "BaseTrainerConfig", 3 | "output_dir": "reproducibility/binary_mnist", 4 | "per_device_train_batch_size": 100, 5 | "per_device_eval_batch_size": 100, 6 | "num_epochs": 500, 7 | "learning_rate": 1e-4, 8 | "steps_saving": null, 9 | "steps_predict": 20, 10 | "no_cuda": false 11 | } 12 | -------------------------------------------------------------------------------- /examples/scripts/configs/binary_mnist/hvae_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "HVAEConfig", 3 | "latent_dim": 16, 4 | "reconstruction_loss": "bce", 5 | "n_lf": 3, 6 | "eps_lf": 0.001, 7 | "beta_zero": 0.3, 8 | "learn_eps_lf": false, 9 | "learn_beta_zero": false 10 | } 11 | -------------------------------------------------------------------------------- /examples/scripts/configs/binary_mnist/iwae_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "IWAEConfig", 3 | "latent_dim": 16, 4 | "reconstruction_loss": "bce", 5 | "number_samples": 50, 6 | "beta": 1 7 | } -------------------------------------------------------------------------------- /examples/scripts/configs/binary_mnist/rhvae_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "RHVAEConfig", 3 | "latent_dim": 16, 4 | "reconstruction_loss": "bce", 5 | "n_lf": 3, 6 | "eps_lf": 0.001, 7 | "beta_zero": 0.3, 8 | "temperature": 0.8, 9 | "regularization": 0.01 10 | } 11 | -------------------------------------------------------------------------------- /examples/scripts/configs/binary_mnist/svae_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "SVAEConfig", 3 | "latent_dim": 16, 4 | "reconstruction_loss": "bce" 5 | } 6 | -------------------------------------------------------------------------------- /examples/scripts/configs/binary_mnist/vae_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "VAEConfig", 3 | "latent_dim": 16, 4 | "reconstruction_loss": "bce" 5 | } -------------------------------------------------------------------------------- /examples/scripts/configs/binary_mnist/vae_iaf_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "VAE_IAF_Config", 3 | "latent_dim": 16, 4 | "reconstruction_loss": "bce", 5 | "n_made_blocks": 2, 6 | "n_hidden_in_made":2, 7 | "hidden_size": 32 8 | } 9 | -------------------------------------------------------------------------------- /examples/scripts/configs/binary_mnist/vae_lin_nf_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "VAE_LinNF_Config", 3 | "latent_dim": 16, 4 | "reconstruction_loss": "bce", 5 | "flows": ["Planar", "Planar", "Planar"] 6 | } 7 | -------------------------------------------------------------------------------- /examples/scripts/configs/binary_mnist/vamp_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "VAMPConfig", 3 | "latent_dim": 16, 4 | "reconstruction_loss": "bce", 5 | "number_components": 50 6 | } 7 | -------------------------------------------------------------------------------- /examples/scripts/configs/celeba/aae_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "AdversarialAEConfig", 3 | "latent_dim": 64, 4 | "reconstruction_loss": "mse", 5 | "adversarial_loss_scale": 0.9 6 | } -------------------------------------------------------------------------------- /examples/scripts/configs/celeba/ae_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "AEConfig", 3 | "latent_dim": 64 4 | } -------------------------------------------------------------------------------- /examples/scripts/configs/celeba/base_training_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "BaseTrainerConfig", 3 | "output_dir": "my_models_on_celeba", 4 | "per_device_train_batch_size": 100, 5 | "per_device_eval_batch_size": 100, 6 | "num_epochs": 50, 7 | "learning_rate": 0.001, 8 | "steps_saving": null, 9 | "steps_predict": 50, 10 | "no_cuda": false 11 | } 12 | -------------------------------------------------------------------------------- /examples/scripts/configs/celeba/beta_tc_vae_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "BetaTCVAEConfig", 3 | "latent_dim": 64, 4 | "reconstruction_loss": "mse", 5 | "beta": 2, 6 | "gamma": 1, 7 | "alpha": 1, 8 | "use_mss": true 9 | } -------------------------------------------------------------------------------- /examples/scripts/configs/celeba/beta_vae_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "BetaVAEConfig", 3 | "latent_dim": 64, 4 | "reconstruction_loss": "mse", 5 | "beta": 2 6 | } -------------------------------------------------------------------------------- /examples/scripts/configs/celeba/disentangled_beta_vae_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "DisentangledBetaVAEConfig", 3 | "latent_dim": 64, 4 | "reconstruction_loss": "mse", 5 | "beta": 5, 6 | "C": 30.0, 7 | "warmup_epoch": 25 8 | } -------------------------------------------------------------------------------- /examples/scripts/configs/celeba/factor_vae_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "FactorVAEConfig", 3 | "latent_dim": 64, 4 | "reconstruction_loss": "mse", 5 | "gamma": 10 6 | } -------------------------------------------------------------------------------- /examples/scripts/configs/celeba/hvae_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "HVAEConfig", 3 | "latent_dim": 64, 4 | "reconstruction_loss": "mse", 5 | "n_lf": 1, 6 | "eps_lf": 0.001, 7 | "beta_zero": 0.3, 8 | "learn_eps_lf": false, 9 | "learn_beta_zero": false 10 | } 11 | -------------------------------------------------------------------------------- /examples/scripts/configs/celeba/info_vae_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "INFOVAE_MMD_Config", 3 | "latent_dim": 64, 4 | "reconstruction_loss": "mse", 5 | "kernel_choice": "imq", 6 | "alpha": 0, 7 | "lbd": 1000, 8 | "kernel_bandwidth": 1.0 9 | } -------------------------------------------------------------------------------- /examples/scripts/configs/celeba/msssim_vae_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "MSSSIM_VAEConfig", 3 | "latent_dim": 64, 4 | "window_size": 11, 5 | "beta": 1e-3 6 | } 7 | -------------------------------------------------------------------------------- /examples/scripts/configs/celeba/rae_gp_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "RAE_GP_Config", 3 | "latent_dim": 64, 4 | "embedding_weight": 1e-4, 5 | "reg_weight": 1e-7 6 | } -------------------------------------------------------------------------------- /examples/scripts/configs/celeba/rae_l2_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "RAE_L2_Config", 3 | "latent_dim": 64, 4 | "embedding_weight": 1e-4, 5 | "reg_weight": 1e-7 6 | } -------------------------------------------------------------------------------- /examples/scripts/configs/celeba/rhvae_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "RHVAEConfig", 3 | "latent_dim": 64, 4 | "reconstruction_loss": "mse", 5 | "n_lf": 1, 6 | "eps_lf": 0.001, 7 | "beta_zero": 0.3, 8 | "temperature": 1.5, 9 | "regularization": 0.01 10 | } 11 | -------------------------------------------------------------------------------- /examples/scripts/configs/celeba/svae_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "SVAEConfig", 3 | "latent_dim": 32, 4 | "reconstruction_loss": "mse" 5 | } 6 | -------------------------------------------------------------------------------- /examples/scripts/configs/celeba/vae_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "VAEConfig", 3 | "latent_dim": 64, 4 | "reconstruction_loss": "mse" 5 | } -------------------------------------------------------------------------------- /examples/scripts/configs/celeba/vae_iaf_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "VAE_IAF_Config", 3 | "latent_dim": 64, 4 | "reconstruction_loss": "mse", 5 | "n_made_blocks": 3, 6 | "n_hidden_in_made":2, 7 | "hidden_size": 128 8 | } 9 | -------------------------------------------------------------------------------- /examples/scripts/configs/celeba/vae_lin_nf_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "VAE_LinNF_Config", 3 | "latent_dim": 64, 4 | "reconstruction_loss": "mse", 5 | "flows": ["Planar", "Planar", "Planar", "Planar", "Planar"] 6 | } 7 | -------------------------------------------------------------------------------- /examples/scripts/configs/celeba/vaegan_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "VAEGANConfig", 3 | "latent_dim": 64, 4 | "reconstruction_loss": "mse", 5 | "adversarial_loss_scale": 0.99999, 6 | "reconstruction_layer": 3, 7 | "margin":0.4, 8 | "equilibrium": 0.68 9 | } 10 | -------------------------------------------------------------------------------- /examples/scripts/configs/celeba/vamp_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "VAMPConfig", 3 | "latent_dim": 64, 4 | "reconstruction_loss": "mse", 5 | "number_components": 50 6 | } 7 | -------------------------------------------------------------------------------- /examples/scripts/configs/celeba/vqvae_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "VQVAEConfig", 3 | "latent_dim": 64, 4 | "commitment_loss_factor": 0.25, 5 | "quantization_loss_factor": 1, 6 | "num_embeddings": 1024, 7 | "use_ema": true, 8 | "decay": 0.99 9 | } 10 | -------------------------------------------------------------------------------- /examples/scripts/configs/celeba/wae_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "WAE_MMD_Config", 3 | "latent_dim": 64, 4 | "kernel_choice": "imq", 5 | "reg_weight": 100, 6 | "kernel_bandwidth": 1.0 7 | } -------------------------------------------------------------------------------- /examples/scripts/configs/cifar10/ae_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "latent_dim": 16 3 | } -------------------------------------------------------------------------------- /examples/scripts/configs/cifar10/base_training_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "BaseTrainerConfig", 3 | "output_dir": "my_models_on_cifar", 4 | "per_device_train_batch_size": 100, 5 | "per_device_eval_batch_size": 100, 6 | "num_epochs": 100, 7 | "learning_rate": 1e-4, 8 | "steps_saving": null, 9 | "steps_predict": 100, 10 | "no_cuda": false 11 | } 12 | -------------------------------------------------------------------------------- /examples/scripts/configs/cifar10/beta_vae_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "latent_dim": 16, 3 | "reconstruction_loss": "mse", 4 | "number_components": 50, 5 | "name": "BetaVAE" 6 | } 7 | -------------------------------------------------------------------------------- /examples/scripts/configs/cifar10/hvae_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "latent_dim": 16, 3 | "reconstruction_loss": "mse", 4 | "n_lf": 3, 5 | "eps_lf": 0.001, 6 | "beta_zero": 0.3, 7 | "learn_eps_lf": false, 8 | "learn_beta_zero": false 9 | } 10 | -------------------------------------------------------------------------------- /examples/scripts/configs/cifar10/info_vae_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "latent_dim": 16, 3 | "reconstruction_loss": "mse", 4 | "kernel_choice": "imq", 5 | "alpha": 0.5, 6 | "lbd": 3e-2, 7 | "kernel_bandwidth": 1.0 8 | } -------------------------------------------------------------------------------- /examples/scripts/configs/cifar10/rae_gp_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "latent_dim": 16, 3 | "embedding_weight": 1e-2, 4 | "reg_weight": 1e-4 5 | } -------------------------------------------------------------------------------- /examples/scripts/configs/cifar10/rae_l2_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "latent_dim": 16, 3 | "embedding_weight": 1e-2, 4 | "reg_weight": 1e-4 5 | } -------------------------------------------------------------------------------- /examples/scripts/configs/cifar10/rhvae_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "latent_dim": 16, 3 | "reconstruction_loss": "mse", 4 | "n_lf": 3, 5 | "eps_lf": 0.001, 6 | "beta_zero": 0.3, 7 | "temperature": 1.5, 8 | "regularization": 0.01 9 | } 10 | -------------------------------------------------------------------------------- /examples/scripts/configs/cifar10/vae_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "VAEConfig", 3 | "latent_dim": 16, 4 | "reconstruction_loss": "mse" 5 | } -------------------------------------------------------------------------------- /examples/scripts/configs/cifar10/vamp_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "latent_dim": 16, 3 | "reconstruction_loss": "mse", 4 | "number_components": 50 5 | } -------------------------------------------------------------------------------- /examples/scripts/configs/cifar10/vqvae_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "VQVAEConfig", 3 | "latent_dim": 16, 4 | "commitment_loss_factor": 0.25, 5 | "quantization_loss_factor": 1, 6 | "num_embeddings": 256, 7 | "decay": 0.99, 8 | "use_ema": true 9 | } 10 | -------------------------------------------------------------------------------- /examples/scripts/configs/cifar10/wae_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "latent_dim": 16, 3 | "kernel_choice": "imq", 4 | "reg_weight": 1, 5 | "kernel_bandwidth": 2.0 6 | } -------------------------------------------------------------------------------- /examples/scripts/configs/dsprites/beta_tc_vae/base_training_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "BaseTrainerConfig", 3 | "output_dir": "reproducibility/dsprites", 4 | "per_device_train_batch_size": 1000, 5 | "per_device_eval_batch_size": 1000, 6 | "num_epochs": 50, 7 | "learning_rate": 1e-3, 8 | "steps_saving": null, 9 | "steps_predict": null, 10 | "no_cuda": false 11 | } 12 | -------------------------------------------------------------------------------- /examples/scripts/configs/dsprites/beta_tc_vae/beta_tc_vae_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "BetaTCVAEConfig", 3 | "latent_dim": 10, 4 | "reconstruction_loss": "bce", 5 | "beta": 6, 6 | "gamma": 1, 7 | "alpha": 1, 8 | "use_mss": false 9 | } -------------------------------------------------------------------------------- /examples/scripts/configs/dsprites/factorvae/base_training_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "BaseTrainerConfig", 3 | "output_dir": "reproducibility/dsprites", 4 | "per_device_train_batch_size": 64, 5 | "per_device_eval_batch_size": 64, 6 | "num_epochs": 500, 7 | "learning_rate": 1e-3, 8 | "steps_saving": 50, 9 | "steps_predict": null, 10 | "no_cuda": false 11 | } 12 | -------------------------------------------------------------------------------- /examples/scripts/configs/dsprites/factorvae/factorvae_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "FactorVAEConfig", 3 | "latent_dim": 10, 4 | "reconstruction_loss": "bce", 5 | "gamma": 35 6 | } 7 | -------------------------------------------------------------------------------- /examples/scripts/configs/mnist/aae_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "AdversarialAEConfig", 3 | "latent_dim": 16, 4 | "reconstruction_loss": "mse", 5 | "adversarial_loss_scale": 0.9 6 | } -------------------------------------------------------------------------------- /examples/scripts/configs/mnist/ae_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "AEConfig", 3 | "latent_dim": 16 4 | } -------------------------------------------------------------------------------- /examples/scripts/configs/mnist/base_training_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "BaseTrainerConfig", 3 | "output_dir": "my_models_on_mnist", 4 | "per_device_train_batch_size": 100, 5 | "per_device_eval_batch_size": 100, 6 | "num_epochs": 100, 7 | "learning_rate": 1e-3, 8 | "steps_saving": null, 9 | "steps_predict": 100, 10 | "no_cuda": false 11 | } 12 | -------------------------------------------------------------------------------- /examples/scripts/configs/mnist/beta_tc_vae_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "BetaTCVAEConfig", 3 | "latent_dim": 16, 4 | "reconstruction_loss": "mse", 5 | "beta": 2, 6 | "gamma": 1, 7 | "alpha": 1, 8 | "use_mss": true 9 | } -------------------------------------------------------------------------------- /examples/scripts/configs/mnist/beta_vae_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "BetaVAEConfig", 3 | "latent_dim": 16, 4 | "reconstruction_loss": "mse", 5 | "beta": 2 6 | } -------------------------------------------------------------------------------- /examples/scripts/configs/mnist/disentangled_beta_vae_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "DisentangledBetaVAEConfig", 3 | "latent_dim": 16, 4 | "reconstruction_loss": "mse", 5 | "beta": 5, 6 | "C": 30.0, 7 | "warmup_epoch": 25 8 | } -------------------------------------------------------------------------------- /examples/scripts/configs/mnist/factor_vae_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "FactorVAEConfig", 3 | "latent_dim": 16, 4 | "reconstruction_loss": "mse", 5 | "gamma": 5 6 | } -------------------------------------------------------------------------------- /examples/scripts/configs/mnist/hvae_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "HVAEConfig", 3 | "latent_dim": 16, 4 | "reconstruction_loss": "mse", 5 | "n_lf": 3, 6 | "eps_lf": 0.001, 7 | "beta_zero": 0.3, 8 | "learn_eps_lf": false, 9 | "learn_beta_zero": false 10 | } 11 | -------------------------------------------------------------------------------- /examples/scripts/configs/mnist/info_vae_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "INFOVAE_MMD_Config", 3 | "latent_dim": 16, 4 | "reconstruction_loss": "mse", 5 | "kernel_choice": "imq", 6 | "alpha": 0, 7 | "lbd": 1000, 8 | "kernel_bandwidth": 1.0 9 | } -------------------------------------------------------------------------------- /examples/scripts/configs/mnist/iwae_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "IWAEConfig", 3 | "latent_dim": 16, 4 | "reconstruction_loss": "mse", 5 | "number_samples": 5, 6 | "beta": 1 7 | } -------------------------------------------------------------------------------- /examples/scripts/configs/mnist/msssim_vae_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "MSSSIM_VAEConfig", 3 | "latent_dim": 16, 4 | "window_size": 3, 5 | "beta": 1e-2 6 | } -------------------------------------------------------------------------------- /examples/scripts/configs/mnist/rae_gp_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "RAE_GP_Config", 3 | "latent_dim": 16, 4 | "embedding_weight": 1e-2, 5 | "reg_weight": 1e-4 6 | } -------------------------------------------------------------------------------- /examples/scripts/configs/mnist/rae_l2_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "RAE_L2_Config", 3 | "latent_dim": 16, 4 | "embedding_weight": 1e-2, 5 | "reg_weight": 1e-4 6 | } -------------------------------------------------------------------------------- /examples/scripts/configs/mnist/rhvae_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "RHVAEConfig", 3 | "latent_dim": 16, 4 | "reconstruction_loss": "mse", 5 | "n_lf": 3, 6 | "eps_lf": 0.001, 7 | "beta_zero": 0.3, 8 | "temperature": 0.8, 9 | "regularization": 0.01 10 | } 11 | -------------------------------------------------------------------------------- /examples/scripts/configs/mnist/svae_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "SVAEConfig", 3 | "latent_dim": 16, 4 | "reconstruction_loss": "mse" 5 | } 6 | -------------------------------------------------------------------------------- /examples/scripts/configs/mnist/vae_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "VAEConfig", 3 | "latent_dim": 16, 4 | "reconstruction_loss": "mse" 5 | } -------------------------------------------------------------------------------- /examples/scripts/configs/mnist/vae_iaf_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "VAE_IAF_Config", 3 | "latent_dim": 16, 4 | "reconstruction_loss": "mse", 5 | "n_made_blocks": 2, 6 | "n_hidden_in_made":2, 7 | "hidden_size": 32 8 | } 9 | -------------------------------------------------------------------------------- /examples/scripts/configs/mnist/vae_lin_nf_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "VAE_LinNF_Config", 3 | "latent_dim": 16, 4 | "reconstruction_loss": "mse", 5 | "flows": ["Planar", "Planar", "Planar", "Planar", "Planar"] 6 | } 7 | -------------------------------------------------------------------------------- /examples/scripts/configs/mnist/vaegan_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "VAEGANConfig", 3 | "latent_dim": 16, 4 | "reconstruction_loss": "mse", 5 | "adversarial_loss_scale": 0.8, 6 | "reconstruction_layer": 3, 7 | "margin":0.4, 8 | "equilibrium": 0.68 9 | } 10 | -------------------------------------------------------------------------------- /examples/scripts/configs/mnist/vamp_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "VAMPConfig", 3 | "latent_dim": 16, 4 | "reconstruction_loss": "mse", 5 | "number_components": 50 6 | } 7 | -------------------------------------------------------------------------------- /examples/scripts/configs/mnist/vqvae_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "VQVAEConfig", 3 | "latent_dim": 16, 4 | "commitment_loss_factor": 0.25, 5 | "quantization_loss_factor": 1, 6 | "num_embeddings": 256, 7 | "use_ema": true 8 | } 9 | -------------------------------------------------------------------------------- /examples/scripts/configs/mnist/wae_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "WAE_MMD_Config", 3 | "latent_dim": 16, 4 | "kernel_choice": "imq", 5 | "reg_weight": 10, 6 | "kernel_bandwidth": 1.0 7 | } -------------------------------------------------------------------------------- /examples/showcases/aae_normal_sampling_celeba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/aae_normal_sampling_celeba.png -------------------------------------------------------------------------------- /examples/showcases/aae_normal_sampling_mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/aae_normal_sampling_mnist.png -------------------------------------------------------------------------------- /examples/showcases/aae_reconstruction_celeba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/aae_reconstruction_celeba.png -------------------------------------------------------------------------------- /examples/showcases/aae_reconstruction_mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/aae_reconstruction_mnist.png -------------------------------------------------------------------------------- /examples/showcases/ae_gmm_sampling_celeba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/ae_gmm_sampling_celeba.png -------------------------------------------------------------------------------- /examples/showcases/ae_gmm_sampling_mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/ae_gmm_sampling_mnist.png -------------------------------------------------------------------------------- /examples/showcases/ae_normal_sampling_celeba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/ae_normal_sampling_celeba.png -------------------------------------------------------------------------------- /examples/showcases/ae_normal_sampling_mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/ae_normal_sampling_mnist.png -------------------------------------------------------------------------------- /examples/showcases/ae_reconstruction_celeba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/ae_reconstruction_celeba.png -------------------------------------------------------------------------------- /examples/showcases/ae_reconstruction_mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/ae_reconstruction_mnist.png -------------------------------------------------------------------------------- /examples/showcases/beta_tc_vae_normal_sampling_celeba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/beta_tc_vae_normal_sampling_celeba.png -------------------------------------------------------------------------------- /examples/showcases/beta_tc_vae_normal_sampling_mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/beta_tc_vae_normal_sampling_mnist.png -------------------------------------------------------------------------------- /examples/showcases/beta_tc_vae_reconstruction_celeba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/beta_tc_vae_reconstruction_celeba.png -------------------------------------------------------------------------------- /examples/showcases/beta_tc_vae_reconstruction_mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/beta_tc_vae_reconstruction_mnist.png -------------------------------------------------------------------------------- /examples/showcases/beta_vae_normal_sampling_celeba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/beta_vae_normal_sampling_celeba.png -------------------------------------------------------------------------------- /examples/showcases/beta_vae_normal_sampling_mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/beta_vae_normal_sampling_mnist.png -------------------------------------------------------------------------------- /examples/showcases/beta_vae_reconstruction_celeba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/beta_vae_reconstruction_celeba.png -------------------------------------------------------------------------------- /examples/showcases/beta_vae_reconstruction_mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/beta_vae_reconstruction_mnist.png -------------------------------------------------------------------------------- /examples/showcases/disentangled_beta_vae_normal_sampling_celeba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/disentangled_beta_vae_normal_sampling_celeba.png -------------------------------------------------------------------------------- /examples/showcases/disentangled_beta_vae_normal_sampling_mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/disentangled_beta_vae_normal_sampling_mnist.png -------------------------------------------------------------------------------- /examples/showcases/disentangled_beta_vae_reconstruction_celeba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/disentangled_beta_vae_reconstruction_celeba.png -------------------------------------------------------------------------------- /examples/showcases/disentangled_beta_vae_reconstruction_mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/disentangled_beta_vae_reconstruction_mnist.png -------------------------------------------------------------------------------- /examples/showcases/eval_reconstruction_celeba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/eval_reconstruction_celeba.png -------------------------------------------------------------------------------- /examples/showcases/eval_reconstruction_mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/eval_reconstruction_mnist.png -------------------------------------------------------------------------------- /examples/showcases/factor_vae_normal_sampling_celeba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/factor_vae_normal_sampling_celeba.png -------------------------------------------------------------------------------- /examples/showcases/factor_vae_normal_sampling_mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/factor_vae_normal_sampling_mnist.png -------------------------------------------------------------------------------- /examples/showcases/factor_vae_reconstruction_celeba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/factor_vae_reconstruction_celeba.png -------------------------------------------------------------------------------- /examples/showcases/factor_vae_reconstruction_mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/factor_vae_reconstruction_mnist.png -------------------------------------------------------------------------------- /examples/showcases/hvae_gmm_sampling_mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/hvae_gmm_sampling_mnist.png -------------------------------------------------------------------------------- /examples/showcases/hvae_normal_sampling_celeba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/hvae_normal_sampling_celeba.png -------------------------------------------------------------------------------- /examples/showcases/hvae_normal_sampling_mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/hvae_normal_sampling_mnist.png -------------------------------------------------------------------------------- /examples/showcases/hvae_reconstruction_celeba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/hvae_reconstruction_celeba.png -------------------------------------------------------------------------------- /examples/showcases/hvae_reconstruction_mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/hvae_reconstruction_mnist.png -------------------------------------------------------------------------------- /examples/showcases/hvvae_gmm_sampling_mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/hvvae_gmm_sampling_mnist.png -------------------------------------------------------------------------------- /examples/showcases/hvvae_normal_sampling_mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/hvvae_normal_sampling_mnist.png -------------------------------------------------------------------------------- /examples/showcases/infovae_normal_sampling_celeba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/infovae_normal_sampling_celeba.png -------------------------------------------------------------------------------- /examples/showcases/infovae_normal_sampling_mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/infovae_normal_sampling_mnist.png -------------------------------------------------------------------------------- /examples/showcases/infovae_reconstruction_celeba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/infovae_reconstruction_celeba.png -------------------------------------------------------------------------------- /examples/showcases/infovae_reconstruction_mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/infovae_reconstruction_mnist.png -------------------------------------------------------------------------------- /examples/showcases/iwae_gmm_sampling_mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/iwae_gmm_sampling_mnist.png -------------------------------------------------------------------------------- /examples/showcases/iwae_normal_sampling_celeba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/iwae_normal_sampling_celeba.png -------------------------------------------------------------------------------- /examples/showcases/iwae_normal_sampling_mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/iwae_normal_sampling_mnist.png -------------------------------------------------------------------------------- /examples/showcases/iwae_reconstruction_celeba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/iwae_reconstruction_celeba.png -------------------------------------------------------------------------------- /examples/showcases/iwae_reconstruction_mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/iwae_reconstruction_mnist.png -------------------------------------------------------------------------------- /examples/showcases/msssim_vae_normal_sampling_celeba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/msssim_vae_normal_sampling_celeba.png -------------------------------------------------------------------------------- /examples/showcases/msssim_vae_normal_sampling_mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/msssim_vae_normal_sampling_mnist.png -------------------------------------------------------------------------------- /examples/showcases/msssim_vae_reconstruction_celeba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/msssim_vae_reconstruction_celeba.png -------------------------------------------------------------------------------- /examples/showcases/msssim_vae_reconstruction_mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/msssim_vae_reconstruction_mnist.png -------------------------------------------------------------------------------- /examples/showcases/rae_gp_gmm_sampling_celeba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/rae_gp_gmm_sampling_celeba.png -------------------------------------------------------------------------------- /examples/showcases/rae_gp_gmm_sampling_mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/rae_gp_gmm_sampling_mnist.png -------------------------------------------------------------------------------- /examples/showcases/rae_gp_normal_sampling_celeba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/rae_gp_normal_sampling_celeba.png -------------------------------------------------------------------------------- /examples/showcases/rae_gp_normal_sampling_mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/rae_gp_normal_sampling_mnist.png -------------------------------------------------------------------------------- /examples/showcases/rae_gp_reconstruction_celeba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/rae_gp_reconstruction_celeba.png -------------------------------------------------------------------------------- /examples/showcases/rae_gp_reconstruction_mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/rae_gp_reconstruction_mnist.png -------------------------------------------------------------------------------- /examples/showcases/rae_l2_gmm_sampling_celeba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/rae_l2_gmm_sampling_celeba.png -------------------------------------------------------------------------------- /examples/showcases/rae_l2_gmm_sampling_mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/rae_l2_gmm_sampling_mnist.png -------------------------------------------------------------------------------- /examples/showcases/rae_l2_normal_sampling_celeba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/rae_l2_normal_sampling_celeba.png -------------------------------------------------------------------------------- /examples/showcases/rae_l2_normal_sampling_mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/rae_l2_normal_sampling_mnist.png -------------------------------------------------------------------------------- /examples/showcases/rae_l2_reconstruction_celeba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/rae_l2_reconstruction_celeba.png -------------------------------------------------------------------------------- /examples/showcases/rae_l2_reconstruction_mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/rae_l2_reconstruction_mnist.png -------------------------------------------------------------------------------- /examples/showcases/rhvae_reconstruction_celeba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/rhvae_reconstruction_celeba.png -------------------------------------------------------------------------------- /examples/showcases/rhvae_reconstruction_mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/rhvae_reconstruction_mnist.png -------------------------------------------------------------------------------- /examples/showcases/rhvae_rhvae_sampling_celeba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/rhvae_rhvae_sampling_celeba.png -------------------------------------------------------------------------------- /examples/showcases/rhvae_rhvae_sampling_mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/rhvae_rhvae_sampling_mnist.png -------------------------------------------------------------------------------- /examples/showcases/svae_hypersphere_uniform_sampling_celeba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/svae_hypersphere_uniform_sampling_celeba.png -------------------------------------------------------------------------------- /examples/showcases/svae_hypersphere_uniform_sampling_mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/svae_hypersphere_uniform_sampling_mnist.png -------------------------------------------------------------------------------- /examples/showcases/svae_reconstruction_celeba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/svae_reconstruction_celeba.png -------------------------------------------------------------------------------- /examples/showcases/svae_reconstruction_mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/svae_reconstruction_mnist.png -------------------------------------------------------------------------------- /examples/showcases/vae_gmm_sampling_celeba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/vae_gmm_sampling_celeba.png -------------------------------------------------------------------------------- /examples/showcases/vae_gmm_sampling_mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/vae_gmm_sampling_mnist.png -------------------------------------------------------------------------------- /examples/showcases/vae_iaf_normal_sampling_celeba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/vae_iaf_normal_sampling_celeba.png -------------------------------------------------------------------------------- /examples/showcases/vae_iaf_normal_sampling_mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/vae_iaf_normal_sampling_mnist.png -------------------------------------------------------------------------------- /examples/showcases/vae_iaf_reconstruction_celeba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/vae_iaf_reconstruction_celeba.png -------------------------------------------------------------------------------- /examples/showcases/vae_iaf_reconstruction_mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/vae_iaf_reconstruction_mnist.png -------------------------------------------------------------------------------- /examples/showcases/vae_lin_nf_normal_sampling_celeba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/vae_lin_nf_normal_sampling_celeba.png -------------------------------------------------------------------------------- /examples/showcases/vae_lin_nf_normal_sampling_mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/vae_lin_nf_normal_sampling_mnist.png -------------------------------------------------------------------------------- /examples/showcases/vae_lin_nf_reconstruction_celeba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/vae_lin_nf_reconstruction_celeba.png -------------------------------------------------------------------------------- /examples/showcases/vae_lin_nf_reconstruction_mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/vae_lin_nf_reconstruction_mnist.png -------------------------------------------------------------------------------- /examples/showcases/vae_maf_sampling_celeba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/vae_maf_sampling_celeba.png -------------------------------------------------------------------------------- /examples/showcases/vae_maf_sampling_mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/vae_maf_sampling_mnist.png -------------------------------------------------------------------------------- /examples/showcases/vae_normal_sampling_celeba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/vae_normal_sampling_celeba.png -------------------------------------------------------------------------------- /examples/showcases/vae_normal_sampling_mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/vae_normal_sampling_mnist.png -------------------------------------------------------------------------------- /examples/showcases/vae_reconstruction_celeba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/vae_reconstruction_celeba.png -------------------------------------------------------------------------------- /examples/showcases/vae_reconstruction_mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/vae_reconstruction_mnist.png -------------------------------------------------------------------------------- /examples/showcases/vae_second_stage_sampling_celeba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/vae_second_stage_sampling_celeba.png -------------------------------------------------------------------------------- /examples/showcases/vae_second_stage_sampling_mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/vae_second_stage_sampling_mnist.png -------------------------------------------------------------------------------- /examples/showcases/vaegan_normal_sampling_celeba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/vaegan_normal_sampling_celeba.png -------------------------------------------------------------------------------- /examples/showcases/vaegan_normal_sampling_mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/vaegan_normal_sampling_mnist.png -------------------------------------------------------------------------------- /examples/showcases/vaegan_reconstruction_celeba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/vaegan_reconstruction_celeba.png -------------------------------------------------------------------------------- /examples/showcases/vaegan_reconstruction_mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/vaegan_reconstruction_mnist.png -------------------------------------------------------------------------------- /examples/showcases/vamp_reconstruction_celeba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/vamp_reconstruction_celeba.png -------------------------------------------------------------------------------- /examples/showcases/vamp_reconstruction_mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/vamp_reconstruction_mnist.png -------------------------------------------------------------------------------- /examples/showcases/vamp_vamp_sampling_celeba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/vamp_vamp_sampling_celeba.png -------------------------------------------------------------------------------- /examples/showcases/vamp_vamp_sampling_mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/vamp_vamp_sampling_mnist.png -------------------------------------------------------------------------------- /examples/showcases/vqvae_maf_sampling_celeba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/vqvae_maf_sampling_celeba.png -------------------------------------------------------------------------------- /examples/showcases/vqvae_maf_sampling_mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/vqvae_maf_sampling_mnist.png -------------------------------------------------------------------------------- /examples/showcases/vqvae_pixelcnn_sampling_mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/vqvae_pixelcnn_sampling_mnist.png -------------------------------------------------------------------------------- /examples/showcases/vqvae_reconstruction_celeba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/vqvae_reconstruction_celeba.png -------------------------------------------------------------------------------- /examples/showcases/vqvae_reconstruction_mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/vqvae_reconstruction_mnist.png -------------------------------------------------------------------------------- /examples/showcases/wae_gmm_sampling_celeba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/wae_gmm_sampling_celeba.png -------------------------------------------------------------------------------- /examples/showcases/wae_gmm_sampling_mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/wae_gmm_sampling_mnist.png -------------------------------------------------------------------------------- /examples/showcases/wae_normal_sampling_celeba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/wae_normal_sampling_celeba.png -------------------------------------------------------------------------------- /examples/showcases/wae_normal_sampling_mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/wae_normal_sampling_mnist.png -------------------------------------------------------------------------------- /examples/showcases/wae_reconstruction_celeba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/wae_reconstruction_celeba.png -------------------------------------------------------------------------------- /examples/showcases/wae_reconstruction_mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/examples/showcases/wae_reconstruction_mnist.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.isort] 2 | multi_line_output = 3 3 | include_trailing_comma = true 4 | force_grid_wrap = 0 5 | use_parentheses = true 6 | ensure_newline_before_comments = true 7 | line_length = 88 8 | 9 | [tool.black] 10 | line-length = 88 11 | target-version = ['py37', 'py38', 'py39'] -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cloudpickle>=2.1.0 2 | imageio 3 | numpy>=1.19 4 | pydantic>=2.0 5 | scikit-learn 6 | scipy>=1.7.1 7 | torch>=1.10.1 8 | tqdm 9 | typing_extensions 10 | dataclasses>=0.6 11 | pickle5; python_version == '3.7.*' -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore= W503 3 | max-line-length = 119 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | with open("README.md", "r", encoding="utf-8") as fh: 4 | long_description = fh.read() 5 | 6 | setup( 7 | name="pythae", 8 | version="0.1.2", 9 | author="Clement Chadebec", 10 | author_email="clement.chadebec@gmail.com", 11 | description="Unifying Generative Autoencoders in Python", 12 | long_description=long_description, 13 | long_description_content_type="text/markdown", 14 | url="https://github.com/clementchadebec/benchmark_VAE", 15 | project_urls={ 16 | "Bug Tracker": "https://github.com/clementchadebec/benchmark_VAE/issues" 17 | }, 18 | classifiers=[ 19 | "Programming Language :: Python :: 3", 20 | "Programming Language :: Python :: 3.7", 21 | "Programming Language :: Python :: 3.8", 22 | "Programming Language :: Python :: 3.9", 23 | "License :: OSI Approved :: Apache Software License", 24 | "Operating System :: OS Independent", 25 | ], 26 | package_dir={"": "src"}, 27 | packages=find_packages(where="src"), 28 | install_requires=[ 29 | "cloudpickle>=2.1.0", 30 | "imageio", 31 | "numpy>=1.19", 32 | "pydantic>=2.0", 33 | "scikit-learn", 34 | "scipy>=1.7.1", 35 | "torch>=1.10.1", 36 | "tqdm", 37 | "typing_extensions", 38 | "dataclasses>=0.6", 39 | ], 40 | extras_require={':python_version == "3.7.*"': ["pickle5"]}, 41 | python_requires=">=3.7", 42 | ) 43 | -------------------------------------------------------------------------------- /src/pythae/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/src/pythae/__init__.py -------------------------------------------------------------------------------- /src/pythae/customexception.py: -------------------------------------------------------------------------------- 1 | class SizeMismatchError(Exception): 2 | pass 3 | 4 | 5 | class LoadError(Exception): 6 | pass 7 | 8 | 9 | class EncoderInputError(Exception): 10 | pass 11 | 12 | 13 | class EncoderOutputError(Exception): 14 | pass 15 | 16 | 17 | class DecoderInputError(Exception): 18 | pass 19 | 20 | 21 | class DecoderOutputError(Exception): 22 | pass 23 | 24 | 25 | class MetricInputError(Exception): 26 | pass 27 | 28 | 29 | class MetricOutputError(Exception): 30 | pass 31 | 32 | 33 | class BadInheritanceError(Exception): 34 | pass 35 | 36 | 37 | class ModelError(Exception): 38 | pass 39 | 40 | 41 | class DatasetError(Exception): 42 | pass 43 | -------------------------------------------------------------------------------- /src/pythae/data/__init__.py: -------------------------------------------------------------------------------- 1 | """This module contrains the methods to load and preprocess the data. 2 | 3 | .. note:: 4 | 5 | As of now, only imaging modality is handled by pythae. In the near future other 6 | modalities should be added. 7 | """ 8 | 9 | from .datasets import BaseDataset 10 | 11 | __all__ = ["BaseDataset"] 12 | -------------------------------------------------------------------------------- /src/pythae/models/adversarial_ae/__init__.py: -------------------------------------------------------------------------------- 1 | """Implementation of an Adversarial Autoencoder model as proposed in 2 | (https://arxiv.org/abs/1511.05644). This model tries to make the posterior distribution match 3 | the prior using adversarial training. 4 | 5 | Available samplers 6 | ------------------- 7 | 8 | .. autosummary:: 9 | ~pythae.samplers.NormalSampler 10 | ~pythae.samplers.GaussianMixtureSampler 11 | ~pythae.samplers.TwoStageVAESampler 12 | ~pythae.samplers.MAFSampler 13 | ~pythae.samplers.IAFSampler 14 | :nosignatures: 15 | 16 | 17 | """ 18 | 19 | from .adversarial_ae_config import Adversarial_AE_Config 20 | from .adversarial_ae_model import Adversarial_AE 21 | 22 | __all__ = ["Adversarial_AE", "Adversarial_AE_Config"] 23 | -------------------------------------------------------------------------------- /src/pythae/models/adversarial_ae/adversarial_ae_config.py: -------------------------------------------------------------------------------- 1 | from pydantic.dataclasses import dataclass 2 | 3 | from ..vae import VAEConfig 4 | 5 | 6 | @dataclass 7 | class Adversarial_AE_Config(VAEConfig): 8 | """Adversarial AE model config class. 9 | 10 | Parameters: 11 | input_dim (tuple): The input_data dimension. 12 | latent_dim (int): The latent space dimension. Default: None. 13 | reconstruction_loss (str): The reconstruction loss to use ['bce', 'mse']. Default: 'mse' 14 | adversarial_loss_scale (float): Parameter scaling the adversarial loss. Default: 0.5 15 | reconstruction_loss_scale (float): Parameter scaling the reconstruction loss. Default: 1 16 | deterministic_posterior (bool): Whether to use a deterministic posterior (Dirac). Default: 17 | False 18 | """ 19 | 20 | adversarial_loss_scale: float = 0.5 21 | reconstruction_loss_scale: float = 1.0 22 | deterministic_posterior: bool = False 23 | uses_default_discriminator: bool = True 24 | discriminator_input_dim: int = None 25 | -------------------------------------------------------------------------------- /src/pythae/models/ae/__init__.py: -------------------------------------------------------------------------------- 1 | """Implementation of a Vanilla Autoencoder model. 2 | 3 | Available samplers 4 | ------------------- 5 | 6 | .. autosummary:: 7 | ~pythae.samplers.NormalSampler 8 | ~pythae.samplers.GaussianMixtureSampler 9 | ~pythae.samplers.MAFSampler 10 | ~pythae.samplers.IAFSampler 11 | :nosignatures: 12 | 13 | """ 14 | 15 | from .ae_config import AEConfig 16 | from .ae_model import AE 17 | 18 | __all__ = ["AE", "AEConfig"] 19 | -------------------------------------------------------------------------------- /src/pythae/models/ae/ae_config.py: -------------------------------------------------------------------------------- 1 | from pydantic.dataclasses import dataclass 2 | 3 | from ..base.base_config import BaseAEConfig 4 | 5 | 6 | @dataclass 7 | class AEConfig(BaseAEConfig): 8 | """This is the autoencoder model configuration instance deriving from 9 | :class:`~BaseAEConfig`. 10 | 11 | Parameters: 12 | input_dim (tuple): The input_data dimension. 13 | latent_dim (int): The latent space dimension. Default: None. 14 | default_encoder (bool): Whether the encoder default. Default: True. 15 | default_decoder (bool): Whether the encoder default. Default: True. 16 | """ 17 | -------------------------------------------------------------------------------- /src/pythae/models/auto_model/__init__.py: -------------------------------------------------------------------------------- 1 | """Utils class allowing to reload any :class:`pythae.models` automatically with the following 2 | lines of code. 3 | 4 | .. code-block:: 5 | 6 | >>> from pythae.models import AutoModel 7 | >>> model = AutoModel.load_from_folder(dir_path='path/to/my_model') 8 | """ 9 | 10 | from .auto_config import AutoConfig 11 | from .auto_model import AutoModel 12 | -------------------------------------------------------------------------------- /src/pythae/models/base/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | **Abstract class** 3 | 4 | This is the base AuteEncoder architecture module from which all future autoencoder based 5 | models should inherit. 6 | 7 | It contains: 8 | 9 | - | a :class:`~pythae.models.base.base_config.BaseAEConfig` instance containing the main model's 10 | parameters (*e.g.* latent dimension ...) 11 | - | a :class:`~pythae.models.BaseAE` instance which creates a BaseAE model having a basic 12 | autoencoding architecture 13 | - | The :class:`~pythae.models.base.base_utils.ModelOutput` instance used for neural nets outputs and 14 | model outputs of the :class:`forward` method). 15 | """ 16 | 17 | from .base_config import BaseAEConfig 18 | from .base_model import BaseAE 19 | 20 | __all__ = ["BaseAE", "BaseAEConfig"] 21 | -------------------------------------------------------------------------------- /src/pythae/models/base/base_config.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Union 2 | 3 | from pydantic.dataclasses import dataclass 4 | 5 | from pythae.config import BaseConfig 6 | 7 | 8 | @dataclass 9 | class BaseAEConfig(BaseConfig): 10 | """This is the base configuration instance of the models deriving from 11 | :class:`~pythae.config.BaseConfig`. 12 | 13 | Parameters: 14 | input_dim (tuple): The input_data dimension (channels X x_dim X y_dim) 15 | latent_dim (int): The latent space dimension. Default: None. 16 | """ 17 | 18 | input_dim: Union[Tuple[int, ...], None] = None 19 | latent_dim: int = 10 20 | uses_default_encoder: bool = True 21 | uses_default_decoder: bool = True 22 | 23 | 24 | @dataclass 25 | class EnvironmentConfig(BaseConfig): 26 | python_version: str = "3.8" 27 | -------------------------------------------------------------------------------- /src/pythae/models/beta_tc_vae/__init__.py: -------------------------------------------------------------------------------- 1 | """This module is the implementation of the BetaTCVAE proposed in 2 | (https://arxiv.org/abs/1802.04942). 3 | 4 | 5 | Available samplers 6 | ------------------- 7 | 8 | .. autosummary:: 9 | ~pythae.samplers.NormalSampler 10 | ~pythae.samplers.GaussianMixtureSampler 11 | ~pythae.samplers.TwoStageVAESampler 12 | ~pythae.samplers.MAFSampler 13 | ~pythae.samplers.IAFSampler 14 | :nosignatures: 15 | """ 16 | 17 | from .beta_tc_vae_config import BetaTCVAEConfig 18 | from .beta_tc_vae_model import BetaTCVAE 19 | 20 | __all__ = ["BetaTCVAE", "BetaTCVAEConfig"] 21 | -------------------------------------------------------------------------------- /src/pythae/models/beta_tc_vae/beta_tc_vae_config.py: -------------------------------------------------------------------------------- 1 | from pydantic.dataclasses import dataclass 2 | 3 | from ..vae import VAEConfig 4 | 5 | 6 | @dataclass 7 | class BetaTCVAEConfig(VAEConfig): 8 | r""" 9 | :math:`\beta`-TCVAE model config config class 10 | 11 | Parameters: 12 | input_dim (tuple): The input_data dimension. 13 | latent_dim (int): The latent space dimension. Default: None. 14 | reconstruction_loss (str): The reconstruction loss to use ['bce', 'mse']. Default: 'mse' 15 | alpha (float): The balancing factor before the Index code Mutual Info. Default: 1 16 | beta (float): The balancing factor before the Total Correlation. Default: 1 17 | gamma (float): The balancing factor before the dimension-wise KL. Default: 1 18 | use_mss (bool): Use Minibatch Stratified Sampling. If False: uses Minibatch Weighted 19 | Sampling. Default: True 20 | """ 21 | alpha: float = 1.0 22 | beta: float = 1.0 23 | gamma: float = 1.0 24 | use_mss: bool = True 25 | -------------------------------------------------------------------------------- /src/pythae/models/beta_vae/__init__.py: -------------------------------------------------------------------------------- 1 | """This module is the implementation of the BetaVAE proposed in 2 | (https://openreview.net/pdf?id=Sy2fzU9gl). 3 | This model adds a new parameter to the VAE loss function balancing the weight of the 4 | reconstruction term and KL term. 5 | 6 | 7 | Available samplers 8 | ------------------- 9 | 10 | .. autosummary:: 11 | ~pythae.samplers.NormalSampler 12 | ~pythae.samplers.GaussianMixtureSampler 13 | ~pythae.samplers.TwoStageVAESampler 14 | ~pythae.samplers.MAFSampler 15 | ~pythae.samplers.IAFSampler 16 | :nosignatures: 17 | """ 18 | 19 | from .beta_vae_config import BetaVAEConfig 20 | from .beta_vae_model import BetaVAE 21 | 22 | __all__ = ["BetaVAE", "BetaVAEConfig"] 23 | -------------------------------------------------------------------------------- /src/pythae/models/beta_vae/beta_vae_config.py: -------------------------------------------------------------------------------- 1 | from pydantic.dataclasses import dataclass 2 | 3 | from ..vae import VAEConfig 4 | 5 | 6 | @dataclass 7 | class BetaVAEConfig(VAEConfig): 8 | r""" 9 | :math:`\beta`-VAE model config config class 10 | 11 | Parameters: 12 | input_dim (tuple): The input_data dimension. 13 | latent_dim (int): The latent space dimension. Default: None. 14 | reconstruction_loss (str): The reconstruction loss to use ['bce', 'mse']. Default: 'mse' 15 | beta (float): The balancing factor. Default: 1 16 | """ 17 | 18 | beta: float = 1.0 19 | -------------------------------------------------------------------------------- /src/pythae/models/ciwae/__init__.py: -------------------------------------------------------------------------------- 1 | """This module is the implementation of the Combination Importance Weighted Autoencoder 2 | proposed in (https://arxiv.org/abs/1802.04537). 3 | 4 | Available samplers 5 | ------------------- 6 | 7 | .. autosummary:: 8 | ~pythae.samplers.NormalSampler 9 | ~pythae.samplers.GaussianMixtureSampler 10 | ~pythae.samplers.TwoStageVAESampler 11 | ~pythae.samplers.MAFSampler 12 | ~pythae.samplers.IAFSampler 13 | :nosignatures: 14 | 15 | """ 16 | 17 | from .ciwae_config import CIWAEConfig 18 | from .ciwae_model import CIWAE 19 | 20 | __all__ = ["CIWAE", "CIWAEConfig"] 21 | -------------------------------------------------------------------------------- /src/pythae/models/ciwae/ciwae_config.py: -------------------------------------------------------------------------------- 1 | from pydantic.dataclasses import dataclass 2 | 3 | from ..vae import VAEConfig 4 | 5 | 6 | @dataclass 7 | class CIWAEConfig(VAEConfig): 8 | """Combination IWAE model config class. 9 | 10 | Parameters: 11 | input_dim (tuple): The input_data dimension. 12 | latent_dim (int): The latent space dimension. Default: None. 13 | reconstruction_loss (str): The reconstruction loss to use ['bce', 'mse']. Default: 'mse' 14 | number_samples (int): Number of samples to use on the Monte-Carlo estimation 15 | beta (float): The value of the factor in the convex combination of the VAE and IWAE ELBO. 16 | Default: 0.5. 17 | """ 18 | 19 | number_samples: int = 10 20 | beta: float = 0.5 21 | 22 | def __post_init__(self): 23 | super().__post_init__() 24 | assert 0 <= self.beta <= 1, f"Beta parameter must be in [0-1]. Got {self.beta}." 25 | -------------------------------------------------------------------------------- /src/pythae/models/disentangled_beta_vae/__init__.py: -------------------------------------------------------------------------------- 1 | r"""This module is the implementation of the Disentangled Beta VAE proposed in 2 | (https://arxiv.org/abs/1804.03599). 3 | This model adds a new parameter to the :math:`\beta`-VAE loss function corresponding to the target 4 | value for the KL between the prior and the posterior distribution. It is progressively increased 5 | throughout training. 6 | 7 | 8 | Available samplers 9 | ------------------- 10 | 11 | .. autosummary:: 12 | ~pythae.samplers.NormalSampler 13 | ~pythae.samplers.GaussianMixtureSampler 14 | ~pythae.samplers.TwoStageVAESampler 15 | ~pythae.samplers.MAFSampler 16 | ~pythae.samplers.IAFSampler 17 | :nosignatures: 18 | """ 19 | 20 | from .disentangled_beta_vae_config import DisentangledBetaVAEConfig 21 | from .disentangled_beta_vae_model import DisentangledBetaVAE 22 | 23 | __all__ = ["DisentangledBetaVAE", "DisentangledBetaVAEConfig"] 24 | -------------------------------------------------------------------------------- /src/pythae/models/disentangled_beta_vae/disentangled_beta_vae_config.py: -------------------------------------------------------------------------------- 1 | from pydantic.dataclasses import dataclass 2 | 3 | from ..vae import VAEConfig 4 | 5 | 6 | @dataclass 7 | class DisentangledBetaVAEConfig(VAEConfig): 8 | r""" 9 | Disentangled :math:`\beta`-VAE model config config class 10 | 11 | Parameters: 12 | input_dim (tuple): The input_data dimension. 13 | latent_dim (int): The latent space dimension. Default: None. 14 | reconstruction_loss (str): The reconstruction loss to use ['bce', 'mse']. Default: 'mse' 15 | beta (float): The balancing factor. Default: 10. 16 | C (float): The value of the KL divergence term of the ELBO we wish to approach, measured in 17 | nats. Default: 50. 18 | warmup_epoch (int): The number of epochs during which the KL divergence objective will 19 | increase from 0 to C (should be smaller or equal to nb_epochs). Default: 100 20 | epoch (int): The current epoch. Default: 0 21 | """ 22 | beta: float = 10.0 23 | C: float = 50.0 24 | warmup_epoch: int = 25 25 | -------------------------------------------------------------------------------- /src/pythae/models/factor_vae/__init__.py: -------------------------------------------------------------------------------- 1 | """This module is the implementation of the FactorVAE proposed in 2 | (https://arxiv.org/abs/1802.05983). 3 | This model adds a new parameter to the VAE loss function balancing the weight of the 4 | reconstruction term and the Total Correlation. 5 | 6 | 7 | Available samplers 8 | ------------------- 9 | 10 | .. autosummary:: 11 | ~pythae.samplers.NormalSampler 12 | ~pythae.samplers.GaussianMixtureSampler 13 | ~pythae.samplers.TwoStageVAESampler 14 | ~pythae.samplers.MAFSampler 15 | ~pythae.samplers.IAFSampler 16 | :nosignatures: 17 | """ 18 | 19 | from .factor_vae_config import FactorVAEConfig 20 | from .factor_vae_model import FactorVAE 21 | 22 | __all__ = ["FactorVAE", "FactorVAEConfig"] 23 | -------------------------------------------------------------------------------- /src/pythae/models/factor_vae/factor_vae_config.py: -------------------------------------------------------------------------------- 1 | from pydantic.dataclasses import dataclass 2 | 3 | from ..vae import VAEConfig 4 | 5 | 6 | @dataclass 7 | class FactorVAEConfig(VAEConfig): 8 | r""" 9 | FactorVAE model config config class 10 | 11 | Parameters: 12 | input_dim (tuple): The input_data dimension. 13 | latent_dim (int): The latent space dimension. Default: None. 14 | reconstruction_loss (str): The reconstruction loss to use ['bce', 'mse']. Default: 'mse' 15 | gamma (float): The balancing factor before the Total Correlation. Default: 0.5 16 | """ 17 | gamma: float = 2.0 18 | uses_default_discriminator: bool = True 19 | -------------------------------------------------------------------------------- /src/pythae/models/factor_vae/factor_vae_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class FactorVAEDiscriminator(nn.Module): 6 | def __init__(self, latent_dim=16, hidden_units=1000) -> None: 7 | nn.Module.__init__(self) 8 | 9 | self.layers = nn.Sequential( 10 | nn.Linear(latent_dim, hidden_units), 11 | nn.LeakyReLU(0.2), 12 | nn.Linear(hidden_units, hidden_units), 13 | nn.LeakyReLU(0.2), 14 | nn.Linear(hidden_units, hidden_units), 15 | nn.LeakyReLU(0.2), 16 | nn.Linear(hidden_units, hidden_units), 17 | nn.LeakyReLU(0.2), 18 | nn.Linear(hidden_units, hidden_units), 19 | nn.LeakyReLU(0.2), 20 | nn.Linear(hidden_units, 2), 21 | ) 22 | 23 | def forward(self, z: torch.Tensor): 24 | return self.layers(z) 25 | -------------------------------------------------------------------------------- /src/pythae/models/hrq_vae/__init__.py: -------------------------------------------------------------------------------- 1 | """This module is the implementation of the Vector Quantized VAE proposed in 2 | (https://arxiv.org/abs/1711.00937). 3 | 4 | Available samplers 5 | ------------------- 6 | 7 | Normalizing flows sampler to come. 8 | 9 | .. autosummary:: 10 | ~pythae.samplers.GaussianMixtureSampler 11 | ~pythae.samplers.MAFSampler 12 | ~pythae.samplers.IAFSampler 13 | :nosignatures: 14 | """ 15 | 16 | from .hrq_vae_config import HRQVAEConfig 17 | from .hrq_vae_model import HRQVAE 18 | 19 | __all__ = ["HRQVAE", "HRQVAEConfig"] 20 | -------------------------------------------------------------------------------- /src/pythae/models/hrq_vae/hrq_vae_config.py: -------------------------------------------------------------------------------- 1 | from pydantic.dataclasses import dataclass 2 | from typing import Optional 3 | 4 | from ..ae import AEConfig 5 | 6 | 7 | @dataclass 8 | class HRQVAEConfig(AEConfig): 9 | r""" 10 | Hierarchical Residual Quantization VAE model config config class 11 | 12 | Parameters: 13 | input_dim (tuple): The input_data dimension. 14 | latent_dim (int): The latent space dimension. Default: 10. 15 | num_embedding (int): The number of embedding points. Default: 64 16 | num_levels (int): Depth of hierarchy. Default: 4 17 | kl_weight (float): Weighting of the KL term. Default: 0.1 18 | init_scale (float): Magnitude of the embedding initialisation, should roughly match the encoder. Default: 1.0 19 | init_decay_weight (float): Factor by which the magnitude of each successive levels is multiplied. Default: 1.5 20 | norm_loss_weight (float): Weighting of the norm loss term. Default: 0.5 21 | norm_loss_scale (float): Scale for the norm loss. Default: 1.5 22 | temp_schedule_gamma (float): Decay constant for the Gumbel temperature - will be (epoch/gamma). Default: 33.333 23 | depth_drop_rate (float): Probability of dropping each level during training. Default: 0.1 24 | """ 25 | 26 | num_embeddings: int = 64 27 | num_levels: int = 4 28 | kl_weight: float = 0.1 29 | init_scale: float = 1.0 30 | init_decay_weight: float = 0.5 31 | norm_loss_weight: Optional[float] = 0.5 32 | norm_loss_scale: float = 1.5 33 | temp_schedule_gamma: float = 33.333 34 | depth_drop_rate: float = 0.1 35 | 36 | def __post_init__(self): 37 | super().__post_init__() 38 | -------------------------------------------------------------------------------- /src/pythae/models/hvae/__init__.py: -------------------------------------------------------------------------------- 1 | """This module is the implementation of the Hamiltonian VAE proposed in 2 | (https://arxiv.org/abs/1805.11328). 3 | This model combines Hamiltonian Monte Carlo smapling and normalizing flows together to improve the 4 | true posterior estimate within the VAE framework. 5 | 6 | Available samplers 7 | ------------------- 8 | 9 | .. autosummary:: 10 | ~pythae.samplers.NormalSampler 11 | ~pythae.samplers.GaussianMixtureSampler 12 | ~pythae.samplers.TwoStageVAESampler 13 | ~pythae.samplers.MAFSampler 14 | ~pythae.samplers.IAFSampler 15 | :nosignatures: 16 | """ 17 | 18 | from .hvae_config import HVAEConfig 19 | from .hvae_model import HVAE 20 | 21 | __all__ = ["HVAE", "HVAEConfig"] 22 | -------------------------------------------------------------------------------- /src/pythae/models/hvae/hvae_config.py: -------------------------------------------------------------------------------- 1 | from pydantic.dataclasses import dataclass 2 | 3 | from ..vae import VAEConfig 4 | 5 | 6 | @dataclass 7 | class HVAEConfig(VAEConfig): 8 | r"""Hamiltonian Variational Autoencoder config class. 9 | 10 | Parameters: 11 | latent_dim (int): The latent dimension used for the latent space. Default: 10 12 | n_lf (int): The number of leapfrog steps to used in the integrator: Default: 3 13 | eps_lf (int): The leapfrog stepsize. Default: 1e-3 14 | beta_zero (int): The tempering factor in the Riemannian Hamiltonian Monte Carlo Sampler. 15 | Default: 0.3 16 | learn_eps_lf (bool): Whether the leapfrog stepsize should be learned. Default: False 17 | learn_beta_zero (bool): Whether the temperature betazero should be learned. Default: False. 18 | """ 19 | n_lf: int = 3 20 | eps_lf: float = 0.001 21 | beta_zero: float = 0.3 22 | learn_eps_lf: bool = False 23 | learn_beta_zero: bool = False 24 | -------------------------------------------------------------------------------- /src/pythae/models/info_vae/__init__.py: -------------------------------------------------------------------------------- 1 | """This module is the implementation of a Info Variational Auto Encoder as proposed in 2 | (https://arxiv.org/abs/1706.02262) 3 | 4 | Available samplers 5 | ------------------- 6 | 7 | .. autosummary:: 8 | ~pythae.samplers.NormalSampler 9 | ~pythae.samplers.GaussianMixtureSampler 10 | ~pythae.samplers.TwoStageVAESampler 11 | ~pythae.samplers.MAFSampler 12 | ~pythae.samplers.IAFSampler 13 | :nosignatures: 14 | 15 | """ 16 | 17 | from .info_vae_config import INFOVAE_MMD_Config 18 | from .info_vae_model import INFOVAE_MMD 19 | 20 | __all__ = ["INFOVAE_MMD", "INFOVAE_MMD_Config"] 21 | -------------------------------------------------------------------------------- /src/pythae/models/info_vae/info_vae_config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import field 2 | from typing import List, Union 3 | 4 | from pydantic.dataclasses import dataclass 5 | from typing_extensions import Literal 6 | 7 | from ..vae import VAEConfig 8 | 9 | 10 | @dataclass 11 | class INFOVAE_MMD_Config(VAEConfig): 12 | """Info-VAE model config class. 13 | 14 | Parameters: 15 | input_dim (tuple): The input_data dimension. 16 | latent_dim (int): The latent space dimension. Default: None. 17 | reconstruction_loss (str): The reconstruction loss to use ['bce', 'mse']. Default: 'mse' 18 | kernel_choice (str): The kernel to choose. Available options are ['rbf', 'imq'] i.e. 19 | radial basis functions or inverse multiquadratic kernel. Default: 'imq'. 20 | alpha (float): The alpha factor balancing the weigth: Default: 0.5 21 | lbd (float): The lambda factor. Default: 3e-2 22 | kernel_bandwidth (float): The kernel bandwidth. Default: 1 23 | scales (list): The scales to apply if using multi-scale imq kernels. If None, use a unique 24 | imq kernel. Default: [.1, .2, .5, 1., 2., 5, 10.]. 25 | """ 26 | 27 | kernel_choice: Literal["rbf", "imq"] = "imq" 28 | alpha: float = 0 29 | lbd: float = 1e-3 30 | kernel_bandwidth: float = 1.0 31 | scales: Union[List[float], None] = field( 32 | default_factory=lambda: [0.1, 0.2, 0.5, 1.0, 2.0, 5, 10.0] 33 | ) 34 | -------------------------------------------------------------------------------- /src/pythae/models/iwae/__init__.py: -------------------------------------------------------------------------------- 1 | """This module is the implementation of the Importance Weighted Autoencoder 2 | proposed in (https://arxiv.org/abs/1509.00519v4). 3 | 4 | Available samplers 5 | ------------------- 6 | 7 | .. autosummary:: 8 | ~pythae.samplers.NormalSampler 9 | ~pythae.samplers.GaussianMixtureSampler 10 | ~pythae.samplers.TwoStageVAESampler 11 | ~pythae.samplers.MAFSampler 12 | ~pythae.samplers.IAFSampler 13 | :nosignatures: 14 | 15 | """ 16 | 17 | from .iwae_config import IWAEConfig 18 | from .iwae_model import IWAE 19 | 20 | __all__ = ["IWAE", "IWAEConfig"] 21 | -------------------------------------------------------------------------------- /src/pythae/models/iwae/iwae_config.py: -------------------------------------------------------------------------------- 1 | from pydantic.dataclasses import dataclass 2 | 3 | from ..vae import VAEConfig 4 | 5 | 6 | @dataclass 7 | class IWAEConfig(VAEConfig): 8 | """IWAE model config class. 9 | 10 | Parameters: 11 | input_dim (tuple): The input_data dimension. 12 | latent_dim (int): The latent space dimension. Default: None. 13 | reconstruction_loss (str): The reconstruction loss to use ['bce', 'mse']. Default: 'mse' 14 | number_samples (int): Number of samples to use on the Monte-Carlo estimation. Default: 10 15 | """ 16 | 17 | number_samples: int = 10 18 | -------------------------------------------------------------------------------- /src/pythae/models/miwae/__init__.py: -------------------------------------------------------------------------------- 1 | """This module is the implementation of the Multiply Importance Weighted Autoencoder 2 | proposed in (https://arxiv.org/abs/1802.04537). 3 | 4 | Available samplers 5 | ------------------- 6 | 7 | .. autosummary:: 8 | ~pythae.samplers.NormalSampler 9 | ~pythae.samplers.GaussianMixtureSampler 10 | ~pythae.samplers.TwoStageVAESampler 11 | ~pythae.samplers.MAFSampler 12 | ~pythae.samplers.IAFSampler 13 | :nosignatures: 14 | 15 | """ 16 | 17 | from .miwae_config import MIWAEConfig 18 | from .miwae_model import MIWAE 19 | 20 | __all__ = ["MIWAE", "MIWAEConfig"] 21 | -------------------------------------------------------------------------------- /src/pythae/models/miwae/miwae_config.py: -------------------------------------------------------------------------------- 1 | from pydantic.dataclasses import dataclass 2 | 3 | from ..vae import VAEConfig 4 | 5 | 6 | @dataclass 7 | class MIWAEConfig(VAEConfig): 8 | """Multiply IWAE model config class. 9 | 10 | Parameters: 11 | input_dim (tuple): The input_data dimension. 12 | latent_dim (int): The latent space dimension. Default: None. 13 | reconstruction_loss (str): The reconstruction loss to use ['bce', 'mse']. Default: 'mse' 14 | number_gradient_estimates (int): Number of (M-)estimates to use for the gradient 15 | estimate. Default: 5 16 | number_samples (int): Number of samples to use on the Monte-Carlo estimation. Default: 10 17 | """ 18 | 19 | number_gradient_estimates: int = 5 20 | number_samples: int = 10 21 | -------------------------------------------------------------------------------- /src/pythae/models/msssim_vae/__init__.py: -------------------------------------------------------------------------------- 1 | """This module is the implementation of the MSSSIM_VAE proposed in 2 | (https://arxiv.org/abs/1511.06409). 3 | This models uses a perceptual similarity metric for reconstruction. 4 | 5 | 6 | Available samplers 7 | ------------------- 8 | 9 | .. autosummary:: 10 | ~pythae.samplers.NormalSampler 11 | ~pythae.samplers.GaussianMixtureSampler 12 | ~pythae.samplers.TwoStageVAESampler 13 | ~pythae.samplers.MAFSampler 14 | ~pythae.samplers.IAFSampler 15 | :nosignatures: 16 | """ 17 | 18 | from .msssim_vae_config import MSSSIM_VAEConfig 19 | from .msssim_vae_model import MSSSIM_VAE 20 | 21 | __all__ = ["MSSSIM_VAE", "MSSSIM_VAEConfig"] 22 | -------------------------------------------------------------------------------- /src/pythae/models/msssim_vae/msssim_vae_config.py: -------------------------------------------------------------------------------- 1 | from pydantic.dataclasses import dataclass 2 | 3 | from ..vae import VAEConfig 4 | 5 | 6 | @dataclass 7 | class MSSSIM_VAEConfig(VAEConfig): 8 | r""" 9 | VAE using using perceptual similarity metrics (MS-SSIM) model config config class 10 | 11 | Parameters: 12 | input_dim (tuple): The input_data dimension. 13 | latent_dim (int): The latent space dimension. Default: None. 14 | reconstruction_loss (str): The reconstruction loss to use ['bce', 'mse']. Default: 'mse' 15 | beta (float): The balancing factor. Default: 1 16 | """ 17 | 18 | beta: float = 1.0 19 | window_size: int = 11 20 | -------------------------------------------------------------------------------- /src/pythae/models/nn/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | In this module are stored the main Neural Networks Architectures. 3 | """ 4 | 5 | 6 | from .base_architectures import BaseDecoder, BaseDiscriminator, BaseEncoder, BaseMetric 7 | 8 | __all__ = ["BaseDecoder", "BaseEncoder", "BaseMetric", "BaseDiscriminator"] 9 | -------------------------------------------------------------------------------- /src/pythae/models/nn/benchmarks/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | In this model are gathered some predefined neural nets architectures that may be used on 3 | benchmark datasets such as MNIST, CIFAR or CELEBA 4 | """ 5 | -------------------------------------------------------------------------------- /src/pythae/models/nn/benchmarks/celeba/__init__.py: -------------------------------------------------------------------------------- 1 | """A collection of Neural nets used to perform the benchmark on CELEBA""" 2 | 3 | from .convnets import * 4 | from .resnets import * 5 | 6 | __all__ = [ 7 | "Encoder_Conv_AE_CELEBA", 8 | "Encoder_Conv_VAE_CELEBA", 9 | "Encoder_Conv_SVAE_CELEBA", 10 | "Decoder_Conv_AE_CELEBA", 11 | "Discriminator_Conv_CELEBA", 12 | "Encoder_ResNet_AE_CELEBA", 13 | "Encoder_ResNet_VAE_CELEBA", 14 | "Encoder_ResNet_SVAE_CELEBA", 15 | "Encoder_ResNet_VQVAE_CELEBA", 16 | "Decoder_ResNet_AE_CELEBA", 17 | "Decoder_ResNet_VQVAE_CELEBA", 18 | ] 19 | -------------------------------------------------------------------------------- /src/pythae/models/nn/benchmarks/cifar/__init__.py: -------------------------------------------------------------------------------- 1 | """A collection of Neural nets used to perform the benchmark on CIFAR""" 2 | 3 | from .convnets import * 4 | from .resnets import * 5 | 6 | __all__ = [ 7 | "Encoder_Conv_AE_CIFAR", 8 | "Encoder_Conv_VAE_CIFAR", 9 | "Encoder_Conv_SVAE_CIFAR", 10 | "Decoder_Conv_AE_CIFAR", 11 | "Discriminator_Conv_CIFAR", 12 | "Encoder_ResNet_AE_CIFAR", 13 | "Encoder_ResNet_VAE_CIFAR", 14 | "Encoder_ResNet_SVAE_CIFAR", 15 | "Encoder_ResNet_VQVAE_CIFAR", 16 | "Decoder_ResNet_AE_CIFAR", 17 | "Decoder_ResNet_VQVAE_CIFAR", 18 | ] 19 | -------------------------------------------------------------------------------- /src/pythae/models/nn/benchmarks/mnist/__init__.py: -------------------------------------------------------------------------------- 1 | """A collection of Neural nets used to perform the benchmark on MNIST""" 2 | 3 | from .convnets import * 4 | from .resnets import * 5 | 6 | __all__ = [ 7 | "Encoder_Conv_AE_MNIST", 8 | "Encoder_Conv_VAE_MNIST", 9 | "Encoder_Conv_SVAE_MNIST", 10 | "Decoder_Conv_AE_MNIST", 11 | "Discriminator_Conv_MNIST", 12 | "Encoder_ResNet_AE_MNIST", 13 | "Encoder_ResNet_VAE_MNIST", 14 | "Encoder_ResNet_SVAE_MNIST", 15 | "Encoder_ResNet_VQVAE_MNIST", 16 | "Decoder_ResNet_AE_MNIST", 17 | "Decoder_ResNet_VQVAE_MNIST", 18 | ] 19 | -------------------------------------------------------------------------------- /src/pythae/models/nn/benchmarks/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class ResBlock(nn.Module): 6 | def __init__(self, in_channels, out_channels): 7 | nn.Module.__init__(self) 8 | 9 | self.conv_block = nn.Sequential( 10 | nn.ReLU(), 11 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1), 12 | nn.ReLU(), 13 | nn.Conv2d(out_channels, in_channels, kernel_size=1, stride=1, padding=0), 14 | ) 15 | 16 | def forward(self, x: torch.tensor) -> torch.Tensor: 17 | return x + self.conv_block(x) 18 | -------------------------------------------------------------------------------- /src/pythae/models/normalizing_flows/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | In this module are implemented so normalizing flows that can be used within the vae model or for 3 | sampling. 4 | 5 | By convention, each implemented model is stored in a folder located in 6 | :class:`pythae.models.normalizing_flows` and named likewise the model. The following modules can be 7 | found in this folder: 8 | 9 | - | *modelname_config.py*: Contains a :class:`ModelNameConfig` instance inheriting 10 | from :class:`~pythae.models.normalizing_flows.BaseNF`. 11 | - | *modelname_model.py*: An implementation of the model inheriting either from 12 | :class:`~pythae.models.normalizing_flows.BaseNFConfig` 13 | - *modelname_utils.py* (optional): A module where utils methods are stored. 14 | 15 | All the codes are inspired from: 16 | 17 | - | https://github.com/kamenbliznashki/normalizing_flows 18 | - | https://github.com/karpathy/pytorch-normalizing-flows 19 | - | https://github.com/ikostrikov/pytorch-flows 20 | """ 21 | 22 | from .base import BaseNF, BaseNFConfig, NFModel 23 | from .iaf import IAF, IAFConfig 24 | from .made import MADE, MADEConfig 25 | from .maf import MAF, MAFConfig 26 | from .pixelcnn import PixelCNN, PixelCNNConfig 27 | from .planar_flow import PlanarFlow, PlanarFlowConfig 28 | from .radial_flow import RadialFlow, RadialFlowConfig 29 | 30 | __all__ = [ 31 | "BaseNF", 32 | "BaseNFConfig", 33 | "NFModel", 34 | "MADE", 35 | "MADEConfig", 36 | "MAF", 37 | "MAFConfig", 38 | "IAF", 39 | "IAFConfig", 40 | "PlanarFlow", 41 | "PlanarFlowConfig", 42 | "RadialFlow", 43 | "RadialFlowConfig", 44 | "PixelCNN", 45 | "PixelCNNConfig", 46 | ] 47 | -------------------------------------------------------------------------------- /src/pythae/models/normalizing_flows/base/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | **Abstract class** 4 | 5 | Base module for Normalizing Flows implementation""" 6 | 7 | from .base_nf_config import BaseNFConfig 8 | from .base_nf_model import BaseNF, NFModel 9 | 10 | __all__ = ["NFModel", "BaseNF", "BaseNFConfig"] 11 | -------------------------------------------------------------------------------- /src/pythae/models/normalizing_flows/base/base_nf_config.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Union 2 | 3 | from pydantic.dataclasses import dataclass 4 | 5 | from pythae.config import BaseConfig 6 | 7 | 8 | @dataclass 9 | class BaseNFConfig(BaseConfig): 10 | """This is the Base Normalizing Flow config instance. 11 | 12 | Parameters: 13 | input_dim (tuple): The input data dimension. Default: None. 14 | """ 15 | 16 | input_dim: Union[Tuple[int, ...], None] = None 17 | -------------------------------------------------------------------------------- /src/pythae/models/normalizing_flows/iaf/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of the Inverse Autoregressive Flows (IAF) proposed 3 | in (https://arxiv.org/abs/1606.04934). 4 | """ 5 | 6 | from .iaf_config import IAFConfig 7 | from .iaf_model import IAF 8 | 9 | __all__ = ["IAF", "IAFConfig"] 10 | -------------------------------------------------------------------------------- /src/pythae/models/normalizing_flows/iaf/iaf_config.py: -------------------------------------------------------------------------------- 1 | from pydantic.dataclasses import dataclass 2 | 3 | from ..base import BaseNFConfig 4 | 5 | 6 | @dataclass 7 | class IAFConfig(BaseNFConfig): 8 | """This is the MADE model configuration instance. 9 | 10 | Parameters: 11 | input_dim (tuple): The input data dimension. Default: None. 12 | n_made_blocks (int): The number of MADE model to consider in the IAF. Default: 2. 13 | n_hidden_in_made (int): The number of hidden layers in the MADE models. Default: 3. 14 | hidden_size (list): The number of unit in each hidder layer. The same number of units is 15 | used across the `n_hidden_in_made` and `n_made_blocks`. Default: 128. 16 | include_batch_norm (bool): Whether to include batch normalization after each 17 | :class:`~pythae.models.normalizing_flows.MADE` layers. Default: False. 18 | """ 19 | 20 | n_made_blocks: int = 2 21 | n_hidden_in_made: int = 3 22 | hidden_size: int = 128 23 | include_batch_norm: bool = False 24 | -------------------------------------------------------------------------------- /src/pythae/models/normalizing_flows/made/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of the Masked Autoencoder model (MADE) proposed 3 | in (https://arxiv.org/abs/1502.03509). 4 | """ 5 | 6 | from .made_config import MADEConfig 7 | from .made_model import MADE 8 | 9 | __all__ = ["MADE", "MADEConfig"] 10 | -------------------------------------------------------------------------------- /src/pythae/models/normalizing_flows/made/made_config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import field 2 | from typing import List, Tuple, Union 3 | 4 | from pydantic.dataclasses import dataclass 5 | from typing_extensions import Literal 6 | 7 | from ..base import BaseNFConfig 8 | 9 | 10 | @dataclass 11 | class MADEConfig(BaseNFConfig): 12 | """This is the MADE model configuration instance. 13 | 14 | Parameters: 15 | input_dim (tuple): The input data dimension. Default: None. 16 | output_dim (tuple): The output data dimension. Default: None. 17 | hidden_sizes (list): The list of the number of hidden units in the Autoencoder. 18 | Default: [128]. 19 | degrees_ordering (str): The ordering to use for the mask creation. Can be either 20 | `sequential` or `random`. Default: `sequential`. 21 | """ 22 | 23 | output_dim: Union[Tuple[int, ...], None] = None 24 | hidden_sizes: List[int] = field(default_factory=lambda: [128]) 25 | degrees_ordering: Literal["sequential", "random"] = "sequential" 26 | -------------------------------------------------------------------------------- /src/pythae/models/normalizing_flows/maf/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of the Masked Autoregressive Flow (MAF) proposed 3 | in (https://arxiv.org/abs/1502.03509) 4 | """ 5 | 6 | from .maf_config import MAFConfig 7 | from .maf_model import MAF 8 | 9 | __all__ = ["MAF", "MAFConfig"] 10 | -------------------------------------------------------------------------------- /src/pythae/models/normalizing_flows/maf/maf_config.py: -------------------------------------------------------------------------------- 1 | from pydantic.dataclasses import dataclass 2 | 3 | from ..base import BaseNFConfig 4 | 5 | 6 | @dataclass 7 | class MAFConfig(BaseNFConfig): 8 | """This is the Masked Autoregressive Flows model configuration instance. 9 | 10 | Parameters: 11 | input_dim (tuple): The input data dimension. Default: None. 12 | n_made_blocks (int): The number of MADE model to consider in the MAF. Default: 2. 13 | n_hidden_in_made (int): The number of hidden layers in the MADE models. Default: 3. 14 | hidden_size (list): The number of unit in each hidder layer. The same number of units is 15 | used across the `n_hidden_in_made` and `n_made_blocks` 16 | include_batch_norm (bool): Whether to include batch normalization after each 17 | :class:`~pythae.models.normalizing_flows.MADE` layers. Default: False. 18 | """ 19 | 20 | n_made_blocks: int = 2 21 | n_hidden_in_made: int = 3 22 | hidden_size: int = 128 23 | include_batch_norm: bool = False 24 | -------------------------------------------------------------------------------- /src/pythae/models/normalizing_flows/pixelcnn/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of PixelCNN model proposed 3 | in (https://arxiv.org/abs/1601.06759) 4 | """ 5 | 6 | from .pixelcnn_config import PixelCNNConfig 7 | from .pixelcnn_model import PixelCNN 8 | 9 | __all__ = ["PixelCNN", "PixelCNNConfig"] 10 | -------------------------------------------------------------------------------- /src/pythae/models/normalizing_flows/pixelcnn/pixelcnn_config.py: -------------------------------------------------------------------------------- 1 | from pydantic.dataclasses import dataclass 2 | 3 | from ..base import BaseNFConfig 4 | 5 | 6 | @dataclass 7 | class PixelCNNConfig(BaseNFConfig): 8 | """This is the PixelCNN model configuration instance. 9 | 10 | Parameters: 11 | input_dim (tuple): The input data dimension. Default: None. 12 | n_embeddings (int): The number of possible values for the image. Default: 256. 13 | n_layers (int): The number of convolutional layers in the model. Default: 10. 14 | kernel_size (int): The kernel size in the convolutional layers. It must be odd. Default: 5 15 | """ 16 | 17 | n_embeddings: int = 256 18 | n_layers: int = 10 19 | kernel_size: int = 5 20 | 21 | def __post_init__(self): 22 | super().__post_init__() 23 | assert ( 24 | self.kernel_size % 2 == 1 25 | ), f"Wrong kernel size provided. The kernel size must be odd. Got {self.kernel_size}." 26 | -------------------------------------------------------------------------------- /src/pythae/models/normalizing_flows/pixelcnn/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class MaskedConv2d(nn.Conv2d): 9 | def __init__( 10 | self, 11 | mask_type: str, 12 | in_channels: int, 13 | out_channels: int, 14 | kernel_size: int, 15 | stride: int = 1, 16 | padding: Union[str, int] = 0, 17 | dilation: int = 1, 18 | groups: int = 1, 19 | bias: bool = True, 20 | padding_mode: str = "zeros", 21 | device=None, 22 | dtype=None, 23 | ) -> None: 24 | super().__init__( 25 | in_channels, 26 | out_channels, 27 | kernel_size, 28 | stride, 29 | padding, 30 | dilation, 31 | groups, 32 | bias, 33 | padding_mode, 34 | device, 35 | dtype, 36 | ) 37 | 38 | self.register_buffer("mask", torch.ones_like(self.weight)) 39 | 40 | _, _, kH, kW = self.weight.shape 41 | 42 | if mask_type == "A": 43 | self.mask[:, :, kH // 2, kW // 2 :] = 0 44 | self.mask[:, :, kH // 2 + 1 :] = 0 45 | 46 | else: 47 | self.mask[:, :, kH // 2, kW // 2 + 1 :] = 0 48 | self.mask[:, :, kH // 2 + 1 :] = 0 49 | 50 | def forward(self, input: torch.Tensor) -> torch.Tensor: 51 | # self.weight.data *= self.mask 52 | 53 | return F.conv2d( 54 | input, 55 | self.weight * self.mask, 56 | self.bias, 57 | self.stride, 58 | self.padding, 59 | self.dilation, 60 | ) 61 | -------------------------------------------------------------------------------- /src/pythae/models/normalizing_flows/planar_flow/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of Planar Flows (PlanarFlow) proposed 3 | in (https://arxiv.org/abs/1505.05770) 4 | """ 5 | 6 | from .planar_flow_config import PlanarFlowConfig 7 | from .planar_flow_model import PlanarFlow 8 | 9 | __all__ = ["PlanarFlow", "PlanarFlowConfig"] 10 | -------------------------------------------------------------------------------- /src/pythae/models/normalizing_flows/planar_flow/planar_flow_config.py: -------------------------------------------------------------------------------- 1 | from pydantic.dataclasses import dataclass 2 | 3 | from ..base import BaseNFConfig 4 | 5 | 6 | @dataclass 7 | class PlanarFlowConfig(BaseNFConfig): 8 | """This is the PlanarFlow model configuration instance. 9 | 10 | Parameters: 11 | input_dim (tuple): The input data dimension. Default: None. 12 | activation (str): The activation function to be applied. Choices: ['linear', 'tanh', 'elu']. 13 | Default: 'tanh'. 14 | """ 15 | 16 | activation: str = "tanh" 17 | 18 | def __post_init__(self): 19 | super().__post_init__() 20 | assert self.activation in ["linear", "tanh", "elu"], ( 21 | f"'{self.activation}' doesn't correspond to an activation handled by the model. " 22 | "Available activations ['linear', 'tanh', 'elu']" 23 | ) 24 | -------------------------------------------------------------------------------- /src/pythae/models/normalizing_flows/radial_flow/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of Radial Flows (RadialFlow) proposed 3 | in (https://arxiv.org/abs/1505.05770) 4 | """ 5 | 6 | from .radial_flow_config import RadialFlowConfig 7 | from .radial_flow_model import RadialFlow 8 | 9 | __all__ = ["RadialFlow", "RadialFlowConfig"] 10 | -------------------------------------------------------------------------------- /src/pythae/models/normalizing_flows/radial_flow/radial_flow_config.py: -------------------------------------------------------------------------------- 1 | from pydantic.dataclasses import dataclass 2 | 3 | from ..base import BaseNFConfig 4 | 5 | 6 | @dataclass 7 | class RadialFlowConfig(BaseNFConfig): 8 | """This is the RadialFlow model configuration instance. 9 | 10 | Parameters: 11 | input_dim (tuple): The input data dimension. Default: None. 12 | """ 13 | -------------------------------------------------------------------------------- /src/pythae/models/piwae/__init__.py: -------------------------------------------------------------------------------- 1 | """This module is the implementation of the Partially Importance Weighted Autoencoder 2 | proposed in (https://arxiv.org/abs/1802.04537). 3 | 4 | Available samplers 5 | ------------------- 6 | 7 | .. autosummary:: 8 | ~pythae.samplers.NormalSampler 9 | ~pythae.samplers.GaussianMixtureSampler 10 | ~pythae.samplers.TwoStageVAESampler 11 | ~pythae.samplers.MAFSampler 12 | ~pythae.samplers.IAFSampler 13 | :nosignatures: 14 | 15 | """ 16 | 17 | from .piwae_config import PIWAEConfig 18 | from .piwae_model import PIWAE 19 | 20 | __all__ = ["PIWAE", "PIWAEConfig"] 21 | -------------------------------------------------------------------------------- /src/pythae/models/piwae/piwae_config.py: -------------------------------------------------------------------------------- 1 | from pydantic.dataclasses import dataclass 2 | 3 | from ..vae import VAEConfig 4 | 5 | 6 | @dataclass 7 | class PIWAEConfig(VAEConfig): 8 | """Partially IWAE model config class. 9 | 10 | Parameters: 11 | input_dim (tuple): The input_data dimension. 12 | latent_dim (int): The latent space dimension. Default: None. 13 | reconstruction_loss (str): The reconstruction loss to use ['bce', 'mse']. Default: 'mse' 14 | number_gradient_estimates (int): Number of (M-)estimates to use for the gradient 15 | estimate for the encoder. Default: 5 16 | number_samples (int): Number of samples to use on the Monte-Carlo estimation. Default: 10 17 | """ 18 | 19 | number_gradient_estimates: int = 5 20 | number_samples: int = 10 21 | -------------------------------------------------------------------------------- /src/pythae/models/pvae/__init__.py: -------------------------------------------------------------------------------- 1 | """This module is the implementation of a Poincaré Disk Variational Autoencoder 2 | (https://arxiv.org/abs/1901.06033). 3 | 4 | Available samplers 5 | ------------------- 6 | 7 | .. autosummary:: 8 | ~pythae.samplers.PoincareDiskSampler 9 | ~pythae.samplers.NormalSampler 10 | ~pythae.samplers.GaussianMixtureSampler 11 | ~pythae.samplers.TwoStageVAESampler 12 | ~pythae.samplers.MAFSampler 13 | ~pythae.samplers.IAFSampler 14 | :nosignatures: 15 | """ 16 | 17 | from .pvae_config import PoincareVAEConfig 18 | from .pvae_model import PoincareVAE 19 | 20 | __all__ = ["PoincareVAE", "PoincareVAEConfig"] 21 | -------------------------------------------------------------------------------- /src/pythae/models/pvae/pvae_config.py: -------------------------------------------------------------------------------- 1 | from pydantic.dataclasses import dataclass 2 | from typing_extensions import Literal 3 | 4 | from ..vae import VAEConfig 5 | 6 | 7 | @dataclass 8 | class PoincareVAEConfig(VAEConfig): 9 | """Poincaré VAE config class. 10 | 11 | Parameters: 12 | input_dim (tuple): The input_data dimension. 13 | latent_dim (int): The latent space dimension. Default: None. 14 | reconstruction_loss (str): The reconstruction loss to use ['bce', 'mse']. Default: 'mse' 15 | prior_distribution (str): The distribution to use as prior 16 | ["wrapped_normal", "riemannian_normal"]. Default: "wrapped_normal" 17 | posterior_distribution (str): The distribution to use as posterior 18 | ["wrapped_normal", "riemannian_normal"]. Default: "wrapped_normal" 19 | curvature (int): The curvature of the manifold. Default: 1 20 | """ 21 | 22 | prior_distribution: Literal[ 23 | "wrapped_normal", "riemannian_normal" 24 | ] = "wrapped_normal" 25 | posterior_distribution: Literal[ 26 | "wrapped_normal", "riemannian_normal" 27 | ] = "wrapped_normal" 28 | curvature: float = 1 29 | -------------------------------------------------------------------------------- /src/pythae/models/rae_gp/__init__.py: -------------------------------------------------------------------------------- 1 | """This module is the implementation of the Regularized AE with gradient penalty regularization 2 | proposed in (https://arxiv.org/abs/1903.12436). 3 | 4 | Available samplers 5 | ------------------- 6 | 7 | .. autosummary:: 8 | ~pythae.samplers.NormalSampler 9 | ~pythae.samplers.GaussianMixtureSampler 10 | ~pythae.samplers.MAFSampler 11 | ~pythae.samplers.IAFSampler 12 | :nosignatures: 13 | """ 14 | 15 | from .rae_gp_config import RAE_GP_Config 16 | from .rae_gp_model import RAE_GP 17 | 18 | __all__ = ["RAE_GP", "RAE_GP_Config"] 19 | -------------------------------------------------------------------------------- /src/pythae/models/rae_gp/rae_gp_config.py: -------------------------------------------------------------------------------- 1 | from pydantic.dataclasses import dataclass 2 | 3 | from ..ae import AEConfig 4 | 5 | 6 | @dataclass 7 | class RAE_GP_Config(AEConfig): 8 | """RAE_GP config class. 9 | 10 | Parameters: 11 | input_dim (tuple): The input_data dimension. 12 | latent_dim (int): The latent space dimension. Default: None. 13 | embedding_weight (float): The factor before the L2 embedding regularization term in the 14 | loss. Default: 1e-4 15 | reg_weight (float): The factor before the gradient penalty regularization term. 16 | """ 17 | 18 | embedding_weight: float = 1e-4 19 | reg_weight: float = 1e-7 20 | -------------------------------------------------------------------------------- /src/pythae/models/rae_l2/__init__.py: -------------------------------------------------------------------------------- 1 | """This module is the implementation of the Regularized AE with L2 decoder parameter regularization 2 | proposed in (https://arxiv.org/abs/1903.12436). 3 | 4 | Available samplers 5 | ------------------- 6 | 7 | .. autosummary:: 8 | ~pythae.samplers.NormalSampler 9 | ~pythae.samplers.GaussianMixtureSampler 10 | ~pythae.samplers.MAFSampler 11 | ~pythae.samplers.IAFSampler 12 | :nosignatures: 13 | """ 14 | 15 | from .rae_l2_config import RAE_L2_Config 16 | from .rae_l2_model import RAE_L2 17 | 18 | __all__ = ["RAE_L2", "RAE_L2_Config"] 19 | -------------------------------------------------------------------------------- /src/pythae/models/rae_l2/rae_l2_config.py: -------------------------------------------------------------------------------- 1 | from pydantic.dataclasses import dataclass 2 | 3 | from ..ae import AEConfig 4 | 5 | 6 | @dataclass 7 | class RAE_L2_Config(AEConfig): 8 | """RAE_L2 config class. 9 | 10 | Parameters: 11 | input_dim (tuple): The input_data dimension. 12 | latent_dim (int): The latent space dimension. Default: None. 13 | embedding_weight (float): The factor before the L2 regularization term in the loss. 14 | Default: 1e-4 15 | reg_weight (float): The weight decay to apply. 16 | """ 17 | 18 | embedding_weight: float = 1e-4 19 | reg_weight: float = 1e-7 20 | -------------------------------------------------------------------------------- /src/pythae/models/rhvae/__init__.py: -------------------------------------------------------------------------------- 1 | r"""This is an implementation of the Riemannian Hamiltonian VAE model proposed in 2 | (https://arxiv.org/abs/2105.00026). This model provides a way to 3 | learn the Riemannian latent structure of a given set of data set through a parametrized 4 | Riemannian metric having the following shape: 5 | :math:`\mathbf{G}^{-1}(z) = \sum \limits _{i=1}^N L_{\psi_i} L_{\psi_i}^{\top} \exp 6 | \Big(-\frac{\lVert z - c_i \rVert_2^2}{T^2} \Big) + \lambda I_d` 7 | 8 | It is particularly well suited for High 9 | Dimensional data combined with low sample number and proved relevant for Data Augmentation as 10 | proved in (https://arxiv.org/abs/2105.00026). 11 | 12 | Available samplers 13 | ------------------- 14 | 15 | .. autosummary:: 16 | ~pythae.samplers.RHVAESampler 17 | ~pythae.samplers.NormalSampler 18 | ~pythae.samplers.GaussianMixtureSampler 19 | ~pythae.samplers.TwoStageVAESampler 20 | ~pythae.samplers.MAFSampler 21 | ~pythae.samplers.IAFSampler 22 | :nosignatures: 23 | """ 24 | 25 | from .rhvae_config import RHVAEConfig 26 | from .rhvae_model import RHVAE 27 | 28 | __all__ = ["RHVAE", "RHVAEConfig"] 29 | -------------------------------------------------------------------------------- /src/pythae/models/rhvae/rhvae_config.py: -------------------------------------------------------------------------------- 1 | from pydantic.dataclasses import dataclass 2 | 3 | from ..vae import VAEConfig 4 | 5 | 6 | @dataclass 7 | class RHVAEConfig(VAEConfig): 8 | r"""RHVAE config class. 9 | 10 | Parameters: 11 | latent_dim (int): The latent dimension used for the latent space. Default: 10 12 | n_lf (int): The number of leapfrog steps to used in the integrator: Default: 3 13 | eps_lf (int): The leapfrog stepsize. Default: 1e-3 14 | beta_zero (int): The tempering factor in the Riemannian Hamiltonian Monte Carlo Sampler. 15 | Default: 0.3 16 | temperature (float): The metric temperature :math:`T`. Default: 1.5 17 | regularization (float): The metric regularization factor :math:`\lambda` 18 | """ 19 | n_lf: int = 3 20 | eps_lf: float = 0.001 21 | beta_zero: float = 0.3 22 | temperature: float = 1.5 23 | regularization: float = 0.01 24 | uses_default_metric: bool = True 25 | -------------------------------------------------------------------------------- /src/pythae/models/rhvae/rhvae_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def create_metric(model): 5 | def G(z): 6 | return torch.inverse( 7 | ( 8 | model.M_tens.unsqueeze(0).to(z.device) 9 | * torch.exp( 10 | -torch.norm( 11 | model.centroids_tens.unsqueeze(0).to(z.device) - z.unsqueeze(1), 12 | dim=-1, 13 | ) 14 | ** 2 15 | / (model.temperature**2) 16 | ) 17 | .unsqueeze(-1) 18 | .unsqueeze(-1) 19 | ).sum(dim=1) 20 | + model.lbd * torch.eye(model.latent_dim).to(z.device) 21 | ) 22 | 23 | return G 24 | 25 | 26 | def create_inverse_metric(model): 27 | def G_inv(z): 28 | return ( 29 | model.M_tens.unsqueeze(0).to(z.device) 30 | * torch.exp( 31 | -torch.norm( 32 | model.centroids_tens.unsqueeze(0).to(z.device) - z.unsqueeze(1), 33 | dim=-1, 34 | ) 35 | ** 2 36 | / (model.temperature**2) 37 | ) 38 | .unsqueeze(-1) 39 | .unsqueeze(-1) 40 | ).sum(dim=1) + model.lbd * torch.eye(model.latent_dim).to(z.device) 41 | 42 | return G_inv 43 | -------------------------------------------------------------------------------- /src/pythae/models/svae/__init__.py: -------------------------------------------------------------------------------- 1 | """This module is the implementation of the Hyperspherical VAE proposed in 2 | (https://arxiv.org/abs/1804.00891). 3 | This models uses a Hyperspherical latent space. 4 | 5 | 6 | Available samplers 7 | ------------------- 8 | 9 | .. autosummary:: 10 | ~pythae.samplers.HypersphereUniformSampler 11 | :nosignatures: 12 | """ 13 | 14 | from .svae_config import SVAEConfig 15 | from .svae_model import SVAE 16 | 17 | __all__ = ["SVAE", "SVAEConfig"] 18 | -------------------------------------------------------------------------------- /src/pythae/models/svae/svae_config.py: -------------------------------------------------------------------------------- 1 | from pydantic.dataclasses import dataclass 2 | 3 | from ..vae import VAEConfig 4 | 5 | 6 | @dataclass 7 | class SVAEConfig(VAEConfig): 8 | r""" 9 | :math:`\mathcal{S}`-VAE model config config class 10 | 11 | Parameters: 12 | input_dim (tuple): The input_data dimension. 13 | latent_dim (int): The latent space dimension in which lives the hypersphere. Default: None. 14 | reconstruction_loss (str): The reconstruction loss to use ['bce', 'mse']. Default: 'mse' 15 | """ 16 | -------------------------------------------------------------------------------- /src/pythae/models/svae/svae_utils.py: -------------------------------------------------------------------------------- 1 | import scipy.special 2 | import torch 3 | 4 | # Using exp(-x)I_v(k) for numerical stability (vanishes in ratios) 5 | 6 | 7 | class ModifiedBesselFn(torch.autograd.Function): 8 | @staticmethod 9 | def forward(self, nu, inp): 10 | self._nu = nu 11 | self.save_for_backward(inp) 12 | inp_cpu = inp.data.cpu().numpy() 13 | return torch.from_numpy(scipy.special.ive(nu, inp_cpu)).to(inp.device) 14 | 15 | @staticmethod 16 | def backward(self, grad_out): 17 | inp = self.saved_tensors[-1] 18 | nu = self._nu 19 | return (None, grad_out * (ive(nu - 1, inp) - ive(nu, inp) * (nu + inp) / inp)) 20 | 21 | 22 | ive = ModifiedBesselFn.apply 23 | -------------------------------------------------------------------------------- /src/pythae/models/vae/__init__.py: -------------------------------------------------------------------------------- 1 | """This module is the implementation of a Vanilla Variational Autoencoder 2 | (https://arxiv.org/abs/1312.6114). 3 | 4 | Available samplers 5 | ------------------- 6 | 7 | .. autosummary:: 8 | ~pythae.samplers.NormalSampler 9 | ~pythae.samplers.GaussianMixtureSampler 10 | ~pythae.samplers.TwoStageVAESampler 11 | ~pythae.samplers.MAFSampler 12 | ~pythae.samplers.IAFSampler 13 | :nosignatures: 14 | """ 15 | 16 | from .vae_config import VAEConfig 17 | from .vae_model import VAE 18 | 19 | __all__ = ["VAE", "VAEConfig"] 20 | -------------------------------------------------------------------------------- /src/pythae/models/vae/vae_config.py: -------------------------------------------------------------------------------- 1 | from pydantic.dataclasses import dataclass 2 | from typing_extensions import Literal 3 | 4 | from ..base.base_config import BaseAEConfig 5 | 6 | 7 | @dataclass 8 | class VAEConfig(BaseAEConfig): 9 | """VAE config class. 10 | 11 | Parameters: 12 | input_dim (tuple): The input_data dimension. 13 | latent_dim (int): The latent space dimension. Default: None. 14 | reconstruction_loss (str): The reconstruction loss to use ['bce', 'mse']. Default: 'mse' 15 | """ 16 | 17 | reconstruction_loss: Literal["bce", "mse"] = "mse" 18 | -------------------------------------------------------------------------------- /src/pythae/models/vae_gan/__init__.py: -------------------------------------------------------------------------------- 1 | """This module is the implementation of the VAE-GAN model 2 | proposed in (https://arxiv.org/abs/1512.09300). 3 | 4 | Available samplers 5 | ------------------- 6 | 7 | .. autosummary:: 8 | ~pythae.samplers.NormalSampler 9 | ~pythae.samplers.GaussianMixtureSampler 10 | ~pythae.samplers.TwoStageVAESampler 11 | ~pythae.samplers.MAFSampler 12 | ~pythae.samplers.IAFSampler 13 | :nosignatures: 14 | 15 | """ 16 | 17 | from .vae_gan_config import VAEGANConfig 18 | from .vae_gan_model import VAEGAN 19 | 20 | __all__ = ["VAEGAN", "VAEGANConfig"] 21 | -------------------------------------------------------------------------------- /src/pythae/models/vae_gan/vae_gan_config.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Union 2 | 3 | from pydantic.dataclasses import dataclass 4 | 5 | from ..vae import VAEConfig 6 | 7 | 8 | @dataclass 9 | class VAEGANConfig(VAEConfig): 10 | """Adversarial AE model config class. 11 | 12 | Parameters: 13 | input_dim (tuple): The input_data dimension. 14 | latent_dim (int): The latent space dimension. Default: None. 15 | reconstruction_loss (str): The reconstruction loss to use ['bce', 'mse']. Default: 'mse' 16 | adversarial_loss_scale (float): Parameter scaling the adversarial loss 17 | reconstruction_layer (int): The reconstruction layer depth used for reconstruction metric 18 | """ 19 | 20 | adversarial_loss_scale: float = 0.5 21 | reconstruction_layer: int = -1 22 | uses_default_discriminator: bool = True 23 | discriminator_input_dim: Union[Tuple[int, ...], None] = None 24 | equilibrium: float = 0.68 25 | margin: float = 0.4 26 | -------------------------------------------------------------------------------- /src/pythae/models/vae_iaf/__init__.py: -------------------------------------------------------------------------------- 1 | """This module is the implementation of a Variational Autoencoder with Inverse Autoregressive Flow 2 | to enhance the expressiveness of the posterior distribution. 3 | (https://arxiv.org/abs/1606.04934). 4 | 5 | Available samplers 6 | ------------------- 7 | 8 | .. autosummary:: 9 | ~pythae.samplers.NormalSampler 10 | ~pythae.samplers.GaussianMixtureSampler 11 | ~pythae.samplers.TwoStageVAESampler 12 | ~pythae.samplers.MAFSampler 13 | ~pythae.samplers.IAFSampler 14 | :nosignatures: 15 | """ 16 | 17 | from .vae_iaf_config import VAE_IAF_Config 18 | from .vae_iaf_model import VAE_IAF 19 | 20 | __all__ = ["VAE_IAF", "VAE_IAF_Config"] 21 | -------------------------------------------------------------------------------- /src/pythae/models/vae_iaf/vae_iaf_config.py: -------------------------------------------------------------------------------- 1 | from pydantic.dataclasses import dataclass 2 | 3 | from ..vae import VAEConfig 4 | 5 | 6 | @dataclass 7 | class VAE_IAF_Config(VAEConfig): 8 | """VAE with Inverse Autoregressive Normalizing Flow config class. 9 | 10 | Parameters: 11 | input_dim (int): The input_data dimension 12 | latent_dim (int): The latent space dimension. Default: None. 13 | reconstruction_loss (str): The reconstruction loss to use ['bce', 'mse']. Default: 'mse'. 14 | n_made_blocks (int): The number of :class:`~pythae.models.normalizing_flows.MADE` models 15 | to consider in the IAF used in the VAE. Default: 2. 16 | n_hidden_in_made (int): The number of hidden layers in the 17 | :class:`~pythae.models.normalizing_flows.MADE` models composing the IAF in the VAE. 18 | Default: 3. 19 | hidden_size (list): The number of unit in each hidder layerin the 20 | :class:`~pythae.models.normalizing_flows.MADE` model. The same number of 21 | units is used across the `n_hidden_in_made` and `n_made_blocks`. Default: 128. 22 | """ 23 | 24 | n_made_blocks: int = 2 25 | n_hidden_in_made: int = 3 26 | hidden_size: int = 128 27 | -------------------------------------------------------------------------------- /src/pythae/models/vae_lin_nf/__init__.py: -------------------------------------------------------------------------------- 1 | """This module is the implementation of a Variational Auto Encoder with Normalizing Flow 2 | (https://arxiv.org/pdf/1505.05770.pdf). 3 | 4 | Available samplers 5 | ------------------- 6 | 7 | .. autosummary:: 8 | ~pythae.samplers.NormalSampler 9 | ~pythae.samplers.GaussianMixtureSampler 10 | ~pythae.samplers.TwoStageVAESampler 11 | ~pythae.samplers.MAFSampler 12 | ~pythae.samplers.IAFSampler 13 | :nosignatures: 14 | """ 15 | 16 | from .vae_lin_nf_config import VAE_LinNF_Config 17 | from .vae_lin_nf_model import VAE_LinNF 18 | 19 | __all__ = ["VAE_LinNF", "VAE_LinNF_Config"] 20 | -------------------------------------------------------------------------------- /src/pythae/models/vae_lin_nf/vae_lin_nf_config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import field 2 | from typing import List 3 | 4 | from pydantic.dataclasses import dataclass 5 | 6 | from ..vae import VAEConfig 7 | 8 | 9 | @dataclass 10 | class VAE_LinNF_Config(VAEConfig): 11 | """VAE with linear Normalizing Flow config class. 12 | 13 | Parameters: 14 | input_dim (int): The input_data dimension 15 | latent_dim (int): The latent space dimension. Default: None. 16 | reconstruction_loss (str): The reconstruction loss to use ['bce', 'mse']. Default: 'mse' 17 | flows (List[str]): A list of strings corresponding to the class of each flow to be applied. 18 | Default: ['Plannar', 'Planar']. Flow choices: ['Planar', 'Radial']. 19 | """ 20 | 21 | flows: List[str] = field(default_factory=lambda: ["Planar", "Planar"]) 22 | 23 | def __post_init__(self): 24 | super().__post_init__() 25 | for i, f in enumerate(self.flows): 26 | assert f in ["Planar", "Radial"], ( 27 | f"Flow name number {i+1}: '{f}' doesn't correspond " 28 | "to ones of the classes. Available linear flows ['Planar', 'Radial']" 29 | ) 30 | -------------------------------------------------------------------------------- /src/pythae/models/vamp/__init__.py: -------------------------------------------------------------------------------- 1 | """This module is the implementation of a Variational Mixture of Posterior prior 2 | Variational Auto Encoder proposed in (https://arxiv.org/abs/1705.07120). 3 | 4 | Available samplers 5 | ------------------- 6 | 7 | .. autosummary:: 8 | ~pythae.samplers.NormalSampler 9 | ~pythae.samplers.GaussianMixtureSampler 10 | ~pythae.samplers.TwoStageVAESampler 11 | ~pythae.samplers.VAMPSampler 12 | ~pythae.samplers.MAFSampler 13 | ~pythae.samplers.IAFSampler 14 | :nosignatures: 15 | """ 16 | 17 | from .vamp_config import VAMPConfig 18 | from .vamp_model import VAMP 19 | 20 | __all__ = ["VAMP", "VAMPConfig"] 21 | -------------------------------------------------------------------------------- /src/pythae/models/vamp/vamp_config.py: -------------------------------------------------------------------------------- 1 | from pydantic.dataclasses import dataclass 2 | 3 | from ..vae import VAEConfig 4 | 5 | 6 | @dataclass 7 | class VAMPConfig(VAEConfig): 8 | """VAMP config class. 9 | 10 | Parameters: 11 | input_dim (tuple): The input_data dimension. 12 | latent_dim (int): The latent space dimension. Default: None. 13 | reconstruction_loss (str): The reconstruction loss to use ['bce', 'mse']. Default: 'mse' 14 | number_components (int): The number of components to use in the VAMP prior. Default: 50 15 | linear_scheduling_steps (int): The number of warmup steps to perform using a linear 16 | scheduling. Default: 0 17 | """ 18 | 19 | number_components: int = 50 20 | linear_scheduling_steps: int = 0 21 | -------------------------------------------------------------------------------- /src/pythae/models/vq_vae/__init__.py: -------------------------------------------------------------------------------- 1 | """This module is the implementation of the Vector Quantized VAE proposed in 2 | (https://arxiv.org/abs/1711.00937). 3 | 4 | Available samplers 5 | ------------------- 6 | 7 | Normalizing flows sampler to come. 8 | 9 | .. autosummary:: 10 | ~pythae.samplers.GaussianMixtureSampler 11 | ~pythae.samplers.MAFSampler 12 | ~pythae.samplers.IAFSampler 13 | :nosignatures: 14 | """ 15 | 16 | from .vq_vae_config import VQVAEConfig 17 | from .vq_vae_model import VQVAE 18 | 19 | __all__ = ["VQVAE", "VQVAEConfig"] 20 | -------------------------------------------------------------------------------- /src/pythae/models/vq_vae/vq_vae_config.py: -------------------------------------------------------------------------------- 1 | from pydantic.dataclasses import dataclass 2 | 3 | from ..ae import AEConfig 4 | 5 | 6 | @dataclass 7 | class VQVAEConfig(AEConfig): 8 | r""" 9 | Vector Quantized VAE model config config class 10 | 11 | Parameters: 12 | input_dim (tuple): The input_data dimension. 13 | latent_dim (int): The latent space dimension. Default: None. 14 | commitment_loss_factor (float): The commitment loss factor in the loss. Default: 0.25. 15 | quantization_loss_factor: The quantization loss factor in the loss. Default: 1. 16 | num_embedding (int): The number of embedding points. Default: 512 17 | use_ema (bool): Whether to use the Exponential Movng Average Update (EMA). Default: False. 18 | decay (float): The decay to apply in the EMA update. Must be in [0, 1]. Default: 0.99. 19 | """ 20 | commitment_loss_factor: float = 0.25 21 | quantization_loss_factor: float = 1.0 22 | num_embeddings: int = 512 23 | use_ema: bool = False 24 | decay: float = 0.99 25 | 26 | def __post_init__(self): 27 | super().__post_init__() 28 | if self.use_ema: 29 | assert 0 <= self.decay <= 1, ( 30 | "The decay in the EMA update must be in [0, 1]. " f"Got {self.decay}." 31 | ) 32 | -------------------------------------------------------------------------------- /src/pythae/models/wae_mmd/__init__.py: -------------------------------------------------------------------------------- 1 | """This module is the implementation of the Wasserstein Autoencoder proposed in 2 | (https://arxiv.org/abs/1711.01558). 3 | 4 | Available samplers 5 | ------------------- 6 | 7 | .. autosummary:: 8 | ~pythae.samplers.NormalSampler 9 | ~pythae.samplers.GaussianMixtureSampler 10 | ~pythae.samplers.MAFSampler 11 | ~pythae.samplers.IAFSampler 12 | :nosignatures: 13 | """ 14 | 15 | from .wae_mmd_config import WAE_MMD_Config 16 | from .wae_mmd_model import WAE_MMD 17 | 18 | __all__ = ["WAE_MMD", "WAE_MMD_Config"] 19 | -------------------------------------------------------------------------------- /src/pythae/models/wae_mmd/wae_mmd_config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import field 2 | from typing import List, Union 3 | 4 | from pydantic.dataclasses import dataclass 5 | from typing_extensions import Literal 6 | 7 | from ..ae import AEConfig 8 | 9 | 10 | @dataclass 11 | class WAE_MMD_Config(AEConfig): 12 | """Wasserstein autoencoder model config class. 13 | 14 | Parameters: 15 | input_dim (tuple): The input_data dimension. 16 | latent_dim (int): The latent space dimension. Default: None. 17 | kernel_choice (str): The kernel to choose. Available options are ['rbf', 'imq'] i.e. 18 | radial basis functions or inverse multiquadratic kernel. Default: 'imq'. 19 | reg_weight (float): The weight to apply between reconstruction and Maximum Mean 20 | Discrepancy. Default: 3e-2 21 | kernel_bandwidth (float): The kernel bandwidth. Default: 1 22 | scales (list): The scales to apply if using multi-scale imq kernels. If None, use a unique 23 | imq kernel. Default: [.1, .2, .5, 1., 2., 5, 10.]. 24 | reconstruction_loss_scale (float): Parameter scaling the reconstruction loss. Default: 1 25 | """ 26 | 27 | kernel_choice: Literal["rbf", "imq"] = "imq" 28 | reg_weight: float = 3e-2 29 | kernel_bandwidth: float = 1.0 30 | scales: Union[List[float], None] = field( 31 | default_factory=lambda: [0.1, 0.2, 0.5, 1.0, 2.0, 5, 10.0] 32 | ) 33 | reconstruction_loss_scale: float = 1.0 34 | -------------------------------------------------------------------------------- /src/pythae/pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | """The Pipelines module is created to facilitate the use of the library. It provides ways to 2 | perform end-to-end operation such as model training or generation. A typical Pipeline is composed by 3 | several pythae's instances which are articulated together. 4 | 5 | A :class:`__call__` function is defined and used to launch the Pipeline. """ 6 | 7 | from .base_pipeline import Pipeline 8 | from .generation import GenerationPipeline 9 | from .training import TrainingPipeline 10 | 11 | __all__ = ["Pipeline", "TrainingPipeline", "GenerationPipeline"] 12 | -------------------------------------------------------------------------------- /src/pythae/pipelines/base_pipeline.py: -------------------------------------------------------------------------------- 1 | class Pipeline: 2 | def __init__(self): 3 | pass 4 | 5 | def __call__(self): 6 | """ 7 | The method run the Pipeline process and must be implemented in a child class 8 | """ 9 | raise NotImplementedError() 10 | -------------------------------------------------------------------------------- /src/pythae/pipelines/pipeline_utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/src/pythae/pipelines/pipeline_utils.py -------------------------------------------------------------------------------- /src/pythae/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/src/pythae/py.typed -------------------------------------------------------------------------------- /src/pythae/samplers/base/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | **Abstract class** 3 | 4 | This is the base Sampler architecture module from which all future samplers should inherit. 5 | """ 6 | 7 | from .base_sampler import BaseSampler 8 | from .base_sampler_config import BaseSamplerConfig 9 | 10 | __all__ = ["BaseSampler", "BaseSamplerConfig"] 11 | -------------------------------------------------------------------------------- /src/pythae/samplers/base/base_sampler_config.py: -------------------------------------------------------------------------------- 1 | from pydantic.dataclasses import dataclass 2 | 3 | from pythae.config import BaseConfig 4 | 5 | 6 | @dataclass 7 | class BaseSamplerConfig(BaseConfig): 8 | """ 9 | BaseSampler config class. 10 | """ 11 | 12 | pass 13 | -------------------------------------------------------------------------------- /src/pythae/samplers/gaussian_mixture/__init__.py: -------------------------------------------------------------------------------- 1 | """Implementation of a Gaussian mixture sampler. 2 | 3 | Available models: 4 | ------------------ 5 | 6 | .. autosummary:: 7 | ~pythae.models.AE 8 | ~pythae.models.VAE 9 | ~pythae.models.BetaVAE 10 | ~pythae.models.VAE_LinNF 11 | ~pythae.models.VAE_IAF 12 | ~pythae.models.DisentangledBetaVAE 13 | ~pythae.models.FactorVAE 14 | ~pythae.models.BetaTCVAE 15 | ~pythae.models.IWAE 16 | ~pythae.models.MSSSIM_VAE 17 | ~pythae.models.WAE_MMD 18 | ~pythae.models.INFOVAE_MMD 19 | ~pythae.models.VAMP 20 | ~pythae.models.SVAE 21 | ~pythae.models.Adversarial_AE 22 | ~pythae.models.VAEGAN 23 | ~pythae.models.VQVAE 24 | ~pythae.models.HVAE 25 | ~pythae.models.RAE_GP 26 | ~pythae.models.RAE_L2 27 | ~pythae.models.RHVAE 28 | :nosignatures: 29 | """ 30 | 31 | from .gaussian_mixture_config import GaussianMixtureSamplerConfig 32 | from .gaussian_mixture_sampler import GaussianMixtureSampler 33 | 34 | __all__ = ["GaussianMixtureSampler", "GaussianMixtureSamplerConfig"] 35 | -------------------------------------------------------------------------------- /src/pythae/samplers/gaussian_mixture/gaussian_mixture_config.py: -------------------------------------------------------------------------------- 1 | from pydantic.dataclasses import dataclass 2 | 3 | from ..base import BaseSamplerConfig 4 | 5 | 6 | @dataclass 7 | class GaussianMixtureSamplerConfig(BaseSamplerConfig): 8 | """Gaussian mixture sampler config class. 9 | 10 | Parameters: 11 | n_components (int): The number of Gaussians in the mixture 12 | """ 13 | 14 | n_components: int = 10 15 | -------------------------------------------------------------------------------- /src/pythae/samplers/hypersphere_uniform_sampler/__init__.py: -------------------------------------------------------------------------------- 1 | """Basic sampler sampling from a uniform distribution on the hypersphere 2 | in the Autoencoder's latent space. 3 | 4 | Available models: 5 | ------------------ 6 | 7 | .. autosummary:: 8 | ~pythae.models.SVAE 9 | :nosignatures: 10 | """ 11 | 12 | from .hypersphere_uniform_config import HypersphereUniformSamplerConfig 13 | from .hypersphere_uniform_sampler import HypersphereUniformSampler 14 | 15 | __all__ = ["HypersphereUniformSampler", "HypersphereUniformSamplerConfig"] 16 | -------------------------------------------------------------------------------- /src/pythae/samplers/hypersphere_uniform_sampler/hypersphere_uniform_config.py: -------------------------------------------------------------------------------- 1 | from pydantic.dataclasses import dataclass 2 | 3 | from ..base import BaseSamplerConfig 4 | 5 | 6 | @dataclass 7 | class HypersphereUniformSamplerConfig(BaseSamplerConfig): 8 | """HypershpereUniformSampler config class. 9 | 10 | N/A 11 | """ 12 | 13 | pass 14 | -------------------------------------------------------------------------------- /src/pythae/samplers/iaf_sampler/__init__.py: -------------------------------------------------------------------------------- 1 | """Sampler fitting an Inverse Autoregressive Flow (:class:`~pythae.models.normalizing_flows.IAF`) 2 | in the Autoencoder's latent space. 3 | 4 | Available models: 5 | ------------------ 6 | 7 | .. autosummary:: 8 | ~pythae.models.AE 9 | ~pythae.models.VAE 10 | ~pythae.models.BetaVAE 11 | ~pythae.models.VAE_LinNF 12 | ~pythae.models.VAE_IAF 13 | ~pythae.models.DisentangledBetaVAE 14 | ~pythae.models.FactorVAE 15 | ~pythae.models.BetaTCVAE 16 | ~pythae.models.IWAE 17 | ~pythae.models.MSSSIM_VAE 18 | ~pythae.models.WAE_MMD 19 | ~pythae.models.INFOVAE_MMD 20 | ~pythae.models.VAMP 21 | ~pythae.models.SVAE 22 | ~pythae.models.Adversarial_AE 23 | ~pythae.models.VAEGAN 24 | ~pythae.models.VQVAE 25 | ~pythae.models.HVAE 26 | ~pythae.models.RAE_GP 27 | ~pythae.models.RAE_L2 28 | ~pythae.models.RHVAE 29 | :nosignatures: 30 | """ 31 | 32 | from .iaf_sampler import IAFSampler 33 | from .iaf_sampler_config import IAFSamplerConfig 34 | 35 | __all__ = ["IAFSampler", "IAFSamplerConfig"] 36 | -------------------------------------------------------------------------------- /src/pythae/samplers/iaf_sampler/iaf_sampler_config.py: -------------------------------------------------------------------------------- 1 | from pydantic.dataclasses import dataclass 2 | 3 | from ..base import BaseSamplerConfig 4 | 5 | 6 | @dataclass 7 | class IAFSamplerConfig(BaseSamplerConfig): 8 | """This is the IAF sampler model configuration instance. 9 | 10 | Parameters: 11 | n_made_blocks (int): The number of MADE model to consider in the IAF. Default: 2. 12 | n_hidden_in_made (int): The number of hidden layers in the MADE models. Default: 3. 13 | hidden_size (list): The number of unit in each hidder layer. The same number of units is 14 | used across the `n_hidden_in_made` and `n_made_blocks` 15 | include_batch_norm (bool): Whether to include batch normalization after each 16 | :class:`~pythae.models.normalizing_flows.MADE` layers. Default: False. 17 | """ 18 | 19 | n_made_blocks: int = 2 20 | n_hidden_in_made: int = 3 21 | hidden_size: int = 128 22 | include_batch_norm: bool = False 23 | -------------------------------------------------------------------------------- /src/pythae/samplers/maf_sampler/__init__.py: -------------------------------------------------------------------------------- 1 | """Sampler fitting a Masked Autoregressive Flow (:class:`~pythae.models.normalizing_flows.MAF`) 2 | in the Autoencoder's latent space. 3 | 4 | Available models: 5 | ------------------ 6 | 7 | .. autosummary:: 8 | ~pythae.models.AE 9 | ~pythae.models.VAE 10 | ~pythae.models.BetaVAE 11 | ~pythae.models.VAE_LinNF 12 | ~pythae.models.VAE_IAF 13 | ~pythae.models.DisentangledBetaVAE 14 | ~pythae.models.FactorVAE 15 | ~pythae.models.BetaTCVAE 16 | ~pythae.models.IWAE 17 | ~pythae.models.MSSSIM_VAE 18 | ~pythae.models.WAE_MMD 19 | ~pythae.models.INFOVAE_MMD 20 | ~pythae.models.VAMP 21 | ~pythae.models.SVAE 22 | ~pythae.models.Adversarial_AE 23 | ~pythae.models.VAEGAN 24 | ~pythae.models.VQVAE 25 | ~pythae.models.HVAE 26 | ~pythae.models.RAE_GP 27 | ~pythae.models.RAE_L2 28 | ~pythae.models.RHVAE 29 | :nosignatures: 30 | """ 31 | 32 | from .maf_sampler import MAFSampler 33 | from .maf_sampler_config import MAFSamplerConfig 34 | 35 | __all__ = ["MAFSampler", "MAFSamplerConfig"] 36 | -------------------------------------------------------------------------------- /src/pythae/samplers/maf_sampler/maf_sampler_config.py: -------------------------------------------------------------------------------- 1 | from pydantic.dataclasses import dataclass 2 | 3 | from ..base import BaseSamplerConfig 4 | 5 | 6 | @dataclass 7 | class MAFSamplerConfig(BaseSamplerConfig): 8 | """This is the MAF sampler model configuration instance. 9 | 10 | Parameters: 11 | n_made_blocks (int): The number of MADE model to consider in the MAF. Default: 2. 12 | n_hidden_in_made (int): The number of hidden layers in the MADE models. Default: 3. 13 | hidden_size (list): The number of unit in each hidder layer. The same number of units is 14 | used across the `n_hidden_in_made` and `n_made_blocks` 15 | include_batch_norm (bool): Whether to include batch normalization after each 16 | :class:`~pythae.models.normalizing_flows.MADE` layers. Default: False. 17 | """ 18 | 19 | n_made_blocks: int = 2 20 | n_hidden_in_made: int = 3 21 | hidden_size: int = 128 22 | include_batch_norm: bool = False 23 | -------------------------------------------------------------------------------- /src/pythae/samplers/manifold_sampler/__init__.py: -------------------------------------------------------------------------------- 1 | """Implementation of a Manifold sampler proposed in (https://arxiv.org/abs/2105.00026). 2 | 3 | Available models: 4 | ------------------ 5 | 6 | .. autosummary:: 7 | ~pythae.models.RHVAE 8 | :nosignatures: 9 | 10 | """ 11 | 12 | from .rhvae_sampler import RHVAESampler 13 | from .rhvae_sampler_config import RHVAESamplerConfig 14 | 15 | __all__ = ["RHVAESampler", "RHVAESamplerConfig"] 16 | -------------------------------------------------------------------------------- /src/pythae/samplers/manifold_sampler/rhvae_sampler_config.py: -------------------------------------------------------------------------------- 1 | from pydantic.dataclasses import dataclass 2 | 3 | from ..base import BaseSamplerConfig 4 | 5 | 6 | @dataclass 7 | class RHVAESamplerConfig(BaseSamplerConfig): 8 | """RHVAESampler config class. 9 | 10 | Parameters: 11 | num_samples (int): The number of samples to generate. Default: 1 12 | batch_size (int): The number of samples per batch. Batching is used to speed up 13 | generation and avoid memory overflows. Default: 50 14 | mcmc_steps (int): The number of MCMC steps to use in the latent space HMC sampler. 15 | Default: 100 16 | n_lf (int): The number of leapfrog to use in the integrator of the HMC sampler. 17 | Default: 15 18 | eps_lf (float): The leapfrog stepsize in the integrator of the HMC sampler. Default: 3e-2 19 | random_start (bool): Initialization of the latent space sampler. If False, the sampler 20 | starts the Markov chain on the metric centroids. If True , a random start is applied. 21 | Default: False 22 | """ 23 | 24 | mcmc_steps_nbr: int = 100 25 | n_lf: int = 15 26 | eps_lf: float = 0.03 27 | beta_zero: float = 1.0 28 | -------------------------------------------------------------------------------- /src/pythae/samplers/normal_sampling/__init__.py: -------------------------------------------------------------------------------- 1 | """Basic sampler sampling from a N(0, 1) in the Autoencoder's latent space. 2 | 3 | Available models: 4 | ------------------ 5 | 6 | .. autosummary:: 7 | ~pythae.models.AE 8 | ~pythae.models.VAE 9 | ~pythae.models.BetaVAE 10 | ~pythae.models.VAE_LinNF 11 | ~pythae.models.VAE_IAF 12 | ~pythae.models.DisentangledBetaVAE 13 | ~pythae.models.FactorVAE 14 | ~pythae.models.BetaTCVAE 15 | ~pythae.models.IWAE 16 | ~pythae.models.MSSSIM_VAE 17 | ~pythae.models.WAE_MMD 18 | ~pythae.models.INFOVAE_MMD 19 | ~pythae.models.VAMP 20 | ~pythae.models.SVAE 21 | ~pythae.models.Adversarial_AE 22 | ~pythae.models.VAEGAN 23 | ~pythae.models.VQVAE 24 | ~pythae.models.HVAE 25 | ~pythae.models.RAE_GP 26 | ~pythae.models.RAE_L2 27 | ~pythae.models.RHVAE 28 | :nosignatures: 29 | """ 30 | 31 | from .normal_config import NormalSamplerConfig 32 | from .normal_sampler import NormalSampler 33 | 34 | __all__ = ["NormalSampler", "NormalSamplerConfig"] 35 | -------------------------------------------------------------------------------- /src/pythae/samplers/normal_sampling/normal_config.py: -------------------------------------------------------------------------------- 1 | from pydantic.dataclasses import dataclass 2 | 3 | from ..base import BaseSamplerConfig 4 | 5 | 6 | @dataclass 7 | class NormalSamplerConfig(BaseSamplerConfig): 8 | """NormalSampler config class. 9 | 10 | N/A 11 | """ 12 | 13 | pass 14 | -------------------------------------------------------------------------------- /src/pythae/samplers/pixelcnn_sampler/__init__.py: -------------------------------------------------------------------------------- 1 | """Sampler fitting a :class:`~pythae.models.normalizing_flows.PixelCNN` 2 | in the VQVAE's latent space. 3 | 4 | Available models: 5 | ------------------ 6 | 7 | .. autosummary:: 8 | ~pythae.models.VQVAE 9 | :nosignatures: 10 | """ 11 | 12 | from .pixelcnn_sampler import PixelCNNSampler 13 | from .pixelcnn_sampler_config import PixelCNNSamplerConfig 14 | 15 | __all__ = ["PixelCNNSampler", "PixelCNNSamplerConfig"] 16 | -------------------------------------------------------------------------------- /src/pythae/samplers/pixelcnn_sampler/pixelcnn_sampler_config.py: -------------------------------------------------------------------------------- 1 | from pydantic.dataclasses import dataclass 2 | 3 | from ..base import BaseSamplerConfig 4 | 5 | 6 | @dataclass 7 | class PixelCNNSamplerConfig(BaseSamplerConfig): 8 | """This is the PixelCNN sampler configuration instance. 9 | 10 | Parameters: 11 | input_dim (tuple): The input data dimension. Default: None. 12 | n_layers (int): The number of convolutional layers in the model. Default: 10. 13 | kernel_size (int): The kernel size in the convolutional layers. It must be odd. Default: 5 14 | """ 15 | 16 | n_layers: int = 10 17 | kernel_size: int = 5 18 | 19 | def __post_init__(self): 20 | super().__post_init__() 21 | assert ( 22 | self.kernel_size % 2 == 1 23 | ), f"Wrong kernel size provided. The kernel size must be odd. Got {self.kernel_size}." 24 | -------------------------------------------------------------------------------- /src/pythae/samplers/pvae_sampler/__init__.py: -------------------------------------------------------------------------------- 1 | """Implementation of a the sampling scheme from a Wrapped Riemannian or Riemannian Gaussian 2 | distribution on the Poincaré Disk as proposed in (https://arxiv.org/abs/1901.06033). 3 | 4 | Available models: 5 | ------------------ 6 | 7 | .. autosummary:: 8 | ~pythae.models.PoincareVAE 9 | :nosignatures: 10 | """ 11 | 12 | from .pvae_sampler import PoincareDiskSampler 13 | from .pvae_sampler_config import PoincareDiskSamplerConfig 14 | 15 | __all__ = ["PoincareDiskSampler", "PoincareDiskSamplerConfig"] 16 | -------------------------------------------------------------------------------- /src/pythae/samplers/pvae_sampler/pvae_sampler_config.py: -------------------------------------------------------------------------------- 1 | from pydantic.dataclasses import dataclass 2 | 3 | from ..base import BaseSamplerConfig 4 | 5 | 6 | @dataclass 7 | class PoincareDiskSamplerConfig(BaseSamplerConfig): 8 | """This is the Poincare Disk prior sampler configuration instance deriving from 9 | :class:`BaseSamplerConfig`. 10 | """ 11 | 12 | pass 13 | -------------------------------------------------------------------------------- /src/pythae/samplers/two_stage_vae_sampler/__init__.py: -------------------------------------------------------------------------------- 1 | """Implementation of a Two Stage VAE sampler as proposed in 2 | (https://openreview.net/pdf?id=B1e0X3C9tQ). 3 | 4 | Available models: 5 | ------------------ 6 | 7 | .. autosummary:: 8 | ~pythae.models.VAE 9 | ~pythae.models.BetaVAE 10 | ~pythae.models.VAE_IAF 11 | ~pythae.models.DisentangledBetaVAE 12 | ~pythae.models.FactorVAE 13 | ~pythae.models.BetaTCVAE 14 | ~pythae.models.IWAE 15 | ~pythae.models.MSSSIM_VAE 16 | ~pythae.models.WAE_MMD 17 | ~pythae.models.INFOVAE_MMD 18 | ~pythae.models.VAMP 19 | ~pythae.models.VAEGAN 20 | ~pythae.models.HVAE 21 | ~pythae.models.RHVAE 22 | :nosignatures: 23 | 24 | """ 25 | 26 | from .two_stage_sampler import TwoStageVAESampler 27 | from .two_stage_sampler_config import TwoStageVAESamplerConfig 28 | 29 | __all__ = ["TwoStageVAESampler", "TwoStageVAESamplerConfig"] 30 | -------------------------------------------------------------------------------- /src/pythae/samplers/two_stage_vae_sampler/two_stage_sampler_config.py: -------------------------------------------------------------------------------- 1 | from pydantic.dataclasses import dataclass 2 | from typing_extensions import Literal 3 | 4 | from ..base import BaseSamplerConfig 5 | 6 | 7 | @dataclass 8 | class TwoStageVAESamplerConfig(BaseSamplerConfig): 9 | """Two Stage VAE sampler config class. 10 | 11 | Parameters: 12 | latent_dim (int): The latent space dimension. Default: None. 13 | reconstruction_loss (str): The reconstruction loss to use ['bce', 'mse']. Default: 'mse' 14 | second_stage_depth (int): The number of layers in the second stage VAE. Default: 2 15 | second_layers_dim (int): The size of the fully connected layer to used un the second VAE 16 | architecture. 17 | """ 18 | 19 | reconstruction_loss: Literal["bce", "mse"] = "mse" 20 | second_stage_depth: int = 2 21 | second_layers_dim: int = 1024 22 | -------------------------------------------------------------------------------- /src/pythae/samplers/vamp_sampler/__init__.py: -------------------------------------------------------------------------------- 1 | """Implementation of a VAMP prior sampler as proposed in (https://arxiv.org/abs/1705.07120). 2 | 3 | Available models: 4 | ------------------ 5 | 6 | .. autosummary:: 7 | ~pythae.models.VAMP 8 | :nosignatures: 9 | """ 10 | 11 | from .vamp_sampler import VAMPSampler 12 | from .vamp_sampler_config import VAMPSamplerConfig 13 | 14 | __all__ = ["VAMPSampler", "VAMPSamplerConfig"] 15 | -------------------------------------------------------------------------------- /src/pythae/samplers/vamp_sampler/vamp_sampler_config.py: -------------------------------------------------------------------------------- 1 | from pydantic.dataclasses import dataclass 2 | 3 | from ..base import BaseSamplerConfig 4 | 5 | 6 | @dataclass 7 | class VAMPSamplerConfig(BaseSamplerConfig): 8 | """This is the VAMP prior sampler configuration instance deriving from 9 | :class:`BaseSamplerConfig`. 10 | """ 11 | 12 | pass 13 | -------------------------------------------------------------------------------- /src/pythae/trainers/__init__.py: -------------------------------------------------------------------------------- 1 | """ Here are implemented the trainers used to train the Autoencoder models 2 | """ 3 | 4 | from .adversarial_trainer import AdversarialTrainer, AdversarialTrainerConfig 5 | from .base_trainer import BaseTrainer, BaseTrainerConfig 6 | from .coupled_optimizer_adversarial_trainer import ( 7 | CoupledOptimizerAdversarialTrainer, 8 | CoupledOptimizerAdversarialTrainerConfig, 9 | ) 10 | from .coupled_optimizer_trainer import ( 11 | CoupledOptimizerTrainer, 12 | CoupledOptimizerTrainerConfig, 13 | ) 14 | 15 | __all__ = [ 16 | "BaseTrainer", 17 | "BaseTrainerConfig", 18 | "CoupledOptimizerTrainer", 19 | "CoupledOptimizerTrainerConfig", 20 | "AdversarialTrainer", 21 | "AdversarialTrainerConfig", 22 | "CoupledOptimizerAdversarialTrainer", 23 | "CoupledOptimizerAdversarialTrainerConfig", 24 | ] 25 | -------------------------------------------------------------------------------- /src/pythae/trainers/adversarial_trainer/__init__.py: -------------------------------------------------------------------------------- 1 | """This module implements the trainer to be used when using adversarial models. It uses two distinct 2 | optimizers, one for the encoder and decoder of the AE and one for the discriminator. 3 | It is suitable for adversarial models. 4 | 5 | Available models: 6 | ------------------ 7 | 8 | .. autosummary:: 9 | ~pythae.models.Adversarial_AE 10 | ~pythae.models.FactorVAE 11 | :nosignatures: 12 | """ 13 | 14 | from .adversarial_trainer import AdversarialTrainer 15 | from .adversarial_trainer_config import AdversarialTrainerConfig 16 | 17 | __all__ = ["AdversarialTrainer", "AdversarialTrainerConfig"] 18 | -------------------------------------------------------------------------------- /src/pythae/trainers/base_trainer/__init__.py: -------------------------------------------------------------------------------- 1 | """This module implements the base trainer allowing you to train the models implemented in pythae. 2 | 3 | Available models: 4 | ------------------ 5 | 6 | .. autosummary:: 7 | ~pythae.models.AE 8 | ~pythae.models.VAE 9 | ~pythae.models.BetaVAE 10 | ~pythae.models.DisentangledBetaVAE 11 | ~pythae.models.BetaTCVAE 12 | ~pythae.models.IWAE 13 | ~pythae.models.MSSSIM_VAE 14 | ~pythae.models.INFOVAE_MMD 15 | ~pythae.models.WAE_MMD 16 | ~pythae.models.VAMP 17 | ~pythae.models.SVAE 18 | ~pythae.models.VQVAE 19 | ~pythae.models.RAE_GP 20 | ~pythae.models.HVAE 21 | ~pythae.models.RHVAE 22 | :nosignatures: 23 | """ 24 | 25 | from .base_trainer import BaseTrainer 26 | from .base_training_config import BaseTrainerConfig 27 | 28 | __all__ = ["BaseTrainer", "BaseTrainerConfig"] 29 | -------------------------------------------------------------------------------- /src/pythae/trainers/coupled_optimizer_adversarial_trainer/__init__.py: -------------------------------------------------------------------------------- 1 | """This module implements the trainer to be used when using adversarial models. Contrary to 2 | :class:`~pythae.trainers.AdversarialTrainer` it uses *three* distinct 3 | optimizers, one for the encoder, one for the decoder of the AE and one for the discriminator. 4 | It is suitable for GAN based models models. 5 | 6 | Available models: 7 | ------------------ 8 | 9 | .. autosummary:: 10 | ~pythae.models.Adversarial_AE 11 | ~pythae.models.VAEGAN 12 | :nosignatures: 13 | """ 14 | 15 | from .coupled_optimizer_adversarial_trainer import CoupledOptimizerAdversarialTrainer 16 | from .coupled_optimizer_adversarial_trainer_config import ( 17 | CoupledOptimizerAdversarialTrainerConfig, 18 | ) 19 | 20 | __all__ = [ 21 | "CoupledOptimizerAdversarialTrainer", 22 | "CoupledOptimizerAdversarialTrainerConfig", 23 | ] 24 | -------------------------------------------------------------------------------- /src/pythae/trainers/coupled_optimizer_trainer/__init__.py: -------------------------------------------------------------------------------- 1 | """This module implements the dual optimizer trainer using two distinct optimizers for the encoder 2 | and the decoder. It is suitable for all models but must be used in particular to train a 3 | :class:`~pythae.models.RAE_L2`. 4 | 5 | Available models: 6 | ------------------ 7 | 8 | .. autosummary:: 9 | ~pythae.models.AE 10 | ~pythae.models.VAE 11 | ~pythae.models.BetaVAE 12 | ~pythae.models.DisentangledBetaVAE 13 | ~pythae.models.BetaTCVAE 14 | ~pythae.models.IWAE 15 | ~pythae.models.MSSSIM_VAE 16 | ~pythae.models.INFOVAE_MMD 17 | ~pythae.models.WAE_MMD 18 | ~pythae.models.VAMP 19 | ~pythae.models.SVAE 20 | ~pythae.models.VQVAE 21 | ~pythae.models.RAE_GP 22 | ~pythae.models.RAE_L2 23 | ~pythae.models.HVAE 24 | ~pythae.models.RHVAE 25 | :nosignatures: 26 | """ 27 | 28 | from .coupled_optimizer_trainer import CoupledOptimizerTrainer 29 | from .coupled_optimizer_trainer_config import CoupledOptimizerTrainerConfig 30 | 31 | __all__ = ["CoupledOptimizerTrainer", "CoupledOptimizerTrainerConfig"] 32 | -------------------------------------------------------------------------------- /src/pythae/trainers/trainer_utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import torch 5 | 6 | 7 | def set_seed(seed: int): 8 | """ 9 | Functions setting the seed for reproducibility on ``random``, ``numpy``, 10 | and ``torch`` 11 | 12 | Args: 13 | 14 | seed (int): The seed to be applied 15 | """ 16 | 17 | random.seed(seed) 18 | np.random.seed(seed) 19 | torch.manual_seed(seed) 20 | torch.cuda.manual_seed_all(seed) 21 | -------------------------------------------------------------------------------- /tests/README.md: -------------------------------------------------------------------------------- 1 | ## Tests 2 | 3 | Tu run the tests locally, first install pytest 4 | 5 | ```bash 6 | $ pip install pytest 7 | ``` 8 | and then launch the tests. For instance, from this folder run 9 | ```bash 10 | $ pytest . 11 | 12 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | 4 | def pytest_addoption(parser): 5 | parser.addoption( 6 | "--runslow", action="store_true", default=False, help="run slow tests" 7 | ) 8 | 9 | 10 | def pytest_collection_modifyitems(config, items): 11 | if config.getoption("--runslow"): 12 | # --runslow given in cli: do not skip slow tests 13 | return 14 | skip_slow = pytest.mark.skip(reason="need --runslow option to run") 15 | for item in items: 16 | if "slow" in item.keywords: 17 | item.add_marker(skip_slow) 18 | -------------------------------------------------------------------------------- /tests/data/baseAE/configs/corrupted_model_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "BaseAEConfig", 3 | "latnt_dim": 11 4 | } -------------------------------------------------------------------------------- /tests/data/baseAE/configs/generation_config00.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "BaseSamplerConfig" 3 | } -------------------------------------------------------------------------------- /tests/data/baseAE/configs/model_config00.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "BaseAEConfig", 3 | "latent_dim": 11 4 | } -------------------------------------------------------------------------------- /tests/data/baseAE/configs/not_json_file.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/tests/data/baseAE/configs/not_json_file.md -------------------------------------------------------------------------------- /tests/data/baseAE/configs/training_config00.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "BaseTrainerConfig", 3 | "per_device_train_batch_size": 13, 4 | "per_device_eval_batch_size": 42, 5 | "num_epochs": 2, 6 | "learning_rate": 1e-5 7 | } -------------------------------------------------------------------------------- /tests/data/corrupted_config/model_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "BadName" 3 | } -------------------------------------------------------------------------------- /tests/data/loading/dummy_data_folder/example0.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/tests/data/loading/dummy_data_folder/example0.bmp -------------------------------------------------------------------------------- /tests/data/loading/dummy_data_folder/example0.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/tests/data/loading/dummy_data_folder/example0.jpeg -------------------------------------------------------------------------------- /tests/data/loading/dummy_data_folder/example0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/tests/data/loading/dummy_data_folder/example0.jpg -------------------------------------------------------------------------------- /tests/data/loading/dummy_data_folder/example0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/tests/data/loading/dummy_data_folder/example0.png -------------------------------------------------------------------------------- /tests/data/loading/dummy_data_folder/example0_downsampled_12_12.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/tests/data/loading/dummy_data_folder/example0_downsampled_12_12.jpg -------------------------------------------------------------------------------- /tests/data/mnist_clean_train_dataset_sample: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/tests/data/mnist_clean_train_dataset_sample -------------------------------------------------------------------------------- /tests/data/rhvae/configs/model_config00.json: -------------------------------------------------------------------------------- 1 | { 2 | "latent_dim": 11, 3 | "n_lf": 2, 4 | "eps_lf": 0.00001, 5 | "temperature": 0.5, 6 | "regularization": 0.1, 7 | "beta_zero": 0.8 8 | } -------------------------------------------------------------------------------- /tests/data/rhvae/configs/trained_model_folder/model.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/tests/data/rhvae/configs/trained_model_folder/model.pt -------------------------------------------------------------------------------- /tests/data/rhvae/configs/trained_model_folder/model_config.json: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/tests/data/rhvae/configs/trained_model_folder/model_config.json -------------------------------------------------------------------------------- /tests/data/unnormalized_mnist_data_array: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/tests/data/unnormalized_mnist_data_array -------------------------------------------------------------------------------- /tests/data/unnormalized_mnist_data_list_of_array: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/tests/data/unnormalized_mnist_data_list_of_array -------------------------------------------------------------------------------- /tests/pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | markers = 3 | slow: mark test as slow. -------------------------------------------------------------------------------- /tests/test_auto_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | 5 | from pythae.models import AutoModel 6 | 7 | PATH = os.path.dirname(os.path.abspath(__file__)) 8 | 9 | 10 | @pytest.fixture() 11 | def corrupted_config(): 12 | return os.path.join(PATH, "data", "corrupted_config") 13 | 14 | 15 | def test_raises_file_not_found(): 16 | 17 | with pytest.raises(FileNotFoundError): 18 | AutoModel.load_from_folder("wrong_file_dir") 19 | 20 | 21 | def test_raises_name_error(corrupted_config): 22 | with pytest.raises(NameError): 23 | AutoModel.load_from_folder(corrupted_config) 24 | -------------------------------------------------------------------------------- /tests/test_datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | import torch 5 | 6 | from pythae.data.datasets import BaseDataset 7 | 8 | PATH = os.path.dirname(os.path.abspath(__file__)) 9 | 10 | 11 | @pytest.fixture 12 | def data(): 13 | return torch.randn(1000, 2) 14 | 15 | 16 | @pytest.fixture 17 | def labels(): 18 | return torch.randint(10, (1000,)) 19 | 20 | 21 | class Test_Dataset: 22 | def test_dataset_call(self, data, labels): 23 | dataset = BaseDataset(data, labels) 24 | 25 | assert torch.all(dataset[0]["data"] == data[0]) 26 | assert torch.all(dataset[1]["labels"] == labels[1]) 27 | assert torch.all(dataset.data == data) 28 | assert torch.all(dataset.labels == labels) 29 | 30 | index = torch.randperm(1000)[:3] 31 | 32 | assert torch.all(dataset[index]["data"] == data[index]) 33 | assert torch.all(dataset[index]["labels"] == labels[index]) 34 | -------------------------------------------------------------------------------- /tests/your_file.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clementchadebec/benchmark_VAE/6419e21558f2a6abc2da99944bddda846ded30f4/tests/your_file.jpeg --------------------------------------------------------------------------------