├── .gitignore ├── .idea ├── .gitignore ├── PyTorch-VAE.iml ├── inspectionProfiles │ └── profiles_settings.xml ├── misc.xml ├── modules.xml └── vcs.xml ├── LICENSE.md ├── README.md ├── assets ├── BetaTCVAE_49.png ├── BetaVAE_B_20.png ├── BetaVAE_B_35.png ├── BetaVAE_H_20.png ├── CategoricalVAE_49.png ├── ConditionalVAE_20.png ├── DFCVAE_49.png ├── DIPVAE_83.png ├── IWAE_19.png ├── InfoVAE_31.png ├── InfoVAE_7.png ├── JointVAE_49.png ├── LogCoshVAE_49.png ├── MIWAE_29.png ├── MSSIMVAE_29.png ├── SWAE_49.png ├── Vanilla VAE_25.png ├── WAE_IMQ_15.png ├── WAE_RBF_18.png ├── recons_BetaTCVAE_49.png ├── recons_BetaVAE_B_20.png ├── recons_BetaVAE_B_35.png ├── recons_BetaVAE_H_20.png ├── recons_CategoricalVAE_49.png ├── recons_ConditionalVAE_20.png ├── recons_DFCVAE_49.png ├── recons_DIPVAE_83.png ├── recons_IWAE_19.png ├── recons_InfoVAE_31.png ├── recons_JointVAE_49.png ├── recons_LogCoshVAE_49.png ├── recons_MIWAE_29.png ├── recons_MSSIMVAE_29.png ├── recons_SWAE_49.png ├── recons_VQVAE_29.png ├── recons_Vanilla VAE_25.png ├── recons_WAE_IMQ_15.png └── recons_WAE_RBF_19.png ├── configs ├── bbvae.yaml ├── betatc_vae.yaml ├── bhvae.yaml ├── cat_vae.yaml ├── cvae.yaml ├── dfc_vae.yaml ├── dip_vae.yaml ├── factorvae.yaml ├── gammavae.yaml ├── hvae.yaml ├── infovae.yaml ├── iwae.yaml ├── joint_vae.yaml ├── logcosh_vae.yaml ├── lvae.yaml ├── miwae.yaml ├── mssim_vae.yaml ├── swae.yaml ├── vae.yaml ├── vampvae.yaml ├── vq_vae.yaml ├── wae_mmd_imq.yaml └── wae_mmd_rbf.yaml ├── dataset.py ├── experiment.py ├── models ├── __init__.py ├── base.py ├── beta_vae.py ├── betatc_vae.py ├── cat_vae.py ├── cvae.py ├── dfcvae.py ├── dip_vae.py ├── fvae.py ├── gamma_vae.py ├── hvae.py ├── info_vae.py ├── iwae.py ├── joint_vae.py ├── logcosh_vae.py ├── lvae.py ├── miwae.py ├── mssim_vae.py ├── swae.py ├── twostage_vae.py ├── types_.py ├── vampvae.py ├── vanilla_vae.py ├── vq_vae.py └── wae_mmd.py ├── requirements.txt ├── run.py ├── tests ├── bvae.py ├── test_betatcvae.py ├── test_cat_vae.py ├── test_dfc.py ├── test_dipvae.py ├── test_fvae.py ├── test_gvae.py ├── test_hvae.py ├── test_iwae.py ├── test_joint_Vae.py ├── test_logcosh.py ├── test_lvae.py ├── test_miwae.py ├── test_mssimvae.py ├── test_swae.py ├── test_vae.py ├── test_vq_vae.py ├── test_wae.py ├── text_cvae.py └── text_vamp.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | Data/ 3 | logs/ 4 | 5 | VanillaVAE/version_0/ 6 | 7 | __pycache__/ 8 | .ipynb_checkpoints/ 9 | 10 | Run.ipynb 11 | -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /workspace.xml 3 | -------------------------------------------------------------------------------- /.idea/PyTorch-VAE.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 12 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 7 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | -------------------------------------------------------------------------------- /assets/BetaTCVAE_49.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/BetaTCVAE_49.png -------------------------------------------------------------------------------- /assets/BetaVAE_B_20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/BetaVAE_B_20.png -------------------------------------------------------------------------------- /assets/BetaVAE_B_35.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/BetaVAE_B_35.png -------------------------------------------------------------------------------- /assets/BetaVAE_H_20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/BetaVAE_H_20.png -------------------------------------------------------------------------------- /assets/CategoricalVAE_49.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/CategoricalVAE_49.png -------------------------------------------------------------------------------- /assets/ConditionalVAE_20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/ConditionalVAE_20.png -------------------------------------------------------------------------------- /assets/DFCVAE_49.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/DFCVAE_49.png -------------------------------------------------------------------------------- /assets/DIPVAE_83.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/DIPVAE_83.png -------------------------------------------------------------------------------- /assets/IWAE_19.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/IWAE_19.png -------------------------------------------------------------------------------- /assets/InfoVAE_31.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/InfoVAE_31.png -------------------------------------------------------------------------------- /assets/InfoVAE_7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/InfoVAE_7.png -------------------------------------------------------------------------------- /assets/JointVAE_49.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/JointVAE_49.png -------------------------------------------------------------------------------- /assets/LogCoshVAE_49.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/LogCoshVAE_49.png -------------------------------------------------------------------------------- /assets/MIWAE_29.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/MIWAE_29.png -------------------------------------------------------------------------------- /assets/MSSIMVAE_29.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/MSSIMVAE_29.png -------------------------------------------------------------------------------- /assets/SWAE_49.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/SWAE_49.png -------------------------------------------------------------------------------- /assets/Vanilla VAE_25.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/Vanilla VAE_25.png -------------------------------------------------------------------------------- /assets/WAE_IMQ_15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/WAE_IMQ_15.png -------------------------------------------------------------------------------- /assets/WAE_RBF_18.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/WAE_RBF_18.png -------------------------------------------------------------------------------- /assets/recons_BetaTCVAE_49.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/recons_BetaTCVAE_49.png -------------------------------------------------------------------------------- /assets/recons_BetaVAE_B_20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/recons_BetaVAE_B_20.png -------------------------------------------------------------------------------- /assets/recons_BetaVAE_B_35.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/recons_BetaVAE_B_35.png -------------------------------------------------------------------------------- /assets/recons_BetaVAE_H_20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/recons_BetaVAE_H_20.png -------------------------------------------------------------------------------- /assets/recons_CategoricalVAE_49.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/recons_CategoricalVAE_49.png -------------------------------------------------------------------------------- /assets/recons_ConditionalVAE_20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/recons_ConditionalVAE_20.png -------------------------------------------------------------------------------- /assets/recons_DFCVAE_49.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/recons_DFCVAE_49.png -------------------------------------------------------------------------------- /assets/recons_DIPVAE_83.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/recons_DIPVAE_83.png -------------------------------------------------------------------------------- /assets/recons_IWAE_19.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/recons_IWAE_19.png -------------------------------------------------------------------------------- /assets/recons_InfoVAE_31.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/recons_InfoVAE_31.png -------------------------------------------------------------------------------- /assets/recons_JointVAE_49.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/recons_JointVAE_49.png -------------------------------------------------------------------------------- /assets/recons_LogCoshVAE_49.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/recons_LogCoshVAE_49.png -------------------------------------------------------------------------------- /assets/recons_MIWAE_29.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/recons_MIWAE_29.png -------------------------------------------------------------------------------- /assets/recons_MSSIMVAE_29.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/recons_MSSIMVAE_29.png -------------------------------------------------------------------------------- /assets/recons_SWAE_49.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/recons_SWAE_49.png -------------------------------------------------------------------------------- /assets/recons_VQVAE_29.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/recons_VQVAE_29.png -------------------------------------------------------------------------------- /assets/recons_Vanilla VAE_25.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/recons_Vanilla VAE_25.png -------------------------------------------------------------------------------- /assets/recons_WAE_IMQ_15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/recons_WAE_IMQ_15.png -------------------------------------------------------------------------------- /assets/recons_WAE_RBF_19.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/recons_WAE_RBF_19.png -------------------------------------------------------------------------------- /configs/bbvae.yaml: -------------------------------------------------------------------------------- 1 | model_params: 2 | name: 'BetaVAE' 3 | in_channels: 3 4 | latent_dim: 128 5 | loss_type: 'B' 6 | gamma: 10.0 7 | max_capacity: 25 8 | Capacity_max_iter: 10000 9 | 10 | data_params: 11 | data_path: "Data/" 12 | train_batch_size: 64 13 | val_batch_size: 64 14 | patch_size: 64 15 | num_workers: 4 16 | 17 | exp_params: 18 | LR: 0.005 19 | weight_decay: 0.0 20 | scheduler_gamma: 0.95 21 | kld_weight: 0.00025 22 | manual_seed: 1265 23 | 24 | trainer_params: 25 | gpus: [1] 26 | max_epochs: 10 27 | 28 | logging_params: 29 | save_dir: "logs/" 30 | manual_seed: 1265 31 | name: 'BetaVAE' 32 | -------------------------------------------------------------------------------- /configs/betatc_vae.yaml: -------------------------------------------------------------------------------- 1 | model_params: 2 | name: 'BetaTCVAE' 3 | in_channels: 3 4 | latent_dim: 10 5 | anneal_steps: 10000 6 | alpha: 1. 7 | beta: 6. 8 | gamma: 1. 9 | 10 | data_params: 11 | data_path: "Data/" 12 | train_batch_size: 64 13 | val_batch_size: 64 14 | patch_size: 64 15 | num_workers: 4 16 | 17 | 18 | exp_params: 19 | LR: 0.005 20 | weight_decay: 0.0 21 | scheduler_gamma: 0.95 22 | kld_weight: 0.00025 23 | manual_seed: 1265 24 | 25 | trainer_params: 26 | gpus: [1] 27 | max_epochs: 10 28 | 29 | logging_params: 30 | save_dir: "logs/" 31 | name: 'BetaTCVAE' 32 | -------------------------------------------------------------------------------- /configs/bhvae.yaml: -------------------------------------------------------------------------------- 1 | model_params: 2 | name: 'BetaVAE' 3 | in_channels: 3 4 | latent_dim: 128 5 | loss_type: 'H' 6 | beta: 10. 7 | 8 | data_params: 9 | data_path: "Data/" 10 | train_batch_size: 64 11 | val_batch_size: 64 12 | patch_size: 64 13 | num_workers: 4 14 | 15 | 16 | exp_params: 17 | LR: 0.005 18 | weight_decay: 0.0 19 | scheduler_gamma: 0.95 20 | kld_weight: 0.00025 21 | manual_seed: 1265 22 | 23 | trainer_params: 24 | gpus: [1] 25 | max_epochs: 10 26 | 27 | logging_params: 28 | save_dir: "logs/" 29 | name: 'BetaVAE' 30 | -------------------------------------------------------------------------------- /configs/cat_vae.yaml: -------------------------------------------------------------------------------- 1 | model_params: 2 | name: 'CategoricalVAE' 3 | in_channels: 3 4 | latent_dim: 512 5 | categorical_dim: 40 6 | temperature: 0.5 7 | anneal_rate: 0.00003 8 | anneal_interval: 100 9 | alpha: 1.0 10 | 11 | data_params: 12 | data_path: "Data/" 13 | train_batch_size: 64 14 | val_batch_size: 64 15 | patch_size: 64 16 | num_workers: 4 17 | 18 | 19 | exp_params: 20 | LR: 0.005 21 | weight_decay: 0.0 22 | scheduler_gamma: 0.95 23 | kld_weight: 0.00025 24 | manual_seed: 1265 25 | 26 | trainer_params: 27 | gpus: [1] 28 | max_epochs: 10 29 | 30 | logging_params: 31 | save_dir: "logs/" 32 | name: "CategoricalVAE" 33 | -------------------------------------------------------------------------------- /configs/cvae.yaml: -------------------------------------------------------------------------------- 1 | model_params: 2 | name: 'ConditionalVAE' 3 | in_channels: 3 4 | num_classes: 40 5 | latent_dim: 128 6 | 7 | data_params: 8 | data_path: "Data/" 9 | train_batch_size: 64 10 | val_batch_size: 64 11 | patch_size: 64 12 | num_workers: 4 13 | 14 | 15 | exp_params: 16 | LR: 0.005 17 | weight_decay: 0.0 18 | scheduler_gamma: 0.95 19 | kld_weight: 0.00025 20 | manual_seed: 1265 21 | 22 | trainer_params: 23 | gpus: [1] 24 | max_epochs: 10 25 | 26 | logging_params: 27 | save_dir: "logs/" 28 | name: "ConditionalVAE" -------------------------------------------------------------------------------- /configs/dfc_vae.yaml: -------------------------------------------------------------------------------- 1 | model_params: 2 | name: 'DFCVAE' 3 | in_channels: 3 4 | latent_dim: 128 5 | 6 | data_params: 7 | data_path: "Data/" 8 | train_batch_size: 64 9 | val_batch_size: 64 10 | patch_size: 64 11 | num_workers: 4 12 | 13 | 14 | exp_params: 15 | LR: 0.005 16 | weight_decay: 0.0 17 | scheduler_gamma: 0.95 18 | kld_weight: 0.00025 19 | manual_seed: 1265 20 | 21 | trainer_params: 22 | gpus: [1] 23 | max_epochs: 10 24 | 25 | logging_params: 26 | save_dir: "logs/" 27 | name: "DFCVAE" 28 | -------------------------------------------------------------------------------- /configs/dip_vae.yaml: -------------------------------------------------------------------------------- 1 | model_params: 2 | name: 'DIPVAE' 3 | in_channels: 3 4 | latent_dim: 128 5 | lambda_diag: 0.05 6 | lambda_offdiag: 0.1 7 | 8 | 9 | data_params: 10 | data_path: "Data/" 11 | train_batch_size: 64 12 | val_batch_size: 64 13 | patch_size: 64 14 | num_workers: 4 15 | 16 | 17 | exp_params: 18 | LR: 0.001 19 | weight_decay: 0.0 20 | scheduler_gamma: 0.97 21 | kld_weight: 1 22 | manual_seed: 1265 23 | 24 | trainer_params: 25 | gpus: [1] 26 | max_epochs: 10 27 | 28 | logging_params: 29 | save_dir: "logs/" 30 | name: "DIPVAE" 31 | manual_seed: 1265 32 | -------------------------------------------------------------------------------- /configs/factorvae.yaml: -------------------------------------------------------------------------------- 1 | model_params: 2 | name: 'FactorVAE' 3 | in_channels: 3 4 | latent_dim: 128 5 | gamma: 6.4 6 | 7 | data_params: 8 | data_path: "Data/" 9 | train_batch_size: 64 10 | val_batch_size: 64 11 | patch_size: 64 12 | num_workers: 4 13 | 14 | 15 | exp_params: 16 | submodel: 'discriminator' 17 | retain_first_backpass: True 18 | LR: 0.005 19 | weight_decay: 0.0 20 | LR_2: 0.005 21 | scheduler_gamma_2: 0.95 22 | scheduler_gamma: 0.95 23 | kld_weight: 0.00025 24 | manual_seed: 1265 25 | 26 | trainer_params: 27 | gpus: [1] 28 | max_epochs: 10 29 | 30 | logging_params: 31 | save_dir: "logs/" 32 | name: "FactorVAE" 33 | 34 | 35 | -------------------------------------------------------------------------------- /configs/gammavae.yaml: -------------------------------------------------------------------------------- 1 | model_params: 2 | name: 'GammaVAE' 3 | in_channels: 3 4 | latent_dim: 128 5 | gamma_shape: 8. 6 | prior_shape: 2. 7 | prior_rate: 1. 8 | 9 | 10 | data_params: 11 | data_path: "Data/" 12 | train_batch_size: 64 13 | val_batch_size: 64 14 | patch_size: 64 15 | num_workers: 4 16 | 17 | 18 | exp_params: 19 | LR: 0.003 20 | weight_decay: 0.00005 21 | scheduler_gamma: 0.95 22 | kld_weight: 0.00025 23 | manual_seed: 1265 24 | 25 | trainer_params: 26 | gpus: [1] 27 | max_epochs: 10 28 | gradient_clip_val: 0.8 29 | 30 | logging_params: 31 | save_dir: "logs/" 32 | name: "GammaVAE" 33 | -------------------------------------------------------------------------------- /configs/hvae.yaml: -------------------------------------------------------------------------------- 1 | model_params: 2 | name: 'HVAE' 3 | in_channels: 3 4 | latent1_dim: 64 5 | latent2_dim: 64 6 | pseudo_input_size: 128 7 | 8 | data_params: 9 | data_path: "Data/" 10 | train_batch_size: 64 11 | val_batch_size: 64 12 | patch_size: 64 13 | num_workers: 4 14 | 15 | 16 | exp_params: 17 | LR: 0.005 18 | weight_decay: 0.0 19 | scheduler_gamma: 0.95 20 | kld_weight: 0.00025 21 | manual_seed: 1265 22 | 23 | trainer_params: 24 | gpus: [1] 25 | max_epochs: 10 26 | 27 | logging_params: 28 | save_dir: "logs/" 29 | name: "VampVAE" 30 | -------------------------------------------------------------------------------- /configs/infovae.yaml: -------------------------------------------------------------------------------- 1 | model_params: 2 | name: 'InfoVAE' 3 | in_channels: 3 4 | latent_dim: 128 5 | reg_weight: 110 # MMD weight 6 | kernel_type: 'imq' 7 | alpha: -9.0 # KLD weight 8 | beta: 10.5 # Reconstruction weight 9 | 10 | data_params: 11 | data_path: "Data/" 12 | train_batch_size: 64 13 | val_batch_size: 64 14 | patch_size: 64 15 | num_workers: 4 16 | 17 | 18 | exp_params: 19 | LR: 0.005 20 | weight_decay: 0.0 21 | scheduler_gamma: 0.95 22 | kld_weight: 0.00025 23 | manual_seed: 1265 24 | 25 | trainer_params: 26 | gpus: [1] 27 | max_epochs: 10 28 | gradient_clip_val: 0.8 29 | 30 | logging_params: 31 | save_dir: "logs/" 32 | name: "InfoVAE" 33 | manual_seed: 1265 34 | 35 | 36 | 37 | 38 | -------------------------------------------------------------------------------- /configs/iwae.yaml: -------------------------------------------------------------------------------- 1 | model_params: 2 | name: 'IWAE' 3 | in_channels: 3 4 | latent_dim: 128 5 | num_samples: 5 6 | 7 | data_params: 8 | data_path: "Data/" 9 | train_batch_size: 64 10 | val_batch_size: 64 11 | patch_size: 64 12 | num_workers: 4 13 | 14 | 15 | exp_params: 16 | LR: 0.007 17 | weight_decay: 0.0 18 | scheduler_gamma: 0.95 19 | kld_weight: 0.00025 20 | manual_seed: 1265 21 | 22 | trainer_params: 23 | gpus: [1] 24 | max_epochs: 10 25 | 26 | logging_params: 27 | save_dir: "logs/" 28 | name: "IWAE" 29 | -------------------------------------------------------------------------------- /configs/joint_vae.yaml: -------------------------------------------------------------------------------- 1 | model_params: 2 | name: 'JointVAE' 3 | in_channels: 3 4 | latent_dim: 512 5 | categorical_dim: 40 6 | latent_min_capacity: 0.0 7 | latent_max_capacity: 20.0 8 | latent_gamma: 10. 9 | latent_num_iter: 25000 10 | categorical_min_capacity: 0.0 11 | categorical_max_capacity: 20.0 12 | categorical_gamma: 10. 13 | categorical_num_iter: 25000 14 | temperature: 0.5 15 | anneal_rate: 0.00003 16 | anneal_interval: 100 17 | alpha: 10.0 18 | 19 | data_params: 20 | data_path: "Data/" 21 | train_batch_size: 64 22 | val_batch_size: 64 23 | patch_size: 64 24 | num_workers: 4 25 | 26 | 27 | exp_params: 28 | LR: 0.005 29 | weight_decay: 0.0 30 | scheduler_gamma: 0.95 31 | kld_weight: 0.00025 32 | manual_seed: 1265 33 | 34 | trainer_params: 35 | gpus: [1] 36 | max_epochs: 10 37 | 38 | logging_params: 39 | save_dir: "logs/" 40 | name: "JointVAE" 41 | 42 | -------------------------------------------------------------------------------- /configs/logcosh_vae.yaml: -------------------------------------------------------------------------------- 1 | model_params: 2 | name: 'LogCoshVAE' 3 | in_channels: 3 4 | latent_dim: 128 5 | alpha: 10.0 6 | beta: 1.0 7 | 8 | data_params: 9 | data_path: "Data/" 10 | train_batch_size: 64 11 | val_batch_size: 64 12 | patch_size: 64 13 | num_workers: 4 14 | 15 | 16 | exp_params: 17 | LR: 0.005 18 | weight_decay: 0.0 19 | scheduler_gamma: 0.97 20 | kld_weight: 0.00025 21 | manual_seed: 1265 22 | 23 | trainer_params: 24 | gpus: [1] 25 | max_epochs: 10 26 | 27 | logging_params: 28 | save_dir: "logs/" 29 | name: "LogCoshVAE" 30 | 31 | -------------------------------------------------------------------------------- /configs/lvae.yaml: -------------------------------------------------------------------------------- 1 | model_params: 2 | name: 'LVAE' 3 | in_channels: 3 4 | latent_dims: [4,8,16,32,128] 5 | hidden_dims: [32, 64,128, 256, 512] 6 | 7 | data_params: 8 | data_path: "Data/" 9 | train_batch_size: 64 10 | val_batch_size: 64 11 | patch_size: 64 12 | num_workers: 4 13 | 14 | 15 | exp_params: 16 | LR: 0.005 17 | weight_decay: 0.0 18 | scheduler_gamma: 0.95 19 | kld_weight: 0.00025 20 | manual_seed: 1265 21 | 22 | trainer_params: 23 | gpus: [1] 24 | max_epochs: 10 25 | 26 | logging_params: 27 | save_dir: "logs/" 28 | name: "LVAE" 29 | -------------------------------------------------------------------------------- /configs/miwae.yaml: -------------------------------------------------------------------------------- 1 | model_params: 2 | name: 'MIWAE' 3 | in_channels: 3 4 | latent_dim: 128 5 | num_samples: 5 6 | num_estimates: 3 7 | 8 | data_params: 9 | data_path: "Data/" 10 | train_batch_size: 64 11 | val_batch_size: 64 12 | patch_size: 64 13 | num_workers: 4 14 | 15 | 16 | exp_params: 17 | LR: 0.005 18 | weight_decay: 0.0 19 | scheduler_gamma: 0.95 20 | kld_weight: 0.00025 21 | manual_seed: 1265 22 | 23 | trainer_params: 24 | gpus: [1] 25 | max_epochs: 10 26 | 27 | logging_params: 28 | save_dir: "logs/" 29 | name: "MIWAE" 30 | 31 | -------------------------------------------------------------------------------- /configs/mssim_vae.yaml: -------------------------------------------------------------------------------- 1 | model_params: 2 | name: 'MSSIMVAE' 3 | in_channels: 3 4 | latent_dim: 128 5 | 6 | data_params: 7 | data_path: "Data/" 8 | train_batch_size: 64 9 | val_batch_size: 64 10 | patch_size: 64 11 | num_workers: 4 12 | 13 | 14 | exp_params: 15 | LR: 0.005 16 | weight_decay: 0.0 17 | scheduler_gamma: 0.95 18 | kld_weight: 0.00025 19 | manual_seed: 1265 20 | 21 | trainer_params: 22 | gpus: [1] 23 | max_epochs: 10 24 | 25 | logging_params: 26 | save_dir: "logs/" 27 | name: "MSSIMVAE" 28 | -------------------------------------------------------------------------------- /configs/swae.yaml: -------------------------------------------------------------------------------- 1 | model_params: 2 | name: 'SWAE' 3 | in_channels: 3 4 | latent_dim: 128 5 | reg_weight: 100 6 | wasserstein_deg: 2.0 7 | num_projections: 200 8 | projection_dist: "normal" #"cauchy" 9 | 10 | data_params: 11 | data_path: "Data/" 12 | train_batch_size: 64 13 | val_batch_size: 64 14 | patch_size: 64 15 | num_workers: 4 16 | 17 | 18 | exp_params: 19 | LR: 0.005 20 | weight_decay: 0.0 21 | scheduler_gamma: 0.95 22 | kld_weight: 0.00025 23 | manual_seed: 1265 24 | 25 | trainer_params: 26 | gpus: [1] 27 | max_epochs: 10 28 | 29 | logging_params: 30 | save_dir: "logs/" 31 | name: "SWAE" 32 | 33 | 34 | 35 | 36 | 37 | -------------------------------------------------------------------------------- /configs/vae.yaml: -------------------------------------------------------------------------------- 1 | model_params: 2 | name: 'VanillaVAE' 3 | in_channels: 3 4 | latent_dim: 128 5 | 6 | 7 | data_params: 8 | data_path: "Data/" 9 | train_batch_size: 64 10 | val_batch_size: 64 11 | patch_size: 64 12 | num_workers: 4 13 | 14 | 15 | exp_params: 16 | LR: 0.005 17 | weight_decay: 0.0 18 | scheduler_gamma: 0.95 19 | kld_weight: 0.00025 20 | manual_seed: 1265 21 | 22 | trainer_params: 23 | gpus: [1] 24 | max_epochs: 100 25 | 26 | logging_params: 27 | save_dir: "logs/" 28 | name: "VanillaVAE" 29 | 30 | -------------------------------------------------------------------------------- /configs/vampvae.yaml: -------------------------------------------------------------------------------- 1 | model_params: 2 | name: 'VampVAE' 3 | in_channels: 3 4 | latent_dim: 128 5 | 6 | exp_params: 7 | dataset: celeba 8 | data_path: "../../shared/Data/" 9 | img_size: 64 10 | batch_size: 144 # Better to have a square number 11 | LR: 0.005 12 | weight_decay: 0.0 13 | scheduler_gamma: 0.95 14 | 15 | trainer_params: 16 | gpus: 1 17 | max_nb_epochs: 50 18 | max_epochs: 50 19 | 20 | logging_params: 21 | save_dir: "logs/" 22 | name: "VampVAE" 23 | manual_seed: 1265 24 | -------------------------------------------------------------------------------- /configs/vq_vae.yaml: -------------------------------------------------------------------------------- 1 | model_params: 2 | name: 'VQVAE' 3 | in_channels: 3 4 | embedding_dim: 64 5 | num_embeddings: 512 6 | img_size: 64 7 | beta: 0.25 8 | 9 | data_params: 10 | data_path: "Data/" 11 | train_batch_size: 64 12 | val_batch_size: 64 13 | patch_size: 64 14 | num_workers: 4 15 | 16 | 17 | exp_params: 18 | LR: 0.005 19 | weight_decay: 0.0 20 | scheduler_gamma: 0.0 21 | kld_weight: 0.00025 22 | manual_seed: 1265 23 | 24 | trainer_params: 25 | gpus: [1] 26 | max_epochs: 10 27 | 28 | logging_params: 29 | save_dir: "logs/" 30 | name: 'VQVAE' 31 | -------------------------------------------------------------------------------- /configs/wae_mmd_imq.yaml: -------------------------------------------------------------------------------- 1 | model_params: 2 | name: 'WAE_MMD' 3 | in_channels: 3 4 | latent_dim: 128 5 | reg_weight: 100 6 | kernel_type: 'imq' 7 | 8 | data_params: 9 | data_path: "Data/" 10 | train_batch_size: 64 11 | val_batch_size: 64 12 | patch_size: 64 13 | num_workers: 4 14 | 15 | 16 | exp_params: 17 | LR: 0.005 18 | weight_decay: 0.0 19 | scheduler_gamma: 0.95 20 | kld_weight: 0.00025 21 | manual_seed: 1265 22 | 23 | trainer_params: 24 | gpus: [1] 25 | max_epochs: 10 26 | 27 | logging_params: 28 | save_dir: "logs/" 29 | name: "WassersteinVAE_IMQ" 30 | 31 | 32 | 33 | 34 | 35 | -------------------------------------------------------------------------------- /configs/wae_mmd_rbf.yaml: -------------------------------------------------------------------------------- 1 | model_params: 2 | name: 'WAE_MMD' 3 | in_channels: 3 4 | latent_dim: 128 5 | reg_weight: 5000 6 | kernel_type: 'rbf' 7 | 8 | data_params: 9 | data_path: "Data/" 10 | train_batch_size: 64 11 | val_batch_size: 64 12 | patch_size: 64 13 | num_workers: 4 14 | 15 | 16 | exp_params: 17 | LR: 0.005 18 | weight_decay: 0.0 19 | scheduler_gamma: 0.95 20 | kld_weight: 0.00025 21 | manual_seed: 1265 22 | 23 | trainer_params: 24 | gpus: [1] 25 | max_epochs: 10 26 | 27 | logging_params: 28 | save_dir: "logs/" 29 | name: "WassersteinVAE_RBF" 30 | 31 | 32 | 33 | 34 | 35 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch import Tensor 4 | from pathlib import Path 5 | from typing import List, Optional, Sequence, Union, Any, Callable 6 | from torchvision.datasets.folder import default_loader 7 | from pytorch_lightning import LightningDataModule 8 | from torch.utils.data import DataLoader, Dataset 9 | from torchvision import transforms 10 | from torchvision.datasets import CelebA 11 | import zipfile 12 | 13 | 14 | # Add your custom dataset class here 15 | class MyDataset(Dataset): 16 | def __init__(self): 17 | pass 18 | 19 | 20 | def __len__(self): 21 | pass 22 | 23 | def __getitem__(self, idx): 24 | pass 25 | 26 | 27 | class MyCelebA(CelebA): 28 | """ 29 | A work-around to address issues with pytorch's celebA dataset class. 30 | 31 | Download and Extract 32 | URL : https://drive.google.com/file/d/1m8-EBPgi5MRubrm6iQjafK2QMHDBMSfJ/view?usp=sharing 33 | """ 34 | 35 | def _check_integrity(self) -> bool: 36 | return True 37 | 38 | 39 | 40 | class OxfordPets(Dataset): 41 | """ 42 | URL = https://www.robots.ox.ac.uk/~vgg/data/pets/ 43 | """ 44 | def __init__(self, 45 | data_path: str, 46 | split: str, 47 | transform: Callable, 48 | **kwargs): 49 | self.data_dir = Path(data_path) / "OxfordPets" 50 | self.transforms = transform 51 | imgs = sorted([f for f in self.data_dir.iterdir() if f.suffix == '.jpg']) 52 | 53 | self.imgs = imgs[:int(len(imgs) * 0.75)] if split == "train" else imgs[int(len(imgs) * 0.75):] 54 | 55 | def __len__(self): 56 | return len(self.imgs) 57 | 58 | def __getitem__(self, idx): 59 | img = default_loader(self.imgs[idx]) 60 | 61 | if self.transforms is not None: 62 | img = self.transforms(img) 63 | 64 | return img, 0.0 # dummy datat to prevent breaking 65 | 66 | class VAEDataset(LightningDataModule): 67 | """ 68 | PyTorch Lightning data module 69 | 70 | Args: 71 | data_dir: root directory of your dataset. 72 | train_batch_size: the batch size to use during training. 73 | val_batch_size: the batch size to use during validation. 74 | patch_size: the size of the crop to take from the original images. 75 | num_workers: the number of parallel workers to create to load data 76 | items (see PyTorch's Dataloader documentation for more details). 77 | pin_memory: whether prepared items should be loaded into pinned memory 78 | or not. This can improve performance on GPUs. 79 | """ 80 | 81 | def __init__( 82 | self, 83 | data_path: str, 84 | train_batch_size: int = 8, 85 | val_batch_size: int = 8, 86 | patch_size: Union[int, Sequence[int]] = (256, 256), 87 | num_workers: int = 0, 88 | pin_memory: bool = False, 89 | **kwargs, 90 | ): 91 | super().__init__() 92 | 93 | self.data_dir = data_path 94 | self.train_batch_size = train_batch_size 95 | self.val_batch_size = val_batch_size 96 | self.patch_size = patch_size 97 | self.num_workers = num_workers 98 | self.pin_memory = pin_memory 99 | 100 | def setup(self, stage: Optional[str] = None) -> None: 101 | # ========================= OxfordPets Dataset ========================= 102 | 103 | # train_transforms = transforms.Compose([transforms.RandomHorizontalFlip(), 104 | # transforms.CenterCrop(self.patch_size), 105 | # # transforms.Resize(self.patch_size), 106 | # transforms.ToTensor(), 107 | # transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]) 108 | 109 | # val_transforms = transforms.Compose([transforms.RandomHorizontalFlip(), 110 | # transforms.CenterCrop(self.patch_size), 111 | # # transforms.Resize(self.patch_size), 112 | # transforms.ToTensor(), 113 | # transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]) 114 | 115 | # self.train_dataset = OxfordPets( 116 | # self.data_dir, 117 | # split='train', 118 | # transform=train_transforms, 119 | # ) 120 | 121 | # self.val_dataset = OxfordPets( 122 | # self.data_dir, 123 | # split='val', 124 | # transform=val_transforms, 125 | # ) 126 | 127 | # ========================= CelebA Dataset ========================= 128 | 129 | train_transforms = transforms.Compose([transforms.RandomHorizontalFlip(), 130 | transforms.CenterCrop(148), 131 | transforms.Resize(self.patch_size), 132 | transforms.ToTensor(),]) 133 | 134 | val_transforms = transforms.Compose([transforms.RandomHorizontalFlip(), 135 | transforms.CenterCrop(148), 136 | transforms.Resize(self.patch_size), 137 | transforms.ToTensor(),]) 138 | 139 | self.train_dataset = MyCelebA( 140 | self.data_dir, 141 | split='train', 142 | transform=train_transforms, 143 | download=False, 144 | ) 145 | 146 | # Replace CelebA with your dataset 147 | self.val_dataset = MyCelebA( 148 | self.data_dir, 149 | split='test', 150 | transform=val_transforms, 151 | download=False, 152 | ) 153 | # =============================================================== 154 | 155 | def train_dataloader(self) -> DataLoader: 156 | return DataLoader( 157 | self.train_dataset, 158 | batch_size=self.train_batch_size, 159 | num_workers=self.num_workers, 160 | shuffle=True, 161 | pin_memory=self.pin_memory, 162 | ) 163 | 164 | def val_dataloader(self) -> Union[DataLoader, List[DataLoader]]: 165 | return DataLoader( 166 | self.val_dataset, 167 | batch_size=self.val_batch_size, 168 | num_workers=self.num_workers, 169 | shuffle=False, 170 | pin_memory=self.pin_memory, 171 | ) 172 | 173 | def test_dataloader(self) -> Union[DataLoader, List[DataLoader]]: 174 | return DataLoader( 175 | self.val_dataset, 176 | batch_size=144, 177 | num_workers=self.num_workers, 178 | shuffle=True, 179 | pin_memory=self.pin_memory, 180 | ) 181 | -------------------------------------------------------------------------------- /experiment.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import torch 4 | from torch import optim 5 | from models import BaseVAE 6 | from models.types_ import * 7 | from utils import data_loader 8 | import pytorch_lightning as pl 9 | from torchvision import transforms 10 | import torchvision.utils as vutils 11 | from torchvision.datasets import CelebA 12 | from torch.utils.data import DataLoader 13 | 14 | 15 | class VAEXperiment(pl.LightningModule): 16 | 17 | def __init__(self, 18 | vae_model: BaseVAE, 19 | params: dict) -> None: 20 | super(VAEXperiment, self).__init__() 21 | 22 | self.model = vae_model 23 | self.params = params 24 | self.curr_device = None 25 | self.hold_graph = False 26 | try: 27 | self.hold_graph = self.params['retain_first_backpass'] 28 | except: 29 | pass 30 | 31 | def forward(self, input: Tensor, **kwargs) -> Tensor: 32 | return self.model(input, **kwargs) 33 | 34 | def training_step(self, batch, batch_idx, optimizer_idx = 0): 35 | real_img, labels = batch 36 | self.curr_device = real_img.device 37 | 38 | results = self.forward(real_img, labels = labels) 39 | train_loss = self.model.loss_function(*results, 40 | M_N = self.params['kld_weight'], #al_img.shape[0]/ self.num_train_imgs, 41 | optimizer_idx=optimizer_idx, 42 | batch_idx = batch_idx) 43 | 44 | self.log_dict({key: val.item() for key, val in train_loss.items()}, sync_dist=True) 45 | 46 | return train_loss['loss'] 47 | 48 | def validation_step(self, batch, batch_idx, optimizer_idx = 0): 49 | real_img, labels = batch 50 | self.curr_device = real_img.device 51 | 52 | results = self.forward(real_img, labels = labels) 53 | val_loss = self.model.loss_function(*results, 54 | M_N = 1.0, #real_img.shape[0]/ self.num_val_imgs, 55 | optimizer_idx = optimizer_idx, 56 | batch_idx = batch_idx) 57 | 58 | self.log_dict({f"val_{key}": val.item() for key, val in val_loss.items()}, sync_dist=True) 59 | 60 | 61 | def on_validation_end(self) -> None: 62 | self.sample_images() 63 | 64 | def sample_images(self): 65 | # Get sample reconstruction image 66 | test_input, test_label = next(iter(self.trainer.datamodule.test_dataloader())) 67 | test_input = test_input.to(self.curr_device) 68 | test_label = test_label.to(self.curr_device) 69 | 70 | # test_input, test_label = batch 71 | recons = self.model.generate(test_input, labels = test_label) 72 | vutils.save_image(recons.data, 73 | os.path.join(self.logger.log_dir , 74 | "Reconstructions", 75 | f"recons_{self.logger.name}_Epoch_{self.current_epoch}.png"), 76 | normalize=True, 77 | nrow=12) 78 | 79 | try: 80 | samples = self.model.sample(144, 81 | self.curr_device, 82 | labels = test_label) 83 | vutils.save_image(samples.cpu().data, 84 | os.path.join(self.logger.log_dir , 85 | "Samples", 86 | f"{self.logger.name}_Epoch_{self.current_epoch}.png"), 87 | normalize=True, 88 | nrow=12) 89 | except Warning: 90 | pass 91 | 92 | def configure_optimizers(self): 93 | 94 | optims = [] 95 | scheds = [] 96 | 97 | optimizer = optim.Adam(self.model.parameters(), 98 | lr=self.params['LR'], 99 | weight_decay=self.params['weight_decay']) 100 | optims.append(optimizer) 101 | # Check if more than 1 optimizer is required (Used for adversarial training) 102 | try: 103 | if self.params['LR_2'] is not None: 104 | optimizer2 = optim.Adam(getattr(self.model,self.params['submodel']).parameters(), 105 | lr=self.params['LR_2']) 106 | optims.append(optimizer2) 107 | except: 108 | pass 109 | 110 | try: 111 | if self.params['scheduler_gamma'] is not None: 112 | scheduler = optim.lr_scheduler.ExponentialLR(optims[0], 113 | gamma = self.params['scheduler_gamma']) 114 | scheds.append(scheduler) 115 | 116 | # Check if another scheduler is required for the second optimizer 117 | try: 118 | if self.params['scheduler_gamma_2'] is not None: 119 | scheduler2 = optim.lr_scheduler.ExponentialLR(optims[1], 120 | gamma = self.params['scheduler_gamma_2']) 121 | scheds.append(scheduler2) 122 | except: 123 | pass 124 | return optims, scheds 125 | except: 126 | return optims 127 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import * 2 | from .vanilla_vae import * 3 | from .gamma_vae import * 4 | from .beta_vae import * 5 | from .wae_mmd import * 6 | from .cvae import * 7 | from .hvae import * 8 | from .vampvae import * 9 | from .iwae import * 10 | from .dfcvae import * 11 | from .mssim_vae import MSSIMVAE 12 | from .fvae import * 13 | from .cat_vae import * 14 | from .joint_vae import * 15 | from .info_vae import * 16 | # from .twostage_vae import * 17 | from .lvae import LVAE 18 | from .logcosh_vae import * 19 | from .swae import * 20 | from .miwae import * 21 | from .vq_vae import * 22 | from .betatc_vae import * 23 | from .dip_vae import * 24 | 25 | 26 | # Aliases 27 | VAE = VanillaVAE 28 | GaussianVAE = VanillaVAE 29 | CVAE = ConditionalVAE 30 | GumbelVAE = CategoricalVAE 31 | 32 | vae_models = {'HVAE':HVAE, 33 | 'LVAE':LVAE, 34 | 'IWAE':IWAE, 35 | 'SWAE':SWAE, 36 | 'MIWAE':MIWAE, 37 | 'VQVAE':VQVAE, 38 | 'DFCVAE':DFCVAE, 39 | 'DIPVAE':DIPVAE, 40 | 'BetaVAE':BetaVAE, 41 | 'InfoVAE':InfoVAE, 42 | 'WAE_MMD':WAE_MMD, 43 | 'VampVAE': VampVAE, 44 | 'GammaVAE':GammaVAE, 45 | 'MSSIMVAE':MSSIMVAE, 46 | 'JointVAE':JointVAE, 47 | 'BetaTCVAE':BetaTCVAE, 48 | 'FactorVAE':FactorVAE, 49 | 'LogCoshVAE':LogCoshVAE, 50 | 'VanillaVAE':VanillaVAE, 51 | 'ConditionalVAE':ConditionalVAE, 52 | 'CategoricalVAE':CategoricalVAE} 53 | -------------------------------------------------------------------------------- /models/base.py: -------------------------------------------------------------------------------- 1 | from .types_ import * 2 | from torch import nn 3 | from abc import abstractmethod 4 | 5 | class BaseVAE(nn.Module): 6 | 7 | def __init__(self) -> None: 8 | super(BaseVAE, self).__init__() 9 | 10 | def encode(self, input: Tensor) -> List[Tensor]: 11 | raise NotImplementedError 12 | 13 | def decode(self, input: Tensor) -> Any: 14 | raise NotImplementedError 15 | 16 | def sample(self, batch_size:int, current_device: int, **kwargs) -> Tensor: 17 | raise NotImplementedError 18 | 19 | def generate(self, x: Tensor, **kwargs) -> Tensor: 20 | raise NotImplementedError 21 | 22 | @abstractmethod 23 | def forward(self, *inputs: Tensor) -> Tensor: 24 | pass 25 | 26 | @abstractmethod 27 | def loss_function(self, *inputs: Any, **kwargs) -> Tensor: 28 | pass 29 | 30 | 31 | 32 | -------------------------------------------------------------------------------- /models/beta_vae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from models import BaseVAE 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from .types_ import * 6 | 7 | 8 | class BetaVAE(BaseVAE): 9 | 10 | num_iter = 0 # Global static variable to keep track of iterations 11 | 12 | def __init__(self, 13 | in_channels: int, 14 | latent_dim: int, 15 | hidden_dims: List = None, 16 | beta: int = 4, 17 | gamma:float = 1000., 18 | max_capacity: int = 25, 19 | Capacity_max_iter: int = 1e5, 20 | loss_type:str = 'B', 21 | **kwargs) -> None: 22 | super(BetaVAE, self).__init__() 23 | 24 | self.latent_dim = latent_dim 25 | self.beta = beta 26 | self.gamma = gamma 27 | self.loss_type = loss_type 28 | self.C_max = torch.Tensor([max_capacity]) 29 | self.C_stop_iter = Capacity_max_iter 30 | 31 | modules = [] 32 | if hidden_dims is None: 33 | hidden_dims = [32, 64, 128, 256, 512] 34 | 35 | # Build Encoder 36 | for h_dim in hidden_dims: 37 | modules.append( 38 | nn.Sequential( 39 | nn.Conv2d(in_channels, out_channels=h_dim, 40 | kernel_size= 3, stride= 2, padding = 1), 41 | nn.BatchNorm2d(h_dim), 42 | nn.LeakyReLU()) 43 | ) 44 | in_channels = h_dim 45 | 46 | self.encoder = nn.Sequential(*modules) 47 | self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim) 48 | self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim) 49 | 50 | 51 | # Build Decoder 52 | modules = [] 53 | 54 | self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4) 55 | 56 | hidden_dims.reverse() 57 | 58 | for i in range(len(hidden_dims) - 1): 59 | modules.append( 60 | nn.Sequential( 61 | nn.ConvTranspose2d(hidden_dims[i], 62 | hidden_dims[i + 1], 63 | kernel_size=3, 64 | stride = 2, 65 | padding=1, 66 | output_padding=1), 67 | nn.BatchNorm2d(hidden_dims[i + 1]), 68 | nn.LeakyReLU()) 69 | ) 70 | 71 | 72 | 73 | self.decoder = nn.Sequential(*modules) 74 | 75 | self.final_layer = nn.Sequential( 76 | nn.ConvTranspose2d(hidden_dims[-1], 77 | hidden_dims[-1], 78 | kernel_size=3, 79 | stride=2, 80 | padding=1, 81 | output_padding=1), 82 | nn.BatchNorm2d(hidden_dims[-1]), 83 | nn.LeakyReLU(), 84 | nn.Conv2d(hidden_dims[-1], out_channels= 3, 85 | kernel_size= 3, padding= 1), 86 | nn.Tanh()) 87 | 88 | def encode(self, input: Tensor) -> List[Tensor]: 89 | """ 90 | Encodes the input by passing through the encoder network 91 | and returns the latent codes. 92 | :param input: (Tensor) Input tensor to encoder [N x C x H x W] 93 | :return: (Tensor) List of latent codes 94 | """ 95 | result = self.encoder(input) 96 | result = torch.flatten(result, start_dim=1) 97 | 98 | # Split the result into mu and var components 99 | # of the latent Gaussian distribution 100 | mu = self.fc_mu(result) 101 | log_var = self.fc_var(result) 102 | 103 | return [mu, log_var] 104 | 105 | def decode(self, z: Tensor) -> Tensor: 106 | result = self.decoder_input(z) 107 | result = result.view(-1, 512, 2, 2) 108 | result = self.decoder(result) 109 | result = self.final_layer(result) 110 | return result 111 | 112 | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor: 113 | """ 114 | Will a single z be enough ti compute the expectation 115 | for the loss?? 116 | :param mu: (Tensor) Mean of the latent Gaussian 117 | :param logvar: (Tensor) Standard deviation of the latent Gaussian 118 | :return: 119 | """ 120 | std = torch.exp(0.5 * logvar) 121 | eps = torch.randn_like(std) 122 | return eps * std + mu 123 | 124 | def forward(self, input: Tensor, **kwargs) -> Tensor: 125 | mu, log_var = self.encode(input) 126 | z = self.reparameterize(mu, log_var) 127 | return [self.decode(z), input, mu, log_var] 128 | 129 | def loss_function(self, 130 | *args, 131 | **kwargs) -> dict: 132 | self.num_iter += 1 133 | recons = args[0] 134 | input = args[1] 135 | mu = args[2] 136 | log_var = args[3] 137 | kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset 138 | 139 | recons_loss =F.mse_loss(recons, input) 140 | 141 | kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0) 142 | 143 | if self.loss_type == 'H': # https://openreview.net/forum?id=Sy2fzU9gl 144 | loss = recons_loss + self.beta * kld_weight * kld_loss 145 | elif self.loss_type == 'B': # https://arxiv.org/pdf/1804.03599.pdf 146 | self.C_max = self.C_max.to(input.device) 147 | C = torch.clamp(self.C_max/self.C_stop_iter * self.num_iter, 0, self.C_max.data[0]) 148 | loss = recons_loss + self.gamma * kld_weight* (kld_loss - C).abs() 149 | else: 150 | raise ValueError('Undefined loss type.') 151 | 152 | return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':kld_loss} 153 | 154 | def sample(self, 155 | num_samples:int, 156 | current_device: int, **kwargs) -> Tensor: 157 | """ 158 | Samples from the latent space and return the corresponding 159 | image space map. 160 | :param num_samples: (Int) Number of samples 161 | :param current_device: (Int) Device to run the model 162 | :return: (Tensor) 163 | """ 164 | z = torch.randn(num_samples, 165 | self.latent_dim) 166 | 167 | z = z.to(current_device) 168 | 169 | samples = self.decode(z) 170 | return samples 171 | 172 | def generate(self, x: Tensor, **kwargs) -> Tensor: 173 | """ 174 | Given an input image x, returns the reconstructed image 175 | :param x: (Tensor) [B x C x H x W] 176 | :return: (Tensor) [B x C x H x W] 177 | """ 178 | 179 | return self.forward(x)[0] -------------------------------------------------------------------------------- /models/betatc_vae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from models import BaseVAE 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from .types_ import * 6 | import math 7 | 8 | 9 | class BetaTCVAE(BaseVAE): 10 | num_iter = 0 # Global static variable to keep track of iterations 11 | 12 | def __init__(self, 13 | in_channels: int, 14 | latent_dim: int, 15 | hidden_dims: List = None, 16 | anneal_steps: int = 200, 17 | alpha: float = 1., 18 | beta: float = 6., 19 | gamma: float = 1., 20 | **kwargs) -> None: 21 | super(BetaTCVAE, self).__init__() 22 | 23 | self.latent_dim = latent_dim 24 | self.anneal_steps = anneal_steps 25 | 26 | self.alpha = alpha 27 | self.beta = beta 28 | self.gamma = gamma 29 | 30 | modules = [] 31 | if hidden_dims is None: 32 | hidden_dims = [32, 32, 32, 32] 33 | 34 | # Build Encoder 35 | for h_dim in hidden_dims: 36 | modules.append( 37 | nn.Sequential( 38 | nn.Conv2d(in_channels, out_channels=h_dim, 39 | kernel_size= 4, stride= 2, padding = 1), 40 | nn.LeakyReLU()) 41 | ) 42 | in_channels = h_dim 43 | 44 | self.encoder = nn.Sequential(*modules) 45 | 46 | self.fc = nn.Linear(hidden_dims[-1]*16, 256) 47 | self.fc_mu = nn.Linear(256, latent_dim) 48 | self.fc_var = nn.Linear(256, latent_dim) 49 | 50 | 51 | # Build Decoder 52 | modules = [] 53 | 54 | self.decoder_input = nn.Linear(latent_dim, 256 * 2) 55 | 56 | hidden_dims.reverse() 57 | 58 | for i in range(len(hidden_dims) - 1): 59 | modules.append( 60 | nn.Sequential( 61 | nn.ConvTranspose2d(hidden_dims[i], 62 | hidden_dims[i + 1], 63 | kernel_size=3, 64 | stride = 2, 65 | padding=1, 66 | output_padding=1), 67 | nn.LeakyReLU()) 68 | ) 69 | 70 | self.decoder = nn.Sequential(*modules) 71 | 72 | self.final_layer = nn.Sequential( 73 | nn.ConvTranspose2d(hidden_dims[-1], 74 | hidden_dims[-1], 75 | kernel_size=3, 76 | stride=2, 77 | padding=1, 78 | output_padding=1), 79 | nn.LeakyReLU(), 80 | nn.Conv2d(hidden_dims[-1], out_channels= 3, 81 | kernel_size= 3, padding= 1), 82 | nn.Tanh()) 83 | 84 | def encode(self, input: Tensor) -> List[Tensor]: 85 | """ 86 | Encodes the input by passing through the encoder network 87 | and returns the latent codes. 88 | :param input: (Tensor) Input tensor to encoder [N x C x H x W] 89 | :return: (Tensor) List of latent codes 90 | """ 91 | result = self.encoder(input) 92 | 93 | result = torch.flatten(result, start_dim=1) 94 | result = self.fc(result) 95 | # Split the result into mu and var components 96 | # of the latent Gaussian distribution 97 | mu = self.fc_mu(result) 98 | log_var = self.fc_var(result) 99 | 100 | return [mu, log_var] 101 | 102 | def decode(self, z: Tensor) -> Tensor: 103 | """ 104 | Maps the given latent codes 105 | onto the image space. 106 | :param z: (Tensor) [B x D] 107 | :return: (Tensor) [B x C x H x W] 108 | """ 109 | result = self.decoder_input(z) 110 | result = result.view(-1, 32, 4, 4) 111 | result = self.decoder(result) 112 | result = self.final_layer(result) 113 | return result 114 | 115 | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor: 116 | """ 117 | Reparameterization trick to sample from N(mu, var) from 118 | N(0,1). 119 | :param mu: (Tensor) Mean of the latent Gaussian [B x D] 120 | :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D] 121 | :return: (Tensor) [B x D] 122 | """ 123 | std = torch.exp(0.5 * logvar) 124 | eps = torch.randn_like(std) 125 | return eps * std + mu 126 | 127 | def forward(self, input: Tensor, **kwargs) -> List[Tensor]: 128 | mu, log_var = self.encode(input) 129 | z = self.reparameterize(mu, log_var) 130 | return [self.decode(z), input, mu, log_var, z] 131 | 132 | def log_density_gaussian(self, x: Tensor, mu: Tensor, logvar: Tensor): 133 | """ 134 | Computes the log pdf of the Gaussian with parameters mu and logvar at x 135 | :param x: (Tensor) Point at whichGaussian PDF is to be evaluated 136 | :param mu: (Tensor) Mean of the Gaussian distribution 137 | :param logvar: (Tensor) Log variance of the Gaussian distribution 138 | :return: 139 | """ 140 | norm = - 0.5 * (math.log(2 * math.pi) + logvar) 141 | log_density = norm - 0.5 * ((x - mu) ** 2 * torch.exp(-logvar)) 142 | return log_density 143 | 144 | def loss_function(self, 145 | *args, 146 | **kwargs) -> dict: 147 | """ 148 | Computes the VAE loss function. 149 | KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2} 150 | :param args: 151 | :param kwargs: 152 | :return: 153 | """ 154 | 155 | recons = args[0] 156 | input = args[1] 157 | mu = args[2] 158 | log_var = args[3] 159 | z = args[4] 160 | 161 | weight = 1 #kwargs['M_N'] # Account for the minibatch samples from the dataset 162 | 163 | recons_loss =F.mse_loss(recons, input, reduction='sum') 164 | 165 | log_q_zx = self.log_density_gaussian(z, mu, log_var).sum(dim = 1) 166 | 167 | zeros = torch.zeros_like(z) 168 | log_p_z = self.log_density_gaussian(z, zeros, zeros).sum(dim = 1) 169 | 170 | batch_size, latent_dim = z.shape 171 | mat_log_q_z = self.log_density_gaussian(z.view(batch_size, 1, latent_dim), 172 | mu.view(1, batch_size, latent_dim), 173 | log_var.view(1, batch_size, latent_dim)) 174 | 175 | # Reference 176 | # [1] https://github.com/YannDubs/disentangling-vae/blob/535bbd2e9aeb5a200663a4f82f1d34e084c4ba8d/disvae/utils/math.py#L54 177 | dataset_size = (1 / kwargs['M_N']) * batch_size # dataset size 178 | strat_weight = (dataset_size - batch_size + 1) / (dataset_size * (batch_size - 1)) 179 | importance_weights = torch.Tensor(batch_size, batch_size).fill_(1 / (batch_size -1)).to(input.device) 180 | importance_weights.view(-1)[::batch_size] = 1 / dataset_size 181 | importance_weights.view(-1)[1::batch_size] = strat_weight 182 | importance_weights[batch_size - 2, 0] = strat_weight 183 | log_importance_weights = importance_weights.log() 184 | 185 | mat_log_q_z += log_importance_weights.view(batch_size, batch_size, 1) 186 | 187 | log_q_z = torch.logsumexp(mat_log_q_z.sum(2), dim=1, keepdim=False) 188 | log_prod_q_z = torch.logsumexp(mat_log_q_z, dim=1, keepdim=False).sum(1) 189 | 190 | mi_loss = (log_q_zx - log_q_z).mean() 191 | tc_loss = (log_q_z - log_prod_q_z).mean() 192 | kld_loss = (log_prod_q_z - log_p_z).mean() 193 | 194 | # kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0) 195 | 196 | if self.training: 197 | self.num_iter += 1 198 | anneal_rate = min(0 + 1 * self.num_iter / self.anneal_steps, 1) 199 | else: 200 | anneal_rate = 1. 201 | 202 | loss = recons_loss/batch_size + \ 203 | self.alpha * mi_loss + \ 204 | weight * (self.beta * tc_loss + 205 | anneal_rate * self.gamma * kld_loss) 206 | 207 | return {'loss': loss, 208 | 'Reconstruction_Loss':recons_loss, 209 | 'KLD':kld_loss, 210 | 'TC_Loss':tc_loss, 211 | 'MI_Loss':mi_loss} 212 | 213 | def sample(self, 214 | num_samples:int, 215 | current_device: int, **kwargs) -> Tensor: 216 | """ 217 | Samples from the latent space and return the corresponding 218 | image space map. 219 | :param num_samples: (Int) Number of samples 220 | :param current_device: (Int) Device to run the model 221 | :return: (Tensor) 222 | """ 223 | z = torch.randn(num_samples, 224 | self.latent_dim) 225 | 226 | z = z.to(current_device) 227 | 228 | samples = self.decode(z) 229 | return samples 230 | 231 | def generate(self, x: Tensor, **kwargs) -> Tensor: 232 | """ 233 | Given an input image x, returns the reconstructed image 234 | :param x: (Tensor) [B x C x H x W] 235 | :return: (Tensor) [B x C x H x W] 236 | """ 237 | 238 | return self.forward(x)[0] -------------------------------------------------------------------------------- /models/cat_vae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from models import BaseVAE 4 | from torch import nn 5 | from torch.nn import functional as F 6 | from .types_ import * 7 | 8 | 9 | class CategoricalVAE(BaseVAE): 10 | 11 | def __init__(self, 12 | in_channels: int, 13 | latent_dim: int, 14 | categorical_dim: int = 40, # Num classes 15 | hidden_dims: List = None, 16 | temperature: float = 0.5, 17 | anneal_rate: float = 3e-5, 18 | anneal_interval: int = 100, # every 100 batches 19 | alpha: float = 30., 20 | **kwargs) -> None: 21 | super(CategoricalVAE, self).__init__() 22 | 23 | self.latent_dim = latent_dim 24 | self.categorical_dim = categorical_dim 25 | self.temp = temperature 26 | self.min_temp = temperature 27 | self.anneal_rate = anneal_rate 28 | self.anneal_interval = anneal_interval 29 | self.alpha = alpha 30 | 31 | modules = [] 32 | if hidden_dims is None: 33 | hidden_dims = [32, 64, 128, 256, 512] 34 | 35 | # Build Encoder 36 | for h_dim in hidden_dims: 37 | modules.append( 38 | nn.Sequential( 39 | nn.Conv2d(in_channels, out_channels=h_dim, 40 | kernel_size= 3, stride= 2, padding = 1), 41 | nn.BatchNorm2d(h_dim), 42 | nn.LeakyReLU()) 43 | ) 44 | in_channels = h_dim 45 | 46 | self.encoder = nn.Sequential(*modules) 47 | self.fc_z = nn.Linear(hidden_dims[-1]*4, 48 | self.latent_dim * self.categorical_dim) 49 | 50 | # Build Decoder 51 | modules = [] 52 | 53 | self.decoder_input = nn.Linear(self.latent_dim * self.categorical_dim 54 | , hidden_dims[-1] * 4) 55 | 56 | hidden_dims.reverse() 57 | 58 | for i in range(len(hidden_dims) - 1): 59 | modules.append( 60 | nn.Sequential( 61 | nn.ConvTranspose2d(hidden_dims[i], 62 | hidden_dims[i + 1], 63 | kernel_size=3, 64 | stride = 2, 65 | padding=1, 66 | output_padding=1), 67 | nn.BatchNorm2d(hidden_dims[i + 1]), 68 | nn.LeakyReLU()) 69 | ) 70 | 71 | 72 | 73 | self.decoder = nn.Sequential(*modules) 74 | 75 | self.final_layer = nn.Sequential( 76 | nn.ConvTranspose2d(hidden_dims[-1], 77 | hidden_dims[-1], 78 | kernel_size=3, 79 | stride=2, 80 | padding=1, 81 | output_padding=1), 82 | nn.BatchNorm2d(hidden_dims[-1]), 83 | nn.LeakyReLU(), 84 | nn.Conv2d(hidden_dims[-1], out_channels= 3, 85 | kernel_size= 3, padding= 1), 86 | nn.Tanh()) 87 | self.sampling_dist = torch.distributions.OneHotCategorical(1. / categorical_dim * torch.ones((self.categorical_dim, 1))) 88 | 89 | def encode(self, input: Tensor) -> List[Tensor]: 90 | """ 91 | Encodes the input by passing through the encoder network 92 | and returns the latent codes. 93 | :param input: (Tensor) Input tensor to encoder [B x C x H x W] 94 | :return: (Tensor) Latent code [B x D x Q] 95 | """ 96 | result = self.encoder(input) 97 | result = torch.flatten(result, start_dim=1) 98 | 99 | # Split the result into mu and var components 100 | # of the latent Gaussian distribution 101 | z = self.fc_z(result) 102 | z = z.view(-1, self.latent_dim, self.categorical_dim) 103 | return [z] 104 | 105 | def decode(self, z: Tensor) -> Tensor: 106 | """ 107 | Maps the given latent codes 108 | onto the image space. 109 | :param z: (Tensor) [B x D x Q] 110 | :return: (Tensor) [B x C x H x W] 111 | """ 112 | result = self.decoder_input(z) 113 | result = result.view(-1, 512, 2, 2) 114 | result = self.decoder(result) 115 | result = self.final_layer(result) 116 | return result 117 | 118 | def reparameterize(self, z: Tensor, eps:float = 1e-7) -> Tensor: 119 | """ 120 | Gumbel-softmax trick to sample from Categorical Distribution 121 | :param z: (Tensor) Latent Codes [B x D x Q] 122 | :return: (Tensor) [B x D] 123 | """ 124 | # Sample from Gumbel 125 | u = torch.rand_like(z) 126 | g = - torch.log(- torch.log(u + eps) + eps) 127 | 128 | # Gumbel-Softmax sample 129 | s = F.softmax((z + g) / self.temp, dim=-1) 130 | s = s.view(-1, self.latent_dim * self.categorical_dim) 131 | return s 132 | 133 | 134 | def forward(self, input: Tensor, **kwargs) -> List[Tensor]: 135 | q = self.encode(input)[0] 136 | z = self.reparameterize(q) 137 | return [self.decode(z), input, q] 138 | 139 | def loss_function(self, 140 | *args, 141 | **kwargs) -> dict: 142 | """ 143 | Computes the VAE loss function. 144 | KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2} 145 | :param args: 146 | :param kwargs: 147 | :return: 148 | """ 149 | recons = args[0] 150 | input = args[1] 151 | q = args[2] 152 | 153 | q_p = F.softmax(q, dim=-1) # Convert the categorical codes into probabilities 154 | 155 | kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset 156 | batch_idx = kwargs['batch_idx'] 157 | 158 | # Anneal the temperature at regular intervals 159 | if batch_idx % self.anneal_interval == 0 and self.training: 160 | self.temp = np.maximum(self.temp * np.exp(- self.anneal_rate * batch_idx), 161 | self.min_temp) 162 | 163 | recons_loss =F.mse_loss(recons, input, reduction='mean') 164 | 165 | # KL divergence between gumbel-softmax distribution 166 | eps = 1e-7 167 | 168 | # Entropy of the logits 169 | h1 = q_p * torch.log(q_p + eps) 170 | 171 | # Cross entropy with the categorical distribution 172 | h2 = q_p * np.log(1. / self.categorical_dim + eps) 173 | kld_loss = torch.mean(torch.sum(h1 - h2, dim =(1,2)), dim=0) 174 | 175 | # kld_weight = 1.2 176 | loss = self.alpha * recons_loss + kld_weight * kld_loss 177 | return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':-kld_loss} 178 | 179 | def sample(self, 180 | num_samples:int, 181 | current_device: int, **kwargs) -> Tensor: 182 | """ 183 | Samples from the latent space and return the corresponding 184 | image space map. 185 | :param num_samples: (Int) Number of samples 186 | :param current_device: (Int) Device to run the model 187 | :return: (Tensor) 188 | """ 189 | # [S x D x Q] 190 | 191 | M = num_samples * self.latent_dim 192 | np_y = np.zeros((M, self.categorical_dim), dtype=np.float32) 193 | np_y[range(M), np.random.choice(self.categorical_dim, M)] = 1 194 | np_y = np.reshape(np_y, [M // self.latent_dim, self.latent_dim, self.categorical_dim]) 195 | z = torch.from_numpy(np_y) 196 | 197 | # z = self.sampling_dist.sample((num_samples * self.latent_dim, )) 198 | z = z.view(num_samples, self.latent_dim * self.categorical_dim).to(current_device) 199 | samples = self.decode(z) 200 | return samples 201 | 202 | def generate(self, x: Tensor, **kwargs) -> Tensor: 203 | """ 204 | Given an input image x, returns the reconstructed image 205 | :param x: (Tensor) [B x C x H x W] 206 | :return: (Tensor) [B x C x H x W] 207 | """ 208 | 209 | return self.forward(x)[0] -------------------------------------------------------------------------------- /models/cvae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from models import BaseVAE 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from .types_ import * 6 | 7 | 8 | class ConditionalVAE(BaseVAE): 9 | 10 | def __init__(self, 11 | in_channels: int, 12 | num_classes: int, 13 | latent_dim: int, 14 | hidden_dims: List = None, 15 | img_size:int = 64, 16 | **kwargs) -> None: 17 | super(ConditionalVAE, self).__init__() 18 | 19 | self.latent_dim = latent_dim 20 | self.img_size = img_size 21 | 22 | self.embed_class = nn.Linear(num_classes, img_size * img_size) 23 | self.embed_data = nn.Conv2d(in_channels, in_channels, kernel_size=1) 24 | 25 | modules = [] 26 | if hidden_dims is None: 27 | hidden_dims = [32, 64, 128, 256, 512] 28 | 29 | in_channels += 1 # To account for the extra label channel 30 | # Build Encoder 31 | for h_dim in hidden_dims: 32 | modules.append( 33 | nn.Sequential( 34 | nn.Conv2d(in_channels, out_channels=h_dim, 35 | kernel_size= 3, stride= 2, padding = 1), 36 | nn.BatchNorm2d(h_dim), 37 | nn.LeakyReLU()) 38 | ) 39 | in_channels = h_dim 40 | 41 | self.encoder = nn.Sequential(*modules) 42 | self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim) 43 | self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim) 44 | 45 | 46 | # Build Decoder 47 | modules = [] 48 | 49 | self.decoder_input = nn.Linear(latent_dim + num_classes, hidden_dims[-1] * 4) 50 | 51 | hidden_dims.reverse() 52 | 53 | for i in range(len(hidden_dims) - 1): 54 | modules.append( 55 | nn.Sequential( 56 | nn.ConvTranspose2d(hidden_dims[i], 57 | hidden_dims[i + 1], 58 | kernel_size=3, 59 | stride = 2, 60 | padding=1, 61 | output_padding=1), 62 | nn.BatchNorm2d(hidden_dims[i + 1]), 63 | nn.LeakyReLU()) 64 | ) 65 | 66 | 67 | 68 | self.decoder = nn.Sequential(*modules) 69 | 70 | self.final_layer = nn.Sequential( 71 | nn.ConvTranspose2d(hidden_dims[-1], 72 | hidden_dims[-1], 73 | kernel_size=3, 74 | stride=2, 75 | padding=1, 76 | output_padding=1), 77 | nn.BatchNorm2d(hidden_dims[-1]), 78 | nn.LeakyReLU(), 79 | nn.Conv2d(hidden_dims[-1], out_channels= 3, 80 | kernel_size= 3, padding= 1), 81 | nn.Tanh()) 82 | 83 | def encode(self, input: Tensor) -> List[Tensor]: 84 | """ 85 | Encodes the input by passing through the encoder network 86 | and returns the latent codes. 87 | :param input: (Tensor) Input tensor to encoder [N x C x H x W] 88 | :return: (Tensor) List of latent codes 89 | """ 90 | result = self.encoder(input) 91 | result = torch.flatten(result, start_dim=1) 92 | 93 | # Split the result into mu and var components 94 | # of the latent Gaussian distribution 95 | mu = self.fc_mu(result) 96 | log_var = self.fc_var(result) 97 | 98 | return [mu, log_var] 99 | 100 | def decode(self, z: Tensor) -> Tensor: 101 | result = self.decoder_input(z) 102 | result = result.view(-1, 512, 2, 2) 103 | result = self.decoder(result) 104 | result = self.final_layer(result) 105 | return result 106 | 107 | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor: 108 | """ 109 | Will a single z be enough ti compute the expectation 110 | for the loss?? 111 | :param mu: (Tensor) Mean of the latent Gaussian 112 | :param logvar: (Tensor) Standard deviation of the latent Gaussian 113 | :return: 114 | """ 115 | std = torch.exp(0.5 * logvar) 116 | eps = torch.randn_like(std) 117 | return eps * std + mu 118 | 119 | def forward(self, input: Tensor, **kwargs) -> List[Tensor]: 120 | y = kwargs['labels'].float() 121 | embedded_class = self.embed_class(y) 122 | embedded_class = embedded_class.view(-1, self.img_size, self.img_size).unsqueeze(1) 123 | embedded_input = self.embed_data(input) 124 | 125 | x = torch.cat([embedded_input, embedded_class], dim = 1) 126 | mu, log_var = self.encode(x) 127 | 128 | z = self.reparameterize(mu, log_var) 129 | 130 | z = torch.cat([z, y], dim = 1) 131 | return [self.decode(z), input, mu, log_var] 132 | 133 | def loss_function(self, 134 | *args, 135 | **kwargs) -> dict: 136 | recons = args[0] 137 | input = args[1] 138 | mu = args[2] 139 | log_var = args[3] 140 | 141 | kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset 142 | recons_loss =F.mse_loss(recons, input) 143 | 144 | kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0) 145 | 146 | loss = recons_loss + kld_weight * kld_loss 147 | return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':-kld_loss} 148 | 149 | def sample(self, 150 | num_samples:int, 151 | current_device: int, 152 | **kwargs) -> Tensor: 153 | """ 154 | Samples from the latent space and return the corresponding 155 | image space map. 156 | :param num_samples: (Int) Number of samples 157 | :param current_device: (Int) Device to run the model 158 | :return: (Tensor) 159 | """ 160 | y = kwargs['labels'].float() 161 | z = torch.randn(num_samples, 162 | self.latent_dim) 163 | 164 | z = z.to(current_device) 165 | 166 | z = torch.cat([z, y], dim=1) 167 | samples = self.decode(z) 168 | return samples 169 | 170 | def generate(self, x: Tensor, **kwargs) -> Tensor: 171 | """ 172 | Given an input image x, returns the reconstructed image 173 | :param x: (Tensor) [B x C x H x W] 174 | :return: (Tensor) [B x C x H x W] 175 | """ 176 | 177 | return self.forward(x, **kwargs)[0] -------------------------------------------------------------------------------- /models/dfcvae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from models import BaseVAE 3 | from torch import nn 4 | from torchvision.models import vgg19_bn 5 | from torch.nn import functional as F 6 | from .types_ import * 7 | 8 | 9 | class DFCVAE(BaseVAE): 10 | 11 | def __init__(self, 12 | in_channels: int, 13 | latent_dim: int, 14 | hidden_dims: List = None, 15 | alpha:float = 1, 16 | beta:float = 0.5, 17 | **kwargs) -> None: 18 | super(DFCVAE, self).__init__() 19 | 20 | self.latent_dim = latent_dim 21 | self.alpha = alpha 22 | self.beta = beta 23 | 24 | modules = [] 25 | if hidden_dims is None: 26 | hidden_dims = [32, 64, 128, 256, 512] 27 | 28 | # Build Encoder 29 | for h_dim in hidden_dims: 30 | modules.append( 31 | nn.Sequential( 32 | nn.Conv2d(in_channels, out_channels=h_dim, 33 | kernel_size= 3, stride= 2, padding = 1), 34 | nn.BatchNorm2d(h_dim), 35 | nn.LeakyReLU()) 36 | ) 37 | in_channels = h_dim 38 | 39 | self.encoder = nn.Sequential(*modules) 40 | self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim) 41 | self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim) 42 | 43 | 44 | # Build Decoder 45 | modules = [] 46 | 47 | self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4) 48 | 49 | hidden_dims.reverse() 50 | 51 | for i in range(len(hidden_dims) - 1): 52 | modules.append( 53 | nn.Sequential( 54 | nn.ConvTranspose2d(hidden_dims[i], 55 | hidden_dims[i + 1], 56 | kernel_size=3, 57 | stride = 2, 58 | padding=1, 59 | output_padding=1), 60 | nn.BatchNorm2d(hidden_dims[i + 1]), 61 | nn.LeakyReLU()) 62 | ) 63 | 64 | 65 | 66 | self.decoder = nn.Sequential(*modules) 67 | 68 | self.final_layer = nn.Sequential( 69 | nn.ConvTranspose2d(hidden_dims[-1], 70 | hidden_dims[-1], 71 | kernel_size=3, 72 | stride=2, 73 | padding=1, 74 | output_padding=1), 75 | nn.BatchNorm2d(hidden_dims[-1]), 76 | nn.LeakyReLU(), 77 | nn.Conv2d(hidden_dims[-1], out_channels= 3, 78 | kernel_size= 3, padding= 1), 79 | nn.Tanh()) 80 | 81 | self.feature_network = vgg19_bn(pretrained=True) 82 | 83 | # Freeze the pretrained feature network 84 | for param in self.feature_network.parameters(): 85 | param.requires_grad = False 86 | 87 | self.feature_network.eval() 88 | 89 | 90 | def encode(self, input: Tensor) -> List[Tensor]: 91 | """ 92 | Encodes the input by passing through the encoder network 93 | and returns the latent codes. 94 | :param input: (Tensor) Input tensor to encoder [N x C x H x W] 95 | :return: (Tensor) List of latent codes 96 | """ 97 | result = self.encoder(input) 98 | result = torch.flatten(result, start_dim=1) 99 | 100 | # Split the result into mu and var components 101 | # of the latent Gaussian distribution 102 | mu = self.fc_mu(result) 103 | log_var = self.fc_var(result) 104 | 105 | return [mu, log_var] 106 | 107 | def decode(self, z: Tensor) -> Tensor: 108 | """ 109 | Maps the given latent codes 110 | onto the image space. 111 | :param z: (Tensor) [B x D] 112 | :return: (Tensor) [B x C x H x W] 113 | """ 114 | result = self.decoder_input(z) 115 | result = result.view(-1, 512, 2, 2) 116 | result = self.decoder(result) 117 | result = self.final_layer(result) 118 | return result 119 | 120 | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor: 121 | """ 122 | Reparameterization trick to sample from N(mu, var) from 123 | N(0,1). 124 | :param mu: (Tensor) Mean of the latent Gaussian [B x D] 125 | :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D] 126 | :return: (Tensor) [B x D] 127 | """ 128 | std = torch.exp(0.5 * logvar) 129 | eps = torch.randn_like(std) 130 | return eps * std + mu 131 | 132 | def forward(self, input: Tensor, **kwargs) -> List[Tensor]: 133 | mu, log_var = self.encode(input) 134 | z = self.reparameterize(mu, log_var) 135 | recons = self.decode(z) 136 | 137 | recons_features = self.extract_features(recons) 138 | input_features = self.extract_features(input) 139 | 140 | return [recons, input, recons_features, input_features, mu, log_var] 141 | 142 | def extract_features(self, 143 | input: Tensor, 144 | feature_layers: List = None) -> List[Tensor]: 145 | """ 146 | Extracts the features from the pretrained model 147 | at the layers indicated by feature_layers. 148 | :param input: (Tensor) [B x C x H x W] 149 | :param feature_layers: List of string of IDs 150 | :return: List of the extracted features 151 | """ 152 | if feature_layers is None: 153 | feature_layers = ['14', '24', '34', '43'] 154 | features = [] 155 | result = input 156 | for (key, module) in self.feature_network.features._modules.items(): 157 | result = module(result) 158 | if(key in feature_layers): 159 | features.append(result) 160 | 161 | return features 162 | 163 | def loss_function(self, 164 | *args, 165 | **kwargs) -> dict: 166 | """ 167 | Computes the VAE loss function. 168 | KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2} 169 | :param args: 170 | :param kwargs: 171 | :return: 172 | """ 173 | recons = args[0] 174 | input = args[1] 175 | recons_features = args[2] 176 | input_features = args[3] 177 | mu = args[4] 178 | log_var = args[5] 179 | 180 | kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset 181 | recons_loss =F.mse_loss(recons, input) 182 | 183 | feature_loss = 0.0 184 | for (r, i) in zip(recons_features, input_features): 185 | feature_loss += F.mse_loss(r, i) 186 | 187 | kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0) 188 | 189 | loss = self.beta * (recons_loss + feature_loss) + self.alpha * kld_weight * kld_loss 190 | return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':-kld_loss} 191 | 192 | def sample(self, 193 | num_samples:int, 194 | current_device: int, **kwargs) -> Tensor: 195 | """ 196 | Samples from the latent space and return the corresponding 197 | image space map. 198 | :param num_samples: (Int) Number of samples 199 | :param current_device: (Int) Device to run the model 200 | :return: (Tensor) 201 | """ 202 | z = torch.randn(num_samples, 203 | self.latent_dim) 204 | 205 | z = z.to(current_device) 206 | 207 | samples = self.decode(z) 208 | return samples 209 | 210 | def generate(self, x: Tensor, **kwargs) -> Tensor: 211 | """ 212 | Given an input image x, returns the reconstructed image 213 | :param x: (Tensor) [B x C x H x W] 214 | :return: (Tensor) [B x C x H x W] 215 | """ 216 | 217 | return self.forward(x)[0] -------------------------------------------------------------------------------- /models/dip_vae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from models import BaseVAE 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from .types_ import * 6 | 7 | 8 | class DIPVAE(BaseVAE): 9 | 10 | def __init__(self, 11 | in_channels: int, 12 | latent_dim: int, 13 | hidden_dims: List = None, 14 | lambda_diag: float = 10., 15 | lambda_offdiag: float = 5., 16 | **kwargs) -> None: 17 | super(DIPVAE, self).__init__() 18 | 19 | self.latent_dim = latent_dim 20 | self.lambda_diag = lambda_diag 21 | self.lambda_offdiag = lambda_offdiag 22 | 23 | modules = [] 24 | if hidden_dims is None: 25 | hidden_dims = [32, 64, 128, 256, 512] 26 | 27 | # Build Encoder 28 | for h_dim in hidden_dims: 29 | modules.append( 30 | nn.Sequential( 31 | nn.Conv2d(in_channels, out_channels=h_dim, 32 | kernel_size= 3, stride= 2, padding = 1), 33 | nn.BatchNorm2d(h_dim), 34 | nn.LeakyReLU()) 35 | ) 36 | in_channels = h_dim 37 | 38 | self.encoder = nn.Sequential(*modules) 39 | self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim) 40 | self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim) 41 | 42 | 43 | # Build Decoder 44 | modules = [] 45 | 46 | self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4) 47 | 48 | hidden_dims.reverse() 49 | 50 | for i in range(len(hidden_dims) - 1): 51 | modules.append( 52 | nn.Sequential( 53 | nn.ConvTranspose2d(hidden_dims[i], 54 | hidden_dims[i + 1], 55 | kernel_size=3, 56 | stride = 2, 57 | padding=1, 58 | output_padding=1), 59 | nn.BatchNorm2d(hidden_dims[i + 1]), 60 | nn.LeakyReLU()) 61 | ) 62 | 63 | self.decoder = nn.Sequential(*modules) 64 | 65 | self.final_layer = nn.Sequential( 66 | nn.ConvTranspose2d(hidden_dims[-1], 67 | hidden_dims[-1], 68 | kernel_size=3, 69 | stride=2, 70 | padding=1, 71 | output_padding=1), 72 | nn.BatchNorm2d(hidden_dims[-1]), 73 | nn.LeakyReLU(), 74 | nn.Conv2d(hidden_dims[-1], out_channels= 3, 75 | kernel_size= 3, padding= 1), 76 | nn.Tanh()) 77 | 78 | def encode(self, input: Tensor) -> List[Tensor]: 79 | """ 80 | Encodes the input by passing through the encoder network 81 | and returns the latent codes. 82 | :param input: (Tensor) Input tensor to encoder [N x C x H x W] 83 | :return: (Tensor) List of latent codes 84 | """ 85 | result = self.encoder(input) 86 | result = torch.flatten(result, start_dim=1) 87 | 88 | # Split the result into mu and var components 89 | # of the latent Gaussian distribution 90 | mu = self.fc_mu(result) 91 | log_var = self.fc_var(result) 92 | 93 | return [mu, log_var] 94 | 95 | def decode(self, z: Tensor) -> Tensor: 96 | """ 97 | Maps the given latent codes 98 | onto the image space. 99 | :param z: (Tensor) [B x D] 100 | :return: (Tensor) [B x C x H x W] 101 | """ 102 | result = self.decoder_input(z) 103 | result = result.view(-1, 512, 2, 2) 104 | result = self.decoder(result) 105 | result = self.final_layer(result) 106 | return result 107 | 108 | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor: 109 | """ 110 | Reparameterization trick to sample from N(mu, var) from 111 | N(0,1). 112 | :param mu: (Tensor) Mean of the latent Gaussian [B x D] 113 | :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D] 114 | :return: (Tensor) [B x D] 115 | """ 116 | std = torch.exp(0.5 * logvar) 117 | eps = torch.randn_like(std) 118 | return eps * std + mu 119 | 120 | def forward(self, input: Tensor, **kwargs) -> List[Tensor]: 121 | mu, log_var = self.encode(input) 122 | z = self.reparameterize(mu, log_var) 123 | return [self.decode(z), input, mu, log_var] 124 | 125 | def loss_function(self, 126 | *args, 127 | **kwargs) -> dict: 128 | """ 129 | Computes the VAE loss function. 130 | KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2} 131 | :param args: 132 | :param kwargs: 133 | :return: 134 | """ 135 | recons = args[0] 136 | input = args[1] 137 | mu = args[2] 138 | log_var = args[3] 139 | 140 | kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset 141 | recons_loss =F.mse_loss(recons, input, reduction='sum') 142 | 143 | 144 | kld_loss = torch.sum(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0) 145 | 146 | # DIP Loss 147 | centered_mu = mu - mu.mean(dim=1, keepdim = True) # [B x D] 148 | cov_mu = centered_mu.t().matmul(centered_mu).squeeze() # [D X D] 149 | 150 | # Add Variance for DIP Loss II 151 | cov_z = cov_mu + torch.mean(torch.diagonal((2. * log_var).exp(), dim1 = 0), dim = 0) # [D x D] 152 | # For DIp Loss I 153 | # cov_z = cov_mu 154 | 155 | cov_diag = torch.diag(cov_z) # [D] 156 | cov_offdiag = cov_z - torch.diag(cov_diag) # [D x D] 157 | dip_loss = self.lambda_offdiag * torch.sum(cov_offdiag ** 2) + \ 158 | self.lambda_diag * torch.sum((cov_diag - 1) ** 2) 159 | 160 | loss = recons_loss + kld_weight * kld_loss + dip_loss 161 | return {'loss': loss, 162 | 'Reconstruction_Loss':recons_loss, 163 | 'KLD':-kld_loss, 164 | 'DIP_Loss':dip_loss} 165 | 166 | def sample(self, 167 | num_samples:int, 168 | current_device: int, **kwargs) -> Tensor: 169 | """ 170 | Samples from the latent space and return the corresponding 171 | image space map. 172 | :param num_samples: (Int) Number of samples 173 | :param current_device: (Int) Device to run the model 174 | :return: (Tensor) 175 | """ 176 | z = torch.randn(num_samples, 177 | self.latent_dim) 178 | 179 | z = z.to(current_device) 180 | 181 | samples = self.decode(z) 182 | return samples 183 | 184 | def generate(self, x: Tensor, **kwargs) -> Tensor: 185 | """ 186 | Given an input image x, returns the reconstructed image 187 | :param x: (Tensor) [B x C x H x W] 188 | :return: (Tensor) [B x C x H x W] 189 | """ 190 | 191 | return self.forward(x)[0] -------------------------------------------------------------------------------- /models/fvae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from models import BaseVAE 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from .types_ import * 6 | 7 | 8 | class FactorVAE(BaseVAE): 9 | 10 | def __init__(self, 11 | in_channels: int, 12 | latent_dim: int, 13 | hidden_dims: List = None, 14 | gamma: float = 40., 15 | **kwargs) -> None: 16 | super(FactorVAE, self).__init__() 17 | 18 | self.latent_dim = latent_dim 19 | self.gamma = gamma 20 | 21 | modules = [] 22 | if hidden_dims is None: 23 | hidden_dims = [32, 64, 128, 256, 512] 24 | 25 | # Build Encoder 26 | for h_dim in hidden_dims: 27 | modules.append( 28 | nn.Sequential( 29 | nn.Conv2d(in_channels, out_channels=h_dim, 30 | kernel_size= 3, stride= 2, padding = 1), 31 | nn.BatchNorm2d(h_dim), 32 | nn.LeakyReLU()) 33 | ) 34 | in_channels = h_dim 35 | 36 | self.encoder = nn.Sequential(*modules) 37 | self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim) 38 | self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim) 39 | 40 | 41 | # Build Decoder 42 | modules = [] 43 | 44 | self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4) 45 | 46 | hidden_dims.reverse() 47 | 48 | for i in range(len(hidden_dims) - 1): 49 | modules.append( 50 | nn.Sequential( 51 | nn.ConvTranspose2d(hidden_dims[i], 52 | hidden_dims[i + 1], 53 | kernel_size=3, 54 | stride = 2, 55 | padding=1, 56 | output_padding=1), 57 | nn.BatchNorm2d(hidden_dims[i + 1]), 58 | nn.LeakyReLU()) 59 | ) 60 | 61 | 62 | 63 | self.decoder = nn.Sequential(*modules) 64 | 65 | self.final_layer = nn.Sequential( 66 | nn.ConvTranspose2d(hidden_dims[-1], 67 | hidden_dims[-1], 68 | kernel_size=3, 69 | stride=2, 70 | padding=1, 71 | output_padding=1), 72 | nn.BatchNorm2d(hidden_dims[-1]), 73 | nn.LeakyReLU(), 74 | nn.Conv2d(hidden_dims[-1], out_channels= 3, 75 | kernel_size= 3, padding= 1), 76 | nn.Tanh()) 77 | 78 | # Discriminator network for the Total Correlation (TC) loss 79 | self.discriminator = nn.Sequential(nn.Linear(self.latent_dim, 1000), 80 | nn.BatchNorm1d(1000), 81 | nn.LeakyReLU(0.2), 82 | nn.Linear(1000, 1000), 83 | nn.BatchNorm1d(1000), 84 | nn.LeakyReLU(0.2), 85 | nn.Linear(1000, 1000), 86 | nn.BatchNorm1d(1000), 87 | nn.LeakyReLU(0.2), 88 | nn.Linear(1000, 2)) 89 | self.D_z_reserve = None 90 | 91 | 92 | def encode(self, input: Tensor) -> List[Tensor]: 93 | """ 94 | Encodes the input by passing through the encoder network 95 | and returns the latent codes. 96 | :param input: (Tensor) Input tensor to encoder [N x C x H x W] 97 | :return: (Tensor) List of latent codes 98 | """ 99 | result = self.encoder(input) 100 | result = torch.flatten(result, start_dim=1) 101 | 102 | # Split the result into mu and var components 103 | # of the latent Gaussian distribution 104 | mu = self.fc_mu(result) 105 | log_var = self.fc_var(result) 106 | 107 | return [mu, log_var] 108 | 109 | def decode(self, z: Tensor) -> Tensor: 110 | """ 111 | Maps the given latent codes 112 | onto the image space. 113 | :param z: (Tensor) [B x D] 114 | :return: (Tensor) [B x C x H x W] 115 | """ 116 | result = self.decoder_input(z) 117 | result = result.view(-1, 512, 2, 2) 118 | result = self.decoder(result) 119 | result = self.final_layer(result) 120 | return result 121 | 122 | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor: 123 | """ 124 | Reparameterization trick to sample from N(mu, var) from 125 | N(0,1). 126 | :param mu: (Tensor) Mean of the latent Gaussian [B x D] 127 | :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D] 128 | :return: (Tensor) [B x D] 129 | """ 130 | std = torch.exp(0.5 * logvar) 131 | eps = torch.randn_like(std) 132 | return eps * std + mu 133 | 134 | def forward(self, input: Tensor, **kwargs) -> List[Tensor]: 135 | mu, log_var = self.encode(input) 136 | z = self.reparameterize(mu, log_var) 137 | return [self.decode(z), input, mu, log_var, z] 138 | 139 | def permute_latent(self, z: Tensor) -> Tensor: 140 | """ 141 | Permutes each of the latent codes in the batch 142 | :param z: [B x D] 143 | :return: [B x D] 144 | """ 145 | B, D = z.size() 146 | 147 | # Returns a shuffled inds for each latent code in the batch 148 | inds = torch.cat([(D *i) + torch.randperm(D) for i in range(B)]) 149 | return z.view(-1)[inds].view(B, D) 150 | 151 | def loss_function(self, 152 | *args, 153 | **kwargs) -> dict: 154 | """ 155 | Computes the VAE loss function. 156 | KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2} 157 | :param args: 158 | :param kwargs: 159 | :return: 160 | """ 161 | recons = args[0] 162 | input = args[1] 163 | mu = args[2] 164 | log_var = args[3] 165 | z = args[4] 166 | 167 | kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset 168 | optimizer_idx = kwargs['optimizer_idx'] 169 | 170 | # Update the VAE 171 | if optimizer_idx == 0: 172 | recons_loss =F.mse_loss(recons, input) 173 | kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0) 174 | 175 | self.D_z_reserve = self.discriminator(z) 176 | vae_tc_loss = (self.D_z_reserve[:, 0] - self.D_z_reserve[:, 1]).mean() 177 | 178 | loss = recons_loss + kld_weight * kld_loss + self.gamma * vae_tc_loss 179 | 180 | # print(f' recons: {recons_loss}, kld: {kld_loss}, VAE_TC_loss: {vae_tc_loss}') 181 | return {'loss': loss, 182 | 'Reconstruction_Loss':recons_loss, 183 | 'KLD':-kld_loss, 184 | 'VAE_TC_Loss': vae_tc_loss} 185 | 186 | # Update the Discriminator 187 | elif optimizer_idx == 1: 188 | device = input.device 189 | true_labels = torch.ones(input.size(0), dtype= torch.long, 190 | requires_grad=False).to(device) 191 | false_labels = torch.zeros(input.size(0), dtype= torch.long, 192 | requires_grad=False).to(device) 193 | 194 | z = z.detach() # Detach so that VAE is not trained again 195 | z_perm = self.permute_latent(z) 196 | D_z_perm = self.discriminator(z_perm) 197 | D_tc_loss = 0.5 * (F.cross_entropy(self.D_z_reserve, false_labels) + 198 | F.cross_entropy(D_z_perm, true_labels)) 199 | # print(f'D_TC: {D_tc_loss}') 200 | return {'loss': D_tc_loss, 201 | 'D_TC_Loss':D_tc_loss} 202 | 203 | def sample(self, 204 | num_samples:int, 205 | current_device: int, **kwargs) -> Tensor: 206 | """ 207 | Samples from the latent space and return the corresponding 208 | image space map. 209 | :param num_samples: (Int) Number of samples 210 | :param current_device: (Int) Device to run the model 211 | :return: (Tensor) 212 | """ 213 | z = torch.randn(num_samples, 214 | self.latent_dim) 215 | 216 | z = z.to(current_device) 217 | 218 | samples = self.decode(z) 219 | return samples 220 | 221 | def generate(self, x: Tensor, **kwargs) -> Tensor: 222 | """ 223 | Given an input image x, returns the reconstructed image 224 | :param x: (Tensor) [B x C x H x W] 225 | :return: (Tensor) [B x C x H x W] 226 | """ 227 | 228 | return self.forward(x)[0] -------------------------------------------------------------------------------- /models/gamma_vae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from models import BaseVAE 3 | from torch import nn 4 | from torch.distributions import Gamma 5 | from torch.nn import functional as F 6 | from .types_ import * 7 | import torch.nn.init as init 8 | 9 | 10 | class GammaVAE(BaseVAE): 11 | 12 | def __init__(self, 13 | in_channels: int, 14 | latent_dim: int, 15 | hidden_dims: List = None, 16 | gamma_shape: float = 8., 17 | prior_shape: float = 2.0, 18 | prior_rate: float = 1., 19 | **kwargs) -> None: 20 | super(GammaVAE, self).__init__() 21 | self.latent_dim = latent_dim 22 | self.B = gamma_shape 23 | 24 | self.prior_alpha = torch.tensor([prior_shape]) 25 | self.prior_beta = torch.tensor([prior_rate]) 26 | 27 | modules = [] 28 | if hidden_dims is None: 29 | hidden_dims = [32, 64, 128, 256, 512] 30 | 31 | # Build Encoder 32 | for h_dim in hidden_dims: 33 | modules.append( 34 | nn.Sequential( 35 | nn.Conv2d(in_channels, out_channels=h_dim, 36 | kernel_size=3, stride=2, padding=1), 37 | nn.BatchNorm2d(h_dim), 38 | nn.LeakyReLU()) 39 | ) 40 | in_channels = h_dim 41 | 42 | self.encoder = nn.Sequential(*modules) 43 | self.fc_mu = nn.Sequential(nn.Linear(hidden_dims[-1] * 4, latent_dim), 44 | nn.Softmax()) 45 | self.fc_var = nn.Sequential(nn.Linear(hidden_dims[-1] * 4, latent_dim), 46 | nn.Softmax()) 47 | 48 | # Build Decoder 49 | modules = [] 50 | 51 | self.decoder_input = nn.Sequential(nn.Linear(latent_dim, hidden_dims[-1] * 4)) 52 | 53 | hidden_dims.reverse() 54 | 55 | for i in range(len(hidden_dims) - 1): 56 | modules.append( 57 | nn.Sequential( 58 | nn.ConvTranspose2d(hidden_dims[i], 59 | hidden_dims[i + 1], 60 | kernel_size=3, 61 | stride=2, 62 | padding=1, 63 | output_padding=1), 64 | nn.BatchNorm2d(hidden_dims[i + 1]), 65 | nn.LeakyReLU()) 66 | ) 67 | 68 | self.decoder = nn.Sequential(*modules) 69 | 70 | self.final_layer = nn.Sequential( 71 | nn.ConvTranspose2d(hidden_dims[-1], 72 | hidden_dims[-1], 73 | kernel_size=3, 74 | stride=2, 75 | padding=1, 76 | output_padding=1), 77 | nn.BatchNorm2d(hidden_dims[-1]), 78 | nn.LeakyReLU(), 79 | nn.Conv2d(hidden_dims[-1], out_channels=3, 80 | kernel_size=3, padding=1), 81 | nn.Sigmoid()) 82 | 83 | self.weight_init() 84 | 85 | def weight_init(self): 86 | 87 | # print(self._modules) 88 | for block in self._modules: 89 | for m in self._modules[block]: 90 | init_(m) 91 | 92 | def encode(self, input: Tensor) -> List[Tensor]: 93 | """ 94 | Encodes the input by passing through the encoder network 95 | and returns the latent codes. 96 | :param input: (Tensor) Input tensor to encoder [N x C x H x W] 97 | :return: (Tensor) List of latent codes 98 | """ 99 | result = self.encoder(input) 100 | result = torch.flatten(result, start_dim=1) 101 | 102 | # Split the result into mu and var components 103 | # of the latent Gaussian distribution 104 | alpha = self.fc_mu(result) 105 | beta = self.fc_var(result) 106 | 107 | return [alpha, beta] 108 | 109 | def decode(self, z: Tensor) -> Tensor: 110 | result = self.decoder_input(z) 111 | result = result.view(-1, 512, 2, 2) 112 | result = self.decoder(result) 113 | result = self.final_layer(result) 114 | return result 115 | 116 | def reparameterize(self, alpha: Tensor, beta: Tensor) -> Tensor: 117 | """ 118 | Reparameterize the Gamma distribution by the shape augmentation trick. 119 | Reference: 120 | [1] https://arxiv.org/pdf/1610.05683.pdf 121 | 122 | :param alpha: (Tensor) Shape parameter of the latent Gamma 123 | :param beta: (Tensor) Rate parameter of the latent Gamma 124 | :return: 125 | """ 126 | # Sample from Gamma to guarantee acceptance 127 | alpha_ = alpha.clone().detach() 128 | z_hat = Gamma(alpha_ + self.B, torch.ones_like(alpha_)).sample() 129 | 130 | # Compute the eps ~ N(0,1) that produces z_hat 131 | eps = self.inv_h_func(alpha + self.B , z_hat) 132 | z = self.h_func(alpha + self.B, eps) 133 | 134 | # When beta != 1, scale by beta 135 | return z / beta 136 | 137 | def h_func(self, alpha: Tensor, eps: Tensor) -> Tensor: 138 | """ 139 | Reparameterize a sample eps ~ N(0, 1) so that h(z) ~ Gamma(alpha, 1) 140 | :param alpha: (Tensor) Shape parameter 141 | :param eps: (Tensor) Random sample to reparameterize 142 | :return: (Tensor) 143 | """ 144 | 145 | z = (alpha - 1./3.) * (1 + eps / torch.sqrt(9. * alpha - 3.))**3 146 | return z 147 | 148 | def inv_h_func(self, alpha: Tensor, z: Tensor) -> Tensor: 149 | """ 150 | Inverse reparameterize the given z into eps. 151 | :param alpha: (Tensor) 152 | :param z: (Tensor) 153 | :return: (Tensor) 154 | """ 155 | eps = torch.sqrt(9. * alpha - 3.) * ((z / (alpha - 1./3.))**(1. / 3.) - 1.) 156 | return eps 157 | 158 | def forward(self, input: Tensor, **kwargs) -> Tensor: 159 | alpha, beta = self.encode(input) 160 | z = self.reparameterize(alpha, beta) 161 | return [self.decode(z), input, alpha, beta] 162 | 163 | # def I_function(self, alpha_p, beta_p, alpha_q, beta_q): 164 | # return - (alpha_q * beta_q) / alpha_p - \ 165 | # beta_p * torch.log(alpha_p) - torch.lgamma(beta_p) + \ 166 | # (beta_p - 1) * torch.digamma(beta_q) + \ 167 | # (beta_p - 1) * torch.log(alpha_q) 168 | def I_function(self, a, b, c, d): 169 | return - c * d / a - b * torch.log(a) - torch.lgamma(b) + (b - 1) * (torch.digamma(d) + torch.log(c)) 170 | 171 | def vae_gamma_kl_loss(self, a, b, c, d): 172 | """ 173 | https://stats.stackexchange.com/questions/11646/kullback-leibler-divergence-between-two-gamma-distributions 174 | b and d are Gamma shape parameters and 175 | a and c are scale parameters. 176 | (All, therefore, must be positive.) 177 | """ 178 | 179 | a = 1 / a 180 | c = 1 / c 181 | losses = self.I_function(c, d, c, d) - self.I_function(a, b, c, d) 182 | return torch.sum(losses, dim=1) 183 | 184 | def loss_function(self, 185 | *args, 186 | **kwargs) -> dict: 187 | recons = args[0] 188 | input = args[1] 189 | alpha = args[2] 190 | beta = args[3] 191 | 192 | curr_device = input.device 193 | kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset 194 | recons_loss = torch.mean(F.mse_loss(recons, input, reduction = 'none'), dim = (1,2,3)) 195 | 196 | # https://stats.stackexchange.com/questions/11646/kullback-leibler-divergence-between-two-gamma-distributions 197 | # alpha = 1./ alpha 198 | 199 | 200 | self.prior_alpha = self.prior_alpha.to(curr_device) 201 | self.prior_beta = self.prior_beta.to(curr_device) 202 | 203 | # kld_loss = - self.I_function(alpha, beta, self.prior_alpha, self.prior_beta) 204 | 205 | kld_loss = self.vae_gamma_kl_loss(alpha, beta, self.prior_alpha, self.prior_beta) 206 | 207 | # kld_loss = torch.sum(kld_loss, dim=1) 208 | 209 | loss = recons_loss + kld_loss 210 | loss = torch.mean(loss, dim = 0) 211 | # print(loss, recons_loss, kld_loss) 212 | return {'loss': loss} #, 'Reconstruction_Loss': recons_loss, 'KLD': -kld_loss} 213 | 214 | def sample(self, 215 | num_samples:int, 216 | current_device: int, **kwargs) -> Tensor: 217 | """ 218 | Samples from the latent space and return the corresponding 219 | image space map. 220 | :param num_samples: (Int) Number of samples 221 | :param current_device: (Int) Device to run the modelSay 222 | :return: (Tensor) 223 | """ 224 | z = Gamma(self.prior_alpha, self.prior_beta).sample((num_samples, self.latent_dim)) 225 | z = z.squeeze().to(current_device) 226 | 227 | samples = self.decode(z) 228 | return samples 229 | 230 | def generate(self, x: Tensor, **kwargs) -> Tensor: 231 | """ 232 | Given an input image x, returns the reconstructed image 233 | :param x: (Tensor) [B x C x H x W] 234 | :return: (Tensor) [B x C x H x W] 235 | """ 236 | 237 | return self.forward(x)[0] 238 | 239 | def init_(m): 240 | if isinstance(m, (nn.Linear, nn.Conv2d)): 241 | init.orthogonal_(m.weight) 242 | if m.bias is not None: 243 | m.bias.data.fill_(0) 244 | elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)): 245 | m.weight.data.fill_(1) 246 | if m.bias is not None: 247 | m.bias.data.fill_(0) 248 | -------------------------------------------------------------------------------- /models/hvae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from models import BaseVAE 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from .types_ import * 6 | 7 | 8 | class HVAE(BaseVAE): 9 | 10 | def __init__(self, 11 | in_channels: int, 12 | latent1_dim: int, 13 | latent2_dim: int, 14 | hidden_dims: List = None, 15 | img_size:int = 64, 16 | pseudo_input_size: int = 128, 17 | **kwargs) -> None: 18 | super(HVAE, self).__init__() 19 | 20 | self.latent1_dim = latent1_dim 21 | self.latent2_dim = latent2_dim 22 | self.img_size = img_size 23 | 24 | modules = [] 25 | if hidden_dims is None: 26 | hidden_dims = [32, 64, 128, 256, 512] 27 | channels = in_channels 28 | 29 | # Build z2 Encoder 30 | for h_dim in hidden_dims: 31 | modules.append( 32 | nn.Sequential( 33 | nn.Conv2d(channels, out_channels=h_dim, 34 | kernel_size= 3, stride= 2, padding = 1), 35 | nn.BatchNorm2d(h_dim), 36 | nn.LeakyReLU()) 37 | ) 38 | channels = h_dim 39 | 40 | self.encoder_z2_layers = nn.Sequential(*modules) 41 | self.fc_z2_mu = nn.Linear(hidden_dims[-1]*4, latent2_dim) 42 | self.fc_z2_var = nn.Linear(hidden_dims[-1]*4, latent2_dim) 43 | # ========================================================================# 44 | # Build z1 Encoder 45 | self.embed_z2_code = nn.Linear(latent2_dim, img_size * img_size) 46 | self.embed_data = nn.Conv2d(in_channels, in_channels, kernel_size=1) 47 | 48 | modules = [] 49 | channels = in_channels + 1 # One more channel for the latent code 50 | for h_dim in hidden_dims: 51 | modules.append( 52 | nn.Sequential( 53 | nn.Conv2d(channels, out_channels=h_dim, 54 | kernel_size= 3, stride= 2, padding = 1), 55 | nn.BatchNorm2d(h_dim), 56 | nn.LeakyReLU()) 57 | ) 58 | channels = h_dim 59 | 60 | self.encoder_z1_layers = nn.Sequential(*modules) 61 | self.fc_z1_mu = nn.Linear(hidden_dims[-1]*4, latent1_dim) 62 | self.fc_z1_var = nn.Linear(hidden_dims[-1]*4, latent1_dim) 63 | 64 | #========================================================================# 65 | # Build z2 Decoder 66 | self.recons_z1_mu = nn.Linear(latent2_dim, latent1_dim) 67 | self.recons_z1_log_var = nn.Linear(latent2_dim, latent1_dim) 68 | 69 | # ========================================================================# 70 | # Build z1 Decoder 71 | self.debed_z1_code = nn.Linear(latent1_dim, 1024) 72 | self.debed_z2_code = nn.Linear(latent2_dim, 1024) 73 | modules = [] 74 | hidden_dims.reverse() 75 | 76 | for i in range(len(hidden_dims) - 1): 77 | modules.append( 78 | nn.Sequential( 79 | nn.ConvTranspose2d(hidden_dims[i], 80 | hidden_dims[i + 1], 81 | kernel_size=3, 82 | stride = 2, 83 | padding=1, 84 | output_padding=1), 85 | nn.BatchNorm2d(hidden_dims[i + 1]), 86 | nn.LeakyReLU()) 87 | ) 88 | 89 | 90 | 91 | self.decoder = nn.Sequential(*modules) 92 | 93 | self.final_layer = nn.Sequential( 94 | nn.ConvTranspose2d(hidden_dims[-1], 95 | hidden_dims[-1], 96 | kernel_size=3, 97 | stride=2, 98 | padding=1, 99 | output_padding=1), 100 | nn.BatchNorm2d(hidden_dims[-1]), 101 | nn.LeakyReLU(), 102 | nn.Conv2d(hidden_dims[-1], out_channels= 3, 103 | kernel_size= 3, padding= 1), 104 | nn.Tanh()) 105 | 106 | # ========================================================================# 107 | # Pesudo Input for the Vamp-Prior 108 | # self.pseudo_input = torch.eye(pseudo_input_size, 109 | # requires_grad=False).view(1, 1, pseudo_input_size, -1) 110 | # 111 | # 112 | # self.pseudo_layer = nn.Conv2d(1, out_channels=in_channels, 113 | # kernel_size=3, stride=2, padding=1) 114 | 115 | def encode_z2(self, input: Tensor) -> List[Tensor]: 116 | """ 117 | Encodes the input by passing through the encoder network 118 | and returns the latent codes. 119 | :param input: (Tensor) Input tensor to encoder [N x C x H x W] 120 | :return: (Tensor) List of latent codes 121 | """ 122 | result = self.encoder_z2_layers(input) 123 | result = torch.flatten(result, start_dim=1) 124 | 125 | # Split the result into mu and var components 126 | # of the latent Gaussian distribution 127 | z2_mu = self.fc_z2_mu(result) 128 | z2_log_var = self.fc_z2_var(result) 129 | 130 | return [z2_mu, z2_log_var] 131 | 132 | def encode_z1(self, input: Tensor, z2: Tensor) -> List[Tensor]: 133 | x = self.embed_data(input) 134 | z2 = self.embed_z2_code(z2) 135 | z2 = z2.view(-1, self.img_size, self.img_size).unsqueeze(1) 136 | result = torch.cat([x, z2], dim=1) 137 | 138 | result = self.encoder_z1_layers(result) 139 | result = torch.flatten(result, start_dim=1) 140 | z1_mu = self.fc_z1_mu(result) 141 | z1_log_var = self.fc_z1_var(result) 142 | 143 | return [z1_mu, z1_log_var] 144 | 145 | def encode(self, input: Tensor) -> List[Tensor]: 146 | z2_mu, z2_log_var = self.encode_z2(input) 147 | z2 = self.reparameterize(z2_mu, z2_log_var) 148 | 149 | # z1 ~ q(z1|x, z2) 150 | z1_mu, z1_log_var = self.encode_z1(input, z2) 151 | return [z1_mu, z1_log_var, z2_mu, z2_log_var, z2] 152 | 153 | def decode(self, input: Tensor) -> Tensor: 154 | result = self.decoder(input) 155 | result = self.final_layer(result) 156 | return result 157 | 158 | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor: 159 | """ 160 | Will a single z be enough ti compute the expectation 161 | for the loss?? 162 | :param mu: (Tensor) Mean of the latent Gaussian 163 | :param logvar: (Tensor) Standard deviation of the latent Gaussian 164 | :return: 165 | """ 166 | std = torch.exp(0.5 * logvar) 167 | eps = torch.randn_like(std) 168 | return eps * std + mu 169 | 170 | def forward(self, input: Tensor, **kwargs) -> List[Tensor]: 171 | 172 | # Encode the input into the latent codes z1 and z2 173 | # z2 ~q(z2 | x) 174 | # z1 ~ q(z1|x, z2) 175 | z1_mu, z1_log_var, z2_mu, z2_log_var, z2 = self.encode(input) 176 | z1 = self.reparameterize(z1_mu, z1_log_var) 177 | 178 | # Reconstruct the image using both the latent codes 179 | # x ~ p(x|z1, z2) 180 | debedded_z1 = self.debed_z1_code(z1) 181 | debedded_z2 = self.debed_z2_code(z2) 182 | result = torch.cat([debedded_z1, debedded_z2], dim=1) 183 | result = result.view(-1, 512, 2, 2) 184 | recons = self.decode(result) 185 | 186 | return [recons, 187 | input, 188 | z1_mu, z1_log_var, 189 | z2_mu, z2_log_var, 190 | z1, z2] 191 | 192 | def loss_function(self, 193 | *args, 194 | **kwargs) -> dict: 195 | recons = args[0] 196 | input = args[1] 197 | 198 | z1_mu = args[2] 199 | z1_log_var = args[3] 200 | 201 | z2_mu = args[4] 202 | z2_log_var = args[5] 203 | 204 | z1= args[6] 205 | z2 = args[7] 206 | 207 | # Reconstruct (decode) z2 into z1 208 | # z1 ~ p(z1|z2) [This for the loss calculation] 209 | z1_p_mu = self.recons_z1_mu(z2) 210 | z1_p_log_var = self.recons_z1_log_var(z2) 211 | 212 | 213 | kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset 214 | recons_loss =F.mse_loss(recons, input) 215 | 216 | z1_kld = torch.mean(-0.5 * torch.sum(1 + z1_log_var - z1_mu ** 2 - z1_log_var.exp(), dim = 1), 217 | dim = 0) 218 | z2_kld = torch.mean(-0.5 * torch.sum(1 + z2_log_var - z2_mu ** 2 - z2_log_var.exp(), dim = 1), 219 | dim = 0) 220 | 221 | z1_p_kld = torch.mean(-0.5 * torch.sum(1 + z1_p_log_var - (z1 - z1_p_mu) ** 2 - z1_p_log_var.exp(), 222 | dim = 1), 223 | dim = 0) 224 | 225 | z2_p_kld = torch.mean(-0.5*(z2**2), dim = 0) 226 | 227 | kld_loss = -(z1_p_kld - z1_kld - z2_kld) 228 | loss = recons_loss + kld_weight * kld_loss 229 | # print(z2_p_kld) 230 | 231 | return {'loss': loss, 'Reconstruction Loss':recons_loss, 'KLD':-kld_loss} 232 | 233 | def sample(self, batch_size:int, current_device: int, **kwargs) -> Tensor: 234 | z2 = torch.randn(batch_size, 235 | self.latent2_dim) 236 | 237 | z2 = z2.cuda(current_device) 238 | 239 | z1_mu = self.recons_z1_mu(z2) 240 | z1_log_var = self.recons_z1_log_var(z2) 241 | z1 = self.reparameterize(z1_mu, z1_log_var) 242 | 243 | debedded_z1 = self.debed_z1_code(z1) 244 | debedded_z2 = self.debed_z2_code(z2) 245 | 246 | result = torch.cat([debedded_z1, debedded_z2], dim=1) 247 | result = result.view(-1, 512, 2, 2) 248 | samples = self.decode(result) 249 | 250 | return samples 251 | 252 | def generate(self, x: Tensor, **kwargs) -> Tensor: 253 | """ 254 | Given an input image x, returns the reconstructed image 255 | :param x: (Tensor) [B x C x H x W] 256 | :return: (Tensor) [B x C x H x W] 257 | """ 258 | 259 | return self.forward(x)[0] 260 | -------------------------------------------------------------------------------- /models/info_vae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from models import BaseVAE 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from .types_ import * 6 | 7 | 8 | class InfoVAE(BaseVAE): 9 | 10 | def __init__(self, 11 | in_channels: int, 12 | latent_dim: int, 13 | hidden_dims: List = None, 14 | alpha: float = -0.5, 15 | beta: float = 5.0, 16 | reg_weight: int = 100, 17 | kernel_type: str = 'imq', 18 | latent_var: float = 2., 19 | **kwargs) -> None: 20 | super(InfoVAE, self).__init__() 21 | 22 | self.latent_dim = latent_dim 23 | self.reg_weight = reg_weight 24 | self.kernel_type = kernel_type 25 | self.z_var = latent_var 26 | 27 | assert alpha <= 0, 'alpha must be negative or zero.' 28 | 29 | self.alpha = alpha 30 | self.beta = beta 31 | 32 | modules = [] 33 | if hidden_dims is None: 34 | hidden_dims = [32, 64, 128, 256, 512] 35 | 36 | # Build Encoder 37 | for h_dim in hidden_dims: 38 | modules.append( 39 | nn.Sequential( 40 | nn.Conv2d(in_channels, out_channels=h_dim, 41 | kernel_size= 3, stride= 2, padding = 1), 42 | nn.BatchNorm2d(h_dim), 43 | nn.LeakyReLU()) 44 | ) 45 | in_channels = h_dim 46 | 47 | self.encoder = nn.Sequential(*modules) 48 | self.fc_mu = nn.Linear(hidden_dims[-1] * 4, latent_dim) 49 | self.fc_var = nn.Linear(hidden_dims[-1] * 4, latent_dim) 50 | 51 | # Build Decoder 52 | modules = [] 53 | 54 | self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4) 55 | 56 | hidden_dims.reverse() 57 | 58 | for i in range(len(hidden_dims) - 1): 59 | modules.append( 60 | nn.Sequential( 61 | nn.ConvTranspose2d(hidden_dims[i], 62 | hidden_dims[i + 1], 63 | kernel_size=3, 64 | stride = 2, 65 | padding=1, 66 | output_padding=1), 67 | nn.BatchNorm2d(hidden_dims[i + 1]), 68 | nn.LeakyReLU()) 69 | ) 70 | 71 | 72 | 73 | self.decoder = nn.Sequential(*modules) 74 | 75 | self.final_layer = nn.Sequential( 76 | nn.ConvTranspose2d(hidden_dims[-1], 77 | hidden_dims[-1], 78 | kernel_size=3, 79 | stride=2, 80 | padding=1, 81 | output_padding=1), 82 | nn.BatchNorm2d(hidden_dims[-1]), 83 | nn.LeakyReLU(), 84 | nn.Conv2d(hidden_dims[-1], out_channels= 3, 85 | kernel_size= 3, padding= 1), 86 | nn.Tanh()) 87 | 88 | def encode(self, input: Tensor) -> List[Tensor]: 89 | """ 90 | Encodes the input by passing through the encoder network 91 | and returns the latent codes. 92 | :param input: (Tensor) Input tensor to encoder [N x C x H x W] 93 | :return: (Tensor) List of latent codes 94 | """ 95 | result = self.encoder(input) 96 | result = torch.flatten(result, start_dim=1) 97 | 98 | # Split the result into mu and var components 99 | # of the latent Gaussian distribution 100 | mu = self.fc_mu(result) 101 | log_var = self.fc_var(result) 102 | return [mu, log_var] 103 | 104 | def decode(self, z: Tensor) -> Tensor: 105 | result = self.decoder_input(z) 106 | result = result.view(-1, 512, 2, 2) 107 | result = self.decoder(result) 108 | result = self.final_layer(result) 109 | return result 110 | 111 | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor: 112 | """ 113 | Reparameterization trick to sample from N(mu, var) from 114 | N(0,1). 115 | :param mu: (Tensor) Mean of the latent Gaussian [B x D] 116 | :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D] 117 | :return: (Tensor) [B x D] 118 | """ 119 | std = torch.exp(0.5 * logvar) 120 | eps = torch.randn_like(std) 121 | return eps * std + mu 122 | 123 | def forward(self, input: Tensor, **kwargs) -> List[Tensor]: 124 | mu, log_var = self.encode(input) 125 | z = self.reparameterize(mu, log_var) 126 | return [self.decode(z), input, z, mu, log_var] 127 | 128 | def loss_function(self, 129 | *args, 130 | **kwargs) -> dict: 131 | recons = args[0] 132 | input = args[1] 133 | z = args[2] 134 | mu = args[3] 135 | log_var = args[4] 136 | 137 | batch_size = input.size(0) 138 | bias_corr = batch_size * (batch_size - 1) 139 | kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset 140 | 141 | recons_loss =F.mse_loss(recons, input) 142 | mmd_loss = self.compute_mmd(z) 143 | kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim=1), dim=0) 144 | 145 | loss = self.beta * recons_loss + \ 146 | (1. - self.alpha) * kld_weight * kld_loss + \ 147 | (self.alpha + self.reg_weight - 1.)/bias_corr * mmd_loss 148 | return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'MMD': mmd_loss, 'KLD':-kld_loss} 149 | 150 | def compute_kernel(self, 151 | x1: Tensor, 152 | x2: Tensor) -> Tensor: 153 | # Convert the tensors into row and column vectors 154 | D = x1.size(1) 155 | N = x1.size(0) 156 | 157 | x1 = x1.unsqueeze(-2) # Make it into a column tensor 158 | x2 = x2.unsqueeze(-3) # Make it into a row tensor 159 | 160 | """ 161 | Usually the below lines are not required, especially in our case, 162 | but this is useful when x1 and x2 have different sizes 163 | along the 0th dimension. 164 | """ 165 | x1 = x1.expand(N, N, D) 166 | x2 = x2.expand(N, N, D) 167 | 168 | if self.kernel_type == 'rbf': 169 | result = self.compute_rbf(x1, x2) 170 | elif self.kernel_type == 'imq': 171 | result = self.compute_inv_mult_quad(x1, x2) 172 | else: 173 | raise ValueError('Undefined kernel type.') 174 | 175 | return result 176 | 177 | 178 | def compute_rbf(self, 179 | x1: Tensor, 180 | x2: Tensor, 181 | eps: float = 1e-7) -> Tensor: 182 | """ 183 | Computes the RBF Kernel between x1 and x2. 184 | :param x1: (Tensor) 185 | :param x2: (Tensor) 186 | :param eps: (Float) 187 | :return: 188 | """ 189 | z_dim = x2.size(-1) 190 | sigma = 2. * z_dim * self.z_var 191 | 192 | result = torch.exp(-((x1 - x2).pow(2).mean(-1) / sigma)) 193 | return result 194 | 195 | def compute_inv_mult_quad(self, 196 | x1: Tensor, 197 | x2: Tensor, 198 | eps: float = 1e-7) -> Tensor: 199 | """ 200 | Computes the Inverse Multi-Quadratics Kernel between x1 and x2, 201 | given by 202 | 203 | k(x_1, x_2) = \sum \frac{C}{C + \|x_1 - x_2 \|^2} 204 | :param x1: (Tensor) 205 | :param x2: (Tensor) 206 | :param eps: (Float) 207 | :return: 208 | """ 209 | z_dim = x2.size(-1) 210 | C = 2 * z_dim * self.z_var 211 | kernel = C / (eps + C + (x1 - x2).pow(2).sum(dim = -1)) 212 | 213 | # Exclude diagonal elements 214 | result = kernel.sum() - kernel.diag().sum() 215 | 216 | return result 217 | 218 | def compute_mmd(self, z: Tensor) -> Tensor: 219 | # Sample from prior (Gaussian) distribution 220 | prior_z = torch.randn_like(z) 221 | 222 | prior_z__kernel = self.compute_kernel(prior_z, prior_z) 223 | z__kernel = self.compute_kernel(z, z) 224 | priorz_z__kernel = self.compute_kernel(prior_z, z) 225 | 226 | mmd = prior_z__kernel.mean() + \ 227 | z__kernel.mean() - \ 228 | 2 * priorz_z__kernel.mean() 229 | return mmd 230 | 231 | def sample(self, 232 | num_samples:int, 233 | current_device: int, **kwargs) -> Tensor: 234 | """ 235 | Samples from the latent space and return the corresponding 236 | image space map. 237 | :param num_samples: (Int) Number of samples 238 | :param current_device: (Int) Device to run the model 239 | :return: (Tensor) 240 | """ 241 | z = torch.randn(num_samples, 242 | self.latent_dim) 243 | 244 | z = z.to(current_device) 245 | 246 | samples = self.decode(z) 247 | return samples 248 | 249 | def generate(self, x: Tensor, **kwargs) -> Tensor: 250 | """ 251 | Given an input image x, returns the reconstructed image 252 | :param x: (Tensor) [B x C x H x W] 253 | :return: (Tensor) [B x C x H x W] 254 | """ 255 | 256 | return self.forward(x)[0] -------------------------------------------------------------------------------- /models/iwae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from models import BaseVAE 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from .types_ import * 6 | 7 | 8 | class IWAE(BaseVAE): 9 | 10 | def __init__(self, 11 | in_channels: int, 12 | latent_dim: int, 13 | hidden_dims: List = None, 14 | num_samples: int = 5, 15 | **kwargs) -> None: 16 | super(IWAE, self).__init__() 17 | 18 | self.latent_dim = latent_dim 19 | self.num_samples = num_samples 20 | 21 | modules = [] 22 | if hidden_dims is None: 23 | hidden_dims = [32, 64, 128, 256, 512] 24 | 25 | # Build Encoder 26 | for h_dim in hidden_dims: 27 | modules.append( 28 | nn.Sequential( 29 | nn.Conv2d(in_channels, out_channels=h_dim, 30 | kernel_size= 3, stride= 2, padding = 1), 31 | nn.BatchNorm2d(h_dim), 32 | nn.LeakyReLU()) 33 | ) 34 | in_channels = h_dim 35 | 36 | self.encoder = nn.Sequential(*modules) 37 | self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim) 38 | self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim) 39 | 40 | 41 | # Build Decoder 42 | modules = [] 43 | 44 | self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4) 45 | 46 | hidden_dims.reverse() 47 | 48 | for i in range(len(hidden_dims) - 1): 49 | modules.append( 50 | nn.Sequential( 51 | nn.ConvTranspose2d(hidden_dims[i], 52 | hidden_dims[i + 1], 53 | kernel_size=3, 54 | stride = 2, 55 | padding=1, 56 | output_padding=1), 57 | nn.BatchNorm2d(hidden_dims[i + 1]), 58 | nn.LeakyReLU()) 59 | ) 60 | 61 | 62 | 63 | self.decoder = nn.Sequential(*modules) 64 | 65 | self.final_layer = nn.Sequential( 66 | nn.ConvTranspose2d(hidden_dims[-1], 67 | hidden_dims[-1], 68 | kernel_size=3, 69 | stride=2, 70 | padding=1, 71 | output_padding=1), 72 | nn.BatchNorm2d(hidden_dims[-1]), 73 | nn.LeakyReLU(), 74 | nn.Conv2d(hidden_dims[-1], out_channels= 3, 75 | kernel_size= 3, padding= 1), 76 | nn.Tanh()) 77 | 78 | def encode(self, input: Tensor) -> List[Tensor]: 79 | """ 80 | Encodes the input by passing through the encoder network 81 | and returns the latent codes. 82 | :param input: (Tensor) Input tensor to encoder [N x C x H x W] 83 | :return: (Tensor) List of latent codes 84 | """ 85 | result = self.encoder(input) 86 | result = torch.flatten(result, start_dim=1) 87 | 88 | # Split the result into mu and var components 89 | # of the latent Gaussian distribution 90 | mu = self.fc_mu(result) 91 | log_var = self.fc_var(result) 92 | 93 | return [mu, log_var] 94 | 95 | def decode(self, z: Tensor) -> Tensor: 96 | """ 97 | Maps the given latent codes of S samples 98 | onto the image space. 99 | :param z: (Tensor) [B x S x D] 100 | :return: (Tensor) [B x S x C x H x W] 101 | """ 102 | B, _, _ = z.size() 103 | z = z.view(-1, self.latent_dim) #[BS x D] 104 | result = self.decoder_input(z) 105 | result = result.view(-1, 512, 2, 2) 106 | result = self.decoder(result) 107 | result = self.final_layer(result) #[BS x C x H x W ] 108 | result = result.view([B, -1, result.size(1), result.size(2), result.size(3)]) #[B x S x C x H x W] 109 | return result 110 | 111 | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor: 112 | """ 113 | :param mu: (Tensor) Mean of the latent Gaussian 114 | :param logvar: (Tensor) Standard deviation of the latent Gaussian 115 | :return: 116 | """ 117 | std = torch.exp(0.5 * logvar) 118 | eps = torch.randn_like(std) 119 | return eps * std + mu 120 | 121 | def forward(self, input: Tensor, **kwargs) -> List[Tensor]: 122 | mu, log_var = self.encode(input) 123 | mu = mu.repeat(self.num_samples, 1, 1).permute(1, 0, 2) # [B x S x D] 124 | log_var = log_var.repeat(self.num_samples, 1, 1).permute(1, 0, 2) # [B x S x D] 125 | z= self.reparameterize(mu, log_var) # [B x S x D] 126 | eps = (z - mu) / log_var # Prior samples 127 | return [self.decode(z), input, mu, log_var, z, eps] 128 | 129 | def loss_function(self, 130 | *args, 131 | **kwargs) -> dict: 132 | """ 133 | KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2} 134 | :param args: 135 | :param kwargs: 136 | :return: 137 | """ 138 | recons = args[0] 139 | input = args[1] 140 | mu = args[2] 141 | log_var = args[3] 142 | z = args[4] 143 | eps = args[5] 144 | 145 | input = input.repeat(self.num_samples, 1, 1, 1, 1).permute(1, 0, 2, 3, 4) #[B x S x C x H x W] 146 | 147 | kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset 148 | 149 | log_p_x_z = ((recons - input) ** 2).flatten(2).mean(-1) # Reconstruction Loss [B x S] 150 | kld_loss = -0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim=2) ## [B x S] 151 | # Get importance weights 152 | log_weight = (log_p_x_z + kld_weight * kld_loss) #.detach().data 153 | 154 | # Rescale the weights (along the sample dim) to lie in [0, 1] and sum to 1 155 | weight = F.softmax(log_weight, dim = -1) 156 | # kld_loss = torch.mean(kld_loss, dim = 0) 157 | 158 | loss = torch.mean(torch.sum(weight * log_weight, dim=-1), dim = 0) 159 | 160 | return {'loss': loss, 'Reconstruction_Loss':log_p_x_z.mean(), 'KLD':-kld_loss.mean()} 161 | 162 | def sample(self, 163 | num_samples:int, 164 | current_device: int, **kwargs) -> Tensor: 165 | """ 166 | Samples from the latent space and return the corresponding 167 | image space map. 168 | :param num_samples: (Int) Number of samples 169 | :param current_device: (Int) Device to run the model 170 | :return: (Tensor) 171 | """ 172 | z = torch.randn(num_samples, 1, 173 | self.latent_dim) 174 | 175 | z = z.to(current_device) 176 | 177 | samples = self.decode(z).squeeze() 178 | return samples 179 | 180 | def generate(self, x: Tensor, **kwargs) -> Tensor: 181 | """ 182 | Given an input image x, returns the reconstructed image. 183 | Returns only the first reconstructed sample 184 | :param x: (Tensor) [B x C x H x W] 185 | :return: (Tensor) [B x C x H x W] 186 | """ 187 | 188 | return self.forward(x)[0][:, 0, :] 189 | -------------------------------------------------------------------------------- /models/logcosh_vae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from models import BaseVAE 4 | from torch import nn 5 | from .types_ import * 6 | 7 | 8 | class LogCoshVAE(BaseVAE): 9 | 10 | def __init__(self, 11 | in_channels: int, 12 | latent_dim: int, 13 | hidden_dims: List = None, 14 | alpha: float = 100., 15 | beta: float = 10., 16 | **kwargs) -> None: 17 | super(LogCoshVAE, self).__init__() 18 | 19 | self.latent_dim = latent_dim 20 | self.alpha = alpha 21 | self.beta = beta 22 | 23 | modules = [] 24 | if hidden_dims is None: 25 | hidden_dims = [32, 64, 128, 256, 512] 26 | 27 | # Build Encoder 28 | for h_dim in hidden_dims: 29 | modules.append( 30 | nn.Sequential( 31 | nn.Conv2d(in_channels, out_channels=h_dim, 32 | kernel_size= 3, stride= 2, padding = 1), 33 | nn.BatchNorm2d(h_dim), 34 | nn.LeakyReLU()) 35 | ) 36 | in_channels = h_dim 37 | 38 | self.encoder = nn.Sequential(*modules) 39 | self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim) 40 | self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim) 41 | 42 | 43 | # Build Decoder 44 | modules = [] 45 | 46 | self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4) 47 | 48 | hidden_dims.reverse() 49 | 50 | for i in range(len(hidden_dims) - 1): 51 | modules.append( 52 | nn.Sequential( 53 | nn.ConvTranspose2d(hidden_dims[i], 54 | hidden_dims[i + 1], 55 | kernel_size=3, 56 | stride = 2, 57 | padding=1, 58 | output_padding=1), 59 | nn.BatchNorm2d(hidden_dims[i + 1]), 60 | nn.LeakyReLU()) 61 | ) 62 | 63 | self.decoder = nn.Sequential(*modules) 64 | 65 | self.final_layer = nn.Sequential( 66 | nn.ConvTranspose2d(hidden_dims[-1], 67 | hidden_dims[-1], 68 | kernel_size=3, 69 | stride=2, 70 | padding=1, 71 | output_padding=1), 72 | nn.BatchNorm2d(hidden_dims[-1]), 73 | nn.LeakyReLU(), 74 | nn.Conv2d(hidden_dims[-1], out_channels= 3, 75 | kernel_size= 3, padding= 1), 76 | nn.Tanh()) 77 | 78 | def encode(self, input: Tensor) -> List[Tensor]: 79 | """ 80 | Encodes the input by passing through the encoder network 81 | and returns the latent codes. 82 | :param input: (Tensor) Input tensor to encoder [N x C x H x W] 83 | :return: (Tensor) List of latent codes 84 | """ 85 | result = self.encoder(input) 86 | result = torch.flatten(result, start_dim=1) 87 | 88 | # Split the result into mu and var components 89 | # of the latent Gaussian distribution 90 | mu = self.fc_mu(result) 91 | log_var = self.fc_var(result) 92 | 93 | return [mu, log_var] 94 | 95 | def decode(self, z: Tensor) -> Tensor: 96 | """ 97 | Maps the given latent codes 98 | onto the image space. 99 | :param z: (Tensor) [B x D] 100 | :return: (Tensor) [B x C x H x W] 101 | """ 102 | result = self.decoder_input(z) 103 | result = result.view(-1, 512, 2, 2) 104 | result = self.decoder(result) 105 | result = self.final_layer(result) 106 | return result 107 | 108 | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor: 109 | """ 110 | Reparameterization trick to sample from N(mu, var) from 111 | N(0,1). 112 | :param mu: (Tensor) Mean of the latent Gaussian [B x D] 113 | :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D] 114 | :return: (Tensor) [B x D] 115 | """ 116 | std = torch.exp(0.5 * logvar) 117 | eps = torch.randn_like(std) 118 | return eps * std + mu 119 | 120 | def forward(self, input: Tensor, **kwargs) -> List[Tensor]: 121 | mu, log_var = self.encode(input) 122 | z = self.reparameterize(mu, log_var) 123 | return [self.decode(z), input, mu, log_var] 124 | 125 | def loss_function(self, 126 | *args, 127 | **kwargs) -> dict: 128 | """ 129 | Computes the VAE loss function. 130 | KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2} 131 | :param args: 132 | :param kwargs: 133 | :return: 134 | """ 135 | recons = args[0] 136 | input = args[1] 137 | mu = args[2] 138 | log_var = args[3] 139 | 140 | kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset 141 | t = recons - input 142 | # recons_loss = F.mse_loss(recons, input) 143 | # cosh = torch.cosh(self.alpha * t) 144 | # recons_loss = (1./self.alpha * torch.log(cosh)).mean() 145 | 146 | recons_loss = self.alpha * t + \ 147 | torch.log(1. + torch.exp(- 2 * self.alpha * t)) - \ 148 | torch.log(torch.tensor(2.0)) 149 | # print(self.alpha* t.max(), self.alpha*t.min()) 150 | recons_loss = (1. / self.alpha) * recons_loss.mean() 151 | 152 | kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0) 153 | 154 | loss = recons_loss + self.beta * kld_weight * kld_loss 155 | return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':-kld_loss} 156 | 157 | def sample(self, 158 | num_samples:int, 159 | current_device: int, **kwargs) -> Tensor: 160 | """ 161 | Samples from the latent space and return the corresponding 162 | image space map. 163 | :param num_samples: (Int) Number of samples 164 | :param current_device: (Int) Device to run the model 165 | :return: (Tensor) 166 | """ 167 | z = torch.randn(num_samples, 168 | self.latent_dim) 169 | 170 | z = z.to(current_device) 171 | 172 | samples = self.decode(z) 173 | return samples 174 | 175 | def generate(self, x: Tensor, **kwargs) -> Tensor: 176 | """ 177 | Given an input image x, returns the reconstructed image 178 | :param x: (Tensor) [B x C x H x W] 179 | :return: (Tensor) [B x C x H x W] 180 | """ 181 | 182 | return self.forward(x)[0] -------------------------------------------------------------------------------- /models/miwae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from models import BaseVAE 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from .types_ import * 6 | from torch.distributions import Normal 7 | 8 | 9 | class MIWAE(BaseVAE): 10 | 11 | def __init__(self, 12 | in_channels: int, 13 | latent_dim: int, 14 | hidden_dims: List = None, 15 | num_samples: int = 5, 16 | num_estimates: int = 5, 17 | **kwargs) -> None: 18 | super(MIWAE, self).__init__() 19 | 20 | self.latent_dim = latent_dim 21 | self.num_samples = num_samples # K 22 | self.num_estimates = num_estimates # M 23 | 24 | modules = [] 25 | if hidden_dims is None: 26 | hidden_dims = [32, 64, 128, 256, 512] 27 | 28 | # Build Encoder 29 | for h_dim in hidden_dims: 30 | modules.append( 31 | nn.Sequential( 32 | nn.Conv2d(in_channels, out_channels=h_dim, 33 | kernel_size= 3, stride= 2, padding = 1), 34 | nn.BatchNorm2d(h_dim), 35 | nn.LeakyReLU()) 36 | ) 37 | in_channels = h_dim 38 | 39 | self.encoder = nn.Sequential(*modules) 40 | self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim) 41 | self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim) 42 | 43 | 44 | # Build Decoder 45 | modules = [] 46 | 47 | self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4) 48 | 49 | hidden_dims.reverse() 50 | 51 | for i in range(len(hidden_dims) - 1): 52 | modules.append( 53 | nn.Sequential( 54 | nn.ConvTranspose2d(hidden_dims[i], 55 | hidden_dims[i + 1], 56 | kernel_size=3, 57 | stride = 2, 58 | padding=1, 59 | output_padding=1), 60 | nn.BatchNorm2d(hidden_dims[i + 1]), 61 | nn.LeakyReLU()) 62 | ) 63 | 64 | 65 | 66 | self.decoder = nn.Sequential(*modules) 67 | 68 | self.final_layer = nn.Sequential( 69 | nn.ConvTranspose2d(hidden_dims[-1], 70 | hidden_dims[-1], 71 | kernel_size=3, 72 | stride=2, 73 | padding=1, 74 | output_padding=1), 75 | nn.BatchNorm2d(hidden_dims[-1]), 76 | nn.LeakyReLU(), 77 | nn.Conv2d(hidden_dims[-1], out_channels= 3, 78 | kernel_size= 3, padding= 1), 79 | nn.Tanh()) 80 | 81 | def encode(self, input: Tensor) -> List[Tensor]: 82 | """ 83 | Encodes the input by passing through the encoder network 84 | and returns the latent codes. 85 | :param input: (Tensor) Input tensor to encoder [N x C x H x W] 86 | :return: (Tensor) List of latent codes 87 | """ 88 | result = self.encoder(input) 89 | result = torch.flatten(result, start_dim=1) 90 | 91 | # Split the result into mu and var components 92 | # of the latent Gaussian distribution 93 | mu = self.fc_mu(result) 94 | log_var = self.fc_var(result) 95 | 96 | return [mu, log_var] 97 | 98 | def decode(self, z: Tensor) -> Tensor: 99 | """ 100 | Maps the given latent codes of S samples 101 | onto the image space. 102 | :param z: (Tensor) [B x S x D] 103 | :return: (Tensor) [B x S x C x H x W] 104 | """ 105 | B, M,S, D = z.size() 106 | z = z.contiguous().view(-1, self.latent_dim) #[BMS x D] 107 | result = self.decoder_input(z) 108 | result = result.view(-1, 512, 2, 2) 109 | result = self.decoder(result) 110 | result = self.final_layer(result) #[BMS x C x H x W ] 111 | result = result.view([B, M, S,result.size(-3), result.size(-2), result.size(-1)]) #[B x M x S x C x H x W] 112 | return result 113 | 114 | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor: 115 | """ 116 | :param mu: (Tensor) Mean of the latent Gaussian 117 | :param logvar: (Tensor) Standard deviation of the latent Gaussian 118 | :return: 119 | """ 120 | std = torch.exp(0.5 * logvar) 121 | eps = torch.randn_like(std) 122 | return eps * std + mu 123 | 124 | def forward(self, input: Tensor, **kwargs) -> List[Tensor]: 125 | mu, log_var = self.encode(input) 126 | mu = mu.repeat(self.num_estimates, self.num_samples, 1, 1).permute(2, 0, 1, 3) # [B x M x S x D] 127 | log_var = log_var.repeat(self.num_estimates, self.num_samples, 1, 1).permute(2, 0, 1, 3) # [B x M x S x D] 128 | z = self.reparameterize(mu, log_var) # [B x M x S x D] 129 | eps = (z - mu) / log_var # Prior samples 130 | return [self.decode(z), input, mu, log_var, z, eps] 131 | 132 | def loss_function(self, 133 | *args, 134 | **kwargs) -> dict: 135 | """ 136 | KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2} 137 | :param args: 138 | :param kwargs: 139 | :return: 140 | """ 141 | recons = args[0] 142 | input = args[1] 143 | mu = args[2] 144 | log_var = args[3] 145 | z = args[4] 146 | eps = args[5] 147 | 148 | input = input.repeat(self.num_estimates, 149 | self.num_samples, 1, 1, 1, 1).permute(2, 0, 1, 3, 4, 5) #[B x M x S x C x H x W] 150 | 151 | kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset 152 | 153 | log_p_x_z = ((recons - input) ** 2).flatten(3).mean(-1) # Reconstruction Loss # [B x M x S] 154 | 155 | kld_loss = -0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim=3) # [B x M x S] 156 | # Get importance weights 157 | log_weight = (log_p_x_z + kld_weight * kld_loss) #.detach().data 158 | 159 | # Rescale the weights (along the sample dim) to lie in [0, 1] and sum to 1 160 | weight = F.softmax(log_weight, dim = -1) # [B x M x S] 161 | 162 | loss = torch.mean(torch.mean(torch.sum(weight * log_weight, dim=-1), dim = -2), dim = 0) 163 | 164 | return {'loss': loss, 'Reconstruction_Loss':log_p_x_z.mean(), 'KLD':-kld_loss.mean()} 165 | 166 | def sample(self, 167 | num_samples:int, 168 | current_device: int, **kwargs) -> Tensor: 169 | """ 170 | Samples from the latent space and return the corresponding 171 | image space map. 172 | :param num_samples: (Int) Number of samples 173 | :param current_device: (Int) Device to run the model 174 | :return: (Tensor) 175 | """ 176 | z = torch.randn(num_samples, 1, 1, 177 | self.latent_dim) 178 | 179 | z = z.to(current_device) 180 | 181 | samples = self.decode(z).squeeze() 182 | return samples 183 | 184 | def generate(self, x: Tensor, **kwargs) -> Tensor: 185 | """ 186 | Given an input image x, returns the reconstructed image. 187 | Returns only the first reconstructed sample 188 | :param x: (Tensor) [B x C x H x W] 189 | :return: (Tensor) [B x C x H x W] 190 | """ 191 | 192 | return self.forward(x)[0][:, 0, 0, :] 193 | -------------------------------------------------------------------------------- /models/swae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from models import BaseVAE 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from torch import distributions as dist 6 | from .types_ import * 7 | 8 | 9 | class SWAE(BaseVAE): 10 | 11 | def __init__(self, 12 | in_channels: int, 13 | latent_dim: int, 14 | hidden_dims: List = None, 15 | reg_weight: int = 100, 16 | wasserstein_deg: float= 2., 17 | num_projections: int = 50, 18 | projection_dist: str = 'normal', 19 | **kwargs) -> None: 20 | super(SWAE, self).__init__() 21 | 22 | self.latent_dim = latent_dim 23 | self.reg_weight = reg_weight 24 | self.p = wasserstein_deg 25 | self.num_projections = num_projections 26 | self.proj_dist = projection_dist 27 | 28 | modules = [] 29 | if hidden_dims is None: 30 | hidden_dims = [32, 64, 128, 256, 512] 31 | 32 | # Build Encoder 33 | for h_dim in hidden_dims: 34 | modules.append( 35 | nn.Sequential( 36 | nn.Conv2d(in_channels, out_channels=h_dim, 37 | kernel_size= 3, stride= 2, padding = 1), 38 | nn.BatchNorm2d(h_dim), 39 | nn.LeakyReLU()) 40 | ) 41 | in_channels = h_dim 42 | 43 | self.encoder = nn.Sequential(*modules) 44 | self.fc_z = nn.Linear(hidden_dims[-1]*4, latent_dim) 45 | 46 | 47 | # Build Decoder 48 | modules = [] 49 | 50 | self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4) 51 | 52 | hidden_dims.reverse() 53 | 54 | for i in range(len(hidden_dims) - 1): 55 | modules.append( 56 | nn.Sequential( 57 | nn.ConvTranspose2d(hidden_dims[i], 58 | hidden_dims[i + 1], 59 | kernel_size=3, 60 | stride = 2, 61 | padding=1, 62 | output_padding=1), 63 | nn.BatchNorm2d(hidden_dims[i + 1]), 64 | nn.LeakyReLU()) 65 | ) 66 | 67 | 68 | 69 | self.decoder = nn.Sequential(*modules) 70 | 71 | self.final_layer = nn.Sequential( 72 | nn.ConvTranspose2d(hidden_dims[-1], 73 | hidden_dims[-1], 74 | kernel_size=3, 75 | stride=2, 76 | padding=1, 77 | output_padding=1), 78 | nn.BatchNorm2d(hidden_dims[-1]), 79 | nn.LeakyReLU(), 80 | nn.Conv2d(hidden_dims[-1], out_channels= 3, 81 | kernel_size= 3, padding= 1), 82 | nn.Tanh()) 83 | 84 | def encode(self, input: Tensor) -> Tensor: 85 | """ 86 | Encodes the input by passing through the encoder network 87 | and returns the latent codes. 88 | :param input: (Tensor) Input tensor to encoder [N x C x H x W] 89 | :return: (Tensor) List of latent codes 90 | """ 91 | result = self.encoder(input) 92 | result = torch.flatten(result, start_dim=1) 93 | 94 | # Split the result into mu and var components 95 | # of the latent Gaussian distribution 96 | z = self.fc_z(result) 97 | return z 98 | 99 | def decode(self, z: Tensor) -> Tensor: 100 | result = self.decoder_input(z) 101 | result = result.view(-1, 512, 2, 2) 102 | result = self.decoder(result) 103 | result = self.final_layer(result) 104 | return result 105 | 106 | def forward(self, input: Tensor, **kwargs) -> List[Tensor]: 107 | z = self.encode(input) 108 | return [self.decode(z), input, z] 109 | 110 | def loss_function(self, 111 | *args, 112 | **kwargs) -> dict: 113 | recons = args[0] 114 | input = args[1] 115 | z = args[2] 116 | 117 | batch_size = input.size(0) 118 | bias_corr = batch_size * (batch_size - 1) 119 | reg_weight = self.reg_weight / bias_corr 120 | 121 | recons_loss_l2 = F.mse_loss(recons, input) 122 | recons_loss_l1 = F.l1_loss(recons, input) 123 | 124 | swd_loss = self.compute_swd(z, self.p, reg_weight) 125 | 126 | loss = recons_loss_l2 + recons_loss_l1 + swd_loss 127 | return {'loss': loss, 'Reconstruction_Loss':(recons_loss_l2 + recons_loss_l1), 'SWD': swd_loss} 128 | 129 | def get_random_projections(self, latent_dim: int, num_samples: int) -> Tensor: 130 | """ 131 | Returns random samples from latent distribution's (Gaussian) 132 | unit sphere for projecting the encoded samples and the 133 | distribution samples. 134 | 135 | :param latent_dim: (Int) Dimensionality of the latent space (D) 136 | :param num_samples: (Int) Number of samples required (S) 137 | :return: Random projections from the latent unit sphere 138 | """ 139 | if self.proj_dist == 'normal': 140 | rand_samples = torch.randn(num_samples, latent_dim) 141 | elif self.proj_dist == 'cauchy': 142 | rand_samples = dist.Cauchy(torch.tensor([0.0]), 143 | torch.tensor([1.0])).sample((num_samples, latent_dim)).squeeze() 144 | else: 145 | raise ValueError('Unknown projection distribution.') 146 | 147 | rand_proj = rand_samples / rand_samples.norm(dim=1).view(-1,1) 148 | return rand_proj # [S x D] 149 | 150 | 151 | def compute_swd(self, 152 | z: Tensor, 153 | p: float, 154 | reg_weight: float) -> Tensor: 155 | """ 156 | Computes the Sliced Wasserstein Distance (SWD) - which consists of 157 | randomly projecting the encoded and prior vectors and computing 158 | their Wasserstein distance along those projections. 159 | 160 | :param z: Latent samples # [N x D] 161 | :param p: Value for the p^th Wasserstein distance 162 | :param reg_weight: 163 | :return: 164 | """ 165 | prior_z = torch.randn_like(z) # [N x D] 166 | device = z.device 167 | 168 | proj_matrix = self.get_random_projections(self.latent_dim, 169 | num_samples=self.num_projections).transpose(0,1).to(device) 170 | 171 | latent_projections = z.matmul(proj_matrix) # [N x S] 172 | prior_projections = prior_z.matmul(proj_matrix) # [N x S] 173 | 174 | # The Wasserstein distance is computed by sorting the two projections 175 | # across the batches and computing their element-wise l2 distance 176 | w_dist = torch.sort(latent_projections.t(), dim=1)[0] - \ 177 | torch.sort(prior_projections.t(), dim=1)[0] 178 | w_dist = w_dist.pow(p) 179 | return reg_weight * w_dist.mean() 180 | 181 | def sample(self, 182 | num_samples:int, 183 | current_device: int, **kwargs) -> Tensor: 184 | """ 185 | Samples from the latent space and return the corresponding 186 | image space map. 187 | :param num_samples: (Int) Number of samples 188 | :param current_device: (Int) Device to run the model 189 | :return: (Tensor) 190 | """ 191 | z = torch.randn(num_samples, 192 | self.latent_dim) 193 | 194 | z = z.to(current_device) 195 | 196 | samples = self.decode(z) 197 | return samples 198 | 199 | def generate(self, x: Tensor, **kwargs) -> Tensor: 200 | """ 201 | Given an input image x, returns the reconstructed image 202 | :param x: (Tensor) [B x C x H x W] 203 | :return: (Tensor) [B x C x H x W] 204 | """ 205 | 206 | return self.forward(x)[0] 207 | -------------------------------------------------------------------------------- /models/twostage_vae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from models import BaseVAE 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from .types_ import * 6 | 7 | 8 | class TwoStageVAE(BaseVAE): 9 | 10 | def __init__(self, 11 | in_channels: int, 12 | latent_dim: int, 13 | hidden_dims: List = None, 14 | hidden_dims2: List = None, 15 | **kwargs) -> None: 16 | super(TwoStageVAE, self).__init__() 17 | 18 | self.latent_dim = latent_dim 19 | 20 | modules = [] 21 | if hidden_dims is None: 22 | hidden_dims = [32, 64, 128, 256, 512] 23 | 24 | if hidden_dims2 is None: 25 | hidden_dims2 = [1024, 1024] 26 | 27 | # Build Encoder 28 | for h_dim in hidden_dims: 29 | modules.append( 30 | nn.Sequential( 31 | nn.Conv2d(in_channels, out_channels=h_dim, 32 | kernel_size= 3, stride= 2, padding = 1), 33 | nn.BatchNorm2d(h_dim), 34 | nn.LeakyReLU()) 35 | ) 36 | in_channels = h_dim 37 | 38 | self.encoder = nn.Sequential(*modules) 39 | self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim) 40 | self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim) 41 | 42 | 43 | # Build Decoder 44 | modules = [] 45 | self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4) 46 | hidden_dims.reverse() 47 | 48 | for i in range(len(hidden_dims) - 1): 49 | modules.append( 50 | nn.Sequential( 51 | nn.ConvTranspose2d(hidden_dims[i], 52 | hidden_dims[i + 1], 53 | kernel_size=3, 54 | stride = 2, 55 | padding=1, 56 | output_padding=1), 57 | nn.BatchNorm2d(hidden_dims[i + 1]), 58 | nn.LeakyReLU()) 59 | ) 60 | self.decoder = nn.Sequential(*modules) 61 | 62 | self.final_layer = nn.Sequential( 63 | nn.ConvTranspose2d(hidden_dims[-1], 64 | hidden_dims[-1], 65 | kernel_size=3, 66 | stride=2, 67 | padding=1, 68 | output_padding=1), 69 | nn.BatchNorm2d(hidden_dims[-1]), 70 | nn.LeakyReLU(), 71 | nn.Conv2d(hidden_dims[-1], out_channels= 3, 72 | kernel_size= 3, padding= 1), 73 | nn.Tanh()) 74 | 75 | #---------------------- Second VAE ---------------------------# 76 | encoder2 = [] 77 | in_channels = self.latent_dim 78 | for h_dim in hidden_dims2: 79 | encoder2.append(nn.Sequential( 80 | nn.Linear(in_channels, h_dim), 81 | nn.BatchNorm1d(h_dim), 82 | nn.LeakyReLU())) 83 | in_channels = h_dim 84 | self.encoder2 = nn.Sequential(*encoder2) 85 | self.fc_mu2 = nn.Linear(hidden_dims2[-1], self.latent_dim) 86 | self.fc_var2 = nn.Linear(hidden_dims2[-1], self.latent_dim) 87 | 88 | decoder2 = [] 89 | hidden_dims2.reverse() 90 | 91 | in_channels = self.latent_dim 92 | for h_dim in hidden_dims2: 93 | decoder2.append(nn.Sequential( 94 | nn.Linear(in_channels, h_dim), 95 | nn.BatchNorm1d(h_dim), 96 | nn.LeakyReLU())) 97 | in_channels = h_dim 98 | self.decoder2 = nn.Sequential(*decoder2) 99 | 100 | def encode(self, input: Tensor) -> List[Tensor]: 101 | """ 102 | Encodes the input by passing through the encoder network 103 | and returns the latent codes. 104 | :param input: (Tensor) Input tensor to encoder [N x C x H x W] 105 | :return: (Tensor) List of latent codes 106 | """ 107 | result = self.encoder(input) 108 | result = torch.flatten(result, start_dim=1) 109 | 110 | # Split the result into mu and var components 111 | # of the latent Gaussian distribution 112 | mu = self.fc_mu(result) 113 | log_var = self.fc_var(result) 114 | 115 | return [mu, log_var] 116 | 117 | def decode(self, z: Tensor) -> Tensor: 118 | """ 119 | Maps the given latent codes 120 | onto the image space. 121 | :param z: (Tensor) [B x D] 122 | :return: (Tensor) [B x C x H x W] 123 | """ 124 | result = self.decoder_input(z) 125 | result = result.view(-1, 512, 2, 2) 126 | result = self.decoder(result) 127 | result = self.final_layer(result) 128 | return result 129 | 130 | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor: 131 | """ 132 | Reparameterization trick to sample from N(mu, var) from 133 | N(0,1). 134 | :param mu: (Tensor) Mean of the latent Gaussian [B x D] 135 | :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D] 136 | :return: (Tensor) [B x D] 137 | """ 138 | std = torch.exp(0.5 * logvar) 139 | eps = torch.randn_like(std) 140 | return eps * std + mu 141 | 142 | def forward(self, input: Tensor, **kwargs) -> List[Tensor]: 143 | mu, log_var = self.encode(input) 144 | z = self.reparameterize(mu, log_var) 145 | 146 | return [self.decode(z), input, mu, log_var] 147 | 148 | def loss_function(self, 149 | *args, 150 | **kwargs) -> dict: 151 | """ 152 | Computes the VAE loss function. 153 | KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2} 154 | :param args: 155 | :param kwargs: 156 | :return: 157 | """ 158 | recons = args[0] 159 | input = args[1] 160 | mu = args[2] 161 | log_var = args[3] 162 | 163 | kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset 164 | recons_loss =F.mse_loss(recons, input) 165 | 166 | 167 | kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0) 168 | 169 | loss = recons_loss + kld_weight * kld_loss 170 | return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':-kld_loss} 171 | 172 | def sample(self, 173 | num_samples:int, 174 | current_device: int, **kwargs) -> Tensor: 175 | """ 176 | Samples from the latent space and return the corresponding 177 | image space map. 178 | :param num_samples: (Int) Number of samples 179 | :param current_device: (Int) Device to run the model 180 | :return: (Tensor) 181 | """ 182 | z = torch.randn(num_samples, 183 | self.latent_dim) 184 | 185 | z = z.to(current_device) 186 | 187 | samples = self.decode(z) 188 | return samples 189 | 190 | def generate(self, x: Tensor, **kwargs) -> Tensor: 191 | """ 192 | Given an input image x, returns the reconstructed image 193 | :param x: (Tensor) [B x C x H x W] 194 | :return: (Tensor) [B x C x H x W] 195 | """ 196 | 197 | return self.forward(x)[0] -------------------------------------------------------------------------------- /models/types_.py: -------------------------------------------------------------------------------- 1 | from typing import List, Callable, Union, Any, TypeVar, Tuple 2 | # from torch import tensor as Tensor 3 | 4 | Tensor = TypeVar('torch.tensor') 5 | -------------------------------------------------------------------------------- /models/vampvae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from models import BaseVAE 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from .types_ import * 6 | 7 | 8 | class VampVAE(BaseVAE): 9 | 10 | def __init__(self, 11 | in_channels: int, 12 | latent_dim: int, 13 | hidden_dims: List = None, 14 | num_components: int = 50, 15 | **kwargs) -> None: 16 | super(VampVAE, self).__init__() 17 | 18 | self.latent_dim = latent_dim 19 | self.num_components = num_components 20 | 21 | modules = [] 22 | if hidden_dims is None: 23 | hidden_dims = [32, 64, 128, 256, 512] 24 | 25 | # Build Encoder 26 | for h_dim in hidden_dims: 27 | modules.append( 28 | nn.Sequential( 29 | nn.Conv2d(in_channels, out_channels=h_dim, 30 | kernel_size= 3, stride= 2, padding = 1), 31 | nn.BatchNorm2d(h_dim), 32 | nn.LeakyReLU()) 33 | ) 34 | in_channels = h_dim 35 | 36 | self.encoder = nn.Sequential(*modules) 37 | self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim) 38 | self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim) 39 | 40 | 41 | # Build Decoder 42 | modules = [] 43 | 44 | self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4) 45 | 46 | hidden_dims.reverse() 47 | 48 | for i in range(len(hidden_dims) - 1): 49 | modules.append( 50 | nn.Sequential( 51 | nn.ConvTranspose2d(hidden_dims[i], 52 | hidden_dims[i + 1], 53 | kernel_size=3, 54 | stride = 2, 55 | padding=1, 56 | output_padding=1), 57 | nn.BatchNorm2d(hidden_dims[i + 1]), 58 | nn.LeakyReLU()) 59 | ) 60 | 61 | 62 | 63 | self.decoder = nn.Sequential(*modules) 64 | 65 | self.final_layer = nn.Sequential( 66 | nn.ConvTranspose2d(hidden_dims[-1], 67 | hidden_dims[-1], 68 | kernel_size=3, 69 | stride=2, 70 | padding=1, 71 | output_padding=1), 72 | nn.BatchNorm2d(hidden_dims[-1]), 73 | nn.LeakyReLU(), 74 | nn.Conv2d(hidden_dims[-1], out_channels= 3, 75 | kernel_size= 3, padding= 1), 76 | nn.Tanh()) 77 | 78 | self.pseudo_input = torch.eye(self.num_components, requires_grad= False) 79 | self.embed_pseudo = nn.Sequential(nn.Linear(self.num_components, 12288), 80 | nn.Hardtanh(0.0, 1.0)) # 3x64x64 = 12288 81 | 82 | def encode(self, input: Tensor) -> List[Tensor]: 83 | """ 84 | Encodes the input by passing through the encoder network 85 | and returns the latent codes. 86 | :param input: (Tensor) Input tensor to encoder [N x C x H x W] 87 | :return: (Tensor) List of latent codes 88 | """ 89 | result = self.encoder(input) 90 | result = torch.flatten(result, start_dim=1) 91 | 92 | # Split the result into mu and var components 93 | # of the latent Gaussian distribution 94 | mu = self.fc_mu(result) 95 | log_var = self.fc_var(result) 96 | 97 | return [mu, log_var] 98 | 99 | def decode(self, z: Tensor) -> Tensor: 100 | result = self.decoder_input(z) 101 | result = result.view(-1, 512, 2, 2) 102 | result = self.decoder(result) 103 | result = self.final_layer(result) 104 | return result 105 | 106 | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor: 107 | """ 108 | Will a single z be enough ti compute the expectation 109 | for the loss?? 110 | :param mu: (Tensor) Mean of the latent Gaussian 111 | :param logvar: (Tensor) Standard deviation of the latent Gaussian 112 | :return: 113 | """ 114 | std = torch.exp(0.5 * logvar) 115 | eps = torch.randn_like(std) 116 | return eps * std + mu 117 | 118 | def forward(self, input: Tensor, **kwargs) -> List[Tensor]: 119 | mu, log_var = self.encode(input) 120 | z = self.reparameterize(mu, log_var) 121 | return [self.decode(z), input, mu, log_var, z] 122 | 123 | def loss_function(self, 124 | *args, 125 | **kwargs) -> dict: 126 | recons = args[0] 127 | input = args[1] 128 | mu = args[2] 129 | log_var = args[3] 130 | z = args[4] 131 | 132 | kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset 133 | recons_loss =F.mse_loss(recons, input) 134 | 135 | E_log_q_z = torch.mean(torch.sum(-0.5 * (log_var + (z - mu) ** 2)/ log_var.exp(), 136 | dim = 1), 137 | dim = 0) 138 | 139 | # Original Prior 140 | # E_log_p_z = torch.mean(torch.sum(-0.5 * (z ** 2), dim = 1), dim = 0) 141 | 142 | # Vamp Prior 143 | M, C, H, W = input.size() 144 | curr_device = input.device 145 | self.pseudo_input = self.pseudo_input.cuda(curr_device) 146 | x = self.embed_pseudo(self.pseudo_input) 147 | x = x.view(-1, C, H, W) 148 | prior_mu, prior_log_var = self.encode(x) 149 | 150 | z_expand = z.unsqueeze(1) 151 | prior_mu = prior_mu.unsqueeze(0) 152 | prior_log_var = prior_log_var.unsqueeze(0) 153 | 154 | E_log_p_z = torch.sum(-0.5 * 155 | (prior_log_var + (z_expand - prior_mu) ** 2)/ prior_log_var.exp(), 156 | dim = 2) - torch.log(torch.tensor(self.num_components).float()) 157 | 158 | # dim = 0) 159 | E_log_p_z = torch.logsumexp(E_log_p_z, dim = 1) 160 | E_log_p_z = torch.mean(E_log_p_z, dim = 0) 161 | 162 | # KLD = E_q log q - E_q log p 163 | kld_loss = -(E_log_p_z - E_log_q_z) 164 | # print(E_log_p_z, E_log_q_z) 165 | 166 | 167 | loss = recons_loss + kld_weight * kld_loss 168 | return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':-kld_loss} 169 | 170 | def sample(self, 171 | num_samples:int, 172 | current_device: int, **kwargs) -> Tensor: 173 | """ 174 | Samples from the latent space and return the corresponding 175 | image space map. 176 | :param num_samples: (Int) Number of samples 177 | :param current_device: (Int) Device to run the model 178 | :return: (Tensor) 179 | """ 180 | z = torch.randn(num_samples, 181 | self.latent_dim) 182 | 183 | z = z.cuda(current_device) 184 | 185 | samples = self.decode(z) 186 | return samples 187 | 188 | def generate(self, x: Tensor, **kwargs) -> Tensor: 189 | """ 190 | Given an input image x, returns the reconstructed image 191 | :param x: (Tensor) [B x C x H x W] 192 | :return: (Tensor) [B x C x H x W] 193 | """ 194 | 195 | return self.forward(x)[0] -------------------------------------------------------------------------------- /models/vanilla_vae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from models import BaseVAE 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from .types_ import * 6 | 7 | 8 | class VanillaVAE(BaseVAE): 9 | 10 | 11 | def __init__(self, 12 | in_channels: int, 13 | latent_dim: int, 14 | hidden_dims: List = None, 15 | **kwargs) -> None: 16 | super(VanillaVAE, self).__init__() 17 | 18 | self.latent_dim = latent_dim 19 | 20 | modules = [] 21 | if hidden_dims is None: 22 | hidden_dims = [32, 64, 128, 256, 512] 23 | 24 | # Build Encoder 25 | for h_dim in hidden_dims: 26 | modules.append( 27 | nn.Sequential( 28 | nn.Conv2d(in_channels, out_channels=h_dim, 29 | kernel_size= 3, stride= 2, padding = 1), 30 | nn.BatchNorm2d(h_dim), 31 | nn.LeakyReLU()) 32 | ) 33 | in_channels = h_dim 34 | 35 | self.encoder = nn.Sequential(*modules) 36 | self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim) 37 | self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim) 38 | 39 | 40 | # Build Decoder 41 | modules = [] 42 | 43 | self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4) 44 | 45 | hidden_dims.reverse() 46 | 47 | for i in range(len(hidden_dims) - 1): 48 | modules.append( 49 | nn.Sequential( 50 | nn.ConvTranspose2d(hidden_dims[i], 51 | hidden_dims[i + 1], 52 | kernel_size=3, 53 | stride = 2, 54 | padding=1, 55 | output_padding=1), 56 | nn.BatchNorm2d(hidden_dims[i + 1]), 57 | nn.LeakyReLU()) 58 | ) 59 | 60 | 61 | 62 | self.decoder = nn.Sequential(*modules) 63 | 64 | self.final_layer = nn.Sequential( 65 | nn.ConvTranspose2d(hidden_dims[-1], 66 | hidden_dims[-1], 67 | kernel_size=3, 68 | stride=2, 69 | padding=1, 70 | output_padding=1), 71 | nn.BatchNorm2d(hidden_dims[-1]), 72 | nn.LeakyReLU(), 73 | nn.Conv2d(hidden_dims[-1], out_channels= 3, 74 | kernel_size= 3, padding= 1), 75 | nn.Tanh()) 76 | 77 | def encode(self, input: Tensor) -> List[Tensor]: 78 | """ 79 | Encodes the input by passing through the encoder network 80 | and returns the latent codes. 81 | :param input: (Tensor) Input tensor to encoder [N x C x H x W] 82 | :return: (Tensor) List of latent codes 83 | """ 84 | result = self.encoder(input) 85 | result = torch.flatten(result, start_dim=1) 86 | 87 | # Split the result into mu and var components 88 | # of the latent Gaussian distribution 89 | mu = self.fc_mu(result) 90 | log_var = self.fc_var(result) 91 | 92 | return [mu, log_var] 93 | 94 | def decode(self, z: Tensor) -> Tensor: 95 | """ 96 | Maps the given latent codes 97 | onto the image space. 98 | :param z: (Tensor) [B x D] 99 | :return: (Tensor) [B x C x H x W] 100 | """ 101 | result = self.decoder_input(z) 102 | result = result.view(-1, 512, 2, 2) 103 | result = self.decoder(result) 104 | result = self.final_layer(result) 105 | return result 106 | 107 | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor: 108 | """ 109 | Reparameterization trick to sample from N(mu, var) from 110 | N(0,1). 111 | :param mu: (Tensor) Mean of the latent Gaussian [B x D] 112 | :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D] 113 | :return: (Tensor) [B x D] 114 | """ 115 | std = torch.exp(0.5 * logvar) 116 | eps = torch.randn_like(std) 117 | return eps * std + mu 118 | 119 | def forward(self, input: Tensor, **kwargs) -> List[Tensor]: 120 | mu, log_var = self.encode(input) 121 | z = self.reparameterize(mu, log_var) 122 | return [self.decode(z), input, mu, log_var] 123 | 124 | def loss_function(self, 125 | *args, 126 | **kwargs) -> dict: 127 | """ 128 | Computes the VAE loss function. 129 | KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2} 130 | :param args: 131 | :param kwargs: 132 | :return: 133 | """ 134 | recons = args[0] 135 | input = args[1] 136 | mu = args[2] 137 | log_var = args[3] 138 | 139 | kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset 140 | recons_loss =F.mse_loss(recons, input) 141 | 142 | 143 | kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0) 144 | 145 | loss = recons_loss + kld_weight * kld_loss 146 | return {'loss': loss, 'Reconstruction_Loss':recons_loss.detach(), 'KLD':-kld_loss.detach()} 147 | 148 | def sample(self, 149 | num_samples:int, 150 | current_device: int, **kwargs) -> Tensor: 151 | """ 152 | Samples from the latent space and return the corresponding 153 | image space map. 154 | :param num_samples: (Int) Number of samples 155 | :param current_device: (Int) Device to run the model 156 | :return: (Tensor) 157 | """ 158 | z = torch.randn(num_samples, 159 | self.latent_dim) 160 | 161 | z = z.to(current_device) 162 | 163 | samples = self.decode(z) 164 | return samples 165 | 166 | def generate(self, x: Tensor, **kwargs) -> Tensor: 167 | """ 168 | Given an input image x, returns the reconstructed image 169 | :param x: (Tensor) [B x C x H x W] 170 | :return: (Tensor) [B x C x H x W] 171 | """ 172 | 173 | return self.forward(x)[0] -------------------------------------------------------------------------------- /models/vq_vae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from models import BaseVAE 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from .types_ import * 6 | 7 | class VectorQuantizer(nn.Module): 8 | """ 9 | Reference: 10 | [1] https://github.com/deepmind/sonnet/blob/v2/sonnet/src/nets/vqvae.py 11 | """ 12 | def __init__(self, 13 | num_embeddings: int, 14 | embedding_dim: int, 15 | beta: float = 0.25): 16 | super(VectorQuantizer, self).__init__() 17 | self.K = num_embeddings 18 | self.D = embedding_dim 19 | self.beta = beta 20 | 21 | self.embedding = nn.Embedding(self.K, self.D) 22 | self.embedding.weight.data.uniform_(-1 / self.K, 1 / self.K) 23 | 24 | def forward(self, latents: Tensor) -> Tensor: 25 | latents = latents.permute(0, 2, 3, 1).contiguous() # [B x D x H x W] -> [B x H x W x D] 26 | latents_shape = latents.shape 27 | flat_latents = latents.view(-1, self.D) # [BHW x D] 28 | 29 | # Compute L2 distance between latents and embedding weights 30 | dist = torch.sum(flat_latents ** 2, dim=1, keepdim=True) + \ 31 | torch.sum(self.embedding.weight ** 2, dim=1) - \ 32 | 2 * torch.matmul(flat_latents, self.embedding.weight.t()) # [BHW x K] 33 | 34 | # Get the encoding that has the min distance 35 | encoding_inds = torch.argmin(dist, dim=1).unsqueeze(1) # [BHW, 1] 36 | 37 | # Convert to one-hot encodings 38 | device = latents.device 39 | encoding_one_hot = torch.zeros(encoding_inds.size(0), self.K, device=device) 40 | encoding_one_hot.scatter_(1, encoding_inds, 1) # [BHW x K] 41 | 42 | # Quantize the latents 43 | quantized_latents = torch.matmul(encoding_one_hot, self.embedding.weight) # [BHW, D] 44 | quantized_latents = quantized_latents.view(latents_shape) # [B x H x W x D] 45 | 46 | # Compute the VQ Losses 47 | commitment_loss = F.mse_loss(quantized_latents.detach(), latents) 48 | embedding_loss = F.mse_loss(quantized_latents, latents.detach()) 49 | 50 | vq_loss = commitment_loss * self.beta + embedding_loss 51 | 52 | # Add the residue back to the latents 53 | quantized_latents = latents + (quantized_latents - latents).detach() 54 | 55 | return quantized_latents.permute(0, 3, 1, 2).contiguous(), vq_loss # [B x D x H x W] 56 | 57 | class ResidualLayer(nn.Module): 58 | 59 | def __init__(self, 60 | in_channels: int, 61 | out_channels: int): 62 | super(ResidualLayer, self).__init__() 63 | self.resblock = nn.Sequential(nn.Conv2d(in_channels, out_channels, 64 | kernel_size=3, padding=1, bias=False), 65 | nn.ReLU(True), 66 | nn.Conv2d(out_channels, out_channels, 67 | kernel_size=1, bias=False)) 68 | 69 | def forward(self, input: Tensor) -> Tensor: 70 | return input + self.resblock(input) 71 | 72 | 73 | class VQVAE(BaseVAE): 74 | 75 | def __init__(self, 76 | in_channels: int, 77 | embedding_dim: int, 78 | num_embeddings: int, 79 | hidden_dims: List = None, 80 | beta: float = 0.25, 81 | img_size: int = 64, 82 | **kwargs) -> None: 83 | super(VQVAE, self).__init__() 84 | 85 | self.embedding_dim = embedding_dim 86 | self.num_embeddings = num_embeddings 87 | self.img_size = img_size 88 | self.beta = beta 89 | 90 | modules = [] 91 | if hidden_dims is None: 92 | hidden_dims = [128, 256] 93 | 94 | # Build Encoder 95 | for h_dim in hidden_dims: 96 | modules.append( 97 | nn.Sequential( 98 | nn.Conv2d(in_channels, out_channels=h_dim, 99 | kernel_size=4, stride=2, padding=1), 100 | nn.LeakyReLU()) 101 | ) 102 | in_channels = h_dim 103 | 104 | modules.append( 105 | nn.Sequential( 106 | nn.Conv2d(in_channels, in_channels, 107 | kernel_size=3, stride=1, padding=1), 108 | nn.LeakyReLU()) 109 | ) 110 | 111 | for _ in range(6): 112 | modules.append(ResidualLayer(in_channels, in_channels)) 113 | modules.append(nn.LeakyReLU()) 114 | 115 | modules.append( 116 | nn.Sequential( 117 | nn.Conv2d(in_channels, embedding_dim, 118 | kernel_size=1, stride=1), 119 | nn.LeakyReLU()) 120 | ) 121 | 122 | self.encoder = nn.Sequential(*modules) 123 | 124 | self.vq_layer = VectorQuantizer(num_embeddings, 125 | embedding_dim, 126 | self.beta) 127 | 128 | # Build Decoder 129 | modules = [] 130 | modules.append( 131 | nn.Sequential( 132 | nn.Conv2d(embedding_dim, 133 | hidden_dims[-1], 134 | kernel_size=3, 135 | stride=1, 136 | padding=1), 137 | nn.LeakyReLU()) 138 | ) 139 | 140 | for _ in range(6): 141 | modules.append(ResidualLayer(hidden_dims[-1], hidden_dims[-1])) 142 | 143 | modules.append(nn.LeakyReLU()) 144 | 145 | hidden_dims.reverse() 146 | 147 | for i in range(len(hidden_dims) - 1): 148 | modules.append( 149 | nn.Sequential( 150 | nn.ConvTranspose2d(hidden_dims[i], 151 | hidden_dims[i + 1], 152 | kernel_size=4, 153 | stride=2, 154 | padding=1), 155 | nn.LeakyReLU()) 156 | ) 157 | 158 | modules.append( 159 | nn.Sequential( 160 | nn.ConvTranspose2d(hidden_dims[-1], 161 | out_channels=3, 162 | kernel_size=4, 163 | stride=2, padding=1), 164 | nn.Tanh())) 165 | 166 | self.decoder = nn.Sequential(*modules) 167 | 168 | def encode(self, input: Tensor) -> List[Tensor]: 169 | """ 170 | Encodes the input by passing through the encoder network 171 | and returns the latent codes. 172 | :param input: (Tensor) Input tensor to encoder [N x C x H x W] 173 | :return: (Tensor) List of latent codes 174 | """ 175 | result = self.encoder(input) 176 | return [result] 177 | 178 | def decode(self, z: Tensor) -> Tensor: 179 | """ 180 | Maps the given latent codes 181 | onto the image space. 182 | :param z: (Tensor) [B x D x H x W] 183 | :return: (Tensor) [B x C x H x W] 184 | """ 185 | 186 | result = self.decoder(z) 187 | return result 188 | 189 | def forward(self, input: Tensor, **kwargs) -> List[Tensor]: 190 | encoding = self.encode(input)[0] 191 | quantized_inputs, vq_loss = self.vq_layer(encoding) 192 | return [self.decode(quantized_inputs), input, vq_loss] 193 | 194 | def loss_function(self, 195 | *args, 196 | **kwargs) -> dict: 197 | """ 198 | :param args: 199 | :param kwargs: 200 | :return: 201 | """ 202 | recons = args[0] 203 | input = args[1] 204 | vq_loss = args[2] 205 | 206 | recons_loss = F.mse_loss(recons, input) 207 | 208 | loss = recons_loss + vq_loss 209 | return {'loss': loss, 210 | 'Reconstruction_Loss': recons_loss, 211 | 'VQ_Loss':vq_loss} 212 | 213 | def sample(self, 214 | num_samples: int, 215 | current_device: Union[int, str], **kwargs) -> Tensor: 216 | raise Warning('VQVAE sampler is not implemented.') 217 | 218 | def generate(self, x: Tensor, **kwargs) -> Tensor: 219 | """ 220 | Given an input image x, returns the reconstructed image 221 | :param x: (Tensor) [B x C x H x W] 222 | :return: (Tensor) [B x C x H x W] 223 | """ 224 | 225 | return self.forward(x)[0] -------------------------------------------------------------------------------- /models/wae_mmd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from models import BaseVAE 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from .types_ import * 6 | 7 | 8 | class WAE_MMD(BaseVAE): 9 | 10 | def __init__(self, 11 | in_channels: int, 12 | latent_dim: int, 13 | hidden_dims: List = None, 14 | reg_weight: int = 100, 15 | kernel_type: str = 'imq', 16 | latent_var: float = 2., 17 | **kwargs) -> None: 18 | super(WAE_MMD, self).__init__() 19 | 20 | self.latent_dim = latent_dim 21 | self.reg_weight = reg_weight 22 | self.kernel_type = kernel_type 23 | self.z_var = latent_var 24 | 25 | modules = [] 26 | if hidden_dims is None: 27 | hidden_dims = [32, 64, 128, 256, 512] 28 | 29 | # Build Encoder 30 | for h_dim in hidden_dims: 31 | modules.append( 32 | nn.Sequential( 33 | nn.Conv2d(in_channels, out_channels=h_dim, 34 | kernel_size= 3, stride= 2, padding = 1), 35 | nn.BatchNorm2d(h_dim), 36 | nn.LeakyReLU()) 37 | ) 38 | in_channels = h_dim 39 | 40 | self.encoder = nn.Sequential(*modules) 41 | self.fc_z = nn.Linear(hidden_dims[-1]*4, latent_dim) 42 | 43 | 44 | # Build Decoder 45 | modules = [] 46 | 47 | self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4) 48 | 49 | hidden_dims.reverse() 50 | 51 | for i in range(len(hidden_dims) - 1): 52 | modules.append( 53 | nn.Sequential( 54 | nn.ConvTranspose2d(hidden_dims[i], 55 | hidden_dims[i + 1], 56 | kernel_size=3, 57 | stride = 2, 58 | padding=1, 59 | output_padding=1), 60 | nn.BatchNorm2d(hidden_dims[i + 1]), 61 | nn.LeakyReLU()) 62 | ) 63 | 64 | 65 | 66 | self.decoder = nn.Sequential(*modules) 67 | 68 | self.final_layer = nn.Sequential( 69 | nn.ConvTranspose2d(hidden_dims[-1], 70 | hidden_dims[-1], 71 | kernel_size=3, 72 | stride=2, 73 | padding=1, 74 | output_padding=1), 75 | nn.BatchNorm2d(hidden_dims[-1]), 76 | nn.LeakyReLU(), 77 | nn.Conv2d(hidden_dims[-1], out_channels= 3, 78 | kernel_size= 3, padding= 1), 79 | nn.Tanh()) 80 | 81 | def encode(self, input: Tensor) -> Tensor: 82 | """ 83 | Encodes the input by passing through the encoder network 84 | and returns the latent codes. 85 | :param input: (Tensor) Input tensor to encoder [N x C x H x W] 86 | :return: (Tensor) List of latent codes 87 | """ 88 | result = self.encoder(input) 89 | result = torch.flatten(result, start_dim=1) 90 | 91 | # Split the result into mu and var components 92 | # of the latent Gaussian distribution 93 | z = self.fc_z(result) 94 | return z 95 | 96 | def decode(self, z: Tensor) -> Tensor: 97 | result = self.decoder_input(z) 98 | result = result.view(-1, 512, 2, 2) 99 | result = self.decoder(result) 100 | result = self.final_layer(result) 101 | return result 102 | 103 | def forward(self, input: Tensor, **kwargs) -> List[Tensor]: 104 | z = self.encode(input) 105 | return [self.decode(z), input, z] 106 | 107 | def loss_function(self, 108 | *args, 109 | **kwargs) -> dict: 110 | recons = args[0] 111 | input = args[1] 112 | z = args[2] 113 | 114 | batch_size = input.size(0) 115 | bias_corr = batch_size * (batch_size - 1) 116 | reg_weight = self.reg_weight / bias_corr 117 | 118 | recons_loss =F.mse_loss(recons, input) 119 | 120 | mmd_loss = self.compute_mmd(z, reg_weight) 121 | 122 | loss = recons_loss + mmd_loss 123 | return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'MMD': mmd_loss} 124 | 125 | def compute_kernel(self, 126 | x1: Tensor, 127 | x2: Tensor) -> Tensor: 128 | # Convert the tensors into row and column vectors 129 | D = x1.size(1) 130 | N = x1.size(0) 131 | 132 | x1 = x1.unsqueeze(-2) # Make it into a column tensor 133 | x2 = x2.unsqueeze(-3) # Make it into a row tensor 134 | 135 | """ 136 | Usually the below lines are not required, especially in our case, 137 | but this is useful when x1 and x2 have different sizes 138 | along the 0th dimension. 139 | """ 140 | x1 = x1.expand(N, N, D) 141 | x2 = x2.expand(N, N, D) 142 | 143 | if self.kernel_type == 'rbf': 144 | result = self.compute_rbf(x1, x2) 145 | elif self.kernel_type == 'imq': 146 | result = self.compute_inv_mult_quad(x1, x2) 147 | else: 148 | raise ValueError('Undefined kernel type.') 149 | 150 | return result 151 | 152 | 153 | def compute_rbf(self, 154 | x1: Tensor, 155 | x2: Tensor, 156 | eps: float = 1e-7) -> Tensor: 157 | """ 158 | Computes the RBF Kernel between x1 and x2. 159 | :param x1: (Tensor) 160 | :param x2: (Tensor) 161 | :param eps: (Float) 162 | :return: 163 | """ 164 | z_dim = x2.size(-1) 165 | sigma = 2. * z_dim * self.z_var 166 | 167 | result = torch.exp(-((x1 - x2).pow(2).mean(-1) / sigma)) 168 | return result 169 | 170 | def compute_inv_mult_quad(self, 171 | x1: Tensor, 172 | x2: Tensor, 173 | eps: float = 1e-7) -> Tensor: 174 | """ 175 | Computes the Inverse Multi-Quadratics Kernel between x1 and x2, 176 | given by 177 | 178 | k(x_1, x_2) = \sum \frac{C}{C + \|x_1 - x_2 \|^2} 179 | :param x1: (Tensor) 180 | :param x2: (Tensor) 181 | :param eps: (Float) 182 | :return: 183 | """ 184 | z_dim = x2.size(-1) 185 | C = 2 * z_dim * self.z_var 186 | kernel = C / (eps + C + (x1 - x2).pow(2).sum(dim = -1)) 187 | 188 | # Exclude diagonal elements 189 | result = kernel.sum() - kernel.diag().sum() 190 | 191 | return result 192 | 193 | def compute_mmd(self, z: Tensor, reg_weight: float) -> Tensor: 194 | # Sample from prior (Gaussian) distribution 195 | prior_z = torch.randn_like(z) 196 | 197 | prior_z__kernel = self.compute_kernel(prior_z, prior_z) 198 | z__kernel = self.compute_kernel(z, z) 199 | priorz_z__kernel = self.compute_kernel(prior_z, z) 200 | 201 | mmd = reg_weight * prior_z__kernel.mean() + \ 202 | reg_weight * z__kernel.mean() - \ 203 | 2 * reg_weight * priorz_z__kernel.mean() 204 | return mmd 205 | 206 | def sample(self, 207 | num_samples:int, 208 | current_device: int, **kwargs) -> Tensor: 209 | """ 210 | Samples from the latent space and return the corresponding 211 | image space map. 212 | :param num_samples: (Int) Number of samples 213 | :param current_device: (Int) Device to run the model 214 | :return: (Tensor) 215 | """ 216 | z = torch.randn(num_samples, 217 | self.latent_dim) 218 | 219 | z = z.to(current_device) 220 | 221 | samples = self.decode(z) 222 | return samples 223 | 224 | def generate(self, x: Tensor, **kwargs) -> Tensor: 225 | """ 226 | Given an input image x, returns the reconstructed image 227 | :param x: (Tensor) [B x C x H x W] 228 | :return: (Tensor) [B x C x H x W] 229 | """ 230 | 231 | return self.forward(x)[0] -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pytorch-lightning==1.5.6 2 | PyYAML==6.0 3 | tensorboard>=2.2.0 4 | torch>=1.6.1 5 | torchsummary==1.5.1 6 | torchvision>=0.10.1 -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | import argparse 4 | import numpy as np 5 | from pathlib import Path 6 | from models import * 7 | from experiment import VAEXperiment 8 | import torch.backends.cudnn as cudnn 9 | from pytorch_lightning import Trainer 10 | from pytorch_lightning.loggers import TensorBoardLogger 11 | from pytorch_lightning.utilities.seed import seed_everything 12 | from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint 13 | from dataset import VAEDataset 14 | from pytorch_lightning.plugins import DDPPlugin 15 | 16 | 17 | parser = argparse.ArgumentParser(description='Generic runner for VAE models') 18 | parser.add_argument('--config', '-c', 19 | dest="filename", 20 | metavar='FILE', 21 | help = 'path to the config file', 22 | default='configs/vae.yaml') 23 | 24 | args = parser.parse_args() 25 | with open(args.filename, 'r') as file: 26 | try: 27 | config = yaml.safe_load(file) 28 | except yaml.YAMLError as exc: 29 | print(exc) 30 | 31 | 32 | tb_logger = TensorBoardLogger(save_dir=config['logging_params']['save_dir'], 33 | name=config['model_params']['name'],) 34 | 35 | # For reproducibility 36 | seed_everything(config['exp_params']['manual_seed'], True) 37 | 38 | model = vae_models[config['model_params']['name']](**config['model_params']) 39 | experiment = VAEXperiment(model, 40 | config['exp_params']) 41 | 42 | data = VAEDataset(**config["data_params"], pin_memory=len(config['trainer_params']['gpus']) != 0) 43 | 44 | data.setup() 45 | runner = Trainer(logger=tb_logger, 46 | callbacks=[ 47 | LearningRateMonitor(), 48 | ModelCheckpoint(save_top_k=2, 49 | dirpath =os.path.join(tb_logger.log_dir , "checkpoints"), 50 | monitor= "val_loss", 51 | save_last= True), 52 | ], 53 | strategy=DDPPlugin(find_unused_parameters=False), 54 | **config['trainer_params']) 55 | 56 | 57 | Path(f"{tb_logger.log_dir}/Samples").mkdir(exist_ok=True, parents=True) 58 | Path(f"{tb_logger.log_dir}/Reconstructions").mkdir(exist_ok=True, parents=True) 59 | 60 | 61 | print(f"======= Training {config['model_params']['name']} =======") 62 | runner.fit(experiment, datamodule=data) -------------------------------------------------------------------------------- /tests/bvae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import unittest 3 | from models import BetaVAE 4 | from torchsummary import summary 5 | 6 | 7 | class TestVAE(unittest.TestCase): 8 | 9 | def setUp(self) -> None: 10 | # self.model2 = VAE(3, 10) 11 | self.model = BetaVAE(3, 10, loss_type='H').cuda() 12 | 13 | def test_summary(self): 14 | print(summary(self.model, (3, 64, 64), device='cpu')) 15 | # print(summary(self.model2, (3, 64, 64), device='cpu')) 16 | 17 | def test_forward(self): 18 | x = torch.randn(16, 3, 64, 64) 19 | y = self.model(x) 20 | print("Model Output size:", y[0].size()) 21 | # print("Model2 Output size:", self.model2(x)[0].size()) 22 | 23 | def test_loss(self): 24 | x = torch.randn(16, 3, 64, 64).cuda() 25 | 26 | result = self.model(x) 27 | loss = self.model.loss_function(*result, M_N = 0.005) 28 | print(loss) 29 | 30 | 31 | if __name__ == '__main__': 32 | unittest.main() -------------------------------------------------------------------------------- /tests/test_betatcvae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import unittest 3 | from models import BetaTCVAE 4 | from torchsummary import summary 5 | 6 | 7 | class TestBetaTCVAE(unittest.TestCase): 8 | 9 | def setUp(self) -> None: 10 | # self.model2 = VAE(3, 10) 11 | self.model = BetaTCVAE(3, 64, anneal_steps= 100) 12 | 13 | def test_summary(self): 14 | print(summary(self.model, (3, 64, 64), device='cpu')) 15 | # print(summary(self.model2, (3, 64, 64), device='cpu')) 16 | 17 | def test_forward(self): 18 | print(sum(p.numel() for p in self.model.parameters() if p.requires_grad)) 19 | x = torch.randn(16, 3, 64, 64) 20 | y = self.model(x) 21 | print("Model Output size:", y[0].size()) 22 | # print("Model2 Output size:", self.model2(x)[0].size()) 23 | 24 | def test_loss(self): 25 | x = torch.randn(16, 3, 64, 64) 26 | 27 | result = self.model(x) 28 | loss = self.model.loss_function(*result, M_N = 0.005) 29 | print(loss) 30 | 31 | def test_sample(self): 32 | self.model.cuda() 33 | y = self.model.sample(8, 'cuda') 34 | print(y.shape) 35 | 36 | def test_generate(self): 37 | x = torch.randn(16, 3, 64, 64) 38 | y = self.model.generate(x) 39 | print(y.shape) 40 | 41 | 42 | if __name__ == '__main__': 43 | unittest.main() -------------------------------------------------------------------------------- /tests/test_cat_vae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import unittest 3 | from models import GumbelVAE 4 | from torchsummary import summary 5 | 6 | 7 | class TestVAE(unittest.TestCase): 8 | 9 | def setUp(self) -> None: 10 | # self.model2 = VAE(3, 10) 11 | self.model = GumbelVAE(3, 10) 12 | 13 | def test_summary(self): 14 | print(summary(self.model, (3, 64, 64), device='cpu')) 15 | # print(summary(self.model2, (3, 64, 64), device='cpu')) 16 | 17 | def test_forward(self): 18 | x = torch.randn(16, 3, 64, 64) 19 | y = self.model(x) 20 | print("Model Output size:", y[0].size()) 21 | # print("Model2 Output size:", self.model2(x)[0].size()) 22 | 23 | def test_loss(self): 24 | x = torch.randn(128, 3, 64, 64) 25 | 26 | result = self.model(x) 27 | loss = self.model.loss_function(*result, M_N = 0.005, batch_idx=5) 28 | print(loss) 29 | 30 | 31 | def test_sample(self): 32 | self.model.cuda() 33 | y = self.model.sample(144, 0) 34 | print(y.shape) 35 | 36 | 37 | if __name__ == '__main__': 38 | unittest.main() -------------------------------------------------------------------------------- /tests/test_dfc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import unittest 3 | from models import DFCVAE 4 | from torchsummary import summary 5 | 6 | 7 | class TestDFCVAE(unittest.TestCase): 8 | 9 | def setUp(self) -> None: 10 | # self.model2 = VAE(3, 10) 11 | self.model = DFCVAE(3, 10) 12 | 13 | def test_summary(self): 14 | print(summary(self.model, (3, 64, 64), device='cpu')) 15 | 16 | # print(summary(self.model2, (3, 64, 64), device='cpu')) 17 | 18 | def test_forward(self): 19 | x = torch.randn(16, 3, 64, 64) 20 | y = self.model(x) 21 | print("Model Output size:", y[0].size()) 22 | 23 | # print("Model2 Output size:", self.model2(x)[0].size()) 24 | 25 | def test_loss(self): 26 | x = torch.randn(16, 3, 64, 64) 27 | 28 | result = self.model(x) 29 | loss = self.model.loss_function(*result, M_N = 0.005) 30 | print(loss) 31 | 32 | def test_sample(self): 33 | self.model.cuda() 34 | y = self.model.sample(144, 0) 35 | 36 | 37 | 38 | 39 | if __name__ == '__main__': 40 | unittest.main() -------------------------------------------------------------------------------- /tests/test_dipvae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import unittest 3 | from models import DIPVAE 4 | from torchsummary import summary 5 | 6 | 7 | class TestDIPVAE(unittest.TestCase): 8 | 9 | def setUp(self) -> None: 10 | # self.model2 = VAE(3, 10) 11 | self.model = DIPVAE(3, 64) 12 | 13 | def test_summary(self): 14 | print(summary(self.model, (3, 64, 64), device='cpu')) 15 | # print(summary(self.model2, (3, 64, 64), device='cpu')) 16 | 17 | def test_forward(self): 18 | print(sum(p.numel() for p in self.model.parameters() if p.requires_grad)) 19 | x = torch.randn(16, 3, 64, 64) 20 | y = self.model(x) 21 | print("Model Output size:", y[0].size()) 22 | # print("Model2 Output size:", self.model2(x)[0].size()) 23 | 24 | def test_loss(self): 25 | x = torch.randn(16, 3, 64, 64) 26 | 27 | result = self.model(x) 28 | loss = self.model.loss_function(*result, M_N = 0.005) 29 | print(loss) 30 | 31 | def test_sample(self): 32 | self.model.cuda() 33 | y = self.model.sample(8, 'cuda') 34 | print(y.shape) 35 | 36 | def test_generate(self): 37 | x = torch.randn(16, 3, 64, 64) 38 | y = self.model.generate(x) 39 | print(y.shape) 40 | 41 | 42 | if __name__ == '__main__': 43 | unittest.main() -------------------------------------------------------------------------------- /tests/test_fvae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import unittest 3 | from models import FactorVAE 4 | from torchsummary import summary 5 | 6 | 7 | class TestFAE(unittest.TestCase): 8 | 9 | def setUp(self) -> None: 10 | # self.model2 = VAE(3, 10) 11 | self.model = FactorVAE(3, 10) 12 | 13 | def test_summary(self): 14 | print(summary(self.model, (3, 64, 64), device='cpu')) 15 | # 16 | # print(sum(p.numel() for p in self.model.parameters() if p.requires_grad)) 17 | 18 | # print(summary(self.model2, (3, 64, 64), device='cpu')) 19 | 20 | def test_forward(self): 21 | x = torch.randn(16, 3, 64, 64) 22 | y = self.model(x) 23 | print("Model Output size:", y[0].size()) 24 | 25 | # print("Model2 Output size:", self.model2(x)[0].size()) 26 | 27 | def test_loss(self): 28 | x = torch.randn(16, 3, 64, 64) 29 | x2 = torch.randn(16,3, 64, 64) 30 | 31 | result = self.model(x) 32 | loss = self.model.loss_function(*result, M_N = 0.005, optimizer_idx=0, secondary_input=x2) 33 | loss = self.model.loss_function(*result, M_N = 0.005, optimizer_idx=1, secondary_input=x2) 34 | print(loss) 35 | 36 | def test_optim(self): 37 | optim1 = torch.optim.Adam(self.model.parameters(), lr = 0.001) 38 | optim2 = torch.optim.Adam(self.model.discrminator.parameters(), lr = 0.001) 39 | 40 | def test_sample(self): 41 | self.model.cuda() 42 | y = self.model.sample(144, 0) 43 | 44 | 45 | 46 | 47 | if __name__ == '__main__': 48 | unittest.main() -------------------------------------------------------------------------------- /tests/test_gvae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import unittest 3 | from models import GammaVAE 4 | from torchsummary import summary 5 | 6 | 7 | class TestGammaVAE(unittest.TestCase): 8 | 9 | def setUp(self) -> None: 10 | # self.model2 = VAE(3, 10) 11 | self.model = GammaVAE(3, 10) 12 | 13 | def test_summary(self): 14 | print(summary(self.model, (3, 64, 64), device='cpu')) 15 | 16 | # print(summary(self.model2, (3, 64, 64), device='cpu')) 17 | 18 | def test_forward(self): 19 | x = torch.randn(16, 3, 64, 64) 20 | y = self.model(x) 21 | print("Model Output size:", y[0].size()) 22 | 23 | # print("Model2 Output size:", self.model2(x)[0].size()) 24 | 25 | def test_loss(self): 26 | x = torch.randn(16, 3, 64, 64) 27 | 28 | result = self.model(x) 29 | loss = self.model.loss_function(*result, M_N = 0.005) 30 | print(loss) 31 | 32 | def test_sample(self): 33 | self.model.cuda() 34 | y = self.model.sample(144, 0) 35 | 36 | 37 | 38 | 39 | if __name__ == '__main__': 40 | unittest.main() -------------------------------------------------------------------------------- /tests/test_hvae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import unittest 3 | from models import HVAE 4 | from torchsummary import summary 5 | 6 | 7 | class TestHVAE(unittest.TestCase): 8 | 9 | def setUp(self) -> None: 10 | # self.model2 = VAE(3, 10) 11 | self.model = HVAE(3, latent1_dim=10, latent2_dim=20) 12 | 13 | def test_summary(self): 14 | print(summary(self.model, (3, 64, 64), device='cpu')) 15 | # print(summary(self.model2, (3, 64, 64), device='cpu')) 16 | 17 | def test_forward(self): 18 | x = torch.randn(16, 3, 64, 64) 19 | y = self.model(x) 20 | print("Model Output size:", y[0].size()) 21 | # print("Model2 Output size:", self.model2(x)[0].size()) 22 | 23 | def test_loss(self): 24 | x = torch.randn(16, 3, 64, 64) 25 | 26 | result = self.model(x) 27 | loss = self.model.loss_function(*result, M_N = 0.005) 28 | print(loss) 29 | 30 | 31 | if __name__ == '__main__': 32 | unittest.main() -------------------------------------------------------------------------------- /tests/test_iwae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import unittest 3 | from models import IWAE 4 | from torchsummary import summary 5 | 6 | 7 | class TestIWAE(unittest.TestCase): 8 | 9 | def setUp(self) -> None: 10 | # self.model2 = VAE(3, 10) 11 | self.model = IWAE(3, 10) 12 | 13 | def test_summary(self): 14 | print(summary(self.model, (3, 64, 64), device='cpu')) 15 | # print(summary(self.model2, (3, 64, 64), device='cpu')) 16 | 17 | def test_forward(self): 18 | x = torch.randn(16, 3, 64, 64) 19 | y = self.model(x) 20 | print("Model Output size:", y[0].size()) 21 | # print("Model2 Output size:", self.model2(x)[0].size()) 22 | 23 | def test_loss(self): 24 | x = torch.randn(16, 3, 64, 64) 25 | 26 | result = self.model(x) 27 | loss = self.model.loss_function(*result, M_N = 0.005) 28 | print(loss) 29 | 30 | def test_sample(self): 31 | self.model.cuda() 32 | y = self.model.sample(144, 0) 33 | 34 | 35 | if __name__ == '__main__': 36 | unittest.main() -------------------------------------------------------------------------------- /tests/test_joint_Vae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import unittest 3 | from models import JointVAE 4 | from torchsummary import summary 5 | 6 | 7 | class TestVAE(unittest.TestCase): 8 | 9 | def setUp(self) -> None: 10 | # self.model2 = VAE(3, 10) 11 | self.model = JointVAE(3, 10, 40, 0.0) 12 | 13 | def test_summary(self): 14 | print(summary(self.model, (3, 64, 64), device='cpu')) 15 | # print(summary(self.model2, (3, 64, 64), device='cpu')) 16 | 17 | def test_forward(self): 18 | x = torch.randn(16, 3, 64, 64) 19 | y = self.model(x) 20 | print("Model Output size:", y[0].size()) 21 | # print("Model2 Output size:", self.model2(x)[0].size()) 22 | 23 | def test_loss(self): 24 | x = torch.randn(128, 3, 64, 64) 25 | 26 | result = self.model(x) 27 | loss = self.model.loss_function(*result, M_N = 0.005, batch_idx=5) 28 | print(loss) 29 | 30 | 31 | def test_sample(self): 32 | self.model.cuda() 33 | y = self.model.sample(144, 0) 34 | print(y.shape) 35 | 36 | 37 | if __name__ == '__main__': 38 | unittest.main() -------------------------------------------------------------------------------- /tests/test_logcosh.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import unittest 3 | from models import LogCoshVAE 4 | from torchsummary import summary 5 | 6 | 7 | class TestVAE(unittest.TestCase): 8 | 9 | def setUp(self) -> None: 10 | # self.model2 = VAE(3, 10) 11 | self.model = LogCoshVAE(3, 10, alpha=10) 12 | 13 | def test_summary(self): 14 | print(summary(self.model, (3, 64, 64), device='cpu')) 15 | # print(summary(self.model2, (3, 64, 64), device='cpu')) 16 | 17 | def test_forward(self): 18 | x = torch.randn(16, 3, 64, 64) 19 | y = self.model(x) 20 | print("Model Output size:", y[0].size()) 21 | # print("Model2 Output size:", self.model2(x)[0].size()) 22 | 23 | def test_loss(self): 24 | x = torch.rand(16, 3, 64, 64) 25 | 26 | result = self.model(x) 27 | loss = self.model.loss_function(*result, M_N = 0.005) 28 | print(loss) 29 | 30 | 31 | if __name__ == '__main__': 32 | unittest.main() -------------------------------------------------------------------------------- /tests/test_lvae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import unittest 3 | from models import LVAE 4 | from torchsummary import summary 5 | 6 | 7 | class TestLVAE(unittest.TestCase): 8 | 9 | def setUp(self) -> None: 10 | # self.model2 = VAE(3, 10) 11 | self.model = LVAE(3, [4,8,16,32,128], hidden_dims=[32, 64,128, 256, 512]) 12 | 13 | def test_summary(self): 14 | print(summary(self.model, (3, 64, 64), device='cpu')) 15 | # print(summary(self.model2, (3, 64, 64), device='cpu')) 16 | 17 | def test_forward(self): 18 | x = torch.randn(16, 3, 64, 64) 19 | y = self.model(x) 20 | 21 | print("Model Output size:", y[0].size()) 22 | # print("Model2 Output size:", self.model2(x)[0].size()) 23 | 24 | def test_loss(self): 25 | x = torch.randn(16, 3, 64, 64) 26 | 27 | result = self.model(x) 28 | loss = self.model.loss_function(*result, M_N = 0.005) 29 | print(loss) 30 | 31 | def test_sample(self): 32 | self.model.cuda() 33 | y = self.model.sample(144, 0) 34 | print(y.shape) 35 | 36 | 37 | if __name__ == '__main__': 38 | unittest.main() -------------------------------------------------------------------------------- /tests/test_miwae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import unittest 3 | from models import MIWAE 4 | from torchsummary import summary 5 | 6 | 7 | class TestMIWAE(unittest.TestCase): 8 | 9 | def setUp(self) -> None: 10 | # self.model2 = VAE(3, 10) 11 | self.model = MIWAE(3, 10) 12 | 13 | def test_summary(self): 14 | print(summary(self.model, (3, 64, 64), device='cpu')) 15 | # print(summary(self.model2, (3, 64, 64), device='cpu')) 16 | 17 | def test_forward(self): 18 | x = torch.randn(16, 3, 64, 64) 19 | y = self.model(x) 20 | print("Model Output size:", y[0].size()) 21 | # print("Model2 Output size:", self.model2(x)[0].size()) 22 | 23 | def test_loss(self): 24 | x = torch.randn(16, 3, 64, 64) 25 | 26 | result = self.model(x) 27 | loss = self.model.loss_function(*result, M_N = 0.005) 28 | print(loss) 29 | 30 | def test_sample(self): 31 | self.model.cuda() 32 | y = self.model.sample(144, 0) 33 | print(y.shape) 34 | 35 | def test_generate(self): 36 | x = torch.randn(16, 3, 64, 64) 37 | y = self.model.generate(x) 38 | print(y.shape) 39 | 40 | 41 | if __name__ == '__main__': 42 | unittest.main() -------------------------------------------------------------------------------- /tests/test_mssimvae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import unittest 3 | from models import MSSIMVAE 4 | from torchsummary import summary 5 | 6 | 7 | class TestMSSIMVAE(unittest.TestCase): 8 | 9 | def setUp(self) -> None: 10 | # self.model2 = VAE(3, 10) 11 | self.model = MSSIMVAE(3, 10) 12 | 13 | def test_summary(self): 14 | print(summary(self.model, (3, 64, 64), device='cpu')) 15 | 16 | # print(summary(self.model2, (3, 64, 64), device='cpu')) 17 | 18 | def test_forward(self): 19 | x = torch.randn(16, 3, 64, 64) 20 | y = self.model(x) 21 | print("Model Output size:", y[0].size()) 22 | 23 | # print("Model2 Output size:", self.model2(x)[0].size()) 24 | 25 | def test_loss(self): 26 | x = torch.randn(16, 3, 64, 64) 27 | 28 | result = self.model(x) 29 | loss = self.model.loss_function(*result, M_N = 0.005) 30 | print(loss) 31 | 32 | def test_sample(self): 33 | self.model.cuda() 34 | y = self.model.sample(144, 0) 35 | 36 | 37 | 38 | 39 | if __name__ == '__main__': 40 | unittest.main() -------------------------------------------------------------------------------- /tests/test_swae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import unittest 3 | from models import SWAE 4 | from torchsummary import summary 5 | 6 | 7 | class TestSWAE(unittest.TestCase): 8 | 9 | def setUp(self) -> None: 10 | self.model = SWAE(3, 10, reg_weight = 100) 11 | 12 | def test_summary(self): 13 | print(summary(self.model, (3, 64, 64), device='cpu')) 14 | # print(summary(self.model2, (3, 64, 64), device='cpu')) 15 | 16 | def test_forward(self): 17 | x = torch.randn(16, 3, 64, 64) 18 | y = self.model(x) 19 | print("Model Output size:", y[0].size()) 20 | # print("Model2 Output size:", self.model2(x)[0].size()) 21 | 22 | def test_loss(self): 23 | x = torch.randn(16, 3, 64, 64) 24 | 25 | result = self.model(x) 26 | loss = self.model.loss_function(*result) 27 | print(loss) 28 | 29 | 30 | if __name__ == '__main__': 31 | unittest.main() -------------------------------------------------------------------------------- /tests/test_vae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import unittest 3 | from models import VanillaVAE 4 | from torchsummary import summary 5 | 6 | 7 | class TestVAE(unittest.TestCase): 8 | 9 | def setUp(self) -> None: 10 | # self.model2 = VAE(3, 10) 11 | self.model = VanillaVAE(3, 10) 12 | 13 | def test_summary(self): 14 | print(summary(self.model, (3, 64, 64), device='cpu')) 15 | # print(summary(self.model2, (3, 64, 64), device='cpu')) 16 | 17 | def test_forward(self): 18 | x = torch.randn(16, 3, 64, 64) 19 | y = self.model(x) 20 | print("Model Output size:", y[0].size()) 21 | # print("Model2 Output size:", self.model2(x)[0].size()) 22 | 23 | def test_loss(self): 24 | x = torch.randn(16, 3, 64, 64) 25 | 26 | result = self.model(x) 27 | loss = self.model.loss_function(*result, M_N = 0.005) 28 | print(loss) 29 | 30 | 31 | if __name__ == '__main__': 32 | unittest.main() -------------------------------------------------------------------------------- /tests/test_vq_vae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import unittest 3 | from models import VQVAE 4 | from torchsummary import summary 5 | 6 | 7 | class TestVQVAE(unittest.TestCase): 8 | 9 | def setUp(self) -> None: 10 | # self.model2 = VAE(3, 10) 11 | self.model = VQVAE(3, 64, 512) 12 | 13 | def test_summary(self): 14 | print(summary(self.model, (3, 64, 64), device='cpu')) 15 | # print(summary(self.model2, (3, 64, 64), device='cpu')) 16 | 17 | def test_forward(self): 18 | print(sum(p.numel() for p in self.model.parameters() if p.requires_grad)) 19 | x = torch.randn(16, 3, 64, 64) 20 | y = self.model(x) 21 | print("Model Output size:", y[0].size()) 22 | # print("Model2 Output size:", self.model2(x)[0].size()) 23 | 24 | def test_loss(self): 25 | x = torch.randn(16, 3, 64, 64) 26 | 27 | result = self.model(x) 28 | loss = self.model.loss_function(*result, M_N = 0.005) 29 | print(loss) 30 | 31 | def test_sample(self): 32 | self.model.cuda() 33 | y = self.model.sample(8, 'cuda') 34 | print(y.shape) 35 | 36 | def test_generate(self): 37 | x = torch.randn(16, 3, 64, 64) 38 | y = self.model.generate(x) 39 | print(y.shape) 40 | 41 | 42 | if __name__ == '__main__': 43 | unittest.main() -------------------------------------------------------------------------------- /tests/test_wae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import unittest 3 | from models import WAE_MMD 4 | from torchsummary import summary 5 | 6 | 7 | class TestWAE(unittest.TestCase): 8 | 9 | def setUp(self) -> None: 10 | self.model = WAE_MMD(3, 10, reg_weight = 100) 11 | 12 | def test_summary(self): 13 | print(summary(self.model, (3, 64, 64), device='cpu')) 14 | # print(summary(self.model2, (3, 64, 64), device='cpu')) 15 | 16 | def test_forward(self): 17 | x = torch.randn(16, 3, 64, 64) 18 | y = self.model(x) 19 | print("Model Output size:", y[0].size()) 20 | # print("Model2 Output size:", self.model2(x)[0].size()) 21 | 22 | def test_loss(self): 23 | x = torch.randn(16, 3, 64, 64) 24 | 25 | result = self.model(x) 26 | loss = self.model.loss_function(*result) 27 | print(loss) 28 | 29 | 30 | if __name__ == '__main__': 31 | unittest.main() -------------------------------------------------------------------------------- /tests/text_cvae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import unittest 3 | from models import CVAE 4 | 5 | 6 | class TestCVAE(unittest.TestCase): 7 | 8 | def setUp(self) -> None: 9 | # self.model2 = VAE(3, 10) 10 | self.model = CVAE(3, 40, 10) 11 | 12 | def test_forward(self): 13 | x = torch.randn(16, 3, 64, 64) 14 | c = torch.randn(16, 40) 15 | y = self.model(x, c) 16 | print("Model Output size:", y[0].size()) 17 | # print("Model2 Output size:", self.model2(x)[0].size()) 18 | 19 | def test_loss(self): 20 | x = torch.randn(16, 3, 64, 64) 21 | c = torch.randn(16, 40) 22 | result = self.model(x, labels = c) 23 | loss = self.model.loss_function(*result, M_N = 0.005) 24 | print(loss) 25 | 26 | 27 | if __name__ == '__main__': 28 | unittest.main() -------------------------------------------------------------------------------- /tests/text_vamp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import unittest 3 | from models import VampVAE 4 | from torchsummary import summary 5 | 6 | 7 | class TestVVAE(unittest.TestCase): 8 | 9 | def setUp(self) -> None: 10 | # self.model2 = VAE(3, 10) 11 | self.model = VampVAE(3, latent_dim=10).cuda() 12 | 13 | def test_summary(self): 14 | print(summary(self.model, (3, 64, 64), device='cpu')) 15 | # print(summary(self.model2, (3, 64, 64), device='cpu')) 16 | 17 | def test_forward(self): 18 | x = torch.randn(16, 3, 64, 64) 19 | y = self.model(x) 20 | print("Model Output size:", y[0].size()) 21 | # print("Model2 Output size:", self.model2(x)[0].size()) 22 | 23 | def test_loss(self): 24 | x = torch.randn(144, 3, 64, 64).cuda() 25 | 26 | result = self.model(x) 27 | loss = self.model.loss_function(*result, M_N = 0.005) 28 | print(loss) 29 | 30 | 31 | if __name__ == '__main__': 32 | unittest.main() -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | 3 | 4 | ## Utils to handle newer PyTorch Lightning changes from version 0.6 5 | ## ==================================================================================================== ## 6 | 7 | 8 | def data_loader(fn): 9 | """ 10 | Decorator to handle the deprecation of data_loader from 0.7 11 | :param fn: User defined data loader function 12 | :return: A wrapper for the data_loader function 13 | """ 14 | 15 | def func_wrapper(self): 16 | try: # Works for version 0.6.0 17 | return pl.data_loader(fn)(self) 18 | 19 | except: # Works for version > 0.6.0 20 | return fn(self) 21 | 22 | return func_wrapper 23 | --------------------------------------------------------------------------------