├── .gitignore
├── .idea
├── .gitignore
├── PyTorch-VAE.iml
├── inspectionProfiles
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
└── vcs.xml
├── LICENSE.md
├── README.md
├── assets
├── BetaTCVAE_49.png
├── BetaVAE_B_20.png
├── BetaVAE_B_35.png
├── BetaVAE_H_20.png
├── CategoricalVAE_49.png
├── ConditionalVAE_20.png
├── DFCVAE_49.png
├── DIPVAE_83.png
├── IWAE_19.png
├── InfoVAE_31.png
├── InfoVAE_7.png
├── JointVAE_49.png
├── LogCoshVAE_49.png
├── MIWAE_29.png
├── MSSIMVAE_29.png
├── SWAE_49.png
├── Vanilla VAE_25.png
├── WAE_IMQ_15.png
├── WAE_RBF_18.png
├── recons_BetaTCVAE_49.png
├── recons_BetaVAE_B_20.png
├── recons_BetaVAE_B_35.png
├── recons_BetaVAE_H_20.png
├── recons_CategoricalVAE_49.png
├── recons_ConditionalVAE_20.png
├── recons_DFCVAE_49.png
├── recons_DIPVAE_83.png
├── recons_IWAE_19.png
├── recons_InfoVAE_31.png
├── recons_JointVAE_49.png
├── recons_LogCoshVAE_49.png
├── recons_MIWAE_29.png
├── recons_MSSIMVAE_29.png
├── recons_SWAE_49.png
├── recons_VQVAE_29.png
├── recons_Vanilla VAE_25.png
├── recons_WAE_IMQ_15.png
└── recons_WAE_RBF_19.png
├── configs
├── bbvae.yaml
├── betatc_vae.yaml
├── bhvae.yaml
├── cat_vae.yaml
├── cvae.yaml
├── dfc_vae.yaml
├── dip_vae.yaml
├── factorvae.yaml
├── gammavae.yaml
├── hvae.yaml
├── infovae.yaml
├── iwae.yaml
├── joint_vae.yaml
├── logcosh_vae.yaml
├── lvae.yaml
├── miwae.yaml
├── mssim_vae.yaml
├── swae.yaml
├── vae.yaml
├── vampvae.yaml
├── vq_vae.yaml
├── wae_mmd_imq.yaml
└── wae_mmd_rbf.yaml
├── dataset.py
├── experiment.py
├── models
├── __init__.py
├── base.py
├── beta_vae.py
├── betatc_vae.py
├── cat_vae.py
├── cvae.py
├── dfcvae.py
├── dip_vae.py
├── fvae.py
├── gamma_vae.py
├── hvae.py
├── info_vae.py
├── iwae.py
├── joint_vae.py
├── logcosh_vae.py
├── lvae.py
├── miwae.py
├── mssim_vae.py
├── swae.py
├── twostage_vae.py
├── types_.py
├── vampvae.py
├── vanilla_vae.py
├── vq_vae.py
└── wae_mmd.py
├── requirements.txt
├── run.py
├── tests
├── bvae.py
├── test_betatcvae.py
├── test_cat_vae.py
├── test_dfc.py
├── test_dipvae.py
├── test_fvae.py
├── test_gvae.py
├── test_hvae.py
├── test_iwae.py
├── test_joint_Vae.py
├── test_logcosh.py
├── test_lvae.py
├── test_miwae.py
├── test_mssimvae.py
├── test_swae.py
├── test_vae.py
├── test_vq_vae.py
├── test_wae.py
├── text_cvae.py
└── text_vamp.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 |
2 | Data/
3 | logs/
4 |
5 | VanillaVAE/version_0/
6 |
7 | __pycache__/
8 | .ipynb_checkpoints/
9 |
10 | Run.ipynb
11 |
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /workspace.xml
3 |
--------------------------------------------------------------------------------
/.idea/PyTorch-VAE.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/assets/BetaTCVAE_49.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/BetaTCVAE_49.png
--------------------------------------------------------------------------------
/assets/BetaVAE_B_20.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/BetaVAE_B_20.png
--------------------------------------------------------------------------------
/assets/BetaVAE_B_35.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/BetaVAE_B_35.png
--------------------------------------------------------------------------------
/assets/BetaVAE_H_20.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/BetaVAE_H_20.png
--------------------------------------------------------------------------------
/assets/CategoricalVAE_49.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/CategoricalVAE_49.png
--------------------------------------------------------------------------------
/assets/ConditionalVAE_20.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/ConditionalVAE_20.png
--------------------------------------------------------------------------------
/assets/DFCVAE_49.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/DFCVAE_49.png
--------------------------------------------------------------------------------
/assets/DIPVAE_83.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/DIPVAE_83.png
--------------------------------------------------------------------------------
/assets/IWAE_19.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/IWAE_19.png
--------------------------------------------------------------------------------
/assets/InfoVAE_31.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/InfoVAE_31.png
--------------------------------------------------------------------------------
/assets/InfoVAE_7.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/InfoVAE_7.png
--------------------------------------------------------------------------------
/assets/JointVAE_49.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/JointVAE_49.png
--------------------------------------------------------------------------------
/assets/LogCoshVAE_49.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/LogCoshVAE_49.png
--------------------------------------------------------------------------------
/assets/MIWAE_29.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/MIWAE_29.png
--------------------------------------------------------------------------------
/assets/MSSIMVAE_29.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/MSSIMVAE_29.png
--------------------------------------------------------------------------------
/assets/SWAE_49.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/SWAE_49.png
--------------------------------------------------------------------------------
/assets/Vanilla VAE_25.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/Vanilla VAE_25.png
--------------------------------------------------------------------------------
/assets/WAE_IMQ_15.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/WAE_IMQ_15.png
--------------------------------------------------------------------------------
/assets/WAE_RBF_18.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/WAE_RBF_18.png
--------------------------------------------------------------------------------
/assets/recons_BetaTCVAE_49.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/recons_BetaTCVAE_49.png
--------------------------------------------------------------------------------
/assets/recons_BetaVAE_B_20.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/recons_BetaVAE_B_20.png
--------------------------------------------------------------------------------
/assets/recons_BetaVAE_B_35.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/recons_BetaVAE_B_35.png
--------------------------------------------------------------------------------
/assets/recons_BetaVAE_H_20.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/recons_BetaVAE_H_20.png
--------------------------------------------------------------------------------
/assets/recons_CategoricalVAE_49.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/recons_CategoricalVAE_49.png
--------------------------------------------------------------------------------
/assets/recons_ConditionalVAE_20.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/recons_ConditionalVAE_20.png
--------------------------------------------------------------------------------
/assets/recons_DFCVAE_49.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/recons_DFCVAE_49.png
--------------------------------------------------------------------------------
/assets/recons_DIPVAE_83.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/recons_DIPVAE_83.png
--------------------------------------------------------------------------------
/assets/recons_IWAE_19.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/recons_IWAE_19.png
--------------------------------------------------------------------------------
/assets/recons_InfoVAE_31.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/recons_InfoVAE_31.png
--------------------------------------------------------------------------------
/assets/recons_JointVAE_49.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/recons_JointVAE_49.png
--------------------------------------------------------------------------------
/assets/recons_LogCoshVAE_49.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/recons_LogCoshVAE_49.png
--------------------------------------------------------------------------------
/assets/recons_MIWAE_29.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/recons_MIWAE_29.png
--------------------------------------------------------------------------------
/assets/recons_MSSIMVAE_29.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/recons_MSSIMVAE_29.png
--------------------------------------------------------------------------------
/assets/recons_SWAE_49.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/recons_SWAE_49.png
--------------------------------------------------------------------------------
/assets/recons_VQVAE_29.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/recons_VQVAE_29.png
--------------------------------------------------------------------------------
/assets/recons_Vanilla VAE_25.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/recons_Vanilla VAE_25.png
--------------------------------------------------------------------------------
/assets/recons_WAE_IMQ_15.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/recons_WAE_IMQ_15.png
--------------------------------------------------------------------------------
/assets/recons_WAE_RBF_19.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AntixK/PyTorch-VAE/a6896b944c918dd7030e7d795a8c13e5c6345ec7/assets/recons_WAE_RBF_19.png
--------------------------------------------------------------------------------
/configs/bbvae.yaml:
--------------------------------------------------------------------------------
1 | model_params:
2 | name: 'BetaVAE'
3 | in_channels: 3
4 | latent_dim: 128
5 | loss_type: 'B'
6 | gamma: 10.0
7 | max_capacity: 25
8 | Capacity_max_iter: 10000
9 |
10 | data_params:
11 | data_path: "Data/"
12 | train_batch_size: 64
13 | val_batch_size: 64
14 | patch_size: 64
15 | num_workers: 4
16 |
17 | exp_params:
18 | LR: 0.005
19 | weight_decay: 0.0
20 | scheduler_gamma: 0.95
21 | kld_weight: 0.00025
22 | manual_seed: 1265
23 |
24 | trainer_params:
25 | gpus: [1]
26 | max_epochs: 10
27 |
28 | logging_params:
29 | save_dir: "logs/"
30 | manual_seed: 1265
31 | name: 'BetaVAE'
32 |
--------------------------------------------------------------------------------
/configs/betatc_vae.yaml:
--------------------------------------------------------------------------------
1 | model_params:
2 | name: 'BetaTCVAE'
3 | in_channels: 3
4 | latent_dim: 10
5 | anneal_steps: 10000
6 | alpha: 1.
7 | beta: 6.
8 | gamma: 1.
9 |
10 | data_params:
11 | data_path: "Data/"
12 | train_batch_size: 64
13 | val_batch_size: 64
14 | patch_size: 64
15 | num_workers: 4
16 |
17 |
18 | exp_params:
19 | LR: 0.005
20 | weight_decay: 0.0
21 | scheduler_gamma: 0.95
22 | kld_weight: 0.00025
23 | manual_seed: 1265
24 |
25 | trainer_params:
26 | gpus: [1]
27 | max_epochs: 10
28 |
29 | logging_params:
30 | save_dir: "logs/"
31 | name: 'BetaTCVAE'
32 |
--------------------------------------------------------------------------------
/configs/bhvae.yaml:
--------------------------------------------------------------------------------
1 | model_params:
2 | name: 'BetaVAE'
3 | in_channels: 3
4 | latent_dim: 128
5 | loss_type: 'H'
6 | beta: 10.
7 |
8 | data_params:
9 | data_path: "Data/"
10 | train_batch_size: 64
11 | val_batch_size: 64
12 | patch_size: 64
13 | num_workers: 4
14 |
15 |
16 | exp_params:
17 | LR: 0.005
18 | weight_decay: 0.0
19 | scheduler_gamma: 0.95
20 | kld_weight: 0.00025
21 | manual_seed: 1265
22 |
23 | trainer_params:
24 | gpus: [1]
25 | max_epochs: 10
26 |
27 | logging_params:
28 | save_dir: "logs/"
29 | name: 'BetaVAE'
30 |
--------------------------------------------------------------------------------
/configs/cat_vae.yaml:
--------------------------------------------------------------------------------
1 | model_params:
2 | name: 'CategoricalVAE'
3 | in_channels: 3
4 | latent_dim: 512
5 | categorical_dim: 40
6 | temperature: 0.5
7 | anneal_rate: 0.00003
8 | anneal_interval: 100
9 | alpha: 1.0
10 |
11 | data_params:
12 | data_path: "Data/"
13 | train_batch_size: 64
14 | val_batch_size: 64
15 | patch_size: 64
16 | num_workers: 4
17 |
18 |
19 | exp_params:
20 | LR: 0.005
21 | weight_decay: 0.0
22 | scheduler_gamma: 0.95
23 | kld_weight: 0.00025
24 | manual_seed: 1265
25 |
26 | trainer_params:
27 | gpus: [1]
28 | max_epochs: 10
29 |
30 | logging_params:
31 | save_dir: "logs/"
32 | name: "CategoricalVAE"
33 |
--------------------------------------------------------------------------------
/configs/cvae.yaml:
--------------------------------------------------------------------------------
1 | model_params:
2 | name: 'ConditionalVAE'
3 | in_channels: 3
4 | num_classes: 40
5 | latent_dim: 128
6 |
7 | data_params:
8 | data_path: "Data/"
9 | train_batch_size: 64
10 | val_batch_size: 64
11 | patch_size: 64
12 | num_workers: 4
13 |
14 |
15 | exp_params:
16 | LR: 0.005
17 | weight_decay: 0.0
18 | scheduler_gamma: 0.95
19 | kld_weight: 0.00025
20 | manual_seed: 1265
21 |
22 | trainer_params:
23 | gpus: [1]
24 | max_epochs: 10
25 |
26 | logging_params:
27 | save_dir: "logs/"
28 | name: "ConditionalVAE"
--------------------------------------------------------------------------------
/configs/dfc_vae.yaml:
--------------------------------------------------------------------------------
1 | model_params:
2 | name: 'DFCVAE'
3 | in_channels: 3
4 | latent_dim: 128
5 |
6 | data_params:
7 | data_path: "Data/"
8 | train_batch_size: 64
9 | val_batch_size: 64
10 | patch_size: 64
11 | num_workers: 4
12 |
13 |
14 | exp_params:
15 | LR: 0.005
16 | weight_decay: 0.0
17 | scheduler_gamma: 0.95
18 | kld_weight: 0.00025
19 | manual_seed: 1265
20 |
21 | trainer_params:
22 | gpus: [1]
23 | max_epochs: 10
24 |
25 | logging_params:
26 | save_dir: "logs/"
27 | name: "DFCVAE"
28 |
--------------------------------------------------------------------------------
/configs/dip_vae.yaml:
--------------------------------------------------------------------------------
1 | model_params:
2 | name: 'DIPVAE'
3 | in_channels: 3
4 | latent_dim: 128
5 | lambda_diag: 0.05
6 | lambda_offdiag: 0.1
7 |
8 |
9 | data_params:
10 | data_path: "Data/"
11 | train_batch_size: 64
12 | val_batch_size: 64
13 | patch_size: 64
14 | num_workers: 4
15 |
16 |
17 | exp_params:
18 | LR: 0.001
19 | weight_decay: 0.0
20 | scheduler_gamma: 0.97
21 | kld_weight: 1
22 | manual_seed: 1265
23 |
24 | trainer_params:
25 | gpus: [1]
26 | max_epochs: 10
27 |
28 | logging_params:
29 | save_dir: "logs/"
30 | name: "DIPVAE"
31 | manual_seed: 1265
32 |
--------------------------------------------------------------------------------
/configs/factorvae.yaml:
--------------------------------------------------------------------------------
1 | model_params:
2 | name: 'FactorVAE'
3 | in_channels: 3
4 | latent_dim: 128
5 | gamma: 6.4
6 |
7 | data_params:
8 | data_path: "Data/"
9 | train_batch_size: 64
10 | val_batch_size: 64
11 | patch_size: 64
12 | num_workers: 4
13 |
14 |
15 | exp_params:
16 | submodel: 'discriminator'
17 | retain_first_backpass: True
18 | LR: 0.005
19 | weight_decay: 0.0
20 | LR_2: 0.005
21 | scheduler_gamma_2: 0.95
22 | scheduler_gamma: 0.95
23 | kld_weight: 0.00025
24 | manual_seed: 1265
25 |
26 | trainer_params:
27 | gpus: [1]
28 | max_epochs: 10
29 |
30 | logging_params:
31 | save_dir: "logs/"
32 | name: "FactorVAE"
33 |
34 |
35 |
--------------------------------------------------------------------------------
/configs/gammavae.yaml:
--------------------------------------------------------------------------------
1 | model_params:
2 | name: 'GammaVAE'
3 | in_channels: 3
4 | latent_dim: 128
5 | gamma_shape: 8.
6 | prior_shape: 2.
7 | prior_rate: 1.
8 |
9 |
10 | data_params:
11 | data_path: "Data/"
12 | train_batch_size: 64
13 | val_batch_size: 64
14 | patch_size: 64
15 | num_workers: 4
16 |
17 |
18 | exp_params:
19 | LR: 0.003
20 | weight_decay: 0.00005
21 | scheduler_gamma: 0.95
22 | kld_weight: 0.00025
23 | manual_seed: 1265
24 |
25 | trainer_params:
26 | gpus: [1]
27 | max_epochs: 10
28 | gradient_clip_val: 0.8
29 |
30 | logging_params:
31 | save_dir: "logs/"
32 | name: "GammaVAE"
33 |
--------------------------------------------------------------------------------
/configs/hvae.yaml:
--------------------------------------------------------------------------------
1 | model_params:
2 | name: 'HVAE'
3 | in_channels: 3
4 | latent1_dim: 64
5 | latent2_dim: 64
6 | pseudo_input_size: 128
7 |
8 | data_params:
9 | data_path: "Data/"
10 | train_batch_size: 64
11 | val_batch_size: 64
12 | patch_size: 64
13 | num_workers: 4
14 |
15 |
16 | exp_params:
17 | LR: 0.005
18 | weight_decay: 0.0
19 | scheduler_gamma: 0.95
20 | kld_weight: 0.00025
21 | manual_seed: 1265
22 |
23 | trainer_params:
24 | gpus: [1]
25 | max_epochs: 10
26 |
27 | logging_params:
28 | save_dir: "logs/"
29 | name: "VampVAE"
30 |
--------------------------------------------------------------------------------
/configs/infovae.yaml:
--------------------------------------------------------------------------------
1 | model_params:
2 | name: 'InfoVAE'
3 | in_channels: 3
4 | latent_dim: 128
5 | reg_weight: 110 # MMD weight
6 | kernel_type: 'imq'
7 | alpha: -9.0 # KLD weight
8 | beta: 10.5 # Reconstruction weight
9 |
10 | data_params:
11 | data_path: "Data/"
12 | train_batch_size: 64
13 | val_batch_size: 64
14 | patch_size: 64
15 | num_workers: 4
16 |
17 |
18 | exp_params:
19 | LR: 0.005
20 | weight_decay: 0.0
21 | scheduler_gamma: 0.95
22 | kld_weight: 0.00025
23 | manual_seed: 1265
24 |
25 | trainer_params:
26 | gpus: [1]
27 | max_epochs: 10
28 | gradient_clip_val: 0.8
29 |
30 | logging_params:
31 | save_dir: "logs/"
32 | name: "InfoVAE"
33 | manual_seed: 1265
34 |
35 |
36 |
37 |
38 |
--------------------------------------------------------------------------------
/configs/iwae.yaml:
--------------------------------------------------------------------------------
1 | model_params:
2 | name: 'IWAE'
3 | in_channels: 3
4 | latent_dim: 128
5 | num_samples: 5
6 |
7 | data_params:
8 | data_path: "Data/"
9 | train_batch_size: 64
10 | val_batch_size: 64
11 | patch_size: 64
12 | num_workers: 4
13 |
14 |
15 | exp_params:
16 | LR: 0.007
17 | weight_decay: 0.0
18 | scheduler_gamma: 0.95
19 | kld_weight: 0.00025
20 | manual_seed: 1265
21 |
22 | trainer_params:
23 | gpus: [1]
24 | max_epochs: 10
25 |
26 | logging_params:
27 | save_dir: "logs/"
28 | name: "IWAE"
29 |
--------------------------------------------------------------------------------
/configs/joint_vae.yaml:
--------------------------------------------------------------------------------
1 | model_params:
2 | name: 'JointVAE'
3 | in_channels: 3
4 | latent_dim: 512
5 | categorical_dim: 40
6 | latent_min_capacity: 0.0
7 | latent_max_capacity: 20.0
8 | latent_gamma: 10.
9 | latent_num_iter: 25000
10 | categorical_min_capacity: 0.0
11 | categorical_max_capacity: 20.0
12 | categorical_gamma: 10.
13 | categorical_num_iter: 25000
14 | temperature: 0.5
15 | anneal_rate: 0.00003
16 | anneal_interval: 100
17 | alpha: 10.0
18 |
19 | data_params:
20 | data_path: "Data/"
21 | train_batch_size: 64
22 | val_batch_size: 64
23 | patch_size: 64
24 | num_workers: 4
25 |
26 |
27 | exp_params:
28 | LR: 0.005
29 | weight_decay: 0.0
30 | scheduler_gamma: 0.95
31 | kld_weight: 0.00025
32 | manual_seed: 1265
33 |
34 | trainer_params:
35 | gpus: [1]
36 | max_epochs: 10
37 |
38 | logging_params:
39 | save_dir: "logs/"
40 | name: "JointVAE"
41 |
42 |
--------------------------------------------------------------------------------
/configs/logcosh_vae.yaml:
--------------------------------------------------------------------------------
1 | model_params:
2 | name: 'LogCoshVAE'
3 | in_channels: 3
4 | latent_dim: 128
5 | alpha: 10.0
6 | beta: 1.0
7 |
8 | data_params:
9 | data_path: "Data/"
10 | train_batch_size: 64
11 | val_batch_size: 64
12 | patch_size: 64
13 | num_workers: 4
14 |
15 |
16 | exp_params:
17 | LR: 0.005
18 | weight_decay: 0.0
19 | scheduler_gamma: 0.97
20 | kld_weight: 0.00025
21 | manual_seed: 1265
22 |
23 | trainer_params:
24 | gpus: [1]
25 | max_epochs: 10
26 |
27 | logging_params:
28 | save_dir: "logs/"
29 | name: "LogCoshVAE"
30 |
31 |
--------------------------------------------------------------------------------
/configs/lvae.yaml:
--------------------------------------------------------------------------------
1 | model_params:
2 | name: 'LVAE'
3 | in_channels: 3
4 | latent_dims: [4,8,16,32,128]
5 | hidden_dims: [32, 64,128, 256, 512]
6 |
7 | data_params:
8 | data_path: "Data/"
9 | train_batch_size: 64
10 | val_batch_size: 64
11 | patch_size: 64
12 | num_workers: 4
13 |
14 |
15 | exp_params:
16 | LR: 0.005
17 | weight_decay: 0.0
18 | scheduler_gamma: 0.95
19 | kld_weight: 0.00025
20 | manual_seed: 1265
21 |
22 | trainer_params:
23 | gpus: [1]
24 | max_epochs: 10
25 |
26 | logging_params:
27 | save_dir: "logs/"
28 | name: "LVAE"
29 |
--------------------------------------------------------------------------------
/configs/miwae.yaml:
--------------------------------------------------------------------------------
1 | model_params:
2 | name: 'MIWAE'
3 | in_channels: 3
4 | latent_dim: 128
5 | num_samples: 5
6 | num_estimates: 3
7 |
8 | data_params:
9 | data_path: "Data/"
10 | train_batch_size: 64
11 | val_batch_size: 64
12 | patch_size: 64
13 | num_workers: 4
14 |
15 |
16 | exp_params:
17 | LR: 0.005
18 | weight_decay: 0.0
19 | scheduler_gamma: 0.95
20 | kld_weight: 0.00025
21 | manual_seed: 1265
22 |
23 | trainer_params:
24 | gpus: [1]
25 | max_epochs: 10
26 |
27 | logging_params:
28 | save_dir: "logs/"
29 | name: "MIWAE"
30 |
31 |
--------------------------------------------------------------------------------
/configs/mssim_vae.yaml:
--------------------------------------------------------------------------------
1 | model_params:
2 | name: 'MSSIMVAE'
3 | in_channels: 3
4 | latent_dim: 128
5 |
6 | data_params:
7 | data_path: "Data/"
8 | train_batch_size: 64
9 | val_batch_size: 64
10 | patch_size: 64
11 | num_workers: 4
12 |
13 |
14 | exp_params:
15 | LR: 0.005
16 | weight_decay: 0.0
17 | scheduler_gamma: 0.95
18 | kld_weight: 0.00025
19 | manual_seed: 1265
20 |
21 | trainer_params:
22 | gpus: [1]
23 | max_epochs: 10
24 |
25 | logging_params:
26 | save_dir: "logs/"
27 | name: "MSSIMVAE"
28 |
--------------------------------------------------------------------------------
/configs/swae.yaml:
--------------------------------------------------------------------------------
1 | model_params:
2 | name: 'SWAE'
3 | in_channels: 3
4 | latent_dim: 128
5 | reg_weight: 100
6 | wasserstein_deg: 2.0
7 | num_projections: 200
8 | projection_dist: "normal" #"cauchy"
9 |
10 | data_params:
11 | data_path: "Data/"
12 | train_batch_size: 64
13 | val_batch_size: 64
14 | patch_size: 64
15 | num_workers: 4
16 |
17 |
18 | exp_params:
19 | LR: 0.005
20 | weight_decay: 0.0
21 | scheduler_gamma: 0.95
22 | kld_weight: 0.00025
23 | manual_seed: 1265
24 |
25 | trainer_params:
26 | gpus: [1]
27 | max_epochs: 10
28 |
29 | logging_params:
30 | save_dir: "logs/"
31 | name: "SWAE"
32 |
33 |
34 |
35 |
36 |
37 |
--------------------------------------------------------------------------------
/configs/vae.yaml:
--------------------------------------------------------------------------------
1 | model_params:
2 | name: 'VanillaVAE'
3 | in_channels: 3
4 | latent_dim: 128
5 |
6 |
7 | data_params:
8 | data_path: "Data/"
9 | train_batch_size: 64
10 | val_batch_size: 64
11 | patch_size: 64
12 | num_workers: 4
13 |
14 |
15 | exp_params:
16 | LR: 0.005
17 | weight_decay: 0.0
18 | scheduler_gamma: 0.95
19 | kld_weight: 0.00025
20 | manual_seed: 1265
21 |
22 | trainer_params:
23 | gpus: [1]
24 | max_epochs: 100
25 |
26 | logging_params:
27 | save_dir: "logs/"
28 | name: "VanillaVAE"
29 |
30 |
--------------------------------------------------------------------------------
/configs/vampvae.yaml:
--------------------------------------------------------------------------------
1 | model_params:
2 | name: 'VampVAE'
3 | in_channels: 3
4 | latent_dim: 128
5 |
6 | exp_params:
7 | dataset: celeba
8 | data_path: "../../shared/Data/"
9 | img_size: 64
10 | batch_size: 144 # Better to have a square number
11 | LR: 0.005
12 | weight_decay: 0.0
13 | scheduler_gamma: 0.95
14 |
15 | trainer_params:
16 | gpus: 1
17 | max_nb_epochs: 50
18 | max_epochs: 50
19 |
20 | logging_params:
21 | save_dir: "logs/"
22 | name: "VampVAE"
23 | manual_seed: 1265
24 |
--------------------------------------------------------------------------------
/configs/vq_vae.yaml:
--------------------------------------------------------------------------------
1 | model_params:
2 | name: 'VQVAE'
3 | in_channels: 3
4 | embedding_dim: 64
5 | num_embeddings: 512
6 | img_size: 64
7 | beta: 0.25
8 |
9 | data_params:
10 | data_path: "Data/"
11 | train_batch_size: 64
12 | val_batch_size: 64
13 | patch_size: 64
14 | num_workers: 4
15 |
16 |
17 | exp_params:
18 | LR: 0.005
19 | weight_decay: 0.0
20 | scheduler_gamma: 0.0
21 | kld_weight: 0.00025
22 | manual_seed: 1265
23 |
24 | trainer_params:
25 | gpus: [1]
26 | max_epochs: 10
27 |
28 | logging_params:
29 | save_dir: "logs/"
30 | name: 'VQVAE'
31 |
--------------------------------------------------------------------------------
/configs/wae_mmd_imq.yaml:
--------------------------------------------------------------------------------
1 | model_params:
2 | name: 'WAE_MMD'
3 | in_channels: 3
4 | latent_dim: 128
5 | reg_weight: 100
6 | kernel_type: 'imq'
7 |
8 | data_params:
9 | data_path: "Data/"
10 | train_batch_size: 64
11 | val_batch_size: 64
12 | patch_size: 64
13 | num_workers: 4
14 |
15 |
16 | exp_params:
17 | LR: 0.005
18 | weight_decay: 0.0
19 | scheduler_gamma: 0.95
20 | kld_weight: 0.00025
21 | manual_seed: 1265
22 |
23 | trainer_params:
24 | gpus: [1]
25 | max_epochs: 10
26 |
27 | logging_params:
28 | save_dir: "logs/"
29 | name: "WassersteinVAE_IMQ"
30 |
31 |
32 |
33 |
34 |
35 |
--------------------------------------------------------------------------------
/configs/wae_mmd_rbf.yaml:
--------------------------------------------------------------------------------
1 | model_params:
2 | name: 'WAE_MMD'
3 | in_channels: 3
4 | latent_dim: 128
5 | reg_weight: 5000
6 | kernel_type: 'rbf'
7 |
8 | data_params:
9 | data_path: "Data/"
10 | train_batch_size: 64
11 | val_batch_size: 64
12 | patch_size: 64
13 | num_workers: 4
14 |
15 |
16 | exp_params:
17 | LR: 0.005
18 | weight_decay: 0.0
19 | scheduler_gamma: 0.95
20 | kld_weight: 0.00025
21 | manual_seed: 1265
22 |
23 | trainer_params:
24 | gpus: [1]
25 | max_epochs: 10
26 |
27 | logging_params:
28 | save_dir: "logs/"
29 | name: "WassersteinVAE_RBF"
30 |
31 |
32 |
33 |
34 |
35 |
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from torch import Tensor
4 | from pathlib import Path
5 | from typing import List, Optional, Sequence, Union, Any, Callable
6 | from torchvision.datasets.folder import default_loader
7 | from pytorch_lightning import LightningDataModule
8 | from torch.utils.data import DataLoader, Dataset
9 | from torchvision import transforms
10 | from torchvision.datasets import CelebA
11 | import zipfile
12 |
13 |
14 | # Add your custom dataset class here
15 | class MyDataset(Dataset):
16 | def __init__(self):
17 | pass
18 |
19 |
20 | def __len__(self):
21 | pass
22 |
23 | def __getitem__(self, idx):
24 | pass
25 |
26 |
27 | class MyCelebA(CelebA):
28 | """
29 | A work-around to address issues with pytorch's celebA dataset class.
30 |
31 | Download and Extract
32 | URL : https://drive.google.com/file/d/1m8-EBPgi5MRubrm6iQjafK2QMHDBMSfJ/view?usp=sharing
33 | """
34 |
35 | def _check_integrity(self) -> bool:
36 | return True
37 |
38 |
39 |
40 | class OxfordPets(Dataset):
41 | """
42 | URL = https://www.robots.ox.ac.uk/~vgg/data/pets/
43 | """
44 | def __init__(self,
45 | data_path: str,
46 | split: str,
47 | transform: Callable,
48 | **kwargs):
49 | self.data_dir = Path(data_path) / "OxfordPets"
50 | self.transforms = transform
51 | imgs = sorted([f for f in self.data_dir.iterdir() if f.suffix == '.jpg'])
52 |
53 | self.imgs = imgs[:int(len(imgs) * 0.75)] if split == "train" else imgs[int(len(imgs) * 0.75):]
54 |
55 | def __len__(self):
56 | return len(self.imgs)
57 |
58 | def __getitem__(self, idx):
59 | img = default_loader(self.imgs[idx])
60 |
61 | if self.transforms is not None:
62 | img = self.transforms(img)
63 |
64 | return img, 0.0 # dummy datat to prevent breaking
65 |
66 | class VAEDataset(LightningDataModule):
67 | """
68 | PyTorch Lightning data module
69 |
70 | Args:
71 | data_dir: root directory of your dataset.
72 | train_batch_size: the batch size to use during training.
73 | val_batch_size: the batch size to use during validation.
74 | patch_size: the size of the crop to take from the original images.
75 | num_workers: the number of parallel workers to create to load data
76 | items (see PyTorch's Dataloader documentation for more details).
77 | pin_memory: whether prepared items should be loaded into pinned memory
78 | or not. This can improve performance on GPUs.
79 | """
80 |
81 | def __init__(
82 | self,
83 | data_path: str,
84 | train_batch_size: int = 8,
85 | val_batch_size: int = 8,
86 | patch_size: Union[int, Sequence[int]] = (256, 256),
87 | num_workers: int = 0,
88 | pin_memory: bool = False,
89 | **kwargs,
90 | ):
91 | super().__init__()
92 |
93 | self.data_dir = data_path
94 | self.train_batch_size = train_batch_size
95 | self.val_batch_size = val_batch_size
96 | self.patch_size = patch_size
97 | self.num_workers = num_workers
98 | self.pin_memory = pin_memory
99 |
100 | def setup(self, stage: Optional[str] = None) -> None:
101 | # ========================= OxfordPets Dataset =========================
102 |
103 | # train_transforms = transforms.Compose([transforms.RandomHorizontalFlip(),
104 | # transforms.CenterCrop(self.patch_size),
105 | # # transforms.Resize(self.patch_size),
106 | # transforms.ToTensor(),
107 | # transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
108 |
109 | # val_transforms = transforms.Compose([transforms.RandomHorizontalFlip(),
110 | # transforms.CenterCrop(self.patch_size),
111 | # # transforms.Resize(self.patch_size),
112 | # transforms.ToTensor(),
113 | # transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
114 |
115 | # self.train_dataset = OxfordPets(
116 | # self.data_dir,
117 | # split='train',
118 | # transform=train_transforms,
119 | # )
120 |
121 | # self.val_dataset = OxfordPets(
122 | # self.data_dir,
123 | # split='val',
124 | # transform=val_transforms,
125 | # )
126 |
127 | # ========================= CelebA Dataset =========================
128 |
129 | train_transforms = transforms.Compose([transforms.RandomHorizontalFlip(),
130 | transforms.CenterCrop(148),
131 | transforms.Resize(self.patch_size),
132 | transforms.ToTensor(),])
133 |
134 | val_transforms = transforms.Compose([transforms.RandomHorizontalFlip(),
135 | transforms.CenterCrop(148),
136 | transforms.Resize(self.patch_size),
137 | transforms.ToTensor(),])
138 |
139 | self.train_dataset = MyCelebA(
140 | self.data_dir,
141 | split='train',
142 | transform=train_transforms,
143 | download=False,
144 | )
145 |
146 | # Replace CelebA with your dataset
147 | self.val_dataset = MyCelebA(
148 | self.data_dir,
149 | split='test',
150 | transform=val_transforms,
151 | download=False,
152 | )
153 | # ===============================================================
154 |
155 | def train_dataloader(self) -> DataLoader:
156 | return DataLoader(
157 | self.train_dataset,
158 | batch_size=self.train_batch_size,
159 | num_workers=self.num_workers,
160 | shuffle=True,
161 | pin_memory=self.pin_memory,
162 | )
163 |
164 | def val_dataloader(self) -> Union[DataLoader, List[DataLoader]]:
165 | return DataLoader(
166 | self.val_dataset,
167 | batch_size=self.val_batch_size,
168 | num_workers=self.num_workers,
169 | shuffle=False,
170 | pin_memory=self.pin_memory,
171 | )
172 |
173 | def test_dataloader(self) -> Union[DataLoader, List[DataLoader]]:
174 | return DataLoader(
175 | self.val_dataset,
176 | batch_size=144,
177 | num_workers=self.num_workers,
178 | shuffle=True,
179 | pin_memory=self.pin_memory,
180 | )
181 |
--------------------------------------------------------------------------------
/experiment.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import torch
4 | from torch import optim
5 | from models import BaseVAE
6 | from models.types_ import *
7 | from utils import data_loader
8 | import pytorch_lightning as pl
9 | from torchvision import transforms
10 | import torchvision.utils as vutils
11 | from torchvision.datasets import CelebA
12 | from torch.utils.data import DataLoader
13 |
14 |
15 | class VAEXperiment(pl.LightningModule):
16 |
17 | def __init__(self,
18 | vae_model: BaseVAE,
19 | params: dict) -> None:
20 | super(VAEXperiment, self).__init__()
21 |
22 | self.model = vae_model
23 | self.params = params
24 | self.curr_device = None
25 | self.hold_graph = False
26 | try:
27 | self.hold_graph = self.params['retain_first_backpass']
28 | except:
29 | pass
30 |
31 | def forward(self, input: Tensor, **kwargs) -> Tensor:
32 | return self.model(input, **kwargs)
33 |
34 | def training_step(self, batch, batch_idx, optimizer_idx = 0):
35 | real_img, labels = batch
36 | self.curr_device = real_img.device
37 |
38 | results = self.forward(real_img, labels = labels)
39 | train_loss = self.model.loss_function(*results,
40 | M_N = self.params['kld_weight'], #al_img.shape[0]/ self.num_train_imgs,
41 | optimizer_idx=optimizer_idx,
42 | batch_idx = batch_idx)
43 |
44 | self.log_dict({key: val.item() for key, val in train_loss.items()}, sync_dist=True)
45 |
46 | return train_loss['loss']
47 |
48 | def validation_step(self, batch, batch_idx, optimizer_idx = 0):
49 | real_img, labels = batch
50 | self.curr_device = real_img.device
51 |
52 | results = self.forward(real_img, labels = labels)
53 | val_loss = self.model.loss_function(*results,
54 | M_N = 1.0, #real_img.shape[0]/ self.num_val_imgs,
55 | optimizer_idx = optimizer_idx,
56 | batch_idx = batch_idx)
57 |
58 | self.log_dict({f"val_{key}": val.item() for key, val in val_loss.items()}, sync_dist=True)
59 |
60 |
61 | def on_validation_end(self) -> None:
62 | self.sample_images()
63 |
64 | def sample_images(self):
65 | # Get sample reconstruction image
66 | test_input, test_label = next(iter(self.trainer.datamodule.test_dataloader()))
67 | test_input = test_input.to(self.curr_device)
68 | test_label = test_label.to(self.curr_device)
69 |
70 | # test_input, test_label = batch
71 | recons = self.model.generate(test_input, labels = test_label)
72 | vutils.save_image(recons.data,
73 | os.path.join(self.logger.log_dir ,
74 | "Reconstructions",
75 | f"recons_{self.logger.name}_Epoch_{self.current_epoch}.png"),
76 | normalize=True,
77 | nrow=12)
78 |
79 | try:
80 | samples = self.model.sample(144,
81 | self.curr_device,
82 | labels = test_label)
83 | vutils.save_image(samples.cpu().data,
84 | os.path.join(self.logger.log_dir ,
85 | "Samples",
86 | f"{self.logger.name}_Epoch_{self.current_epoch}.png"),
87 | normalize=True,
88 | nrow=12)
89 | except Warning:
90 | pass
91 |
92 | def configure_optimizers(self):
93 |
94 | optims = []
95 | scheds = []
96 |
97 | optimizer = optim.Adam(self.model.parameters(),
98 | lr=self.params['LR'],
99 | weight_decay=self.params['weight_decay'])
100 | optims.append(optimizer)
101 | # Check if more than 1 optimizer is required (Used for adversarial training)
102 | try:
103 | if self.params['LR_2'] is not None:
104 | optimizer2 = optim.Adam(getattr(self.model,self.params['submodel']).parameters(),
105 | lr=self.params['LR_2'])
106 | optims.append(optimizer2)
107 | except:
108 | pass
109 |
110 | try:
111 | if self.params['scheduler_gamma'] is not None:
112 | scheduler = optim.lr_scheduler.ExponentialLR(optims[0],
113 | gamma = self.params['scheduler_gamma'])
114 | scheds.append(scheduler)
115 |
116 | # Check if another scheduler is required for the second optimizer
117 | try:
118 | if self.params['scheduler_gamma_2'] is not None:
119 | scheduler2 = optim.lr_scheduler.ExponentialLR(optims[1],
120 | gamma = self.params['scheduler_gamma_2'])
121 | scheds.append(scheduler2)
122 | except:
123 | pass
124 | return optims, scheds
125 | except:
126 | return optims
127 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import *
2 | from .vanilla_vae import *
3 | from .gamma_vae import *
4 | from .beta_vae import *
5 | from .wae_mmd import *
6 | from .cvae import *
7 | from .hvae import *
8 | from .vampvae import *
9 | from .iwae import *
10 | from .dfcvae import *
11 | from .mssim_vae import MSSIMVAE
12 | from .fvae import *
13 | from .cat_vae import *
14 | from .joint_vae import *
15 | from .info_vae import *
16 | # from .twostage_vae import *
17 | from .lvae import LVAE
18 | from .logcosh_vae import *
19 | from .swae import *
20 | from .miwae import *
21 | from .vq_vae import *
22 | from .betatc_vae import *
23 | from .dip_vae import *
24 |
25 |
26 | # Aliases
27 | VAE = VanillaVAE
28 | GaussianVAE = VanillaVAE
29 | CVAE = ConditionalVAE
30 | GumbelVAE = CategoricalVAE
31 |
32 | vae_models = {'HVAE':HVAE,
33 | 'LVAE':LVAE,
34 | 'IWAE':IWAE,
35 | 'SWAE':SWAE,
36 | 'MIWAE':MIWAE,
37 | 'VQVAE':VQVAE,
38 | 'DFCVAE':DFCVAE,
39 | 'DIPVAE':DIPVAE,
40 | 'BetaVAE':BetaVAE,
41 | 'InfoVAE':InfoVAE,
42 | 'WAE_MMD':WAE_MMD,
43 | 'VampVAE': VampVAE,
44 | 'GammaVAE':GammaVAE,
45 | 'MSSIMVAE':MSSIMVAE,
46 | 'JointVAE':JointVAE,
47 | 'BetaTCVAE':BetaTCVAE,
48 | 'FactorVAE':FactorVAE,
49 | 'LogCoshVAE':LogCoshVAE,
50 | 'VanillaVAE':VanillaVAE,
51 | 'ConditionalVAE':ConditionalVAE,
52 | 'CategoricalVAE':CategoricalVAE}
53 |
--------------------------------------------------------------------------------
/models/base.py:
--------------------------------------------------------------------------------
1 | from .types_ import *
2 | from torch import nn
3 | from abc import abstractmethod
4 |
5 | class BaseVAE(nn.Module):
6 |
7 | def __init__(self) -> None:
8 | super(BaseVAE, self).__init__()
9 |
10 | def encode(self, input: Tensor) -> List[Tensor]:
11 | raise NotImplementedError
12 |
13 | def decode(self, input: Tensor) -> Any:
14 | raise NotImplementedError
15 |
16 | def sample(self, batch_size:int, current_device: int, **kwargs) -> Tensor:
17 | raise NotImplementedError
18 |
19 | def generate(self, x: Tensor, **kwargs) -> Tensor:
20 | raise NotImplementedError
21 |
22 | @abstractmethod
23 | def forward(self, *inputs: Tensor) -> Tensor:
24 | pass
25 |
26 | @abstractmethod
27 | def loss_function(self, *inputs: Any, **kwargs) -> Tensor:
28 | pass
29 |
30 |
31 |
32 |
--------------------------------------------------------------------------------
/models/beta_vae.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from models import BaseVAE
3 | from torch import nn
4 | from torch.nn import functional as F
5 | from .types_ import *
6 |
7 |
8 | class BetaVAE(BaseVAE):
9 |
10 | num_iter = 0 # Global static variable to keep track of iterations
11 |
12 | def __init__(self,
13 | in_channels: int,
14 | latent_dim: int,
15 | hidden_dims: List = None,
16 | beta: int = 4,
17 | gamma:float = 1000.,
18 | max_capacity: int = 25,
19 | Capacity_max_iter: int = 1e5,
20 | loss_type:str = 'B',
21 | **kwargs) -> None:
22 | super(BetaVAE, self).__init__()
23 |
24 | self.latent_dim = latent_dim
25 | self.beta = beta
26 | self.gamma = gamma
27 | self.loss_type = loss_type
28 | self.C_max = torch.Tensor([max_capacity])
29 | self.C_stop_iter = Capacity_max_iter
30 |
31 | modules = []
32 | if hidden_dims is None:
33 | hidden_dims = [32, 64, 128, 256, 512]
34 |
35 | # Build Encoder
36 | for h_dim in hidden_dims:
37 | modules.append(
38 | nn.Sequential(
39 | nn.Conv2d(in_channels, out_channels=h_dim,
40 | kernel_size= 3, stride= 2, padding = 1),
41 | nn.BatchNorm2d(h_dim),
42 | nn.LeakyReLU())
43 | )
44 | in_channels = h_dim
45 |
46 | self.encoder = nn.Sequential(*modules)
47 | self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim)
48 | self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim)
49 |
50 |
51 | # Build Decoder
52 | modules = []
53 |
54 | self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)
55 |
56 | hidden_dims.reverse()
57 |
58 | for i in range(len(hidden_dims) - 1):
59 | modules.append(
60 | nn.Sequential(
61 | nn.ConvTranspose2d(hidden_dims[i],
62 | hidden_dims[i + 1],
63 | kernel_size=3,
64 | stride = 2,
65 | padding=1,
66 | output_padding=1),
67 | nn.BatchNorm2d(hidden_dims[i + 1]),
68 | nn.LeakyReLU())
69 | )
70 |
71 |
72 |
73 | self.decoder = nn.Sequential(*modules)
74 |
75 | self.final_layer = nn.Sequential(
76 | nn.ConvTranspose2d(hidden_dims[-1],
77 | hidden_dims[-1],
78 | kernel_size=3,
79 | stride=2,
80 | padding=1,
81 | output_padding=1),
82 | nn.BatchNorm2d(hidden_dims[-1]),
83 | nn.LeakyReLU(),
84 | nn.Conv2d(hidden_dims[-1], out_channels= 3,
85 | kernel_size= 3, padding= 1),
86 | nn.Tanh())
87 |
88 | def encode(self, input: Tensor) -> List[Tensor]:
89 | """
90 | Encodes the input by passing through the encoder network
91 | and returns the latent codes.
92 | :param input: (Tensor) Input tensor to encoder [N x C x H x W]
93 | :return: (Tensor) List of latent codes
94 | """
95 | result = self.encoder(input)
96 | result = torch.flatten(result, start_dim=1)
97 |
98 | # Split the result into mu and var components
99 | # of the latent Gaussian distribution
100 | mu = self.fc_mu(result)
101 | log_var = self.fc_var(result)
102 |
103 | return [mu, log_var]
104 |
105 | def decode(self, z: Tensor) -> Tensor:
106 | result = self.decoder_input(z)
107 | result = result.view(-1, 512, 2, 2)
108 | result = self.decoder(result)
109 | result = self.final_layer(result)
110 | return result
111 |
112 | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
113 | """
114 | Will a single z be enough ti compute the expectation
115 | for the loss??
116 | :param mu: (Tensor) Mean of the latent Gaussian
117 | :param logvar: (Tensor) Standard deviation of the latent Gaussian
118 | :return:
119 | """
120 | std = torch.exp(0.5 * logvar)
121 | eps = torch.randn_like(std)
122 | return eps * std + mu
123 |
124 | def forward(self, input: Tensor, **kwargs) -> Tensor:
125 | mu, log_var = self.encode(input)
126 | z = self.reparameterize(mu, log_var)
127 | return [self.decode(z), input, mu, log_var]
128 |
129 | def loss_function(self,
130 | *args,
131 | **kwargs) -> dict:
132 | self.num_iter += 1
133 | recons = args[0]
134 | input = args[1]
135 | mu = args[2]
136 | log_var = args[3]
137 | kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset
138 |
139 | recons_loss =F.mse_loss(recons, input)
140 |
141 | kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)
142 |
143 | if self.loss_type == 'H': # https://openreview.net/forum?id=Sy2fzU9gl
144 | loss = recons_loss + self.beta * kld_weight * kld_loss
145 | elif self.loss_type == 'B': # https://arxiv.org/pdf/1804.03599.pdf
146 | self.C_max = self.C_max.to(input.device)
147 | C = torch.clamp(self.C_max/self.C_stop_iter * self.num_iter, 0, self.C_max.data[0])
148 | loss = recons_loss + self.gamma * kld_weight* (kld_loss - C).abs()
149 | else:
150 | raise ValueError('Undefined loss type.')
151 |
152 | return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':kld_loss}
153 |
154 | def sample(self,
155 | num_samples:int,
156 | current_device: int, **kwargs) -> Tensor:
157 | """
158 | Samples from the latent space and return the corresponding
159 | image space map.
160 | :param num_samples: (Int) Number of samples
161 | :param current_device: (Int) Device to run the model
162 | :return: (Tensor)
163 | """
164 | z = torch.randn(num_samples,
165 | self.latent_dim)
166 |
167 | z = z.to(current_device)
168 |
169 | samples = self.decode(z)
170 | return samples
171 |
172 | def generate(self, x: Tensor, **kwargs) -> Tensor:
173 | """
174 | Given an input image x, returns the reconstructed image
175 | :param x: (Tensor) [B x C x H x W]
176 | :return: (Tensor) [B x C x H x W]
177 | """
178 |
179 | return self.forward(x)[0]
--------------------------------------------------------------------------------
/models/betatc_vae.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from models import BaseVAE
3 | from torch import nn
4 | from torch.nn import functional as F
5 | from .types_ import *
6 | import math
7 |
8 |
9 | class BetaTCVAE(BaseVAE):
10 | num_iter = 0 # Global static variable to keep track of iterations
11 |
12 | def __init__(self,
13 | in_channels: int,
14 | latent_dim: int,
15 | hidden_dims: List = None,
16 | anneal_steps: int = 200,
17 | alpha: float = 1.,
18 | beta: float = 6.,
19 | gamma: float = 1.,
20 | **kwargs) -> None:
21 | super(BetaTCVAE, self).__init__()
22 |
23 | self.latent_dim = latent_dim
24 | self.anneal_steps = anneal_steps
25 |
26 | self.alpha = alpha
27 | self.beta = beta
28 | self.gamma = gamma
29 |
30 | modules = []
31 | if hidden_dims is None:
32 | hidden_dims = [32, 32, 32, 32]
33 |
34 | # Build Encoder
35 | for h_dim in hidden_dims:
36 | modules.append(
37 | nn.Sequential(
38 | nn.Conv2d(in_channels, out_channels=h_dim,
39 | kernel_size= 4, stride= 2, padding = 1),
40 | nn.LeakyReLU())
41 | )
42 | in_channels = h_dim
43 |
44 | self.encoder = nn.Sequential(*modules)
45 |
46 | self.fc = nn.Linear(hidden_dims[-1]*16, 256)
47 | self.fc_mu = nn.Linear(256, latent_dim)
48 | self.fc_var = nn.Linear(256, latent_dim)
49 |
50 |
51 | # Build Decoder
52 | modules = []
53 |
54 | self.decoder_input = nn.Linear(latent_dim, 256 * 2)
55 |
56 | hidden_dims.reverse()
57 |
58 | for i in range(len(hidden_dims) - 1):
59 | modules.append(
60 | nn.Sequential(
61 | nn.ConvTranspose2d(hidden_dims[i],
62 | hidden_dims[i + 1],
63 | kernel_size=3,
64 | stride = 2,
65 | padding=1,
66 | output_padding=1),
67 | nn.LeakyReLU())
68 | )
69 |
70 | self.decoder = nn.Sequential(*modules)
71 |
72 | self.final_layer = nn.Sequential(
73 | nn.ConvTranspose2d(hidden_dims[-1],
74 | hidden_dims[-1],
75 | kernel_size=3,
76 | stride=2,
77 | padding=1,
78 | output_padding=1),
79 | nn.LeakyReLU(),
80 | nn.Conv2d(hidden_dims[-1], out_channels= 3,
81 | kernel_size= 3, padding= 1),
82 | nn.Tanh())
83 |
84 | def encode(self, input: Tensor) -> List[Tensor]:
85 | """
86 | Encodes the input by passing through the encoder network
87 | and returns the latent codes.
88 | :param input: (Tensor) Input tensor to encoder [N x C x H x W]
89 | :return: (Tensor) List of latent codes
90 | """
91 | result = self.encoder(input)
92 |
93 | result = torch.flatten(result, start_dim=1)
94 | result = self.fc(result)
95 | # Split the result into mu and var components
96 | # of the latent Gaussian distribution
97 | mu = self.fc_mu(result)
98 | log_var = self.fc_var(result)
99 |
100 | return [mu, log_var]
101 |
102 | def decode(self, z: Tensor) -> Tensor:
103 | """
104 | Maps the given latent codes
105 | onto the image space.
106 | :param z: (Tensor) [B x D]
107 | :return: (Tensor) [B x C x H x W]
108 | """
109 | result = self.decoder_input(z)
110 | result = result.view(-1, 32, 4, 4)
111 | result = self.decoder(result)
112 | result = self.final_layer(result)
113 | return result
114 |
115 | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
116 | """
117 | Reparameterization trick to sample from N(mu, var) from
118 | N(0,1).
119 | :param mu: (Tensor) Mean of the latent Gaussian [B x D]
120 | :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
121 | :return: (Tensor) [B x D]
122 | """
123 | std = torch.exp(0.5 * logvar)
124 | eps = torch.randn_like(std)
125 | return eps * std + mu
126 |
127 | def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
128 | mu, log_var = self.encode(input)
129 | z = self.reparameterize(mu, log_var)
130 | return [self.decode(z), input, mu, log_var, z]
131 |
132 | def log_density_gaussian(self, x: Tensor, mu: Tensor, logvar: Tensor):
133 | """
134 | Computes the log pdf of the Gaussian with parameters mu and logvar at x
135 | :param x: (Tensor) Point at whichGaussian PDF is to be evaluated
136 | :param mu: (Tensor) Mean of the Gaussian distribution
137 | :param logvar: (Tensor) Log variance of the Gaussian distribution
138 | :return:
139 | """
140 | norm = - 0.5 * (math.log(2 * math.pi) + logvar)
141 | log_density = norm - 0.5 * ((x - mu) ** 2 * torch.exp(-logvar))
142 | return log_density
143 |
144 | def loss_function(self,
145 | *args,
146 | **kwargs) -> dict:
147 | """
148 | Computes the VAE loss function.
149 | KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
150 | :param args:
151 | :param kwargs:
152 | :return:
153 | """
154 |
155 | recons = args[0]
156 | input = args[1]
157 | mu = args[2]
158 | log_var = args[3]
159 | z = args[4]
160 |
161 | weight = 1 #kwargs['M_N'] # Account for the minibatch samples from the dataset
162 |
163 | recons_loss =F.mse_loss(recons, input, reduction='sum')
164 |
165 | log_q_zx = self.log_density_gaussian(z, mu, log_var).sum(dim = 1)
166 |
167 | zeros = torch.zeros_like(z)
168 | log_p_z = self.log_density_gaussian(z, zeros, zeros).sum(dim = 1)
169 |
170 | batch_size, latent_dim = z.shape
171 | mat_log_q_z = self.log_density_gaussian(z.view(batch_size, 1, latent_dim),
172 | mu.view(1, batch_size, latent_dim),
173 | log_var.view(1, batch_size, latent_dim))
174 |
175 | # Reference
176 | # [1] https://github.com/YannDubs/disentangling-vae/blob/535bbd2e9aeb5a200663a4f82f1d34e084c4ba8d/disvae/utils/math.py#L54
177 | dataset_size = (1 / kwargs['M_N']) * batch_size # dataset size
178 | strat_weight = (dataset_size - batch_size + 1) / (dataset_size * (batch_size - 1))
179 | importance_weights = torch.Tensor(batch_size, batch_size).fill_(1 / (batch_size -1)).to(input.device)
180 | importance_weights.view(-1)[::batch_size] = 1 / dataset_size
181 | importance_weights.view(-1)[1::batch_size] = strat_weight
182 | importance_weights[batch_size - 2, 0] = strat_weight
183 | log_importance_weights = importance_weights.log()
184 |
185 | mat_log_q_z += log_importance_weights.view(batch_size, batch_size, 1)
186 |
187 | log_q_z = torch.logsumexp(mat_log_q_z.sum(2), dim=1, keepdim=False)
188 | log_prod_q_z = torch.logsumexp(mat_log_q_z, dim=1, keepdim=False).sum(1)
189 |
190 | mi_loss = (log_q_zx - log_q_z).mean()
191 | tc_loss = (log_q_z - log_prod_q_z).mean()
192 | kld_loss = (log_prod_q_z - log_p_z).mean()
193 |
194 | # kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)
195 |
196 | if self.training:
197 | self.num_iter += 1
198 | anneal_rate = min(0 + 1 * self.num_iter / self.anneal_steps, 1)
199 | else:
200 | anneal_rate = 1.
201 |
202 | loss = recons_loss/batch_size + \
203 | self.alpha * mi_loss + \
204 | weight * (self.beta * tc_loss +
205 | anneal_rate * self.gamma * kld_loss)
206 |
207 | return {'loss': loss,
208 | 'Reconstruction_Loss':recons_loss,
209 | 'KLD':kld_loss,
210 | 'TC_Loss':tc_loss,
211 | 'MI_Loss':mi_loss}
212 |
213 | def sample(self,
214 | num_samples:int,
215 | current_device: int, **kwargs) -> Tensor:
216 | """
217 | Samples from the latent space and return the corresponding
218 | image space map.
219 | :param num_samples: (Int) Number of samples
220 | :param current_device: (Int) Device to run the model
221 | :return: (Tensor)
222 | """
223 | z = torch.randn(num_samples,
224 | self.latent_dim)
225 |
226 | z = z.to(current_device)
227 |
228 | samples = self.decode(z)
229 | return samples
230 |
231 | def generate(self, x: Tensor, **kwargs) -> Tensor:
232 | """
233 | Given an input image x, returns the reconstructed image
234 | :param x: (Tensor) [B x C x H x W]
235 | :return: (Tensor) [B x C x H x W]
236 | """
237 |
238 | return self.forward(x)[0]
--------------------------------------------------------------------------------
/models/cat_vae.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from models import BaseVAE
4 | from torch import nn
5 | from torch.nn import functional as F
6 | from .types_ import *
7 |
8 |
9 | class CategoricalVAE(BaseVAE):
10 |
11 | def __init__(self,
12 | in_channels: int,
13 | latent_dim: int,
14 | categorical_dim: int = 40, # Num classes
15 | hidden_dims: List = None,
16 | temperature: float = 0.5,
17 | anneal_rate: float = 3e-5,
18 | anneal_interval: int = 100, # every 100 batches
19 | alpha: float = 30.,
20 | **kwargs) -> None:
21 | super(CategoricalVAE, self).__init__()
22 |
23 | self.latent_dim = latent_dim
24 | self.categorical_dim = categorical_dim
25 | self.temp = temperature
26 | self.min_temp = temperature
27 | self.anneal_rate = anneal_rate
28 | self.anneal_interval = anneal_interval
29 | self.alpha = alpha
30 |
31 | modules = []
32 | if hidden_dims is None:
33 | hidden_dims = [32, 64, 128, 256, 512]
34 |
35 | # Build Encoder
36 | for h_dim in hidden_dims:
37 | modules.append(
38 | nn.Sequential(
39 | nn.Conv2d(in_channels, out_channels=h_dim,
40 | kernel_size= 3, stride= 2, padding = 1),
41 | nn.BatchNorm2d(h_dim),
42 | nn.LeakyReLU())
43 | )
44 | in_channels = h_dim
45 |
46 | self.encoder = nn.Sequential(*modules)
47 | self.fc_z = nn.Linear(hidden_dims[-1]*4,
48 | self.latent_dim * self.categorical_dim)
49 |
50 | # Build Decoder
51 | modules = []
52 |
53 | self.decoder_input = nn.Linear(self.latent_dim * self.categorical_dim
54 | , hidden_dims[-1] * 4)
55 |
56 | hidden_dims.reverse()
57 |
58 | for i in range(len(hidden_dims) - 1):
59 | modules.append(
60 | nn.Sequential(
61 | nn.ConvTranspose2d(hidden_dims[i],
62 | hidden_dims[i + 1],
63 | kernel_size=3,
64 | stride = 2,
65 | padding=1,
66 | output_padding=1),
67 | nn.BatchNorm2d(hidden_dims[i + 1]),
68 | nn.LeakyReLU())
69 | )
70 |
71 |
72 |
73 | self.decoder = nn.Sequential(*modules)
74 |
75 | self.final_layer = nn.Sequential(
76 | nn.ConvTranspose2d(hidden_dims[-1],
77 | hidden_dims[-1],
78 | kernel_size=3,
79 | stride=2,
80 | padding=1,
81 | output_padding=1),
82 | nn.BatchNorm2d(hidden_dims[-1]),
83 | nn.LeakyReLU(),
84 | nn.Conv2d(hidden_dims[-1], out_channels= 3,
85 | kernel_size= 3, padding= 1),
86 | nn.Tanh())
87 | self.sampling_dist = torch.distributions.OneHotCategorical(1. / categorical_dim * torch.ones((self.categorical_dim, 1)))
88 |
89 | def encode(self, input: Tensor) -> List[Tensor]:
90 | """
91 | Encodes the input by passing through the encoder network
92 | and returns the latent codes.
93 | :param input: (Tensor) Input tensor to encoder [B x C x H x W]
94 | :return: (Tensor) Latent code [B x D x Q]
95 | """
96 | result = self.encoder(input)
97 | result = torch.flatten(result, start_dim=1)
98 |
99 | # Split the result into mu and var components
100 | # of the latent Gaussian distribution
101 | z = self.fc_z(result)
102 | z = z.view(-1, self.latent_dim, self.categorical_dim)
103 | return [z]
104 |
105 | def decode(self, z: Tensor) -> Tensor:
106 | """
107 | Maps the given latent codes
108 | onto the image space.
109 | :param z: (Tensor) [B x D x Q]
110 | :return: (Tensor) [B x C x H x W]
111 | """
112 | result = self.decoder_input(z)
113 | result = result.view(-1, 512, 2, 2)
114 | result = self.decoder(result)
115 | result = self.final_layer(result)
116 | return result
117 |
118 | def reparameterize(self, z: Tensor, eps:float = 1e-7) -> Tensor:
119 | """
120 | Gumbel-softmax trick to sample from Categorical Distribution
121 | :param z: (Tensor) Latent Codes [B x D x Q]
122 | :return: (Tensor) [B x D]
123 | """
124 | # Sample from Gumbel
125 | u = torch.rand_like(z)
126 | g = - torch.log(- torch.log(u + eps) + eps)
127 |
128 | # Gumbel-Softmax sample
129 | s = F.softmax((z + g) / self.temp, dim=-1)
130 | s = s.view(-1, self.latent_dim * self.categorical_dim)
131 | return s
132 |
133 |
134 | def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
135 | q = self.encode(input)[0]
136 | z = self.reparameterize(q)
137 | return [self.decode(z), input, q]
138 |
139 | def loss_function(self,
140 | *args,
141 | **kwargs) -> dict:
142 | """
143 | Computes the VAE loss function.
144 | KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
145 | :param args:
146 | :param kwargs:
147 | :return:
148 | """
149 | recons = args[0]
150 | input = args[1]
151 | q = args[2]
152 |
153 | q_p = F.softmax(q, dim=-1) # Convert the categorical codes into probabilities
154 |
155 | kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset
156 | batch_idx = kwargs['batch_idx']
157 |
158 | # Anneal the temperature at regular intervals
159 | if batch_idx % self.anneal_interval == 0 and self.training:
160 | self.temp = np.maximum(self.temp * np.exp(- self.anneal_rate * batch_idx),
161 | self.min_temp)
162 |
163 | recons_loss =F.mse_loss(recons, input, reduction='mean')
164 |
165 | # KL divergence between gumbel-softmax distribution
166 | eps = 1e-7
167 |
168 | # Entropy of the logits
169 | h1 = q_p * torch.log(q_p + eps)
170 |
171 | # Cross entropy with the categorical distribution
172 | h2 = q_p * np.log(1. / self.categorical_dim + eps)
173 | kld_loss = torch.mean(torch.sum(h1 - h2, dim =(1,2)), dim=0)
174 |
175 | # kld_weight = 1.2
176 | loss = self.alpha * recons_loss + kld_weight * kld_loss
177 | return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':-kld_loss}
178 |
179 | def sample(self,
180 | num_samples:int,
181 | current_device: int, **kwargs) -> Tensor:
182 | """
183 | Samples from the latent space and return the corresponding
184 | image space map.
185 | :param num_samples: (Int) Number of samples
186 | :param current_device: (Int) Device to run the model
187 | :return: (Tensor)
188 | """
189 | # [S x D x Q]
190 |
191 | M = num_samples * self.latent_dim
192 | np_y = np.zeros((M, self.categorical_dim), dtype=np.float32)
193 | np_y[range(M), np.random.choice(self.categorical_dim, M)] = 1
194 | np_y = np.reshape(np_y, [M // self.latent_dim, self.latent_dim, self.categorical_dim])
195 | z = torch.from_numpy(np_y)
196 |
197 | # z = self.sampling_dist.sample((num_samples * self.latent_dim, ))
198 | z = z.view(num_samples, self.latent_dim * self.categorical_dim).to(current_device)
199 | samples = self.decode(z)
200 | return samples
201 |
202 | def generate(self, x: Tensor, **kwargs) -> Tensor:
203 | """
204 | Given an input image x, returns the reconstructed image
205 | :param x: (Tensor) [B x C x H x W]
206 | :return: (Tensor) [B x C x H x W]
207 | """
208 |
209 | return self.forward(x)[0]
--------------------------------------------------------------------------------
/models/cvae.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from models import BaseVAE
3 | from torch import nn
4 | from torch.nn import functional as F
5 | from .types_ import *
6 |
7 |
8 | class ConditionalVAE(BaseVAE):
9 |
10 | def __init__(self,
11 | in_channels: int,
12 | num_classes: int,
13 | latent_dim: int,
14 | hidden_dims: List = None,
15 | img_size:int = 64,
16 | **kwargs) -> None:
17 | super(ConditionalVAE, self).__init__()
18 |
19 | self.latent_dim = latent_dim
20 | self.img_size = img_size
21 |
22 | self.embed_class = nn.Linear(num_classes, img_size * img_size)
23 | self.embed_data = nn.Conv2d(in_channels, in_channels, kernel_size=1)
24 |
25 | modules = []
26 | if hidden_dims is None:
27 | hidden_dims = [32, 64, 128, 256, 512]
28 |
29 | in_channels += 1 # To account for the extra label channel
30 | # Build Encoder
31 | for h_dim in hidden_dims:
32 | modules.append(
33 | nn.Sequential(
34 | nn.Conv2d(in_channels, out_channels=h_dim,
35 | kernel_size= 3, stride= 2, padding = 1),
36 | nn.BatchNorm2d(h_dim),
37 | nn.LeakyReLU())
38 | )
39 | in_channels = h_dim
40 |
41 | self.encoder = nn.Sequential(*modules)
42 | self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim)
43 | self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim)
44 |
45 |
46 | # Build Decoder
47 | modules = []
48 |
49 | self.decoder_input = nn.Linear(latent_dim + num_classes, hidden_dims[-1] * 4)
50 |
51 | hidden_dims.reverse()
52 |
53 | for i in range(len(hidden_dims) - 1):
54 | modules.append(
55 | nn.Sequential(
56 | nn.ConvTranspose2d(hidden_dims[i],
57 | hidden_dims[i + 1],
58 | kernel_size=3,
59 | stride = 2,
60 | padding=1,
61 | output_padding=1),
62 | nn.BatchNorm2d(hidden_dims[i + 1]),
63 | nn.LeakyReLU())
64 | )
65 |
66 |
67 |
68 | self.decoder = nn.Sequential(*modules)
69 |
70 | self.final_layer = nn.Sequential(
71 | nn.ConvTranspose2d(hidden_dims[-1],
72 | hidden_dims[-1],
73 | kernel_size=3,
74 | stride=2,
75 | padding=1,
76 | output_padding=1),
77 | nn.BatchNorm2d(hidden_dims[-1]),
78 | nn.LeakyReLU(),
79 | nn.Conv2d(hidden_dims[-1], out_channels= 3,
80 | kernel_size= 3, padding= 1),
81 | nn.Tanh())
82 |
83 | def encode(self, input: Tensor) -> List[Tensor]:
84 | """
85 | Encodes the input by passing through the encoder network
86 | and returns the latent codes.
87 | :param input: (Tensor) Input tensor to encoder [N x C x H x W]
88 | :return: (Tensor) List of latent codes
89 | """
90 | result = self.encoder(input)
91 | result = torch.flatten(result, start_dim=1)
92 |
93 | # Split the result into mu and var components
94 | # of the latent Gaussian distribution
95 | mu = self.fc_mu(result)
96 | log_var = self.fc_var(result)
97 |
98 | return [mu, log_var]
99 |
100 | def decode(self, z: Tensor) -> Tensor:
101 | result = self.decoder_input(z)
102 | result = result.view(-1, 512, 2, 2)
103 | result = self.decoder(result)
104 | result = self.final_layer(result)
105 | return result
106 |
107 | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
108 | """
109 | Will a single z be enough ti compute the expectation
110 | for the loss??
111 | :param mu: (Tensor) Mean of the latent Gaussian
112 | :param logvar: (Tensor) Standard deviation of the latent Gaussian
113 | :return:
114 | """
115 | std = torch.exp(0.5 * logvar)
116 | eps = torch.randn_like(std)
117 | return eps * std + mu
118 |
119 | def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
120 | y = kwargs['labels'].float()
121 | embedded_class = self.embed_class(y)
122 | embedded_class = embedded_class.view(-1, self.img_size, self.img_size).unsqueeze(1)
123 | embedded_input = self.embed_data(input)
124 |
125 | x = torch.cat([embedded_input, embedded_class], dim = 1)
126 | mu, log_var = self.encode(x)
127 |
128 | z = self.reparameterize(mu, log_var)
129 |
130 | z = torch.cat([z, y], dim = 1)
131 | return [self.decode(z), input, mu, log_var]
132 |
133 | def loss_function(self,
134 | *args,
135 | **kwargs) -> dict:
136 | recons = args[0]
137 | input = args[1]
138 | mu = args[2]
139 | log_var = args[3]
140 |
141 | kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset
142 | recons_loss =F.mse_loss(recons, input)
143 |
144 | kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)
145 |
146 | loss = recons_loss + kld_weight * kld_loss
147 | return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':-kld_loss}
148 |
149 | def sample(self,
150 | num_samples:int,
151 | current_device: int,
152 | **kwargs) -> Tensor:
153 | """
154 | Samples from the latent space and return the corresponding
155 | image space map.
156 | :param num_samples: (Int) Number of samples
157 | :param current_device: (Int) Device to run the model
158 | :return: (Tensor)
159 | """
160 | y = kwargs['labels'].float()
161 | z = torch.randn(num_samples,
162 | self.latent_dim)
163 |
164 | z = z.to(current_device)
165 |
166 | z = torch.cat([z, y], dim=1)
167 | samples = self.decode(z)
168 | return samples
169 |
170 | def generate(self, x: Tensor, **kwargs) -> Tensor:
171 | """
172 | Given an input image x, returns the reconstructed image
173 | :param x: (Tensor) [B x C x H x W]
174 | :return: (Tensor) [B x C x H x W]
175 | """
176 |
177 | return self.forward(x, **kwargs)[0]
--------------------------------------------------------------------------------
/models/dfcvae.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from models import BaseVAE
3 | from torch import nn
4 | from torchvision.models import vgg19_bn
5 | from torch.nn import functional as F
6 | from .types_ import *
7 |
8 |
9 | class DFCVAE(BaseVAE):
10 |
11 | def __init__(self,
12 | in_channels: int,
13 | latent_dim: int,
14 | hidden_dims: List = None,
15 | alpha:float = 1,
16 | beta:float = 0.5,
17 | **kwargs) -> None:
18 | super(DFCVAE, self).__init__()
19 |
20 | self.latent_dim = latent_dim
21 | self.alpha = alpha
22 | self.beta = beta
23 |
24 | modules = []
25 | if hidden_dims is None:
26 | hidden_dims = [32, 64, 128, 256, 512]
27 |
28 | # Build Encoder
29 | for h_dim in hidden_dims:
30 | modules.append(
31 | nn.Sequential(
32 | nn.Conv2d(in_channels, out_channels=h_dim,
33 | kernel_size= 3, stride= 2, padding = 1),
34 | nn.BatchNorm2d(h_dim),
35 | nn.LeakyReLU())
36 | )
37 | in_channels = h_dim
38 |
39 | self.encoder = nn.Sequential(*modules)
40 | self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim)
41 | self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim)
42 |
43 |
44 | # Build Decoder
45 | modules = []
46 |
47 | self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)
48 |
49 | hidden_dims.reverse()
50 |
51 | for i in range(len(hidden_dims) - 1):
52 | modules.append(
53 | nn.Sequential(
54 | nn.ConvTranspose2d(hidden_dims[i],
55 | hidden_dims[i + 1],
56 | kernel_size=3,
57 | stride = 2,
58 | padding=1,
59 | output_padding=1),
60 | nn.BatchNorm2d(hidden_dims[i + 1]),
61 | nn.LeakyReLU())
62 | )
63 |
64 |
65 |
66 | self.decoder = nn.Sequential(*modules)
67 |
68 | self.final_layer = nn.Sequential(
69 | nn.ConvTranspose2d(hidden_dims[-1],
70 | hidden_dims[-1],
71 | kernel_size=3,
72 | stride=2,
73 | padding=1,
74 | output_padding=1),
75 | nn.BatchNorm2d(hidden_dims[-1]),
76 | nn.LeakyReLU(),
77 | nn.Conv2d(hidden_dims[-1], out_channels= 3,
78 | kernel_size= 3, padding= 1),
79 | nn.Tanh())
80 |
81 | self.feature_network = vgg19_bn(pretrained=True)
82 |
83 | # Freeze the pretrained feature network
84 | for param in self.feature_network.parameters():
85 | param.requires_grad = False
86 |
87 | self.feature_network.eval()
88 |
89 |
90 | def encode(self, input: Tensor) -> List[Tensor]:
91 | """
92 | Encodes the input by passing through the encoder network
93 | and returns the latent codes.
94 | :param input: (Tensor) Input tensor to encoder [N x C x H x W]
95 | :return: (Tensor) List of latent codes
96 | """
97 | result = self.encoder(input)
98 | result = torch.flatten(result, start_dim=1)
99 |
100 | # Split the result into mu and var components
101 | # of the latent Gaussian distribution
102 | mu = self.fc_mu(result)
103 | log_var = self.fc_var(result)
104 |
105 | return [mu, log_var]
106 |
107 | def decode(self, z: Tensor) -> Tensor:
108 | """
109 | Maps the given latent codes
110 | onto the image space.
111 | :param z: (Tensor) [B x D]
112 | :return: (Tensor) [B x C x H x W]
113 | """
114 | result = self.decoder_input(z)
115 | result = result.view(-1, 512, 2, 2)
116 | result = self.decoder(result)
117 | result = self.final_layer(result)
118 | return result
119 |
120 | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
121 | """
122 | Reparameterization trick to sample from N(mu, var) from
123 | N(0,1).
124 | :param mu: (Tensor) Mean of the latent Gaussian [B x D]
125 | :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
126 | :return: (Tensor) [B x D]
127 | """
128 | std = torch.exp(0.5 * logvar)
129 | eps = torch.randn_like(std)
130 | return eps * std + mu
131 |
132 | def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
133 | mu, log_var = self.encode(input)
134 | z = self.reparameterize(mu, log_var)
135 | recons = self.decode(z)
136 |
137 | recons_features = self.extract_features(recons)
138 | input_features = self.extract_features(input)
139 |
140 | return [recons, input, recons_features, input_features, mu, log_var]
141 |
142 | def extract_features(self,
143 | input: Tensor,
144 | feature_layers: List = None) -> List[Tensor]:
145 | """
146 | Extracts the features from the pretrained model
147 | at the layers indicated by feature_layers.
148 | :param input: (Tensor) [B x C x H x W]
149 | :param feature_layers: List of string of IDs
150 | :return: List of the extracted features
151 | """
152 | if feature_layers is None:
153 | feature_layers = ['14', '24', '34', '43']
154 | features = []
155 | result = input
156 | for (key, module) in self.feature_network.features._modules.items():
157 | result = module(result)
158 | if(key in feature_layers):
159 | features.append(result)
160 |
161 | return features
162 |
163 | def loss_function(self,
164 | *args,
165 | **kwargs) -> dict:
166 | """
167 | Computes the VAE loss function.
168 | KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
169 | :param args:
170 | :param kwargs:
171 | :return:
172 | """
173 | recons = args[0]
174 | input = args[1]
175 | recons_features = args[2]
176 | input_features = args[3]
177 | mu = args[4]
178 | log_var = args[5]
179 |
180 | kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset
181 | recons_loss =F.mse_loss(recons, input)
182 |
183 | feature_loss = 0.0
184 | for (r, i) in zip(recons_features, input_features):
185 | feature_loss += F.mse_loss(r, i)
186 |
187 | kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)
188 |
189 | loss = self.beta * (recons_loss + feature_loss) + self.alpha * kld_weight * kld_loss
190 | return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':-kld_loss}
191 |
192 | def sample(self,
193 | num_samples:int,
194 | current_device: int, **kwargs) -> Tensor:
195 | """
196 | Samples from the latent space and return the corresponding
197 | image space map.
198 | :param num_samples: (Int) Number of samples
199 | :param current_device: (Int) Device to run the model
200 | :return: (Tensor)
201 | """
202 | z = torch.randn(num_samples,
203 | self.latent_dim)
204 |
205 | z = z.to(current_device)
206 |
207 | samples = self.decode(z)
208 | return samples
209 |
210 | def generate(self, x: Tensor, **kwargs) -> Tensor:
211 | """
212 | Given an input image x, returns the reconstructed image
213 | :param x: (Tensor) [B x C x H x W]
214 | :return: (Tensor) [B x C x H x W]
215 | """
216 |
217 | return self.forward(x)[0]
--------------------------------------------------------------------------------
/models/dip_vae.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from models import BaseVAE
3 | from torch import nn
4 | from torch.nn import functional as F
5 | from .types_ import *
6 |
7 |
8 | class DIPVAE(BaseVAE):
9 |
10 | def __init__(self,
11 | in_channels: int,
12 | latent_dim: int,
13 | hidden_dims: List = None,
14 | lambda_diag: float = 10.,
15 | lambda_offdiag: float = 5.,
16 | **kwargs) -> None:
17 | super(DIPVAE, self).__init__()
18 |
19 | self.latent_dim = latent_dim
20 | self.lambda_diag = lambda_diag
21 | self.lambda_offdiag = lambda_offdiag
22 |
23 | modules = []
24 | if hidden_dims is None:
25 | hidden_dims = [32, 64, 128, 256, 512]
26 |
27 | # Build Encoder
28 | for h_dim in hidden_dims:
29 | modules.append(
30 | nn.Sequential(
31 | nn.Conv2d(in_channels, out_channels=h_dim,
32 | kernel_size= 3, stride= 2, padding = 1),
33 | nn.BatchNorm2d(h_dim),
34 | nn.LeakyReLU())
35 | )
36 | in_channels = h_dim
37 |
38 | self.encoder = nn.Sequential(*modules)
39 | self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim)
40 | self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim)
41 |
42 |
43 | # Build Decoder
44 | modules = []
45 |
46 | self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)
47 |
48 | hidden_dims.reverse()
49 |
50 | for i in range(len(hidden_dims) - 1):
51 | modules.append(
52 | nn.Sequential(
53 | nn.ConvTranspose2d(hidden_dims[i],
54 | hidden_dims[i + 1],
55 | kernel_size=3,
56 | stride = 2,
57 | padding=1,
58 | output_padding=1),
59 | nn.BatchNorm2d(hidden_dims[i + 1]),
60 | nn.LeakyReLU())
61 | )
62 |
63 | self.decoder = nn.Sequential(*modules)
64 |
65 | self.final_layer = nn.Sequential(
66 | nn.ConvTranspose2d(hidden_dims[-1],
67 | hidden_dims[-1],
68 | kernel_size=3,
69 | stride=2,
70 | padding=1,
71 | output_padding=1),
72 | nn.BatchNorm2d(hidden_dims[-1]),
73 | nn.LeakyReLU(),
74 | nn.Conv2d(hidden_dims[-1], out_channels= 3,
75 | kernel_size= 3, padding= 1),
76 | nn.Tanh())
77 |
78 | def encode(self, input: Tensor) -> List[Tensor]:
79 | """
80 | Encodes the input by passing through the encoder network
81 | and returns the latent codes.
82 | :param input: (Tensor) Input tensor to encoder [N x C x H x W]
83 | :return: (Tensor) List of latent codes
84 | """
85 | result = self.encoder(input)
86 | result = torch.flatten(result, start_dim=1)
87 |
88 | # Split the result into mu and var components
89 | # of the latent Gaussian distribution
90 | mu = self.fc_mu(result)
91 | log_var = self.fc_var(result)
92 |
93 | return [mu, log_var]
94 |
95 | def decode(self, z: Tensor) -> Tensor:
96 | """
97 | Maps the given latent codes
98 | onto the image space.
99 | :param z: (Tensor) [B x D]
100 | :return: (Tensor) [B x C x H x W]
101 | """
102 | result = self.decoder_input(z)
103 | result = result.view(-1, 512, 2, 2)
104 | result = self.decoder(result)
105 | result = self.final_layer(result)
106 | return result
107 |
108 | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
109 | """
110 | Reparameterization trick to sample from N(mu, var) from
111 | N(0,1).
112 | :param mu: (Tensor) Mean of the latent Gaussian [B x D]
113 | :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
114 | :return: (Tensor) [B x D]
115 | """
116 | std = torch.exp(0.5 * logvar)
117 | eps = torch.randn_like(std)
118 | return eps * std + mu
119 |
120 | def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
121 | mu, log_var = self.encode(input)
122 | z = self.reparameterize(mu, log_var)
123 | return [self.decode(z), input, mu, log_var]
124 |
125 | def loss_function(self,
126 | *args,
127 | **kwargs) -> dict:
128 | """
129 | Computes the VAE loss function.
130 | KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
131 | :param args:
132 | :param kwargs:
133 | :return:
134 | """
135 | recons = args[0]
136 | input = args[1]
137 | mu = args[2]
138 | log_var = args[3]
139 |
140 | kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset
141 | recons_loss =F.mse_loss(recons, input, reduction='sum')
142 |
143 |
144 | kld_loss = torch.sum(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)
145 |
146 | # DIP Loss
147 | centered_mu = mu - mu.mean(dim=1, keepdim = True) # [B x D]
148 | cov_mu = centered_mu.t().matmul(centered_mu).squeeze() # [D X D]
149 |
150 | # Add Variance for DIP Loss II
151 | cov_z = cov_mu + torch.mean(torch.diagonal((2. * log_var).exp(), dim1 = 0), dim = 0) # [D x D]
152 | # For DIp Loss I
153 | # cov_z = cov_mu
154 |
155 | cov_diag = torch.diag(cov_z) # [D]
156 | cov_offdiag = cov_z - torch.diag(cov_diag) # [D x D]
157 | dip_loss = self.lambda_offdiag * torch.sum(cov_offdiag ** 2) + \
158 | self.lambda_diag * torch.sum((cov_diag - 1) ** 2)
159 |
160 | loss = recons_loss + kld_weight * kld_loss + dip_loss
161 | return {'loss': loss,
162 | 'Reconstruction_Loss':recons_loss,
163 | 'KLD':-kld_loss,
164 | 'DIP_Loss':dip_loss}
165 |
166 | def sample(self,
167 | num_samples:int,
168 | current_device: int, **kwargs) -> Tensor:
169 | """
170 | Samples from the latent space and return the corresponding
171 | image space map.
172 | :param num_samples: (Int) Number of samples
173 | :param current_device: (Int) Device to run the model
174 | :return: (Tensor)
175 | """
176 | z = torch.randn(num_samples,
177 | self.latent_dim)
178 |
179 | z = z.to(current_device)
180 |
181 | samples = self.decode(z)
182 | return samples
183 |
184 | def generate(self, x: Tensor, **kwargs) -> Tensor:
185 | """
186 | Given an input image x, returns the reconstructed image
187 | :param x: (Tensor) [B x C x H x W]
188 | :return: (Tensor) [B x C x H x W]
189 | """
190 |
191 | return self.forward(x)[0]
--------------------------------------------------------------------------------
/models/fvae.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from models import BaseVAE
3 | from torch import nn
4 | from torch.nn import functional as F
5 | from .types_ import *
6 |
7 |
8 | class FactorVAE(BaseVAE):
9 |
10 | def __init__(self,
11 | in_channels: int,
12 | latent_dim: int,
13 | hidden_dims: List = None,
14 | gamma: float = 40.,
15 | **kwargs) -> None:
16 | super(FactorVAE, self).__init__()
17 |
18 | self.latent_dim = latent_dim
19 | self.gamma = gamma
20 |
21 | modules = []
22 | if hidden_dims is None:
23 | hidden_dims = [32, 64, 128, 256, 512]
24 |
25 | # Build Encoder
26 | for h_dim in hidden_dims:
27 | modules.append(
28 | nn.Sequential(
29 | nn.Conv2d(in_channels, out_channels=h_dim,
30 | kernel_size= 3, stride= 2, padding = 1),
31 | nn.BatchNorm2d(h_dim),
32 | nn.LeakyReLU())
33 | )
34 | in_channels = h_dim
35 |
36 | self.encoder = nn.Sequential(*modules)
37 | self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim)
38 | self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim)
39 |
40 |
41 | # Build Decoder
42 | modules = []
43 |
44 | self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)
45 |
46 | hidden_dims.reverse()
47 |
48 | for i in range(len(hidden_dims) - 1):
49 | modules.append(
50 | nn.Sequential(
51 | nn.ConvTranspose2d(hidden_dims[i],
52 | hidden_dims[i + 1],
53 | kernel_size=3,
54 | stride = 2,
55 | padding=1,
56 | output_padding=1),
57 | nn.BatchNorm2d(hidden_dims[i + 1]),
58 | nn.LeakyReLU())
59 | )
60 |
61 |
62 |
63 | self.decoder = nn.Sequential(*modules)
64 |
65 | self.final_layer = nn.Sequential(
66 | nn.ConvTranspose2d(hidden_dims[-1],
67 | hidden_dims[-1],
68 | kernel_size=3,
69 | stride=2,
70 | padding=1,
71 | output_padding=1),
72 | nn.BatchNorm2d(hidden_dims[-1]),
73 | nn.LeakyReLU(),
74 | nn.Conv2d(hidden_dims[-1], out_channels= 3,
75 | kernel_size= 3, padding= 1),
76 | nn.Tanh())
77 |
78 | # Discriminator network for the Total Correlation (TC) loss
79 | self.discriminator = nn.Sequential(nn.Linear(self.latent_dim, 1000),
80 | nn.BatchNorm1d(1000),
81 | nn.LeakyReLU(0.2),
82 | nn.Linear(1000, 1000),
83 | nn.BatchNorm1d(1000),
84 | nn.LeakyReLU(0.2),
85 | nn.Linear(1000, 1000),
86 | nn.BatchNorm1d(1000),
87 | nn.LeakyReLU(0.2),
88 | nn.Linear(1000, 2))
89 | self.D_z_reserve = None
90 |
91 |
92 | def encode(self, input: Tensor) -> List[Tensor]:
93 | """
94 | Encodes the input by passing through the encoder network
95 | and returns the latent codes.
96 | :param input: (Tensor) Input tensor to encoder [N x C x H x W]
97 | :return: (Tensor) List of latent codes
98 | """
99 | result = self.encoder(input)
100 | result = torch.flatten(result, start_dim=1)
101 |
102 | # Split the result into mu and var components
103 | # of the latent Gaussian distribution
104 | mu = self.fc_mu(result)
105 | log_var = self.fc_var(result)
106 |
107 | return [mu, log_var]
108 |
109 | def decode(self, z: Tensor) -> Tensor:
110 | """
111 | Maps the given latent codes
112 | onto the image space.
113 | :param z: (Tensor) [B x D]
114 | :return: (Tensor) [B x C x H x W]
115 | """
116 | result = self.decoder_input(z)
117 | result = result.view(-1, 512, 2, 2)
118 | result = self.decoder(result)
119 | result = self.final_layer(result)
120 | return result
121 |
122 | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
123 | """
124 | Reparameterization trick to sample from N(mu, var) from
125 | N(0,1).
126 | :param mu: (Tensor) Mean of the latent Gaussian [B x D]
127 | :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
128 | :return: (Tensor) [B x D]
129 | """
130 | std = torch.exp(0.5 * logvar)
131 | eps = torch.randn_like(std)
132 | return eps * std + mu
133 |
134 | def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
135 | mu, log_var = self.encode(input)
136 | z = self.reparameterize(mu, log_var)
137 | return [self.decode(z), input, mu, log_var, z]
138 |
139 | def permute_latent(self, z: Tensor) -> Tensor:
140 | """
141 | Permutes each of the latent codes in the batch
142 | :param z: [B x D]
143 | :return: [B x D]
144 | """
145 | B, D = z.size()
146 |
147 | # Returns a shuffled inds for each latent code in the batch
148 | inds = torch.cat([(D *i) + torch.randperm(D) for i in range(B)])
149 | return z.view(-1)[inds].view(B, D)
150 |
151 | def loss_function(self,
152 | *args,
153 | **kwargs) -> dict:
154 | """
155 | Computes the VAE loss function.
156 | KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
157 | :param args:
158 | :param kwargs:
159 | :return:
160 | """
161 | recons = args[0]
162 | input = args[1]
163 | mu = args[2]
164 | log_var = args[3]
165 | z = args[4]
166 |
167 | kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset
168 | optimizer_idx = kwargs['optimizer_idx']
169 |
170 | # Update the VAE
171 | if optimizer_idx == 0:
172 | recons_loss =F.mse_loss(recons, input)
173 | kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)
174 |
175 | self.D_z_reserve = self.discriminator(z)
176 | vae_tc_loss = (self.D_z_reserve[:, 0] - self.D_z_reserve[:, 1]).mean()
177 |
178 | loss = recons_loss + kld_weight * kld_loss + self.gamma * vae_tc_loss
179 |
180 | # print(f' recons: {recons_loss}, kld: {kld_loss}, VAE_TC_loss: {vae_tc_loss}')
181 | return {'loss': loss,
182 | 'Reconstruction_Loss':recons_loss,
183 | 'KLD':-kld_loss,
184 | 'VAE_TC_Loss': vae_tc_loss}
185 |
186 | # Update the Discriminator
187 | elif optimizer_idx == 1:
188 | device = input.device
189 | true_labels = torch.ones(input.size(0), dtype= torch.long,
190 | requires_grad=False).to(device)
191 | false_labels = torch.zeros(input.size(0), dtype= torch.long,
192 | requires_grad=False).to(device)
193 |
194 | z = z.detach() # Detach so that VAE is not trained again
195 | z_perm = self.permute_latent(z)
196 | D_z_perm = self.discriminator(z_perm)
197 | D_tc_loss = 0.5 * (F.cross_entropy(self.D_z_reserve, false_labels) +
198 | F.cross_entropy(D_z_perm, true_labels))
199 | # print(f'D_TC: {D_tc_loss}')
200 | return {'loss': D_tc_loss,
201 | 'D_TC_Loss':D_tc_loss}
202 |
203 | def sample(self,
204 | num_samples:int,
205 | current_device: int, **kwargs) -> Tensor:
206 | """
207 | Samples from the latent space and return the corresponding
208 | image space map.
209 | :param num_samples: (Int) Number of samples
210 | :param current_device: (Int) Device to run the model
211 | :return: (Tensor)
212 | """
213 | z = torch.randn(num_samples,
214 | self.latent_dim)
215 |
216 | z = z.to(current_device)
217 |
218 | samples = self.decode(z)
219 | return samples
220 |
221 | def generate(self, x: Tensor, **kwargs) -> Tensor:
222 | """
223 | Given an input image x, returns the reconstructed image
224 | :param x: (Tensor) [B x C x H x W]
225 | :return: (Tensor) [B x C x H x W]
226 | """
227 |
228 | return self.forward(x)[0]
--------------------------------------------------------------------------------
/models/gamma_vae.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from models import BaseVAE
3 | from torch import nn
4 | from torch.distributions import Gamma
5 | from torch.nn import functional as F
6 | from .types_ import *
7 | import torch.nn.init as init
8 |
9 |
10 | class GammaVAE(BaseVAE):
11 |
12 | def __init__(self,
13 | in_channels: int,
14 | latent_dim: int,
15 | hidden_dims: List = None,
16 | gamma_shape: float = 8.,
17 | prior_shape: float = 2.0,
18 | prior_rate: float = 1.,
19 | **kwargs) -> None:
20 | super(GammaVAE, self).__init__()
21 | self.latent_dim = latent_dim
22 | self.B = gamma_shape
23 |
24 | self.prior_alpha = torch.tensor([prior_shape])
25 | self.prior_beta = torch.tensor([prior_rate])
26 |
27 | modules = []
28 | if hidden_dims is None:
29 | hidden_dims = [32, 64, 128, 256, 512]
30 |
31 | # Build Encoder
32 | for h_dim in hidden_dims:
33 | modules.append(
34 | nn.Sequential(
35 | nn.Conv2d(in_channels, out_channels=h_dim,
36 | kernel_size=3, stride=2, padding=1),
37 | nn.BatchNorm2d(h_dim),
38 | nn.LeakyReLU())
39 | )
40 | in_channels = h_dim
41 |
42 | self.encoder = nn.Sequential(*modules)
43 | self.fc_mu = nn.Sequential(nn.Linear(hidden_dims[-1] * 4, latent_dim),
44 | nn.Softmax())
45 | self.fc_var = nn.Sequential(nn.Linear(hidden_dims[-1] * 4, latent_dim),
46 | nn.Softmax())
47 |
48 | # Build Decoder
49 | modules = []
50 |
51 | self.decoder_input = nn.Sequential(nn.Linear(latent_dim, hidden_dims[-1] * 4))
52 |
53 | hidden_dims.reverse()
54 |
55 | for i in range(len(hidden_dims) - 1):
56 | modules.append(
57 | nn.Sequential(
58 | nn.ConvTranspose2d(hidden_dims[i],
59 | hidden_dims[i + 1],
60 | kernel_size=3,
61 | stride=2,
62 | padding=1,
63 | output_padding=1),
64 | nn.BatchNorm2d(hidden_dims[i + 1]),
65 | nn.LeakyReLU())
66 | )
67 |
68 | self.decoder = nn.Sequential(*modules)
69 |
70 | self.final_layer = nn.Sequential(
71 | nn.ConvTranspose2d(hidden_dims[-1],
72 | hidden_dims[-1],
73 | kernel_size=3,
74 | stride=2,
75 | padding=1,
76 | output_padding=1),
77 | nn.BatchNorm2d(hidden_dims[-1]),
78 | nn.LeakyReLU(),
79 | nn.Conv2d(hidden_dims[-1], out_channels=3,
80 | kernel_size=3, padding=1),
81 | nn.Sigmoid())
82 |
83 | self.weight_init()
84 |
85 | def weight_init(self):
86 |
87 | # print(self._modules)
88 | for block in self._modules:
89 | for m in self._modules[block]:
90 | init_(m)
91 |
92 | def encode(self, input: Tensor) -> List[Tensor]:
93 | """
94 | Encodes the input by passing through the encoder network
95 | and returns the latent codes.
96 | :param input: (Tensor) Input tensor to encoder [N x C x H x W]
97 | :return: (Tensor) List of latent codes
98 | """
99 | result = self.encoder(input)
100 | result = torch.flatten(result, start_dim=1)
101 |
102 | # Split the result into mu and var components
103 | # of the latent Gaussian distribution
104 | alpha = self.fc_mu(result)
105 | beta = self.fc_var(result)
106 |
107 | return [alpha, beta]
108 |
109 | def decode(self, z: Tensor) -> Tensor:
110 | result = self.decoder_input(z)
111 | result = result.view(-1, 512, 2, 2)
112 | result = self.decoder(result)
113 | result = self.final_layer(result)
114 | return result
115 |
116 | def reparameterize(self, alpha: Tensor, beta: Tensor) -> Tensor:
117 | """
118 | Reparameterize the Gamma distribution by the shape augmentation trick.
119 | Reference:
120 | [1] https://arxiv.org/pdf/1610.05683.pdf
121 |
122 | :param alpha: (Tensor) Shape parameter of the latent Gamma
123 | :param beta: (Tensor) Rate parameter of the latent Gamma
124 | :return:
125 | """
126 | # Sample from Gamma to guarantee acceptance
127 | alpha_ = alpha.clone().detach()
128 | z_hat = Gamma(alpha_ + self.B, torch.ones_like(alpha_)).sample()
129 |
130 | # Compute the eps ~ N(0,1) that produces z_hat
131 | eps = self.inv_h_func(alpha + self.B , z_hat)
132 | z = self.h_func(alpha + self.B, eps)
133 |
134 | # When beta != 1, scale by beta
135 | return z / beta
136 |
137 | def h_func(self, alpha: Tensor, eps: Tensor) -> Tensor:
138 | """
139 | Reparameterize a sample eps ~ N(0, 1) so that h(z) ~ Gamma(alpha, 1)
140 | :param alpha: (Tensor) Shape parameter
141 | :param eps: (Tensor) Random sample to reparameterize
142 | :return: (Tensor)
143 | """
144 |
145 | z = (alpha - 1./3.) * (1 + eps / torch.sqrt(9. * alpha - 3.))**3
146 | return z
147 |
148 | def inv_h_func(self, alpha: Tensor, z: Tensor) -> Tensor:
149 | """
150 | Inverse reparameterize the given z into eps.
151 | :param alpha: (Tensor)
152 | :param z: (Tensor)
153 | :return: (Tensor)
154 | """
155 | eps = torch.sqrt(9. * alpha - 3.) * ((z / (alpha - 1./3.))**(1. / 3.) - 1.)
156 | return eps
157 |
158 | def forward(self, input: Tensor, **kwargs) -> Tensor:
159 | alpha, beta = self.encode(input)
160 | z = self.reparameterize(alpha, beta)
161 | return [self.decode(z), input, alpha, beta]
162 |
163 | # def I_function(self, alpha_p, beta_p, alpha_q, beta_q):
164 | # return - (alpha_q * beta_q) / alpha_p - \
165 | # beta_p * torch.log(alpha_p) - torch.lgamma(beta_p) + \
166 | # (beta_p - 1) * torch.digamma(beta_q) + \
167 | # (beta_p - 1) * torch.log(alpha_q)
168 | def I_function(self, a, b, c, d):
169 | return - c * d / a - b * torch.log(a) - torch.lgamma(b) + (b - 1) * (torch.digamma(d) + torch.log(c))
170 |
171 | def vae_gamma_kl_loss(self, a, b, c, d):
172 | """
173 | https://stats.stackexchange.com/questions/11646/kullback-leibler-divergence-between-two-gamma-distributions
174 | b and d are Gamma shape parameters and
175 | a and c are scale parameters.
176 | (All, therefore, must be positive.)
177 | """
178 |
179 | a = 1 / a
180 | c = 1 / c
181 | losses = self.I_function(c, d, c, d) - self.I_function(a, b, c, d)
182 | return torch.sum(losses, dim=1)
183 |
184 | def loss_function(self,
185 | *args,
186 | **kwargs) -> dict:
187 | recons = args[0]
188 | input = args[1]
189 | alpha = args[2]
190 | beta = args[3]
191 |
192 | curr_device = input.device
193 | kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset
194 | recons_loss = torch.mean(F.mse_loss(recons, input, reduction = 'none'), dim = (1,2,3))
195 |
196 | # https://stats.stackexchange.com/questions/11646/kullback-leibler-divergence-between-two-gamma-distributions
197 | # alpha = 1./ alpha
198 |
199 |
200 | self.prior_alpha = self.prior_alpha.to(curr_device)
201 | self.prior_beta = self.prior_beta.to(curr_device)
202 |
203 | # kld_loss = - self.I_function(alpha, beta, self.prior_alpha, self.prior_beta)
204 |
205 | kld_loss = self.vae_gamma_kl_loss(alpha, beta, self.prior_alpha, self.prior_beta)
206 |
207 | # kld_loss = torch.sum(kld_loss, dim=1)
208 |
209 | loss = recons_loss + kld_loss
210 | loss = torch.mean(loss, dim = 0)
211 | # print(loss, recons_loss, kld_loss)
212 | return {'loss': loss} #, 'Reconstruction_Loss': recons_loss, 'KLD': -kld_loss}
213 |
214 | def sample(self,
215 | num_samples:int,
216 | current_device: int, **kwargs) -> Tensor:
217 | """
218 | Samples from the latent space and return the corresponding
219 | image space map.
220 | :param num_samples: (Int) Number of samples
221 | :param current_device: (Int) Device to run the modelSay
222 | :return: (Tensor)
223 | """
224 | z = Gamma(self.prior_alpha, self.prior_beta).sample((num_samples, self.latent_dim))
225 | z = z.squeeze().to(current_device)
226 |
227 | samples = self.decode(z)
228 | return samples
229 |
230 | def generate(self, x: Tensor, **kwargs) -> Tensor:
231 | """
232 | Given an input image x, returns the reconstructed image
233 | :param x: (Tensor) [B x C x H x W]
234 | :return: (Tensor) [B x C x H x W]
235 | """
236 |
237 | return self.forward(x)[0]
238 |
239 | def init_(m):
240 | if isinstance(m, (nn.Linear, nn.Conv2d)):
241 | init.orthogonal_(m.weight)
242 | if m.bias is not None:
243 | m.bias.data.fill_(0)
244 | elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
245 | m.weight.data.fill_(1)
246 | if m.bias is not None:
247 | m.bias.data.fill_(0)
248 |
--------------------------------------------------------------------------------
/models/hvae.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from models import BaseVAE
3 | from torch import nn
4 | from torch.nn import functional as F
5 | from .types_ import *
6 |
7 |
8 | class HVAE(BaseVAE):
9 |
10 | def __init__(self,
11 | in_channels: int,
12 | latent1_dim: int,
13 | latent2_dim: int,
14 | hidden_dims: List = None,
15 | img_size:int = 64,
16 | pseudo_input_size: int = 128,
17 | **kwargs) -> None:
18 | super(HVAE, self).__init__()
19 |
20 | self.latent1_dim = latent1_dim
21 | self.latent2_dim = latent2_dim
22 | self.img_size = img_size
23 |
24 | modules = []
25 | if hidden_dims is None:
26 | hidden_dims = [32, 64, 128, 256, 512]
27 | channels = in_channels
28 |
29 | # Build z2 Encoder
30 | for h_dim in hidden_dims:
31 | modules.append(
32 | nn.Sequential(
33 | nn.Conv2d(channels, out_channels=h_dim,
34 | kernel_size= 3, stride= 2, padding = 1),
35 | nn.BatchNorm2d(h_dim),
36 | nn.LeakyReLU())
37 | )
38 | channels = h_dim
39 |
40 | self.encoder_z2_layers = nn.Sequential(*modules)
41 | self.fc_z2_mu = nn.Linear(hidden_dims[-1]*4, latent2_dim)
42 | self.fc_z2_var = nn.Linear(hidden_dims[-1]*4, latent2_dim)
43 | # ========================================================================#
44 | # Build z1 Encoder
45 | self.embed_z2_code = nn.Linear(latent2_dim, img_size * img_size)
46 | self.embed_data = nn.Conv2d(in_channels, in_channels, kernel_size=1)
47 |
48 | modules = []
49 | channels = in_channels + 1 # One more channel for the latent code
50 | for h_dim in hidden_dims:
51 | modules.append(
52 | nn.Sequential(
53 | nn.Conv2d(channels, out_channels=h_dim,
54 | kernel_size= 3, stride= 2, padding = 1),
55 | nn.BatchNorm2d(h_dim),
56 | nn.LeakyReLU())
57 | )
58 | channels = h_dim
59 |
60 | self.encoder_z1_layers = nn.Sequential(*modules)
61 | self.fc_z1_mu = nn.Linear(hidden_dims[-1]*4, latent1_dim)
62 | self.fc_z1_var = nn.Linear(hidden_dims[-1]*4, latent1_dim)
63 |
64 | #========================================================================#
65 | # Build z2 Decoder
66 | self.recons_z1_mu = nn.Linear(latent2_dim, latent1_dim)
67 | self.recons_z1_log_var = nn.Linear(latent2_dim, latent1_dim)
68 |
69 | # ========================================================================#
70 | # Build z1 Decoder
71 | self.debed_z1_code = nn.Linear(latent1_dim, 1024)
72 | self.debed_z2_code = nn.Linear(latent2_dim, 1024)
73 | modules = []
74 | hidden_dims.reverse()
75 |
76 | for i in range(len(hidden_dims) - 1):
77 | modules.append(
78 | nn.Sequential(
79 | nn.ConvTranspose2d(hidden_dims[i],
80 | hidden_dims[i + 1],
81 | kernel_size=3,
82 | stride = 2,
83 | padding=1,
84 | output_padding=1),
85 | nn.BatchNorm2d(hidden_dims[i + 1]),
86 | nn.LeakyReLU())
87 | )
88 |
89 |
90 |
91 | self.decoder = nn.Sequential(*modules)
92 |
93 | self.final_layer = nn.Sequential(
94 | nn.ConvTranspose2d(hidden_dims[-1],
95 | hidden_dims[-1],
96 | kernel_size=3,
97 | stride=2,
98 | padding=1,
99 | output_padding=1),
100 | nn.BatchNorm2d(hidden_dims[-1]),
101 | nn.LeakyReLU(),
102 | nn.Conv2d(hidden_dims[-1], out_channels= 3,
103 | kernel_size= 3, padding= 1),
104 | nn.Tanh())
105 |
106 | # ========================================================================#
107 | # Pesudo Input for the Vamp-Prior
108 | # self.pseudo_input = torch.eye(pseudo_input_size,
109 | # requires_grad=False).view(1, 1, pseudo_input_size, -1)
110 | #
111 | #
112 | # self.pseudo_layer = nn.Conv2d(1, out_channels=in_channels,
113 | # kernel_size=3, stride=2, padding=1)
114 |
115 | def encode_z2(self, input: Tensor) -> List[Tensor]:
116 | """
117 | Encodes the input by passing through the encoder network
118 | and returns the latent codes.
119 | :param input: (Tensor) Input tensor to encoder [N x C x H x W]
120 | :return: (Tensor) List of latent codes
121 | """
122 | result = self.encoder_z2_layers(input)
123 | result = torch.flatten(result, start_dim=1)
124 |
125 | # Split the result into mu and var components
126 | # of the latent Gaussian distribution
127 | z2_mu = self.fc_z2_mu(result)
128 | z2_log_var = self.fc_z2_var(result)
129 |
130 | return [z2_mu, z2_log_var]
131 |
132 | def encode_z1(self, input: Tensor, z2: Tensor) -> List[Tensor]:
133 | x = self.embed_data(input)
134 | z2 = self.embed_z2_code(z2)
135 | z2 = z2.view(-1, self.img_size, self.img_size).unsqueeze(1)
136 | result = torch.cat([x, z2], dim=1)
137 |
138 | result = self.encoder_z1_layers(result)
139 | result = torch.flatten(result, start_dim=1)
140 | z1_mu = self.fc_z1_mu(result)
141 | z1_log_var = self.fc_z1_var(result)
142 |
143 | return [z1_mu, z1_log_var]
144 |
145 | def encode(self, input: Tensor) -> List[Tensor]:
146 | z2_mu, z2_log_var = self.encode_z2(input)
147 | z2 = self.reparameterize(z2_mu, z2_log_var)
148 |
149 | # z1 ~ q(z1|x, z2)
150 | z1_mu, z1_log_var = self.encode_z1(input, z2)
151 | return [z1_mu, z1_log_var, z2_mu, z2_log_var, z2]
152 |
153 | def decode(self, input: Tensor) -> Tensor:
154 | result = self.decoder(input)
155 | result = self.final_layer(result)
156 | return result
157 |
158 | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
159 | """
160 | Will a single z be enough ti compute the expectation
161 | for the loss??
162 | :param mu: (Tensor) Mean of the latent Gaussian
163 | :param logvar: (Tensor) Standard deviation of the latent Gaussian
164 | :return:
165 | """
166 | std = torch.exp(0.5 * logvar)
167 | eps = torch.randn_like(std)
168 | return eps * std + mu
169 |
170 | def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
171 |
172 | # Encode the input into the latent codes z1 and z2
173 | # z2 ~q(z2 | x)
174 | # z1 ~ q(z1|x, z2)
175 | z1_mu, z1_log_var, z2_mu, z2_log_var, z2 = self.encode(input)
176 | z1 = self.reparameterize(z1_mu, z1_log_var)
177 |
178 | # Reconstruct the image using both the latent codes
179 | # x ~ p(x|z1, z2)
180 | debedded_z1 = self.debed_z1_code(z1)
181 | debedded_z2 = self.debed_z2_code(z2)
182 | result = torch.cat([debedded_z1, debedded_z2], dim=1)
183 | result = result.view(-1, 512, 2, 2)
184 | recons = self.decode(result)
185 |
186 | return [recons,
187 | input,
188 | z1_mu, z1_log_var,
189 | z2_mu, z2_log_var,
190 | z1, z2]
191 |
192 | def loss_function(self,
193 | *args,
194 | **kwargs) -> dict:
195 | recons = args[0]
196 | input = args[1]
197 |
198 | z1_mu = args[2]
199 | z1_log_var = args[3]
200 |
201 | z2_mu = args[4]
202 | z2_log_var = args[5]
203 |
204 | z1= args[6]
205 | z2 = args[7]
206 |
207 | # Reconstruct (decode) z2 into z1
208 | # z1 ~ p(z1|z2) [This for the loss calculation]
209 | z1_p_mu = self.recons_z1_mu(z2)
210 | z1_p_log_var = self.recons_z1_log_var(z2)
211 |
212 |
213 | kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset
214 | recons_loss =F.mse_loss(recons, input)
215 |
216 | z1_kld = torch.mean(-0.5 * torch.sum(1 + z1_log_var - z1_mu ** 2 - z1_log_var.exp(), dim = 1),
217 | dim = 0)
218 | z2_kld = torch.mean(-0.5 * torch.sum(1 + z2_log_var - z2_mu ** 2 - z2_log_var.exp(), dim = 1),
219 | dim = 0)
220 |
221 | z1_p_kld = torch.mean(-0.5 * torch.sum(1 + z1_p_log_var - (z1 - z1_p_mu) ** 2 - z1_p_log_var.exp(),
222 | dim = 1),
223 | dim = 0)
224 |
225 | z2_p_kld = torch.mean(-0.5*(z2**2), dim = 0)
226 |
227 | kld_loss = -(z1_p_kld - z1_kld - z2_kld)
228 | loss = recons_loss + kld_weight * kld_loss
229 | # print(z2_p_kld)
230 |
231 | return {'loss': loss, 'Reconstruction Loss':recons_loss, 'KLD':-kld_loss}
232 |
233 | def sample(self, batch_size:int, current_device: int, **kwargs) -> Tensor:
234 | z2 = torch.randn(batch_size,
235 | self.latent2_dim)
236 |
237 | z2 = z2.cuda(current_device)
238 |
239 | z1_mu = self.recons_z1_mu(z2)
240 | z1_log_var = self.recons_z1_log_var(z2)
241 | z1 = self.reparameterize(z1_mu, z1_log_var)
242 |
243 | debedded_z1 = self.debed_z1_code(z1)
244 | debedded_z2 = self.debed_z2_code(z2)
245 |
246 | result = torch.cat([debedded_z1, debedded_z2], dim=1)
247 | result = result.view(-1, 512, 2, 2)
248 | samples = self.decode(result)
249 |
250 | return samples
251 |
252 | def generate(self, x: Tensor, **kwargs) -> Tensor:
253 | """
254 | Given an input image x, returns the reconstructed image
255 | :param x: (Tensor) [B x C x H x W]
256 | :return: (Tensor) [B x C x H x W]
257 | """
258 |
259 | return self.forward(x)[0]
260 |
--------------------------------------------------------------------------------
/models/info_vae.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from models import BaseVAE
3 | from torch import nn
4 | from torch.nn import functional as F
5 | from .types_ import *
6 |
7 |
8 | class InfoVAE(BaseVAE):
9 |
10 | def __init__(self,
11 | in_channels: int,
12 | latent_dim: int,
13 | hidden_dims: List = None,
14 | alpha: float = -0.5,
15 | beta: float = 5.0,
16 | reg_weight: int = 100,
17 | kernel_type: str = 'imq',
18 | latent_var: float = 2.,
19 | **kwargs) -> None:
20 | super(InfoVAE, self).__init__()
21 |
22 | self.latent_dim = latent_dim
23 | self.reg_weight = reg_weight
24 | self.kernel_type = kernel_type
25 | self.z_var = latent_var
26 |
27 | assert alpha <= 0, 'alpha must be negative or zero.'
28 |
29 | self.alpha = alpha
30 | self.beta = beta
31 |
32 | modules = []
33 | if hidden_dims is None:
34 | hidden_dims = [32, 64, 128, 256, 512]
35 |
36 | # Build Encoder
37 | for h_dim in hidden_dims:
38 | modules.append(
39 | nn.Sequential(
40 | nn.Conv2d(in_channels, out_channels=h_dim,
41 | kernel_size= 3, stride= 2, padding = 1),
42 | nn.BatchNorm2d(h_dim),
43 | nn.LeakyReLU())
44 | )
45 | in_channels = h_dim
46 |
47 | self.encoder = nn.Sequential(*modules)
48 | self.fc_mu = nn.Linear(hidden_dims[-1] * 4, latent_dim)
49 | self.fc_var = nn.Linear(hidden_dims[-1] * 4, latent_dim)
50 |
51 | # Build Decoder
52 | modules = []
53 |
54 | self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)
55 |
56 | hidden_dims.reverse()
57 |
58 | for i in range(len(hidden_dims) - 1):
59 | modules.append(
60 | nn.Sequential(
61 | nn.ConvTranspose2d(hidden_dims[i],
62 | hidden_dims[i + 1],
63 | kernel_size=3,
64 | stride = 2,
65 | padding=1,
66 | output_padding=1),
67 | nn.BatchNorm2d(hidden_dims[i + 1]),
68 | nn.LeakyReLU())
69 | )
70 |
71 |
72 |
73 | self.decoder = nn.Sequential(*modules)
74 |
75 | self.final_layer = nn.Sequential(
76 | nn.ConvTranspose2d(hidden_dims[-1],
77 | hidden_dims[-1],
78 | kernel_size=3,
79 | stride=2,
80 | padding=1,
81 | output_padding=1),
82 | nn.BatchNorm2d(hidden_dims[-1]),
83 | nn.LeakyReLU(),
84 | nn.Conv2d(hidden_dims[-1], out_channels= 3,
85 | kernel_size= 3, padding= 1),
86 | nn.Tanh())
87 |
88 | def encode(self, input: Tensor) -> List[Tensor]:
89 | """
90 | Encodes the input by passing through the encoder network
91 | and returns the latent codes.
92 | :param input: (Tensor) Input tensor to encoder [N x C x H x W]
93 | :return: (Tensor) List of latent codes
94 | """
95 | result = self.encoder(input)
96 | result = torch.flatten(result, start_dim=1)
97 |
98 | # Split the result into mu and var components
99 | # of the latent Gaussian distribution
100 | mu = self.fc_mu(result)
101 | log_var = self.fc_var(result)
102 | return [mu, log_var]
103 |
104 | def decode(self, z: Tensor) -> Tensor:
105 | result = self.decoder_input(z)
106 | result = result.view(-1, 512, 2, 2)
107 | result = self.decoder(result)
108 | result = self.final_layer(result)
109 | return result
110 |
111 | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
112 | """
113 | Reparameterization trick to sample from N(mu, var) from
114 | N(0,1).
115 | :param mu: (Tensor) Mean of the latent Gaussian [B x D]
116 | :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
117 | :return: (Tensor) [B x D]
118 | """
119 | std = torch.exp(0.5 * logvar)
120 | eps = torch.randn_like(std)
121 | return eps * std + mu
122 |
123 | def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
124 | mu, log_var = self.encode(input)
125 | z = self.reparameterize(mu, log_var)
126 | return [self.decode(z), input, z, mu, log_var]
127 |
128 | def loss_function(self,
129 | *args,
130 | **kwargs) -> dict:
131 | recons = args[0]
132 | input = args[1]
133 | z = args[2]
134 | mu = args[3]
135 | log_var = args[4]
136 |
137 | batch_size = input.size(0)
138 | bias_corr = batch_size * (batch_size - 1)
139 | kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset
140 |
141 | recons_loss =F.mse_loss(recons, input)
142 | mmd_loss = self.compute_mmd(z)
143 | kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim=1), dim=0)
144 |
145 | loss = self.beta * recons_loss + \
146 | (1. - self.alpha) * kld_weight * kld_loss + \
147 | (self.alpha + self.reg_weight - 1.)/bias_corr * mmd_loss
148 | return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'MMD': mmd_loss, 'KLD':-kld_loss}
149 |
150 | def compute_kernel(self,
151 | x1: Tensor,
152 | x2: Tensor) -> Tensor:
153 | # Convert the tensors into row and column vectors
154 | D = x1.size(1)
155 | N = x1.size(0)
156 |
157 | x1 = x1.unsqueeze(-2) # Make it into a column tensor
158 | x2 = x2.unsqueeze(-3) # Make it into a row tensor
159 |
160 | """
161 | Usually the below lines are not required, especially in our case,
162 | but this is useful when x1 and x2 have different sizes
163 | along the 0th dimension.
164 | """
165 | x1 = x1.expand(N, N, D)
166 | x2 = x2.expand(N, N, D)
167 |
168 | if self.kernel_type == 'rbf':
169 | result = self.compute_rbf(x1, x2)
170 | elif self.kernel_type == 'imq':
171 | result = self.compute_inv_mult_quad(x1, x2)
172 | else:
173 | raise ValueError('Undefined kernel type.')
174 |
175 | return result
176 |
177 |
178 | def compute_rbf(self,
179 | x1: Tensor,
180 | x2: Tensor,
181 | eps: float = 1e-7) -> Tensor:
182 | """
183 | Computes the RBF Kernel between x1 and x2.
184 | :param x1: (Tensor)
185 | :param x2: (Tensor)
186 | :param eps: (Float)
187 | :return:
188 | """
189 | z_dim = x2.size(-1)
190 | sigma = 2. * z_dim * self.z_var
191 |
192 | result = torch.exp(-((x1 - x2).pow(2).mean(-1) / sigma))
193 | return result
194 |
195 | def compute_inv_mult_quad(self,
196 | x1: Tensor,
197 | x2: Tensor,
198 | eps: float = 1e-7) -> Tensor:
199 | """
200 | Computes the Inverse Multi-Quadratics Kernel between x1 and x2,
201 | given by
202 |
203 | k(x_1, x_2) = \sum \frac{C}{C + \|x_1 - x_2 \|^2}
204 | :param x1: (Tensor)
205 | :param x2: (Tensor)
206 | :param eps: (Float)
207 | :return:
208 | """
209 | z_dim = x2.size(-1)
210 | C = 2 * z_dim * self.z_var
211 | kernel = C / (eps + C + (x1 - x2).pow(2).sum(dim = -1))
212 |
213 | # Exclude diagonal elements
214 | result = kernel.sum() - kernel.diag().sum()
215 |
216 | return result
217 |
218 | def compute_mmd(self, z: Tensor) -> Tensor:
219 | # Sample from prior (Gaussian) distribution
220 | prior_z = torch.randn_like(z)
221 |
222 | prior_z__kernel = self.compute_kernel(prior_z, prior_z)
223 | z__kernel = self.compute_kernel(z, z)
224 | priorz_z__kernel = self.compute_kernel(prior_z, z)
225 |
226 | mmd = prior_z__kernel.mean() + \
227 | z__kernel.mean() - \
228 | 2 * priorz_z__kernel.mean()
229 | return mmd
230 |
231 | def sample(self,
232 | num_samples:int,
233 | current_device: int, **kwargs) -> Tensor:
234 | """
235 | Samples from the latent space and return the corresponding
236 | image space map.
237 | :param num_samples: (Int) Number of samples
238 | :param current_device: (Int) Device to run the model
239 | :return: (Tensor)
240 | """
241 | z = torch.randn(num_samples,
242 | self.latent_dim)
243 |
244 | z = z.to(current_device)
245 |
246 | samples = self.decode(z)
247 | return samples
248 |
249 | def generate(self, x: Tensor, **kwargs) -> Tensor:
250 | """
251 | Given an input image x, returns the reconstructed image
252 | :param x: (Tensor) [B x C x H x W]
253 | :return: (Tensor) [B x C x H x W]
254 | """
255 |
256 | return self.forward(x)[0]
--------------------------------------------------------------------------------
/models/iwae.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from models import BaseVAE
3 | from torch import nn
4 | from torch.nn import functional as F
5 | from .types_ import *
6 |
7 |
8 | class IWAE(BaseVAE):
9 |
10 | def __init__(self,
11 | in_channels: int,
12 | latent_dim: int,
13 | hidden_dims: List = None,
14 | num_samples: int = 5,
15 | **kwargs) -> None:
16 | super(IWAE, self).__init__()
17 |
18 | self.latent_dim = latent_dim
19 | self.num_samples = num_samples
20 |
21 | modules = []
22 | if hidden_dims is None:
23 | hidden_dims = [32, 64, 128, 256, 512]
24 |
25 | # Build Encoder
26 | for h_dim in hidden_dims:
27 | modules.append(
28 | nn.Sequential(
29 | nn.Conv2d(in_channels, out_channels=h_dim,
30 | kernel_size= 3, stride= 2, padding = 1),
31 | nn.BatchNorm2d(h_dim),
32 | nn.LeakyReLU())
33 | )
34 | in_channels = h_dim
35 |
36 | self.encoder = nn.Sequential(*modules)
37 | self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim)
38 | self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim)
39 |
40 |
41 | # Build Decoder
42 | modules = []
43 |
44 | self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)
45 |
46 | hidden_dims.reverse()
47 |
48 | for i in range(len(hidden_dims) - 1):
49 | modules.append(
50 | nn.Sequential(
51 | nn.ConvTranspose2d(hidden_dims[i],
52 | hidden_dims[i + 1],
53 | kernel_size=3,
54 | stride = 2,
55 | padding=1,
56 | output_padding=1),
57 | nn.BatchNorm2d(hidden_dims[i + 1]),
58 | nn.LeakyReLU())
59 | )
60 |
61 |
62 |
63 | self.decoder = nn.Sequential(*modules)
64 |
65 | self.final_layer = nn.Sequential(
66 | nn.ConvTranspose2d(hidden_dims[-1],
67 | hidden_dims[-1],
68 | kernel_size=3,
69 | stride=2,
70 | padding=1,
71 | output_padding=1),
72 | nn.BatchNorm2d(hidden_dims[-1]),
73 | nn.LeakyReLU(),
74 | nn.Conv2d(hidden_dims[-1], out_channels= 3,
75 | kernel_size= 3, padding= 1),
76 | nn.Tanh())
77 |
78 | def encode(self, input: Tensor) -> List[Tensor]:
79 | """
80 | Encodes the input by passing through the encoder network
81 | and returns the latent codes.
82 | :param input: (Tensor) Input tensor to encoder [N x C x H x W]
83 | :return: (Tensor) List of latent codes
84 | """
85 | result = self.encoder(input)
86 | result = torch.flatten(result, start_dim=1)
87 |
88 | # Split the result into mu and var components
89 | # of the latent Gaussian distribution
90 | mu = self.fc_mu(result)
91 | log_var = self.fc_var(result)
92 |
93 | return [mu, log_var]
94 |
95 | def decode(self, z: Tensor) -> Tensor:
96 | """
97 | Maps the given latent codes of S samples
98 | onto the image space.
99 | :param z: (Tensor) [B x S x D]
100 | :return: (Tensor) [B x S x C x H x W]
101 | """
102 | B, _, _ = z.size()
103 | z = z.view(-1, self.latent_dim) #[BS x D]
104 | result = self.decoder_input(z)
105 | result = result.view(-1, 512, 2, 2)
106 | result = self.decoder(result)
107 | result = self.final_layer(result) #[BS x C x H x W ]
108 | result = result.view([B, -1, result.size(1), result.size(2), result.size(3)]) #[B x S x C x H x W]
109 | return result
110 |
111 | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
112 | """
113 | :param mu: (Tensor) Mean of the latent Gaussian
114 | :param logvar: (Tensor) Standard deviation of the latent Gaussian
115 | :return:
116 | """
117 | std = torch.exp(0.5 * logvar)
118 | eps = torch.randn_like(std)
119 | return eps * std + mu
120 |
121 | def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
122 | mu, log_var = self.encode(input)
123 | mu = mu.repeat(self.num_samples, 1, 1).permute(1, 0, 2) # [B x S x D]
124 | log_var = log_var.repeat(self.num_samples, 1, 1).permute(1, 0, 2) # [B x S x D]
125 | z= self.reparameterize(mu, log_var) # [B x S x D]
126 | eps = (z - mu) / log_var # Prior samples
127 | return [self.decode(z), input, mu, log_var, z, eps]
128 |
129 | def loss_function(self,
130 | *args,
131 | **kwargs) -> dict:
132 | """
133 | KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
134 | :param args:
135 | :param kwargs:
136 | :return:
137 | """
138 | recons = args[0]
139 | input = args[1]
140 | mu = args[2]
141 | log_var = args[3]
142 | z = args[4]
143 | eps = args[5]
144 |
145 | input = input.repeat(self.num_samples, 1, 1, 1, 1).permute(1, 0, 2, 3, 4) #[B x S x C x H x W]
146 |
147 | kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset
148 |
149 | log_p_x_z = ((recons - input) ** 2).flatten(2).mean(-1) # Reconstruction Loss [B x S]
150 | kld_loss = -0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim=2) ## [B x S]
151 | # Get importance weights
152 | log_weight = (log_p_x_z + kld_weight * kld_loss) #.detach().data
153 |
154 | # Rescale the weights (along the sample dim) to lie in [0, 1] and sum to 1
155 | weight = F.softmax(log_weight, dim = -1)
156 | # kld_loss = torch.mean(kld_loss, dim = 0)
157 |
158 | loss = torch.mean(torch.sum(weight * log_weight, dim=-1), dim = 0)
159 |
160 | return {'loss': loss, 'Reconstruction_Loss':log_p_x_z.mean(), 'KLD':-kld_loss.mean()}
161 |
162 | def sample(self,
163 | num_samples:int,
164 | current_device: int, **kwargs) -> Tensor:
165 | """
166 | Samples from the latent space and return the corresponding
167 | image space map.
168 | :param num_samples: (Int) Number of samples
169 | :param current_device: (Int) Device to run the model
170 | :return: (Tensor)
171 | """
172 | z = torch.randn(num_samples, 1,
173 | self.latent_dim)
174 |
175 | z = z.to(current_device)
176 |
177 | samples = self.decode(z).squeeze()
178 | return samples
179 |
180 | def generate(self, x: Tensor, **kwargs) -> Tensor:
181 | """
182 | Given an input image x, returns the reconstructed image.
183 | Returns only the first reconstructed sample
184 | :param x: (Tensor) [B x C x H x W]
185 | :return: (Tensor) [B x C x H x W]
186 | """
187 |
188 | return self.forward(x)[0][:, 0, :]
189 |
--------------------------------------------------------------------------------
/models/logcosh_vae.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from models import BaseVAE
4 | from torch import nn
5 | from .types_ import *
6 |
7 |
8 | class LogCoshVAE(BaseVAE):
9 |
10 | def __init__(self,
11 | in_channels: int,
12 | latent_dim: int,
13 | hidden_dims: List = None,
14 | alpha: float = 100.,
15 | beta: float = 10.,
16 | **kwargs) -> None:
17 | super(LogCoshVAE, self).__init__()
18 |
19 | self.latent_dim = latent_dim
20 | self.alpha = alpha
21 | self.beta = beta
22 |
23 | modules = []
24 | if hidden_dims is None:
25 | hidden_dims = [32, 64, 128, 256, 512]
26 |
27 | # Build Encoder
28 | for h_dim in hidden_dims:
29 | modules.append(
30 | nn.Sequential(
31 | nn.Conv2d(in_channels, out_channels=h_dim,
32 | kernel_size= 3, stride= 2, padding = 1),
33 | nn.BatchNorm2d(h_dim),
34 | nn.LeakyReLU())
35 | )
36 | in_channels = h_dim
37 |
38 | self.encoder = nn.Sequential(*modules)
39 | self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim)
40 | self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim)
41 |
42 |
43 | # Build Decoder
44 | modules = []
45 |
46 | self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)
47 |
48 | hidden_dims.reverse()
49 |
50 | for i in range(len(hidden_dims) - 1):
51 | modules.append(
52 | nn.Sequential(
53 | nn.ConvTranspose2d(hidden_dims[i],
54 | hidden_dims[i + 1],
55 | kernel_size=3,
56 | stride = 2,
57 | padding=1,
58 | output_padding=1),
59 | nn.BatchNorm2d(hidden_dims[i + 1]),
60 | nn.LeakyReLU())
61 | )
62 |
63 | self.decoder = nn.Sequential(*modules)
64 |
65 | self.final_layer = nn.Sequential(
66 | nn.ConvTranspose2d(hidden_dims[-1],
67 | hidden_dims[-1],
68 | kernel_size=3,
69 | stride=2,
70 | padding=1,
71 | output_padding=1),
72 | nn.BatchNorm2d(hidden_dims[-1]),
73 | nn.LeakyReLU(),
74 | nn.Conv2d(hidden_dims[-1], out_channels= 3,
75 | kernel_size= 3, padding= 1),
76 | nn.Tanh())
77 |
78 | def encode(self, input: Tensor) -> List[Tensor]:
79 | """
80 | Encodes the input by passing through the encoder network
81 | and returns the latent codes.
82 | :param input: (Tensor) Input tensor to encoder [N x C x H x W]
83 | :return: (Tensor) List of latent codes
84 | """
85 | result = self.encoder(input)
86 | result = torch.flatten(result, start_dim=1)
87 |
88 | # Split the result into mu and var components
89 | # of the latent Gaussian distribution
90 | mu = self.fc_mu(result)
91 | log_var = self.fc_var(result)
92 |
93 | return [mu, log_var]
94 |
95 | def decode(self, z: Tensor) -> Tensor:
96 | """
97 | Maps the given latent codes
98 | onto the image space.
99 | :param z: (Tensor) [B x D]
100 | :return: (Tensor) [B x C x H x W]
101 | """
102 | result = self.decoder_input(z)
103 | result = result.view(-1, 512, 2, 2)
104 | result = self.decoder(result)
105 | result = self.final_layer(result)
106 | return result
107 |
108 | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
109 | """
110 | Reparameterization trick to sample from N(mu, var) from
111 | N(0,1).
112 | :param mu: (Tensor) Mean of the latent Gaussian [B x D]
113 | :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
114 | :return: (Tensor) [B x D]
115 | """
116 | std = torch.exp(0.5 * logvar)
117 | eps = torch.randn_like(std)
118 | return eps * std + mu
119 |
120 | def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
121 | mu, log_var = self.encode(input)
122 | z = self.reparameterize(mu, log_var)
123 | return [self.decode(z), input, mu, log_var]
124 |
125 | def loss_function(self,
126 | *args,
127 | **kwargs) -> dict:
128 | """
129 | Computes the VAE loss function.
130 | KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
131 | :param args:
132 | :param kwargs:
133 | :return:
134 | """
135 | recons = args[0]
136 | input = args[1]
137 | mu = args[2]
138 | log_var = args[3]
139 |
140 | kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset
141 | t = recons - input
142 | # recons_loss = F.mse_loss(recons, input)
143 | # cosh = torch.cosh(self.alpha * t)
144 | # recons_loss = (1./self.alpha * torch.log(cosh)).mean()
145 |
146 | recons_loss = self.alpha * t + \
147 | torch.log(1. + torch.exp(- 2 * self.alpha * t)) - \
148 | torch.log(torch.tensor(2.0))
149 | # print(self.alpha* t.max(), self.alpha*t.min())
150 | recons_loss = (1. / self.alpha) * recons_loss.mean()
151 |
152 | kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)
153 |
154 | loss = recons_loss + self.beta * kld_weight * kld_loss
155 | return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':-kld_loss}
156 |
157 | def sample(self,
158 | num_samples:int,
159 | current_device: int, **kwargs) -> Tensor:
160 | """
161 | Samples from the latent space and return the corresponding
162 | image space map.
163 | :param num_samples: (Int) Number of samples
164 | :param current_device: (Int) Device to run the model
165 | :return: (Tensor)
166 | """
167 | z = torch.randn(num_samples,
168 | self.latent_dim)
169 |
170 | z = z.to(current_device)
171 |
172 | samples = self.decode(z)
173 | return samples
174 |
175 | def generate(self, x: Tensor, **kwargs) -> Tensor:
176 | """
177 | Given an input image x, returns the reconstructed image
178 | :param x: (Tensor) [B x C x H x W]
179 | :return: (Tensor) [B x C x H x W]
180 | """
181 |
182 | return self.forward(x)[0]
--------------------------------------------------------------------------------
/models/miwae.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from models import BaseVAE
3 | from torch import nn
4 | from torch.nn import functional as F
5 | from .types_ import *
6 | from torch.distributions import Normal
7 |
8 |
9 | class MIWAE(BaseVAE):
10 |
11 | def __init__(self,
12 | in_channels: int,
13 | latent_dim: int,
14 | hidden_dims: List = None,
15 | num_samples: int = 5,
16 | num_estimates: int = 5,
17 | **kwargs) -> None:
18 | super(MIWAE, self).__init__()
19 |
20 | self.latent_dim = latent_dim
21 | self.num_samples = num_samples # K
22 | self.num_estimates = num_estimates # M
23 |
24 | modules = []
25 | if hidden_dims is None:
26 | hidden_dims = [32, 64, 128, 256, 512]
27 |
28 | # Build Encoder
29 | for h_dim in hidden_dims:
30 | modules.append(
31 | nn.Sequential(
32 | nn.Conv2d(in_channels, out_channels=h_dim,
33 | kernel_size= 3, stride= 2, padding = 1),
34 | nn.BatchNorm2d(h_dim),
35 | nn.LeakyReLU())
36 | )
37 | in_channels = h_dim
38 |
39 | self.encoder = nn.Sequential(*modules)
40 | self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim)
41 | self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim)
42 |
43 |
44 | # Build Decoder
45 | modules = []
46 |
47 | self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)
48 |
49 | hidden_dims.reverse()
50 |
51 | for i in range(len(hidden_dims) - 1):
52 | modules.append(
53 | nn.Sequential(
54 | nn.ConvTranspose2d(hidden_dims[i],
55 | hidden_dims[i + 1],
56 | kernel_size=3,
57 | stride = 2,
58 | padding=1,
59 | output_padding=1),
60 | nn.BatchNorm2d(hidden_dims[i + 1]),
61 | nn.LeakyReLU())
62 | )
63 |
64 |
65 |
66 | self.decoder = nn.Sequential(*modules)
67 |
68 | self.final_layer = nn.Sequential(
69 | nn.ConvTranspose2d(hidden_dims[-1],
70 | hidden_dims[-1],
71 | kernel_size=3,
72 | stride=2,
73 | padding=1,
74 | output_padding=1),
75 | nn.BatchNorm2d(hidden_dims[-1]),
76 | nn.LeakyReLU(),
77 | nn.Conv2d(hidden_dims[-1], out_channels= 3,
78 | kernel_size= 3, padding= 1),
79 | nn.Tanh())
80 |
81 | def encode(self, input: Tensor) -> List[Tensor]:
82 | """
83 | Encodes the input by passing through the encoder network
84 | and returns the latent codes.
85 | :param input: (Tensor) Input tensor to encoder [N x C x H x W]
86 | :return: (Tensor) List of latent codes
87 | """
88 | result = self.encoder(input)
89 | result = torch.flatten(result, start_dim=1)
90 |
91 | # Split the result into mu and var components
92 | # of the latent Gaussian distribution
93 | mu = self.fc_mu(result)
94 | log_var = self.fc_var(result)
95 |
96 | return [mu, log_var]
97 |
98 | def decode(self, z: Tensor) -> Tensor:
99 | """
100 | Maps the given latent codes of S samples
101 | onto the image space.
102 | :param z: (Tensor) [B x S x D]
103 | :return: (Tensor) [B x S x C x H x W]
104 | """
105 | B, M,S, D = z.size()
106 | z = z.contiguous().view(-1, self.latent_dim) #[BMS x D]
107 | result = self.decoder_input(z)
108 | result = result.view(-1, 512, 2, 2)
109 | result = self.decoder(result)
110 | result = self.final_layer(result) #[BMS x C x H x W ]
111 | result = result.view([B, M, S,result.size(-3), result.size(-2), result.size(-1)]) #[B x M x S x C x H x W]
112 | return result
113 |
114 | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
115 | """
116 | :param mu: (Tensor) Mean of the latent Gaussian
117 | :param logvar: (Tensor) Standard deviation of the latent Gaussian
118 | :return:
119 | """
120 | std = torch.exp(0.5 * logvar)
121 | eps = torch.randn_like(std)
122 | return eps * std + mu
123 |
124 | def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
125 | mu, log_var = self.encode(input)
126 | mu = mu.repeat(self.num_estimates, self.num_samples, 1, 1).permute(2, 0, 1, 3) # [B x M x S x D]
127 | log_var = log_var.repeat(self.num_estimates, self.num_samples, 1, 1).permute(2, 0, 1, 3) # [B x M x S x D]
128 | z = self.reparameterize(mu, log_var) # [B x M x S x D]
129 | eps = (z - mu) / log_var # Prior samples
130 | return [self.decode(z), input, mu, log_var, z, eps]
131 |
132 | def loss_function(self,
133 | *args,
134 | **kwargs) -> dict:
135 | """
136 | KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
137 | :param args:
138 | :param kwargs:
139 | :return:
140 | """
141 | recons = args[0]
142 | input = args[1]
143 | mu = args[2]
144 | log_var = args[3]
145 | z = args[4]
146 | eps = args[5]
147 |
148 | input = input.repeat(self.num_estimates,
149 | self.num_samples, 1, 1, 1, 1).permute(2, 0, 1, 3, 4, 5) #[B x M x S x C x H x W]
150 |
151 | kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset
152 |
153 | log_p_x_z = ((recons - input) ** 2).flatten(3).mean(-1) # Reconstruction Loss # [B x M x S]
154 |
155 | kld_loss = -0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim=3) # [B x M x S]
156 | # Get importance weights
157 | log_weight = (log_p_x_z + kld_weight * kld_loss) #.detach().data
158 |
159 | # Rescale the weights (along the sample dim) to lie in [0, 1] and sum to 1
160 | weight = F.softmax(log_weight, dim = -1) # [B x M x S]
161 |
162 | loss = torch.mean(torch.mean(torch.sum(weight * log_weight, dim=-1), dim = -2), dim = 0)
163 |
164 | return {'loss': loss, 'Reconstruction_Loss':log_p_x_z.mean(), 'KLD':-kld_loss.mean()}
165 |
166 | def sample(self,
167 | num_samples:int,
168 | current_device: int, **kwargs) -> Tensor:
169 | """
170 | Samples from the latent space and return the corresponding
171 | image space map.
172 | :param num_samples: (Int) Number of samples
173 | :param current_device: (Int) Device to run the model
174 | :return: (Tensor)
175 | """
176 | z = torch.randn(num_samples, 1, 1,
177 | self.latent_dim)
178 |
179 | z = z.to(current_device)
180 |
181 | samples = self.decode(z).squeeze()
182 | return samples
183 |
184 | def generate(self, x: Tensor, **kwargs) -> Tensor:
185 | """
186 | Given an input image x, returns the reconstructed image.
187 | Returns only the first reconstructed sample
188 | :param x: (Tensor) [B x C x H x W]
189 | :return: (Tensor) [B x C x H x W]
190 | """
191 |
192 | return self.forward(x)[0][:, 0, 0, :]
193 |
--------------------------------------------------------------------------------
/models/swae.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from models import BaseVAE
3 | from torch import nn
4 | from torch.nn import functional as F
5 | from torch import distributions as dist
6 | from .types_ import *
7 |
8 |
9 | class SWAE(BaseVAE):
10 |
11 | def __init__(self,
12 | in_channels: int,
13 | latent_dim: int,
14 | hidden_dims: List = None,
15 | reg_weight: int = 100,
16 | wasserstein_deg: float= 2.,
17 | num_projections: int = 50,
18 | projection_dist: str = 'normal',
19 | **kwargs) -> None:
20 | super(SWAE, self).__init__()
21 |
22 | self.latent_dim = latent_dim
23 | self.reg_weight = reg_weight
24 | self.p = wasserstein_deg
25 | self.num_projections = num_projections
26 | self.proj_dist = projection_dist
27 |
28 | modules = []
29 | if hidden_dims is None:
30 | hidden_dims = [32, 64, 128, 256, 512]
31 |
32 | # Build Encoder
33 | for h_dim in hidden_dims:
34 | modules.append(
35 | nn.Sequential(
36 | nn.Conv2d(in_channels, out_channels=h_dim,
37 | kernel_size= 3, stride= 2, padding = 1),
38 | nn.BatchNorm2d(h_dim),
39 | nn.LeakyReLU())
40 | )
41 | in_channels = h_dim
42 |
43 | self.encoder = nn.Sequential(*modules)
44 | self.fc_z = nn.Linear(hidden_dims[-1]*4, latent_dim)
45 |
46 |
47 | # Build Decoder
48 | modules = []
49 |
50 | self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)
51 |
52 | hidden_dims.reverse()
53 |
54 | for i in range(len(hidden_dims) - 1):
55 | modules.append(
56 | nn.Sequential(
57 | nn.ConvTranspose2d(hidden_dims[i],
58 | hidden_dims[i + 1],
59 | kernel_size=3,
60 | stride = 2,
61 | padding=1,
62 | output_padding=1),
63 | nn.BatchNorm2d(hidden_dims[i + 1]),
64 | nn.LeakyReLU())
65 | )
66 |
67 |
68 |
69 | self.decoder = nn.Sequential(*modules)
70 |
71 | self.final_layer = nn.Sequential(
72 | nn.ConvTranspose2d(hidden_dims[-1],
73 | hidden_dims[-1],
74 | kernel_size=3,
75 | stride=2,
76 | padding=1,
77 | output_padding=1),
78 | nn.BatchNorm2d(hidden_dims[-1]),
79 | nn.LeakyReLU(),
80 | nn.Conv2d(hidden_dims[-1], out_channels= 3,
81 | kernel_size= 3, padding= 1),
82 | nn.Tanh())
83 |
84 | def encode(self, input: Tensor) -> Tensor:
85 | """
86 | Encodes the input by passing through the encoder network
87 | and returns the latent codes.
88 | :param input: (Tensor) Input tensor to encoder [N x C x H x W]
89 | :return: (Tensor) List of latent codes
90 | """
91 | result = self.encoder(input)
92 | result = torch.flatten(result, start_dim=1)
93 |
94 | # Split the result into mu and var components
95 | # of the latent Gaussian distribution
96 | z = self.fc_z(result)
97 | return z
98 |
99 | def decode(self, z: Tensor) -> Tensor:
100 | result = self.decoder_input(z)
101 | result = result.view(-1, 512, 2, 2)
102 | result = self.decoder(result)
103 | result = self.final_layer(result)
104 | return result
105 |
106 | def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
107 | z = self.encode(input)
108 | return [self.decode(z), input, z]
109 |
110 | def loss_function(self,
111 | *args,
112 | **kwargs) -> dict:
113 | recons = args[0]
114 | input = args[1]
115 | z = args[2]
116 |
117 | batch_size = input.size(0)
118 | bias_corr = batch_size * (batch_size - 1)
119 | reg_weight = self.reg_weight / bias_corr
120 |
121 | recons_loss_l2 = F.mse_loss(recons, input)
122 | recons_loss_l1 = F.l1_loss(recons, input)
123 |
124 | swd_loss = self.compute_swd(z, self.p, reg_weight)
125 |
126 | loss = recons_loss_l2 + recons_loss_l1 + swd_loss
127 | return {'loss': loss, 'Reconstruction_Loss':(recons_loss_l2 + recons_loss_l1), 'SWD': swd_loss}
128 |
129 | def get_random_projections(self, latent_dim: int, num_samples: int) -> Tensor:
130 | """
131 | Returns random samples from latent distribution's (Gaussian)
132 | unit sphere for projecting the encoded samples and the
133 | distribution samples.
134 |
135 | :param latent_dim: (Int) Dimensionality of the latent space (D)
136 | :param num_samples: (Int) Number of samples required (S)
137 | :return: Random projections from the latent unit sphere
138 | """
139 | if self.proj_dist == 'normal':
140 | rand_samples = torch.randn(num_samples, latent_dim)
141 | elif self.proj_dist == 'cauchy':
142 | rand_samples = dist.Cauchy(torch.tensor([0.0]),
143 | torch.tensor([1.0])).sample((num_samples, latent_dim)).squeeze()
144 | else:
145 | raise ValueError('Unknown projection distribution.')
146 |
147 | rand_proj = rand_samples / rand_samples.norm(dim=1).view(-1,1)
148 | return rand_proj # [S x D]
149 |
150 |
151 | def compute_swd(self,
152 | z: Tensor,
153 | p: float,
154 | reg_weight: float) -> Tensor:
155 | """
156 | Computes the Sliced Wasserstein Distance (SWD) - which consists of
157 | randomly projecting the encoded and prior vectors and computing
158 | their Wasserstein distance along those projections.
159 |
160 | :param z: Latent samples # [N x D]
161 | :param p: Value for the p^th Wasserstein distance
162 | :param reg_weight:
163 | :return:
164 | """
165 | prior_z = torch.randn_like(z) # [N x D]
166 | device = z.device
167 |
168 | proj_matrix = self.get_random_projections(self.latent_dim,
169 | num_samples=self.num_projections).transpose(0,1).to(device)
170 |
171 | latent_projections = z.matmul(proj_matrix) # [N x S]
172 | prior_projections = prior_z.matmul(proj_matrix) # [N x S]
173 |
174 | # The Wasserstein distance is computed by sorting the two projections
175 | # across the batches and computing their element-wise l2 distance
176 | w_dist = torch.sort(latent_projections.t(), dim=1)[0] - \
177 | torch.sort(prior_projections.t(), dim=1)[0]
178 | w_dist = w_dist.pow(p)
179 | return reg_weight * w_dist.mean()
180 |
181 | def sample(self,
182 | num_samples:int,
183 | current_device: int, **kwargs) -> Tensor:
184 | """
185 | Samples from the latent space and return the corresponding
186 | image space map.
187 | :param num_samples: (Int) Number of samples
188 | :param current_device: (Int) Device to run the model
189 | :return: (Tensor)
190 | """
191 | z = torch.randn(num_samples,
192 | self.latent_dim)
193 |
194 | z = z.to(current_device)
195 |
196 | samples = self.decode(z)
197 | return samples
198 |
199 | def generate(self, x: Tensor, **kwargs) -> Tensor:
200 | """
201 | Given an input image x, returns the reconstructed image
202 | :param x: (Tensor) [B x C x H x W]
203 | :return: (Tensor) [B x C x H x W]
204 | """
205 |
206 | return self.forward(x)[0]
207 |
--------------------------------------------------------------------------------
/models/twostage_vae.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from models import BaseVAE
3 | from torch import nn
4 | from torch.nn import functional as F
5 | from .types_ import *
6 |
7 |
8 | class TwoStageVAE(BaseVAE):
9 |
10 | def __init__(self,
11 | in_channels: int,
12 | latent_dim: int,
13 | hidden_dims: List = None,
14 | hidden_dims2: List = None,
15 | **kwargs) -> None:
16 | super(TwoStageVAE, self).__init__()
17 |
18 | self.latent_dim = latent_dim
19 |
20 | modules = []
21 | if hidden_dims is None:
22 | hidden_dims = [32, 64, 128, 256, 512]
23 |
24 | if hidden_dims2 is None:
25 | hidden_dims2 = [1024, 1024]
26 |
27 | # Build Encoder
28 | for h_dim in hidden_dims:
29 | modules.append(
30 | nn.Sequential(
31 | nn.Conv2d(in_channels, out_channels=h_dim,
32 | kernel_size= 3, stride= 2, padding = 1),
33 | nn.BatchNorm2d(h_dim),
34 | nn.LeakyReLU())
35 | )
36 | in_channels = h_dim
37 |
38 | self.encoder = nn.Sequential(*modules)
39 | self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim)
40 | self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim)
41 |
42 |
43 | # Build Decoder
44 | modules = []
45 | self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)
46 | hidden_dims.reverse()
47 |
48 | for i in range(len(hidden_dims) - 1):
49 | modules.append(
50 | nn.Sequential(
51 | nn.ConvTranspose2d(hidden_dims[i],
52 | hidden_dims[i + 1],
53 | kernel_size=3,
54 | stride = 2,
55 | padding=1,
56 | output_padding=1),
57 | nn.BatchNorm2d(hidden_dims[i + 1]),
58 | nn.LeakyReLU())
59 | )
60 | self.decoder = nn.Sequential(*modules)
61 |
62 | self.final_layer = nn.Sequential(
63 | nn.ConvTranspose2d(hidden_dims[-1],
64 | hidden_dims[-1],
65 | kernel_size=3,
66 | stride=2,
67 | padding=1,
68 | output_padding=1),
69 | nn.BatchNorm2d(hidden_dims[-1]),
70 | nn.LeakyReLU(),
71 | nn.Conv2d(hidden_dims[-1], out_channels= 3,
72 | kernel_size= 3, padding= 1),
73 | nn.Tanh())
74 |
75 | #---------------------- Second VAE ---------------------------#
76 | encoder2 = []
77 | in_channels = self.latent_dim
78 | for h_dim in hidden_dims2:
79 | encoder2.append(nn.Sequential(
80 | nn.Linear(in_channels, h_dim),
81 | nn.BatchNorm1d(h_dim),
82 | nn.LeakyReLU()))
83 | in_channels = h_dim
84 | self.encoder2 = nn.Sequential(*encoder2)
85 | self.fc_mu2 = nn.Linear(hidden_dims2[-1], self.latent_dim)
86 | self.fc_var2 = nn.Linear(hidden_dims2[-1], self.latent_dim)
87 |
88 | decoder2 = []
89 | hidden_dims2.reverse()
90 |
91 | in_channels = self.latent_dim
92 | for h_dim in hidden_dims2:
93 | decoder2.append(nn.Sequential(
94 | nn.Linear(in_channels, h_dim),
95 | nn.BatchNorm1d(h_dim),
96 | nn.LeakyReLU()))
97 | in_channels = h_dim
98 | self.decoder2 = nn.Sequential(*decoder2)
99 |
100 | def encode(self, input: Tensor) -> List[Tensor]:
101 | """
102 | Encodes the input by passing through the encoder network
103 | and returns the latent codes.
104 | :param input: (Tensor) Input tensor to encoder [N x C x H x W]
105 | :return: (Tensor) List of latent codes
106 | """
107 | result = self.encoder(input)
108 | result = torch.flatten(result, start_dim=1)
109 |
110 | # Split the result into mu and var components
111 | # of the latent Gaussian distribution
112 | mu = self.fc_mu(result)
113 | log_var = self.fc_var(result)
114 |
115 | return [mu, log_var]
116 |
117 | def decode(self, z: Tensor) -> Tensor:
118 | """
119 | Maps the given latent codes
120 | onto the image space.
121 | :param z: (Tensor) [B x D]
122 | :return: (Tensor) [B x C x H x W]
123 | """
124 | result = self.decoder_input(z)
125 | result = result.view(-1, 512, 2, 2)
126 | result = self.decoder(result)
127 | result = self.final_layer(result)
128 | return result
129 |
130 | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
131 | """
132 | Reparameterization trick to sample from N(mu, var) from
133 | N(0,1).
134 | :param mu: (Tensor) Mean of the latent Gaussian [B x D]
135 | :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
136 | :return: (Tensor) [B x D]
137 | """
138 | std = torch.exp(0.5 * logvar)
139 | eps = torch.randn_like(std)
140 | return eps * std + mu
141 |
142 | def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
143 | mu, log_var = self.encode(input)
144 | z = self.reparameterize(mu, log_var)
145 |
146 | return [self.decode(z), input, mu, log_var]
147 |
148 | def loss_function(self,
149 | *args,
150 | **kwargs) -> dict:
151 | """
152 | Computes the VAE loss function.
153 | KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
154 | :param args:
155 | :param kwargs:
156 | :return:
157 | """
158 | recons = args[0]
159 | input = args[1]
160 | mu = args[2]
161 | log_var = args[3]
162 |
163 | kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset
164 | recons_loss =F.mse_loss(recons, input)
165 |
166 |
167 | kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)
168 |
169 | loss = recons_loss + kld_weight * kld_loss
170 | return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':-kld_loss}
171 |
172 | def sample(self,
173 | num_samples:int,
174 | current_device: int, **kwargs) -> Tensor:
175 | """
176 | Samples from the latent space and return the corresponding
177 | image space map.
178 | :param num_samples: (Int) Number of samples
179 | :param current_device: (Int) Device to run the model
180 | :return: (Tensor)
181 | """
182 | z = torch.randn(num_samples,
183 | self.latent_dim)
184 |
185 | z = z.to(current_device)
186 |
187 | samples = self.decode(z)
188 | return samples
189 |
190 | def generate(self, x: Tensor, **kwargs) -> Tensor:
191 | """
192 | Given an input image x, returns the reconstructed image
193 | :param x: (Tensor) [B x C x H x W]
194 | :return: (Tensor) [B x C x H x W]
195 | """
196 |
197 | return self.forward(x)[0]
--------------------------------------------------------------------------------
/models/types_.py:
--------------------------------------------------------------------------------
1 | from typing import List, Callable, Union, Any, TypeVar, Tuple
2 | # from torch import tensor as Tensor
3 |
4 | Tensor = TypeVar('torch.tensor')
5 |
--------------------------------------------------------------------------------
/models/vampvae.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from models import BaseVAE
3 | from torch import nn
4 | from torch.nn import functional as F
5 | from .types_ import *
6 |
7 |
8 | class VampVAE(BaseVAE):
9 |
10 | def __init__(self,
11 | in_channels: int,
12 | latent_dim: int,
13 | hidden_dims: List = None,
14 | num_components: int = 50,
15 | **kwargs) -> None:
16 | super(VampVAE, self).__init__()
17 |
18 | self.latent_dim = latent_dim
19 | self.num_components = num_components
20 |
21 | modules = []
22 | if hidden_dims is None:
23 | hidden_dims = [32, 64, 128, 256, 512]
24 |
25 | # Build Encoder
26 | for h_dim in hidden_dims:
27 | modules.append(
28 | nn.Sequential(
29 | nn.Conv2d(in_channels, out_channels=h_dim,
30 | kernel_size= 3, stride= 2, padding = 1),
31 | nn.BatchNorm2d(h_dim),
32 | nn.LeakyReLU())
33 | )
34 | in_channels = h_dim
35 |
36 | self.encoder = nn.Sequential(*modules)
37 | self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim)
38 | self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim)
39 |
40 |
41 | # Build Decoder
42 | modules = []
43 |
44 | self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)
45 |
46 | hidden_dims.reverse()
47 |
48 | for i in range(len(hidden_dims) - 1):
49 | modules.append(
50 | nn.Sequential(
51 | nn.ConvTranspose2d(hidden_dims[i],
52 | hidden_dims[i + 1],
53 | kernel_size=3,
54 | stride = 2,
55 | padding=1,
56 | output_padding=1),
57 | nn.BatchNorm2d(hidden_dims[i + 1]),
58 | nn.LeakyReLU())
59 | )
60 |
61 |
62 |
63 | self.decoder = nn.Sequential(*modules)
64 |
65 | self.final_layer = nn.Sequential(
66 | nn.ConvTranspose2d(hidden_dims[-1],
67 | hidden_dims[-1],
68 | kernel_size=3,
69 | stride=2,
70 | padding=1,
71 | output_padding=1),
72 | nn.BatchNorm2d(hidden_dims[-1]),
73 | nn.LeakyReLU(),
74 | nn.Conv2d(hidden_dims[-1], out_channels= 3,
75 | kernel_size= 3, padding= 1),
76 | nn.Tanh())
77 |
78 | self.pseudo_input = torch.eye(self.num_components, requires_grad= False)
79 | self.embed_pseudo = nn.Sequential(nn.Linear(self.num_components, 12288),
80 | nn.Hardtanh(0.0, 1.0)) # 3x64x64 = 12288
81 |
82 | def encode(self, input: Tensor) -> List[Tensor]:
83 | """
84 | Encodes the input by passing through the encoder network
85 | and returns the latent codes.
86 | :param input: (Tensor) Input tensor to encoder [N x C x H x W]
87 | :return: (Tensor) List of latent codes
88 | """
89 | result = self.encoder(input)
90 | result = torch.flatten(result, start_dim=1)
91 |
92 | # Split the result into mu and var components
93 | # of the latent Gaussian distribution
94 | mu = self.fc_mu(result)
95 | log_var = self.fc_var(result)
96 |
97 | return [mu, log_var]
98 |
99 | def decode(self, z: Tensor) -> Tensor:
100 | result = self.decoder_input(z)
101 | result = result.view(-1, 512, 2, 2)
102 | result = self.decoder(result)
103 | result = self.final_layer(result)
104 | return result
105 |
106 | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
107 | """
108 | Will a single z be enough ti compute the expectation
109 | for the loss??
110 | :param mu: (Tensor) Mean of the latent Gaussian
111 | :param logvar: (Tensor) Standard deviation of the latent Gaussian
112 | :return:
113 | """
114 | std = torch.exp(0.5 * logvar)
115 | eps = torch.randn_like(std)
116 | return eps * std + mu
117 |
118 | def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
119 | mu, log_var = self.encode(input)
120 | z = self.reparameterize(mu, log_var)
121 | return [self.decode(z), input, mu, log_var, z]
122 |
123 | def loss_function(self,
124 | *args,
125 | **kwargs) -> dict:
126 | recons = args[0]
127 | input = args[1]
128 | mu = args[2]
129 | log_var = args[3]
130 | z = args[4]
131 |
132 | kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset
133 | recons_loss =F.mse_loss(recons, input)
134 |
135 | E_log_q_z = torch.mean(torch.sum(-0.5 * (log_var + (z - mu) ** 2)/ log_var.exp(),
136 | dim = 1),
137 | dim = 0)
138 |
139 | # Original Prior
140 | # E_log_p_z = torch.mean(torch.sum(-0.5 * (z ** 2), dim = 1), dim = 0)
141 |
142 | # Vamp Prior
143 | M, C, H, W = input.size()
144 | curr_device = input.device
145 | self.pseudo_input = self.pseudo_input.cuda(curr_device)
146 | x = self.embed_pseudo(self.pseudo_input)
147 | x = x.view(-1, C, H, W)
148 | prior_mu, prior_log_var = self.encode(x)
149 |
150 | z_expand = z.unsqueeze(1)
151 | prior_mu = prior_mu.unsqueeze(0)
152 | prior_log_var = prior_log_var.unsqueeze(0)
153 |
154 | E_log_p_z = torch.sum(-0.5 *
155 | (prior_log_var + (z_expand - prior_mu) ** 2)/ prior_log_var.exp(),
156 | dim = 2) - torch.log(torch.tensor(self.num_components).float())
157 |
158 | # dim = 0)
159 | E_log_p_z = torch.logsumexp(E_log_p_z, dim = 1)
160 | E_log_p_z = torch.mean(E_log_p_z, dim = 0)
161 |
162 | # KLD = E_q log q - E_q log p
163 | kld_loss = -(E_log_p_z - E_log_q_z)
164 | # print(E_log_p_z, E_log_q_z)
165 |
166 |
167 | loss = recons_loss + kld_weight * kld_loss
168 | return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':-kld_loss}
169 |
170 | def sample(self,
171 | num_samples:int,
172 | current_device: int, **kwargs) -> Tensor:
173 | """
174 | Samples from the latent space and return the corresponding
175 | image space map.
176 | :param num_samples: (Int) Number of samples
177 | :param current_device: (Int) Device to run the model
178 | :return: (Tensor)
179 | """
180 | z = torch.randn(num_samples,
181 | self.latent_dim)
182 |
183 | z = z.cuda(current_device)
184 |
185 | samples = self.decode(z)
186 | return samples
187 |
188 | def generate(self, x: Tensor, **kwargs) -> Tensor:
189 | """
190 | Given an input image x, returns the reconstructed image
191 | :param x: (Tensor) [B x C x H x W]
192 | :return: (Tensor) [B x C x H x W]
193 | """
194 |
195 | return self.forward(x)[0]
--------------------------------------------------------------------------------
/models/vanilla_vae.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from models import BaseVAE
3 | from torch import nn
4 | from torch.nn import functional as F
5 | from .types_ import *
6 |
7 |
8 | class VanillaVAE(BaseVAE):
9 |
10 |
11 | def __init__(self,
12 | in_channels: int,
13 | latent_dim: int,
14 | hidden_dims: List = None,
15 | **kwargs) -> None:
16 | super(VanillaVAE, self).__init__()
17 |
18 | self.latent_dim = latent_dim
19 |
20 | modules = []
21 | if hidden_dims is None:
22 | hidden_dims = [32, 64, 128, 256, 512]
23 |
24 | # Build Encoder
25 | for h_dim in hidden_dims:
26 | modules.append(
27 | nn.Sequential(
28 | nn.Conv2d(in_channels, out_channels=h_dim,
29 | kernel_size= 3, stride= 2, padding = 1),
30 | nn.BatchNorm2d(h_dim),
31 | nn.LeakyReLU())
32 | )
33 | in_channels = h_dim
34 |
35 | self.encoder = nn.Sequential(*modules)
36 | self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim)
37 | self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim)
38 |
39 |
40 | # Build Decoder
41 | modules = []
42 |
43 | self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)
44 |
45 | hidden_dims.reverse()
46 |
47 | for i in range(len(hidden_dims) - 1):
48 | modules.append(
49 | nn.Sequential(
50 | nn.ConvTranspose2d(hidden_dims[i],
51 | hidden_dims[i + 1],
52 | kernel_size=3,
53 | stride = 2,
54 | padding=1,
55 | output_padding=1),
56 | nn.BatchNorm2d(hidden_dims[i + 1]),
57 | nn.LeakyReLU())
58 | )
59 |
60 |
61 |
62 | self.decoder = nn.Sequential(*modules)
63 |
64 | self.final_layer = nn.Sequential(
65 | nn.ConvTranspose2d(hidden_dims[-1],
66 | hidden_dims[-1],
67 | kernel_size=3,
68 | stride=2,
69 | padding=1,
70 | output_padding=1),
71 | nn.BatchNorm2d(hidden_dims[-1]),
72 | nn.LeakyReLU(),
73 | nn.Conv2d(hidden_dims[-1], out_channels= 3,
74 | kernel_size= 3, padding= 1),
75 | nn.Tanh())
76 |
77 | def encode(self, input: Tensor) -> List[Tensor]:
78 | """
79 | Encodes the input by passing through the encoder network
80 | and returns the latent codes.
81 | :param input: (Tensor) Input tensor to encoder [N x C x H x W]
82 | :return: (Tensor) List of latent codes
83 | """
84 | result = self.encoder(input)
85 | result = torch.flatten(result, start_dim=1)
86 |
87 | # Split the result into mu and var components
88 | # of the latent Gaussian distribution
89 | mu = self.fc_mu(result)
90 | log_var = self.fc_var(result)
91 |
92 | return [mu, log_var]
93 |
94 | def decode(self, z: Tensor) -> Tensor:
95 | """
96 | Maps the given latent codes
97 | onto the image space.
98 | :param z: (Tensor) [B x D]
99 | :return: (Tensor) [B x C x H x W]
100 | """
101 | result = self.decoder_input(z)
102 | result = result.view(-1, 512, 2, 2)
103 | result = self.decoder(result)
104 | result = self.final_layer(result)
105 | return result
106 |
107 | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
108 | """
109 | Reparameterization trick to sample from N(mu, var) from
110 | N(0,1).
111 | :param mu: (Tensor) Mean of the latent Gaussian [B x D]
112 | :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
113 | :return: (Tensor) [B x D]
114 | """
115 | std = torch.exp(0.5 * logvar)
116 | eps = torch.randn_like(std)
117 | return eps * std + mu
118 |
119 | def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
120 | mu, log_var = self.encode(input)
121 | z = self.reparameterize(mu, log_var)
122 | return [self.decode(z), input, mu, log_var]
123 |
124 | def loss_function(self,
125 | *args,
126 | **kwargs) -> dict:
127 | """
128 | Computes the VAE loss function.
129 | KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
130 | :param args:
131 | :param kwargs:
132 | :return:
133 | """
134 | recons = args[0]
135 | input = args[1]
136 | mu = args[2]
137 | log_var = args[3]
138 |
139 | kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset
140 | recons_loss =F.mse_loss(recons, input)
141 |
142 |
143 | kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)
144 |
145 | loss = recons_loss + kld_weight * kld_loss
146 | return {'loss': loss, 'Reconstruction_Loss':recons_loss.detach(), 'KLD':-kld_loss.detach()}
147 |
148 | def sample(self,
149 | num_samples:int,
150 | current_device: int, **kwargs) -> Tensor:
151 | """
152 | Samples from the latent space and return the corresponding
153 | image space map.
154 | :param num_samples: (Int) Number of samples
155 | :param current_device: (Int) Device to run the model
156 | :return: (Tensor)
157 | """
158 | z = torch.randn(num_samples,
159 | self.latent_dim)
160 |
161 | z = z.to(current_device)
162 |
163 | samples = self.decode(z)
164 | return samples
165 |
166 | def generate(self, x: Tensor, **kwargs) -> Tensor:
167 | """
168 | Given an input image x, returns the reconstructed image
169 | :param x: (Tensor) [B x C x H x W]
170 | :return: (Tensor) [B x C x H x W]
171 | """
172 |
173 | return self.forward(x)[0]
--------------------------------------------------------------------------------
/models/vq_vae.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from models import BaseVAE
3 | from torch import nn
4 | from torch.nn import functional as F
5 | from .types_ import *
6 |
7 | class VectorQuantizer(nn.Module):
8 | """
9 | Reference:
10 | [1] https://github.com/deepmind/sonnet/blob/v2/sonnet/src/nets/vqvae.py
11 | """
12 | def __init__(self,
13 | num_embeddings: int,
14 | embedding_dim: int,
15 | beta: float = 0.25):
16 | super(VectorQuantizer, self).__init__()
17 | self.K = num_embeddings
18 | self.D = embedding_dim
19 | self.beta = beta
20 |
21 | self.embedding = nn.Embedding(self.K, self.D)
22 | self.embedding.weight.data.uniform_(-1 / self.K, 1 / self.K)
23 |
24 | def forward(self, latents: Tensor) -> Tensor:
25 | latents = latents.permute(0, 2, 3, 1).contiguous() # [B x D x H x W] -> [B x H x W x D]
26 | latents_shape = latents.shape
27 | flat_latents = latents.view(-1, self.D) # [BHW x D]
28 |
29 | # Compute L2 distance between latents and embedding weights
30 | dist = torch.sum(flat_latents ** 2, dim=1, keepdim=True) + \
31 | torch.sum(self.embedding.weight ** 2, dim=1) - \
32 | 2 * torch.matmul(flat_latents, self.embedding.weight.t()) # [BHW x K]
33 |
34 | # Get the encoding that has the min distance
35 | encoding_inds = torch.argmin(dist, dim=1).unsqueeze(1) # [BHW, 1]
36 |
37 | # Convert to one-hot encodings
38 | device = latents.device
39 | encoding_one_hot = torch.zeros(encoding_inds.size(0), self.K, device=device)
40 | encoding_one_hot.scatter_(1, encoding_inds, 1) # [BHW x K]
41 |
42 | # Quantize the latents
43 | quantized_latents = torch.matmul(encoding_one_hot, self.embedding.weight) # [BHW, D]
44 | quantized_latents = quantized_latents.view(latents_shape) # [B x H x W x D]
45 |
46 | # Compute the VQ Losses
47 | commitment_loss = F.mse_loss(quantized_latents.detach(), latents)
48 | embedding_loss = F.mse_loss(quantized_latents, latents.detach())
49 |
50 | vq_loss = commitment_loss * self.beta + embedding_loss
51 |
52 | # Add the residue back to the latents
53 | quantized_latents = latents + (quantized_latents - latents).detach()
54 |
55 | return quantized_latents.permute(0, 3, 1, 2).contiguous(), vq_loss # [B x D x H x W]
56 |
57 | class ResidualLayer(nn.Module):
58 |
59 | def __init__(self,
60 | in_channels: int,
61 | out_channels: int):
62 | super(ResidualLayer, self).__init__()
63 | self.resblock = nn.Sequential(nn.Conv2d(in_channels, out_channels,
64 | kernel_size=3, padding=1, bias=False),
65 | nn.ReLU(True),
66 | nn.Conv2d(out_channels, out_channels,
67 | kernel_size=1, bias=False))
68 |
69 | def forward(self, input: Tensor) -> Tensor:
70 | return input + self.resblock(input)
71 |
72 |
73 | class VQVAE(BaseVAE):
74 |
75 | def __init__(self,
76 | in_channels: int,
77 | embedding_dim: int,
78 | num_embeddings: int,
79 | hidden_dims: List = None,
80 | beta: float = 0.25,
81 | img_size: int = 64,
82 | **kwargs) -> None:
83 | super(VQVAE, self).__init__()
84 |
85 | self.embedding_dim = embedding_dim
86 | self.num_embeddings = num_embeddings
87 | self.img_size = img_size
88 | self.beta = beta
89 |
90 | modules = []
91 | if hidden_dims is None:
92 | hidden_dims = [128, 256]
93 |
94 | # Build Encoder
95 | for h_dim in hidden_dims:
96 | modules.append(
97 | nn.Sequential(
98 | nn.Conv2d(in_channels, out_channels=h_dim,
99 | kernel_size=4, stride=2, padding=1),
100 | nn.LeakyReLU())
101 | )
102 | in_channels = h_dim
103 |
104 | modules.append(
105 | nn.Sequential(
106 | nn.Conv2d(in_channels, in_channels,
107 | kernel_size=3, stride=1, padding=1),
108 | nn.LeakyReLU())
109 | )
110 |
111 | for _ in range(6):
112 | modules.append(ResidualLayer(in_channels, in_channels))
113 | modules.append(nn.LeakyReLU())
114 |
115 | modules.append(
116 | nn.Sequential(
117 | nn.Conv2d(in_channels, embedding_dim,
118 | kernel_size=1, stride=1),
119 | nn.LeakyReLU())
120 | )
121 |
122 | self.encoder = nn.Sequential(*modules)
123 |
124 | self.vq_layer = VectorQuantizer(num_embeddings,
125 | embedding_dim,
126 | self.beta)
127 |
128 | # Build Decoder
129 | modules = []
130 | modules.append(
131 | nn.Sequential(
132 | nn.Conv2d(embedding_dim,
133 | hidden_dims[-1],
134 | kernel_size=3,
135 | stride=1,
136 | padding=1),
137 | nn.LeakyReLU())
138 | )
139 |
140 | for _ in range(6):
141 | modules.append(ResidualLayer(hidden_dims[-1], hidden_dims[-1]))
142 |
143 | modules.append(nn.LeakyReLU())
144 |
145 | hidden_dims.reverse()
146 |
147 | for i in range(len(hidden_dims) - 1):
148 | modules.append(
149 | nn.Sequential(
150 | nn.ConvTranspose2d(hidden_dims[i],
151 | hidden_dims[i + 1],
152 | kernel_size=4,
153 | stride=2,
154 | padding=1),
155 | nn.LeakyReLU())
156 | )
157 |
158 | modules.append(
159 | nn.Sequential(
160 | nn.ConvTranspose2d(hidden_dims[-1],
161 | out_channels=3,
162 | kernel_size=4,
163 | stride=2, padding=1),
164 | nn.Tanh()))
165 |
166 | self.decoder = nn.Sequential(*modules)
167 |
168 | def encode(self, input: Tensor) -> List[Tensor]:
169 | """
170 | Encodes the input by passing through the encoder network
171 | and returns the latent codes.
172 | :param input: (Tensor) Input tensor to encoder [N x C x H x W]
173 | :return: (Tensor) List of latent codes
174 | """
175 | result = self.encoder(input)
176 | return [result]
177 |
178 | def decode(self, z: Tensor) -> Tensor:
179 | """
180 | Maps the given latent codes
181 | onto the image space.
182 | :param z: (Tensor) [B x D x H x W]
183 | :return: (Tensor) [B x C x H x W]
184 | """
185 |
186 | result = self.decoder(z)
187 | return result
188 |
189 | def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
190 | encoding = self.encode(input)[0]
191 | quantized_inputs, vq_loss = self.vq_layer(encoding)
192 | return [self.decode(quantized_inputs), input, vq_loss]
193 |
194 | def loss_function(self,
195 | *args,
196 | **kwargs) -> dict:
197 | """
198 | :param args:
199 | :param kwargs:
200 | :return:
201 | """
202 | recons = args[0]
203 | input = args[1]
204 | vq_loss = args[2]
205 |
206 | recons_loss = F.mse_loss(recons, input)
207 |
208 | loss = recons_loss + vq_loss
209 | return {'loss': loss,
210 | 'Reconstruction_Loss': recons_loss,
211 | 'VQ_Loss':vq_loss}
212 |
213 | def sample(self,
214 | num_samples: int,
215 | current_device: Union[int, str], **kwargs) -> Tensor:
216 | raise Warning('VQVAE sampler is not implemented.')
217 |
218 | def generate(self, x: Tensor, **kwargs) -> Tensor:
219 | """
220 | Given an input image x, returns the reconstructed image
221 | :param x: (Tensor) [B x C x H x W]
222 | :return: (Tensor) [B x C x H x W]
223 | """
224 |
225 | return self.forward(x)[0]
--------------------------------------------------------------------------------
/models/wae_mmd.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from models import BaseVAE
3 | from torch import nn
4 | from torch.nn import functional as F
5 | from .types_ import *
6 |
7 |
8 | class WAE_MMD(BaseVAE):
9 |
10 | def __init__(self,
11 | in_channels: int,
12 | latent_dim: int,
13 | hidden_dims: List = None,
14 | reg_weight: int = 100,
15 | kernel_type: str = 'imq',
16 | latent_var: float = 2.,
17 | **kwargs) -> None:
18 | super(WAE_MMD, self).__init__()
19 |
20 | self.latent_dim = latent_dim
21 | self.reg_weight = reg_weight
22 | self.kernel_type = kernel_type
23 | self.z_var = latent_var
24 |
25 | modules = []
26 | if hidden_dims is None:
27 | hidden_dims = [32, 64, 128, 256, 512]
28 |
29 | # Build Encoder
30 | for h_dim in hidden_dims:
31 | modules.append(
32 | nn.Sequential(
33 | nn.Conv2d(in_channels, out_channels=h_dim,
34 | kernel_size= 3, stride= 2, padding = 1),
35 | nn.BatchNorm2d(h_dim),
36 | nn.LeakyReLU())
37 | )
38 | in_channels = h_dim
39 |
40 | self.encoder = nn.Sequential(*modules)
41 | self.fc_z = nn.Linear(hidden_dims[-1]*4, latent_dim)
42 |
43 |
44 | # Build Decoder
45 | modules = []
46 |
47 | self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)
48 |
49 | hidden_dims.reverse()
50 |
51 | for i in range(len(hidden_dims) - 1):
52 | modules.append(
53 | nn.Sequential(
54 | nn.ConvTranspose2d(hidden_dims[i],
55 | hidden_dims[i + 1],
56 | kernel_size=3,
57 | stride = 2,
58 | padding=1,
59 | output_padding=1),
60 | nn.BatchNorm2d(hidden_dims[i + 1]),
61 | nn.LeakyReLU())
62 | )
63 |
64 |
65 |
66 | self.decoder = nn.Sequential(*modules)
67 |
68 | self.final_layer = nn.Sequential(
69 | nn.ConvTranspose2d(hidden_dims[-1],
70 | hidden_dims[-1],
71 | kernel_size=3,
72 | stride=2,
73 | padding=1,
74 | output_padding=1),
75 | nn.BatchNorm2d(hidden_dims[-1]),
76 | nn.LeakyReLU(),
77 | nn.Conv2d(hidden_dims[-1], out_channels= 3,
78 | kernel_size= 3, padding= 1),
79 | nn.Tanh())
80 |
81 | def encode(self, input: Tensor) -> Tensor:
82 | """
83 | Encodes the input by passing through the encoder network
84 | and returns the latent codes.
85 | :param input: (Tensor) Input tensor to encoder [N x C x H x W]
86 | :return: (Tensor) List of latent codes
87 | """
88 | result = self.encoder(input)
89 | result = torch.flatten(result, start_dim=1)
90 |
91 | # Split the result into mu and var components
92 | # of the latent Gaussian distribution
93 | z = self.fc_z(result)
94 | return z
95 |
96 | def decode(self, z: Tensor) -> Tensor:
97 | result = self.decoder_input(z)
98 | result = result.view(-1, 512, 2, 2)
99 | result = self.decoder(result)
100 | result = self.final_layer(result)
101 | return result
102 |
103 | def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
104 | z = self.encode(input)
105 | return [self.decode(z), input, z]
106 |
107 | def loss_function(self,
108 | *args,
109 | **kwargs) -> dict:
110 | recons = args[0]
111 | input = args[1]
112 | z = args[2]
113 |
114 | batch_size = input.size(0)
115 | bias_corr = batch_size * (batch_size - 1)
116 | reg_weight = self.reg_weight / bias_corr
117 |
118 | recons_loss =F.mse_loss(recons, input)
119 |
120 | mmd_loss = self.compute_mmd(z, reg_weight)
121 |
122 | loss = recons_loss + mmd_loss
123 | return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'MMD': mmd_loss}
124 |
125 | def compute_kernel(self,
126 | x1: Tensor,
127 | x2: Tensor) -> Tensor:
128 | # Convert the tensors into row and column vectors
129 | D = x1.size(1)
130 | N = x1.size(0)
131 |
132 | x1 = x1.unsqueeze(-2) # Make it into a column tensor
133 | x2 = x2.unsqueeze(-3) # Make it into a row tensor
134 |
135 | """
136 | Usually the below lines are not required, especially in our case,
137 | but this is useful when x1 and x2 have different sizes
138 | along the 0th dimension.
139 | """
140 | x1 = x1.expand(N, N, D)
141 | x2 = x2.expand(N, N, D)
142 |
143 | if self.kernel_type == 'rbf':
144 | result = self.compute_rbf(x1, x2)
145 | elif self.kernel_type == 'imq':
146 | result = self.compute_inv_mult_quad(x1, x2)
147 | else:
148 | raise ValueError('Undefined kernel type.')
149 |
150 | return result
151 |
152 |
153 | def compute_rbf(self,
154 | x1: Tensor,
155 | x2: Tensor,
156 | eps: float = 1e-7) -> Tensor:
157 | """
158 | Computes the RBF Kernel between x1 and x2.
159 | :param x1: (Tensor)
160 | :param x2: (Tensor)
161 | :param eps: (Float)
162 | :return:
163 | """
164 | z_dim = x2.size(-1)
165 | sigma = 2. * z_dim * self.z_var
166 |
167 | result = torch.exp(-((x1 - x2).pow(2).mean(-1) / sigma))
168 | return result
169 |
170 | def compute_inv_mult_quad(self,
171 | x1: Tensor,
172 | x2: Tensor,
173 | eps: float = 1e-7) -> Tensor:
174 | """
175 | Computes the Inverse Multi-Quadratics Kernel between x1 and x2,
176 | given by
177 |
178 | k(x_1, x_2) = \sum \frac{C}{C + \|x_1 - x_2 \|^2}
179 | :param x1: (Tensor)
180 | :param x2: (Tensor)
181 | :param eps: (Float)
182 | :return:
183 | """
184 | z_dim = x2.size(-1)
185 | C = 2 * z_dim * self.z_var
186 | kernel = C / (eps + C + (x1 - x2).pow(2).sum(dim = -1))
187 |
188 | # Exclude diagonal elements
189 | result = kernel.sum() - kernel.diag().sum()
190 |
191 | return result
192 |
193 | def compute_mmd(self, z: Tensor, reg_weight: float) -> Tensor:
194 | # Sample from prior (Gaussian) distribution
195 | prior_z = torch.randn_like(z)
196 |
197 | prior_z__kernel = self.compute_kernel(prior_z, prior_z)
198 | z__kernel = self.compute_kernel(z, z)
199 | priorz_z__kernel = self.compute_kernel(prior_z, z)
200 |
201 | mmd = reg_weight * prior_z__kernel.mean() + \
202 | reg_weight * z__kernel.mean() - \
203 | 2 * reg_weight * priorz_z__kernel.mean()
204 | return mmd
205 |
206 | def sample(self,
207 | num_samples:int,
208 | current_device: int, **kwargs) -> Tensor:
209 | """
210 | Samples from the latent space and return the corresponding
211 | image space map.
212 | :param num_samples: (Int) Number of samples
213 | :param current_device: (Int) Device to run the model
214 | :return: (Tensor)
215 | """
216 | z = torch.randn(num_samples,
217 | self.latent_dim)
218 |
219 | z = z.to(current_device)
220 |
221 | samples = self.decode(z)
222 | return samples
223 |
224 | def generate(self, x: Tensor, **kwargs) -> Tensor:
225 | """
226 | Given an input image x, returns the reconstructed image
227 | :param x: (Tensor) [B x C x H x W]
228 | :return: (Tensor) [B x C x H x W]
229 | """
230 |
231 | return self.forward(x)[0]
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | pytorch-lightning==1.5.6
2 | PyYAML==6.0
3 | tensorboard>=2.2.0
4 | torch>=1.6.1
5 | torchsummary==1.5.1
6 | torchvision>=0.10.1
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 | import os
2 | import yaml
3 | import argparse
4 | import numpy as np
5 | from pathlib import Path
6 | from models import *
7 | from experiment import VAEXperiment
8 | import torch.backends.cudnn as cudnn
9 | from pytorch_lightning import Trainer
10 | from pytorch_lightning.loggers import TensorBoardLogger
11 | from pytorch_lightning.utilities.seed import seed_everything
12 | from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
13 | from dataset import VAEDataset
14 | from pytorch_lightning.plugins import DDPPlugin
15 |
16 |
17 | parser = argparse.ArgumentParser(description='Generic runner for VAE models')
18 | parser.add_argument('--config', '-c',
19 | dest="filename",
20 | metavar='FILE',
21 | help = 'path to the config file',
22 | default='configs/vae.yaml')
23 |
24 | args = parser.parse_args()
25 | with open(args.filename, 'r') as file:
26 | try:
27 | config = yaml.safe_load(file)
28 | except yaml.YAMLError as exc:
29 | print(exc)
30 |
31 |
32 | tb_logger = TensorBoardLogger(save_dir=config['logging_params']['save_dir'],
33 | name=config['model_params']['name'],)
34 |
35 | # For reproducibility
36 | seed_everything(config['exp_params']['manual_seed'], True)
37 |
38 | model = vae_models[config['model_params']['name']](**config['model_params'])
39 | experiment = VAEXperiment(model,
40 | config['exp_params'])
41 |
42 | data = VAEDataset(**config["data_params"], pin_memory=len(config['trainer_params']['gpus']) != 0)
43 |
44 | data.setup()
45 | runner = Trainer(logger=tb_logger,
46 | callbacks=[
47 | LearningRateMonitor(),
48 | ModelCheckpoint(save_top_k=2,
49 | dirpath =os.path.join(tb_logger.log_dir , "checkpoints"),
50 | monitor= "val_loss",
51 | save_last= True),
52 | ],
53 | strategy=DDPPlugin(find_unused_parameters=False),
54 | **config['trainer_params'])
55 |
56 |
57 | Path(f"{tb_logger.log_dir}/Samples").mkdir(exist_ok=True, parents=True)
58 | Path(f"{tb_logger.log_dir}/Reconstructions").mkdir(exist_ok=True, parents=True)
59 |
60 |
61 | print(f"======= Training {config['model_params']['name']} =======")
62 | runner.fit(experiment, datamodule=data)
--------------------------------------------------------------------------------
/tests/bvae.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import unittest
3 | from models import BetaVAE
4 | from torchsummary import summary
5 |
6 |
7 | class TestVAE(unittest.TestCase):
8 |
9 | def setUp(self) -> None:
10 | # self.model2 = VAE(3, 10)
11 | self.model = BetaVAE(3, 10, loss_type='H').cuda()
12 |
13 | def test_summary(self):
14 | print(summary(self.model, (3, 64, 64), device='cpu'))
15 | # print(summary(self.model2, (3, 64, 64), device='cpu'))
16 |
17 | def test_forward(self):
18 | x = torch.randn(16, 3, 64, 64)
19 | y = self.model(x)
20 | print("Model Output size:", y[0].size())
21 | # print("Model2 Output size:", self.model2(x)[0].size())
22 |
23 | def test_loss(self):
24 | x = torch.randn(16, 3, 64, 64).cuda()
25 |
26 | result = self.model(x)
27 | loss = self.model.loss_function(*result, M_N = 0.005)
28 | print(loss)
29 |
30 |
31 | if __name__ == '__main__':
32 | unittest.main()
--------------------------------------------------------------------------------
/tests/test_betatcvae.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import unittest
3 | from models import BetaTCVAE
4 | from torchsummary import summary
5 |
6 |
7 | class TestBetaTCVAE(unittest.TestCase):
8 |
9 | def setUp(self) -> None:
10 | # self.model2 = VAE(3, 10)
11 | self.model = BetaTCVAE(3, 64, anneal_steps= 100)
12 |
13 | def test_summary(self):
14 | print(summary(self.model, (3, 64, 64), device='cpu'))
15 | # print(summary(self.model2, (3, 64, 64), device='cpu'))
16 |
17 | def test_forward(self):
18 | print(sum(p.numel() for p in self.model.parameters() if p.requires_grad))
19 | x = torch.randn(16, 3, 64, 64)
20 | y = self.model(x)
21 | print("Model Output size:", y[0].size())
22 | # print("Model2 Output size:", self.model2(x)[0].size())
23 |
24 | def test_loss(self):
25 | x = torch.randn(16, 3, 64, 64)
26 |
27 | result = self.model(x)
28 | loss = self.model.loss_function(*result, M_N = 0.005)
29 | print(loss)
30 |
31 | def test_sample(self):
32 | self.model.cuda()
33 | y = self.model.sample(8, 'cuda')
34 | print(y.shape)
35 |
36 | def test_generate(self):
37 | x = torch.randn(16, 3, 64, 64)
38 | y = self.model.generate(x)
39 | print(y.shape)
40 |
41 |
42 | if __name__ == '__main__':
43 | unittest.main()
--------------------------------------------------------------------------------
/tests/test_cat_vae.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import unittest
3 | from models import GumbelVAE
4 | from torchsummary import summary
5 |
6 |
7 | class TestVAE(unittest.TestCase):
8 |
9 | def setUp(self) -> None:
10 | # self.model2 = VAE(3, 10)
11 | self.model = GumbelVAE(3, 10)
12 |
13 | def test_summary(self):
14 | print(summary(self.model, (3, 64, 64), device='cpu'))
15 | # print(summary(self.model2, (3, 64, 64), device='cpu'))
16 |
17 | def test_forward(self):
18 | x = torch.randn(16, 3, 64, 64)
19 | y = self.model(x)
20 | print("Model Output size:", y[0].size())
21 | # print("Model2 Output size:", self.model2(x)[0].size())
22 |
23 | def test_loss(self):
24 | x = torch.randn(128, 3, 64, 64)
25 |
26 | result = self.model(x)
27 | loss = self.model.loss_function(*result, M_N = 0.005, batch_idx=5)
28 | print(loss)
29 |
30 |
31 | def test_sample(self):
32 | self.model.cuda()
33 | y = self.model.sample(144, 0)
34 | print(y.shape)
35 |
36 |
37 | if __name__ == '__main__':
38 | unittest.main()
--------------------------------------------------------------------------------
/tests/test_dfc.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import unittest
3 | from models import DFCVAE
4 | from torchsummary import summary
5 |
6 |
7 | class TestDFCVAE(unittest.TestCase):
8 |
9 | def setUp(self) -> None:
10 | # self.model2 = VAE(3, 10)
11 | self.model = DFCVAE(3, 10)
12 |
13 | def test_summary(self):
14 | print(summary(self.model, (3, 64, 64), device='cpu'))
15 |
16 | # print(summary(self.model2, (3, 64, 64), device='cpu'))
17 |
18 | def test_forward(self):
19 | x = torch.randn(16, 3, 64, 64)
20 | y = self.model(x)
21 | print("Model Output size:", y[0].size())
22 |
23 | # print("Model2 Output size:", self.model2(x)[0].size())
24 |
25 | def test_loss(self):
26 | x = torch.randn(16, 3, 64, 64)
27 |
28 | result = self.model(x)
29 | loss = self.model.loss_function(*result, M_N = 0.005)
30 | print(loss)
31 |
32 | def test_sample(self):
33 | self.model.cuda()
34 | y = self.model.sample(144, 0)
35 |
36 |
37 |
38 |
39 | if __name__ == '__main__':
40 | unittest.main()
--------------------------------------------------------------------------------
/tests/test_dipvae.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import unittest
3 | from models import DIPVAE
4 | from torchsummary import summary
5 |
6 |
7 | class TestDIPVAE(unittest.TestCase):
8 |
9 | def setUp(self) -> None:
10 | # self.model2 = VAE(3, 10)
11 | self.model = DIPVAE(3, 64)
12 |
13 | def test_summary(self):
14 | print(summary(self.model, (3, 64, 64), device='cpu'))
15 | # print(summary(self.model2, (3, 64, 64), device='cpu'))
16 |
17 | def test_forward(self):
18 | print(sum(p.numel() for p in self.model.parameters() if p.requires_grad))
19 | x = torch.randn(16, 3, 64, 64)
20 | y = self.model(x)
21 | print("Model Output size:", y[0].size())
22 | # print("Model2 Output size:", self.model2(x)[0].size())
23 |
24 | def test_loss(self):
25 | x = torch.randn(16, 3, 64, 64)
26 |
27 | result = self.model(x)
28 | loss = self.model.loss_function(*result, M_N = 0.005)
29 | print(loss)
30 |
31 | def test_sample(self):
32 | self.model.cuda()
33 | y = self.model.sample(8, 'cuda')
34 | print(y.shape)
35 |
36 | def test_generate(self):
37 | x = torch.randn(16, 3, 64, 64)
38 | y = self.model.generate(x)
39 | print(y.shape)
40 |
41 |
42 | if __name__ == '__main__':
43 | unittest.main()
--------------------------------------------------------------------------------
/tests/test_fvae.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import unittest
3 | from models import FactorVAE
4 | from torchsummary import summary
5 |
6 |
7 | class TestFAE(unittest.TestCase):
8 |
9 | def setUp(self) -> None:
10 | # self.model2 = VAE(3, 10)
11 | self.model = FactorVAE(3, 10)
12 |
13 | def test_summary(self):
14 | print(summary(self.model, (3, 64, 64), device='cpu'))
15 | #
16 | # print(sum(p.numel() for p in self.model.parameters() if p.requires_grad))
17 |
18 | # print(summary(self.model2, (3, 64, 64), device='cpu'))
19 |
20 | def test_forward(self):
21 | x = torch.randn(16, 3, 64, 64)
22 | y = self.model(x)
23 | print("Model Output size:", y[0].size())
24 |
25 | # print("Model2 Output size:", self.model2(x)[0].size())
26 |
27 | def test_loss(self):
28 | x = torch.randn(16, 3, 64, 64)
29 | x2 = torch.randn(16,3, 64, 64)
30 |
31 | result = self.model(x)
32 | loss = self.model.loss_function(*result, M_N = 0.005, optimizer_idx=0, secondary_input=x2)
33 | loss = self.model.loss_function(*result, M_N = 0.005, optimizer_idx=1, secondary_input=x2)
34 | print(loss)
35 |
36 | def test_optim(self):
37 | optim1 = torch.optim.Adam(self.model.parameters(), lr = 0.001)
38 | optim2 = torch.optim.Adam(self.model.discrminator.parameters(), lr = 0.001)
39 |
40 | def test_sample(self):
41 | self.model.cuda()
42 | y = self.model.sample(144, 0)
43 |
44 |
45 |
46 |
47 | if __name__ == '__main__':
48 | unittest.main()
--------------------------------------------------------------------------------
/tests/test_gvae.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import unittest
3 | from models import GammaVAE
4 | from torchsummary import summary
5 |
6 |
7 | class TestGammaVAE(unittest.TestCase):
8 |
9 | def setUp(self) -> None:
10 | # self.model2 = VAE(3, 10)
11 | self.model = GammaVAE(3, 10)
12 |
13 | def test_summary(self):
14 | print(summary(self.model, (3, 64, 64), device='cpu'))
15 |
16 | # print(summary(self.model2, (3, 64, 64), device='cpu'))
17 |
18 | def test_forward(self):
19 | x = torch.randn(16, 3, 64, 64)
20 | y = self.model(x)
21 | print("Model Output size:", y[0].size())
22 |
23 | # print("Model2 Output size:", self.model2(x)[0].size())
24 |
25 | def test_loss(self):
26 | x = torch.randn(16, 3, 64, 64)
27 |
28 | result = self.model(x)
29 | loss = self.model.loss_function(*result, M_N = 0.005)
30 | print(loss)
31 |
32 | def test_sample(self):
33 | self.model.cuda()
34 | y = self.model.sample(144, 0)
35 |
36 |
37 |
38 |
39 | if __name__ == '__main__':
40 | unittest.main()
--------------------------------------------------------------------------------
/tests/test_hvae.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import unittest
3 | from models import HVAE
4 | from torchsummary import summary
5 |
6 |
7 | class TestHVAE(unittest.TestCase):
8 |
9 | def setUp(self) -> None:
10 | # self.model2 = VAE(3, 10)
11 | self.model = HVAE(3, latent1_dim=10, latent2_dim=20)
12 |
13 | def test_summary(self):
14 | print(summary(self.model, (3, 64, 64), device='cpu'))
15 | # print(summary(self.model2, (3, 64, 64), device='cpu'))
16 |
17 | def test_forward(self):
18 | x = torch.randn(16, 3, 64, 64)
19 | y = self.model(x)
20 | print("Model Output size:", y[0].size())
21 | # print("Model2 Output size:", self.model2(x)[0].size())
22 |
23 | def test_loss(self):
24 | x = torch.randn(16, 3, 64, 64)
25 |
26 | result = self.model(x)
27 | loss = self.model.loss_function(*result, M_N = 0.005)
28 | print(loss)
29 |
30 |
31 | if __name__ == '__main__':
32 | unittest.main()
--------------------------------------------------------------------------------
/tests/test_iwae.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import unittest
3 | from models import IWAE
4 | from torchsummary import summary
5 |
6 |
7 | class TestIWAE(unittest.TestCase):
8 |
9 | def setUp(self) -> None:
10 | # self.model2 = VAE(3, 10)
11 | self.model = IWAE(3, 10)
12 |
13 | def test_summary(self):
14 | print(summary(self.model, (3, 64, 64), device='cpu'))
15 | # print(summary(self.model2, (3, 64, 64), device='cpu'))
16 |
17 | def test_forward(self):
18 | x = torch.randn(16, 3, 64, 64)
19 | y = self.model(x)
20 | print("Model Output size:", y[0].size())
21 | # print("Model2 Output size:", self.model2(x)[0].size())
22 |
23 | def test_loss(self):
24 | x = torch.randn(16, 3, 64, 64)
25 |
26 | result = self.model(x)
27 | loss = self.model.loss_function(*result, M_N = 0.005)
28 | print(loss)
29 |
30 | def test_sample(self):
31 | self.model.cuda()
32 | y = self.model.sample(144, 0)
33 |
34 |
35 | if __name__ == '__main__':
36 | unittest.main()
--------------------------------------------------------------------------------
/tests/test_joint_Vae.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import unittest
3 | from models import JointVAE
4 | from torchsummary import summary
5 |
6 |
7 | class TestVAE(unittest.TestCase):
8 |
9 | def setUp(self) -> None:
10 | # self.model2 = VAE(3, 10)
11 | self.model = JointVAE(3, 10, 40, 0.0)
12 |
13 | def test_summary(self):
14 | print(summary(self.model, (3, 64, 64), device='cpu'))
15 | # print(summary(self.model2, (3, 64, 64), device='cpu'))
16 |
17 | def test_forward(self):
18 | x = torch.randn(16, 3, 64, 64)
19 | y = self.model(x)
20 | print("Model Output size:", y[0].size())
21 | # print("Model2 Output size:", self.model2(x)[0].size())
22 |
23 | def test_loss(self):
24 | x = torch.randn(128, 3, 64, 64)
25 |
26 | result = self.model(x)
27 | loss = self.model.loss_function(*result, M_N = 0.005, batch_idx=5)
28 | print(loss)
29 |
30 |
31 | def test_sample(self):
32 | self.model.cuda()
33 | y = self.model.sample(144, 0)
34 | print(y.shape)
35 |
36 |
37 | if __name__ == '__main__':
38 | unittest.main()
--------------------------------------------------------------------------------
/tests/test_logcosh.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import unittest
3 | from models import LogCoshVAE
4 | from torchsummary import summary
5 |
6 |
7 | class TestVAE(unittest.TestCase):
8 |
9 | def setUp(self) -> None:
10 | # self.model2 = VAE(3, 10)
11 | self.model = LogCoshVAE(3, 10, alpha=10)
12 |
13 | def test_summary(self):
14 | print(summary(self.model, (3, 64, 64), device='cpu'))
15 | # print(summary(self.model2, (3, 64, 64), device='cpu'))
16 |
17 | def test_forward(self):
18 | x = torch.randn(16, 3, 64, 64)
19 | y = self.model(x)
20 | print("Model Output size:", y[0].size())
21 | # print("Model2 Output size:", self.model2(x)[0].size())
22 |
23 | def test_loss(self):
24 | x = torch.rand(16, 3, 64, 64)
25 |
26 | result = self.model(x)
27 | loss = self.model.loss_function(*result, M_N = 0.005)
28 | print(loss)
29 |
30 |
31 | if __name__ == '__main__':
32 | unittest.main()
--------------------------------------------------------------------------------
/tests/test_lvae.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import unittest
3 | from models import LVAE
4 | from torchsummary import summary
5 |
6 |
7 | class TestLVAE(unittest.TestCase):
8 |
9 | def setUp(self) -> None:
10 | # self.model2 = VAE(3, 10)
11 | self.model = LVAE(3, [4,8,16,32,128], hidden_dims=[32, 64,128, 256, 512])
12 |
13 | def test_summary(self):
14 | print(summary(self.model, (3, 64, 64), device='cpu'))
15 | # print(summary(self.model2, (3, 64, 64), device='cpu'))
16 |
17 | def test_forward(self):
18 | x = torch.randn(16, 3, 64, 64)
19 | y = self.model(x)
20 |
21 | print("Model Output size:", y[0].size())
22 | # print("Model2 Output size:", self.model2(x)[0].size())
23 |
24 | def test_loss(self):
25 | x = torch.randn(16, 3, 64, 64)
26 |
27 | result = self.model(x)
28 | loss = self.model.loss_function(*result, M_N = 0.005)
29 | print(loss)
30 |
31 | def test_sample(self):
32 | self.model.cuda()
33 | y = self.model.sample(144, 0)
34 | print(y.shape)
35 |
36 |
37 | if __name__ == '__main__':
38 | unittest.main()
--------------------------------------------------------------------------------
/tests/test_miwae.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import unittest
3 | from models import MIWAE
4 | from torchsummary import summary
5 |
6 |
7 | class TestMIWAE(unittest.TestCase):
8 |
9 | def setUp(self) -> None:
10 | # self.model2 = VAE(3, 10)
11 | self.model = MIWAE(3, 10)
12 |
13 | def test_summary(self):
14 | print(summary(self.model, (3, 64, 64), device='cpu'))
15 | # print(summary(self.model2, (3, 64, 64), device='cpu'))
16 |
17 | def test_forward(self):
18 | x = torch.randn(16, 3, 64, 64)
19 | y = self.model(x)
20 | print("Model Output size:", y[0].size())
21 | # print("Model2 Output size:", self.model2(x)[0].size())
22 |
23 | def test_loss(self):
24 | x = torch.randn(16, 3, 64, 64)
25 |
26 | result = self.model(x)
27 | loss = self.model.loss_function(*result, M_N = 0.005)
28 | print(loss)
29 |
30 | def test_sample(self):
31 | self.model.cuda()
32 | y = self.model.sample(144, 0)
33 | print(y.shape)
34 |
35 | def test_generate(self):
36 | x = torch.randn(16, 3, 64, 64)
37 | y = self.model.generate(x)
38 | print(y.shape)
39 |
40 |
41 | if __name__ == '__main__':
42 | unittest.main()
--------------------------------------------------------------------------------
/tests/test_mssimvae.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import unittest
3 | from models import MSSIMVAE
4 | from torchsummary import summary
5 |
6 |
7 | class TestMSSIMVAE(unittest.TestCase):
8 |
9 | def setUp(self) -> None:
10 | # self.model2 = VAE(3, 10)
11 | self.model = MSSIMVAE(3, 10)
12 |
13 | def test_summary(self):
14 | print(summary(self.model, (3, 64, 64), device='cpu'))
15 |
16 | # print(summary(self.model2, (3, 64, 64), device='cpu'))
17 |
18 | def test_forward(self):
19 | x = torch.randn(16, 3, 64, 64)
20 | y = self.model(x)
21 | print("Model Output size:", y[0].size())
22 |
23 | # print("Model2 Output size:", self.model2(x)[0].size())
24 |
25 | def test_loss(self):
26 | x = torch.randn(16, 3, 64, 64)
27 |
28 | result = self.model(x)
29 | loss = self.model.loss_function(*result, M_N = 0.005)
30 | print(loss)
31 |
32 | def test_sample(self):
33 | self.model.cuda()
34 | y = self.model.sample(144, 0)
35 |
36 |
37 |
38 |
39 | if __name__ == '__main__':
40 | unittest.main()
--------------------------------------------------------------------------------
/tests/test_swae.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import unittest
3 | from models import SWAE
4 | from torchsummary import summary
5 |
6 |
7 | class TestSWAE(unittest.TestCase):
8 |
9 | def setUp(self) -> None:
10 | self.model = SWAE(3, 10, reg_weight = 100)
11 |
12 | def test_summary(self):
13 | print(summary(self.model, (3, 64, 64), device='cpu'))
14 | # print(summary(self.model2, (3, 64, 64), device='cpu'))
15 |
16 | def test_forward(self):
17 | x = torch.randn(16, 3, 64, 64)
18 | y = self.model(x)
19 | print("Model Output size:", y[0].size())
20 | # print("Model2 Output size:", self.model2(x)[0].size())
21 |
22 | def test_loss(self):
23 | x = torch.randn(16, 3, 64, 64)
24 |
25 | result = self.model(x)
26 | loss = self.model.loss_function(*result)
27 | print(loss)
28 |
29 |
30 | if __name__ == '__main__':
31 | unittest.main()
--------------------------------------------------------------------------------
/tests/test_vae.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import unittest
3 | from models import VanillaVAE
4 | from torchsummary import summary
5 |
6 |
7 | class TestVAE(unittest.TestCase):
8 |
9 | def setUp(self) -> None:
10 | # self.model2 = VAE(3, 10)
11 | self.model = VanillaVAE(3, 10)
12 |
13 | def test_summary(self):
14 | print(summary(self.model, (3, 64, 64), device='cpu'))
15 | # print(summary(self.model2, (3, 64, 64), device='cpu'))
16 |
17 | def test_forward(self):
18 | x = torch.randn(16, 3, 64, 64)
19 | y = self.model(x)
20 | print("Model Output size:", y[0].size())
21 | # print("Model2 Output size:", self.model2(x)[0].size())
22 |
23 | def test_loss(self):
24 | x = torch.randn(16, 3, 64, 64)
25 |
26 | result = self.model(x)
27 | loss = self.model.loss_function(*result, M_N = 0.005)
28 | print(loss)
29 |
30 |
31 | if __name__ == '__main__':
32 | unittest.main()
--------------------------------------------------------------------------------
/tests/test_vq_vae.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import unittest
3 | from models import VQVAE
4 | from torchsummary import summary
5 |
6 |
7 | class TestVQVAE(unittest.TestCase):
8 |
9 | def setUp(self) -> None:
10 | # self.model2 = VAE(3, 10)
11 | self.model = VQVAE(3, 64, 512)
12 |
13 | def test_summary(self):
14 | print(summary(self.model, (3, 64, 64), device='cpu'))
15 | # print(summary(self.model2, (3, 64, 64), device='cpu'))
16 |
17 | def test_forward(self):
18 | print(sum(p.numel() for p in self.model.parameters() if p.requires_grad))
19 | x = torch.randn(16, 3, 64, 64)
20 | y = self.model(x)
21 | print("Model Output size:", y[0].size())
22 | # print("Model2 Output size:", self.model2(x)[0].size())
23 |
24 | def test_loss(self):
25 | x = torch.randn(16, 3, 64, 64)
26 |
27 | result = self.model(x)
28 | loss = self.model.loss_function(*result, M_N = 0.005)
29 | print(loss)
30 |
31 | def test_sample(self):
32 | self.model.cuda()
33 | y = self.model.sample(8, 'cuda')
34 | print(y.shape)
35 |
36 | def test_generate(self):
37 | x = torch.randn(16, 3, 64, 64)
38 | y = self.model.generate(x)
39 | print(y.shape)
40 |
41 |
42 | if __name__ == '__main__':
43 | unittest.main()
--------------------------------------------------------------------------------
/tests/test_wae.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import unittest
3 | from models import WAE_MMD
4 | from torchsummary import summary
5 |
6 |
7 | class TestWAE(unittest.TestCase):
8 |
9 | def setUp(self) -> None:
10 | self.model = WAE_MMD(3, 10, reg_weight = 100)
11 |
12 | def test_summary(self):
13 | print(summary(self.model, (3, 64, 64), device='cpu'))
14 | # print(summary(self.model2, (3, 64, 64), device='cpu'))
15 |
16 | def test_forward(self):
17 | x = torch.randn(16, 3, 64, 64)
18 | y = self.model(x)
19 | print("Model Output size:", y[0].size())
20 | # print("Model2 Output size:", self.model2(x)[0].size())
21 |
22 | def test_loss(self):
23 | x = torch.randn(16, 3, 64, 64)
24 |
25 | result = self.model(x)
26 | loss = self.model.loss_function(*result)
27 | print(loss)
28 |
29 |
30 | if __name__ == '__main__':
31 | unittest.main()
--------------------------------------------------------------------------------
/tests/text_cvae.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import unittest
3 | from models import CVAE
4 |
5 |
6 | class TestCVAE(unittest.TestCase):
7 |
8 | def setUp(self) -> None:
9 | # self.model2 = VAE(3, 10)
10 | self.model = CVAE(3, 40, 10)
11 |
12 | def test_forward(self):
13 | x = torch.randn(16, 3, 64, 64)
14 | c = torch.randn(16, 40)
15 | y = self.model(x, c)
16 | print("Model Output size:", y[0].size())
17 | # print("Model2 Output size:", self.model2(x)[0].size())
18 |
19 | def test_loss(self):
20 | x = torch.randn(16, 3, 64, 64)
21 | c = torch.randn(16, 40)
22 | result = self.model(x, labels = c)
23 | loss = self.model.loss_function(*result, M_N = 0.005)
24 | print(loss)
25 |
26 |
27 | if __name__ == '__main__':
28 | unittest.main()
--------------------------------------------------------------------------------
/tests/text_vamp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import unittest
3 | from models import VampVAE
4 | from torchsummary import summary
5 |
6 |
7 | class TestVVAE(unittest.TestCase):
8 |
9 | def setUp(self) -> None:
10 | # self.model2 = VAE(3, 10)
11 | self.model = VampVAE(3, latent_dim=10).cuda()
12 |
13 | def test_summary(self):
14 | print(summary(self.model, (3, 64, 64), device='cpu'))
15 | # print(summary(self.model2, (3, 64, 64), device='cpu'))
16 |
17 | def test_forward(self):
18 | x = torch.randn(16, 3, 64, 64)
19 | y = self.model(x)
20 | print("Model Output size:", y[0].size())
21 | # print("Model2 Output size:", self.model2(x)[0].size())
22 |
23 | def test_loss(self):
24 | x = torch.randn(144, 3, 64, 64).cuda()
25 |
26 | result = self.model(x)
27 | loss = self.model.loss_function(*result, M_N = 0.005)
28 | print(loss)
29 |
30 |
31 | if __name__ == '__main__':
32 | unittest.main()
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import pytorch_lightning as pl
2 |
3 |
4 | ## Utils to handle newer PyTorch Lightning changes from version 0.6
5 | ## ==================================================================================================== ##
6 |
7 |
8 | def data_loader(fn):
9 | """
10 | Decorator to handle the deprecation of data_loader from 0.7
11 | :param fn: User defined data loader function
12 | :return: A wrapper for the data_loader function
13 | """
14 |
15 | def func_wrapper(self):
16 | try: # Works for version 0.6.0
17 | return pl.data_loader(fn)(self)
18 |
19 | except: # Works for version > 0.6.0
20 | return fn(self)
21 |
22 | return func_wrapper
23 |
--------------------------------------------------------------------------------