├── .gitignore ├── assets ├── aae │ ├── celeba.jpg │ ├── cifar10.jpg │ └── mnist.jpg ├── age │ └── mnist.jpg ├── beta_vae │ ├── celeba_sample.jpg │ ├── celeba_traverse.jpg │ ├── dsprites_sample.jpg │ └── dsprites_traverse.jpg ├── bigan │ ├── celeba.jpg │ ├── cifar10.jpg │ └── mnist.jpg ├── convert.py ├── cvae │ ├── cifar10.jpg │ └── mnist.jpg ├── ddpm │ ├── celeba.jpg │ ├── cifar10.jpg │ └── mnist.jpg ├── factor_vae │ ├── fvae_celeba_traverse.jpg │ ├── fvae_dsprites_sample.jpg │ ├── fvae_dsprites_traverse.jpg │ └── fvae_sample_celeba.jpg ├── gan │ ├── celeba.jpg │ ├── cifar10.jpg │ └── mnist.jpg ├── ggan │ ├── celeba.jpg │ ├── cifar10.jpg │ └── mnist.jpg ├── infogan │ ├── class.jpg │ ├── random.jpg │ ├── rotation.jpg │ └── thickness.jpg ├── lsgan │ ├── celeba.jpg │ ├── cifar10.jpg │ └── mnist.jpg ├── made │ └── mnist.jpg ├── pixelcnn │ ├── mnist.jpg │ └── mnist_cond.jpg ├── tar │ ├── mnist.jpg │ └── mnist_cond.jpg ├── vae │ ├── celeba.jpg │ ├── cifar10.jpg │ └── mnist.jpg ├── vaegan │ ├── celeba.jpg │ ├── cifar10.jpg │ └── mnist.jpg ├── vqvae │ ├── celeba_real.jpg │ ├── celeba_recon.jpg │ ├── cifar10_real.jpg │ ├── cifar10_recon.jpg │ ├── mnist_real.jpg │ └── mnist_recon.jpg ├── wgan │ ├── celeba.jpg │ ├── cifar10.jpg │ └── mnist.jpg ├── wgangp │ ├── celeba.jpg │ ├── cifar10.jpg │ └── mnist.jpg └── wiki │ └── wgan │ ├── changeD.jpg │ ├── changeG.jpg │ └── logits.png ├── configs ├── callbacks │ ├── ar_models.yaml │ ├── default.yaml │ ├── eval_fid.yaml │ ├── latent_visual.yaml │ ├── sample.yaml │ ├── tqdm.yaml │ └── traverse_latent.yaml ├── config.yaml ├── datamodule │ ├── celeba.yaml │ ├── cifar10.yaml │ ├── dsprites.yaml │ └── mnist.yaml ├── experiment │ ├── aae │ │ ├── celeba.yaml │ │ ├── cifar10.yaml │ │ └── mnist.yaml │ ├── age │ │ ├── celeba.yaml │ │ ├── cifar10.yaml │ │ └── mnist.yaml │ ├── beta_vae │ │ ├── celeba.yaml │ │ └── dsprites.yaml │ ├── bigan │ │ ├── celeba.yaml │ │ ├── cifar10.yaml │ │ └── mnist.yaml │ ├── contra_gan │ │ └── dsprites.yaml │ ├── contra_vae │ │ └── dsprites.yaml │ ├── cvae │ │ ├── cifar10.yaml │ │ └── mnist.yaml │ ├── ddpm │ │ ├── celeba.yaml │ │ ├── cifar10.yaml │ │ └── mnist.yaml │ ├── factor_vae │ │ ├── celeba.yaml │ │ └── dsprites.yaml │ ├── ggan │ │ ├── celeba.yaml │ │ ├── cifar10.yaml │ │ └── mnist_conv.yaml │ ├── infogan │ │ └── mnist.yaml │ ├── lsgan │ │ ├── celeba.yaml │ │ ├── cifar10.yaml │ │ ├── conv_mnist.yaml │ │ └── mlp_mnist.yaml │ ├── made │ │ └── mnist.yaml │ ├── pixelcnn │ │ ├── celeba.yaml │ │ ├── cifar10.yaml │ │ └── mnist.yaml │ ├── tar │ │ ├── mnist.yaml │ │ └── mnist_cond.yaml │ ├── vae │ │ ├── celeba.yaml │ │ ├── cifar10.yaml │ │ ├── mnist_conv.yaml │ │ └── mnist_mlp.yaml │ ├── vaegan │ │ ├── celeba.yaml │ │ ├── cifar10.yaml │ │ └── mnist.yaml │ ├── vanilla_gan │ │ ├── celeba.yaml │ │ ├── cifar10.yaml │ │ ├── dsprites.yaml │ │ ├── mnist_conv.yaml │ │ └── mnist_mlp.yaml │ ├── vqvae │ │ ├── celeba.yaml │ │ ├── cifar10.yaml │ │ └── mnist.yaml │ ├── wgan │ │ ├── celeba.yaml │ │ ├── cifar10.yaml │ │ ├── mnist_conv.yaml │ │ └── mnist_mlp.yaml │ └── wgan_gp │ │ ├── celeba.yaml │ │ ├── cifar10.yaml │ │ ├── mnist_conv.yaml │ │ └── mnist_mlp.yaml ├── hydra │ └── default.yaml ├── logger │ └── tensorboard.yaml ├── model │ ├── aae.yaml │ ├── age.yaml │ ├── bigan.yaml │ ├── cvae.yaml │ ├── ddpm.yaml │ ├── factor_vae.yaml │ ├── gan.yaml │ ├── info_gan.yaml │ ├── made.yaml │ ├── pixelcnn.yaml │ ├── speed_gan.yaml │ ├── tar.yaml │ ├── vae.yaml │ ├── vae_gan.yaml │ ├── vqvae.yaml │ ├── wgan.yaml │ └── wgan_gp.yaml ├── networks │ ├── conv_32.yaml │ ├── conv_64.yaml │ ├── conv_mnist.yaml │ ├── mlp.yaml │ ├── mlp_small.yaml │ └── vqvae.yaml └── trainer │ └── default.yaml ├── readme.adoc ├── requirements.txt ├── run.py ├── src ├── callbacks │ ├── evaluation.py │ ├── util.py │ └── visualization.py ├── datamodules │ ├── base.py │ ├── basic.py │ ├── celeba.py │ ├── cifar10.py │ ├── dsprite.py │ ├── lsun.py │ ├── mnist.py │ └── utils.py ├── models │ ├── BiGAN.py │ ├── aae.py │ ├── age.py │ ├── base.py │ ├── cvae.py │ ├── ddpm.py │ ├── factor_vae.py │ ├── gan.py │ ├── info_gan.py │ ├── made.py │ ├── pixelcnn.py │ ├── speed_gan.py │ ├── tar.py │ ├── vae.py │ ├── vae_gan.py │ ├── vqvae.py │ ├── wgan.py │ └── wgan_gp.py ├── networks │ ├── base.py │ ├── basic.py │ ├── conv32.py │ ├── conv64.py │ ├── utils.py │ └── vqvae.py ├── train.py └── utils │ ├── distributions.py │ ├── losses.py │ ├── toy.py │ ├── utils.py │ └── visual.py └── wiki └── wgan.adoc /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | logs/** 3 | data/ -------------------------------------------------------------------------------- /assets/aae/celeba.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victarry/Image-Generation-models/3609ddb78305943a0af8495d680599da6cadcd90/assets/aae/celeba.jpg -------------------------------------------------------------------------------- /assets/aae/cifar10.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victarry/Image-Generation-models/3609ddb78305943a0af8495d680599da6cadcd90/assets/aae/cifar10.jpg -------------------------------------------------------------------------------- /assets/aae/mnist.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victarry/Image-Generation-models/3609ddb78305943a0af8495d680599da6cadcd90/assets/aae/mnist.jpg -------------------------------------------------------------------------------- /assets/age/mnist.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victarry/Image-Generation-models/3609ddb78305943a0af8495d680599da6cadcd90/assets/age/mnist.jpg -------------------------------------------------------------------------------- /assets/beta_vae/celeba_sample.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victarry/Image-Generation-models/3609ddb78305943a0af8495d680599da6cadcd90/assets/beta_vae/celeba_sample.jpg -------------------------------------------------------------------------------- /assets/beta_vae/celeba_traverse.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victarry/Image-Generation-models/3609ddb78305943a0af8495d680599da6cadcd90/assets/beta_vae/celeba_traverse.jpg -------------------------------------------------------------------------------- /assets/beta_vae/dsprites_sample.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victarry/Image-Generation-models/3609ddb78305943a0af8495d680599da6cadcd90/assets/beta_vae/dsprites_sample.jpg -------------------------------------------------------------------------------- /assets/beta_vae/dsprites_traverse.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victarry/Image-Generation-models/3609ddb78305943a0af8495d680599da6cadcd90/assets/beta_vae/dsprites_traverse.jpg -------------------------------------------------------------------------------- /assets/bigan/celeba.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victarry/Image-Generation-models/3609ddb78305943a0af8495d680599da6cadcd90/assets/bigan/celeba.jpg -------------------------------------------------------------------------------- /assets/bigan/cifar10.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victarry/Image-Generation-models/3609ddb78305943a0af8495d680599da6cadcd90/assets/bigan/cifar10.jpg -------------------------------------------------------------------------------- /assets/bigan/mnist.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victarry/Image-Generation-models/3609ddb78305943a0af8495d680599da6cadcd90/assets/bigan/mnist.jpg -------------------------------------------------------------------------------- /assets/convert.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from pathlib import Path 3 | import os 4 | 5 | path = Path(".") 6 | for x in path.glob("**/*.png"): 7 | img = Image.open(x) 8 | os.remove(x) 9 | new_x = x.with_suffix(".jpg") 10 | img.save(new_x) -------------------------------------------------------------------------------- /assets/cvae/cifar10.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victarry/Image-Generation-models/3609ddb78305943a0af8495d680599da6cadcd90/assets/cvae/cifar10.jpg -------------------------------------------------------------------------------- /assets/cvae/mnist.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victarry/Image-Generation-models/3609ddb78305943a0af8495d680599da6cadcd90/assets/cvae/mnist.jpg -------------------------------------------------------------------------------- /assets/ddpm/celeba.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victarry/Image-Generation-models/3609ddb78305943a0af8495d680599da6cadcd90/assets/ddpm/celeba.jpg -------------------------------------------------------------------------------- /assets/ddpm/cifar10.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victarry/Image-Generation-models/3609ddb78305943a0af8495d680599da6cadcd90/assets/ddpm/cifar10.jpg -------------------------------------------------------------------------------- /assets/ddpm/mnist.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victarry/Image-Generation-models/3609ddb78305943a0af8495d680599da6cadcd90/assets/ddpm/mnist.jpg -------------------------------------------------------------------------------- /assets/factor_vae/fvae_celeba_traverse.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victarry/Image-Generation-models/3609ddb78305943a0af8495d680599da6cadcd90/assets/factor_vae/fvae_celeba_traverse.jpg -------------------------------------------------------------------------------- /assets/factor_vae/fvae_dsprites_sample.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victarry/Image-Generation-models/3609ddb78305943a0af8495d680599da6cadcd90/assets/factor_vae/fvae_dsprites_sample.jpg -------------------------------------------------------------------------------- /assets/factor_vae/fvae_dsprites_traverse.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victarry/Image-Generation-models/3609ddb78305943a0af8495d680599da6cadcd90/assets/factor_vae/fvae_dsprites_traverse.jpg -------------------------------------------------------------------------------- /assets/factor_vae/fvae_sample_celeba.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victarry/Image-Generation-models/3609ddb78305943a0af8495d680599da6cadcd90/assets/factor_vae/fvae_sample_celeba.jpg -------------------------------------------------------------------------------- /assets/gan/celeba.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victarry/Image-Generation-models/3609ddb78305943a0af8495d680599da6cadcd90/assets/gan/celeba.jpg -------------------------------------------------------------------------------- /assets/gan/cifar10.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victarry/Image-Generation-models/3609ddb78305943a0af8495d680599da6cadcd90/assets/gan/cifar10.jpg -------------------------------------------------------------------------------- /assets/gan/mnist.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victarry/Image-Generation-models/3609ddb78305943a0af8495d680599da6cadcd90/assets/gan/mnist.jpg -------------------------------------------------------------------------------- /assets/ggan/celeba.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victarry/Image-Generation-models/3609ddb78305943a0af8495d680599da6cadcd90/assets/ggan/celeba.jpg -------------------------------------------------------------------------------- /assets/ggan/cifar10.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victarry/Image-Generation-models/3609ddb78305943a0af8495d680599da6cadcd90/assets/ggan/cifar10.jpg -------------------------------------------------------------------------------- /assets/ggan/mnist.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victarry/Image-Generation-models/3609ddb78305943a0af8495d680599da6cadcd90/assets/ggan/mnist.jpg -------------------------------------------------------------------------------- /assets/infogan/class.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victarry/Image-Generation-models/3609ddb78305943a0af8495d680599da6cadcd90/assets/infogan/class.jpg -------------------------------------------------------------------------------- /assets/infogan/random.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victarry/Image-Generation-models/3609ddb78305943a0af8495d680599da6cadcd90/assets/infogan/random.jpg -------------------------------------------------------------------------------- /assets/infogan/rotation.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victarry/Image-Generation-models/3609ddb78305943a0af8495d680599da6cadcd90/assets/infogan/rotation.jpg -------------------------------------------------------------------------------- /assets/infogan/thickness.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victarry/Image-Generation-models/3609ddb78305943a0af8495d680599da6cadcd90/assets/infogan/thickness.jpg -------------------------------------------------------------------------------- /assets/lsgan/celeba.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victarry/Image-Generation-models/3609ddb78305943a0af8495d680599da6cadcd90/assets/lsgan/celeba.jpg -------------------------------------------------------------------------------- /assets/lsgan/cifar10.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victarry/Image-Generation-models/3609ddb78305943a0af8495d680599da6cadcd90/assets/lsgan/cifar10.jpg -------------------------------------------------------------------------------- /assets/lsgan/mnist.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victarry/Image-Generation-models/3609ddb78305943a0af8495d680599da6cadcd90/assets/lsgan/mnist.jpg -------------------------------------------------------------------------------- /assets/made/mnist.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victarry/Image-Generation-models/3609ddb78305943a0af8495d680599da6cadcd90/assets/made/mnist.jpg -------------------------------------------------------------------------------- /assets/pixelcnn/mnist.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victarry/Image-Generation-models/3609ddb78305943a0af8495d680599da6cadcd90/assets/pixelcnn/mnist.jpg -------------------------------------------------------------------------------- /assets/pixelcnn/mnist_cond.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victarry/Image-Generation-models/3609ddb78305943a0af8495d680599da6cadcd90/assets/pixelcnn/mnist_cond.jpg -------------------------------------------------------------------------------- /assets/tar/mnist.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victarry/Image-Generation-models/3609ddb78305943a0af8495d680599da6cadcd90/assets/tar/mnist.jpg -------------------------------------------------------------------------------- /assets/tar/mnist_cond.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victarry/Image-Generation-models/3609ddb78305943a0af8495d680599da6cadcd90/assets/tar/mnist_cond.jpg -------------------------------------------------------------------------------- /assets/vae/celeba.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victarry/Image-Generation-models/3609ddb78305943a0af8495d680599da6cadcd90/assets/vae/celeba.jpg -------------------------------------------------------------------------------- /assets/vae/cifar10.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victarry/Image-Generation-models/3609ddb78305943a0af8495d680599da6cadcd90/assets/vae/cifar10.jpg -------------------------------------------------------------------------------- /assets/vae/mnist.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victarry/Image-Generation-models/3609ddb78305943a0af8495d680599da6cadcd90/assets/vae/mnist.jpg -------------------------------------------------------------------------------- /assets/vaegan/celeba.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victarry/Image-Generation-models/3609ddb78305943a0af8495d680599da6cadcd90/assets/vaegan/celeba.jpg -------------------------------------------------------------------------------- /assets/vaegan/cifar10.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victarry/Image-Generation-models/3609ddb78305943a0af8495d680599da6cadcd90/assets/vaegan/cifar10.jpg -------------------------------------------------------------------------------- /assets/vaegan/mnist.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victarry/Image-Generation-models/3609ddb78305943a0af8495d680599da6cadcd90/assets/vaegan/mnist.jpg -------------------------------------------------------------------------------- /assets/vqvae/celeba_real.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victarry/Image-Generation-models/3609ddb78305943a0af8495d680599da6cadcd90/assets/vqvae/celeba_real.jpg -------------------------------------------------------------------------------- /assets/vqvae/celeba_recon.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victarry/Image-Generation-models/3609ddb78305943a0af8495d680599da6cadcd90/assets/vqvae/celeba_recon.jpg -------------------------------------------------------------------------------- /assets/vqvae/cifar10_real.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victarry/Image-Generation-models/3609ddb78305943a0af8495d680599da6cadcd90/assets/vqvae/cifar10_real.jpg -------------------------------------------------------------------------------- /assets/vqvae/cifar10_recon.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victarry/Image-Generation-models/3609ddb78305943a0af8495d680599da6cadcd90/assets/vqvae/cifar10_recon.jpg -------------------------------------------------------------------------------- /assets/vqvae/mnist_real.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victarry/Image-Generation-models/3609ddb78305943a0af8495d680599da6cadcd90/assets/vqvae/mnist_real.jpg -------------------------------------------------------------------------------- /assets/vqvae/mnist_recon.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victarry/Image-Generation-models/3609ddb78305943a0af8495d680599da6cadcd90/assets/vqvae/mnist_recon.jpg -------------------------------------------------------------------------------- /assets/wgan/celeba.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victarry/Image-Generation-models/3609ddb78305943a0af8495d680599da6cadcd90/assets/wgan/celeba.jpg -------------------------------------------------------------------------------- /assets/wgan/cifar10.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victarry/Image-Generation-models/3609ddb78305943a0af8495d680599da6cadcd90/assets/wgan/cifar10.jpg -------------------------------------------------------------------------------- /assets/wgan/mnist.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victarry/Image-Generation-models/3609ddb78305943a0af8495d680599da6cadcd90/assets/wgan/mnist.jpg -------------------------------------------------------------------------------- /assets/wgangp/celeba.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victarry/Image-Generation-models/3609ddb78305943a0af8495d680599da6cadcd90/assets/wgangp/celeba.jpg -------------------------------------------------------------------------------- /assets/wgangp/cifar10.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victarry/Image-Generation-models/3609ddb78305943a0af8495d680599da6cadcd90/assets/wgangp/cifar10.jpg -------------------------------------------------------------------------------- /assets/wgangp/mnist.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victarry/Image-Generation-models/3609ddb78305943a0af8495d680599da6cadcd90/assets/wgangp/mnist.jpg -------------------------------------------------------------------------------- /assets/wiki/wgan/changeD.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victarry/Image-Generation-models/3609ddb78305943a0af8495d680599da6cadcd90/assets/wiki/wgan/changeD.jpg -------------------------------------------------------------------------------- /assets/wiki/wgan/changeG.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victarry/Image-Generation-models/3609ddb78305943a0af8495d680599da6cadcd90/assets/wiki/wgan/changeG.jpg -------------------------------------------------------------------------------- /assets/wiki/wgan/logits.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Victarry/Image-Generation-models/3609ddb78305943a0af8495d680599da6cadcd90/assets/wiki/wgan/logits.png -------------------------------------------------------------------------------- /configs/callbacks/ar_models.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - sample 3 | - tqdm -------------------------------------------------------------------------------- /configs/callbacks/default.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - eval_fid 3 | - latent_visual 4 | - sample 5 | - traverse_latent 6 | - tqdm -------------------------------------------------------------------------------- /configs/callbacks/eval_fid.yaml: -------------------------------------------------------------------------------- 1 | eval_fid: 2 | _target_: src.callbacks.evaluation.FIDEvaluationCallback -------------------------------------------------------------------------------- /configs/callbacks/latent_visual.yaml: -------------------------------------------------------------------------------- 1 | latent_visual: 2 | _target_: src.callbacks.visualization.LatentVisualizationCallback -------------------------------------------------------------------------------- /configs/callbacks/sample.yaml: -------------------------------------------------------------------------------- 1 | sample: 2 | _target_: src.callbacks.visualization.SampleImagesCallback 3 | batch_size: 64 4 | every_n_epochs: 1 -------------------------------------------------------------------------------- /configs/callbacks/tqdm.yaml: -------------------------------------------------------------------------------- 1 | tqdm: 2 | _target_: pytorch_lightning.callbacks.progress.TQDMProgressBar 3 | refresh_rate: 5 -------------------------------------------------------------------------------- /configs/callbacks/traverse_latent.yaml: -------------------------------------------------------------------------------- 1 | traverse: 2 | _target_: src.callbacks.visualization.TraverseLatentCallback -------------------------------------------------------------------------------- /configs/config.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # specify here default training configuration 4 | defaults: 5 | - callbacks: default 6 | - trainer: default 7 | - model: null 8 | - networks: null 9 | - datamodule: null 10 | - logger: tensorboard # set logger here or use command line (e.g. `python run.py logger=wandb`) 11 | - hydra: default 12 | - _self_ 13 | - experiment: null 14 | 15 | # enable color logging 16 | - override hydra/hydra_logging: colorlog 17 | - override hydra/job_logging: colorlog 18 | - override hydra/launcher: joblib 19 | 20 | # path to original working directory 21 | # hydra hijacks working directory by changing it to the current log directory, 22 | # so it's useful to have this path as a special variable 23 | # learn more here: https://hydra.cc/docs/next/tutorials/basic/running_your_app/working_directory 24 | work_dir: ${hydra:runtime.cwd} 25 | 26 | # path to folder with data 27 | data_dir: ${work_dir}/data/ 28 | 29 | # path to logging 30 | log_dir: logs 31 | 32 | # exp_name 33 | exp_name: ${now:%Y-%m-%d}/${now:%H-%M-%S} 34 | 35 | # pretty print config at the start of the run using Rich library 36 | print_config: True 37 | 38 | # disable python warnings if they annoy you 39 | ignore_warnings: False -------------------------------------------------------------------------------- /configs/datamodule/celeba.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /networks: conv_64 4 | 5 | datamodule: 6 | _target_: src.datamodules.celeba.CelebADataModule 7 | data_dir: ${data_dir} 8 | width: 64 9 | height: 64 10 | channels: 3 11 | batch_size: 128 12 | num_workers: 8 13 | n_classes: None 14 | transforms: 15 | convert: True 16 | normalize: True 17 | resize: 18 | width: ${datamodule.width} 19 | height: ${datamodule.height} 20 | -------------------------------------------------------------------------------- /configs/datamodule/cifar10.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /networks: conv_32 4 | 5 | datamodule: 6 | _target_: src.datamodules.cifar10.CIFAR10DataModule 7 | data_dir: ${data_dir} 8 | channels: 3 9 | width: 32 10 | height: 32 11 | batch_size: 128 12 | num_workers: 8 13 | n_classes: 10 14 | transforms: 15 | convert: True 16 | normalize: True -------------------------------------------------------------------------------- /configs/datamodule/dsprites.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /networks: conv_64 4 | 5 | 6 | datamodule: 7 | _target_: src.datamodules.dsprite.DataModule 8 | data_dir: ${data_dir} 9 | channels: 1 10 | width: 64 11 | height: 64 12 | batch_size: 128 13 | num_workers: 4 14 | transforms: 15 | grayscale: True 16 | normalize: False 17 | -------------------------------------------------------------------------------- /configs/datamodule/mnist.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /networks: conv_mnist 4 | 5 | datamodule: 6 | _target_: src.datamodules.mnist.MNISTDataModule 7 | data_dir: ${data_dir} 8 | channels: 1 9 | width: 28 10 | height: 28 11 | batch_size: 128 12 | num_workers: 8 13 | n_classes: 10 14 | transforms: 15 | convert: True 16 | normalize: True 17 | grayscale: True -------------------------------------------------------------------------------- /configs/experiment/aae/celeba.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: aae 4 | - override /networks: conv_64 5 | - override /datamodule: celeba 6 | 7 | model: 8 | loss_mode: vanilla 9 | recon_weight: 10 10 | latent_dim: 64 11 | trainer: 12 | max_epochs: 100 13 | exp_name: aae/celeba -------------------------------------------------------------------------------- /configs/experiment/aae/cifar10.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: aae 4 | - override /networks: conv_32 5 | - override /datamodule: cifar10 6 | 7 | model: 8 | loss_mode: vanilla 9 | recon_weight: 10 10 | latent_dim: 64 11 | trainer: 12 | max_epochs: 100 13 | exp_name: aae/cifar10 -------------------------------------------------------------------------------- /configs/experiment/aae/mnist.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: aae 4 | - override /networks: conv_mnist 5 | - override /datamodule: mnist 6 | 7 | model: 8 | loss_mode: vanilla 9 | latent_dim: 8 10 | trainer: 11 | max_epochs: 50 12 | exp_name: aae/mnist_conv -------------------------------------------------------------------------------- /configs/experiment/age/celeba.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: age 4 | - override /networks: conv_64 5 | - override /datamodule: celeba 6 | 7 | model: 8 | latent_dim: 64 9 | e_recon_z_weight: 0 10 | e_recon_x_weight: 10 11 | g_recon_z_weight: 1000 12 | g_recon_x_weight: 10 13 | lrG: 1e-3 14 | lrE: 1e-3 15 | drop_lr_epoch: 20 16 | g_updates: 3 17 | trainer: 18 | max_epochs: 100 19 | check_val_every_n_epoch: 5 20 | networks: 21 | encoder: 22 | norm_type: batch 23 | decoder: 24 | norm_type: batch 25 | datamodule: 26 | batch_size: 64 27 | 28 | exp_name: age_celeba/z${model.recon_z_weight}_x${model.recon_x_weight}_lrG${model.lrG}_lrE${model.lrE}_batch${datamodule.batch_size} -------------------------------------------------------------------------------- /configs/experiment/age/cifar10.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: age 4 | - override /networks: conv_32 5 | - override /datamodule: cifar10 6 | 7 | model: 8 | latent_dim: 128 9 | e_recon_z_weight: 0 10 | e_recon_x_weight: 10 11 | g_recon_z_weight: 1000 12 | g_recon_x_weight: 10 13 | lrG: 1e-3 14 | lrE: 1e-3 15 | drop_lr_epoch: 20 16 | g_updates: 3 17 | trainer: 18 | max_epochs: 150 19 | networks: 20 | encoder: 21 | norm_type: batch 22 | decoder: 23 | norm_type: batch 24 | datamodule: 25 | batch_size: 256 26 | 27 | exp_name: age_cifar10/z${model.recon_z_weight}_x${model.recon_x_weight}_lrG${model.lrG}_lrE${model.lrE}_batch${datamodule.batch_size} -------------------------------------------------------------------------------- /configs/experiment/age/mnist.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: age 4 | - override /networks: conv_mnist 5 | - override /datamodule: mnist 6 | 7 | model: 8 | latent_dim: 10 9 | e_recon_x_weight: 10 10 | e_recon_z_weight: 0 11 | g_recon_x_weight: 0 12 | g_recon_z_weight: 1000 13 | lrG: 7e-4 14 | lrE: 2e-3 15 | drop_lr_epoch: 30 16 | trainer: 17 | max_epochs: 100 18 | networks: 19 | encoder: 20 | norm_type: batch 21 | decoder: 22 | norm_type: batch 23 | datamodule: 24 | batch_size: 512 25 | 26 | exp_name: age_mnist/z${model.recon_z_weight}_x${model.recon_x_weight}_lrG${model.lrG}_lrE${model.lrE}_batch${datamodule.batch_size} -------------------------------------------------------------------------------- /configs/experiment/beta_vae/celeba.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: vae 4 | - override /networks: conv_64 5 | - override /datamodule: celeba 6 | 7 | model: 8 | beta: 64 9 | latent_dim: 10 10 | decoder_dist: gaussian 11 | networks: 12 | encoder: 13 | norm_type: null 14 | decoder: 15 | norm_type: null 16 | trainer: 17 | max_epochs: 100 18 | exp_name: beta_vae/celeba -------------------------------------------------------------------------------- /configs/experiment/beta_vae/dsprites.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: vae 4 | - override /networks: conv_64 5 | - override /datamodule: dsprites 6 | 7 | exp_name: beta_vae/dsprites 8 | model: 9 | beta: 4 10 | latent_dim: 10 11 | decoder_dist: bernoulli 12 | networks: 13 | encoder: 14 | norm_type: null 15 | decoder: 16 | norm_type: null 17 | trainer: 18 | max_epochs: 100 -------------------------------------------------------------------------------- /configs/experiment/bigan/celeba.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: bigan 4 | - override /networks: conv_64 5 | - override /datamodule: celeba 6 | 7 | model: 8 | loss_mode: hinge 9 | lrG: 2e-4 10 | lrD: 1e-4 11 | trainer: 12 | max_epochs: 100 13 | networks: 14 | encoder: 15 | norm_type: null 16 | decoder: 17 | norm_type: batch 18 | exp_name: bigan/celeba_hinge 19 | -------------------------------------------------------------------------------- /configs/experiment/bigan/cifar10.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: bigan 4 | - override /networks: conv_32 5 | - override /datamodule: cifar10 6 | 7 | model: 8 | loss_mode: vanilla 9 | trainer: 10 | max_epochs: 200 11 | exp_name: bigan/cifar10 12 | -------------------------------------------------------------------------------- /configs/experiment/bigan/mnist.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: bigan 4 | - override /networks: conv_mnist 5 | - override /datamodule: mnist 6 | 7 | model: 8 | loss_mode: vanilla 9 | hidden_dim: 128 10 | trainer: 11 | max_epochs: 50 12 | exp_name: bigan/mnist_conv 13 | -------------------------------------------------------------------------------- /configs/experiment/contra_gan/dsprites.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: contra_gan 4 | - override /networks: conv_64 5 | - override /datamodule: dsprites 6 | 7 | exp_name: contra_gan/dsprites 8 | model: 9 | loss_mode: lsgan 10 | lrG: 2e-4 11 | lrD: 2e-4 12 | datamodule: 13 | batch_size: 64 14 | trainer: 15 | max_epochs: 100 -------------------------------------------------------------------------------- /configs/experiment/contra_vae/dsprites.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: contra_vae 4 | - override /networks: conv_64 5 | - override /datamodule: dsprites 6 | 7 | exp_name: contra_vae/dsprites 8 | model: 9 | beta: 4 10 | latent_dim: 10 11 | decoder_dist: bernoulli 12 | networks: 13 | encoder: 14 | norm_type: False 15 | decoder: 16 | norm_type: False 17 | trainer: 18 | max_epochs: 100 -------------------------------------------------------------------------------- /configs/experiment/cvae/cifar10.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: cvae 4 | - override /datamodule: cifar10 5 | 6 | exp_name: cvae/cifar10 7 | trainer: 8 | max_epochs: 100 -------------------------------------------------------------------------------- /configs/experiment/cvae/mnist.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: cvae 4 | - override /networks: conv_mnist 5 | - override /datamodule: mnist 6 | 7 | exp_name: cvae/mnist 8 | trainer: 9 | max_epochs: 100 -------------------------------------------------------------------------------- /configs/experiment/ddpm/celeba.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: ddpm 4 | - override /datamodule: celeba 5 | 6 | exp_name: ddpm/celeba 7 | 8 | trainer: 9 | max_epochs: 100 10 | check_val_every_n_epoch: 10 11 | 12 | model: 13 | dim_mults: [1, 2, 4, 8] 14 | timesteps: 1000 -------------------------------------------------------------------------------- /configs/experiment/ddpm/cifar10.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: ddpm 4 | - override /datamodule: cifar10 5 | 6 | exp_name: ddpm/cifar10 7 | 8 | trainer: 9 | max_epochs: 100 10 | check_val_every_n_epoch: 10 11 | 12 | model: 13 | dim_mults: [1, 2, 4] 14 | timesteps: 1000 -------------------------------------------------------------------------------- /configs/experiment/ddpm/mnist.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: ddpm 4 | - override /datamodule: mnist 5 | 6 | exp_name: ddpm/mnist 7 | 8 | trainer: 9 | max_epochs: 100 10 | check_val_every_n_epoch: 10 11 | 12 | model: 13 | dim_mults: [2, 4] 14 | timesteps: 1000 -------------------------------------------------------------------------------- /configs/experiment/factor_vae/celeba.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: factor_vae 4 | - override /networks: conv_64 5 | - override /datamodule: celeba 6 | 7 | exp_name: factor_vae/celeba 8 | model: 9 | adv_weight: 6.4 10 | latent_dim: 10 11 | lr: 2e-4 12 | lrD: 1e-4 13 | decoder_dist: gaussian 14 | 15 | trainer: 16 | max_epochs: 100 17 | 18 | networks: 19 | encoder: 20 | norm_type: null 21 | decoder: 22 | norm_type: null -------------------------------------------------------------------------------- /configs/experiment/factor_vae/dsprites.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: factor_vae 4 | - override /networks: conv_64 5 | - override /datamodule: dsprites 6 | 7 | exp_name: factor_vae/dsprites 8 | model: 9 | adv_weight: 35 10 | latent_dim: 10 11 | lr: 2e-4 12 | lrD: 1e-4 13 | decoder_dist: bernoulli 14 | 15 | trainer: 16 | max_epochs: 100 17 | 18 | networks: 19 | encoder: 20 | norm_type: null 21 | decoder: 22 | norm_type: null -------------------------------------------------------------------------------- /configs/experiment/ggan/celeba.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: gan 4 | - override /networks: conv_64 5 | - override /datamodule: celeba 6 | 7 | exp_name: ggan/celeba 8 | model: 9 | loss_mode: hinge 10 | trainer: 11 | max_epochs: 100 -------------------------------------------------------------------------------- /configs/experiment/ggan/cifar10.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: gan 4 | - override /networks: conv_32 5 | - override /datamodule: cifar10 6 | 7 | exp_name: ggan/cifar10 8 | model: 9 | loss_mode: hinge 10 | trainer: 11 | max_epochs: 100 -------------------------------------------------------------------------------- /configs/experiment/ggan/mnist_conv.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: gan 4 | - override /networks: conv_mnist 5 | - override /datamodule: mnist 6 | 7 | model: 8 | loss_mode: hinge 9 | exp_name: ggan/mnist_conv -------------------------------------------------------------------------------- /configs/experiment/infogan/mnist.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: info_gan 4 | - override /networks: conv_mnist 5 | - override /datamodule: mnist 6 | 7 | model: 8 | loss_mode: lsgan 9 | exp_name: info_gan/mnist 10 | 11 | trainer: 12 | max_epochs: 100 -------------------------------------------------------------------------------- /configs/experiment/lsgan/celeba.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: gan 4 | - override /networks: conv_64 5 | - override /datamodule: celeba 6 | 7 | exp_name: lsgan/celeba 8 | model: 9 | loss_mode: lsgan 10 | lrG: 2e-4 11 | lrD: 2e-4 12 | trainer: 13 | max_epochs: 100 -------------------------------------------------------------------------------- /configs/experiment/lsgan/cifar10.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: gan 4 | - override /networks: conv_32 5 | - override /datamodule: cifar10 6 | 7 | exp_name: lsgan/cifar10 8 | model: 9 | loss_mode: lsgan 10 | trainer: 11 | max_epochs: 100 -------------------------------------------------------------------------------- /configs/experiment/lsgan/conv_mnist.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: gan 4 | - override /networks: conv_mnist 5 | - override /datamodule: mnist 6 | 7 | model: 8 | loss_mode: lsgan 9 | exp_name: lsgan/mnist -------------------------------------------------------------------------------- /configs/experiment/lsgan/mlp_mnist.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: gan 4 | - override /networks: mlp_mnist 5 | - override /datamodule: mnist 6 | 7 | model: 8 | loss_mode: lsgan 9 | exp_name: lsgan/mnist_mlp -------------------------------------------------------------------------------- /configs/experiment/made/mnist.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: made 4 | - override /datamodule: mnist 5 | 6 | exp_name: made/mnist 7 | 8 | datamodule: 9 | transforms: 10 | grayscale: True 11 | normalize: False 12 | trainer: 13 | max_epochs: 100 14 | check_val_every_n_epoch: 10 -------------------------------------------------------------------------------- /configs/experiment/pixelcnn/celeba.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: pixelcnn 4 | - override /datamodule: celeba 5 | 6 | exp_name: pixelcnn/celeba 7 | 8 | datamodule: 9 | transforms: 10 | grayscale: False 11 | normalize: False 12 | trainer: 13 | max_epochs: 100 14 | check_val_every_n_epoch: 10 -------------------------------------------------------------------------------- /configs/experiment/pixelcnn/cifar10.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: pixelcnn 4 | - override /datamodule: cifar10 5 | 6 | exp_name: pixelcnn/cifar10 7 | 8 | datamodule: 9 | transforms: 10 | grayscale: False 11 | normalize: False 12 | trainer: 13 | max_epochs: 100 14 | check_val_every_n_epoch: 10 -------------------------------------------------------------------------------- /configs/experiment/pixelcnn/mnist.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: pixelcnn 4 | - override /datamodule: mnist 5 | 6 | exp_name: pixelcnn/mnist 7 | 8 | datamodule: 9 | transforms: 10 | grayscale: True 11 | normalize: False 12 | trainer: 13 | max_epochs: 100 14 | check_val_every_n_epoch: 10 -------------------------------------------------------------------------------- /configs/experiment/tar/mnist.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: tar 4 | - override /datamodule: mnist 5 | 6 | exp_name: tar/mnist 7 | 8 | datamodule: 9 | transforms: 10 | grayscale: True 11 | normalize: False 12 | trainer: 13 | max_epochs: 20 14 | check_val_every_n_epoch: 1 -------------------------------------------------------------------------------- /configs/experiment/tar/mnist_cond.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: tar 4 | - override /datamodule: mnist 5 | 6 | exp_name: tar/mnist 7 | model: 8 | class_cond: True 9 | 10 | datamodule: 11 | transforms: 12 | grayscale: True 13 | normalize: False 14 | trainer: 15 | max_epochs: 20 16 | check_val_every_n_epoch: 1 -------------------------------------------------------------------------------- /configs/experiment/vae/celeba.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: vae 4 | - override /networks: conv_64 5 | - override /datamodule: celeba 6 | 7 | exp_name: vae/celeba 8 | trainer: 9 | max_epochs: 100 10 | datamodule: 11 | batch_size: 128 -------------------------------------------------------------------------------- /configs/experiment/vae/cifar10.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: vae 4 | - override /networks: conv_32 5 | - override /datamodule: cifar10 6 | 7 | exp_name: vae/cifar10_${model.lr} 8 | trainer: 9 | max_epochs: 100 -------------------------------------------------------------------------------- /configs/experiment/vae/mnist_conv.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: vae 4 | - override /datamodule: mnist 5 | 6 | exp_name: vae/mnist_conv -------------------------------------------------------------------------------- /configs/experiment/vae/mnist_mlp.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: vae 4 | - override /networks: mlp 5 | - override /datamodule: mnist 6 | 7 | exp_name: vae/mnist_mlp -------------------------------------------------------------------------------- /configs/experiment/vaegan/celeba.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: vae_gan 4 | - override /networks: conv_64 5 | - override /datamodule: celeba 6 | 7 | model: 8 | loss_mode: vanilla 9 | recon_weight: 1e-6 10 | 11 | trainer: 12 | max_epochs: 100 13 | exp_name: vaegan/celeba -------------------------------------------------------------------------------- /configs/experiment/vaegan/cifar10.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: vae_gan 4 | - override /networks: conv_32 5 | - override /datamodule: cifar10 6 | 7 | model: 8 | loss_mode: vanilla 9 | recon_weight: 1e-5 10 | 11 | trainer: 12 | max_epochs: 100 13 | exp_name: vaegan/cifar10 -------------------------------------------------------------------------------- /configs/experiment/vaegan/mnist.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: vae_gan 4 | - override /networks: conv_mnist 5 | - override /datamodule: mnist 6 | 7 | model: 8 | loss_mode: vanilla 9 | recon_weight: 1e-3 10 | exp_name: vaegan/mnist -------------------------------------------------------------------------------- /configs/experiment/vanilla_gan/celeba.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: gan 4 | - override /networks: conv_64 5 | - override /datamodule: celeba 6 | 7 | exp_name: vanilla_gan/celeba 8 | model: 9 | loss_mode: vanilla 10 | trainer: 11 | max_epochs: 100 -------------------------------------------------------------------------------- /configs/experiment/vanilla_gan/cifar10.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: gan 4 | - override /networks: conv_32 5 | - override /datamodule: cifar10 6 | 7 | exp_name: vanilla_gan/cifar10 8 | model: 9 | loss_mode: vanilla 10 | trainer: 11 | max_epochs: 100 -------------------------------------------------------------------------------- /configs/experiment/vanilla_gan/dsprites.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: gan 4 | - override /networks: conv_64 5 | - override /datamodule: dsprites 6 | 7 | exp_name: vanilla_gan/dsprites 8 | model: 9 | loss_mode: vanilla 10 | lrG: 2e-3 11 | lrD: 2e-3 12 | datamodule: 13 | batch_size: 1024 14 | trainer: 15 | max_epochs: 100 -------------------------------------------------------------------------------- /configs/experiment/vanilla_gan/mnist_conv.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: gan 4 | - override /networks: conv_mnist 5 | - override /datamodule: mnist 6 | 7 | model: 8 | loss_mode: vanilla 9 | exp_name: vanilla_gan/mnist_conv -------------------------------------------------------------------------------- /configs/experiment/vanilla_gan/mnist_mlp.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: gan 4 | - override /networks: mlp 5 | - override /datamodule: mnist 6 | 7 | model: 8 | loss_mode: vanilla 9 | exp_name: vanilla_gan/mnist_mlp -------------------------------------------------------------------------------- /configs/experiment/vqvae/celeba.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: vqvae 4 | - override /networks: vqvae 5 | - override /datamodule: celeba 6 | 7 | trainer: 8 | max_epochs: 100 9 | exp_name: vqvae/celeba -------------------------------------------------------------------------------- /configs/experiment/vqvae/cifar10.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: vqvae 4 | - override /networks: vqvae 5 | - override /datamodule: cifar10 6 | 7 | trainer: 8 | max_epochs: 100 9 | exp_name: vqvae/cifar10 -------------------------------------------------------------------------------- /configs/experiment/vqvae/mnist.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: vqvae 4 | - override /networks: vqvae 5 | - override /datamodule: mnist 6 | 7 | trainer: 8 | max_epochs: 100 9 | exp_name: vqvae/mnist -------------------------------------------------------------------------------- /configs/experiment/wgan/celeba.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: wgan 4 | - override /networks: conv_64 5 | - override /datamodule: celeba 6 | 7 | trainer: 8 | max_epochs: 200 9 | check_val_every_n_epoch: 1 10 | networks: 11 | encoder: 12 | ndf: 64 13 | decoder: 14 | ngf: 64 15 | 16 | lr: 5e-5 17 | model: 18 | lrG: ${lr} 19 | lrD: ${lr} 20 | n_critic: 5 21 | clip_weight: 0.01 22 | eval_fid: True 23 | exp_name: wgan/celeba_lr${lr}_clip${model.clip_weight} -------------------------------------------------------------------------------- /configs/experiment/wgan/cifar10.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: wgan 4 | - override /networks: conv_32 5 | - override /datamodule: cifar10 6 | 7 | trainer: 8 | max_epochs: 2000 9 | check_val_every_n_epoch: 5 10 | lr: 2e-4 11 | model: 12 | lrG: ${lr} 13 | lrD: ${lr} 14 | n_critic: 5 15 | clip_weight: 0.01 16 | eval_fid: True 17 | exp_name: wgan/cifar10_lr_${lr} -------------------------------------------------------------------------------- /configs/experiment/wgan/mnist_conv.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: wgan 4 | - override /networks: conv_mnist 5 | - override /datamodule: mnist 6 | 7 | lr: 6e-4 8 | trainer: 9 | max_epochs: 200 10 | model: 11 | lrG: ${lr} 12 | lrD: ${lr} 13 | n_critic: 5 14 | clip_weight: 0.1 15 | exp_name: wgan/mnist_conv -------------------------------------------------------------------------------- /configs/experiment/wgan/mnist_mlp.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: wgan 4 | - override /networks: mlp 5 | - override /datamodule: mnist 6 | 7 | trainer: 8 | max_epochs: 50 9 | exp_name: wgan_mlp_mnist -------------------------------------------------------------------------------- /configs/experiment/wgan_gp/celeba.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: wgan_gp 4 | - override /networks: conv_64 5 | - override /datamodule: celeba 6 | 7 | networks: 8 | encoder: 9 | norm_type: instance 10 | decoder: 11 | norm_type: instance 12 | 13 | trainer: 14 | max_epochs: 300 15 | exp_name: wgangp/celeba -------------------------------------------------------------------------------- /configs/experiment/wgan_gp/cifar10.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: wgan_gp 4 | - override /networks: conv_32 5 | - override /datamodule: cifar10 6 | 7 | networks: 8 | encoder: 9 | norm_type: instance 10 | decoder: 11 | norm_type: instance 12 | 13 | model: 14 | lrD: 1e-3 15 | lrG: 1e-3 16 | 17 | trainer: 18 | max_epochs: 100 19 | exp_name: wgangp/cifar10 -------------------------------------------------------------------------------- /configs/experiment/wgan_gp/mnist_conv.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: wgan_gp 4 | - override /networks: conv_mnist 5 | - override /datamodule: mnist 6 | 7 | networks: 8 | encoder: 9 | norm_type: instance 10 | decoder: 11 | norm_type: instance 12 | 13 | trainer: 14 | max_epochs: 100 15 | exp_name: wgangp/mnist -------------------------------------------------------------------------------- /configs/experiment/wgan_gp/mnist_mlp.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: wgan_gp 4 | - override /networks: mlp 5 | - override /datamodule: mnist 6 | 7 | networks: 8 | encoder: 9 | norm_type: batch 10 | decoder: 11 | norm_type: batch 12 | 13 | trainer: 14 | max_epochs: 100 15 | exp_name: wgangp/mnist_mlp -------------------------------------------------------------------------------- /configs/hydra/default.yaml: -------------------------------------------------------------------------------- 1 | # output paths for hydra logs 2 | run: 3 | dir: logs/runs/${exp_name} 4 | sweep: 5 | dir: logs/multiruns/ 6 | subdir: ${exp_name} 7 | 8 | # you can set here environment variables that are universal for all users 9 | # for system specific variables (like data paths) it's better to use .env file! 10 | job: 11 | env_set: 12 | EXAMPLE_VAR: "example_value" 13 | chdir: True 14 | -------------------------------------------------------------------------------- /configs/logger/tensorboard.yaml: -------------------------------------------------------------------------------- 1 | # https://www.tensorflow.org/tensorboard/ 2 | _target_: pytorch_lightning.loggers.tensorboard.TensorBoardLogger 3 | save_dir: "tensorboard/" 4 | name: "" 5 | version: "" 6 | log_graph: False 7 | default_hp_metric: True 8 | prefix: "" 9 | -------------------------------------------------------------------------------- /configs/model/aae.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.models.aae.AAE 2 | encoder: ${networks.encoder} 3 | decoder: ${networks.decoder} 4 | netD: ${networks.encoder} 5 | loss_mode: vanilla 6 | latent_dim: 128 7 | lrG: 2e-4 8 | lrD: 2e-4 9 | b1: 0.5 10 | b2: 0.999 11 | recon_weight: 1 12 | prior: "normal" -------------------------------------------------------------------------------- /configs/model/age.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.models.age.AGE 2 | encoder: ${networks.encoder} 3 | decoder: ${networks.decoder} 4 | latent_dim: 128 5 | lrE: 2e-4 6 | lrG: 2e-4 7 | b1: 0.5 8 | b2: 0.999 9 | e_recon_z_weight: 0 10 | e_recon_x_weight: 1 11 | g_recon_z_weight: 1 12 | g_recon_x_weight: 0 13 | norm_z: true -------------------------------------------------------------------------------- /configs/model/bigan.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.models.BiGAN.BiGAN 2 | loss_mode: vanilla 3 | hidden_dim: 512 4 | latent_dim: 100 5 | lrG: 0.0002 6 | lrD: 0.0002 7 | b1: 0.5 8 | b2: 0.999 9 | encoder: ${networks.encoder} 10 | decoder: ${networks.decoder} -------------------------------------------------------------------------------- /configs/model/cvae.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - override /callbacks@_global_: ar_models 3 | _target_: src.models.cvae.cVAE 4 | latent_dim: 128 5 | lr: 1e-4 6 | b1: 0.9 7 | b2: 0.999 8 | beta: 1 9 | encoder: ${networks.encoder} 10 | decoder: ${networks.decoder} 11 | decoder_dist: gaussian 12 | n_classes: ${datamodule.n_classes} 13 | encode_label: True -------------------------------------------------------------------------------- /configs/model/ddpm.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - override /callbacks@_global_: ar_models 3 | _target_: src.models.ddpm.DDPM 4 | hidden_dim: 64 5 | lr: 0.0001 6 | b1: 0.9 7 | b2: 0.999 8 | timesteps: 1000 9 | loss_type: l1 10 | optim: adam 11 | dim_mults: [1, 2, 4, 8] -------------------------------------------------------------------------------- /configs/model/factor_vae.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.models.factor_vae.FactorVAE 2 | encoder: ${networks.encoder} 3 | decoder: ${networks.decoder} 4 | loss_mode: lsgan 5 | latent_dim: 10 6 | lr: 2e-4 7 | lrD: 1e-4 8 | adv_weight: 4 9 | ae_b1: 0.9 10 | ae_b2: 0.999 11 | adv_b1: 0.5 12 | adv_b2: 0.9 13 | decoder_dist: gaussian 14 | -------------------------------------------------------------------------------- /configs/model/gan.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.models.gan.GAN 2 | loss_mode: vanilla 3 | latent_dim: 100 4 | lrG: 0.0002 5 | lrD: 0.0002 6 | b1: 0.5 7 | b2: 0.999 8 | netG: ${networks.decoder} 9 | netD: ${networks.encoder} -------------------------------------------------------------------------------- /configs/model/info_gan.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.models.info_gan.InfoGAN 2 | loss_mode: vanilla 3 | discrete_dim: 1 4 | discrete_value: 10 5 | continuous_dim: 2 6 | noise_dim: 62 7 | encode_dim: 1024 8 | lambda_I: 1 9 | lrG: 0.001 10 | lrD: 0.0002 11 | lrQ: 0.0002 12 | b1: 0.5 13 | b2: 0.999 14 | netG: ${networks.decoder} 15 | netD: ${networks.encoder} -------------------------------------------------------------------------------- /configs/model/made.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - override /callbacks@_global_: ar_models 3 | 4 | _target_: src.models.made.MADE 5 | hidden_dim: 1024 6 | n_layer: 3 7 | lr: 1e-3 -------------------------------------------------------------------------------- /configs/model/pixelcnn.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - override /callbacks@_global_: ar_models 3 | 4 | _target_: src.models.pixelcnn.PixelCNN 5 | hidden_dim: 64 6 | lr: 1e-3 7 | n_classes: ${datamodule.n_classes} 8 | class_condition: False -------------------------------------------------------------------------------- /configs/model/speed_gan.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.models.speed_gan.GAN 2 | loss_mode: vanilla 3 | latent_dim: 100 4 | lrG: 0.0002 5 | lrD: 0.0002 6 | b1: 0.5 7 | b2: 0.999 8 | netG: ${networks.decoder} 9 | netD: ${networks.encoder} -------------------------------------------------------------------------------- /configs/model/tar.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - override /callbacks@_global_: ar_models 3 | _target_: src.models.tar.TAR 4 | lr: 1e-3 5 | b1: 0.9 6 | b2: 0.999 7 | d_model: 256 8 | nhead: 4 9 | num_layers: 4 10 | class_cond: false 11 | n_classes: ${datamodule.n_classes} -------------------------------------------------------------------------------- /configs/model/vae.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.models.vae.VAE 2 | latent_dim: 128 3 | lr: 1e-4 4 | b1: 0.9 5 | b2: 0.999 6 | beta: 1 7 | encoder: ${networks.encoder} 8 | decoder: ${networks.decoder} 9 | decoder_dist: gaussian -------------------------------------------------------------------------------- /configs/model/vae_gan.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.models.vae_gan.VAEGAN 2 | loss_mode: vanilla 3 | latent_dim: 100 4 | lr: 0.0002 5 | b1: 0.5 6 | b2: 0.99 7 | recon_weight: 1e-4 8 | encoder: ${networks.encoder} 9 | decoder: ${networks.decoder} -------------------------------------------------------------------------------- /configs/model/vqvae.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - override /callbacks@_global_: ar_models 3 | _target_: src.models.vqvae.VQVAE 4 | latent_dim: 64 5 | lr: 0.001 # a proper learning rate will converge faster 6 | b1: 0.9 7 | b2: 0.999 8 | beta: 0.25 9 | K: 512 10 | optim: adam 11 | encoder: ${networks.encoder} 12 | decoder: ${networks.decoder} -------------------------------------------------------------------------------- /configs/model/wgan.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.models.wgan.WGAN 2 | # model parameters 3 | latent_dim: 100 4 | lrG: 5e-5 5 | lrD: 5e-5 6 | alpha: 0.99 7 | netG: ${networks.decoder} 8 | netD: ${networks.encoder} 9 | # special for wgan 10 | n_critic: 5 11 | clip_weight: 0.01 -------------------------------------------------------------------------------- /configs/model/wgan_gp.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.models.wgan_gp.WGAN 2 | # model parameters 3 | latent_dim: 100 4 | lrG: 1e-4 5 | lrD: 1e-4 6 | b1: 0 7 | b2: 0.9 8 | netG: ${networks.decoder} 9 | netD: ${networks.encoder} 10 | # special for wgan 11 | n_critic: 5 12 | gp_weight: 10 -------------------------------------------------------------------------------- /configs/networks/conv_32.yaml: -------------------------------------------------------------------------------- 1 | decoder: 2 | _target_: src.networks.conv32.Decoder 3 | input_channel: null 4 | output_channel: null 5 | ngf: 64 6 | encoder: 7 | _target_: src.networks.conv32.Encoder 8 | input_channel: null 9 | output_channel: null 10 | ndf: 64 -------------------------------------------------------------------------------- /configs/networks/conv_64.yaml: -------------------------------------------------------------------------------- 1 | decoder: 2 | _target_: src.networks.conv64.Decoder 3 | input_channel: null 4 | output_channel: null 5 | ngf: 64 6 | norm_type: batch 7 | encoder: 8 | _target_: src.networks.conv64.Encoder 9 | input_channel: null 10 | output_channel: null 11 | ndf: 64 12 | norm_type: batch -------------------------------------------------------------------------------- /configs/networks/conv_mnist.yaml: -------------------------------------------------------------------------------- 1 | decoder: 2 | _target_: src.networks.basic.ConvDecoder 3 | input_channel: null 4 | output_channel: null 5 | ngf: 32 6 | norm_type: batch 7 | encoder: 8 | _target_: src.networks.basic.ConvEncoder 9 | input_channel: null 10 | output_channel: null 11 | ndf: 32 12 | norm_type: batch -------------------------------------------------------------------------------- /configs/networks/mlp.yaml: -------------------------------------------------------------------------------- 1 | decoder: 2 | _target_: src.networks.basic.MLPDecoder 3 | input_channel: null 4 | output_channel: null 5 | width: ${datamodule.width} 6 | height: ${datamodule.height} 7 | hidden_dims: 8 | - 1200 9 | - 1200 10 | - 1200 11 | - 4096 12 | output_act: tanh 13 | norm_type: batch 14 | encoder: 15 | _target_: src.networks.basic.MLPEncoder 16 | input_channel: null 17 | output_channel: null 18 | width: ${datamodule.width} 19 | height: ${datamodule.height} 20 | hidden_dims: 21 | - 1200 22 | - 1200 23 | dropout: 0 24 | output_act: identity 25 | norm_type: batch -------------------------------------------------------------------------------- /configs/networks/mlp_small.yaml: -------------------------------------------------------------------------------- 1 | decoder: 2 | _target_: src.networks.basic.MLPDecoder 3 | input_channel: null 4 | output_channel: null 5 | width: ${datamodule.width} 6 | height: ${datamodule.height} 7 | hidden_dims: 8 | - 128 9 | - 256 10 | - 512 11 | output_act: tanh 12 | norm_type: batch 13 | encoder: 14 | _target_: src.networks.basic.MLPEncoder 15 | input_channel: null 16 | output_channel: null 17 | width: ${datamodule.width} 18 | height: ${datamodule.height} 19 | hidden_dims: 20 | - 128 21 | - 256 22 | - 512 23 | dropout: 0 24 | norm_type: batch -------------------------------------------------------------------------------- /configs/networks/vqvae.yaml: -------------------------------------------------------------------------------- 1 | decoder: 2 | _target_: src.networks.vqvae.Decoder 3 | input_channel: null 4 | output_channel: null 5 | encoder: 6 | _target_: src.networks.vqvae.Encoder 7 | input_channel: null 8 | output_channel: null -------------------------------------------------------------------------------- /configs/trainer/default.yaml: -------------------------------------------------------------------------------- 1 | _target_: pytorch_lightning.Trainer 2 | 3 | # set `1` to train on GPU, `0` to train on CPU only 4 | devices: 1 5 | 6 | max_epochs: 20 7 | 8 | enable_model_summary: False 9 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # --------- pytorch --------- # 2 | torch>=2.0.0 3 | torchvision>=0.9.1 4 | pytorch-lightning>=2.0.0 5 | 6 | 7 | # --------- hydra --------- # 8 | hydra-core>=1.1.0 9 | hydra-colorlog>=1.1.0 10 | hydra-optuna-sweeper>=1.1.0 11 | hydra-joblib-launcher>=1.1.5 12 | torchmetrics>=0.5.1 13 | torch-fidelity>=0.3.0 14 | # hydra-ax-sweeper 15 | # hydra-ray-launcher 16 | # hydra-submitit-launcher 17 | 18 | # --------- loggers --------- # 19 | tensorboardx 20 | wandb 21 | # neptune-client 22 | # mlflow 23 | # comet-ml 24 | # torch_tb_profiler 25 | 26 | # --------- linters --------- # 27 | pre-commit # hooks for applying linters on commit 28 | black # code formatting 29 | isort # import sorting 30 | flake8 # code analysis 31 | 32 | # --------- others --------- # 33 | python-dotenv # loading env variables from .env file 34 | rich # beautiful text formatting in terminal 35 | pytest # tests 36 | sh # for running bash commands in some tests 37 | scikit-learn # used in some callbacks 38 | seaborn # used in some callbacks 39 | jupyterlab # better jupyter notebooks 40 | pudb # debugger 41 | GPUtil # get info about GPUs 42 | einops # Framework to use einstein-like notation 43 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | from omegaconf import DictConfig 3 | 4 | 5 | @hydra.main(config_path="configs", config_name="config", version_base='1.1') 6 | def train(config: DictConfig) -> None: 7 | from src.train import train 8 | from src.utils import utils 9 | 10 | # Pretty print config using Rich library 11 | if config.get("print_config"): 12 | utils.print_config(config, resolve=True) 13 | 14 | # Train model 15 | return train(config) 16 | 17 | 18 | if __name__ == "__main__": 19 | train() 20 | -------------------------------------------------------------------------------- /src/callbacks/evaluation.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import torchmetrics 3 | from torchmetrics.image.fid import FrechetInceptionDistance 4 | import torch 5 | from src.models.base import ValidationResult 6 | 7 | 8 | class FIDEvaluationCallback(pl.Callback): 9 | def __init__(self, every_n_epochs=1): 10 | self.every_n_epoch = every_n_epochs 11 | 12 | def image_float2int(self, imgs, pl_module): 13 | if pl_module.input_normalize: 14 | imgs = (imgs + 1) / 2 15 | imgs = (imgs * 255).to(torch.uint8) 16 | return imgs 17 | 18 | def on_validation_epoch_start(self, trainer, pl_module) -> None: 19 | if pl_module.channels == 3 and trainer.current_epoch % self.every_n_epoch == 0: 20 | self.fid = FrechetInceptionDistance().to(pl_module.device) 21 | 22 | def on_validation_batch_end(self, trainer, pl_module, outputs: ValidationResult, batch, batch_idx): 23 | if pl_module.channels == 3 and trainer.current_epoch % self.every_n_epoch == 0: 24 | real_imgs, fake_images = outputs.real_image, outputs.fake_image 25 | self.fid.update(self.image_float2int(real_imgs, pl_module), real=True) 26 | self.fid.update(self.image_float2int(fake_images, pl_module), real=False) 27 | 28 | def on_validation_epoch_end(self, trainer, pl_module: pl.LightningModule): 29 | if pl_module.channels == 3 and trainer.current_epoch % self.every_n_epoch == 0: 30 | pl_module.log("metrics/fid", self.fid.compute(), on_epoch=True) 31 | -------------------------------------------------------------------------------- /src/callbacks/util.py: -------------------------------------------------------------------------------- 1 | from pytorch_lightning.callbacks import Callback 2 | import subprocess 3 | 4 | 5 | class GifCallback(Callback): 6 | def __init__(self) -> None: 7 | super().__init__() 8 | 9 | def on_train_end(self, *args, **kargs) -> None: 10 | subprocess.call( 11 | [ 12 | "ffmpeg", 13 | "-f", 14 | "image2", 15 | "-framerate", 16 | "4", 17 | "-i", 18 | "results/%d.jpg", 19 | "video.gif", 20 | ] 21 | ) 22 | -------------------------------------------------------------------------------- /src/callbacks/visualization.py: -------------------------------------------------------------------------------- 1 | from random import randint 2 | import numpy as np 3 | import pytorch_lightning as pl 4 | from pathlib import Path 5 | import torchvision 6 | import matplotlib.pyplot as plt 7 | import io 8 | from torchvision.transforms import ToTensor 9 | import torch 10 | from src.models.base import ValidationResult 11 | import PIL 12 | 13 | class SampleImagesCallback(pl.Callback): 14 | def __init__(self, batch_size=64, every_n_epochs=1): 15 | self.batch_size = batch_size 16 | self.every_n_epochs = every_n_epochs 17 | 18 | def on_validation_batch_end(self, trainer, pl_module, outputs: ValidationResult, batch, batch_idx): 19 | if trainer.current_epoch % self.every_n_epochs == 0 and batch_idx == 0: 20 | result_path = Path("results") 21 | result_path.mkdir(parents=True, exist_ok=True) 22 | 23 | real_grid = get_grid_images(outputs.real_image, pl_module) 24 | trainer.logger.experiment.add_image("images/real", real_grid, global_step=trainer.current_epoch) 25 | 26 | if outputs.recon_image is not None: 27 | recon_grid = get_grid_images(outputs.recon_image, pl_module) 28 | trainer.logger.experiment.add_image("images/recon", recon_grid, global_step=trainer.current_epoch) 29 | 30 | if outputs.fake_image is not None: 31 | fake_grid = get_grid_images(outputs.fake_image, pl_module) 32 | trainer.logger.experiment.add_image("images/sample", fake_grid, global_step=trainer.current_epoch) 33 | torchvision.utils.save_image(fake_grid, result_path / f"{trainer.current_epoch}.jpg") 34 | 35 | for key in outputs.others: 36 | if outputs.others[key] is not None: 37 | grid = get_grid_images(outputs.others[key], pl_module) 38 | trainer.logger.experiment.add_image(f"images/{key}", grid, global_step=trainer.current_epoch) 39 | 40 | 41 | class TraverseLatentCallback(pl.Callback): 42 | def __init__(self, col=10, row=10) -> None: 43 | super().__init__() 44 | self.col = col 45 | self.row = row 46 | 47 | def generate_traverse_images(self, pl_module, fixed_z=None): 48 | row, col = 11, min(10, pl_module.hparams.latent_dim) 49 | if fixed_z is None: 50 | fixed_z = torch.randn(1, 1, pl_module.hparams.latent_dim).repeat(row, col, 1).reshape(row, col, -1).to(pl_module.device) 51 | else: 52 | fixed_z = fixed_z.reshape(1, 1, pl_module.hparams.latent_dim).repeat(row, col, 1).reshape(row, col, -1) 53 | variation_z = torch.linspace(-3, 3, row).to(pl_module.device) 54 | for i in range(col): 55 | fixed_z[:, i, i] = variation_z # i-th column correspondes to i-th latent unit variation 56 | imgs = pl_module.forward(fixed_z.reshape(row*col, -1)) 57 | grid = get_grid_images(imgs, pl_module, nimgs=row*col, nrow=col) 58 | return grid 59 | 60 | def on_validation_batch_end(self, trainer, pl_module, outputs: ValidationResult, batch, batch_idx): 61 | if batch_idx == 0: 62 | self.z = outputs.encode_latent 63 | 64 | def on_validation_epoch_end(self, trainer, pl_module): 65 | if self.z is not None: 66 | grid1 = self.generate_traverse_images(pl_module, self.z[3]) 67 | trainer.logger.experiment.add_image("sample/fixed_traverse_latents_1", grid1, global_step=trainer.current_epoch) 68 | if self.z is not None: 69 | grid2 = self.generate_traverse_images(pl_module, self.z[6]) 70 | trainer.logger.experiment.add_image("sample/fixed_traverse_latents_2", grid2, global_step=trainer.current_epoch) 71 | 72 | grid = self.generate_traverse_images(pl_module) 73 | trainer.logger.experiment.add_image("sample/random_traverse_latents", grid, global_step=trainer.current_epoch) 74 | 75 | class Visual2DSpaecCallback(pl.Callback): 76 | def __init__(self) -> None: 77 | super().__init__() 78 | 79 | def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: 80 | if pl_module.hparams.latent_dim == 2: 81 | x = torch.tensor(np.linspace(-3, 3, 20)).to(pl_module.device) 82 | y = torch.tensor(np.linspace(3, -3, 20)).to(pl_module.device) 83 | xx, yy = torch.meshgrid([y, x]) 84 | latent = torch.stack([xx.reshape(-1), yy.reshape(-1)], dim=1) # (20*20, 2) 85 | imgs = pl_module.forward(latent) 86 | grid_imgs = get_grid_images(imgs, pl_module, nimgs=400, nrow=20) 87 | trainer.logger.experiment.add_image("sample/grid_imgs", grid_imgs, global_step=trainer.current_epoch) 88 | 89 | class LatentVisualizationCallback(pl.Callback): 90 | def __init__(self) -> None: 91 | super().__init__() 92 | 93 | def on_validation_epoch_start(self, trainer, pl_module) -> None: 94 | if pl_module.hparams.latent_dim == 2: 95 | self.latents = [] 96 | self.labels = [] 97 | 98 | def on_validation_batch_end(self, trainer, pl_module, outputs: ValidationResult, batch, batch_idx): 99 | if pl_module.hparams.latent_dim == 2: 100 | self.latents.append(outputs.encode_latent) 101 | self.labels.append(outputs.label) 102 | 103 | def on_validation_epoch_end(self, trainer, pl_module): 104 | if pl_module.hparams.latent_dim == 2: 105 | latents_array = torch.cat(self.latents).cpu().numpy() 106 | labels_array = torch.cat(self.labels).cpu().numpy() 107 | sort_idx = np.argsort(labels_array) 108 | self.latents = [] 109 | self.labels = [] 110 | img = make_scatter(x=latents_array[:, 0][sort_idx], y=latents_array[:,1][sort_idx], 111 | c=labels_array[sort_idx], xlim=(-3, 3), ylim=(-3, 3)) 112 | trainer.logger.experiment.add_image("val/latent distributions", img, global_step=trainer.current_epoch) 113 | 114 | 115 | def tensor_to_array(*tensors): 116 | output = [] 117 | for tensor in tensors: 118 | if isinstance(tensor, torch.Tensor): 119 | output.append(np.array(tensor.detach().cpu().numpy())) 120 | else: 121 | output.append(tensor) 122 | return output 123 | 124 | def make_scatter(x, y, c=None, s=None, xlim=None, ylim=None): 125 | x, y, c, s = tensor_to_array(x, y, c, s) 126 | 127 | plt.figure() 128 | plt.scatter(x=x, y=y, s=s, c=c, cmap="tab10", alpha=1) 129 | if xlim: 130 | plt.xlim(xlim) 131 | if ylim: 132 | plt.ylim(ylim) 133 | plt.title("Latent distribution") 134 | buf = io.BytesIO() 135 | plt.savefig(buf, format='jpeg') 136 | plt.close() 137 | buf.seek(0) 138 | visual_image = ToTensor()(PIL.Image.open(buf)) 139 | return visual_image 140 | 141 | def get_grid_images(imgs, model, nimgs=64, nrow=8): 142 | if model.input_normalize: 143 | grid = torchvision.utils.make_grid( 144 | imgs[:nimgs], nrow=nrow, normalize=True, value_range=(-1, 1), pad_value=1 145 | ) 146 | else: 147 | grid = torchvision.utils.make_grid(imgs[:nimgs], normalize=False, nrow=nrow, pad_value=1) 148 | return grid -------------------------------------------------------------------------------- /src/datamodules/base.py: -------------------------------------------------------------------------------- 1 | from telnetlib import IP 2 | import torch.nn.functional as F 3 | from torchvision import transforms 4 | from PIL import Image 5 | from torch.utils.data import DataLoader 6 | import pytorch_lightning as pl 7 | 8 | class BaseDatamodule(pl.LightningDataModule): 9 | def __init__(self, width, height, channels, batch_size, num_workers): 10 | super().__init__() 11 | self.batch_size = batch_size 12 | self.num_workers = num_workers 13 | 14 | def train_dataloader(self): 15 | return DataLoader( 16 | self.train_data, 17 | batch_size=self.batch_size, 18 | num_workers=self.num_workers, 19 | shuffle=True, 20 | multiprocessing_context='fork' 21 | ) 22 | 23 | def val_dataloader(self): 24 | return DataLoader(self.val_data, 25 | batch_size=self.batch_size, 26 | num_workers=self.num_workers, 27 | multiprocessing_context='fork') 28 | 29 | def get_interpolation_method(method): 30 | if method == 'nearest': 31 | return transforms.InterpolationMode.NEAREST 32 | elif method == 'bicubic': 33 | return transforms.InterpolationMode.BICUBIC 34 | elif method == 'bilinear': 35 | return transforms.InterpolationMode.BILINEAR 36 | 37 | def get_transform(config): 38 | transform_list = [] 39 | if config is None: 40 | return transforms.ToTensor() 41 | # if 'grayscale' in config: 42 | # transform_list.append(transforms.Grayscale(1)) 43 | if 'resize' in config: 44 | osize = [config.resize.height, config.resize.width] 45 | if 'method' not in config.resize: 46 | method = transforms.InterpolationMode.BICUBIC 47 | else: 48 | method = get_interpolation_method(config.resize.method) 49 | transform_list.append(transforms.Resize(osize, method)) 50 | 51 | if 'crop' in config: 52 | osize = [config.crop.height, config.crop.width] 53 | transform_list.append(transforms.RandomCrop(osize)) 54 | 55 | if 'flip' in config: 56 | transform_list.append(transforms.RandomHorizontalFlip()) 57 | 58 | if 'convert' in config: 59 | transform_list += [transforms.ToTensor()] 60 | if config.normalize: 61 | if 'grayscale' in config: 62 | transform_list += [transforms.Normalize((0.5,), (0.5,))] 63 | else: 64 | transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] 65 | 66 | if 'onehot' in config: 67 | transform_list += [transforms.PILToTensor()] 68 | def f(x): 69 | return F.one_hot(x.long().squeeze(), num_classes=config.onehot.num_classes).float().permute(2, 0, 1) 70 | transform_list += [transforms.Lambda(lambda x : f(x))] 71 | return transforms.Compose(transform_list) 72 | -------------------------------------------------------------------------------- /src/datamodules/basic.py: -------------------------------------------------------------------------------- 1 | import os 2 | from itertools import cycle, islice 3 | from typing import List, Optional 4 | 5 | import torch 6 | from PIL import Image 7 | from torch.utils import data 8 | from torchvision.datasets import VisionDataset 9 | from pathlib import Path 10 | import math 11 | 12 | IMG_EXTENSIONS = [ 13 | ".jpg", 14 | ".JPG", 15 | ".jpeg", 16 | ".JPEG", 17 | ".png", 18 | ".PNG", 19 | ".ppm", 20 | ".PPM", 21 | ".bmp", 22 | ".BMP", 23 | ".tif", 24 | ".TIF", 25 | ".tiff", 26 | ".TIFF", 27 | ] 28 | 29 | 30 | def is_image_file(file: Path): 31 | return file.suffix in IMG_EXTENSIONS 32 | 33 | 34 | def make_dataset(dir, max_dataset_size=float("inf")) -> List[Path]: 35 | images = [] 36 | root = Path(dir) 37 | assert root.is_dir(), "%s is not a valid directory" % dir 38 | 39 | for file in root.rglob("*"): 40 | if is_image_file(file): 41 | images.append(file) 42 | return images[: min(max_dataset_size, len(images))] 43 | 44 | 45 | def default_loader(path): 46 | return Image.open(path).convert("RGB") 47 | 48 | 49 | class ImageFolder(data.Dataset): 50 | def __init__( 51 | self, 52 | root, 53 | transform=None, 54 | return_paths=False, 55 | return_dict=False, 56 | sort=False, 57 | loader=default_loader, 58 | ): 59 | imgs = make_dataset(root) 60 | if sort: 61 | imgs = sorted(imgs) 62 | if len(imgs) == 0: 63 | raise ( 64 | RuntimeError( 65 | "Found 0 images in: " + root + "\n" 66 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS) 67 | ) 68 | ) 69 | 70 | self.root = root 71 | self.imgs = imgs 72 | self.transform = transform 73 | self.return_paths = return_paths 74 | self.return_dict = return_dict 75 | self.loader = loader 76 | 77 | def __getitem__(self, index): 78 | path = self.imgs[index] 79 | img = self.loader(path) 80 | if self.transform is not None: 81 | img = self.transform(img) 82 | if self.return_paths: 83 | return img, str(path) 84 | else: 85 | if self.return_dict: 86 | return {"images": img} 87 | else: 88 | return img 89 | 90 | def __len__(self): 91 | return len(self.imgs) 92 | 93 | 94 | class MergeDataset(data.Dataset): 95 | def __init__(self, *datasets): 96 | """Merge multiple datasets to one dataset, and each time retrives a combinations of items in all sub datasets.""" 97 | self.datasets = datasets 98 | self.sizes = [len(dataset) for dataset in datasets] 99 | print("dataset size", self.sizes) 100 | 101 | def __getitem__(self, indexs: List[int]): 102 | return tuple(dataset[idx] for idx, dataset in zip(indexs, self.datasets)) 103 | 104 | def __len__(self): 105 | return max(self.sizes) 106 | 107 | 108 | class MultiRandomSampler(data.RandomSampler): 109 | """a Random Sampler for MergeDataset. NOTE will padding all dataset to same length 110 | Each time it generates an index for each subdataset in MergeDataset. 111 | 112 | Args: 113 | data_source (MergeDataset): MergeDataset object 114 | replacement (bool, optional): shuffle index use replacement. Defaults to True. 115 | num_samples ([type], optional): Defaults to None. 116 | generator ([type], optional): Defaults to None. 117 | """ 118 | 119 | def __init__( 120 | self, 121 | data_source: MergeDataset, 122 | replacement=True, 123 | num_samples=None, 124 | generator=None, 125 | ): 126 | self.data_source: MergeDataset = data_source 127 | self.replacement = replacement 128 | self._num_samples = num_samples 129 | self.generator = generator 130 | self.maxn = len(self.data_source) 131 | 132 | @property 133 | def num_samples(self): 134 | # dataset size might change at runtime 135 | if self._num_samples is None: 136 | self._num_samples = self.data_source.sizes 137 | return self._num_samples 138 | 139 | def __iter__(self): 140 | rands = [] 141 | for size in self.num_samples: 142 | if self.maxn == size: 143 | rands.append(torch.randperm(size, generator=self.generator).tolist()) 144 | else: 145 | rands.append( 146 | torch.randint( 147 | high=size, 148 | size=(self.maxn,), 149 | dtype=torch.int64, 150 | generator=self.generator, 151 | ).tolist() 152 | ) 153 | return zip(*rands) 154 | 155 | def __len__(self): 156 | return len(self.data_source) 157 | 158 | 159 | class MultiSequentialSampler(data.Sampler): 160 | r"""Samples elements sequentially, always in the same order. 161 | NOTE: it whill expand all dataset to same length 162 | 163 | Arguments: 164 | data_source (Dataset): dataset to sample from 165 | """ 166 | 167 | def __init__(self, data_source: MergeDataset): 168 | self.data_source: MergeDataset = data_source 169 | self.num_samples = data_source.sizes 170 | self.maxn = len(data_source) 171 | 172 | def __iter__(self): 173 | ls = [] 174 | for size in self.num_samples: 175 | if self.maxn == size: 176 | ls.append(range(size)) 177 | else: 178 | ls.append(islice(cycle(range(size)), self.maxn)) 179 | return zip(*ls) 180 | 181 | def __len__(self): 182 | return len(self.data_source) 183 | 184 | 185 | class DistributedSamplerWrapper(data.DistributedSampler): 186 | def __init__( 187 | self, 188 | sampler, 189 | num_replicas: Optional[int] = None, 190 | rank: Optional[int] = None, 191 | shuffle: bool = True, 192 | ): 193 | super(DistributedSamplerWrapper, self).__init__( 194 | sampler.data_source, num_replicas, rank, shuffle 195 | ) 196 | self.sampler = sampler 197 | 198 | def __iter__(self): 199 | indices = list(self.sampler) 200 | indices = indices[self.rank : self.total_size : self.num_replicas] 201 | return iter(indices) 202 | 203 | def __len__(self): 204 | return len(self.sampler) // self.num_replicas 205 | 206 | 207 | class MultiBatchDataset(MergeDataset): 208 | """MultiBatchDataset for MultiBatchSampler 209 | NOTE: inputs type must be MergeDataset 210 | """ 211 | 212 | def __getitem__(self, indexs: List[int]): 213 | dataset_idxs, idxs = indexs 214 | return self.datasets[dataset_idxs][idxs] 215 | 216 | 217 | class MultiBatchSampler(data.Sampler): 218 | r"""Sample another sampler by repeats times of mini-batch indices. 219 | NOTE always drop last ! 220 | Args: 221 | samplers (Sampler or Iterable): Base sampler. Can be any iterable object 222 | with ``__len__`` implemented. 223 | repeats (list): repeats time 224 | batch_size (int): Size of mini-batch. 225 | 226 | dataloader是依靠什么停止sample的呢, 是next抛出的error, 还是len 227 | 直接for迭代dataloader时, 是不看len的, 要等到StopIteration 228 | 但pytorch_lighning好像会看len...因此len要和iter一致... 229 | 230 | 那么这个len怎么得到呢? 231 | NOTE: 不同repeat之间必须能够整除... 232 | """ 233 | 234 | def __init__(self, samplers: list, repeats: list, batch_size, drop_last=True): 235 | # Since collections.abc.Iterable does not check for `__getitem__`, which 236 | # is one way for an object to be an iterable, we don't do an `isinstance` 237 | # check here. 238 | if ( 239 | not isinstance(batch_size, int) 240 | or isinstance(batch_size, bool) 241 | or batch_size <= 0 242 | ): 243 | raise ValueError( 244 | "batch_size should be a positive integer value, " 245 | "but got batch_size={}".format(batch_size) 246 | ) 247 | 248 | assert len(samplers) == len( 249 | repeats 250 | ), "Samplers number must equal repeats number" 251 | 252 | minweight = min( 253 | repeats 254 | ) # 假设每次sample, 要把频率最小的dataset遍历完, 那么其他的dataset的遍历次数就像最小dataset长度的对应倍数 255 | minlength = len(samplers[repeats.index(minweight)]) 256 | self.sampler_loop = cycle([i for i, w in enumerate(repeats) for _ in range(w)]) 257 | # expand to target length 258 | self.repeats = repeats 259 | self.sizes = [ 260 | minlength * math.ceil(w / minweight) for w in repeats 261 | ] # 如果最小的weight是1, 那么其他dataset的size就是minlength的相应倍数 262 | self.size = sum(self.sizes) 263 | self.batch_size = batch_size 264 | self.samplers: List[data.Sampler] = samplers 265 | self.new_samplers = [] 266 | self.drop_last = True 267 | 268 | def __iter__(self): 269 | self.new_samplers.clear() 270 | self.new_samplers = [ 271 | islice(cycle(smp), size) # size限制了iter的结束点 272 | for smp, size in zip(self.samplers, self.sizes) 273 | ] 274 | return self 275 | 276 | def __next__(self): 277 | # NOTE sampler_idx choice dataset 278 | sampler_idx = next(self.sampler_loop) 279 | sampler: data.Sampler = self.new_samplers[sampler_idx] 280 | return [ 281 | (sampler_idx, next(sampler)) for _ in range(self.batch_size) 282 | ] # 自动droplast, 由于最后一个Batch不满, 会造成next抛出StopIterationn 283 | 284 | def __len__(self): 285 | # NOTE find min batch scale factor 286 | scale = (min(self.sizes) // self.batch_size) // min( 287 | self.repeats 288 | ) # 算出最先stopiteration的dataset 289 | return sum([n * scale for n in self.repeats]) 290 | -------------------------------------------------------------------------------- /src/datamodules/celeba.py: -------------------------------------------------------------------------------- 1 | from torchvision.datasets import CelebA 2 | from .base import get_transform, BaseDatamodule 3 | 4 | class CelebADataModule(BaseDatamodule): 5 | def __init__( 6 | self, 7 | data_dir: str = "./data", 8 | width=64, 9 | height=64, 10 | channels=3, 11 | batch_size: int = 64, 12 | num_workers: int = 8, 13 | transforms=None, 14 | **kargs 15 | ): 16 | super().__init__(width, height, channels, batch_size, num_workers) 17 | self.data_dir = data_dir 18 | self.transform = get_transform(transforms) 19 | print("Preparing celeba transforms", self.transform) 20 | 21 | def prepare_data(self): 22 | CelebA(self.data_dir, split="all", download=True) 23 | 24 | def setup(self, stage=None): 25 | self.train_data = CelebA(self.data_dir, split="train", transform=self.transform) 26 | self.val_data = CelebA(self.data_dir, split="test", transform=self.transform) 27 | -------------------------------------------------------------------------------- /src/datamodules/cifar10.py: -------------------------------------------------------------------------------- 1 | from torchvision.datasets import CIFAR10 2 | from .base import BaseDatamodule, get_transform 3 | 4 | class CIFAR10DataModule(BaseDatamodule): 5 | def __init__( 6 | self, 7 | data_dir: str = "./data", 8 | width=64, 9 | height=64, 10 | channels=3, 11 | batch_size: int = 64, 12 | num_workers: int = 8, 13 | transforms=None, 14 | **kargs 15 | ): 16 | super().__init__(width, height, channels, batch_size, num_workers) 17 | self.data_dir = data_dir 18 | self.transform = get_transform(transforms) 19 | 20 | def prepare_data(self): 21 | # download 22 | CIFAR10(self.data_dir, train=True, download=True) 23 | CIFAR10(self.data_dir, train=False, download=True) 24 | 25 | def setup(self, stage=None): 26 | # Assign train/val datasets for use in dataloaders 27 | self.train_data = CIFAR10(self.data_dir, train=True, transform=self.transform) 28 | self.val_data = CIFAR10(self.data_dir, train=False, transform=self.transform) 29 | -------------------------------------------------------------------------------- /src/datamodules/dsprite.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytorch_lightning as pl 3 | from pathlib import Path 4 | import torch 5 | from .base import BaseDatamodule, get_transform 6 | from .utils import url_retrive, CustomTensorDataset 7 | from torch.utils.data import random_split 8 | 9 | 10 | class DataModule(BaseDatamodule): 11 | def __init__( 12 | self, 13 | data_dir: str = "./data", 14 | width=64, 15 | height=64, 16 | channels=3, 17 | batch_size: int = 64, 18 | num_workers: int = 8, 19 | transforms=None, 20 | ): 21 | super().__init__(width, height, channels, batch_size, num_workers) 22 | self.data_dir = data_dir 23 | self.transform = get_transform(transforms) 24 | 25 | self.data_dir = Path(data_dir) / 'dsprite' 26 | self.data_file = self.data_dir / 'dsprites_64x64.npz' 27 | 28 | def prepare_data(self): 29 | # download 30 | URL = "https://github.com/deepmind/dsprites-dataset/raw/master/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz" 31 | 32 | if not self.data_dir.exists(): 33 | self.data_dir.mkdir(parents=True) 34 | if not self.data_file.exists(): 35 | url_retrive(URL, self.data_file) 36 | 37 | def setup(self, stage=None): 38 | data = np.load(self.data_file, encoding='latin1') 39 | data = torch.from_numpy(data['imgs']).unsqueeze(1).float() 40 | length = data.shape[0] 41 | full_data = CustomTensorDataset(data, transform=self.transform) 42 | self.train_data, self.val_data = random_split(full_data, [8*length // 10, 2*length // 10], generator=torch.Generator().manual_seed(666)) 43 | -------------------------------------------------------------------------------- /src/datamodules/lsun.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | from torch.utils.data import random_split, DataLoader 3 | from torchvision.datasets import LSUN 4 | 5 | from datamodules.base import get_transform 6 | 7 | """ 8 | Download of LSUN dataset refer to https://github.com/fyu/lsun 9 | Place dataset file to data_dir/lsun 10 | """ 11 | 12 | class LSUNDataModule(pl.LightningDataModule): 13 | def __init__( 14 | self, 15 | data_dir: str = "./data", 16 | width=64, 17 | height=64, 18 | channels=3, 19 | batch_size: int = 64, 20 | num_workers: int = 8, 21 | transforms=None, 22 | categories=["bedroom"] 23 | ): 24 | super().__init__(width, height, channels, batch_size, num_workers) 25 | self.data_dir = data_dir + "/lsun" 26 | self.transform = get_transform(transforms) 27 | self.categories = categories 28 | 29 | def setup(self, stage=None): 30 | # Assign train/val datasets for use in dataloaders 31 | self.train_data = LSUN( 32 | self.data_dir, 33 | classes=[x + "_train" for x in self.categories], 34 | transform=self.transform, 35 | ) 36 | self.val_data = LSUN( 37 | self.data_dir, 38 | classes=[x + "_test" for x in self.categories], 39 | transform=self.transform, 40 | ) 41 | 42 | if __name__ == "__main__": 43 | data = LSUNDataModule() 44 | data.prepare_data() 45 | data.setup() 46 | data.train_dataloader() -------------------------------------------------------------------------------- /src/datamodules/mnist.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | from torchvision.datasets import MNIST 3 | from .base import BaseDatamodule, get_transform 4 | 5 | 6 | class MNISTDataModule(BaseDatamodule): 7 | def __init__( 8 | self, 9 | data_dir: str = "./data", 10 | width=64, 11 | height=64, 12 | channels=3, 13 | batch_size: int = 64, 14 | num_workers: int = 8, 15 | transforms=None, 16 | **kargs 17 | ): 18 | super().__init__(width, height, channels, batch_size, num_workers) 19 | self.data_dir = data_dir 20 | self.transform = get_transform(transforms) 21 | 22 | def prepare_data(self): 23 | # download 24 | MNIST(self.data_dir, train=True, download=True) 25 | MNIST(self.data_dir, train=False, download=True) 26 | 27 | def setup(self, stage=None): 28 | # Assign train/val datasets for use in dataloaders 29 | self.train_data = MNIST(self.data_dir, train=True, transform=self.transform) 30 | self.val_data = MNIST(self.data_dir, train=False, transform=self.transform) -------------------------------------------------------------------------------- /src/datamodules/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | from tqdm import tqdm 3 | import urllib 4 | from torch.utils.data import Dataset 5 | 6 | 7 | USER_AGENT = "pytorch/vision" 8 | def url_retrive(url: str, filename: str, chunk_size: int = 1024) -> None: 9 | with open(filename, "wb") as fh: 10 | with urllib.request.urlopen(urllib.request.Request(url, headers={"User-Agent": USER_AGENT})) as response: 11 | with tqdm(total=response.length) as pbar: 12 | for chunk in iter(lambda: response.read(chunk_size), ""): 13 | if not chunk: 14 | break 15 | pbar.update(chunk_size) 16 | fh.write(chunk) 17 | 18 | class CustomTensorDataset(Dataset): 19 | def __init__(self, data_tensor, transform=None): 20 | self.data_tensor = data_tensor 21 | self.transform = transform 22 | self.indices = range(len(self)) 23 | 24 | def __getitem__(self, index1): 25 | index2 = random.choice(self.indices) 26 | 27 | img1 = self.data_tensor[index1] 28 | img2 = self.data_tensor[index2] 29 | if self.transform is not None: 30 | img1 = self.transform(img1) 31 | img2 = self.transform(img2) 32 | 33 | return img1, img2 34 | 35 | def __len__(self): 36 | return self.data_tensor.size(0) -------------------------------------------------------------------------------- /src/models/BiGAN.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | 3 | import hydra 4 | import pytorch_lightning as pl 5 | import torch 6 | import torch.nn.functional as F 7 | import torchmetrics 8 | import torchvision 9 | from src.networks.basic import MLPEncoder 10 | from src.utils.losses import adversarial_loss 11 | from torch import nn 12 | 13 | from .base import BaseModel, ValidationResult 14 | 15 | 16 | class BiGAN(BaseModel): 17 | def __init__( 18 | self, 19 | datamodule, 20 | encoder, 21 | decoder, 22 | latent_dim=100, 23 | hidden_dim=512, 24 | loss_mode="vanilla", 25 | lrG: float = 0.0002, 26 | lrD: float = 0.0002, 27 | b1: float = 0.5, 28 | b2: float = 0.999, 29 | ): 30 | super().__init__(datamodule) 31 | self.save_hyperparameters() 32 | 33 | # networks 34 | self.decoder = hydra.utils.instantiate( 35 | decoder, input_channel=latent_dim, output_channel=self.channels 36 | ) 37 | self.encoder = hydra.utils.instantiate( 38 | encoder, input_channel=self.channels, output_channel=latent_dim 39 | ) 40 | self.discriminator = Discriminator(encoder, self.channels, latent_dim, hidden_dim) 41 | self.automatic_optimization = False 42 | 43 | def forward(self, z): 44 | output = self.decoder(z) 45 | output = output.reshape( 46 | z.shape[0], self.channels, self.height, self.width 47 | ) 48 | return output 49 | 50 | def configure_optimizers(self): 51 | lrG = self.hparams.lrG 52 | lrD = self.hparams.lrD 53 | b1 = self.hparams.b1 54 | b2 = self.hparams.b2 55 | g_param = itertools.chain(self.encoder.parameters(), self.decoder.parameters()) 56 | d_param = self.discriminator.parameters() 57 | opt_g = torch.optim.Adam(g_param, lr=lrG, betas=(b1, b2)) 58 | opt_d = torch.optim.Adam(d_param, lr=lrD, betas=(b1, b2)) 59 | return [opt_g, opt_d] 60 | 61 | def training_step(self, batch, batch_idx): 62 | imgs, _ = batch # (N, C, H, W) 63 | z = torch.randn(imgs.shape[0], self.hparams.latent_dim).to(self.device) # (N, latent_dim) 64 | 65 | optim_g, optim_d = self.optimizers() 66 | 67 | real_pair = imgs, self.encoder(imgs) 68 | fake_pair = self.decoder(z), z 69 | 70 | real_logit = self.discriminator(*real_pair) 71 | fake_logit = self.discriminator(*fake_pair) 72 | 73 | mode = self.hparams.loss_mode 74 | g_loss = adversarial_loss(real_logit, False, mode) + adversarial_loss(fake_logit, True, mode) 75 | d_loss = adversarial_loss(real_logit, True, mode) + adversarial_loss(fake_logit, False, mode) 76 | 77 | optim_g.zero_grad() 78 | self.manual_backward(g_loss, retain_graph=True) 79 | optim_g.step() 80 | 81 | optim_d.zero_grad() 82 | self.manual_backward(d_loss, inputs=list(self.discriminator.parameters()), retain_graph=True) 83 | optim_d.step() 84 | 85 | self.log("train_loss/g_loss", g_loss) 86 | self.log("train_loss/d_loss", d_loss) 87 | self.log("train_log/real_logit", real_logit.mean()) 88 | self.log("train_log/fake_logit", fake_logit.mean()) 89 | 90 | def validation_step(self, batch, batch_idx): 91 | img, _ = batch 92 | z = torch.randn(img.shape[0], self.hparams.latent_dim).to(self.device) 93 | fake_img = self.forward(z) 94 | 95 | encode_z = self.encoder(img) 96 | recon_img = self.decoder(encode_z) 97 | return ValidationResult(real_image=img, fake_image=fake_img, recon_image=recon_img, encode_latent=encode_z) 98 | 99 | 100 | class Discriminator(nn.Module): 101 | def __init__(self, encoder, input_channel, latent_dim, hidden_dim) -> None: 102 | super().__init__() 103 | self.dis_z = MLPEncoder( 104 | input_channel=latent_dim, 105 | output_channel=hidden_dim, 106 | width=1, 107 | height=1, 108 | hidden_dims=[hidden_dim, hidden_dim], 109 | output_act="leaky_relu", 110 | ) 111 | self.dis_x = hydra.utils.instantiate( 112 | encoder, input_channel=input_channel, output_channel=hidden_dim 113 | ) 114 | self.dis_pair = MLPEncoder( 115 | input_channel=2 * hidden_dim, 116 | output_channel=1, 117 | width=1, 118 | height=1, 119 | hidden_dims=[hidden_dim], 120 | ) 121 | 122 | def forward(self, x, z): 123 | z_feature = self.dis_z(z) 124 | x_feature = self.dis_x(x) 125 | concat_feature = torch.cat((z_feature, x_feature), dim=1) 126 | return self.dis_pair(concat_feature) 127 | -------------------------------------------------------------------------------- /src/models/aae.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adversarial Autoencoder 3 | https://arxiv.org/abs/1511.05644 4 | """ 5 | import itertools 6 | import numpy as np 7 | import hydra 8 | import torch 9 | import torch.nn.functional as F 10 | 11 | from src.utils.toy import ToyGMM 12 | from src.utils.losses import adversarial_loss 13 | from src.networks.basic import MLPEncoder 14 | from .base import BaseModel, ValidationResult 15 | 16 | 17 | class AAE(BaseModel): 18 | def __init__( 19 | self, 20 | datamodule, 21 | encoder, 22 | decoder, 23 | netD, 24 | latent_dim=100, 25 | loss_mode="vanilla", 26 | lrG: float = 0.0002, 27 | lrD: float = 0.0002, 28 | b1: float = 0.5, 29 | b2: float = 0.999, 30 | recon_weight=1, 31 | prior="normal", 32 | ): 33 | super().__init__(datamodule) 34 | self.save_hyperparameters() 35 | # networks 36 | self.decoder = hydra.utils.instantiate( 37 | decoder, input_channel=latent_dim, output_channel=self.channels 38 | ) 39 | self.encoder = hydra.utils.instantiate( 40 | encoder, input_channel=self.channels, output_channel=latent_dim 41 | ) 42 | self.discriminator = MLPEncoder( 43 | input_channel=latent_dim, output_channel=1, hidden_dims=[256, 256], width=1, height=1, norm_type="layer" 44 | ) 45 | self.automatic_optimization = False 46 | 47 | def forward(self, z): 48 | output = self.decoder(z) 49 | output = output.reshape( 50 | z.shape[0], self.channels, self.height, self.width 51 | ) 52 | return output 53 | 54 | def configure_optimizers(self): 55 | lrG = self.hparams.lrG 56 | lrD = self.hparams.lrD 57 | b1 = self.hparams.b1 58 | b2 = self.hparams.b2 59 | 60 | opt_g = torch.optim.Adam( 61 | itertools.chain(self.encoder.parameters(),self.decoder.parameters()), lr=lrG, betas=(b1, b2) 62 | ) 63 | opt_d = torch.optim.Adam( 64 | self.discriminator.parameters(), lr=lrD, betas=(b1, b2) 65 | ) 66 | return [opt_g, opt_d] 67 | 68 | def sample_prior(self, N): 69 | if self.hparams.prior == "normal": 70 | samples = torch.randn(N, self.hparams.latent_dim) 71 | elif self.hparams.prior == "toy_gmm": 72 | samples, _ = ToyGMM(10).sample(N) 73 | return samples.to(self.device) 74 | 75 | def training_step(self, batch, batch_idx, optimizer_idx): 76 | imgs, _ = batch # (N, C, H, W) 77 | N = imgs.shape[0] 78 | opt_g, opt_d = self.optimizers() 79 | 80 | # reconstruction phase 81 | q_z = self.encoder(imgs) # (N, hidden_dim) 82 | generated_imgs = self.decoder(q_z) 83 | recon_loss = F.mse_loss(imgs, generated_imgs) 84 | 85 | self.log("train_loss/recon_loss", recon_loss) 86 | opt_g.zero_grad() 87 | self.manual_backward(recon_loss*self.hparams.recon_weight) 88 | opt_g.step() 89 | 90 | # regularization phase 91 | # update discriminator 92 | real_prior = self.sample_prior(N) 93 | real_logit = self.discriminator(real_prior) 94 | real_loss = adversarial_loss(real_logit, True, self.hparams.loss_mode) 95 | fake_logit = self.discriminator(self.encoder(imgs)) 96 | fake_loss = adversarial_loss(fake_logit, False, self.hparams.loss_mode) 97 | d_adv_loss = (real_loss + fake_loss) / 2 98 | self.log("train_loss/d_loss", d_adv_loss) 99 | self.log("train_log/real_logit", real_logit.mean()) 100 | self.log("train_log/fake_logit", fake_logit.mean()) 101 | 102 | opt_d.zero_grad() 103 | self.manual_backward(d_adv_loss) 104 | opt_d.step() 105 | 106 | # update generator 107 | q_z = self.encoder(imgs) 108 | g_adv_loss = adversarial_loss(self.discriminator(q_z), True, self.hparams.loss_mode) 109 | self.log("train_loss/adv_encoder_loss", g_adv_loss) 110 | 111 | opt_g.zero_grad() 112 | self.manual_backward(g_adv_loss) 113 | opt_g.step() 114 | 115 | 116 | def validation_step(self, batch, batch_idx): 117 | imgs, label = batch 118 | z = self.encoder(imgs) 119 | recon_imgs = self.decoder(z) 120 | 121 | sample_z = self.sample_prior(imgs.shape[0]) 122 | sample_imgs = self.decoder(sample_z) 123 | return ValidationResult(real_image=imgs, fake_image=sample_imgs, recon_image=recon_imgs, label=label, encode_latent=z) 124 | -------------------------------------------------------------------------------- /src/models/age.py: -------------------------------------------------------------------------------- 1 | """ 2 | It Takes (Only) Two: Adversarial Generator-Encoder Networks 3 | https://arxiv.org/abs/1704.02304 4 | """ 5 | import hydra 6 | import torch 7 | import torch.nn.functional as F 8 | from .base import BaseModel, ValidationResult 9 | 10 | 11 | class AGE(BaseModel): 12 | def __init__( 13 | self, 14 | datamodule, 15 | encoder, 16 | decoder, 17 | lrE, 18 | lrG, 19 | latent_dim=128, 20 | b1: float = 0.5, 21 | b2: float = 0.999, 22 | e_recon_z_weight=1000, 23 | e_recon_x_weight=0, 24 | g_recon_z_weight=0, 25 | g_recon_x_weight=10, 26 | norm_z=True, 27 | drop_lr_epoch=20, 28 | g_updates=2, # number of decoder iterates relative to encoder 29 | ): 30 | super().__init__(datamodule) 31 | self.save_hyperparameters() 32 | # networks 33 | self.decoder = hydra.utils.instantiate( 34 | decoder, input_channel=latent_dim, output_channel=self.channels 35 | ) 36 | self.encoder = hydra.utils.instantiate( 37 | encoder, input_channel=self.channels, output_channel=latent_dim 38 | ) 39 | 40 | def forward(self, z): 41 | output = self.decoder(z) 42 | output = output.reshape( 43 | z.shape[0], self.channels, self.height, self.width 44 | ) 45 | return output 46 | 47 | def configure_optimizers(self): 48 | lrG = self.hparams.lrG 49 | lrE = self.hparams.lrE 50 | b1 = self.hparams.b1 51 | b2 = self.hparams.b2 52 | 53 | lambda_func = lambda epoch: 0.5 ** (epoch // self.hparams.drop_lr_epoch) 54 | opt_e = torch.optim.Adam(self.encoder.parameters(), lr=lrE, betas=(b1, b2)) 55 | e_scheduler = torch.optim.lr_scheduler.LambdaLR(opt_e, lr_lambda=lambda_func) 56 | 57 | opt_g = torch.optim.Adam(self.decoder.parameters(), lr=lrG, betas=(b1, b2)) 58 | g_scheduler = torch.optim.lr_scheduler.LambdaLR(opt_g, lr_lambda=lambda_func) 59 | return [ 60 | {"optimizer": opt_e, "frequency": 1, "scheduler": e_scheduler}, 61 | {"optimizer": opt_g, "frequency": self.hparams.g_updates, "scheduler": g_scheduler} 62 | ] 63 | 64 | def calculate_kl(self, samples: torch.Tensor, return_state=False): 65 | """Calcuate KL divergence between fitted gaussian distribution and standard normal distribution 66 | """ 67 | assert samples.dim() == 2 68 | mu = samples.mean(dim=0) # (d) 69 | var = samples.var(dim=0) # (d) 70 | kl_div = (mu**2+var-torch.log(var)).mean()/2 71 | if return_state: 72 | return kl_div, mu.mean(), var.mean() 73 | else: 74 | return kl_div 75 | 76 | def encode(self, imgs): 77 | N = imgs.shape[0] 78 | z = self.encoder(imgs).reshape(N, -1) 79 | if self.hparams.norm_z: 80 | z = F.normalize(z) 81 | return z 82 | 83 | def training_step(self, batch, batch_idx, optimizer_idx): 84 | imgs, labels = batch # (N, C, H, W) 85 | N = imgs.shape[0] 86 | 87 | # sample noise 88 | # NOTE: the input to decoder and output to encoder are both in sphere latent space 89 | # This is useful to prevent kl divergence from explosion 90 | z = torch.randn(imgs.shape[0], self.hparams.latent_dim).type_as(imgs) # (N, latent_dim) 91 | if self.hparams.norm_z: 92 | z = F.normalize(z) 93 | 94 | # train encoder 95 | if optimizer_idx == 0: 96 | # divergence between prior and encoded real samples 97 | real_z = self.encode(imgs) # (N, latent_dim) 98 | real_kl, real_mu, real_var = self.calculate_kl(real_z, return_state=True) 99 | 100 | # divergence between prior and encoded generated samples 101 | fake_imgs = self.decoder(z) 102 | fake_z = self.encode(fake_imgs) 103 | fake_kl, fake_mu, fake_var = self.calculate_kl(fake_z, return_state=True) 104 | 105 | recon_x_loss = 0 106 | if self.hparams.e_recon_x_weight > 0: 107 | # recon_x, also prevent encoder mode collapse 108 | recon_imgs = self.decoder(real_z) 109 | recon_x_loss = F.mse_loss(imgs, recon_imgs, reduction="mean") 110 | self.log("train_loss/recon_x", recon_x_loss) 111 | 112 | recon_z_loss = 0 113 | if self.hparams.e_recon_z_weight > 0: 114 | recon_z_loss = 1-F.cosine_similarity(fake_z, z).mean() 115 | 116 | total_e_loss = real_kl-fake_kl+self.hparams.e_recon_x_weight*recon_x_loss+self.hparams.e_recon_z_weight*recon_z_loss 117 | 118 | self.log("train_loss/real_kl", real_kl) 119 | self.log("train_loss/fake_kl", fake_kl) 120 | self.log("train_loss/total_e_loss", total_e_loss) 121 | self.log("train_log/real_mu", real_mu) 122 | self.log("train_log/real_var", real_var) 123 | self.log("train_log/fake_mu", fake_mu) 124 | self.log("train_log/fake_var", fake_var) 125 | return total_e_loss 126 | 127 | # train decoder 128 | if optimizer_idx == 1: 129 | fake_imgs = self.decoder(z) 130 | fake_z = self.encode(fake_imgs) 131 | fake_kl = self.calculate_kl(fake_z) 132 | 133 | # recon_loss = 1-F.cosine_similarity(fake_z, z).mean() 134 | recon_z_loss = 0 135 | if self.hparams.g_recon_z_weight > 0: 136 | recon_z_loss = F.mse_loss(fake_z, z) 137 | 138 | recon_x_loss = 0 139 | if self.hparams.g_recon_x_weight > 0: 140 | real_z = self.encode(imgs) 141 | recon_x = self.decoder(real_z) 142 | recon_x_loss = F.mse_loss(imgs, recon_x) 143 | 144 | total_g_loss = fake_kl + self.hparams.g_recon_z_weight*recon_z_loss + self.hparams.g_recon_x_weight*recon_x_loss 145 | 146 | self.log("train_loss/g_recon_z", recon_z_loss) 147 | self.log("train_loss/g_loss", total_g_loss) 148 | return total_g_loss 149 | 150 | def validation_step(self, batch, batch_idx): 151 | img, _ = batch 152 | z = torch.randn(img.shape[0], self.hparams.latent_dim).to(self.device) 153 | if self.hparams.norm_z: 154 | z = F.normalize(z) 155 | 156 | fake_img = self.forward(z) 157 | encode_z = self.encode(img) 158 | recon_img = self.decoder(encode_z) 159 | return ValidationResult(real_image=img, fake_image=fake_img, recon_image=recon_img, encode_latent=encode_z) 160 | -------------------------------------------------------------------------------- /src/models/base.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from pytorch_lightning import LightningModule 3 | from src.utils.utils import get_logger 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | @dataclass 8 | class ValidationResult(): 9 | others: dict = field(default_factory=dict) 10 | real_image: torch.Tensor = None 11 | fake_image: torch.Tensor = None 12 | recon_image: torch.Tensor = None 13 | label: torch.Tensor = None 14 | encode_latent: torch.Tensor = None 15 | 16 | class BaseModel(LightningModule): 17 | def __init__(self, datamodule) -> None: 18 | super().__init__() 19 | self.console = get_logger() 20 | self.width = datamodule.width 21 | self.height = datamodule.height 22 | self.channels = datamodule.channels 23 | self.input_normalize = datamodule.transforms.normalize 24 | if self.input_normalize: 25 | self.output_act = "tanh" 26 | else: 27 | self.output_act = "sigmoid" 28 | 29 | def sample(self, N: int): 30 | z = torch.randn(N, self.hparams.latent_dim).to(self.device) 31 | return self.forward(z) 32 | 33 | -------------------------------------------------------------------------------- /src/models/cvae.py: -------------------------------------------------------------------------------- 1 | from hydra.utils import instantiate 2 | import torch 3 | from omegaconf import OmegaConf 4 | 5 | from src.models.base import BaseModel, ValidationResult 6 | from torch import distributions 7 | from src.utils.distributions import get_decode_dist 8 | from src.utils.losses import normal_kld 9 | import torch.nn.functional as F 10 | 11 | ## q(z | x, c): encoder, append one-hot vector to all pixels 12 | ## p(x | z, c): decoder, append embedding to z 13 | ## p(z | c): prior distribution, can be fixed or learned 14 | ## Loss: -KL(q(z|x,c) || p(z|c)) + E_{z~q(z|x,c)}(log p(x|z, c)) 15 | 16 | class cVAE(BaseModel): 17 | def __init__( 18 | self, 19 | datamodule: OmegaConf = None, 20 | encoder: OmegaConf = None, 21 | decoder: OmegaConf = None, 22 | latent_dim: int = 100, 23 | beta: float = 1.0, 24 | recon_weight: float = 1.0, 25 | lr: float = 1e-4, 26 | b1: float = 0.9, 27 | b2: float = 0.999, 28 | n_classes: int = None, 29 | encode_label: bool = True, 30 | decoder_dist = "guassian" 31 | ): 32 | super().__init__(datamodule) 33 | self.save_hyperparameters() 34 | 35 | self.decoder = instantiate(decoder, input_channel=latent_dim*2, output_channel=self.channels, output_act=self.output_act) 36 | if encode_label: 37 | self.encoder = instantiate(encoder, input_channel=self.channels+n_classes, output_channel=2 * latent_dim) 38 | else: 39 | self.encoder = instantiate(encoder, input_channel=self.channels, output_channel=2 * latent_dim) 40 | self.class_embedding = torch.nn.Embedding(n_classes, latent_dim) 41 | self.decoder_dist = get_decode_dist(decoder_dist) 42 | self.n_classes = n_classes 43 | 44 | def forward(self, z, labels): 45 | """Generate images given latent code.""" 46 | embed = self.class_embedding(labels) 47 | z = torch.cat([z, embed], dim=1) 48 | output = self.decoder(z) 49 | output = output.reshape(output.shape[0], self.channels, self.height, self.width) 50 | return output 51 | 52 | def configure_optimizers(self): 53 | lr = self.hparams.lr 54 | b1 = self.hparams.b1 55 | b2 = self.hparams.b2 56 | opt = torch.optim.Adam(self.parameters(), lr=lr, betas=(b1, b2)) 57 | scheduler = torch.optim.lr_scheduler.StepLR(opt, 1, gamma=0.99) 58 | return [opt], [scheduler] 59 | 60 | def reparameterize(self, mu, log_sigma): 61 | post_dist = distributions.Normal(mu, torch.exp(log_sigma)) 62 | samples_z = post_dist.rsample() 63 | return samples_z 64 | 65 | def vae(self, imgs, labels): 66 | N, C, H, W = imgs.shape 67 | if self.hparams.encode_label: 68 | labels_onehot = F.one_hot(labels, num_classes=self.n_classes).reshape(N, self.n_classes, 1, 1).expand(N, self.n_classes, H, W) 69 | imgs = torch.cat([imgs, labels_onehot], dim=1) 70 | z_ = self.encoder(imgs) # (N, latent_dim) 71 | mu, log_sigma = torch.chunk(z_, chunks=2, dim=1) 72 | z = self.reparameterize(mu, log_sigma) 73 | recon_imgs = self.forward(z, labels) 74 | return mu, log_sigma, z, recon_imgs 75 | 76 | def training_step(self, batch, batch_idx): 77 | imgs, labels = batch # (N, C, H, W) 78 | mu, log_sigma, z, recon_imgs = self.vae(imgs, labels) 79 | kld = normal_kld(mu, log_sigma) 80 | 81 | log_p_x_of_z = self.decoder_dist.prob(recon_imgs, imgs).mean(dim=0) 82 | elbo = -self.hparams.beta*kld + self.hparams.recon_weight * log_p_x_of_z 83 | 84 | self.log("train_log/elbo", elbo) 85 | self.log("train_log/kl_divergence", kld) 86 | self.log("train_log/log_p_x_of_z", log_p_x_of_z) 87 | return -elbo 88 | 89 | def sample(self, N): 90 | labels = torch.arange(self.n_classes, device=self.device).reshape(self.n_classes, 1).repeat(1, N).reshape(-1) 91 | z = torch.randn(N*self.n_classes, self.hparams.latent_dim, device=self.device) 92 | imgs = self.forward(z, labels) 93 | return imgs 94 | 95 | def validation_step(self, batch, batch_idx): 96 | imgs, labels = batch 97 | mu, log_sigma, z, recon_imgs = self.vae(imgs, labels) 98 | log_p_x_of_z = self.decoder_dist.prob(recon_imgs, imgs).mean(dim=0) 99 | 100 | fake_imgs = self.sample(8) 101 | self.log("val_log/log_p_x_of_z", log_p_x_of_z) 102 | return ValidationResult(real_image=imgs, fake_image=fake_imgs, 103 | recon_image=recon_imgs, label=labels, encode_latent=z) -------------------------------------------------------------------------------- /src/models/factor_vae.py: -------------------------------------------------------------------------------- 1 | """Autoencoding beyond pixels using a learned similarity metric""" 2 | import itertools 3 | import hydra 4 | import torch 5 | from omegaconf import OmegaConf 6 | import torch.distributions as D 7 | 8 | from .base import BaseModel, ValidationResult 9 | from src.utils.distributions import get_decode_dist 10 | from src.networks.basic import MLPEncoder 11 | from src.utils.losses import adversarial_loss, normal_kld 12 | 13 | def permute_dims(z): 14 | assert z.dim() == 2 15 | 16 | B, _ = z.size() 17 | perm_z = [] 18 | for z_j in z.split(1, 1): 19 | perm = torch.randperm(B).to(z.device) 20 | perm_z_j = z_j[perm] 21 | perm_z.append(perm_z_j) 22 | return torch.cat(perm_z, 1) 23 | 24 | class FactorVAE(BaseModel): 25 | def __init__( 26 | self, 27 | datamodule, 28 | encoder: OmegaConf = None, 29 | decoder: OmegaConf = None, 30 | loss_mode: str = 'lsgan', 31 | adv_weight: float = 1, 32 | latent_dim=10, 33 | lr: float = 0.0002, 34 | lrD: float = 0.0001, 35 | ae_b1: float = 0.9, # adam parameter for encoder and decoder 36 | ae_b2: float = 0.999, 37 | adv_b1: float = 0.5, # adam paramter for discriminator 38 | adv_b2: float = 0.999, 39 | decoder_dist="gaussian" 40 | ): 41 | super().__init__(datamodule) 42 | self.save_hyperparameters() 43 | 44 | self.decoder = hydra.utils.instantiate(decoder, input_channel=latent_dim, output_channel=self.channels, output_act=self.output_act) 45 | self.decoder_dist = get_decode_dist(decoder_dist) 46 | 47 | self.encoder = hydra.utils.instantiate(encoder, input_channel=self.channels, output_channel=latent_dim*2) 48 | self.netD = MLPEncoder(input_channel=latent_dim, hidden_dims=[256, 256], output_channel=1, width=1, height=1) 49 | self.automatic_optimization = False # disable automatic optimization 50 | 51 | def forward(self, z): 52 | output = self.decoder(z) 53 | output = output.reshape(output.shape[0], self.channels, self.height, self.width) 54 | output = self.decoder_dist.sample(output) 55 | return output 56 | 57 | def configure_optimizers(self): 58 | lr = self.hparams.lr 59 | lrD = self.hparams.lrD 60 | ae_b1 = self.hparams.ae_b1 61 | ae_b2 = self.hparams.ae_b2 62 | adv_b1 = self.hparams.adv_b1 63 | adv_b2 = self.hparams.adv_b2 64 | 65 | ae_optim = torch.optim.Adam(itertools.chain(self.encoder.parameters(), self.decoder.parameters()), lr=lr, betas=(ae_b1, ae_b2),) 66 | discriminator_optim = torch.optim.Adam(self.netD.parameters(), lr=lrD, betas=(adv_b1, adv_b2),) 67 | return ae_optim, discriminator_optim 68 | 69 | def reparameterize(self, mu, log_sigma): 70 | post_dist = D.Normal(mu, torch.exp(log_sigma)) 71 | samples_z = post_dist.rsample() 72 | return samples_z 73 | 74 | def encode(self, imgs): 75 | z_ = self.encoder(imgs) # (N, latent_dim) 76 | mu, log_sigma = torch.chunk(z_, chunks=2, dim=1) 77 | z = self.reparameterize(mu, log_sigma) 78 | return z, mu, log_sigma 79 | 80 | def vae(self, imgs): 81 | z, mu, log_sigma = self.encode(imgs) 82 | recon_imgs = self.decoder(z) 83 | return z, recon_imgs, mu, log_sigma, 84 | 85 | def training_step(self, batch, batch_idx): 86 | ae_optim, discriminator_optim = self.optimizers() 87 | imgs, _ = batch 88 | imgs1, imgs2 = torch.chunk(imgs, 2, dim=0) 89 | 90 | # auto-encoding 91 | z1_samples, recon_imgs, mu, log_sigma = self.vae(imgs1) 92 | 93 | reg_loss = normal_kld(mu, log_sigma) 94 | recon_loss = -self.decoder_dist.prob(recon_imgs, imgs1).mean(dim=0) 95 | 96 | fake_logit = self.netD(z1_samples) 97 | g_adv_loss = adversarial_loss(fake_logit, target_is_real=True, loss_mode=self.hparams.loss_mode) 98 | encoder_loss = recon_loss + reg_loss + self.hparams.adv_weight * g_adv_loss 99 | 100 | ae_optim.zero_grad() 101 | encoder_loss.backward(retain_graph=True) 102 | ae_optim.step() 103 | 104 | # # discrimination 105 | z2_samples, _, _ = self.encode(imgs2) # (N, latent_dim) 106 | perm_z = permute_dims(z2_samples) 107 | 108 | real_logit = self.netD(perm_z) 109 | d_adv_loss = adversarial_loss(real_logit, True, self.hparams.loss_mode) + adversarial_loss(fake_logit, False, self.hparams.loss_mode) 110 | 111 | discriminator_optim.zero_grad() 112 | d_adv_loss.backward(inputs=list(self.netD.parameters())) 113 | discriminator_optim.step() 114 | 115 | self.log("train_loss/reg_loss", reg_loss) 116 | self.log("train_loss/recon_loss", recon_loss, prog_bar=True) 117 | self.log("train_loss/d_adv_loss", d_adv_loss) 118 | self.log("train_loss/g_adv_loss", g_adv_loss) 119 | self.log("train_log/real_logit", torch.mean(real_logit)) 120 | self.log("train_log/fake_logit", torch.mean(fake_logit)) 121 | 122 | def validation_step(self, batch, batch_idx): 123 | imgs, label = batch 124 | N = imgs.shape[0] 125 | 126 | z, recon_imgs, mu, log_sigma = self.vae(imgs) 127 | fake_image = self.sample(N) 128 | return ValidationResult(real_image=imgs, fake_image=fake_image, 129 | recon_image=recon_imgs, encode_latent=z, label=label) -------------------------------------------------------------------------------- /src/models/gan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from omegaconf import OmegaConf 3 | from .base import BaseModel, ValidationResult 4 | from hydra.utils import instantiate 5 | from src.utils.losses import adversarial_loss 6 | 7 | class GAN(BaseModel): 8 | def __init__( 9 | self, 10 | datamodule: OmegaConf, 11 | netG: OmegaConf, 12 | netD: OmegaConf, 13 | latent_dim: int = 100, 14 | loss_mode: str = "vanilla", 15 | lrG: float = 2e-4, 16 | lrD: float = 2e-4, 17 | b1: float = 0.5, 18 | b2: float = 0.999, 19 | ): 20 | super().__init__(datamodule) 21 | self.save_hyperparameters() 22 | self.netG = instantiate(netG, input_channel=latent_dim, output_channel=self.channels) 23 | self.netD = instantiate(netD, input_channel=self.channels, output_channel=1) 24 | self.automatic_optimization = False 25 | 26 | def forward(self, z): 27 | output = self.netG(z) 28 | output = output.reshape(z.shape[0], self.channels, self.height, self.width) 29 | return output 30 | 31 | def configure_optimizers(self): 32 | lrG, lrD = self.hparams.lrG, self.hparams.lrD 33 | b1, b2 = self.hparams.b1, self.hparams.b2 34 | opt_g = torch.optim.Adam(self.netG.parameters(), lr=lrG, betas=(b1, b2)) 35 | opt_d = torch.optim.Adam(self.netD.parameters(), lr=lrD, betas=(b1, b2)) 36 | return [opt_g, opt_d] 37 | 38 | def training_step(self, batch, batch_idx): 39 | imgs, _ = batch # (N, C, H, W) 40 | N, C, H, W = imgs.shape 41 | z = torch.randn(N, self.hparams.latent_dim).to(self.device) 42 | 43 | opt_g, opt_d = self.optimizers() 44 | 45 | if batch_idx % 2 == 0: 46 | self.toggle_optimizer(opt_g) 47 | fake_imgs = self.netG(z) 48 | pred_fake = self.netD(fake_imgs) 49 | g_loss = adversarial_loss(pred_fake, target_is_real=True, loss_mode=self.hparams.loss_mode) 50 | 51 | opt_g.zero_grad() 52 | self.manual_backward(g_loss) 53 | opt_g.step() 54 | self.untoggle_optimizer(opt_g) 55 | 56 | self.log("train_loss/g_loss", g_loss) 57 | else: 58 | self.toggle_optimizer(opt_d) 59 | pred_real = self.netD(imgs) 60 | real_loss = adversarial_loss(pred_real, target_is_real=True, loss_mode=self.hparams.loss_mode) 61 | 62 | fake_imgs = self.netG(z).detach() 63 | pred_fake = self.netD(fake_imgs) 64 | fake_loss = adversarial_loss(pred_fake, target_is_real=False, loss_mode=self.hparams.loss_mode) 65 | 66 | d_loss = (real_loss + fake_loss) / 2 67 | 68 | opt_d.zero_grad() 69 | self.manual_backward(d_loss) 70 | opt_d.step() 71 | self.untoggle_optimizer(d_loss) 72 | 73 | self.log("train_loss/d_loss", d_loss) 74 | self.log("train_log/pred_real", pred_real.mean()) 75 | self.log("train_log/pred_fake", pred_fake.mean()) 76 | 77 | def validation_step(self, batch, batch_idx): 78 | img, _ = batch 79 | z = torch.randn(img.shape[0], self.hparams.latent_dim).to(self.device) 80 | fake_imgs = self.forward(z) 81 | return ValidationResult(real_image=img, fake_image=fake_imgs) 82 | -------------------------------------------------------------------------------- /src/models/info_gan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import hydra 3 | import torch.nn.functional as F 4 | from torch import nn 5 | from src.models.base import BaseModel 6 | from src.utils.losses import adversarial_loss 7 | from src.callbacks.visualization import get_grid_images 8 | import itertools 9 | 10 | 11 | class InfoGAN(BaseModel): 12 | def __init__( 13 | self, 14 | datamodule, 15 | netG, 16 | netD, 17 | lambda_I=1, # loss weight for mutual information 18 | discrete_dim=1, # discrete latent variable dimension 19 | discrete_value=10, # the value range of discrete latent variable 20 | continuous_dim=2, 21 | noise_dim=62, 22 | encode_dim=1024, # intermediate dim for common layer 23 | loss_mode="vanilla", 24 | lrG: float = 0.001, 25 | lrD: float = 0.0002, 26 | lrQ: float = 0.0002, 27 | b1: float = 0.5, 28 | b2: float = 0.999, 29 | ): 30 | super().__init__(datamodule) 31 | self.save_hyperparameters() 32 | 33 | self.latent_dim = discrete_dim*discrete_value + continuous_dim + noise_dim 34 | # networks 35 | self.netG = hydra.utils.instantiate(netG, input_channel=self.latent_dim, output_channel=self.channels) 36 | self.common_layer = hydra.utils.instantiate(netD, input_channel=self.channels, output_channel=encode_dim) 37 | self.netD = nn.Sequential(nn.LeakyReLU(), nn.Linear(encode_dim, 1)) 38 | self.netQ = nn.Sequential( 39 | nn.LeakyReLU(), 40 | nn.Linear(encode_dim, 128), 41 | nn.LeakyReLU(), 42 | nn.Linear(128, discrete_dim*discrete_value + continuous_dim), 43 | ) 44 | 45 | def configure_optimizers(self): 46 | lrG = self.hparams.lrG 47 | lrD = self.hparams.lrD 48 | lrQ = self.hparams.lrQ 49 | b1 = self.hparams.b1 50 | b2 = self.hparams.b2 51 | q_param = self.netQ.parameters() 52 | g_param = self.netG.parameters() 53 | d_param = itertools.chain( 54 | self.netD.parameters(), self.common_layer.parameters() 55 | ) 56 | 57 | opt_g = torch.optim.Adam( 58 | [{"params": g_param, "lr": lrG}, {"params": q_param, "lr": lrQ}], 59 | betas=(b1, b2), 60 | ) 61 | opt_d = torch.optim.Adam(d_param, lr=lrD, betas=(b1, b2)) 62 | return [opt_g, opt_d] 63 | 64 | def encode(self, x, return_posterior=False): 65 | x = self.common_layer(x) 66 | adv_logit = self.netD(x) 67 | if return_posterior: 68 | output = self.netQ(x) 69 | dis_c_logits = output[:, :-self.hparams.continuous_dim].reshape(-1, 70 | self.hparams.discrete_value, self.hparams.discrete_dim) 71 | cont_c = output[:, -self.hparams.continuous_dim:] 72 | return adv_logit, dis_c_logits, cont_c 73 | else: 74 | return adv_logit 75 | 76 | def decode(self, N, dis_c_index=None, cont_c=None, z=None, return_latents=False): 77 | """ 78 | N: batch_size 79 | disc_c_index: tensor of shape (N, discrete_dim) 80 | """ 81 | if dis_c_index == None: 82 | dis_c_index = torch.randint(0, self.hparams.discrete_value, (N, self.hparams.discrete_dim)).to(self.device) # (N, discrete_dim) 83 | dis_c = torch.zeros(N, self.hparams.discrete_value, self.hparams.discrete_dim).to(self.device) # (N, discrete_value, disrete_dim) 84 | dis_c.scatter_(1, dis_c_index.unsqueeze(1), torch.ones_like(dis_c)) 85 | 86 | if cont_c == None: 87 | cont_c = torch.zeros(N, self.hparams.continuous_dim, device=self.device).uniform_(-1, 1) 88 | 89 | if z == None: 90 | z = torch.randn(N, self.hparams.noise_dim).to(self.device) 91 | 92 | output = self.netG(torch.cat([dis_c.reshape(N, -1), cont_c, z], dim=1)) 93 | output = output.reshape(z.shape[0], self.channels, self.height, self.width) 94 | if return_latents: 95 | return output, (dis_c_index, cont_c, z) 96 | else: 97 | return output 98 | 99 | def training_step(self, batch, batch_idx, optimizer_idx): 100 | imgs, _ = batch 101 | N = imgs.shape[0] 102 | 103 | # train generator 104 | if optimizer_idx == 0: 105 | generated_imgs, (dis_c, cont_c, z) = self.decode(N, return_latents=True) 106 | adv_logit, dis_c_logits, cont_c_hat = self.encode(generated_imgs, return_posterior=True) 107 | g_loss = adversarial_loss(adv_logit, target_is_real=True) 108 | 109 | # mutual information loss 110 | I_discete_loss = F.cross_entropy(dis_c_logits, dis_c) 111 | I_continuous_loss = F.mse_loss(cont_c_hat, cont_c) 112 | I_loss = I_discete_loss + I_continuous_loss 113 | 114 | self.log("train_loss/g_loss", g_loss) 115 | self.log("train_loss/I_discrete_loss", I_discete_loss) 116 | self.log("train_loss/I_continuous", I_continuous_loss) 117 | 118 | return g_loss + self.hparams.lambda_I * I_loss 119 | 120 | if optimizer_idx == 1: 121 | pred_real = self.netD(self.common_layer(imgs)) 122 | real_loss = adversarial_loss(pred_real, target_is_real=True) 123 | 124 | pred_fake = self.encode(self.decode(N).detach()) 125 | fake_loss = adversarial_loss(pred_fake, target_is_real=False) 126 | 127 | d_loss = (real_loss + fake_loss) / 2 128 | 129 | self.log("train_loss/d_loss", d_loss) 130 | self.log("train_log/pred_real", pred_real.mean()) 131 | self.log("train_log/pred_fake", pred_fake.mean()) 132 | 133 | return d_loss 134 | 135 | def on_train_epoch_end(self) -> None: 136 | generated_images = self.decode(64) 137 | grid_images = get_grid_images(generated_images, self, 64, 8) 138 | self.logger.experiment.add_image("images/sample", grid_images, global_step=self.current_epoch) 139 | 140 | N = 8 # row of images 141 | a, b, c = self.hparams.discrete_value, self.hparams.continuous_dim, self.hparams.noise_dim 142 | # each row has `a` values and totally N rows 143 | # Traverse over discrete latent value while other values are fixed for each N 144 | disc_c = torch.arange(a).reshape(1, a).repeat(N, 1).reshape(N*a, 1).to(self.device) 145 | cont_c = torch.randn(N, 1, b).repeat(1, a, 1).reshape(N*a, b).to(self.device) 146 | z = torch.randn(N, 1, c).repeat(1, a, 1).reshape(N*a, c).to(self.device) # (40, noise_dim) 147 | imgs = self.decode(N*a, disc_c, cont_c, z) 148 | 149 | grid_images = get_grid_images(imgs, self, N*a, a) 150 | self.logger.experiment.add_image("visual/traverse over discrete values", grid_images, global_step=self.current_epoch) 151 | 152 | col = 10 153 | # Traverse over continuous latent values while other values are fixed for each N 154 | disc_c = torch.randint(low=0, high=a, size=(N, 1)).repeat(1, col).reshape(N*col, 1).to(self.device) 155 | cont_c_variation = torch.linspace(-2, 2, col).reshape(1, col).repeat(N, 1).reshape(N*col).to(self.device) 156 | cont_c = torch.randn(N, 1, b).repeat(1, col, 1).reshape(N*col, b).to(self.device) 157 | z = torch.randn(N, 1, c).repeat(1, col, 1).reshape(N*col, c).to(self.device) # (N*a, noise_dim) 158 | 159 | cont_c_mix = cont_c.clone() 160 | cont_c_mix[:, 0] = cont_c_variation 161 | imgs = self.decode(N*col, disc_c, cont_c_mix, z) 162 | grid_images = get_grid_images(imgs, self, N*col, col) 163 | self.logger.experiment.add_image("visual/traverse over first continuous values", grid_images, global_step=self.current_epoch) 164 | 165 | cont_c_mix = cont_c.clone() 166 | cont_c_mix[:, 1] = cont_c_variation 167 | imgs = self.decode(N*col, disc_c, cont_c_mix, z) 168 | grid_images = get_grid_images(imgs, self, N*col, col) 169 | self.logger.experiment.add_image("visual/traverse over second continuous values", grid_images, global_step=self.current_epoch) -------------------------------------------------------------------------------- /src/models/made.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from types import new_class 3 | from torch import nn 4 | import torch 5 | from torch import optim 6 | import torch.nn.functional as F 7 | from tqdm import tqdm 8 | from src.models.base import BaseModel, ValidationResult 9 | from einops import rearrange 10 | 11 | 12 | class MaskedLinear(nn.Module): 13 | def __init__(self, in_channel, out_channel): 14 | super().__init__() 15 | self.model = nn.Linear(in_channel, out_channel) 16 | self.register_buffer("mask", torch.ones(out_channel, in_channel)) 17 | 18 | def set_mask(self, mask): 19 | self.mask = mask 20 | 21 | def forward(self, x): 22 | return F.linear(x, weight=self.model.weight * self.mask, bias=self.model.bias) 23 | 24 | class MADENet(nn.Module): 25 | def __init__(self, in_dim, hidden_dim, n_class, n_layer): 26 | """ 27 | in_dim: vector length of input data 28 | hidden_dim: hidden_dim of vectors 29 | """ 30 | super().__init__() 31 | self.in_dim = in_dim 32 | self.n_layer = n_layer 33 | self.hidden_dim = hidden_dim 34 | self.n_class = n_class 35 | 36 | dims = [in_dim] + [hidden_dim] * n_layer + [in_dim*n_class] 37 | self.layers = [] 38 | for in_feature, out_feature in zip(dims[:-1], dims[1:]): 39 | self.layers.append(MaskedLinear(in_feature, out_feature)) 40 | self.model = nn.Sequential(*self.layers) 41 | self.reset_mask() 42 | 43 | def reset_mask(self): 44 | low = 0 45 | high = self.in_dim 46 | units = [] 47 | data_unit = torch.arange(0, high) 48 | # data_unit = torch.randperm(high) 49 | units.append(data_unit) 50 | 51 | for _ in range(self.n_layer): 52 | hidden_unit = torch.randint(low=low, high=high, size=(self.hidden_dim, )) 53 | units.append(hidden_unit) 54 | low = min(hidden_unit) 55 | units.append(data_unit.unsqueeze(1).repeat(1, self.n_class).reshape(-1) - 1) 56 | 57 | for layer, in_unit, out_unit in zip(self.layers, units[:-1], units[1:]): 58 | mask = out_unit.unsqueeze(1) >= in_unit # (out_features, int_features) 59 | layer.set_mask(mask) 60 | 61 | def forward(self, x): 62 | n, c, h, w = x.shape 63 | x = rearrange(x, "n c h w -> n (c h w)") 64 | for layer in self.layers[:-1]: 65 | x = layer(x) 66 | x = torch.sigmoid(x) 67 | x = self.layers[-1](x) 68 | x = rearrange(x, "n (c h w a) -> n a c h w", n=n, c=c, h=h, w=w, a=self.n_class) 69 | return x 70 | 71 | 72 | class MADE(BaseModel): 73 | def __init__( 74 | self, 75 | datamodule, 76 | hidden_dim, 77 | n_layer, 78 | lr=1e-3 79 | ): 80 | super().__init__(datamodule) 81 | self.save_hyperparameters() 82 | self.model = MADENet(self.width*self.height*self.channels, hidden_dim, n_class=256, n_layer=n_layer) 83 | 84 | self.register_buffer("log2", torch.log(torch.tensor(2, dtype=torch.float32, device=self.device))) 85 | 86 | def forward(self, x, y=None): 87 | """ 88 | Forward image through model and return logits for each pixel. 89 | Inputs: 90 | x - Image tensor in range (0, 1). 91 | y - one-hot vector indicating class label. 92 | """ 93 | logits = self.model(x) 94 | return logits 95 | 96 | def calc_likelihood(self, x, label=None): 97 | # Forward pass with bpd likelihood calculation 98 | pred = self.forward(x, label) 99 | if self.input_normalize: 100 | target = ((x + 1) / 2 * 255).to(torch.long) 101 | else: 102 | target = (x * 255).to(torch.long) 103 | nll = F.cross_entropy(pred, target, reduction="none") # (N, C, H, W) 104 | bpd = nll.mean(dim=[1, 2, 3]) / self.log2 105 | return bpd.mean() 106 | 107 | @torch.no_grad() 108 | def sample(self, img_shape, cond=None, img=None): 109 | """ 110 | Sampling function for the autoregressive model. 111 | Inputs: 112 | img_shape - Shape of the image to generate (B,C,H,W) 113 | img (optional) - If given, this tensor will be used as 114 | a starting image. The pixels to fill 115 | should be -1 in the input tensor. 116 | """ 117 | # Create empty image 118 | if img is None: 119 | img = torch.zeros(img_shape, dtype=torch.float32).to(self.device) - 1 120 | # Generation loop 121 | N, C, H, W = img_shape 122 | for h in tqdm(range(H), leave=False): 123 | for w in range(W): 124 | # Skip if not to be filled (-1) 125 | if (img[:, :, h, w] != -1).all().item(): 126 | continue 127 | # For efficiency, we only have to input the upper part of the image 128 | # as all other parts will be skipped by the masked convolutions anyways 129 | pred = self.forward(img, cond) # (N, classes, C, H, W) 130 | probs = F.softmax(rearrange(pred[:, :, :, h, w], "n a c -> n c a"), dim=-1).reshape(N*C, 256) # (NC, n_classes) 131 | new_pred = torch.multinomial(probs, num_samples=1).squeeze(dim=-1).to(torch.float32) / 255 132 | if self.input_normalize: 133 | new_pred = new_pred*2 - 1 134 | img[:, :, h, w] = new_pred.reshape(N, C) 135 | return img 136 | 137 | def configure_optimizers(self): 138 | optimizer = optim.Adam(self.parameters(), lr=self.hparams.lr) 139 | scheduler = optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.99) 140 | return [optimizer], [scheduler] 141 | 142 | def training_step(self, batch, batch_idx): 143 | img, label = batch 144 | loss = self.calc_likelihood(img) 145 | self.log("train_bpd", loss) # bpd: bits per dim, by entropy encoding 146 | return loss 147 | 148 | def validation_step(self, batch, batch_idx): 149 | img, label = batch 150 | N, C, H, W = img.shape 151 | loss = self.calc_likelihood(img) 152 | self.log("val_bpd", loss) 153 | 154 | sample_img = None 155 | if batch_idx == 0: 156 | sample_img = self.sample(img.shape) 157 | return ValidationResult(real_image=img, fake_image=sample_img) 158 | -------------------------------------------------------------------------------- /src/models/pixelcnn.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from types import new_class 3 | from torch import nn 4 | import torch 5 | from torch import optim 6 | import torch.nn.functional as F 7 | from tqdm import tqdm 8 | from src.models.base import BaseModel, ValidationResult 9 | 10 | # TODO: This implementation does not handle color dependency problem(which is nontrival with mask handling), thus has worse results on color images. 11 | 12 | class MaskedConvolution(nn.Module): 13 | def __init__(self, c_in, c_out, mask, **kwargs): 14 | super().__init__() 15 | self.register_buffer("mask", mask) 16 | kernel_size = mask.shape 17 | dilation = 1 if "dilation" not in kwargs else kwargs["dilation"] 18 | padding = tuple([dilation * (kernel_size[i] - 1) // 2 for i in range(2)]) 19 | # Actual convolution 20 | self.conv = nn.Conv2d(c_in, c_out, kernel_size, padding=padding, **kwargs) 21 | 22 | def forward(self, x): 23 | self.conv.weight.data *= self.mask 24 | return self.conv(x) 25 | 26 | 27 | class VerticalStackConvolution(MaskedConvolution): 28 | def __init__(self, c_in, c_out, kernel_size=3, mask_center=False, **kwargs): 29 | mask = torch.ones(kernel_size, kernel_size) 30 | mask[kernel_size // 2 + 1 :, :] = 0 31 | if mask_center: 32 | mask[kernel_size // 2] = 0 33 | super().__init__(c_in, c_out, mask, **kwargs) 34 | 35 | 36 | class HorizontalStackConvolution(MaskedConvolution): 37 | def __init__(self, c_in, c_out, kernel_size=3, mask_center=False, **kwargs): 38 | mask = torch.ones(1, kernel_size) 39 | mask[0, kernel_size // 2 + 1 :] = 0 40 | if mask_center: 41 | mask[0, kernel_size // 2] = 0 42 | super().__init__(c_in, c_out, mask, **kwargs) 43 | 44 | 45 | class GatedMaskedConv(nn.Module): 46 | def __init__(self, channels, kernel_size=3, cond_channel=None, **kargs): 47 | super().__init__() 48 | self.horiz_conv = HorizontalStackConvolution( 49 | channels, 2 * channels, kernel_size, mask_center=False, **kargs 50 | ) 51 | self.vert_conv = VerticalStackConvolution( 52 | channels, 2 * channels, kernel_size, mask_center=False, **kargs 53 | ) 54 | self.conv1x1_1 = nn.Conv2d(2 * channels, 2 * channels, 1) 55 | self.conv1x1_2 = nn.Conv2d(channels, channels, 1) 56 | self.output_channels = channels 57 | self.input_channels = channels 58 | if cond_channel is not None: 59 | self.cond_proj_vert1 = nn.Conv2d(cond_channel, channels, kernel_size=1, bias=False) 60 | self.cond_proj_vert2 = nn.Conv2d(cond_channel, channels, kernel_size=1, bias=False) 61 | self.cond_proj_horiz1 = nn.Conv2d(cond_channel, channels, kernel_size=1, bias=False) 62 | self.cond_proj_horiz2 = nn.Conv2d(cond_channel, channels, kernel_size=1, bias=False) 63 | 64 | def forward(self, vert_x, horiz_x, cond=None): 65 | # The horizontal branch can take information from vertical branch, while not vice versa. Think it carefully. 66 | vert_conv_x = self.vert_conv(vert_x) 67 | vert_x1, vert_x2 = torch.chunk(vert_conv_x, 2, dim=1) 68 | if cond is None: 69 | out_vert_x = torch.tanh(vert_x1) * torch.sigmoid(vert_x2) 70 | else: 71 | out_vert_x = torch.tanh(vert_x1 + self.cond_proj_vert1(cond).expand_as(vert_x1)) * torch.sigmoid(vert_x2 + self.cond_proj_vert2(cond).expand_as(vert_x2)) 72 | 73 | horiz_x1, horiz_x2 = torch.chunk( 74 | self.horiz_conv(horiz_x) + self.conv1x1_1(vert_conv_x), 2, 1 75 | ) 76 | if cond is None: 77 | out_horiz_x = torch.tanh(horiz_x1) * torch.tanh(horiz_x2) 78 | else: 79 | out_horiz_x = torch.tanh(horiz_x1 + self.cond_proj_horiz1(cond).expand_as(horiz_x1)) * torch.tanh(horiz_x2 + self.cond_proj_horiz2(cond).expand_as(horiz_x2)) 80 | out_horiz_x = self.conv1x1_2(out_horiz_x) + horiz_x 81 | 82 | return out_vert_x, out_horiz_x 83 | 84 | 85 | class PixelCNN(BaseModel): 86 | def __init__( 87 | self, 88 | datamodule, 89 | hidden_dim, 90 | class_condition=False, 91 | n_classes=None, 92 | lr=1e-3 93 | ): 94 | super().__init__(datamodule) 95 | self.save_hyperparameters() 96 | 97 | # Initial convolutions skipping the center pixel 98 | self.conv_vstack = VerticalStackConvolution( 99 | self.channels, hidden_dim, 5, mask_center=True 100 | ) 101 | self.conv_hstack = HorizontalStackConvolution( 102 | self.channels, hidden_dim, 5, mask_center=True 103 | ) 104 | # Convolution block of PixelCNN. We use dilation instead of downscaling 105 | if class_condition: 106 | conv_layer = partial(GatedMaskedConv, cond_channel=n_classes) 107 | else: 108 | conv_layer = GatedMaskedConv 109 | self.conv_layers = nn.ModuleList( 110 | [ 111 | conv_layer(hidden_dim), 112 | conv_layer(hidden_dim, dilation=2), 113 | conv_layer(hidden_dim), 114 | conv_layer(hidden_dim, dilation=4), 115 | conv_layer(hidden_dim), 116 | conv_layer(hidden_dim, dilation=2), 117 | conv_layer(hidden_dim), 118 | conv_layer(hidden_dim, dilation=4), 119 | conv_layer(hidden_dim), 120 | conv_layer(hidden_dim, dilation=2), 121 | conv_layer(hidden_dim), 122 | ] 123 | ) 124 | # Output classification convolution (1x1) 125 | self.conv_out = nn.Conv2d(hidden_dim, self.channels * 256, kernel_size=1, padding=0) 126 | self.register_buffer("log2", torch.log(torch.tensor(2, dtype=torch.float32, device=self.device))) 127 | 128 | def forward(self, x, y=None): 129 | """ 130 | Forward image through model and return logits for each pixel. 131 | Inputs: 132 | x - Image tensor in range (0, 1). 133 | y - one-hot vector indicating class label. 134 | """ 135 | N = x.shape[0] 136 | # Initial convolutions 137 | v_stack = self.conv_vstack(x) 138 | h_stack = self.conv_hstack(x) 139 | # Gated Convolutions 140 | for layer in self.conv_layers: 141 | if y is not None: 142 | y = y.reshape(N, self.hparams.n_classes, 1, 1) 143 | v_stack, h_stack = layer(v_stack, h_stack, y) 144 | else: 145 | v_stack, h_stack = layer(v_stack, h_stack) 146 | # 1x1 classification convolution 147 | # Apply ELU before 1x1 convolution for non-linearity on residual connection 148 | out = self.conv_out(F.elu(h_stack)) 149 | 150 | # Output dimensions: [Batch, Classes, Channels, Height, Width] 151 | out = out.reshape( 152 | out.shape[0], 256, out.shape[1] // 256, out.shape[2], out.shape[3] 153 | ) 154 | return out 155 | 156 | def calc_likelihood(self, x, label=None): 157 | # Forward pass with bpd likelihood calculation 158 | pred = self.forward(x, label) 159 | if self.input_normalize: 160 | target = ((x + 1) / 2 * 255).to(torch.long) 161 | else: 162 | target = (x * 255).to(torch.long) 163 | nll = F.cross_entropy(pred, target, reduction="none") # (N, C, H, W) 164 | bpd = nll.mean(dim=[1, 2, 3]) / self.log2 165 | return bpd.mean() 166 | 167 | @torch.no_grad() 168 | def sample(self, img_shape, cond=None, img=None): 169 | """ 170 | Sampling function for the autoregressive model. 171 | Inputs: 172 | img_shape - Shape of the image to generate (B,C,H,W) 173 | img (optional) - If given, this tensor will be used as 174 | a starting image. The pixels to fill 175 | should be -1 in the input tensor. 176 | """ 177 | # Create empty image 178 | if img is None: 179 | img = torch.zeros(img_shape, dtype=torch.float32).to(self.device) - 1 180 | # Generation loop 181 | N, C, H, W = img_shape 182 | for h in tqdm(range(H), leave=False): 183 | for w in range(W): 184 | # Skip if not to be filled (-1) 185 | if (img[:, :, h, w] != -1).all().item(): 186 | continue 187 | # For efficiency, we only have to input the upper part of the image 188 | # as all other parts will be skipped by the masked convolutions anyways 189 | pred = self.forward(img[:, :, : h + 1, :], cond) # (N, classes, C) 190 | probs = F.softmax(pred[:, :, :, h, w].permute(0, 2, 1), dim=-1).reshape(N*C, -1) 191 | new_pred = torch.multinomial(probs, num_samples=1).squeeze(dim=-1).to(torch.float32) / 255 192 | if self.input_normalize: 193 | new_pred = new_pred*2 - 1 194 | img[:, :, h, w] = new_pred.reshape(N, C) 195 | return img 196 | 197 | def configure_optimizers(self): 198 | optimizer = optim.Adam(self.parameters(), lr=self.hparams.lr) 199 | scheduler = optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.99) 200 | return [optimizer], [scheduler] 201 | 202 | def training_step(self, batch, batch_idx): 203 | img, label = batch 204 | if self.hparams.class_condition: 205 | label = F.one_hot(label, num_classes=self.hparams.n_classes).to(torch.float32) 206 | loss = self.calc_likelihood(img, label) 207 | else: 208 | loss = self.calc_likelihood(img) 209 | self.log("train_bpd", loss) # bpd: bits per dim, by entropy encoding 210 | return loss 211 | 212 | def validation_step(self, batch, batch_idx): 213 | img, label = batch 214 | N, C, H, W = img.shape 215 | if self.hparams.class_condition: 216 | label = F.one_hot(label, num_classes=self.hparams.n_classes).to(torch.float32) 217 | loss = self.calc_likelihood(img, label) 218 | else: 219 | loss = self.calc_likelihood(img) 220 | self.log("val_bpd", loss) 221 | 222 | sample_img = None 223 | if batch_idx == 0: 224 | if self.hparams.class_condition: 225 | sample_label = torch.arange(self.hparams.n_classes, device=self.device).reshape(self.hparams.n_classes, 1).repeat(1, 8) 226 | sample_label = F.one_hot(sample_label, num_classes=self.hparams.n_classes).to(torch.float32) 227 | sample_img = self.sample((self.hparams.n_classes*8, C, H, W), cond=sample_label) 228 | else: 229 | sample_img = self.sample(img.shape) 230 | return ValidationResult(real_image=img, fake_image=sample_img) 231 | -------------------------------------------------------------------------------- /src/models/speed_gan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from omegaconf import OmegaConf 3 | from .base import BaseModel, ValidationResult 4 | from hydra.utils import instantiate 5 | from src.utils.losses import adversarial_loss 6 | 7 | class GAN(BaseModel): 8 | def __init__( 9 | self, 10 | datamodule: OmegaConf, 11 | netG: OmegaConf, 12 | netD: OmegaConf, 13 | latent_dim: int = 100, 14 | loss_mode: str = "vanilla", 15 | lrG: float = 0.0002, 16 | lrD: float = 0.0002, 17 | b1: float = 0.5, 18 | b2: float = 0.999, 19 | ): 20 | super().__init__(datamodule) 21 | self.save_hyperparameters() 22 | self.netG = instantiate(netG, input_channel=latent_dim, output_channel=self.channels) 23 | self.netD = instantiate(netD, input_channel=self.channels, output_channel=1) 24 | self.automatic_optimization = False 25 | 26 | def forward(self, z): 27 | output = self.netG(z) 28 | output = output.reshape(z.shape[0], self.channels, self.height, self.width) 29 | return output 30 | 31 | def configure_optimizers(self): 32 | lrG, lrD = self.hparams.lrG, self.hparams.lrD 33 | b1, b2 = self.hparams.b1, self.hparams.b2 34 | opt_g = torch.optim.Adam(self.netG.parameters(), lr=lrG, betas=(b1, b2)) 35 | opt_d = torch.optim.Adam(self.netD.parameters(), lr=lrD, betas=(b1, b2)) 36 | return [opt_g, opt_d] 37 | 38 | def training_step(self, batch, batch_idx, optimizer_idx): 39 | # speed gan 40 | # generator: forward 1 times, backward 1 times 41 | # discriminator: forward 2 times, backward 2 times 42 | 43 | # original gan 44 | # generator: forward 2 times, backward 1 times 45 | # discriminator: forward 3 times, backward 3 times 46 | opt_g, opt_d = self.optimizers() 47 | imgs, _ = batch # (N, C, H, W) 48 | N, C, H, W = imgs.shape 49 | z = torch.randn(N, self.hparams.latent_dim).to(self.device) 50 | 51 | fake_imgs = self.netG(z) 52 | pred_fake = self.netD(fake_imgs) 53 | pred_real = self.netD(imgs) 54 | 55 | real_loss = adversarial_loss(pred_real, True, loss_mode=self.hparams.loss_mode) 56 | fake_loss = adversarial_loss(pred_fake, False, loss_mode=self.hparams.loss_mode) 57 | 58 | g_loss = adversarial_loss(pred_fake, True, loss_mode=self.hparams.loss_mode) 59 | d_loss = (real_loss + fake_loss) / 2 60 | 61 | opt_g.zero_grad() 62 | self.manual_backward(g_loss, retain_graph=True) 63 | opt_g.step() 64 | 65 | opt_d.zero_grad() 66 | self.manual_backward(d_loss, inputs=list(self.netD.parameters()), retain_graph=True) 67 | opt_d.step() 68 | 69 | self.log("train_loss/d_loss", d_loss) 70 | self.log("train_loss/g_loss", g_loss) 71 | self.log("train_log/pred_real", pred_real.mean()) 72 | self.log("train_log/pred_fake", pred_fake.mean()) 73 | 74 | 75 | def validation_step(self, batch, batch_idx): 76 | img, _ = batch 77 | z = torch.randn(img.shape[0], self.hparams.latent_dim).to(self.device) 78 | fake_imgs = self.forward(z) 79 | return ValidationResult(real_image=img, fake_image=fake_imgs) 80 | -------------------------------------------------------------------------------- /src/models/tar.py: -------------------------------------------------------------------------------- 1 | ## Vector quantized autoregressive model 2 | import torch 3 | from omegaconf import OmegaConf 4 | from torch import embedding, embedding_renorm_, nn 5 | from torch import Tensor 6 | import math 7 | from einops import rearrange 8 | from src.models.base import BaseModel, ValidationResult 9 | from src.utils.losses import normal_kld 10 | import torch.nn.functional as F 11 | import numpy as np 12 | 13 | 14 | class PositionalEncoding(nn.Module): 15 | def __init__(self, d_model: int, H: int, W: int) -> None: 16 | super().__init__() 17 | self.h_pe = torch.nn.Parameter(data=torch.randn(H, 1, d_model)) 18 | self.w_pe = torch.nn.Parameter(data=torch.randn(W, 1, d_model)) 19 | self.first_pe = torch.nn.Parameter(data=torch.randn(1, 1, d_model)) 20 | self.H = H 21 | self.W = W 22 | 23 | def forward(self, x): 24 | h_pe = self.h_pe.repeat(1, self.W, 1).reshape(self.H*self.W, 1, -1) 25 | h_pe = torch.cat([self.first_pe, h_pe], dim=0) 26 | 27 | w_pe = self.w_pe.repeat(self.H, 1, 1).reshape(self.H*self.W, 1, -1) 28 | w_pe = torch.cat([self.first_pe, w_pe], dim=0) 29 | 30 | x = x + h_pe[:x.size(0)] + w_pe[:x.size(0)] 31 | return x 32 | 33 | class PixelEncoding(nn.Module): 34 | def __init__(self, n_tokens, d_model, class_cond=False, n_classes=None) -> None: 35 | super().__init__() 36 | self.pixel_embed = nn.Embedding(num_embeddings=n_tokens, embedding_dim=d_model) 37 | if class_cond: 38 | self.cond_embed = nn.Embedding(num_embeddings=n_classes, embedding_dim=d_model) 39 | else: 40 | self.cond_embed = nn.Embedding(num_embeddings=1, embedding_dim=d_model) 41 | 42 | def forward(self, tokens): 43 | token1 = self.cond_embed(tokens[0:1]) 44 | token2 = self.pixel_embed(tokens[1:]) 45 | return torch.cat([token1, token2], dim=0) 46 | 47 | 48 | class TAR(BaseModel): 49 | def __init__( 50 | self, 51 | datamodule: OmegaConf = None, 52 | lr: float = 1e-4, 53 | b1: float = 0.9, 54 | b2: float = 0.999, 55 | d_model: int = 256, 56 | nhead: int = 4, 57 | num_layers: int = 4, 58 | class_cond: bool = False, 59 | n_classes: int = 10 60 | ): 61 | super().__init__(datamodule) 62 | self.save_hyperparameters() 63 | self.n_tokens = 2 # 0-255 and 64 | 65 | self.pos_embed = PositionalEncoding(d_model, H=self.height, W=self.width) 66 | self.pixel_embed = PixelEncoding(self.n_tokens, d_model, class_cond, n_classes=n_classes) 67 | 68 | encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=1024) 69 | self.encoder = nn.TransformerEncoder(encoder_layer=encoder_layer, num_layers=num_layers) 70 | self.proj = nn.Linear(d_model, self.n_tokens) 71 | 72 | def img2tokens(self, imgs, label): 73 | N = imgs.shape[0] 74 | # imgs = (imgs * 255 + 0.5).long().clamp(0, 255) 75 | imgs[imgs >= 0.5] = 1 76 | imgs[imgs < 0.5] = 0 77 | tokens = rearrange(imgs.long(), 'n c h w -> (h w c) n') 78 | # prepend to tokens 79 | if self.hparams.class_cond: 80 | sos = label.long().reshape(1, N) # start of sequence 81 | else: 82 | sos = torch.zeros(1, N, device=self.device, dtype=torch.long) # start of sequence 83 | tokens = torch.cat([sos, tokens], dim=0) # (seq_len+1, batch) 84 | return tokens 85 | 86 | def tokens2img(self, tokens, shape): 87 | N, C, H, W = shape 88 | imgs = rearrange(tokens[1:], "(h w c) n -> n c h w", n=N, c=C, h=H, w=W).float() 89 | return imgs 90 | 91 | def forward(self, tokens): 92 | # tokens: (seq_len, batch) consisting of long tensor 93 | # returns: (seq_len, batch, n_classes) 94 | S, N = tokens.shape 95 | mask = torch.tril(torch.ones(S, S, device=self.device)) == 0 # True indicates not to attend! 96 | # append start of sentence token 97 | embed = self.pos_embed(self.pixel_embed(tokens)) # s n d 98 | features = self.encoder.forward(embed, mask=mask) # s n d 99 | pred = self.proj(features) # s n num_class 100 | pred = rearrange(pred, "s n c -> s c n") 101 | return pred 102 | 103 | def cal_loss(self, tokens): 104 | # tokens: (S+1, N), including token 105 | pred = self.forward(tokens) # (S+1, n_class, N) 106 | loss = F.cross_entropy(pred[:-1], tokens[1:], reduction="none").sum(dim=0).mean() 107 | return loss 108 | 109 | def training_step(self, batch, batch_idx): 110 | imgs, labels = batch # (N, C, H, W) 111 | N, C, H, W = imgs.shape 112 | 113 | tokens = self.img2tokens(imgs, labels) 114 | 115 | loss = self.cal_loss(tokens) 116 | self.log("train_log/nll", loss) 117 | self.log("train_log/bpd", loss / (H*W*C) / np.log(2)) 118 | return loss 119 | 120 | def configure_optimizers(self): 121 | lr = self.hparams.lr 122 | b1 = self.hparams.b1 123 | b2 = self.hparams.b2 124 | opt = torch.optim.Adam(self.parameters(), lr=lr, betas=(b1, b2)) 125 | scheduler = torch.optim.lr_scheduler.StepLR(opt, 1, gamma=0.99) 126 | return [opt], [scheduler] 127 | 128 | def sample(self, shape, tokens=None, labels=None): 129 | if tokens == None: 130 | N, C, H, W = shape 131 | tokens = torch.zeros(1+H*W*C, N, device=self.device).long().fill_(-1) 132 | if self.hparams.class_cond: 133 | tokens[0] = labels 134 | else: 135 | tokens[0].fill_(0) # set to index 0 136 | 137 | for i in range(tokens.shape[0]-1): 138 | if (tokens[i+1, :] != -1).all().item(): 139 | continue 140 | pred = self.forward(tokens[:i+1]) # (S, n_class, N) 141 | prob = torch.softmax(pred[-1].T, dim=-1) 142 | sample = torch.multinomial(prob, num_samples=1).squeeze(-1) # (N) 143 | tokens[i+1] = sample 144 | imgs = self.tokens2img(tokens, shape) 145 | return imgs 146 | 147 | def validation_step(self, batch, batch_idx): 148 | imgs, labels = batch 149 | N, C, H, W = shape = imgs.shape 150 | 151 | tokens = self.img2tokens(imgs, labels) 152 | loss = self.cal_loss(tokens) 153 | 154 | random_tokens = torch.randint(0, 2, (C*H*W+1, N), device=self.device) 155 | random_tokens[0] = 0 156 | rand_loss = self.cal_loss(random_tokens) 157 | 158 | fake_imgs = None 159 | mask_image = None 160 | if batch_idx == 0: 161 | fake_labels = None 162 | if self.hparams.class_cond: 163 | fake_labels = torch.arange(0, self.hparams.n_classes).reshape(-1, 1).repeat(1, 8).reshape(-1) 164 | fake_imgs = self.sample((self.hparams.n_classes*8, C, H, W), labels=fake_labels).float() 165 | 166 | tokens[H*W*C // 2:] = -1 167 | mask_image = self.sample(shape, tokens=tokens) 168 | 169 | fake_tokens = self.img2tokens(mask_image, labels) 170 | fake_loss = self.cal_loss(fake_tokens) 171 | self.log("var_log/fake_bpg", fake_loss / (H*W*C) / np.log(2)) 172 | 173 | self.log("val_log/bpd", loss / (H*W*C) / np.log(2)) 174 | self.log("val_log/rand_bpd", rand_loss / (H*W*C) / np.log(2)) 175 | 176 | return ValidationResult(real_image=imgs, fake_image=fake_imgs, others={"mask_image": mask_image}) 177 | -------------------------------------------------------------------------------- /src/models/vae.py: -------------------------------------------------------------------------------- 1 | from hydra.utils import instantiate 2 | import torch 3 | from omegaconf import OmegaConf 4 | 5 | from src.models.base import BaseModel, ValidationResult 6 | from torch import distributions 7 | from src.utils.distributions import get_decode_dist 8 | from src.utils.losses import normal_kld 9 | 10 | 11 | class VAE(BaseModel): 12 | def __init__( 13 | self, 14 | datamodule: OmegaConf = None, 15 | encoder: OmegaConf = None, 16 | decoder: OmegaConf = None, 17 | latent_dim: int = 100, 18 | beta: float = 1.0, 19 | recon_weight: float = 1.0, 20 | lr: float = 1e-4, 21 | b1: float = 0.9, 22 | b2: float = 0.999, 23 | decoder_dist = "guassian" 24 | ): 25 | super().__init__(datamodule) 26 | self.save_hyperparameters() 27 | 28 | self.decoder = instantiate(decoder, input_channel=latent_dim, output_channel=self.channels, output_act=self.output_act) 29 | self.encoder = instantiate(encoder, input_channel=self.channels, output_channel=2 * latent_dim) 30 | self.decoder_dist = get_decode_dist(decoder_dist) 31 | 32 | def forward(self, z): 33 | """Generate images given latent code.""" 34 | output = self.decoder(z) 35 | output = self.decoder_dist.sample(output) 36 | output = output.reshape(output.shape[0], self.channels, self.height, self.width) 37 | return output 38 | 39 | def configure_optimizers(self): 40 | lr = self.hparams.lr 41 | b1 = self.hparams.b1 42 | b2 = self.hparams.b2 43 | opt = torch.optim.Adam(self.parameters(), lr=lr, betas=(b1, b2)) 44 | scheduler = torch.optim.lr_scheduler.StepLR(opt, 1, gamma=0.99) 45 | return [opt], [scheduler] 46 | 47 | def reparameterize(self, mu, log_sigma): 48 | post_dist = distributions.Normal(mu, torch.exp(log_sigma)) 49 | samples_z = post_dist.rsample() 50 | return samples_z 51 | 52 | def vae(self, imgs): 53 | z_ = self.encoder(imgs) # (N, latent_dim) 54 | mu, log_sigma = torch.chunk(z_, chunks=2, dim=1) 55 | z = self.reparameterize(mu, log_sigma) 56 | recon_imgs = self.decoder(z) 57 | return mu, log_sigma, z, recon_imgs 58 | 59 | 60 | def training_step(self, batch, batch_idx): 61 | imgs, labels = batch # (N, C, H, W) 62 | N = imgs.shape[0] 63 | 64 | mu, log_sigma, z, recon_imgs = self.vae(imgs) 65 | kld = normal_kld(mu, log_sigma) 66 | 67 | log_p_x_of_z = self.decoder_dist.prob(recon_imgs, imgs).mean(dim=0) 68 | elbo = -self.hparams.beta*kld + self.hparams.recon_weight * log_p_x_of_z 69 | 70 | self.log("train_log/elbo", elbo) 71 | self.log("train_log/kl_divergence", kld) 72 | self.log("train_log/log_p_x_of_z", log_p_x_of_z) 73 | return -elbo 74 | 75 | def validation_step(self, batch, batch_idx): 76 | imgs, labels = batch 77 | N = imgs.shape[0] 78 | mu, log_sigma, z, recon_imgs = self.vae(imgs) 79 | log_p_x_of_z = self.decoder_dist.prob(recon_imgs, imgs).mean(dim=0) 80 | 81 | fake_imgs = self.sample(N) 82 | self.log("val_log/log_p_x_of_z", log_p_x_of_z) 83 | return ValidationResult(real_image=imgs, fake_image=fake_imgs, 84 | recon_image=recon_imgs, label=labels, encode_latent=z) -------------------------------------------------------------------------------- /src/models/vae_gan.py: -------------------------------------------------------------------------------- 1 | """Autoencoding beyond pixels using a learned similarity metric""" 2 | import itertools 3 | import hydra 4 | import pytorch_lightning as pl 5 | import torchvision 6 | import torch 7 | import torch.nn.functional as F 8 | from pathlib import Path 9 | from omegaconf import OmegaConf 10 | from .base import BaseModel, ValidationResult 11 | from src.utils.losses import adversarial_loss, normal_kld 12 | from torch import distributions 13 | 14 | class VAEGAN(BaseModel): 15 | def __init__( 16 | self, 17 | datamodule, 18 | encoder: OmegaConf = None, 19 | decoder: OmegaConf = None, 20 | latent_dim=100, 21 | lr: float = 0.0002, 22 | b1: float = 0.5, 23 | b2: float = 0.999, 24 | # reconstruction weight in discriminator feature space, first tune this parameter if performace is unsatifactory. 25 | recon_weight: float = 1e-4, 26 | loss_mode: str = "vanilla" 27 | ): 28 | super().__init__(datamodule) 29 | self.save_hyperparameters() 30 | 31 | self.decoder = hydra.utils.instantiate(decoder, input_channel=latent_dim, output_channel=self.channels) 32 | self.encoder = hydra.utils.instantiate(encoder, input_channel=self.channels, output_channel=2 * latent_dim) 33 | self.netD = hydra.utils.instantiate(encoder, input_channel=self.channels, output_channel=1, return_features=True) 34 | self.automatic_optimization = False 35 | 36 | def configure_optimizers(self): 37 | lr = self.hparams.lr 38 | b1 = self.hparams.b1 39 | b2 = self.hparams.b2 40 | optim_ae = torch.optim.Adam(itertools.chain(self.encoder.parameters(), 41 | self.decoder.parameters()), lr=lr, betas=(b1, b2)) 42 | optim_d = torch.optim.Adam(self.netD.parameters(), lr=lr, betas=(b1, b2)) 43 | 44 | return optim_ae, optim_d 45 | 46 | def forward(self, z): 47 | output = self.decoder(z) 48 | output = output.reshape(output.shape[0], self.channels, self.height, self.width) 49 | return output 50 | 51 | def reparameterize(self, mu, log_sigma): 52 | post_dist = distributions.Normal(mu, torch.exp(log_sigma)) 53 | samples_z = post_dist.rsample() 54 | return samples_z 55 | 56 | def vae(self, imgs): 57 | z_ = self.encoder(imgs) # (N, latent_dim) 58 | mu, log_sigma = torch.chunk(z_, chunks=2, dim=1) 59 | z = self.reparameterize(mu, log_sigma) 60 | recon_imgs = self.decoder(z) 61 | return mu, log_sigma, z, recon_imgs 62 | 63 | def training_step(self, batch, batch_idx): 64 | optim_ae, optim_d = self.optimizers() 65 | 66 | imgs, _ = batch 67 | N = imgs.shape[0] 68 | 69 | mu, log_sigma, infered_z, recon_imgs = self.vae(imgs) 70 | prior_z = torch.randn(N, self.hparams.latent_dim).to(self.device) 71 | fake_imgs = self.decoder(prior_z) 72 | 73 | reg_loss = normal_kld(mu, log_sigma) 74 | 75 | fake_logit, fake_features = self.netD(fake_imgs) 76 | real_logit, real_features = self.netD(imgs) 77 | recon_logit, recon_features = self.netD(recon_imgs) 78 | feature_recon_loss = F.mse_loss(real_features, recon_features, reduction="sum") / N 79 | # NOTE: this paper says also use recon samples as , 80 | # but the official code doesn't use recon images as negative samples 81 | g_adv_loss = adversarial_loss(fake_logit, True) 82 | 83 | optim_ae.zero_grad() 84 | self.manual_backward(reg_loss+feature_recon_loss, retain_graph=True) 85 | for p in self.decoder.parameters(): 86 | p.grad *= self.hparams.recon_weight 87 | # encoder is not optimized w.r.t. GAN loss 88 | self.manual_backward(g_adv_loss, inputs=list(self.decoder.parameters()), retain_graph=True) 89 | optim_ae.step() 90 | 91 | d_adv_loss = adversarial_loss(real_logit, True) + adversarial_loss(fake_logit, False) 92 | optim_d.zero_grad() 93 | self.manual_backward(d_adv_loss, inputs=list(self.netD.parameters())) 94 | optim_d.step() 95 | 96 | self.log("train_loss/reg_loss", reg_loss) 97 | self.log("train_loss/feature_recon_loss", feature_recon_loss) 98 | self.log("train_loss/g_adv_loss", g_adv_loss) 99 | self.log("train_loss/d_adv_loss", d_adv_loss) 100 | self.log("train_log/real_logit", torch.mean(real_logit)) 101 | self.log("train_log/fake_logit", torch.mean(fake_logit)) 102 | self.log("train_log/recon_logit", torch.mean(recon_logit)) 103 | 104 | def validation_step(self, batch, batch_idx): 105 | imgs, labels = batch 106 | N = imgs.shape[0] 107 | mu, log_sigma, z, recon_imgs = self.vae(imgs) 108 | fake_imgs = self.sample(N) 109 | val_mse = F.mse_loss(imgs, recon_imgs) 110 | self.log("val_log/van_mse", val_mse) 111 | 112 | return ValidationResult(real_image=imgs, fake_image=fake_imgs, 113 | recon_image=recon_imgs, label=labels, encode_latent=z) -------------------------------------------------------------------------------- /src/models/vqvae.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import pytorch_lightning as pl 3 | import torch 4 | import torch.nn.functional as F 5 | import itertools 6 | from omegaconf import OmegaConf 7 | from .base import BaseModel, ValidationResult 8 | from torch import nn 9 | 10 | # TODO: 11 | # 1. sampling implementation 12 | # 2. dive into the influence of num_embeddings and latent_dim 13 | class VectorQuantizer(nn.Module): 14 | def __init__(self, num_embeddings, latent_dim, commitment_weight) -> None: 15 | super().__init__() 16 | self.embedding = torch.nn.parameter.Parameter( 17 | torch.zeros(num_embeddings, latent_dim).uniform_( 18 | -1 / num_embeddings, 1 / num_embeddings 19 | ) 20 | ) # learned discete representation 21 | self.latent_dim = latent_dim 22 | self.commitment_weight = commitment_weight 23 | 24 | def forward(self, z): 25 | N, C, H, W = z.shape 26 | 27 | # (N, latent_dim, H, W) -> (N*latent_size, latent_dim) 28 | reshape_z = ( 29 | z.reshape(N, self.latent_dim, -1) 30 | .permute(0, 2, 1) 31 | .reshape(-1, self.latent_dim) 32 | ) 33 | # (N*latent_size, latent_dim), (K, latent_dim) -> (N*latent_size, K) 34 | dist = torch.cdist(reshape_z, self.embedding) 35 | 36 | z_index = torch.argmin(dist, dim=1) # (N*latent_size) 37 | quant_z = self.embedding[z_index] # (N*latent_size, latent_dim) 38 | vq_loss = F.mse_loss(reshape_z.detach(), quant_z) 39 | commit_loss = self.commitment_weight * F.mse_loss(reshape_z, quant_z.detach()) 40 | 41 | quant_z = quant_z.reshape(N, H, W, C).permute(0, 3, 1, 2) 42 | 43 | return quant_z, vq_loss, commit_loss 44 | 45 | 46 | class VQVAE(BaseModel): 47 | def __init__( 48 | self, 49 | datamodule, 50 | encoder: OmegaConf = None, 51 | decoder: OmegaConf = None, 52 | latent_dim=100, 53 | lr: float = 0.0002, 54 | b1: float = 0.5, 55 | b2: float = 0.999, 56 | num_embeddings: int = 512, 57 | beta: float = 0.25, 58 | optim="adam", 59 | **kwargs, 60 | ): 61 | super().__init__(datamodule) 62 | self.save_hyperparameters() 63 | 64 | self.decoder = hydra.utils.instantiate( 65 | decoder, input_channel=latent_dim, output_channel=self.channels 66 | ) 67 | self.encoder = hydra.utils.instantiate( 68 | encoder, input_channel=self.channels, output_channel=latent_dim 69 | ) 70 | self.vector_quntizer = VectorQuantizer(num_embeddings, latent_dim, beta) 71 | 72 | self.latent_w = self.width // 4 73 | self.latent_h = self.height // 4 74 | self.latent_size = self.latent_h * self.latent_w 75 | 76 | def forward(self, imgs): 77 | """ 78 | Directly sample from embeddings will not produce meaningful images. 79 | """ 80 | z = self.encoder(imgs) 81 | quant_z, _, _ = self.vector_quntizer(z) 82 | output = self.decoder(quant_z) 83 | output = output.reshape( 84 | output.shape[0], 85 | self.channels, 86 | self.height, 87 | self.width, 88 | ) 89 | return output 90 | 91 | def training_step(self, batch, batch_idx): 92 | imgs, _ = batch 93 | 94 | ## Encoding 95 | encoder_z = self.encoder(imgs) # (N, latent_dim, latent_w, latent_h) 96 | 97 | ## Vector Quantization 98 | quant_z, vq_loss, commit_loss = self.vector_quntizer(encoder_z) 99 | 100 | ## Decoding 101 | # this will feed value of z to decoder, and backward gradient to encoder_z instead of z, 102 | # such the encoder_z and encoder can be optimized 103 | decoder_z = encoder_z + (quant_z - encoder_z).detach() 104 | fake_imgs = self.decoder(decoder_z) 105 | fake_imgs = fake_imgs.reshape( 106 | -1, self.channels, self.height, self.width 107 | ) 108 | recon_loss = F.mse_loss(fake_imgs, imgs) 109 | 110 | total_loss = recon_loss + vq_loss + self.hparams.beta * commit_loss 111 | 112 | self.log("train_loss/vq_loss", vq_loss) 113 | self.log("train_loss/recon_loss", recon_loss) 114 | self.log("train_loss/commit_loss", commit_loss) 115 | 116 | 117 | return total_loss 118 | 119 | def configure_optimizers(self): 120 | lr = self.hparams.lr 121 | b1 = self.hparams.b1 122 | b2 = self.hparams.b2 123 | 124 | opt = torch.optim.Adam( 125 | itertools.chain( 126 | self.encoder.parameters(), 127 | self.decoder.parameters(), 128 | self.vector_quntizer.parameters(), 129 | ), 130 | lr=lr, 131 | betas=(b1, b2), 132 | ) 133 | return opt 134 | 135 | def validation_step(self, batch, batch_idx): 136 | imgs, labels = batch 137 | recon_imgs = self.forward(imgs) 138 | self.log("val/recon_loss", F.mse_loss(imgs, recon_imgs)) 139 | return ValidationResult(real_image=imgs, recon_image=recon_imgs) -------------------------------------------------------------------------------- /src/models/wgan.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import torch 3 | from .base import BaseModel, ValidationResult 4 | 5 | 6 | class WGAN(BaseModel): 7 | """Wassertain GAN. https://arxiv.org/abs/1701.07875 8 | Training tricks: 9 | 1. As paper said, momentum based optimizer like Adam performs worse, so we use RMSProp here. 10 | """ 11 | def __init__( 12 | self, 13 | datamodule, 14 | netG, 15 | netD, 16 | latent_dim=100, 17 | n_critic=5, 18 | clip_weight=0.01, 19 | lrG: float = 5e-5, 20 | lrD: float = 5e-5, 21 | alpha: float = 0.99, 22 | eval_fid=False, 23 | ): 24 | super().__init__(datamodule) 25 | self.save_hyperparameters() 26 | self.automatic_optimization = False 27 | 28 | # networks 29 | self.generator = hydra.utils.instantiate( 30 | netG, input_channel=latent_dim, output_channel=self.channels 31 | ) 32 | self.console.info(f"Generator architecture: \n {self.generator}") 33 | self.discriminator = hydra.utils.instantiate( 34 | netD, input_channel=self.channels, output_channel=1 35 | ) 36 | self.console.info(f"Discriminator architecture: \n {self.discriminator}") 37 | 38 | def forward(self, z): 39 | output = self.generator(z) 40 | output = output.reshape( 41 | z.shape[0], self.channels, self.height, self.width 42 | ) 43 | return output 44 | 45 | def configure_optimizers(self): 46 | lrG = self.hparams.lrG 47 | lrD = self.hparams.lrD 48 | alpha = self.hparams.alpha 49 | 50 | opt_g = torch.optim.RMSprop( 51 | self.generator.parameters(), lr=lrG, alpha=alpha 52 | ) 53 | opt_d = torch.optim.RMSprop( 54 | self.discriminator.parameters(), lr=lrD, alpha=alpha 55 | ) 56 | return [opt_g, opt_d] 57 | 58 | def training_step(self, batch, batch_idx): 59 | imgs, _ = batch # (N, C, H, W) 60 | 61 | z = torch.randn(imgs.shape[0], self.hparams.latent_dim) # (N, latent_dim) 62 | z = z.type_as(imgs) 63 | 64 | opt_g, opt_d = self.optimizers() 65 | 66 | # clip discriminator weight for 1-Lipschitz constraint 67 | for p in self.discriminator.parameters(): 68 | p.data.clamp_(-self.hparams.clip_weight, self.hparams.clip_weight) 69 | 70 | if batch_idx % (self.hparams.n_critic+1) == 0: 71 | generated_imgs = self(z) 72 | g_loss = -torch.mean(self.discriminator(generated_imgs)) 73 | 74 | opt_g.zero_grad() 75 | self.manual_backward(g_loss) 76 | opt_g.step() 77 | 78 | self.log("train_loss/g_loss", g_loss, prog_bar=True) 79 | 80 | else: 81 | real_loss = -self.discriminator(imgs).mean() 82 | fake_loss = self.discriminator(self(z).detach()).mean() 83 | d_loss = real_loss + fake_loss 84 | opt_d.zero_grad() 85 | self.manual_backward(d_loss) 86 | opt_d.step() 87 | 88 | self.log("train_loss/d_loss", d_loss) 89 | self.log("train_log/real_logit", -real_loss) 90 | self.log("train_log/fake_logit", fake_loss) 91 | 92 | 93 | def validation_step(self, batch, batch_idx): 94 | img, _ = batch 95 | z = torch.randn(img.shape[0], self.hparams.latent_dim).to(self.device) 96 | fake_imgs = self.forward(z) 97 | return ValidationResult(real_image=img, fake_image=fake_imgs) 98 | -------------------------------------------------------------------------------- /src/models/wgan_gp.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import hydra 4 | import pytorch_lightning as pl 5 | import torch 6 | import torch.nn.functional as F 7 | import torchvision 8 | from src.utils import utils 9 | from .base import BaseModel, ValidationResult 10 | 11 | 12 | class WGAN(BaseModel): 13 | def __init__( 14 | self, 15 | datamodule, 16 | netG, 17 | netD, 18 | latent_dim=100, 19 | n_critic=5, 20 | lrG: float = 1e-4, 21 | lrD: float = 1e-4, 22 | b1: float = 0, 23 | b2: float = 0.9, 24 | gp_weight=10, 25 | ): 26 | super().__init__(datamodule) 27 | self.save_hyperparameters() 28 | self.automatic_optimization = False 29 | 30 | self.generator = hydra.utils.instantiate(netG, input_channel=latent_dim, output_channel=self.channels, norm_type="layer") 31 | self.discriminator = hydra.utils.instantiate(netD, input_channel=self.channels, output_channel=1, norm_type="layer") 32 | 33 | def forward(self, z): 34 | output = self.generator(z) 35 | output = output.reshape( 36 | z.shape[0], self.channels, self.height, self.width 37 | ) 38 | return output 39 | 40 | def configure_optimizers(self): 41 | lrG = self.hparams.lrG 42 | lrD = self.hparams.lrD 43 | b1 = self.hparams.b1 44 | b2 = self.hparams.b2 45 | 46 | opt_g = torch.optim.Adam( 47 | self.generator.parameters(), lr=lrG, betas=(b1, b2) 48 | ) 49 | opt_d = torch.optim.Adam( 50 | self.discriminator.parameters(), lr=lrD, betas=(b1, b2) 51 | ) 52 | return opt_g, opt_d 53 | 54 | def training_step(self, batch, batch_idx): 55 | imgs, _ = batch # (N, C, H, W) 56 | 57 | # sample noise 58 | z = torch.randn(imgs.shape[0], self.hparams.latent_dim) # (N, latent_dim) 59 | z = z.type_as(imgs) 60 | 61 | opt_g, opt_d = self.optimizers() 62 | 63 | if batch_idx % (self.hparams.n_critic+1) == self.hparams.n_critic: 64 | generated_imgs = self(z) 65 | 66 | g_loss = -torch.mean(self.discriminator(generated_imgs)) 67 | self.log("train_loss/g_loss", g_loss, prog_bar=True) 68 | 69 | opt_g.zero_grad() 70 | self.manual_backward(g_loss) 71 | opt_g.step() 72 | 73 | else: 74 | # real loss 75 | real_loss = -self.discriminator(imgs).mean() 76 | 77 | # fake loss 78 | fake_imgs = self(z) 79 | fake_loss = self.discriminator(fake_imgs.detach()).mean() 80 | 81 | # gradient panelty 82 | N = imgs.shape[0] 83 | lerp = torch.zeros(N, 1, 1, 1).uniform_().to(self.device) 84 | inter_x = torch.tensor( 85 | lerp * imgs + (1 - lerp) * fake_imgs, requires_grad=True 86 | ).to(self.device) 87 | prob_inter = self.discriminator(inter_x) 88 | gradients = torch.autograd.grad( 89 | outputs=prob_inter, 90 | inputs=inter_x, 91 | grad_outputs=torch.ones_like(prob_inter).to(self.device), 92 | create_graph=True, 93 | retain_graph=True, 94 | )[0] 95 | gradient_panelty = torch.mean( 96 | (torch.linalg.vector_norm(gradients.reshape(N, -1), dim=1) - 1) ** 2 97 | ) 98 | 99 | # discriminator loss is the average of these 100 | d_loss = real_loss + fake_loss + self.hparams.gp_weight * gradient_panelty 101 | self.log("train_loss/d_loss", d_loss) 102 | self.log("train_log/real_logit", -real_loss) 103 | self.log("train_log/fake_logit", fake_loss) 104 | self.log("train_log/gradient_panelty", gradient_panelty) 105 | 106 | opt_d.zero_grad() 107 | self.manual_backward(d_loss) 108 | opt_d.step() 109 | 110 | def validation_step(self, batch, batch_idx): 111 | img, _ = batch 112 | z = torch.randn(img.shape[0], self.hparams.latent_dim).to(self.device) 113 | fake_imgs = self.forward(z) 114 | return ValidationResult(real_image=img, fake_image=fake_imgs) -------------------------------------------------------------------------------- /src/networks/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from abc import abstractmethod 4 | 5 | 6 | class ShapeChecker: 7 | def __init__(self, input_channel, output_channel) -> None: 8 | self.input_channel = input_channel 9 | self.output_channel = output_channel 10 | 11 | def __call__(self, module, input, output): 12 | assert input[0].shape[1] == self.input_channel 13 | assert output.shape[1] == self.output_channel 14 | 15 | 16 | class BaseNetwork(nn.Module): 17 | def __init__(self, input_channel, output_channel): 18 | super().__init__() 19 | self.input_channel = input_channel 20 | self.output_channel = output_channel 21 | -------------------------------------------------------------------------------- /src/networks/basic.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from torch import nn 3 | import torch 4 | 5 | from .base import BaseNetwork 6 | from .utils import FeatureExtractor 7 | 8 | def get_act_function(act="relu"): 9 | if act == "relu": 10 | return nn.ReLU(inplace=True) 11 | elif act == "leaky_relu": 12 | return nn.LeakyReLU(0.2, inplace=True) 13 | elif act == "identity": 14 | return nn.Identity() 15 | elif act == "sigmoid": 16 | return nn.Sigmoid() 17 | elif act == "tanh": 18 | return nn.Tanh() 19 | else: 20 | raise NotImplementedError 21 | 22 | def get_norm_layer(norm_type="batch"): 23 | if norm_type == "batch": 24 | return nn.BatchNorm2d 25 | elif norm_type == "instance": 26 | return nn.InstanceNorm2d 27 | elif norm_type == "layer": 28 | # groupnorm when num_groups = 1, equals to layer norm 29 | # when num_groups = channel, equals to instance norm 30 | # NOTE: instance norm affine is false by default and others are ture 31 | return partial(nn.GroupNorm, 1) 32 | elif norm_type == None: 33 | return nn.Identity 34 | else: 35 | raise NotImplementedError(f"Norm type of {norm_type} is not implemented") 36 | 37 | def get_norm_layer_1d(norm_type="batch"): 38 | if norm_type == "batch": 39 | return nn.BatchNorm1d 40 | elif norm_type == "instance": 41 | return nn.InstanceNorm1d 42 | elif norm_type == "layer": 43 | return partial(nn.GroupNorm, 1) 44 | elif norm_type == None: 45 | return nn.Identity 46 | else: 47 | raise NotImplementedError(f"Norm type of {norm_type} is not implemented") 48 | 49 | class LinearAct(nn.Module): 50 | def __init__( 51 | self, input_channel, output_channel, act="relu", dropout=0, norm_type="batch" 52 | ): 53 | super().__init__() 54 | self.act = get_act_function(act) 55 | self.fc = nn.Linear(input_channel, output_channel) 56 | self.dropout = nn.Dropout(dropout) 57 | self.norm = get_norm_layer_1d(norm_type)(output_channel) 58 | 59 | def forward(self, x): 60 | # NOTE: batch_norm should be placed before activation, otherwise netD will not converge 61 | return self.dropout(self.act(self.norm(self.fc(x)))) 62 | 63 | 64 | class MLPEncoder(BaseNetwork): 65 | def __init__( 66 | self, 67 | input_channel, 68 | output_channel, 69 | hidden_dims, 70 | width, 71 | height, 72 | dropout=0, 73 | norm_type="batch", 74 | return_features=False, 75 | output_act="identity", 76 | ): 77 | super().__init__(input_channel, output_channel) 78 | self.return_features = return_features 79 | if return_features: 80 | self.feature_extractor = FeatureExtractor() 81 | else: 82 | self.feature_extractor = lambda x: x 83 | self.model = nn.Sequential( 84 | # first layer not use batch_norm 85 | LinearAct( 86 | input_channel * width * height, 87 | hidden_dims[0], 88 | "leaky_relu", 89 | dropout=dropout, 90 | norm_type="layer", 91 | ), 92 | *[ 93 | LinearAct(x, y, "leaky_relu", dropout=dropout, norm_type=norm_type) 94 | for x, y in zip(hidden_dims[:-1], hidden_dims[1:]) 95 | ], 96 | ) 97 | self.feature_extractor(self.model) 98 | self.classifier = LinearAct( 99 | hidden_dims[-1], output_channel, output_act, norm_type=None 100 | ) 101 | 102 | def forward(self, x): 103 | N = x.shape[0] 104 | x = x.reshape(N, -1) 105 | if self.return_features: 106 | self.feature_extractor.clean() 107 | output = self.classifier(self.model(x)) 108 | return output, torch.cat( 109 | [torch.ravel(x) for x in self.feature_extractor.features] 110 | ) 111 | else: 112 | return self.classifier(self.model(x)) 113 | 114 | 115 | class MLPDecoder(BaseNetwork): 116 | def __init__( 117 | self, 118 | input_channel, 119 | output_channel, 120 | hidden_dims, 121 | width, 122 | height, 123 | output_act, 124 | norm_type="batch", 125 | ): 126 | super().__init__(input_channel, output_channel) 127 | self.width = width 128 | self.height = height 129 | 130 | dims = [input_channel, *hidden_dims] 131 | self.model = nn.Sequential( 132 | *[ 133 | LinearAct(x, y, "relu", norm_type=norm_type) 134 | for x, y in zip(dims[:-1], dims[1:]) 135 | ], 136 | LinearAct( 137 | hidden_dims[-1], 138 | output_channel * width * height, 139 | act=output_act, 140 | norm_type=False, 141 | ), 142 | ) 143 | 144 | def forward(self, x): 145 | return self.model(x).reshape(-1, self.output_channel, self.width, self.height) 146 | 147 | 148 | class ConvDecoder(BaseNetwork): 149 | def __init__(self, input_channel, output_channel, ngf, norm_type="batch", output_act="tanh"): 150 | super().__init__(input_channel, output_channel) 151 | # cause checkboard artifacts 152 | self.network = nn.Sequential( 153 | nn.ConvTranspose2d(input_channel, ngf * 4, 4, 1, 0), 154 | get_norm_layer(norm_type)(ngf * 4), 155 | nn.ReLU(True), 156 | nn.ConvTranspose2d(ngf * 4, ngf * 2, 3, 2, 1), 157 | get_norm_layer(norm_type)(ngf * 2), 158 | nn.ReLU(True), 159 | nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1), 160 | get_norm_layer(norm_type)(ngf), 161 | nn.ReLU(True), 162 | nn.ConvTranspose2d(ngf, output_channel, 4, 2, 1), 163 | get_act_function(output_act) 164 | ) 165 | 166 | def forward(self, x): 167 | N = x.shape[0] 168 | x = x.reshape(N, -1, 1, 1) 169 | output = self.network(x) 170 | return output 171 | 172 | 173 | class ConvEncoder(BaseNetwork): 174 | def __init__( 175 | self, input_channel, output_channel, ndf, norm_type="batch", return_features=False 176 | ): 177 | super().__init__(input_channel, output_channel) 178 | self.return_features = return_features 179 | if return_features: 180 | self.feature_extractor = FeatureExtractor() 181 | else: 182 | self.feature_extractor = lambda x: x 183 | self.output_channel = output_channel 184 | self.network = nn.Sequential( 185 | nn.Conv2d(input_channel, ndf, 4, 2, 1), 186 | nn.LeakyReLU(0.2, inplace=True), 187 | nn.Conv2d(ndf, ndf * 2, 4, 2, 1), 188 | get_norm_layer(norm_type)(ndf * 2), 189 | nn.LeakyReLU(0.2), 190 | nn.Conv2d(ndf * 2, ndf * 4, 3, 2, 1), 191 | get_norm_layer(norm_type)(ndf * 4), 192 | self.feature_extractor(nn.LeakyReLU(0.2, inplace=True)), 193 | nn.Conv2d(ndf * 4, output_channel, 4, 1, 0), 194 | ) 195 | 196 | def forward(self, x): 197 | if self.return_features: 198 | self.feature_extractor.clean() 199 | output = self.network(x).reshape(-1, self.output_channel) 200 | return output, torch.cat( 201 | [torch.ravel(x) for x in self.feature_extractor.features] 202 | ) 203 | else: 204 | return self.network(x).reshape(-1, self.output_channel) 205 | -------------------------------------------------------------------------------- /src/networks/conv32.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | from .utils import FeatureExtractor 4 | from .base import BaseNetwork 5 | from .basic import get_act_function, get_norm_layer 6 | import torch 7 | 8 | 9 | class Decoder(BaseNetwork): 10 | def __init__(self, input_channel=1, output_channel=3, ngf=32, norm_type="batch", output_act='tanh'): 11 | super().__init__(input_channel, output_channel) 12 | self.main = nn.Sequential( 13 | # input is Z, going into a convolution 14 | nn.ConvTranspose2d(input_channel, ngf * 8, 2, 1, 0), 15 | get_norm_layer(norm_type)(ngf * 8), 16 | nn.ReLU(True), 17 | # state size. (ngf*8) x 2 x 2 18 | nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1), 19 | get_norm_layer(norm_type)(ngf * 4), 20 | nn.ReLU(True), 21 | # state size. (ngf*4) x 4 x 4 22 | nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1), 23 | get_norm_layer(norm_type)(ngf * 2), 24 | nn.ReLU(True), 25 | # state size. (ngf*2) x 8 x 8 26 | nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1), 27 | get_norm_layer(norm_type)(ngf), 28 | nn.ReLU(True), 29 | # state size. (ngf) x 16 x 16 30 | nn.ConvTranspose2d(ngf, output_channel, 4, 2, 1), 31 | get_act_function(output_act) 32 | # state size. (nc) x 32 x 32 33 | ) 34 | 35 | def forward(self, input): 36 | N = input.shape[0] 37 | input = input.reshape(N, -1, 1, 1) 38 | return self.main(input) 39 | 40 | 41 | class Encoder(BaseNetwork): 42 | def __init__( 43 | self, input_channel, output_channel, ndf, norm_type="batch", return_features=False 44 | ): 45 | super().__init__(input_channel, output_channel) 46 | self.return_features = return_features 47 | if return_features: 48 | self.feature_extractor = FeatureExtractor() 49 | else: 50 | self.feature_extractor = lambda x: x 51 | self.main = nn.Sequential( 52 | # input is (nc) x 32 x 32 53 | nn.Conv2d(input_channel, ndf, 4, 2, 1), 54 | nn.LeakyReLU(0.2, inplace=True), 55 | # state size. (ndf) x 16 x 16 56 | nn.Conv2d(ndf, ndf * 2, 4, 2, 1), 57 | get_norm_layer(norm_type)(ndf * 2), 58 | nn.LeakyReLU(0.2, inplace=True), 59 | # state size. (ndf*2) x 8 x 8 60 | nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1), 61 | get_norm_layer(norm_type)(ndf * 4), 62 | self.feature_extractor(nn.LeakyReLU(0.2, inplace=True)), 63 | # state size. (ndf*4) x 4 x 4 64 | nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1), 65 | get_norm_layer(norm_type)(ndf * 8), 66 | nn.LeakyReLU(0.2, inplace=True), 67 | # state size. (ndf*8) x 2 x 2 68 | nn.Conv2d(ndf * 8, output_channel, 2, 1, 0), 69 | ) 70 | 71 | def forward(self, input): 72 | N = input.shape[0] 73 | if self.return_features: 74 | self.feature_extractor.clean() 75 | output = self.main(input).reshape(N, -1) 76 | features = torch.cat( 77 | [torch.ravel(x) for x in self.feature_extractor.features] 78 | ) 79 | return output, features 80 | else: 81 | output = self.main(input).reshape(N, -1) 82 | return output 83 | -------------------------------------------------------------------------------- /src/networks/conv64.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | from .base import BaseNetwork 4 | from .utils import FeatureExtractor 5 | from .basic import get_act_function, get_norm_layer 6 | 7 | 8 | class Decoder(BaseNetwork): 9 | def __init__(self, input_channel=1, output_channel=3, ngf=32, norm_type="batch", output_act="tanh"): 10 | super().__init__(input_channel, output_channel) 11 | self.main = nn.Sequential( 12 | nn.ConvTranspose2d(input_channel, ngf * 8, 4, 1, 0), 13 | get_norm_layer(norm_type)(ngf * 8), 14 | nn.ReLU(True), 15 | # state size. (ngf*8) x 4 x 4 16 | nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1), 17 | get_norm_layer(norm_type)(ngf * 4), 18 | nn.ReLU(True), 19 | # state size. (ngf*4) x 8 x 8 20 | nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1), 21 | get_norm_layer(norm_type)(ngf * 2), 22 | nn.ReLU(True), 23 | # state size. (ngf*2) x 16 x 16 24 | nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1), 25 | get_norm_layer(norm_type)(ngf), 26 | nn.ReLU(True), 27 | # state size. (ngf) x 32 x 32 28 | nn.ConvTranspose2d(ngf, output_channel, 4, 2, 1), 29 | get_act_function(output_act) 30 | # state size. (nc) x 64 x 64 31 | ) 32 | 33 | def forward(self, input): 34 | N = input.shape[0] 35 | input = input.reshape(N, -1, 1, 1) 36 | return self.main(input) 37 | 38 | 39 | class Encoder(BaseNetwork): 40 | def __init__( 41 | self, input_channel, output_channel, ndf, norm_type="batch", return_features=False 42 | ): 43 | super().__init__(input_channel, output_channel) 44 | self.return_features = return_features 45 | 46 | if return_features: 47 | self.feature_extractor = FeatureExtractor() 48 | else: 49 | self.feature_extractor = lambda x: x 50 | self.main = nn.Sequential( 51 | # input is (nc) x 64 x 64 52 | nn.Conv2d(input_channel, ndf, 4, 2, 1), 53 | nn.LeakyReLU(0.2, inplace=True), 54 | # state size. (ndf) x 32 x 32 55 | nn.Conv2d(ndf, ndf * 2, 4, 2, 1), 56 | get_norm_layer(norm_type)(ndf * 2), 57 | nn.LeakyReLU(0.2, inplace=True), 58 | # state size. (ndf*2) x 16 x 16 59 | nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1), 60 | get_norm_layer(norm_type)(ndf * 4), 61 | self.feature_extractor(nn.LeakyReLU(0.2, inplace=True)), # extract features 62 | # state size. (ndf*4) x 8 x 8 63 | nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1), 64 | get_norm_layer(norm_type)(ndf * 8), 65 | nn.LeakyReLU(0.2, inplace=True), 66 | # state size. (ndf*8) x 4 x 4 67 | nn.Conv2d(ndf * 8, output_channel, 4, 1, 0), 68 | ) 69 | 70 | def forward(self, input): 71 | N = input.shape[0] 72 | if self.return_features: 73 | self.feature_extractor.clean() 74 | output = self.main(input).reshape(N, -1) 75 | features = torch.cat( 76 | [torch.ravel(x) for x in self.feature_extractor.features] 77 | ) 78 | return output, features 79 | else: 80 | output = self.main(input).reshape(N, -1) 81 | return output 82 | -------------------------------------------------------------------------------- /src/networks/utils.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class FeatureExtractor: 5 | def __init__(self) -> None: 6 | super().__init__() 7 | self.features = [] 8 | 9 | def __call__(self, module: nn.Module): 10 | module.register_forward_hook(self.forward_hook()) 11 | return module 12 | 13 | def forward_hook(self): 14 | def fn(module, input, output): 15 | self.features.append(output) 16 | 17 | return fn 18 | 19 | def clean(self): 20 | self.features = [] 21 | -------------------------------------------------------------------------------- /src/networks/vqvae.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class ResidualLayer(nn.Module): 6 | """ 7 | One residual layer inputs: 8 | - in_dim : the input dimension 9 | - h_dim : the hidden layer dimension 10 | - res_h_dim : the hidden dimension of the residual block 11 | """ 12 | 13 | def __init__(self, in_dim, h_dim, res_h_dim): 14 | super(ResidualLayer, self).__init__() 15 | self.res_block = nn.Sequential( 16 | nn.ReLU(True), 17 | nn.Conv2d( 18 | in_dim, res_h_dim, kernel_size=3, stride=1, padding=1, bias=False 19 | ), 20 | nn.ReLU(True), 21 | nn.Conv2d(res_h_dim, h_dim, kernel_size=1, stride=1, bias=False), 22 | ) 23 | 24 | def forward(self, x): 25 | x = x + self.res_block(x) 26 | return x 27 | 28 | 29 | class ResidualStack(nn.Module): 30 | """ 31 | A stack of residual layers inputs: 32 | - in_dim : the input dimension 33 | - h_dim : the hidden layer dimension 34 | - res_h_dim : the hidden dimension of the residual block 35 | - n_res_layers : number of layers to stack 36 | """ 37 | 38 | def __init__(self, in_dim, h_dim, res_h_dim, n_res_layers): 39 | super(ResidualStack, self).__init__() 40 | self.n_res_layers = n_res_layers 41 | self.stack = nn.ModuleList( 42 | [ResidualLayer(in_dim, h_dim, res_h_dim)] * n_res_layers 43 | ) 44 | 45 | def forward(self, x): 46 | for layer in self.stack: 47 | x = layer(x) 48 | x = F.relu(x) 49 | return x 50 | 51 | 52 | class Encoder(nn.Module): 53 | """ 54 | This is the q_theta (z|x) network. Given a data sample x q_theta 55 | maps to the latent space x -> z. 56 | For a VQ VAE, q_theta outputs parameters of a categorical distribution. 57 | Inputs: 58 | - in_dim : the input dimension 59 | - h_dim : the hidden layer dimension 60 | - res_h_dim : the hidden dimension of the residual block 61 | - n_res_layers : number of layers to stack 62 | """ 63 | 64 | def __init__(self, input_channel, output_channel, n_res_layers=3, res_h_dim=128): 65 | super(Encoder, self).__init__() 66 | kernel = 4 67 | stride = 2 68 | self.conv_stack = nn.Sequential( 69 | nn.Conv2d( 70 | input_channel, 71 | output_channel // 2, 72 | kernel_size=kernel, 73 | stride=stride, 74 | padding=1, 75 | ), 76 | nn.ReLU(), 77 | nn.Conv2d( 78 | output_channel // 2, 79 | output_channel, 80 | kernel_size=kernel, 81 | stride=stride, 82 | padding=1, 83 | ), 84 | nn.ReLU(), 85 | nn.Conv2d( 86 | output_channel, 87 | output_channel, 88 | kernel_size=kernel - 1, 89 | stride=stride - 1, 90 | padding=1, 91 | ), 92 | ResidualStack(output_channel, output_channel, res_h_dim, n_res_layers), 93 | ) 94 | 95 | def forward(self, x): 96 | return self.conv_stack(x) 97 | 98 | 99 | class Decoder(nn.Module): 100 | """ 101 | This is the p_phi (x|z) network. Given a latent sample z p_phi 102 | maps back to the original space z -> x. 103 | Inputs: 104 | - in_dim : the input dimension 105 | - h_dim : the hidden layer dimension 106 | - res_h_dim : the hidden dimension of the residual block 107 | - n_res_layers : number of layers to stack 108 | """ 109 | 110 | def __init__( 111 | self, input_channel, output_channel, h_dim=128, n_res_layers=3, res_h_dim=128 112 | ): 113 | super(Decoder, self).__init__() 114 | kernel = 4 115 | stride = 2 116 | 117 | self.inverse_conv_stack = nn.Sequential( 118 | nn.ConvTranspose2d( 119 | input_channel, 120 | h_dim, 121 | kernel_size=kernel - 1, 122 | stride=stride - 1, 123 | padding=1, 124 | ), 125 | ResidualStack(h_dim, h_dim, res_h_dim, n_res_layers), 126 | nn.ConvTranspose2d( 127 | h_dim, h_dim // 2, kernel_size=kernel, stride=stride, padding=1 128 | ), 129 | nn.ReLU(), 130 | nn.ConvTranspose2d( 131 | h_dim // 2, output_channel, kernel_size=kernel, stride=stride, padding=1 132 | ), 133 | ) 134 | 135 | def forward(self, x): 136 | return self.inverse_conv_stack(x) 137 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import hydra 4 | from omegaconf import DictConfig 5 | from pytorch_lightning import ( 6 | Callback, 7 | LightningDataModule, 8 | LightningModule, 9 | Trainer, 10 | seed_everything, 11 | ) 12 | import GPUtil 13 | from src.utils import utils 14 | 15 | log = utils.get_logger(__name__) 16 | 17 | 18 | def train(config: DictConfig): 19 | if "seed" in config: 20 | seed_everything(config.seed) 21 | # Init lightning datamodule 22 | log.info(f"Instantiating datamodule <{config.datamodule._target_}>") 23 | datamodule: LightningDataModule = hydra.utils.instantiate(config.datamodule) 24 | 25 | # Init lightning model 26 | log.info(f"Instantiating model <{config.model._target_}>") 27 | model: LightningModule = hydra.utils.instantiate(config.model, datamodule=config.datamodule, _recursive_=False) 28 | 29 | # Init lightning callbacks 30 | callbacks: List[Callback] = [] 31 | if "callbacks" in config: 32 | for _, cb_conf in config.callbacks.items(): 33 | if "_target_" in cb_conf: 34 | log.info(f"Instantiating callback <{cb_conf._target_}>") 35 | callbacks.append(hydra.utils.instantiate(cb_conf)) 36 | 37 | # Init lightning logger 38 | logger = hydra.utils.instantiate(config.logger) 39 | log.info(f"Instantiating logger <{config.logger._target_}>") 40 | 41 | # Init lightning trainer 42 | log.info(f"Instantiating trainer <{config.trainer._target_}>") 43 | # Automatic assign free GPUs to trainer 44 | if isinstance(config.trainer.devices, int) and config.trainer.devices == 1: 45 | config.trainer.devices = GPUtil.getAvailable(limit=config.trainer.devices, maxMemory=0.5, order='random') 46 | trainer: Trainer = hydra.utils.instantiate( 47 | config.trainer, callbacks=callbacks, logger=logger, _convert_="partial" 48 | ) 49 | 50 | # Send some parameters from config to all lightning loggers 51 | log.info("Logging hyperparameters!") 52 | utils.log_hyperparameters( 53 | config=config, 54 | model=model, 55 | datamodule=datamodule, 56 | trainer=trainer, 57 | callbacks=callbacks, 58 | logger=logger, 59 | ) 60 | 61 | # Train the model 62 | log.info("Starting training!") 63 | trainer.fit(model=model, datamodule=datamodule) 64 | 65 | # Evaluate model on test set, using the best model achieved during training 66 | if config.get("test_after_training") and not config.trainer.get("fast_dev_run"): 67 | log.info("Starting testing!") 68 | trainer.test() 69 | 70 | # Make sure everything closed properly 71 | log.info("Finalizing!") 72 | 73 | # Print path to best checkpoint 74 | log.info(f"Best checkpoint path:\n{trainer.checkpoint_callback.best_model_path}") 75 | 76 | # Return metric score for hyperparameter optimization 77 | optimized_metric = config.get("optimized_metric") 78 | if optimized_metric: 79 | return trainer.callback_metrics[optimized_metric] 80 | -------------------------------------------------------------------------------- /src/utils/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.distributions as D 4 | import torch.nn.functional as F 5 | 6 | def get_decode_dist(name): 7 | if name == "gaussian": 8 | return GaussianDistribution() 9 | elif name == "bernoulli": 10 | return BernoulliDistribution() 11 | else: 12 | raise NotImplementedError 13 | 14 | class GaussianDistribution(nn.Module): 15 | def __init__(self): 16 | super().__init__() 17 | 18 | def prob(self, pred, target): 19 | dist = D.Normal(pred, torch.ones_like(pred)) 20 | p_x = dist.log_prob(target).sum(dim=[1,2,3]) 21 | return p_x 22 | 23 | def sample(self, pred): 24 | return pred 25 | 26 | class BernoulliDistribution(nn.Module): 27 | def __init__(self): 28 | super().__init__() 29 | 30 | def prob(self, pred, target): 31 | # pred: probabilities between 0 and 1 32 | prob = -F.binary_cross_entropy(pred, target, reduction='none').sum([1, 2, 3]) 33 | return prob 34 | 35 | def sample(self, pred): 36 | return torch.bernoulli(pred) 37 | -------------------------------------------------------------------------------- /src/utils/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | def adversarial_loss(pred, target_is_real=True, loss_mode="vanilla"): 5 | if loss_mode == "vanilla": 6 | if target_is_real: 7 | target = torch.ones_like(pred) 8 | else: 9 | target = torch.zeros_like(pred) 10 | return F.binary_cross_entropy_with_logits(pred, target) 11 | elif loss_mode == "lsgan": 12 | if target_is_real: 13 | target = torch.ones_like(pred) 14 | else: 15 | target = torch.zeros_like(pred) 16 | return F.mse_loss(pred, target) 17 | elif loss_mode == "hinge": 18 | if target_is_real: 19 | loss = torch.maximum(1-pred, torch.ones_like(pred)).mean() 20 | else: 21 | loss = torch.maximum(1+pred, torch.zeros_like(pred)).mean() 22 | return loss 23 | else: 24 | raise NotImplementedError 25 | 26 | def normal_kld(mu, log_sigma): 27 | kl_divergence = -0.5 * torch.sum(1 + 2 * log_sigma - mu ** 2 - torch.exp(2 * log_sigma), dim=-1).mean(dim=0) 28 | return kl_divergence 29 | 30 | def symmetry_contra_loss(feat1, feat2, temperature=0.07): 31 | logits = torch.einsum("ik,jk->ij", feat1, feat2) / temperature # (d, d) 32 | d = logits.shape[0] 33 | 34 | labels = torch.arange(d).to(feat1.device) # (d) 35 | loss_i = F.cross_entropy(logits, labels) 36 | loss_j = F.cross_entropy(logits.T, labels) 37 | contra_loss = (loss_i + loss_j) / 2 38 | return contra_loss 39 | -------------------------------------------------------------------------------- /src/utils/toy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import distributions as D 4 | import matplotlib.pyplot as plt 5 | from torch import nn 6 | from torch.nn.parameter import Parameter 7 | 8 | 9 | class GMM(nn.Module): 10 | def __init__(self, n=3, device=torch.cuda): 11 | super().__init__() 12 | self.device = device 13 | self.n = n 14 | self.p_s = Parameter(torch.tensor([1/n for _ in range(n)])).to(device) 15 | self.mu_s = Parameter(torch.stack([torch.randn(2)*2 for _ in range(n)], dim=0)).to(device) 16 | self.sigma_s = Parameter(torch.stack([torch.diag(torch.rand(2)) for _ in range(n)], dim=0)).to(device) 17 | self.update() 18 | 19 | def update(self): 20 | self.p_dist = D.Categorical(probs=self.p_s) 21 | self.dist_list = [ 22 | D.MultivariateNormal(loc=self.mu_s[i], 23 | covariance_matrix=self.sigma_s[i]) 24 | for i in range(self.n) 25 | ] 26 | 27 | def plot(self, samples=None, label=None, N=10000): 28 | self.update() 29 | if samples == None and label == None: 30 | samples, label = self.sample(N) 31 | plt.figure() 32 | samples = samples.detach().cpu().numpy() 33 | label = label.detach().cpu().numpy() 34 | plt.scatter(samples[:, 0], samples[:, 1], c=label, cmap="tab10") 35 | plt.show() 36 | 37 | def sample(self, N): 38 | self.update() 39 | samples = self.p_dist.sample([N]).reshape(N, 1, 1).repeat(1, 1, 2) # (N) -> (N, 1, 2) 40 | candidates = torch.stack([dist.sample([N]) for dist in self.dist_list], dim=1) 41 | out = torch.gather(candidates, dim=1, index=samples) # (100, 3, 2) -> (100, 1, 2) 42 | return out.squeeze(), samples.squeeze()[:, 0] # (N, 2), (N) 43 | 44 | def log_prob(self, samples): 45 | self.update() 46 | # (n, N) 47 | log_prob = torch.stack([self.p_dist.log_prob(torch.tensor(i, device=self.device))+self.dist_list[i].log_prob(samples) for i in range(self.n)], axis=0) 48 | return torch.logsumexp(log_prob, dim=0) 49 | 50 | def info(self): 51 | print("discrete p:", self.p_s) 52 | for i in range(self.n): 53 | print("-----------") 54 | print(self.mu_s[i]) 55 | print(self.sigma_s[i]) 56 | 57 | class ToyGMM(GMM): 58 | def __init__(self, n, device): 59 | super().__init__(n=n, device=device) 60 | angles = [2*i*np.pi/self.n for i in range(self.n)] 61 | def mean(theta): 62 | return torch.tensor([np.cos(theta), np.sin(theta)], dtype=torch.float32).to(device) 63 | 64 | def get_covariance(theta): 65 | v1 = mean(theta) 66 | v2 = mean(theta+np.pi/2) # get vector perpend to v1 67 | Q = torch.stack([v1, v2], axis=1).to(device) 68 | D = torch.diag(torch.tensor([0.35, 0.08], dtype=torch.float32)**2).to(device) 69 | return Q@D@Q.T 70 | 71 | self.mu_s = Parameter(torch.stack([mean(x) for x in angles], dim=0)).to(device) 72 | self.sigma_s = Parameter(torch.stack([get_covariance(x) for x in angles], dim=0)).to(device) 73 | self.update() -------------------------------------------------------------------------------- /src/utils/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import warnings 4 | from typing import List, Sequence 5 | 6 | import pytorch_lightning as pl 7 | import rich.syntax 8 | import rich.tree 9 | from omegaconf import DictConfig, OmegaConf 10 | from pytorch_lightning.utilities import rank_zero_only 11 | 12 | 13 | def get_logger(name=__name__, level=logging.INFO) -> logging.Logger: 14 | """Initializes multi-GPU-friendly python logger.""" 15 | 16 | logger = logging.getLogger(name) 17 | logger.setLevel(level) 18 | 19 | # this ensures all logging levels get marked with the rank zero decorator 20 | # otherwise logs would get multiplied for each GPU process in multi-GPU setup 21 | for level in ( 22 | "debug", 23 | "info", 24 | "warning", 25 | "error", 26 | "exception", 27 | "fatal", 28 | "critical", 29 | ): 30 | setattr(logger, level, rank_zero_only(getattr(logger, level))) 31 | 32 | return logger 33 | 34 | 35 | def extras(config: DictConfig) -> None: 36 | """A couple of optional utilities, controlled by main config file: 37 | - disabling warnings 38 | - easier access to debug mode 39 | - forcing debug friendly configuration 40 | 41 | Modifies DictConfig in place. 42 | 43 | Args: 44 | config (DictConfig): Configuration composed by Hydra. 45 | """ 46 | 47 | log = get_logger() 48 | 49 | # enable adding new keys to config 50 | OmegaConf.set_struct(config, False) 51 | 52 | # disable python warnings if 53 | if config.get("ignore_warnings"): 54 | log.info("Disabling python warnings! ") 55 | warnings.filterwarnings("ignore") 56 | 57 | # set if 58 | if config.get("debug"): 59 | log.info("Running in debug mode! ") 60 | config.trainer.fast_dev_run = True 61 | 62 | # force debugger friendly configuration if 63 | if config.trainer.get("fast_dev_run"): 64 | log.info( 65 | "Forcing debugger friendly configuration! " 66 | ) 67 | # Debuggers don't like GPUs or multiprocessing 68 | if config.trainer.get("gpus"): 69 | config.trainer.gpus = 0 70 | if config.datamodule.get("pin_memory"): 71 | config.datamodule.pin_memory = False 72 | if config.datamodule.get("num_workers"): 73 | config.datamodule.num_workers = 0 74 | 75 | # disable adding new keys to config 76 | OmegaConf.set_struct(config, True) 77 | 78 | 79 | @rank_zero_only 80 | def print_config( 81 | config: DictConfig, 82 | fields: Sequence[str] = ( 83 | "trainer", 84 | "model", 85 | "datamodule", 86 | "callbacks", 87 | "logger", 88 | "seed", 89 | "exp_name" 90 | ), 91 | resolve: bool = True, 92 | ) -> None: 93 | """Prints content of DictConfig using Rich library and its tree structure. 94 | 95 | Args: 96 | config (DictConfig): Configuration composed by Hydra. 97 | fields (Sequence[str], optional): Determines which main fields from config will 98 | be printed and in what order. 99 | resolve (bool, optional): Whether to resolve reference fields of DictConfig. 100 | """ 101 | 102 | style = "dim" 103 | tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) 104 | 105 | for field in fields: 106 | branch = tree.add(field, style=style, guide_style=style) 107 | 108 | config_section = config.get(field) 109 | branch_content = str(config_section) 110 | if isinstance(config_section, DictConfig): 111 | branch_content = OmegaConf.to_yaml(config_section, resolve=resolve) 112 | 113 | branch.add(rich.syntax.Syntax(branch_content, "yaml")) 114 | 115 | rich.print(tree) 116 | 117 | with open("config_tree.txt", "w") as fp: 118 | rich.print(tree, file=fp) 119 | 120 | 121 | def empty(*args, **kwargs): 122 | pass 123 | 124 | 125 | @rank_zero_only 126 | def log_hyperparameters( 127 | config: DictConfig, 128 | model: pl.LightningModule, 129 | datamodule: pl.LightningDataModule, 130 | trainer: pl.Trainer, 131 | callbacks: List[pl.Callback], 132 | logger: List[pl.loggers.Logger], 133 | ) -> None: 134 | """This method controls which parameters from Hydra config are saved by Lightning loggers. 135 | 136 | Additionaly saves: 137 | - number of trainable model parameters 138 | """ 139 | 140 | hparams = {} 141 | 142 | # choose which parts of hydra config will be saved to loggers 143 | hparams["trainer"] = config["trainer"] 144 | hparams["model"] = config["model"] 145 | hparams["datamodule"] = config["datamodule"] 146 | if "seed" in config: 147 | hparams["seed"] = config["seed"] 148 | if "callbacks" in config: 149 | hparams["callbacks"] = config["callbacks"] 150 | 151 | # save number of model parameters 152 | hparams["model/params_total"] = sum(p.numel() for p in model.parameters()) 153 | hparams["model/params_trainable"] = sum( 154 | p.numel() for p in model.parameters() if p.requires_grad 155 | ) 156 | hparams["model/params_not_trainable"] = sum( 157 | p.numel() for p in model.parameters() if not p.requires_grad 158 | ) 159 | 160 | # send hparams to all loggers 161 | trainer.logger.log_hyperparams(hparams) 162 | 163 | # disable logging any more hyperparameters for all loggers 164 | # this is just a trick to prevent trainer from logging hparams of model, 165 | # since we already did that above 166 | trainer.logger.log_hyperparams = empty 167 | 168 | 169 | def finish( 170 | config: DictConfig, 171 | model: pl.LightningModule, 172 | datamodule: pl.LightningDataModule, 173 | trainer: pl.Trainer, 174 | callbacks: List[pl.Callback], 175 | logger: List[pl.loggers.Logger], 176 | ) -> None: 177 | """Makes sure everything closed properly.""" 178 | 179 | # without this sweeps with wandb logger might crash! 180 | for lg in logger: 181 | if isinstance(lg, pl.loggers.wandb.WandbLogger): 182 | import wandb 183 | 184 | wandb.finish() 185 | -------------------------------------------------------------------------------- /src/utils/visual.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import matplotlib 3 | import torch 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | from src.models.pixelcnn import HorizontalStackConvolution, VerticalStackConvolution 7 | 8 | # %% Show receptive field of convolution 9 | def show_center_recep_field(img, out): 10 | """ 11 | Calculates the gradients of the input with respect to the output center pixel, 12 | and visualizes the overall receptive field. 13 | Inputs: 14 | img - Input image for which we want to calculate the receptive field on. 15 | out - Output features/loss which is used for backpropagation, and should be 16 | the output of the network/computation graph. 17 | """ 18 | # Determine gradients 19 | loss = out[0, :, img.shape[2] // 2, img.shape[3] // 2].sum() # L1 loss 20 | # Retain graph as we want to stack multiple layers and show the receptive field of all of them 21 | loss.backward(retain_graph=True) 22 | 23 | img_grads = img.grad.abs() 24 | img.grad.fill_(0) # Reset grads 25 | 26 | # Plot receptive field 27 | img = img_grads.squeeze().cpu().numpy() # (H, W) 28 | fig, ax = plt.subplots(1, 2) 29 | pos = ax[0].imshow(img) # weighted receptive field 30 | ax[1].imshow(img > 0) # binary receptive field 31 | # Mark the center pixel in red if it doesn't have any gradients (should be the case for standard autoregressive models) 32 | show_center = img[img.shape[0] // 2, img.shape[1] // 2] == 0 33 | if show_center: 34 | center_pixel = np.zeros(img.shape + (4,)) 35 | center_pixel[ 36 | center_pixel.shape[0] // 2, center_pixel.shape[1] // 2, : 37 | ] = np.array([1.0, 0.0, 0.0, 1.0]) 38 | for i in range(2): 39 | ax[i].axis("off") 40 | if show_center: 41 | ax[i].imshow(center_pixel) 42 | ax[0].set_title("Weighted receptive field") 43 | ax[1].set_title("Binary receptive field") 44 | plt.show() 45 | plt.close() 46 | 47 | 48 | # %% 49 | inp_img = torch.zeros(1, 1, 11, 11) 50 | inp_img.requires_grad_() 51 | show_center_recep_field(inp_img, inp_img) 52 | 53 | # %% show horizontal convolution ERF 54 | horiz_conv = HorizontalStackConvolution( 55 | c_in=1, c_out=1, kernel_size=3, mask_center=True 56 | ) 57 | horiz_conv.conv.weight.data.fill_(1) 58 | horiz_conv.conv.bias.data.fill_(0) 59 | horiz_img = horiz_conv(inp_img) 60 | show_center_recep_field(inp_img, horiz_img) 61 | 62 | # %% show vertical convolution ERF 63 | vert_conv = VerticalStackConvolution(c_in=1, c_out=1, kernel_size=3, mask_center=True) 64 | vert_conv.conv.weight.data.fill_(1) 65 | vert_conv.conv.bias.data.fill_(0) 66 | vert_img = vert_conv(inp_img) 67 | show_center_recep_field(inp_img, vert_img) 68 | 69 | # %% 70 | # Initialize convolutions with equal weight to all input pixels 71 | horiz_conv = HorizontalStackConvolution( 72 | c_in=1, c_out=1, kernel_size=3, mask_center=False 73 | ) 74 | horiz_conv.conv.weight.data.fill_(1) 75 | horiz_conv.conv.bias.data.fill_(0) 76 | vert_conv = VerticalStackConvolution(c_in=1, c_out=1, kernel_size=3, mask_center=False) 77 | vert_conv.conv.weight.data.fill_(1) 78 | vert_conv.conv.bias.data.fill_(0) 79 | 80 | # We reuse our convolutions for the 4 layers here. Note that in a standard network, 81 | # we don't do that, and instead learn 4 separate convolution. As this cell is only for 82 | # visualization purposes, we reuse the convolutions for all layers. 83 | for l_idx in range(4): 84 | vert_img = vert_conv(vert_img) 85 | horiz_img = horiz_conv(horiz_img) + vert_img 86 | print(f"Layer {l_idx+2}") 87 | show_center_recep_field(inp_img, horiz_img) 88 | 89 | # %% 90 | -------------------------------------------------------------------------------- /wiki/wgan.adoc: -------------------------------------------------------------------------------- 1 | = 对WGAN中网络结构的敏感性分析(只因多加了一层conv layer, WGAN怎么都没有办法收敛!) 2 | 3 | 起因: 很久没有更新这个repo了,最近正好看了@苏剑林 关于WGAN的blog, 想对wgan的一些实验设置进行分析,然后重新跑了一遍wgan的实验,发现竟然无法复现出之前的结果,仔细对比了之前实验记录中lr、optmizer、clip_weight、n_critic等设置,及时全部对应也无法复现之前的结果,一度失去信心。 4 | 5 | 最后将代码checkout到原来的commit, 重新对比Training optimizer,weight clip,dataloader: (batch_size, transform),Network architecture等,最后把network的代码copy过来,竟然work了! 6 | 7 | 而这个两种network的区别仅在于我后来为了提高网络复杂度,多加了一层conv layer,本以为一层网络不会影响训练稳定性的,但实际上差别非常大,原来的network能在10个epoch内生成有意义的图像,而加了一层conv layer之后需要1000 epochs才有概率生成有意义的图像。 8 | 9 | == 网络结构对比分析 10 | 11 | work版本的generator和discriminator: 12 | ---- 13 | Generator( 14 | (main): Sequential( 15 | (0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False) 16 | (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) 17 | (2): ReLU(inplace=True) 18 | (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False) 19 | (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) 20 | (5): ReLU(inplace=True) 21 | (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False) 22 | (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) 23 | (8): ReLU(inplace=True) 24 | (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False) 25 | (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) 26 | (11): ReLU(inplace=True) 27 | (12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False) 28 | (13): Tanh() 29 | ) 30 | ) 31 | 32 | Discriminator( 33 | (main): Sequential( 34 | (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False) 35 | (1): LeakyReLU(negative_slope=0.2, inplace=True) 36 | (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False) 37 | (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) 38 | (4): LeakyReLU(negative_slope=0.2, inplace=True) 39 | (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False) 40 | (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) 41 | (7): LeakyReLU(negative_slope=0.2, inplace=True) 42 | (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False) 43 | (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) 44 | (10): LeakyReLU(negative_slope=0.2, inplace=True) 45 | (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False) 46 | ) 47 | ) 48 | 49 | ---- 50 | 51 | 修改之后,失败版本的generator和discriminator 52 | ---- 53 | Generator( 54 | (main): Sequential( 55 | (0): ConvTranspose2d(100, 512, kernel_size=(1, 1), stride=(1, 1)) 56 | (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) 57 | (2): ReLU(inplace=True) 58 | (3): ConvTranspose2d(512, 512, kernel_size=(4, 4), stride=(1, 1)) 59 | (4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) 60 | (5): ReLU(inplace=True) 61 | (6): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)) 62 | (7): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) 63 | (8): ReLU(inplace=True) 64 | (9): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)) 65 | (10): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) 66 | (11): ReLU(inplace=True) 67 | (12): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)) 68 | (13): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) 69 | (14): ReLU(inplace=True) 70 | (15): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)) 71 | (16): Tanh() 72 | ) 73 | ) 74 | 75 | Discriminator( 76 | (main): Sequential( 77 | (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)) 78 | (1): LeakyReLU(negative_slope=0.2, inplace=True) 79 | (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)) 80 | (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) 81 | (4): LeakyReLU(negative_slope=0.2, inplace=True) 82 | (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)) 83 | (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) 84 | (7): LeakyReLU(negative_slope=0.2, inplace=True) 85 | (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)) 86 | (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) 87 | (10): LeakyReLU(negative_slope=0.2, inplace=True) 88 | (11): Conv2d(512, 512, kernel_size=(4, 4), stride=(1, 1)) 89 | (12): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) 90 | (13): LeakyReLU(negative_slope=0.2, inplace=True) 91 | (14): Conv2d(512, 1, kernel_size=(1, 1), stride=(1, 1)) 92 | ) 93 | ) 94 | ---- 95 | 96 | 97 | 具体来说Generator区别在于将原来网络的 98 | ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False) 99 | 替换成了 100 | ConvTranspose2d(100, 512, kernel_size=(1, 1), stride=(1, 1)) 101 | ConvTranspose2d(512, 512, kernel_size=(4, 4), stride=(1, 1)) 102 | 两个layer 103 | 104 | Discriminator将原来网络的最后一个Conv 105 | (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False) 106 | 替换成了 107 | (11): Conv2d(512, 512, kernel_size=(4, 4), stride=(1, 1)) 108 | (14): Conv2d(512, 1, kernel_size=(1, 1), stride=(1, 1)) 109 | 两个Layer 110 | 111 | 112 | 同时分别单独替换Generator和Discriminator,发现替换Generator后仍然能够正常采样,但替换Discriminator后质量大幅下降。 113 | 114 | [%header, cols=2*] 115 | |=== 116 | | 修改Discriminator 117 | | 只修改Generator 118 | 119 | | image:../assets/wiki/wgan/changeD.jpg[ChangeD] 120 | | image:../assets/wiki/wgan/changeG.jpg[ChangeG] 121 | 122 | |=== 123 | 124 | 125 | 进一步, 126 | Discriminator将原来网络的最后一个Conv 127 | ---- 128 | (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False) 129 | ---- 130 | 重新替换成 131 | 132 | ---- 133 | (11): Conv2d(512, 1024, kernel_size=(4, 4), stride=(1, 1)) 134 | (14): Conv2d(1024, 1, kernel_size=(1, 1), stride=(1, 1)) 135 | ---- 136 | 137 | 两个Layer,也不管用 138 | 139 | 猜测: 具体原因应该跟weight cliping有关系 140 | 141 | 用修改后的discriminator搭配WGAN-GP,就可以正常训练,证明是weight clipping的问题 142 | 143 | 分析两种情况real 和 fake logits的分布: 144 | 145 | image::../assets/wiki/wgan/logits.png[] 146 | 147 | 发现ChangeD的情况下,出现了明显的gradient vanishing现象。 148 | 149 | > 为什么正好加上这一层就会出现gradient vanishing呢? 150 | 151 | // TOOD: 具体原因还没搞清楚 -> 深度学习的可解释性啊。。。 152 | 153 | 154 | == 其他参数的敏感性分析: 155 | 156 | . lr: 从1e-4到5e-5之间,效果差别不是很大 157 | . optimizer: 158 | adam和rmsprop差距也不是很大 159 | . --------------------------------------------------------------------------------