├── .gitignore ├── LICENSE.md ├── README.md ├── assets └── mnist.gif ├── configs ├── celeba_ae.yaml ├── celeba_diin.yaml ├── celeba_iin.yaml ├── cifar_ae.yaml ├── cifar_iin.yaml ├── cmnist_ae.yaml ├── cmnist_clf.yaml ├── cmnist_clf_diin.yaml ├── cmnist_diin.yaml ├── cmnist_dimeval.yaml ├── fashionmnist_ae.yaml ├── fashionmnist_iin.yaml ├── mnist_ae.yaml ├── mnist_iin.yaml └── project.yaml ├── environment.yaml ├── iin ├── data.py ├── iterators │ ├── ae.py │ ├── base.py │ ├── clf.py │ └── iin.py ├── losses │ ├── ae.py │ └── iin.py └── models │ ├── ae.py │ ├── clf.py │ └── iin.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | /fid_stats 2 | /logs 3 | /wandb 4 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Patrick Esser and Robin Rombach 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A Disentangling Invertible Interpretation Network for Explaining Latent Representations 2 | 3 | PyTorch code accompanying the [CVPR 2020](http://cvpr2020.thecvf.com/) paper 4 | 5 | [**A Disentangling Invertible Interpretation Network for Explaining Latent Representations**](https://compvis.github.io/iin/)
6 | [Patrick Esser](https://github.com/pesser)\*, 7 | [Robin Rombach](https://github.com/rromb)\*, 8 | [Björn Ommer](https://hci.iwr.uni-heidelberg.de/Staff/bommer)
9 | \* equal contribution 10 | 11 | ![teaser](assets/mnist.gif)
12 | [arXiv](https://arxiv.org/abs/2004.13166) | [BibTeX](#bibtex) | [Project Page](https://compvis.github.io/iin/) 13 | 14 | 15 | Table of Contents 16 | ================= 17 | 18 | * [Requirements](#requirements) 19 | * [Data](#data) 20 | * [Training](#training) 21 | * [Autoencoders](#autoencoders) 22 | * [Classifiers](#classifiers) 23 | * [Invertible Interpretation Networks](#invertible-interpretation-networks) 24 | * [Unsupervised](#unsupervised) 25 | * [Supervised](#supervised) 26 | * [Evaluation](#evaluation) 27 | * [Pretrained Models](#pretrained-models) 28 | * [Results](#results) 29 | * [BibTeX](#bibtex) 30 | 31 | 32 | ## Requirements 33 | A suitable [conda](https://conda.io/) environment named `iin` can be created 34 | and activated with: 35 | 36 | ``` 37 | conda env create -f environment.yaml 38 | conda activate iin 39 | ``` 40 | 41 | Optionally, you can then also `conda install tensorflow-gpu=1.14` to speed up 42 | FID evaluations. 43 | 44 | 45 | ## Data 46 | `MNIST`, `FashionMNIST` and `CIFAR10` will be downloaded automatically the 47 | first time they are used and `CelebA` will prompt you to download it. The 48 | content of each dataset can be visualized with 49 | 50 | ``` 51 | edexplore --dataset iin.data. 52 | ``` 53 | 54 | where `` is one of `MNISTTrain`, `MNISTTest`, `FashionMNISTTrain`, 55 | `FashionMNISTTest`, `CIFAR10Train`, `CIFAR10Test`, `CelebATrain`, `CelebATest`, 56 | `FactorCelebATrain`, `FactorCelebATest`, `ColorfulMNISTTrain`, 57 | `ColorfulMNISTTest`, `SingleColorfulMNISTTrain`, `SingleColorfulMNISTTest`. 58 | 59 | 60 | ## Training 61 | 62 | ### Autoencoders 63 | To train autoencoders, run 64 | 65 | ``` 66 | edflow -b configs/_ae.yaml -t 67 | ``` 68 | 69 | where `` is one of `mnist`, `fashionmnist`, `cifar`, `celeba`, 70 | `cmnist`. To enable logging to [wandb](https://wandb.ai), adjust 71 | `configs/project.yaml` and add it to above command: 72 | 73 | ``` 74 | edflow -b configs/_ae.yaml configs/project.yaml -t 75 | ``` 76 | 77 | ### Classifiers 78 | To train a classifier on `ColorfulMNIST`, run 79 | 80 | ``` 81 | edflow -b configs/cmnist_clf.yaml -t 82 | ``` 83 | 84 | Once you have a checkpoint, you can estimate factor dimensionalities using 85 | 86 | ``` 87 | edflow -b configs/cmnist_clf.yaml configs/cmnist_dimeval.yaml -c 88 | ``` 89 | 90 | For the pretrained classifier, this gives 91 | 92 | ``` 93 | [INFO] [dim_callback]: estimated factor dimensionalities: [22, 11, 31] 94 | ``` 95 | 96 | and to compare this to an autoencoder, run 97 | 98 | ``` 99 | edflow -b configs/cmnist_ae.yaml configs/cmnist_dimeval.yaml -c 100 | ``` 101 | 102 | which gives 103 | 104 | ``` 105 | [INFO] [dim_callback]: estimated factor dimensionalities: [13, 17, 34] 106 | ``` 107 | 108 | ### Invertible Interpretation Networks 109 | #### Unsupervised on AE 110 | To train unsupervised invertible interpretation networks, run 111 | 112 | ``` 113 | edflow -b configs/_iin.yaml [configs/project.yaml] -t 114 | ``` 115 | 116 | where `` is one of `mnist`, `fashionmnist`, `cifar`, `celeba`. If, 117 | instead of using one of the [pretrained models](#pretrained-models), you 118 | trained an autoencoder yourself, adjust the `first_stage` config section 119 | accordingly. 120 | 121 | #### Supervised 122 | For supervised, disentangling IINs, run 123 | 124 | ``` 125 | edflow -b configs/_diin.yaml [configs/project.yaml] -t 126 | ``` 127 | 128 | where `` is one of `cmnist` or `celeba`, or run 129 | 130 | ``` 131 | edflow -b configs/cmnist_clf_diin.yaml [configs/project.yaml] -t 132 | ``` 133 | 134 | to train a dIIN on top of a classifier, with factor dimensionalities as 135 | estimated above (dimensionalities of factors can be adjusted via the 136 | `Transformer/factor_config` configuration entry). 137 | 138 | 139 | ## Evaluation 140 | 141 | Evaluations run automatically after each epoch of training. To start an 142 | evaluation manually, run 143 | 144 | ``` 145 | edflow -p logs//configs/.yaml 146 | ``` 147 | 148 | and, optionally, add `-c ` to evaluate a specific 149 | checkpoint instead of the last one. 150 | 151 | 152 | ## Pretrained Models 153 | Download [`logs.tar.gz`](https://heibox.uni-heidelberg.de/f/0c76b38bf4274448b8e9/) 154 | (~2.2 GB) and extract the pretrained models: 155 | 156 | ``` 157 | tar xzf logs.tar.gz 158 | ``` 159 | 160 | 161 | ## Results 162 | Using spectral normalization for the discriminator, this code slightly improves 163 | upon the values reported in Tab. 2 of the paper. 164 | 165 | | Dataset | Checkpoint | FID | 166 | |--------------|------------|--------| 167 | | MNIST | 105600 | 5.252 | 168 | | FashionMNIST | 110400 | 9.663 | 169 | | CelebA | 84643 | 19.839 | 170 | | CIFAR10 | 32000 | 38.697 | 171 | 172 | Full training logs can be found on [Weights & 173 | Biases](https://app.wandb.ai/trex/iin/reportlist). 174 | 175 | 176 | ## BibTeX 177 | 178 | ``` 179 | @inproceedings{esser2020invertible, 180 | title={A Disentangling Invertible Interpretation Network for Explaining Latent Representations}, 181 | author={Esser, Patrick and Rombach, Robin and Ommer, Bj{\"o}rn}, 182 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, 183 | year={2020} 184 | } 185 | ``` 186 | -------------------------------------------------------------------------------- /assets/mnist.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/iin/fa0d2d1cfa6fc0b813b902cd3474145a045fbb34/assets/mnist.gif -------------------------------------------------------------------------------- /configs/celeba_ae.yaml: -------------------------------------------------------------------------------- 1 | experiment_name: celeba_ae 2 | 3 | datasets: 4 | train: iin.data.CelebATrain 5 | validation: iin.data.CelebATest 6 | 7 | fid: 8 | batch_size: 50 9 | fid_stats: 10 | pre_calc_stat_path: "fid_stats/celeba.npz" 11 | 12 | model: iin.models.ae.Model 13 | Model: 14 | deterministic: false 15 | in_channels: 3 16 | in_size: 64 17 | n_down: 4 18 | norm: an 19 | z_dim: 64 20 | 21 | loss: iin.losses.ae.LossPND 22 | Loss: 23 | calibrate: true 24 | disc_in_channels: 3 25 | disc_start: 100001 26 | logvar_init: 0.0 27 | perceptual_weight: 1.0 28 | d_lr_factor: 4.0 29 | spectral_norm: True 30 | 31 | 32 | iterator: iin.iterators.ae.Trainer 33 | base_learning_rate: 4.5e-06 34 | decay_start: 505001 35 | batch_size: 25 36 | log_freq: 1000 37 | num_epochs: 100 38 | -------------------------------------------------------------------------------- /configs/celeba_diin.yaml: -------------------------------------------------------------------------------- 1 | experiment_name: celeba_diin 2 | 3 | datasets: 4 | train: iin.data.FactorCelebATrain 5 | validation: iin.data.FactorCelebATest 6 | 7 | fid: 8 | batch_size: 50 9 | fid_stats: 10 | pre_calc_stat_path: "fid_stats/celeba.npz" 11 | 12 | 13 | first_stage: 14 | checkpoint: logs/2020-04-23T23-43-25_celeba_ae/train/checkpoints/model-572968.ckpt 15 | model: iin.models.ae.Model 16 | subconfig: 17 | Model: 18 | deterministic: false 19 | in_channels: 3 20 | in_size: 64 21 | n_down: 4 22 | norm: an 23 | z_dim: 64 24 | 25 | model: iin.models.iin.FactorTransformer 26 | Transformer: 27 | n_factors: 5 28 | in_channel: 64 29 | n_flow: 12 30 | hidden_depth: 2 31 | hidden_dim: 512 32 | 33 | 34 | loss: iin.losses.iin.FactorLoss 35 | Loss: 36 | rho: 0.975 37 | 38 | iterator: iin.iterators.iin.FactorTrainer 39 | base_learning_rate: 4.5e-06 40 | decay_start: 505001 41 | batch_size: 25 42 | log_freq: 1000 43 | num_epochs: 100 44 | -------------------------------------------------------------------------------- /configs/celeba_iin.yaml: -------------------------------------------------------------------------------- 1 | experiment_name: celeba_iin 2 | 3 | datasets: 4 | train: iin.data.CelebATrain 5 | validation: iin.data.CelebATest 6 | 7 | fid: 8 | batch_size: 50 9 | fid_stats: 10 | pre_calc_stat_path: "fid_stats/celeba.npz" 11 | 12 | first_stage: 13 | checkpoint: logs/2020-04-23T23-43-25_celeba_ae/train/checkpoints/model-572968.ckpt 14 | model: iin.models.ae.Model 15 | subconfig: 16 | Model: 17 | deterministic: false 18 | in_channels: 3 19 | in_size: 64 20 | n_down: 4 21 | norm: an 22 | z_dim: 64 23 | 24 | model: iin.models.iin.VectorTransformer 25 | Transformer: 26 | in_channel: 64 27 | n_flow: 12 28 | hidden_depth: 2 29 | hidden_dim: 512 30 | 31 | 32 | loss: iin.losses.iin.Loss 33 | 34 | iterator: iin.iterators.iin.Trainer 35 | base_learning_rate: 4.5e-06 36 | decay_start: 505001 37 | batch_size: 25 38 | log_freq: 1000 39 | num_epochs: 100 40 | -------------------------------------------------------------------------------- /configs/cifar_ae.yaml: -------------------------------------------------------------------------------- 1 | experiment_name: cifar_ae 2 | 3 | datasets: 4 | train: iin.data.CIFAR10Train 5 | validation: iin.data.CIFAR10Test 6 | 7 | fid: 8 | batch_size: 50 9 | fid_stats: 10 | pre_calc_stat_path: "fid_stats/cifar.npz" 11 | 12 | model: iin.models.ae.Model 13 | Model: 14 | deterministic: false 15 | in_channels: 3 16 | in_size: 32 17 | n_down: 4 18 | norm: an 19 | z_dim: 64 20 | 21 | loss: iin.losses.ae.LossPND 22 | Loss: 23 | calibrate: true 24 | disc_in_channels: 3 25 | disc_start: 100001 26 | logvar_init: 0.0 27 | perceptual_weight: 1.0 28 | d_lr_factor: 4.0 29 | spectral_norm: True 30 | 31 | iterator: iin.iterators.ae.Trainer 32 | base_learning_rate: 4.5e-06 33 | batch_size: 25 34 | log_freq: 1000 35 | num_epochs: 300 36 | decay_start: 505001 37 | -------------------------------------------------------------------------------- /configs/cifar_iin.yaml: -------------------------------------------------------------------------------- 1 | experiment_name: cifar_iin 2 | 3 | datasets: 4 | train: iin.data.CIFAR10Train 5 | validation: iin.data.CIFAR10Test 6 | 7 | fid: 8 | batch_size: 50 9 | fid_stats: 10 | pre_calc_stat_path: "fid_stats/cifar.npz" 11 | 12 | 13 | first_stage: 14 | checkpoint: logs/2020-04-24T09-47-43_cifar_ae/train/checkpoints/model-534000.ckpt 15 | model: iin.models.ae.Model 16 | subconfig: 17 | Model: 18 | deterministic: false 19 | in_channels: 3 20 | in_size: 32 21 | n_down: 4 22 | norm: an 23 | z_dim: 64 24 | 25 | model: iin.models.iin.VectorTransformer 26 | Transformer: 27 | in_channel: 64 28 | n_flow: 12 29 | hidden_depth: 2 30 | hidden_dim: 512 31 | 32 | 33 | loss: iin.losses.iin.Loss 34 | 35 | iterator: iin.iterators.iin.Trainer 36 | base_learning_rate: 4.5e-06 37 | batch_size: 25 38 | log_freq: 1000 39 | num_epochs: 300 40 | decay_start: 505001 41 | -------------------------------------------------------------------------------- /configs/cmnist_ae.yaml: -------------------------------------------------------------------------------- 1 | experiment_name: cmnist_ae 2 | 3 | datasets: 4 | train: iin.data.SingleColorfulMNISTTrain 5 | validation: iin.data.SingleColorfulMNISTTest 6 | 7 | fid: 8 | batch_size: 50 9 | fid_stats: 10 | pre_calc_stat_path: "fid_stats/cmnist.npz" 11 | 12 | model: iin.models.ae.Model 13 | Model: 14 | deterministic: false 15 | in_channels: 3 16 | in_size: 32 17 | n_down: 4 18 | norm: an 19 | z_dim: 64 20 | 21 | loss: iin.losses.ae.LossPND 22 | Loss: 23 | calibrate: true 24 | disc_in_channels: 3 25 | disc_start: 25001 26 | logvar_init: 0.0 27 | perceptual_weight: 1.0 28 | discriminator_weight: 150.0 29 | 30 | iterator: iin.iterators.ae.Trainer 31 | base_learning_rate: 4.5e-06 32 | batch_size: 25 33 | log_freq: 1000 34 | num_epochs: 50 35 | decay_start: 100001 36 | -------------------------------------------------------------------------------- /configs/cmnist_clf.yaml: -------------------------------------------------------------------------------- 1 | experiment_name: cmnist_clf 2 | 3 | datasets: 4 | train: iin.data.SingleColorfulMNISTTrain 5 | validation: iin.data.SingleColorfulMNISTTest 6 | n_classes: 10 7 | 8 | model: iin.models.clf.Model 9 | Model: 10 | in_channels: 3 11 | in_size: 32 12 | n_down: 4 13 | norm: an 14 | z_dim: 64 15 | 16 | iterator: iin.iterators.clf.Trainer 17 | base_learning_rate: 4.5e-06 18 | batch_size: 25 19 | log_freq: 1000 20 | num_epochs: 50 21 | decay_start: 100001 22 | -------------------------------------------------------------------------------- /configs/cmnist_clf_diin.yaml: -------------------------------------------------------------------------------- 1 | experiment_name: cmnist_diin 2 | 3 | datasets: 4 | train: iin.data.ColorfulMNISTTrain 5 | validation: iin.data.ColorfulMNISTTest 6 | 7 | first_stage: 8 | checkpoint: logs/2020-07-23T13-19-33_cmnist_clf/train/checkpoints/model-120000.ckpt 9 | model: iin.models.clf.Model 10 | subconfig: 11 | n_classes: 10 12 | Model: 13 | in_channels: 3 14 | in_size: 32 15 | n_down: 4 16 | norm: an 17 | z_dim: 64 18 | 19 | model: iin.models.iin.FactorTransformer 20 | Transformer: 21 | n_factors: 3 22 | in_channel: 64 23 | n_flow: 12 24 | hidden_depth: 2 25 | hidden_dim: 512 26 | factor_config: 27 | - 22 28 | - 11 29 | - 31 30 | 31 | loss: iin.losses.iin.FactorLoss 32 | Loss: 33 | rho: 0.975 34 | 35 | iterator: iin.iterators.iin.FactorTrainer 36 | base_learning_rate: 4.5e-06 37 | batch_size: 25 38 | log_freq: 1000 39 | num_epochs: 50 40 | decay_start: 100001 41 | -------------------------------------------------------------------------------- /configs/cmnist_diin.yaml: -------------------------------------------------------------------------------- 1 | experiment_name: cmnist_diin 2 | 3 | datasets: 4 | train: iin.data.ColorfulMNISTTrain 5 | validation: iin.data.ColorfulMNISTTest 6 | 7 | fid: 8 | batch_size: 50 9 | fid_stats: 10 | pre_calc_stat_path: "fid_stats/cmnist.npz" 11 | 12 | first_stage: 13 | checkpoint: logs/2020-04-27T22-55-44_cmnist_ae/train/checkpoints/model-120000.ckpt 14 | model: iin.models.ae.Model 15 | subconfig: 16 | Model: 17 | deterministic: false 18 | in_channels: 3 19 | in_size: 32 20 | n_down: 4 21 | norm: an 22 | z_dim: 64 23 | 24 | model: iin.models.iin.FactorTransformer 25 | Transformer: 26 | n_factors: 3 27 | in_channel: 64 28 | n_flow: 12 29 | hidden_depth: 2 30 | hidden_dim: 512 31 | 32 | 33 | loss: iin.losses.iin.FactorLoss 34 | Loss: 35 | rho: 0.975 36 | 37 | iterator: iin.iterators.iin.FactorTrainer 38 | base_learning_rate: 4.5e-06 39 | batch_size: 25 40 | log_freq: 1000 41 | num_epochs: 50 42 | decay_start: 100001 43 | 44 | -------------------------------------------------------------------------------- /configs/cmnist_dimeval.yaml: -------------------------------------------------------------------------------- 1 | experiment_name: cmnist_dimeval 2 | 3 | datasets: 4 | train: iin.data.ColorfulMNISTTrain 5 | validation: iin.data.ColorfulMNISTTest 6 | n_classes: 10 7 | 8 | iterator: iin.iterators.base.DimEvaluator 9 | batch_size: 25 10 | num_steps: 1 11 | -------------------------------------------------------------------------------- /configs/fashionmnist_ae.yaml: -------------------------------------------------------------------------------- 1 | experiment_name: fashionmnist_ae 2 | 3 | datasets: 4 | train: iin.data.FashionMNISTTrain 5 | validation: iin.data.FashionMNISTTest 6 | 7 | fid: 8 | batch_size: 50 9 | fid_stats: 10 | pre_calc_stat_path: "fid_stats/fashionmnist.npz" 11 | 12 | model: iin.models.ae.Model 13 | Model: 14 | deterministic: false 15 | in_channels: 1 16 | in_size: 32 17 | n_down: 4 18 | norm: an 19 | z_dim: 64 20 | 21 | loss: iin.losses.ae.LossPND 22 | Loss: 23 | calibrate: true 24 | disc_in_channels: 1 25 | disc_start: 25001 26 | logvar_init: 0.0 27 | perceptual_weight: 1.0 28 | 29 | iterator: iin.iterators.ae.Trainer 30 | base_learning_rate: 4.5e-06 31 | batch_size: 25 32 | log_freq: 1000 33 | num_epochs: 50 34 | decay_start: 100001 35 | -------------------------------------------------------------------------------- /configs/fashionmnist_iin.yaml: -------------------------------------------------------------------------------- 1 | experiment_name: fashionmnist_iin 2 | 3 | datasets: 4 | train: iin.data.FashionMNISTTrain 5 | validation: iin.data.FashionMNISTTest 6 | 7 | fid: 8 | batch_size: 50 9 | fid_stats: 10 | pre_calc_stat_path: "fid_stats/fashionmnist.npz" 11 | 12 | first_stage: 13 | checkpoint: logs/2020-04-23T00-17-38_fashionmnist_ae/train/checkpoints/model-120000.ckpt 14 | model: iin.models.ae.Model 15 | subconfig: 16 | Model: 17 | deterministic: false 18 | in_channels: 1 19 | in_size: 32 20 | n_down: 4 21 | norm: an 22 | z_dim: 64 23 | 24 | model: iin.models.iin.VectorTransformer 25 | Transformer: 26 | in_channel: 64 27 | n_flow: 12 28 | hidden_depth: 2 29 | hidden_dim: 512 30 | 31 | 32 | loss: iin.losses.iin.Loss 33 | 34 | iterator: iin.iterators.iin.Trainer 35 | base_learning_rate: 4.5e-06 36 | batch_size: 25 37 | log_freq: 1000 38 | num_epochs: 50 39 | decay_start: 100001 40 | -------------------------------------------------------------------------------- /configs/mnist_ae.yaml: -------------------------------------------------------------------------------- 1 | experiment_name: mnist_ae 2 | 3 | datasets: 4 | train: iin.data.MNISTTrain 5 | validation: iin.data.MNISTTest 6 | 7 | fid: 8 | batch_size: 50 9 | fid_stats: 10 | pre_calc_stat_path: "fid_stats/mnist.npz" 11 | 12 | model: iin.models.ae.Model 13 | Model: 14 | deterministic: false 15 | in_channels: 1 16 | in_size: 32 17 | n_down: 4 18 | norm: an 19 | z_dim: 64 20 | 21 | loss: iin.losses.ae.LossPND 22 | Loss: 23 | calibrate: true 24 | disc_in_channels: 1 25 | disc_start: 25001 26 | logvar_init: 0.0 27 | perceptual_weight: 1.0 28 | discriminator_weight: 150.0 29 | 30 | iterator: iin.iterators.ae.Trainer 31 | base_learning_rate: 4.5e-06 32 | batch_size: 25 33 | log_freq: 1000 34 | num_epochs: 50 35 | decay_start: 100001 36 | -------------------------------------------------------------------------------- /configs/mnist_iin.yaml: -------------------------------------------------------------------------------- 1 | experiment_name: mnist_iin 2 | 3 | datasets: 4 | train: iin.data.MNISTTrain 5 | validation: iin.data.MNISTTest 6 | 7 | fid: 8 | batch_size: 50 9 | fid_stats: 10 | pre_calc_stat_path: "fid_stats/mnist.npz" 11 | 12 | first_stage: 13 | checkpoint: logs/2020-04-23T00-17-34_mnist_ae/train/checkpoints/model-120000.ckpt 14 | model: iin.models.ae.Model 15 | subconfig: 16 | Model: 17 | deterministic: false 18 | in_channels: 1 19 | in_size: 32 20 | n_down: 4 21 | norm: an 22 | z_dim: 64 23 | 24 | model: iin.models.iin.VectorTransformer 25 | Transformer: 26 | in_channel: 64 27 | n_flow: 12 28 | hidden_depth: 2 29 | hidden_dim: 512 30 | 31 | 32 | loss: iin.losses.iin.Loss 33 | 34 | iterator: iin.iterators.iin.Trainer 35 | base_learning_rate: 4.5e-06 36 | batch_size: 25 37 | log_freq: 1000 38 | num_epochs: 50 39 | decay_start: 100001 40 | -------------------------------------------------------------------------------- /configs/project.yaml: -------------------------------------------------------------------------------- 1 | integrations: 2 | wandb: 3 | active: True 4 | entity: trex 5 | project: iin 6 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: iin 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - python=3.7 7 | - pip=19.3 8 | - cudatoolkit=10.1 9 | - pytorch=1.4 10 | - torchvision=0.5 11 | - numpy=1.17 12 | - pip: 13 | - git+https://github.com/pesser/edflow.git@711a728c7afc50a88017bc37213fbe485fc13efe#egg=edflow[full] 14 | - -e git+https://github.com/jhaux/PerceptualSimilarity.git@f752859b104702e52abbe49750c595b4fb980382#egg=perceptual_similarity 15 | - -e git+https://github.com/rromb/fid.git@e6deef55cdb56944a540c3d74739438111292ec8#egg=fid-callback 16 | - albumentations==0.4.3 17 | - opencv-python==4.1.2.30 18 | - pudb==2019.2 19 | -------------------------------------------------------------------------------- /iin/data.py: -------------------------------------------------------------------------------- 1 | import albumentations 2 | import numpy as np 3 | from edflow.util import retrieve 4 | from edflow.data.dataset import PRNGMixin, DatasetMixin, SubDataset 5 | from edflow.datasets.mnist import ( 6 | MNISTTrain as _MNISTTrain, 7 | MNISTTest as _MNISTTest) 8 | from edflow.datasets.fashionmnist import ( 9 | FashionMNISTTrain as _FashionMNISTTrain, 10 | FashionMNISTTest as _FashionMNISTTest) 11 | from edflow.datasets.cifar import ( 12 | CIFAR10Train as _CIFAR10Train, 13 | CIFAR10Test as _CIFAR10Test) 14 | from edflow.datasets.celeba import ( 15 | CelebATrain as _CelebATrain, 16 | CelebATest as _CelebATest) 17 | 18 | 19 | class Base32(DatasetMixin, PRNGMixin): 20 | """Add support for resizing, cropping and dequantization.""" 21 | def __init__(self, config): 22 | self.data = self.get_base_data(config) 23 | self.size = retrieve(config, "spatial_size", default=32) 24 | 25 | self.rescaler = albumentations.SmallestMaxSize(max_size = self.size) 26 | self.cropper = albumentations.CenterCrop(height=self.size,width=self.size) 27 | self.preprocessor = albumentations.Compose([self.rescaler, self.cropper]) 28 | 29 | def preprocess_example(self, example): 30 | example["image"] = ((example["image"]+1)*127.5).astype(np.uint8) 31 | example["image"] = self.preprocessor(image=example["image"])["image"] 32 | example["image"] = (example["image"] + self.prng.random())/256. # dequantization 33 | example["image"] = (example["image"]*2.0-1.0).astype(np.float32) 34 | return example 35 | 36 | def get_example(self, i): 37 | example = super().get_example(i) 38 | example = self.preprocess_example(example) 39 | return example 40 | 41 | 42 | class MNISTTrain(Base32): 43 | def get_base_data(self, config): 44 | return _MNISTTrain(config) 45 | 46 | 47 | class MNISTTest(Base32): 48 | def get_base_data(self, config): 49 | return _MNISTTest(config) 50 | 51 | 52 | class FashionMNISTTrain(Base32): 53 | def get_base_data(self, config): 54 | return _FashionMNISTTrain(config) 55 | 56 | 57 | class FashionMNISTTest(Base32): 58 | def get_base_data(self, config): 59 | return _FashionMNISTTest(config) 60 | 61 | 62 | class CIFAR10Train(Base32): 63 | def get_base_data(self, config): 64 | return _CIFAR10Train(config) 65 | 66 | 67 | class CIFAR10Test(Base32): 68 | def get_base_data(self, config): 69 | return _CIFAR10Test(config) 70 | 71 | 72 | class BaseCelebA(DatasetMixin, PRNGMixin): 73 | """Add support for resizing, cropping and dequantization.""" 74 | def __init__(self, config): 75 | self.data = self.get_base_data(config) 76 | self.size = retrieve(config, "spatial_size", default=64) 77 | self.attribute_descriptions = [ 78 | "5_o_Clock_Shadow", "Arched_Eyebrows", "Attractive", 79 | "Bags_Under_Eyes", "Bald", "Bangs", "Big_Lips", "Big_Nose", 80 | "Black_Hair", "Blond_Hair", "Blurry", "Brown_Hair", 81 | "Bushy_Eyebrows", "Chubby", "Double_Chin", "Eyeglasses", 82 | "Goatee", "Gray_Hair", "Heavy_Makeup", "High_Cheekbones", 83 | "Male", "Mouth_Slightly_Open", "Mustache", "Narrow_Eyes", 84 | "No_Beard", "Oval_Face", "Pale_Skin", "Pointy_Nose", 85 | "Receding_Hairline", "Rosy_Cheeks", "Sideburns", "Smiling", 86 | "Straight_Hair", "Wavy_Hair", "Wearing_Earrings", 87 | "Wearing_Hat", "Wearing_Lipstick", "Wearing_Necklace", 88 | "Wearing_Necktie", "Young"] 89 | 90 | self.cropper = albumentations.CenterCrop(height=160,width=160) 91 | self.rescaler = albumentations.SmallestMaxSize(max_size = self.size) 92 | self.preprocessor = albumentations.Compose([self.cropper, self.rescaler]) 93 | 94 | def preprocess_example(self, example): 95 | example["image"] = ((example["image"]+1)*127.5).astype(np.uint8) 96 | example["image"] = self.preprocessor(image=example["image"])["image"] 97 | example["image"] = (example["image"] + self.prng.random())/256. # dequantization 98 | example["image"] = (example["image"]*2.0-1.0).astype(np.float32) 99 | return example 100 | 101 | def get_example(self, i): 102 | example = super().get_example(i) 103 | example = self.preprocess_example(example) 104 | return example 105 | 106 | 107 | class CelebATrain(BaseCelebA): 108 | def get_base_data(self, config): 109 | return _CelebATrain(config) 110 | 111 | 112 | class CelebATest(BaseCelebA): 113 | def get_base_data(self, config): 114 | data = _CelebATest(config) 115 | indices = np.random.RandomState(1).choice(len(data), size=10000) 116 | return SubDataset(data, indices) 117 | 118 | 119 | class BaseFactorCelebA(BaseCelebA): 120 | def __init__(self, config): 121 | super().__init__(config) 122 | self.attributes = retrieve(config, "CelebAFactors/attributes", 123 | default=["Eyeglasses", "Male", "No_Beard", "Smiling"]) 124 | self.attribute_indices = [self.attribute_descriptions.index(attr) 125 | for attr in self.attributes] 126 | self.pos_attribute_choices = [np.where(self.labels["attributes"][:,attridx]==1)[0] 127 | for attridx in self.attribute_indices] 128 | self.neg_attribute_choices = [np.where(self.labels["attributes"][:,attridx]==-1)[0] 129 | for attridx in self.attribute_indices] 130 | self.n_factors = len(self.attributes)+1 131 | self.residual_index = len(self.attributes) 132 | 133 | def get_factor_idx(self, i): 134 | factor = self.prng.choice(len(self.attributes)) 135 | attridx = self.attribute_indices[factor] 136 | attr = self.labels["attributes"][i,attridx] 137 | if attr == 1: 138 | i2 = self.prng.choice(self.pos_attribute_choices[factor]) 139 | else: 140 | i2 = self.prng.choice(self.neg_attribute_choices[factor]) 141 | return factor, i2 142 | 143 | def get_example(self, i): 144 | e1 = super().get_example(i) 145 | factor, i2 = self.get_factor_idx(i) 146 | e2 = super().get_example(i2) 147 | example = { 148 | "factor": factor, 149 | "example1": e1, 150 | "example2": e2} 151 | return example 152 | 153 | 154 | class FactorCelebATrain(BaseFactorCelebA): 155 | def get_base_data(self, config): 156 | return _CelebATrain(config) 157 | 158 | 159 | class FactorCelebATest(BaseFactorCelebA): 160 | def __init__(self, config): 161 | super().__init__(config) 162 | self.test_prng = np.random.RandomState(1) 163 | self.factor_idx = [BaseFactorCelebA.get_factor_idx(self, i) for i in range(len(self))] 164 | 165 | @property 166 | def prng(self): 167 | return self.test_prng 168 | 169 | def get_factor_idx(self, i): 170 | return self.factor_idx[i] 171 | 172 | def get_base_data(self, config): 173 | data = _CelebATest(config) 174 | indices = np.random.RandomState(1).choice(len(data), size=10000) 175 | return SubDataset(data, indices) 176 | 177 | 178 | class ColorfulMNISTBase(DatasetMixin): 179 | def get_factor(self, i): 180 | factor = self.prng.choice(2) 181 | return factor 182 | 183 | def get_same_idx(self, i): 184 | cls = self.labels["class"][i] 185 | others = np.where(self.labels["class"] == cls)[0] 186 | return self.prng.choice(others) 187 | 188 | def get_other_idx(self, i): 189 | return self.prng.choice(len(self)) 190 | 191 | def get_color(self, i): 192 | return self.prng.uniform(low=0,high=1,size=3).astype(np.float32) 193 | 194 | def get_example(self, i): 195 | example1 = super().get_example(i) 196 | factor = self.get_factor(i) 197 | example = {"factor": factor, "example1": example1} 198 | 199 | if factor == 0: 200 | # same digit, different color 201 | j = self.get_same_idx(i) 202 | color1 = self.get_color(i) 203 | color2 = self.get_color(j) 204 | else: 205 | # different digit, same color 206 | j = self.get_other_idx(i) 207 | color1 = self.get_color(i) 208 | color2 = color1 209 | 210 | example2 = super().get_example(j) 211 | example["example2"] = example2 212 | 213 | example["example1"]["image"] = example["example1"]["image"] * color1 214 | example["example2"]["image"] = example["example2"]["image"] * color2 215 | 216 | return example 217 | 218 | 219 | class ColorfulMNISTTrain(ColorfulMNISTBase, PRNGMixin): 220 | def __init__(self, config): 221 | self.data = MNISTTrain(config) 222 | self.n_factors = 3 223 | self.residual_index = 2 224 | 225 | 226 | class ColorfulMNISTTest(ColorfulMNISTBase): 227 | def __init__(self, config): 228 | self.data = MNISTTest(config) 229 | self.prng = np.random.RandomState(1) 230 | self.factor = [ColorfulMNISTBase.get_factor(self, i) for i in range(len(self))] 231 | self.same_idx = [ColorfulMNISTBase.get_same_idx(self, i) for i in range(len(self))] 232 | self.other_idx = [ColorfulMNISTBase.get_other_idx(self, i) for i in range(len(self))] 233 | self.color = [ColorfulMNISTBase.get_color(self, i) for i in range(len(self))] 234 | self.n_factors = 3 235 | self.residual_index = 2 236 | 237 | def get_factor(self, i): 238 | return self.factor[i] 239 | 240 | def get_same_idx(self, i): 241 | return self.same_idx[i] 242 | 243 | def get_other_idx(self, i): 244 | return self.other_idx[i] 245 | 246 | def get_color(self, i): 247 | return self.color[i] 248 | 249 | 250 | class SingleColorfulMNISTTrain(DatasetMixin): 251 | def __init__(self, config): 252 | self.data = ColorfulMNISTTrain(config) 253 | 254 | def get_example(self, i): 255 | return super().get_example(i)["example1"] 256 | 257 | 258 | class SingleColorfulMNISTTest(DatasetMixin): 259 | def __init__(self, config): 260 | self.data = ColorfulMNISTTest(config) 261 | 262 | def get_example(self, i): 263 | return super().get_example(i)["example1"] 264 | -------------------------------------------------------------------------------- /iin/iterators/ae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from perceptual_similarity import PerceptualLoss 4 | from edflow.util import retrieve 5 | from fid import fid_callback 6 | 7 | from iin.iterators.base import Iterator 8 | 9 | 10 | def rec_fid_callback(*args, **kwargs): 11 | return fid_callback.fid(*args, **kwargs, 12 | im_in_key="image", 13 | im_in_support="-1->1", 14 | im_out_key="reconstructions", 15 | im_out_support="0->255", 16 | name="fid_recons") 17 | 18 | 19 | def sample_fid_callback(*args, **kwargs): 20 | return fid_callback.fid(*args, **kwargs, 21 | im_in_key="image", 22 | im_in_support="-1->1", 23 | im_out_key="samples", 24 | im_out_support="0->255", 25 | name="fid_samples") 26 | 27 | 28 | def reconstruction_callback(root, data_in, data_out, config): 29 | log = {"scalars": dict()} 30 | log["scalars"]["rec_loss"] = np.mean(data_out.labels["rec_loss"]) 31 | log["scalars"]["kl_loss"] = np.mean(data_out.labels["kl_loss"]) 32 | return log 33 | 34 | 35 | class Trainer(Iterator): 36 | """ 37 | AE Trainer. Expects `image` in batch, `encode -> Distribution` and `decode` 38 | methods on model. 39 | """ 40 | def __init__(self, *args, **kwargs): 41 | super().__init__(*args, **kwargs) 42 | self.eval_loss = PerceptualLoss() 43 | 44 | def step_op(self, *args, **kwargs): 45 | inputs = kwargs["image"] 46 | inputs = self.totorch(inputs) 47 | 48 | posterior = self.model.encode(inputs) 49 | z = posterior.sample() 50 | reconstructions = self.model.decode(z) 51 | loss, log_dict, loss_train_op = self.loss( 52 | inputs, reconstructions, posterior, self.get_global_step()) 53 | log_dict.setdefault("scalars", dict()) 54 | log_dict.setdefault("images", dict()) 55 | 56 | def train_op(): 57 | loss_train_op() 58 | self.optimizer.zero_grad() 59 | loss.backward() 60 | self.optimizer.step() 61 | self.update_lr() 62 | 63 | def log_op(): 64 | log_dict["images"]["inputs"] = inputs 65 | log_dict["images"]["reconstructions"] = reconstructions 66 | 67 | if not hasattr(self, "fixed_examples"): 68 | self.fixed_examples = [ 69 | self.dataset[i]["image"] 70 | for i in np.random.RandomState(1).choice(len(self.dataset), 71 | self.config["batch_size"])] 72 | self.fixed_examples = np.stack(self.fixed_examples) 73 | self.fixed_examples = self.totorch(self.fixed_examples) 74 | 75 | with torch.no_grad(): 76 | log_dict["images"]["fixed_inputs"] = self.fixed_examples 77 | log_dict["images"]["fixed_reconstructions"] = self.model.decode( 78 | self.model.encode(self.fixed_examples).sample()) 79 | log_dict["images"]["decoded_sample"] = self.model.decode( 80 | torch.randn_like(posterior.mode())) 81 | 82 | for k in log_dict: 83 | for kk in log_dict[k]: 84 | log_dict[k][kk] = self.tonp(log_dict[k][kk]) 85 | 86 | for i, param_group in enumerate(self.optimizer.param_groups): 87 | log_dict["scalars"]["lr_{:02}".format(i)] = param_group['lr'] 88 | return log_dict 89 | 90 | def eval_op(): 91 | with torch.no_grad(): 92 | kl_loss = posterior.kl() 93 | rec_loss = self.eval_loss(reconstructions, inputs) 94 | samples = self.model.decode(torch.randn_like(posterior.mode())) 95 | return {"reconstructions": self.tonp(reconstructions), 96 | "samples": self.tonp(samples), 97 | "labels": {"rec_loss": self.tonp(rec_loss), 98 | "kl_loss": self.tonp(kl_loss)}} 99 | 100 | return {"train_op": train_op, "log_op": log_op, "eval_op": eval_op} 101 | 102 | @property 103 | def callbacks(self): 104 | cbs = {"eval_op": {"reconstruction": reconstruction_callback}} 105 | cbs["eval_op"]["fid_reconstruction"] = rec_fid_callback 106 | cbs["eval_op"]["fid_samples"] = sample_fid_callback 107 | return cbs 108 | -------------------------------------------------------------------------------- /iin/iterators/base.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import torch 3 | import numpy as np 4 | import edflow 5 | from edflow import TemplateIterator, get_obj_from_str 6 | from edflow.util import retrieve 7 | 8 | 9 | def totorch(x, guess_image=True, device=None): 10 | if x.dtype == np.float64: 11 | x = x.astype(np.float32) 12 | x = torch.tensor(x) 13 | if guess_image and len(x.size()) == 4: 14 | x = x.transpose(2, 3).transpose(1, 2) 15 | if device is None: 16 | if torch.cuda.is_available(): 17 | device = torch.device("cuda") 18 | else: 19 | device = torch.device("cpu") 20 | x = x.to(device) 21 | return x 22 | 23 | 24 | def tonp(x, guess_image=True): 25 | try: 26 | if guess_image and len(x.shape) == 4: 27 | x = x.transpose(1, 2).transpose(2, 3) 28 | return x.detach().cpu().numpy() 29 | except AttributeError: 30 | return x 31 | 32 | 33 | def get_learning_rate(config): 34 | if "learning_rate" in config: 35 | learning_rate = config["learning_rate"] 36 | elif "base_learning_rate" in config: 37 | learning_rate = config["base_learning_rate"]*config["batch_size"] 38 | else: 39 | raise KeyError() 40 | return learning_rate 41 | 42 | 43 | class Iterator(TemplateIterator): 44 | """ 45 | Base class to handle device and state. Adds optimizer and loss. 46 | Call update_lr() in train op for lr scheduling. 47 | 48 | Config parameters: 49 | - test_mode : boolean : Put model into .eval() mode. 50 | - no_restore_keys : string1,string2 : Submodels which should not be 51 | restored from checkpoint. 52 | - learning_rate : float : Learning rate of Adam 53 | - base_learning_rate : float : Learning_rate per example to adjust for 54 | batch size (ignored if learning_rate is present) 55 | - decay_start : float : Step after which learning rate is decayed to 56 | zero. 57 | - loss : string : Import path of loss. 58 | """ 59 | def __init__(self, *args, **kwargs): 60 | super().__init__(*args, **kwargs) 61 | if torch.cuda.is_available(): 62 | self.device = torch.device("cuda") 63 | else: 64 | self.device = torch.device("cpu") 65 | self.model.to(self.device) 66 | self.test_mode = self.config.get("test_mode", False) 67 | if self.test_mode: 68 | self.model.eval() 69 | self.submodules = ["model"] 70 | self.do_not_restore_keys = retrieve(self.config, 'no_restore_keys', default='').split(',') 71 | 72 | self.learning_rate = get_learning_rate(self.config) 73 | self.logger.info("learning_rate: {}".format(self.learning_rate)) 74 | params = self.model.parameters() 75 | if "loss" in self.config: 76 | self.loss = get_obj_from_str(self.config["loss"])(self.config) 77 | self.loss.to(self.device) 78 | self.submodules.append("loss") 79 | params = itertools.chain(params, self.loss.parameters()) 80 | self.optimizer = torch.optim.Adam( 81 | params, 82 | lr=self.learning_rate, 83 | betas=(0.5, 0.9)) 84 | self.submodules.append("optimizer") 85 | try: 86 | self.loss.set_last_layer(self.model.get_last_layer()) 87 | except Exception as e: 88 | self.logger.info(' Could not set last layer for calibration. Reason:\n {}'.format(e)) 89 | 90 | self.num_steps = retrieve(self.config, "num_steps") 91 | self.decay_start = retrieve(self.config, "decay_start", default=self.num_steps) 92 | 93 | def get_decay_factor(self): 94 | alpha = 1.0 95 | if self.num_steps > self.decay_start: 96 | alpha = 1.0 - np.clip( 97 | (self.get_global_step() - self.decay_start) / 98 | (self.num_steps - self.decay_start), 99 | 0.0, 1.0) 100 | return alpha 101 | 102 | def update_lr(self): 103 | for param_group in self.optimizer.param_groups: 104 | param_group['lr'] = self.get_decay_factor()*self.learning_rate 105 | 106 | def get_state(self): 107 | state = dict() 108 | for k in self.submodules: 109 | state[k] = getattr(self, k).state_dict() 110 | return state 111 | 112 | def save(self, checkpoint_path): 113 | torch.save(self.get_state(), checkpoint_path) 114 | 115 | def restore(self, checkpoint_path): 116 | state = torch.load(checkpoint_path) 117 | keys = list(state.keys()) 118 | for k in keys: 119 | if hasattr(self, k): 120 | if k not in self.do_not_restore_keys: 121 | try: 122 | missing, unexpected = getattr(self, k).load_state_dict(state[k], strict=False) 123 | if missing: 124 | self.logger.info("Missing keys for {}: {}".format(k, missing)) 125 | if unexpected: 126 | self.logger.info("Unexpected keys for {}: {}".format(k, unexpected)) 127 | except TypeError: 128 | self.logger.info(k) 129 | try: 130 | getattr(self, k).load_state_dict(state[k]) 131 | except ValueError: 132 | self.logger.info("Could not load state dict for key {}".format(k)) 133 | else: 134 | self.logger.info('Restored key `{}`'.format(k)) 135 | else: 136 | self.logger.info('Not restoring key `{}` (as specified)'.format(k)) 137 | del state[k] 138 | 139 | def totorch(self, x, guess_image=True): 140 | return totorch(x, guess_image=guess_image, device=self.device) 141 | 142 | def tonp(self, x, guess_image=True): 143 | return tonp(x, guess_image=guess_image) 144 | 145 | def interpolate_corners(self, x, num_side, permute=False): 146 | return interpolate_corners(x, side=num_side, permute=permute) 147 | 148 | 149 | 150 | class DimEvaluator(Iterator): 151 | """ 152 | Estimate dimensionalities for factors. 153 | AE Trainer. Expects `factor`, `example1` and `example2` with `image` in batch, 154 | `encode -> Distribution` methods on model. 155 | """ 156 | def __init__(self, *args, **kwargs): 157 | super().__init__(*args, **kwargs) 158 | 159 | def step_op(self, *args, **kwargs): 160 | def eval_op(): 161 | with torch.no_grad(): 162 | hs = dict() 163 | factor=kwargs["factor"] 164 | with torch.no_grad(): 165 | for k in ["example1", "example2"]: 166 | inputs = self.totorch(kwargs[k]["image"]) 167 | hs[k] = self.model.encode(inputs).mode() 168 | log_dict = {} 169 | log_dict["labels"] = {"factor": factor} 170 | for k in ["example1", "example2"]: 171 | log_dict["labels"][k] = self.tonp(hs[k], guess_image=False) 172 | return log_dict 173 | 174 | return {"eval_op": eval_op} 175 | 176 | @property 177 | def callbacks(self): 178 | return {"eval_op": {"dim_callback": dim_callback}} 179 | 180 | 181 | def dim_callback(root, data_in, data_out, config): 182 | logger = edflow.get_logger("dim_callback") 183 | 184 | factors = data_out.labels["factor"] 185 | za = data_out.labels["example1"].squeeze() 186 | zb = data_out.labels["example2"].squeeze() 187 | za_by_factor = dict() 188 | zb_by_factor = dict() 189 | mean_by_factor = dict() 190 | score_by_factor = dict() 191 | 192 | zall = np.concatenate([za,zb], 0) 193 | mean = np.mean(zall, 0, keepdims=True) 194 | var = np.sum(np.mean((zall-mean)*(zall-mean), 0)) 195 | for f in range(data_in.n_factors): 196 | if f != data_in.residual_index: 197 | indices = np.where(factors==f)[0] 198 | za_by_factor[f] = za[indices] 199 | zb_by_factor[f] = zb[indices] 200 | mean_by_factor[f] = 0.5*( 201 | np.mean(za_by_factor[f], 0, keepdims=True)+ 202 | np.mean(zb_by_factor[f], 0, keepdims=True)) 203 | score_by_factor[f] = np.sum( 204 | np.mean( 205 | (za_by_factor[f]-mean_by_factor[f])*(zb_by_factor[f]-mean_by_factor[f]), 0)) 206 | score_by_factor[f] = score_by_factor[f]/var 207 | else: 208 | score_by_factor[f] = 1.0 209 | scores = np.array([score_by_factor[f] for f in range(data_in.n_factors)]) 210 | 211 | m = np.max(scores) 212 | e = np.exp(scores-m) 213 | softmaxed = e / np.sum(e) 214 | 215 | dim = za.shape[1] 216 | dims = [int(s*dim) for s in softmaxed] 217 | dims[-1] = dim - sum(dims[:-1]) 218 | logger.info("estimated factor dimensionalities: {}".format(dims)) 219 | -------------------------------------------------------------------------------- /iin/iterators/clf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from edflow.util import retrieve 4 | 5 | from iin.iterators.base import Iterator 6 | 7 | 8 | class Trainer(Iterator): 9 | """ 10 | Classification Trainer. Expects `image` and `class` in batch, 11 | and a model returning logits. 12 | """ 13 | def __init__(self, *args, **kwargs): 14 | super().__init__(*args, **kwargs) 15 | self.n_classes = retrieve(self.config, "n_classes") 16 | 17 | def step_op(self, *args, **kwargs): 18 | inputs = kwargs["image"] 19 | inputs = self.totorch(inputs) 20 | labels = kwargs["class"] 21 | labels = self.totorch(labels).to(torch.int64) 22 | onehot = torch.nn.functional.one_hot(labels, 23 | num_classes=self.n_classes).float() 24 | 25 | logits = self.model(inputs) 26 | loss = torch.nn.functional.binary_cross_entropy_with_logits(logits, 27 | onehot, 28 | reduction="none") 29 | mean_loss = loss.mean() 30 | 31 | def train_op(): 32 | self.optimizer.zero_grad() 33 | mean_loss.backward() 34 | self.optimizer.step() 35 | self.update_lr() 36 | 37 | def log_op(): 38 | with torch.no_grad(): 39 | prediction = torch.argmax(logits, dim=1) 40 | accuracy = (prediction==labels).float().mean() 41 | 42 | log_dict = {"scalars": { 43 | "loss": mean_loss, 44 | "acc": accuracy 45 | }} 46 | 47 | for k in log_dict: 48 | for kk in log_dict[k]: 49 | log_dict[k][kk] = self.tonp(log_dict[k][kk]) 50 | 51 | for i, param_group in enumerate(self.optimizer.param_groups): 52 | log_dict["scalars"]["lr_{:02}".format(i)] = param_group['lr'] 53 | return log_dict 54 | 55 | def eval_op(): 56 | return { 57 | "labels": {"loss": self.tonp(loss), "logits": self.tonp(logits)}, 58 | } 59 | 60 | return {"train_op": train_op, "log_op": log_op, "eval_op": eval_op} 61 | 62 | @property 63 | def callbacks(self): 64 | return {"eval_op": {"clf_callback": clf_callback}} 65 | 66 | 67 | def clf_callback(root, data_in, data_out, config): 68 | loss = data_out.labels["loss"].mean() 69 | prediction = data_out.labels["logits"].argmax(1) 70 | accuracy = (prediction == data_in.labels["class"][:prediction.shape[0]]).mean() 71 | return {"scalars": {"loss": loss, "accuracy": accuracy}} 72 | -------------------------------------------------------------------------------- /iin/iterators/iin.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from edflow.util import retrieve, get_obj_from_str 4 | 5 | from iin.iterators.base import Iterator 6 | from iin.iterators.ae import sample_fid_callback 7 | 8 | 9 | def loss_callback(root, data_in, data_out, config): 10 | log = {"scalars": dict()} 11 | log["scalars"]["loss"] = np.mean(data_out.labels["loss"]) 12 | return log 13 | 14 | 15 | class Trainer(Iterator): 16 | """ 17 | Unsupervised IIN Trainer. Expects `image` in batch, 18 | `encode -> Distribution` and `decode` methods on first stage model. 19 | """ 20 | def __init__(self, *args, **kwargs): 21 | super().__init__(*args, **kwargs) 22 | first_stage_config = self.config["first_stage"] 23 | self.init_first_stage(first_stage_config) 24 | 25 | def init_first_stage(self, config): 26 | subconfig = config["subconfig"] 27 | self.first_stage = get_obj_from_str(config["model"])(subconfig) 28 | if "checkpoint" in config: 29 | checkpoint = config["checkpoint"] 30 | state = torch.load(checkpoint)["model"] 31 | self.first_stage.load_state_dict(state) 32 | self.logger.info("Restored first stage from {}".format(checkpoint)) 33 | self.first_stage.to(self.device) 34 | self.first_stage.eval() 35 | 36 | def step_op(self, *args, **kwargs): 37 | inputs = kwargs["image"] 38 | inputs = self.totorch(inputs) 39 | 40 | with torch.no_grad(): 41 | posterior = self.first_stage.encode(inputs) 42 | z = posterior.sample() 43 | z_ss, logdet = self.model(z) 44 | loss, log_dict, loss_train_op = self.loss(z_ss, logdet, self.get_global_step()) 45 | log_dict.setdefault("scalars", dict()) 46 | log_dict.setdefault("images", dict()) 47 | 48 | def train_op(): 49 | loss_train_op() 50 | self.optimizer.zero_grad() 51 | loss.backward() 52 | self.optimizer.step() 53 | self.update_lr() 54 | 55 | def log_op(): 56 | if not hasattr(self, "fixed_examples"): 57 | self.fixed_examples = np.random.RandomState(1).randn(*posterior.mode().shape) 58 | self.fixed_examples = self.totorch(self.fixed_examples) 59 | 60 | with torch.no_grad(): 61 | reconstructions = self.first_stage.decode( 62 | self.model.reverse(z_ss)) 63 | samples = self.first_stage.decode( 64 | self.model.reverse(torch.randn_like(posterior.mode()))) 65 | fixed_samples = self.first_stage.decode( 66 | self.model.reverse(self.fixed_examples)) 67 | 68 | log_dict["images"]["inputs"] = inputs 69 | log_dict["images"]["reconstructions"] = reconstructions 70 | log_dict["images"]["samples"] = samples 71 | log_dict["images"]["fixed_samples"] = fixed_samples 72 | 73 | for k in log_dict: 74 | for kk in log_dict[k]: 75 | log_dict[k][kk] = self.tonp(log_dict[k][kk]) 76 | 77 | for i, param_group in enumerate(self.optimizer.param_groups): 78 | log_dict["scalars"]["lr_{:02}".format(i)] = param_group['lr'] 79 | return log_dict 80 | 81 | def eval_op(): 82 | with torch.no_grad(): 83 | loss_ = torch.ones(inputs.shape[0])*loss 84 | samples = self.first_stage.decode(self.model.reverse(torch.randn_like(posterior.mode()))) 85 | return {"samples": self.tonp(samples), 86 | "labels": {"loss": self.tonp(loss_)}} 87 | 88 | return {"train_op": train_op, "log_op": log_op, "eval_op": eval_op} 89 | 90 | @property 91 | def callbacks(self): 92 | cbs = {"eval_op": {"loss_cb": loss_callback}} 93 | cbs["eval_op"]["fid_samples"] = sample_fid_callback 94 | return cbs 95 | 96 | 97 | class FactorTrainer(Trainer): 98 | def step_op(self, *args, **kwargs): 99 | # get inputs 100 | inputs_factor = dict() 101 | z_factor = dict() 102 | z_ss_factor = dict() 103 | logdet_factor = dict() 104 | factors = kwargs["factor"] 105 | for k in ["example1", "example2"]: 106 | inputs = kwargs[k]["image"] 107 | inputs = self.totorch(inputs) 108 | inputs_factor[k] = inputs 109 | 110 | with torch.no_grad(): 111 | posterior = self.first_stage.encode(inputs) 112 | z = posterior.sample() 113 | z_factor[k] = z 114 | z_ss, logdet = self.model(z) 115 | z_ss_factor[k] = z_ss 116 | logdet_factor[k] = logdet 117 | 118 | loss, log_dict, loss_train_op = self.loss( 119 | z_ss_factor, logdet_factor, factors, self.get_global_step()) 120 | 121 | def train_op(): 122 | loss_train_op() 123 | self.optimizer.zero_grad() 124 | loss.backward() 125 | self.optimizer.step() 126 | self.update_lr() 127 | 128 | def log_op(): 129 | if not hasattr(self.model, "decode"): 130 | return None 131 | with torch.no_grad(): 132 | for k in ["example1", "example2"]: 133 | log_dict["images"][k] = inputs_factor[k] 134 | # reencode after update of model 135 | ss_z1 = self.model(z_factor["example1"])[0] 136 | ss_z2 = self.model(z_factor["example2"])[0] 137 | 138 | log_dict["images"]["reconstruction"] = self.first_stage.decode( 139 | self.model.reverse(ss_z1)) 140 | 141 | factor_mask = [ 142 | torch.tensor(((factors==i) | ((factors<0) & (factors!=-i)))[:,None,None,None]).to( 143 | ss_z1[i]) for i in range(len(ss_z1))] 144 | ss_z_swap = [ 145 | ss_z1[i] + 146 | factor_mask[i]*( 147 | ss_z2[i] - ss_z1[i]) 148 | for i in range(len(factor_mask))] 149 | log_dict["images"]["decoded_swap"] = self.first_stage.decode( 150 | self.model.reverse(ss_z_swap)) 151 | 152 | N_cross = 6 153 | z_cross = z_factor["example1"][:N_cross,...] 154 | shape = tuple(z_cross.shape) 155 | z_cross1 = z_cross[None,...][N_cross*[0],...].reshape( 156 | N_cross*N_cross, *shape[1:]) 157 | z_cross2 = z_cross[:,None,...][:,N_cross*[0],...].reshape( 158 | N_cross*N_cross, *shape[1:]) 159 | ss_z_cross1 = self.model(z_cross1)[0] 160 | ss_z_cross2 = self.model(z_cross2)[0] 161 | for i in range(len(ss_z1)): 162 | ss_z_cross = list(ss_z_cross2) 163 | ss_z_cross[i] = ss_z_cross1[i] 164 | log_dict["images"]["decoded_cross_{}".format(i)] = self.first_stage.decode( 165 | self.model.reverse(ss_z_cross)) 166 | 167 | N_fixed = 6 168 | if not hasattr(self, "fixed_examples"): 169 | self.fixed_examples = [ 170 | self.dataset[i]["example1"]["image"] 171 | for i in np.random.RandomState(1).choice(len(self.dataset), 172 | N_fixed)] 173 | self.fixed_examples = np.stack(self.fixed_examples) 174 | self.fixed_examples = self.totorch(self.fixed_examples) 175 | self.fixed_examples = self.first_stage.encode(self.fixed_examples).mode() 176 | shape = tuple(self.fixed_examples.shape) 177 | self.fixed1 = self.fixed_examples[None,...][N_fixed*[0],...].reshape( 178 | N_fixed*N_fixed, *shape[1:]) 179 | self.fixed2 = self.fixed_examples[:,None,...][:,N_fixed*[0],...].reshape( 180 | N_fixed*N_fixed, *shape[1:]) 181 | 182 | ss_z_fixed1 = self.model(self.fixed1)[0] 183 | ss_z_fixed2 = self.model(self.fixed2)[0] 184 | for i in range(len(ss_z_fixed1)): 185 | ss_z_cross = list(ss_z_fixed2) 186 | ss_z_cross[i] = ss_z_fixed1[i] 187 | log_dict["images"]["fixed_cross_{}".format(i)] = self.first_stage.decode( 188 | self.model.reverse(ss_z_cross)) 189 | 190 | for k in log_dict: 191 | for kk in log_dict[k]: 192 | log_dict[k][kk] = self.tonp(log_dict[k][kk]) 193 | return log_dict 194 | 195 | return {"train_op": train_op, "log_op": log_op} 196 | 197 | @property 198 | def callbacks(self): 199 | return dict() 200 | -------------------------------------------------------------------------------- /iin/losses/ae.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from edflow.util import retrieve 3 | import torch.nn as nn 4 | import torch 5 | from perceptual_similarity import PerceptualLoss 6 | import functools 7 | from torch.nn.utils.spectral_norm import spectral_norm 8 | 9 | 10 | def do_spectral_norm(m): 11 | classname = m.__class__.__name__ 12 | if classname.find('Conv') != -1: 13 | spectral_norm(m) 14 | 15 | 16 | def weights_init(m): 17 | classname = m.__class__.__name__ 18 | if classname.find('Conv') != -1: 19 | nn.init.normal_(m.weight.data, 0.0, 0.02) 20 | elif classname.find('BatchNorm') != -1: 21 | nn.init.normal_(m.weight.data, 1.0, 0.02) 22 | nn.init.constant_(m.bias.data, 0) 23 | 24 | 25 | def adopt_weight(weight, global_step, threshold=0, value=0.): 26 | if global_step < threshold: 27 | weight = value 28 | return weight 29 | 30 | 31 | def l1(input, target): 32 | return torch.abs(input-target) 33 | 34 | 35 | def l2(input, target): 36 | return torch.pow((input-target), 2) 37 | 38 | 39 | class NLayerDiscriminator(nn.Module): 40 | """Defines a PatchGAN discriminator 41 | --> from 42 | https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py 43 | """ 44 | 45 | def __init__(self, input_nc=3, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d): 46 | """Construct a PatchGAN discriminator 47 | Parameters: 48 | input_nc (int) -- the number of channels in input images 49 | ndf (int) -- the number of filters in the last conv layer 50 | n_layers (int) -- the number of conv layers in the discriminator 51 | norm_layer -- normalization layer 52 | """ 53 | super(NLayerDiscriminator, self).__init__() 54 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters 55 | use_bias = norm_layer.func != nn.BatchNorm2d 56 | else: 57 | use_bias = norm_layer != nn.BatchNorm2d 58 | 59 | kw = 4 60 | padw = 1 61 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] 62 | nf_mult = 1 63 | nf_mult_prev = 1 64 | for n in range(1, n_layers): # gradually increase the number of filters 65 | nf_mult_prev = nf_mult 66 | nf_mult = min(2 ** n, 8) 67 | sequence += [ 68 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), 69 | norm_layer(ndf * nf_mult), 70 | nn.LeakyReLU(0.2, True) 71 | ] 72 | 73 | nf_mult_prev = nf_mult 74 | nf_mult = min(2 ** n_layers, 8) 75 | sequence += [ 76 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), 77 | norm_layer(ndf * nf_mult), 78 | nn.LeakyReLU(0.2, True) 79 | ] 80 | 81 | sequence += [ 82 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map 83 | self.main = nn.Sequential(*sequence) 84 | 85 | def forward(self, input): 86 | """Standard forward.""" 87 | return self.main(input) 88 | 89 | 90 | class LossPND(nn.Module): 91 | """Using LPIPS Perceptual loss""" 92 | def __init__(self, config): 93 | super().__init__() 94 | __pixel_loss_opt = {"l1": l1, 95 | "l2": l2 } 96 | self.config = config 97 | self.discriminator_iter_start = retrieve(config, "Loss/disc_start") 98 | self.disc_factor = retrieve(config, "Loss/disc_factor", default=1.0) 99 | self.kl_weight = retrieve(config, "Loss/kl_weight", default=1.0) 100 | self.perceptual_weight = retrieve(config, "Loss/perceptual_weight", default=1.0) 101 | if self.perceptual_weight > 0: 102 | self.perceptual_loss = PerceptualLoss() 103 | self.calibrate = retrieve(config, "Loss/calibrate", default=False) 104 | # output log variance 105 | self.logvar_init = retrieve(config, "Loss/logvar_init", default=0.0) 106 | self.logvar = nn.Parameter(torch.ones(size=())*self.logvar_init) 107 | # discriminator 108 | self.discriminator_weight = retrieve(config, "Loss/discriminator_weight", default=1.0) 109 | disc_nc_in = retrieve(config, "Loss/disc_in_channels", default=3) 110 | disc_layers = retrieve(config, "Loss/disc_num_layers", default=3) 111 | self.discriminator = NLayerDiscriminator(input_nc=disc_nc_in, n_layers=disc_layers).apply(weights_init) 112 | self.pixel_loss = __pixel_loss_opt[retrieve(config, "Loss/pixelloss", default="l1")] 113 | if retrieve(config, "Loss/spectral_norm", default=False): 114 | self.discriminator.apply(do_spectral_norm) 115 | if torch.cuda.is_available(): 116 | self.discriminator.cuda() 117 | if "learning_rate" in self.config: 118 | learning_rate = self.config["learning_rate"] 119 | elif "base_learning_rate" in self.config: 120 | learning_rate = self.config["base_learning_rate"]*self.config["batch_size"] 121 | else: 122 | learning_rate = 0.001 123 | self.learning_rate = retrieve(config, "Loss/d_lr_factor", default=1.0)*learning_rate 124 | self.num_steps = retrieve(self.config, "num_steps") 125 | self.decay_start = retrieve(self.config, "decay_start", default=self.num_steps) 126 | self.d_optimizer = torch.optim.Adam( 127 | self.discriminator.parameters(), 128 | lr=learning_rate, 129 | betas=(0.5, 0.9)) 130 | 131 | def get_decay_factor(self, global_step): 132 | alpha = 1.0 133 | if self.num_steps > self.decay_start: 134 | alpha = 1.0 - np.clip( 135 | (global_step - self.decay_start) / 136 | (self.num_steps - self.decay_start), 137 | 0.0, 1.0) 138 | return alpha 139 | 140 | def update_lr(self, global_step): 141 | for param_group in self.d_optimizer.param_groups: 142 | param_group['lr'] = self.get_decay_factor(global_step)*self.learning_rate 143 | 144 | def set_last_layer(self, last_layer): 145 | self.last_layer = [last_layer] 146 | 147 | def parameters(self): 148 | """Exclude discriminator from parameters.""" 149 | ps = super().parameters() 150 | exclude = set(self.discriminator.parameters()) 151 | ps = (p for p in ps if not p in exclude) 152 | return ps 153 | 154 | def forward(self, inputs, reconstructions, posteriors, global_step): 155 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 156 | rec_loss = self.pixel_loss(inputs, reconstructions) # l1 or l2 157 | if self.perceptual_weight > 0: 158 | p_loss = self.perceptual_loss(inputs, reconstructions) 159 | rec_loss = rec_loss + self.perceptual_weight*p_loss 160 | 161 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar 162 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 163 | 164 | calibration = rec_loss / torch.exp(self.logvar) - 1.0 165 | calibration = torch.sum(calibration) / calibration.shape[0] 166 | 167 | kl_loss = posteriors.kl() 168 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] 169 | 170 | logits_real = self.discriminator(inputs) 171 | logits_fake = self.discriminator(reconstructions) 172 | d_loss = 0.5*( 173 | torch.mean(torch.nn.functional.softplus(-logits_real)) + 174 | torch.mean(torch.nn.functional.softplus( logits_fake))) * disc_factor 175 | def train_op(): 176 | self.d_optimizer.zero_grad() 177 | d_loss.backward(retain_graph=True) 178 | self.d_optimizer.step() 179 | self.update_lr(global_step) 180 | 181 | g_loss = -torch.mean(logits_fake) 182 | 183 | if not self.calibrate: 184 | loss = nll_loss + self.kl_weight*kl_loss + self.discriminator_weight*g_loss*disc_factor 185 | else: 186 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 187 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 188 | d_weight = torch.norm(nll_grads)/(torch.norm(g_grads)+1e-4) 189 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 190 | loss = nll_loss + self.kl_weight*kl_loss + d_weight*disc_factor*g_loss 191 | log = {"scalars": {"loss": loss, "logvar": self.logvar, 192 | "kl_loss": kl_loss, "nll_loss": nll_loss, 193 | "rec_loss": rec_loss.mean(), 194 | "calibration": calibration, 195 | "g_loss": g_loss, "d_loss": d_loss, 196 | "logits_real": torch.mean(logits_real), 197 | "logits_fake": torch.mean(logits_fake), 198 | }} 199 | if self.calibrate: 200 | log["scalars"]["d_weight"] = d_weight 201 | for i, param_group in enumerate(self.d_optimizer.param_groups): 202 | log["scalars"]["d_lr_{:02}".format(i)] = param_group['lr'] 203 | return loss, log, train_op 204 | -------------------------------------------------------------------------------- /iin/losses/iin.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.nn as nn 3 | import torch 4 | 5 | from edflow.util import retrieve 6 | 7 | 8 | def nll(sample): 9 | return 0.5*torch.sum(torch.pow(sample, 2), dim=[1,2,3]) 10 | 11 | 12 | class Loss(nn.Module): 13 | def __init__(self, config): 14 | super().__init__() 15 | self.config = config 16 | 17 | def forward(self, sample, logdet, global_step): 18 | nll_loss = torch.mean(nll(sample)) 19 | assert len(logdet.shape) == 1 20 | nlogdet_loss = -torch.mean(logdet) 21 | loss = nll_loss + nlogdet_loss 22 | reference_nll_loss = torch.mean(nll(torch.randn_like(sample))) 23 | log = {"images": {}, 24 | "scalars": { 25 | "loss": loss, "reference_nll_loss": reference_nll_loss, 26 | "nlogdet_loss": nlogdet_loss, "nll_loss": nll_loss, 27 | }} 28 | def train_op(): 29 | pass 30 | return loss, log, train_op 31 | 32 | 33 | class FactorLoss(nn.Module): 34 | def __init__(self, config): 35 | super().__init__() 36 | self.config = config 37 | self.rho = retrieve(config, "Loss/rho", default=0.975) 38 | 39 | def forward(self, samples, logdets, factors, global_step): 40 | sample1 = samples["example1"] 41 | logdet1 = logdets["example1"] 42 | nll_loss1 = torch.mean(nll(torch.cat(sample1, dim=1))) 43 | assert len(logdet1.shape) == 1 44 | nlogdet_loss1 = -torch.mean(logdet1) 45 | loss1 = nll_loss1 + nlogdet_loss1 46 | 47 | sample2 = samples["example2"] 48 | logdet2 = logdets["example2"] 49 | factor_mask = [ 50 | torch.tensor(((factors==i) | ((factors<0) & (factors!=-i)))[:,None,None,None]).to( 51 | sample2[i]) for i in range(len(sample2))] 52 | sample2_cond = [ 53 | sample2[i] - factor_mask[i]*self.rho*sample1[i] 54 | for i in range(len(sample2))] 55 | nll_loss2 = [nll(sample2_cond[i]) for i in range(len(sample2_cond))] 56 | nll_loss2 = [nll_loss2[i]/(1.0-factor_mask[i][:,0,0,0]*self.rho**2) 57 | for i in range(len(sample2_cond))] 58 | nll_loss2 = [torch.mean(nll_loss2[i]) 59 | for i in range(len(sample2_cond))] 60 | nll_loss2 = sum(nll_loss2) 61 | assert len(logdet2.shape) == 1 62 | nlogdet_loss2 = -torch.mean(logdet2) 63 | loss2 = nll_loss2 + nlogdet_loss2 64 | 65 | loss = loss1 + loss2 66 | 67 | log = {"images": {}, 68 | "scalars": { 69 | "loss": loss, "loss1": loss1, "loss2": loss2, 70 | }} 71 | def train_op(): 72 | pass 73 | return loss, log, train_op 74 | -------------------------------------------------------------------------------- /iin/models/ae.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import torch.nn as nn 3 | import torch 4 | import numpy as np 5 | from edflow.util import retrieve 6 | 7 | 8 | class ActNorm(nn.Module): 9 | def __init__(self, num_features, affine=True, logdet=False): 10 | super().__init__() 11 | assert affine 12 | self.logdet = logdet 13 | self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) 14 | self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) 15 | 16 | self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8)) 17 | 18 | def initialize(self, input): 19 | with torch.no_grad(): 20 | flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) 21 | mean = ( 22 | flatten.mean(1) 23 | .unsqueeze(1) 24 | .unsqueeze(2) 25 | .unsqueeze(3) 26 | .permute(1, 0, 2, 3) 27 | ) 28 | std = ( 29 | flatten.std(1) 30 | .unsqueeze(1) 31 | .unsqueeze(2) 32 | .unsqueeze(3) 33 | .permute(1, 0, 2, 3) 34 | ) 35 | 36 | self.loc.data.copy_(-mean) 37 | self.scale.data.copy_(1 / (std + 1e-6)) 38 | 39 | def forward(self, input, reverse=False): 40 | if reverse: 41 | return self.reverse(input) 42 | _, _, height, width = input.shape 43 | 44 | if self.initialized.item() == 0: 45 | self.initialize(input) 46 | self.initialized.fill_(1) 47 | 48 | h = self.scale * (input + self.loc) 49 | 50 | if self.logdet: 51 | log_abs = torch.log(torch.abs(self.scale)) 52 | logdet = height*width*torch.sum(log_abs) 53 | logdet = logdet * torch.ones(input.shape[0]).to(input) 54 | return h, logdet 55 | 56 | return h 57 | 58 | def reverse(self, output): 59 | return output / self.scale - self.loc 60 | 61 | 62 | _norm_options = { 63 | "in": nn.InstanceNorm2d, 64 | "bn": nn.BatchNorm2d, 65 | "an": ActNorm} 66 | 67 | 68 | def weights_init(m): 69 | classname = m.__class__.__name__ 70 | if classname.find('Conv') != -1: 71 | nn.init.normal_(m.weight.data, 0.0, 0.02) 72 | elif classname.find('BatchNorm') != -1: 73 | nn.init.normal_(m.weight.data, 1.0, 0.02) 74 | nn.init.constant_(m.bias.data, 0) 75 | 76 | 77 | class FeatureLayer(nn.Module): 78 | def __init__(self, scale, in_channels=None, norm='IN'): 79 | super().__init__() 80 | self.scale = scale 81 | self.norm = _norm_options[norm.lower()] 82 | if in_channels is None: 83 | self.in_channels = 64*min(2**(self.scale-1), 16) 84 | else: 85 | self.in_channels = in_channels 86 | self.build() 87 | 88 | def forward(self, input): 89 | x = input 90 | for layer in self.sub_layers: 91 | x = layer(x) 92 | return x 93 | 94 | def build(self): 95 | Norm = functools.partial(self.norm, affine=True) 96 | Activate = lambda: nn.LeakyReLU(0.2) 97 | self.sub_layers = nn.ModuleList([ 98 | nn.Conv2d( 99 | in_channels=self.in_channels, 100 | out_channels=64*min(2**self.scale, 16), 101 | kernel_size=4, 102 | stride=2, 103 | padding=1, 104 | bias=False), 105 | Norm(num_features=64*min(2**self.scale, 16)), 106 | Activate()]) 107 | 108 | 109 | class LatentLayer(nn.Module): 110 | def __init__(self, in_channels, out_channels): 111 | super(LatentLayer, self).__init__() 112 | self.in_channels = in_channels 113 | self.out_channels = out_channels 114 | self.build() 115 | 116 | def forward(self, input): 117 | x = input 118 | for layer in self.sub_layers: 119 | x = layer(x) 120 | return x 121 | 122 | def build(self): 123 | self.sub_layers = nn.ModuleList([ 124 | nn.Conv2d( 125 | in_channels=self.in_channels, 126 | out_channels=self.out_channels, 127 | kernel_size=1, 128 | stride=1, 129 | padding=0, 130 | bias=True) 131 | ]) 132 | 133 | 134 | class DecoderLayer(nn.Module): 135 | def __init__(self, scale, in_channels=None, norm='IN'): 136 | super().__init__() 137 | self.scale = scale 138 | self.norm = _norm_options[norm.lower()] 139 | if in_channels is not None: 140 | self.in_channels = in_channels 141 | else: 142 | self.in_channels = 64*min(2**(self.scale+1), 16) 143 | self.build() 144 | 145 | def forward(self, input): 146 | d = input 147 | for layer in self.sub_layers: 148 | d = layer(d) 149 | return d 150 | 151 | def build(self): 152 | Norm = functools.partial(self.norm, affine=True) 153 | Activate = lambda: nn.LeakyReLU(0.2) 154 | self.sub_layers = nn.ModuleList([ 155 | nn.ConvTranspose2d( 156 | in_channels=self.in_channels, 157 | out_channels=64*min(2**self.scale, 16), 158 | kernel_size=4, 159 | stride=2, 160 | padding=1, 161 | bias=False), 162 | Norm(num_features=64*min(2**self.scale, 16)), 163 | Activate()]) 164 | 165 | 166 | class DenseEncoderLayer(nn.Module): 167 | def __init__(self, scale, spatial_size, out_size, in_channels=None): 168 | super().__init__() 169 | self.scale = scale 170 | self.in_channels = 64*min(2**(self.scale-1), 16) 171 | if in_channels is not None: 172 | self.in_channels = in_channels 173 | self.out_channels = out_size 174 | self.kernel_size = spatial_size 175 | self.build() 176 | 177 | def forward(self, input): 178 | x = input 179 | for layer in self.sub_layers: 180 | x = layer(x) 181 | return x 182 | 183 | def build(self): 184 | self.sub_layers = nn.ModuleList([ 185 | nn.Conv2d( 186 | in_channels=self.in_channels, 187 | out_channels=self.out_channels, 188 | kernel_size=self.kernel_size, 189 | stride=1, 190 | padding=0, 191 | bias=True)]) 192 | 193 | 194 | class DenseDecoderLayer(nn.Module): 195 | def __init__(self, scale, spatial_size, in_size): 196 | super().__init__() 197 | self.scale = scale 198 | self.in_channels = in_size 199 | self.out_channels = 64*min(2**self.scale, 16) 200 | self.kernel_size = spatial_size 201 | self.build() 202 | 203 | def forward(self, input): 204 | x = input 205 | for layer in self.sub_layers: 206 | x = layer(x) 207 | return x 208 | 209 | def build(self): 210 | self.sub_layers = nn.ModuleList([ 211 | nn.ConvTranspose2d( 212 | in_channels=self.in_channels, 213 | out_channels=self.out_channels, 214 | kernel_size=self.kernel_size, 215 | stride=1, 216 | padding=0, 217 | bias=True)]) 218 | 219 | 220 | class ImageLayer(nn.Module): 221 | def __init__(self, out_channels=3, in_channels=64): 222 | super().__init__() 223 | self.in_channels = in_channels 224 | self.out_channels = out_channels 225 | self.build() 226 | 227 | def forward(self, input): 228 | x = input 229 | for layer in self.sub_layers: 230 | x = layer(x) 231 | return x 232 | 233 | def build(self): 234 | FinalActivate = lambda: torch.nn.Tanh() 235 | self.sub_layers = nn.ModuleList([ 236 | nn.ConvTranspose2d( 237 | in_channels=self.in_channels, 238 | out_channels=self.out_channels, 239 | kernel_size=4, 240 | stride=2, 241 | padding=1, 242 | bias=False), 243 | FinalActivate() 244 | ]) 245 | 246 | 247 | class Distribution(object): 248 | def __init__(self, parameters, deterministic=False): 249 | self.parameters = parameters 250 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 251 | self.logvar = torch.clamp(self.logvar, -30.0, 10.0) 252 | self.deterministic = deterministic 253 | self.std = torch.exp(0.5*self.logvar) 254 | self.var = torch.exp(self.logvar) 255 | if self.deterministic: 256 | self.var = self.std = torch.zeros_like(self.mean).to( 257 | torch.device('cuda' if torch.cuda.is_available() else 'cpu')) 258 | 259 | def sample(self): 260 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 261 | x = self.mean + self.std*torch.randn(self.mean.shape).to(device) 262 | return x 263 | 264 | def kl(self, other=None): 265 | if self.deterministic: 266 | return torch.Tensor([0.]) 267 | else: 268 | if other is None: 269 | return 0.5*torch.sum(torch.pow(self.mean, 2) 270 | + self.var - 1.0 - self.logvar, 271 | dim=[1,2,3]) 272 | else: 273 | return 0.5*torch.sum( 274 | torch.pow(self.mean - other.mean, 2) / other.var 275 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 276 | dim=[1,2,3]) 277 | 278 | def nll(self, sample): 279 | if self.deterministic: 280 | return torch.Tensor([0.]) 281 | logtwopi = np.log(2.0*np.pi) 282 | return 0.5*torch.sum( 283 | logtwopi+self.logvar+torch.pow(sample-self.mean, 2) / self.var, 284 | dim=[1,2,3]) 285 | 286 | def mode(self): 287 | return self.mean 288 | 289 | 290 | class Model(nn.Module): 291 | def __init__(self, config): 292 | super().__init__() 293 | import torch.backends.cudnn as cudnn 294 | cudnn.benchmark = True 295 | n_down = retrieve(config, "Model/n_down") 296 | z_dim = retrieve(config, "Model/z_dim") 297 | in_size = retrieve(config, "Model/in_size") 298 | bottleneck_size = in_size // 2**n_down 299 | in_channels = retrieve(config, "Model/in_channels") 300 | norm = retrieve(config, "Model/norm") 301 | self.be_deterministic = retrieve(config, "Model/deterministic") 302 | 303 | self.feature_layers = nn.ModuleList() 304 | self.decoder_layers = nn.ModuleList() 305 | 306 | self.feature_layers.append(FeatureLayer(0, in_channels=in_channels, norm=norm)) 307 | for scale in range(1, n_down): 308 | self.feature_layers.append(FeatureLayer(scale, norm=norm)) 309 | 310 | self.dense_encode = DenseEncoderLayer(n_down, bottleneck_size, 2*z_dim) 311 | self.dense_decode = DenseDecoderLayer(n_down-1, bottleneck_size, z_dim) 312 | 313 | for scale in range(n_down-1): 314 | self.decoder_layers.append(DecoderLayer(scale, norm=norm)) 315 | self.image_layer = ImageLayer(out_channels=in_channels) 316 | 317 | self.apply(weights_init) 318 | 319 | self.n_down = n_down 320 | self.z_dim = z_dim 321 | self.bottleneck_size = bottleneck_size 322 | 323 | def encode(self, input): 324 | h = input 325 | for layer in self.feature_layers: 326 | h = layer(h) 327 | h = self.dense_encode(h) 328 | return Distribution(h, deterministic=self.be_deterministic) 329 | 330 | def decode(self, input): 331 | h = input 332 | h = self.dense_decode(h) 333 | for layer in reversed(self.decoder_layers): 334 | h = layer(h) 335 | h = self.image_layer(h) 336 | return h 337 | 338 | def get_last_layer(self): 339 | return self.image_layer.sub_layers[0].weight 340 | -------------------------------------------------------------------------------- /iin/models/clf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from edflow.util import retrieve 5 | 6 | from iin.models.ae import FeatureLayer, DenseEncoderLayer, weights_init 7 | 8 | 9 | class Distribution(object): 10 | def __init__(self, value): 11 | self.value = value 12 | 13 | def sample(self): 14 | return self.value 15 | 16 | def mode(self): 17 | return self.value 18 | 19 | 20 | class Model(nn.Module): 21 | def __init__(self, config): 22 | super().__init__() 23 | import torch.backends.cudnn as cudnn 24 | cudnn.benchmark = True 25 | n_down = retrieve(config, "Model/n_down") 26 | z_dim = retrieve(config, "Model/z_dim") 27 | in_size = retrieve(config, "Model/in_size") 28 | z_dim = retrieve(config, "Model/z_dim") 29 | bottleneck_size = in_size // 2**n_down 30 | in_channels = retrieve(config, "Model/in_channels") 31 | norm = retrieve(config, "Model/norm") 32 | n_classes = retrieve(config, "n_classes") 33 | 34 | self.feature_layers = nn.ModuleList() 35 | 36 | self.feature_layers.append(FeatureLayer(0, in_channels=in_channels, norm=norm)) 37 | for scale in range(1, n_down): 38 | self.feature_layers.append(FeatureLayer(scale, norm=norm)) 39 | 40 | self.dense_encode = DenseEncoderLayer(n_down, bottleneck_size, z_dim) 41 | self.classifier = torch.nn.Linear(z_dim, n_classes) 42 | 43 | self.apply(weights_init) 44 | 45 | self.n_down = n_down 46 | self.bottleneck_size = bottleneck_size 47 | 48 | def forward(self, input): 49 | h = self.encode(input).mode() 50 | assert h.shape[2] == h.shape[3] == 1 51 | h = h[:,:,0,0] 52 | h = self.classifier(h) 53 | return h 54 | 55 | def encode(self, input): 56 | h = input 57 | for layer in self.feature_layers: 58 | h = layer(h) 59 | h = self.dense_encode(h) 60 | return Distribution(h) 61 | -------------------------------------------------------------------------------- /iin/models/iin.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from edflow.util import retrieve, get_obj_from_str 5 | 6 | 7 | class Shuffle(nn.Module): 8 | def __init__(self, in_channels, **kwargs): 9 | super(Shuffle, self).__init__() 10 | self.in_channels = in_channels 11 | idx = torch.randperm(in_channels) 12 | self.register_buffer('forward_shuffle_idx', nn.Parameter(idx, requires_grad=False)) 13 | self.register_buffer('backward_shuffle_idx', nn.Parameter(torch.argsort(idx), requires_grad=False)) 14 | 15 | def forward(self, x, reverse=False, conditioning=None): 16 | if not reverse: 17 | return x[:, self.forward_shuffle_idx, ...], 0 18 | else: 19 | return x[:, self.backward_shuffle_idx, ...] 20 | 21 | 22 | class BasicFullyConnectedNet(nn.Module): 23 | def __init__(self, dim, depth, hidden_dim=256, use_tanh=False, use_bn=False, out_dim=None): 24 | super(BasicFullyConnectedNet, self).__init__() 25 | layers = [] 26 | layers.append(nn.Linear(dim, hidden_dim)) 27 | if use_bn: 28 | layers.append(nn.BatchNorm1d(hidden_dim)) 29 | layers.append(nn.LeakyReLU()) 30 | for d in range(depth): 31 | layers.append(nn.Linear(hidden_dim, hidden_dim)) 32 | if use_bn: 33 | layers.append(nn.BatchNorm1d(hidden_dim)) 34 | layers.append(nn.LeakyReLU()) 35 | layers.append(nn.Linear(hidden_dim, dim if out_dim is None else out_dim)) 36 | if use_tanh: 37 | layers.append(nn.Tanh()) 38 | self.main = nn.Sequential(*layers) 39 | 40 | def forward(self, x): 41 | return self.main(x) 42 | 43 | 44 | class DoubleVectorCouplingBlock(nn.Module): 45 | """In contrast to VectorCouplingBlock, this module assures alternating chunking in upper and lower half.""" 46 | def __init__(self, in_channels, hidden_dim, depth=2, use_hidden_bn=False, n_blocks=2): 47 | super(DoubleVectorCouplingBlock, self).__init__() 48 | assert in_channels % 2 == 0 49 | self.s = nn.ModuleList([BasicFullyConnectedNet(dim=in_channels // 2, depth=depth, hidden_dim=hidden_dim, 50 | use_tanh=True) for _ in range(n_blocks)]) 51 | self.t = nn.ModuleList([BasicFullyConnectedNet(dim=in_channels // 2, depth=depth, hidden_dim=hidden_dim, 52 | use_tanh=False) for _ in range(n_blocks)]) 53 | 54 | def forward(self, x, reverse=False): 55 | if not reverse: 56 | logdet = 0 57 | for i in range(len(self.s)): 58 | idx_apply, idx_keep = 0, 1 59 | if i % 2 != 0: 60 | x = torch.cat(torch.chunk(x, 2, dim=1)[::-1], dim=1) 61 | x = torch.chunk(x, 2, dim=1) 62 | scale = self.s[i](x[idx_apply]) 63 | x_ = x[idx_keep] * (scale.exp()) + self.t[i](x[idx_apply]) 64 | x = torch.cat((x[idx_apply], x_), dim=1) 65 | logdet_ = torch.sum(scale.view(x.size(0), -1), dim=1) 66 | logdet = logdet + logdet_ 67 | return x, logdet 68 | else: 69 | idx_apply, idx_keep = 0, 1 70 | for i in reversed(range(len(self.s))): 71 | if i % 2 == 0: 72 | x = torch.cat(torch.chunk(x, 2, dim=1)[::-1], dim=1) 73 | x = torch.chunk(x, 2, dim=1) 74 | x_ = (x[idx_keep] - self.t[i](x[idx_apply])) * (self.s[i](x[idx_apply]).neg().exp()) 75 | x = torch.cat((x[idx_apply], x_), dim=1) 76 | return x 77 | 78 | 79 | class VectorActNorm(nn.Module): 80 | def __init__(self, in_channel, logdet=True, **kwargs): 81 | super().__init__() 82 | self.loc = nn.Parameter(torch.zeros(1, in_channel, 1, 1)) 83 | self.scale = nn.Parameter(torch.ones(1, in_channel, 1, 1)) 84 | self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8)) 85 | self.logdet = logdet 86 | 87 | def initialize(self, input): 88 | if len(input.shape) == 2: 89 | input = input[:, :, None, None] 90 | with torch.no_grad(): 91 | flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) 92 | mean = ( 93 | flatten.mean(1) 94 | .unsqueeze(1) 95 | .unsqueeze(2) 96 | .unsqueeze(3) 97 | .permute(1, 0, 2, 3) 98 | ) 99 | std = ( 100 | flatten.std(1) 101 | .unsqueeze(1) 102 | .unsqueeze(2) 103 | .unsqueeze(3) 104 | .permute(1, 0, 2, 3) 105 | ) 106 | 107 | self.loc.data.copy_(-mean) 108 | self.scale.data.copy_(1 / (std + 1e-6)) 109 | 110 | def forward(self, input, reverse=False, conditioning=None): 111 | if len(input.shape) == 2: 112 | input = input[:, :, None, None] 113 | if not reverse: 114 | _, _, height, width = input.shape 115 | if self.initialized.item() == 0: 116 | self.initialize(input) 117 | self.initialized.fill_(1) 118 | log_abs = torch.log(torch.abs(self.scale)) 119 | logdet = height * width * torch.sum(log_abs) 120 | logdet = logdet * torch.ones(input.shape[0]).to(input) 121 | if not self.logdet: 122 | return (self.scale * (input + self.loc)).squeeze() 123 | return (self.scale * (input + self.loc)).squeeze(), logdet 124 | else: 125 | return self.reverse(input) 126 | 127 | def reverse(self, output, conditioning=None): 128 | return (output / self.scale - self.loc).squeeze() 129 | 130 | 131 | class Flow(nn.Module): 132 | def __init__(self, module_list, in_channels, hidden_dim, hidden_depth): 133 | super(Flow, self).__init__() 134 | self.in_channels = in_channels 135 | self.flow = nn.ModuleList( 136 | [module(in_channels, hidden_dim=hidden_dim, depth=hidden_depth) for module in module_list]) 137 | 138 | def forward(self, x, condition=None, reverse=False): 139 | if not reverse: 140 | logdet = 0 141 | for i in range(len(self.flow)): 142 | x, logdet_ = self.flow[i](x) 143 | logdet = logdet + logdet_ 144 | return x, logdet 145 | else: 146 | for i in reversed(range(len(self.flow))): 147 | x = self.flow[i](x, reverse=True) 148 | return x 149 | 150 | 151 | class EfficientVRNVP(nn.Module): 152 | def __init__(self, module_list, in_channels, n_flow, hidden_dim, hidden_depth): 153 | super().__init__() 154 | assert in_channels % 2 == 0 155 | self.flow = nn.ModuleList([Flow(module_list, in_channels, hidden_dim, hidden_depth) for n in range(n_flow)]) 156 | 157 | def forward(self, x, condition=None, reverse=False): 158 | if not reverse: 159 | logdet = 0 160 | for i in range(len(self.flow)): 161 | x, logdet_ = self.flow[i](x, condition=condition) 162 | logdet = logdet + logdet_ 163 | return x, logdet 164 | else: 165 | for i in reversed(range(len(self.flow))): 166 | x = self.flow[i](x, condition=condition, reverse=True) 167 | return x, None 168 | 169 | def reverse(self, x, condition=None): 170 | return self.flow(x, condition=condition, reverse=True) 171 | 172 | 173 | class VectorTransformer(nn.Module): 174 | def __init__(self, config): 175 | super().__init__() 176 | import torch.backends.cudnn as cudnn 177 | cudnn.benchmark = True 178 | self.config = config 179 | 180 | self.in_channel = retrieve(config, "Transformer/in_channel") 181 | self.n_flow = retrieve(config, "Transformer/n_flow") 182 | self.depth_submodules = retrieve(config, "Transformer/hidden_depth") 183 | self.hidden_dim = retrieve(config, "Transformer/hidden_dim") 184 | modules = [VectorActNorm, DoubleVectorCouplingBlock, Shuffle] 185 | self.realnvp = EfficientVRNVP(modules, self.in_channel, self.n_flow, self.hidden_dim, 186 | hidden_depth=self.depth_submodules) 187 | 188 | def forward(self, input, reverse=False): 189 | if reverse: 190 | return self.reverse(input) 191 | input = input.squeeze() 192 | out, logdet = self.realnvp(input) 193 | return out[:, :, None, None], logdet 194 | 195 | def reverse(self, out): 196 | out = out.squeeze() 197 | return self.realnvp(out, reverse=True)[0][:, :, None, None] 198 | 199 | 200 | class FactorTransformer(VectorTransformer): 201 | def __init__(self, config): 202 | super().__init__(config) 203 | self.n_factors = retrieve(config, "Transformer/n_factors", default=2) 204 | self.factor_config = retrieve(config, "Transformer/factor_config", default=list()) 205 | 206 | def forward(self, input): 207 | out, logdet = super().forward(input) 208 | if self.factor_config: 209 | out = torch.split(out, self.factor_config, dim=1) 210 | else: 211 | out = torch.chunk(out, self.n_factors, dim=1) 212 | return out, logdet 213 | 214 | def reverse(self, out): 215 | out = torch.cat(out, dim=1) 216 | return super().reverse(out) 217 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup(name='iin', 4 | version='0.1', 5 | description='Code accompanying the paper ' 6 | 'A Disentangling Invertible Interpretation Network for Explaining Latent Representations' 7 | 'https://arxiv.org/abs/2004.13166', 8 | author='Esser, Patrick and Rombach, Robin and Ommer, Bjoern ', 9 | packages=find_packages(), 10 | install_requires=[ 11 | 'torch>=1.4.0', 12 | 'torchvision>=0.5', 13 | 'numpy>=1.17', 14 | 'scipy>=1.0.1' 15 | ], 16 | zip_safe=False) 17 | --------------------------------------------------------------------------------