├── .gitignore
├── LICENSE.md
├── README.md
├── assets
└── mnist.gif
├── configs
├── celeba_ae.yaml
├── celeba_diin.yaml
├── celeba_iin.yaml
├── cifar_ae.yaml
├── cifar_iin.yaml
├── cmnist_ae.yaml
├── cmnist_clf.yaml
├── cmnist_clf_diin.yaml
├── cmnist_diin.yaml
├── cmnist_dimeval.yaml
├── fashionmnist_ae.yaml
├── fashionmnist_iin.yaml
├── mnist_ae.yaml
├── mnist_iin.yaml
└── project.yaml
├── environment.yaml
├── iin
├── data.py
├── iterators
│ ├── ae.py
│ ├── base.py
│ ├── clf.py
│ └── iin.py
├── losses
│ ├── ae.py
│ └── iin.py
└── models
│ ├── ae.py
│ ├── clf.py
│ └── iin.py
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | /fid_stats
2 | /logs
3 | /wandb
4 |
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Patrick Esser and Robin Rombach
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # A Disentangling Invertible Interpretation Network for Explaining Latent Representations
2 |
3 | PyTorch code accompanying the [CVPR 2020](http://cvpr2020.thecvf.com/) paper
4 |
5 | [**A Disentangling Invertible Interpretation Network for Explaining Latent Representations**](https://compvis.github.io/iin/)
6 | [Patrick Esser](https://github.com/pesser)\*,
7 | [Robin Rombach](https://github.com/rromb)\*,
8 | [Björn Ommer](https://hci.iwr.uni-heidelberg.de/Staff/bommer)
9 | \* equal contribution
10 |
11 | 
12 | [arXiv](https://arxiv.org/abs/2004.13166) | [BibTeX](#bibtex) | [Project Page](https://compvis.github.io/iin/)
13 |
14 |
15 | Table of Contents
16 | =================
17 |
18 | * [Requirements](#requirements)
19 | * [Data](#data)
20 | * [Training](#training)
21 | * [Autoencoders](#autoencoders)
22 | * [Classifiers](#classifiers)
23 | * [Invertible Interpretation Networks](#invertible-interpretation-networks)
24 | * [Unsupervised](#unsupervised)
25 | * [Supervised](#supervised)
26 | * [Evaluation](#evaluation)
27 | * [Pretrained Models](#pretrained-models)
28 | * [Results](#results)
29 | * [BibTeX](#bibtex)
30 |
31 |
32 | ## Requirements
33 | A suitable [conda](https://conda.io/) environment named `iin` can be created
34 | and activated with:
35 |
36 | ```
37 | conda env create -f environment.yaml
38 | conda activate iin
39 | ```
40 |
41 | Optionally, you can then also `conda install tensorflow-gpu=1.14` to speed up
42 | FID evaluations.
43 |
44 |
45 | ## Data
46 | `MNIST`, `FashionMNIST` and `CIFAR10` will be downloaded automatically the
47 | first time they are used and `CelebA` will prompt you to download it. The
48 | content of each dataset can be visualized with
49 |
50 | ```
51 | edexplore --dataset iin.data.
52 | ```
53 |
54 | where `` is one of `MNISTTrain`, `MNISTTest`, `FashionMNISTTrain`,
55 | `FashionMNISTTest`, `CIFAR10Train`, `CIFAR10Test`, `CelebATrain`, `CelebATest`,
56 | `FactorCelebATrain`, `FactorCelebATest`, `ColorfulMNISTTrain`,
57 | `ColorfulMNISTTest`, `SingleColorfulMNISTTrain`, `SingleColorfulMNISTTest`.
58 |
59 |
60 | ## Training
61 |
62 | ### Autoencoders
63 | To train autoencoders, run
64 |
65 | ```
66 | edflow -b configs/_ae.yaml -t
67 | ```
68 |
69 | where `` is one of `mnist`, `fashionmnist`, `cifar`, `celeba`,
70 | `cmnist`. To enable logging to [wandb](https://wandb.ai), adjust
71 | `configs/project.yaml` and add it to above command:
72 |
73 | ```
74 | edflow -b configs/_ae.yaml configs/project.yaml -t
75 | ```
76 |
77 | ### Classifiers
78 | To train a classifier on `ColorfulMNIST`, run
79 |
80 | ```
81 | edflow -b configs/cmnist_clf.yaml -t
82 | ```
83 |
84 | Once you have a checkpoint, you can estimate factor dimensionalities using
85 |
86 | ```
87 | edflow -b configs/cmnist_clf.yaml configs/cmnist_dimeval.yaml -c
88 | ```
89 |
90 | For the pretrained classifier, this gives
91 |
92 | ```
93 | [INFO] [dim_callback]: estimated factor dimensionalities: [22, 11, 31]
94 | ```
95 |
96 | and to compare this to an autoencoder, run
97 |
98 | ```
99 | edflow -b configs/cmnist_ae.yaml configs/cmnist_dimeval.yaml -c
100 | ```
101 |
102 | which gives
103 |
104 | ```
105 | [INFO] [dim_callback]: estimated factor dimensionalities: [13, 17, 34]
106 | ```
107 |
108 | ### Invertible Interpretation Networks
109 | #### Unsupervised on AE
110 | To train unsupervised invertible interpretation networks, run
111 |
112 | ```
113 | edflow -b configs/_iin.yaml [configs/project.yaml] -t
114 | ```
115 |
116 | where `` is one of `mnist`, `fashionmnist`, `cifar`, `celeba`. If,
117 | instead of using one of the [pretrained models](#pretrained-models), you
118 | trained an autoencoder yourself, adjust the `first_stage` config section
119 | accordingly.
120 |
121 | #### Supervised
122 | For supervised, disentangling IINs, run
123 |
124 | ```
125 | edflow -b configs/_diin.yaml [configs/project.yaml] -t
126 | ```
127 |
128 | where `` is one of `cmnist` or `celeba`, or run
129 |
130 | ```
131 | edflow -b configs/cmnist_clf_diin.yaml [configs/project.yaml] -t
132 | ```
133 |
134 | to train a dIIN on top of a classifier, with factor dimensionalities as
135 | estimated above (dimensionalities of factors can be adjusted via the
136 | `Transformer/factor_config` configuration entry).
137 |
138 |
139 | ## Evaluation
140 |
141 | Evaluations run automatically after each epoch of training. To start an
142 | evaluation manually, run
143 |
144 | ```
145 | edflow -p logs//configs/.yaml
146 | ```
147 |
148 | and, optionally, add `-c ` to evaluate a specific
149 | checkpoint instead of the last one.
150 |
151 |
152 | ## Pretrained Models
153 | Download [`logs.tar.gz`](https://heibox.uni-heidelberg.de/f/0c76b38bf4274448b8e9/)
154 | (~2.2 GB) and extract the pretrained models:
155 |
156 | ```
157 | tar xzf logs.tar.gz
158 | ```
159 |
160 |
161 | ## Results
162 | Using spectral normalization for the discriminator, this code slightly improves
163 | upon the values reported in Tab. 2 of the paper.
164 |
165 | | Dataset | Checkpoint | FID |
166 | |--------------|------------|--------|
167 | | MNIST | 105600 | 5.252 |
168 | | FashionMNIST | 110400 | 9.663 |
169 | | CelebA | 84643 | 19.839 |
170 | | CIFAR10 | 32000 | 38.697 |
171 |
172 | Full training logs can be found on [Weights &
173 | Biases](https://app.wandb.ai/trex/iin/reportlist).
174 |
175 |
176 | ## BibTeX
177 |
178 | ```
179 | @inproceedings{esser2020invertible,
180 | title={A Disentangling Invertible Interpretation Network for Explaining Latent Representations},
181 | author={Esser, Patrick and Rombach, Robin and Ommer, Bj{\"o}rn},
182 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
183 | year={2020}
184 | }
185 | ```
186 |
--------------------------------------------------------------------------------
/assets/mnist.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/iin/fa0d2d1cfa6fc0b813b902cd3474145a045fbb34/assets/mnist.gif
--------------------------------------------------------------------------------
/configs/celeba_ae.yaml:
--------------------------------------------------------------------------------
1 | experiment_name: celeba_ae
2 |
3 | datasets:
4 | train: iin.data.CelebATrain
5 | validation: iin.data.CelebATest
6 |
7 | fid:
8 | batch_size: 50
9 | fid_stats:
10 | pre_calc_stat_path: "fid_stats/celeba.npz"
11 |
12 | model: iin.models.ae.Model
13 | Model:
14 | deterministic: false
15 | in_channels: 3
16 | in_size: 64
17 | n_down: 4
18 | norm: an
19 | z_dim: 64
20 |
21 | loss: iin.losses.ae.LossPND
22 | Loss:
23 | calibrate: true
24 | disc_in_channels: 3
25 | disc_start: 100001
26 | logvar_init: 0.0
27 | perceptual_weight: 1.0
28 | d_lr_factor: 4.0
29 | spectral_norm: True
30 |
31 |
32 | iterator: iin.iterators.ae.Trainer
33 | base_learning_rate: 4.5e-06
34 | decay_start: 505001
35 | batch_size: 25
36 | log_freq: 1000
37 | num_epochs: 100
38 |
--------------------------------------------------------------------------------
/configs/celeba_diin.yaml:
--------------------------------------------------------------------------------
1 | experiment_name: celeba_diin
2 |
3 | datasets:
4 | train: iin.data.FactorCelebATrain
5 | validation: iin.data.FactorCelebATest
6 |
7 | fid:
8 | batch_size: 50
9 | fid_stats:
10 | pre_calc_stat_path: "fid_stats/celeba.npz"
11 |
12 |
13 | first_stage:
14 | checkpoint: logs/2020-04-23T23-43-25_celeba_ae/train/checkpoints/model-572968.ckpt
15 | model: iin.models.ae.Model
16 | subconfig:
17 | Model:
18 | deterministic: false
19 | in_channels: 3
20 | in_size: 64
21 | n_down: 4
22 | norm: an
23 | z_dim: 64
24 |
25 | model: iin.models.iin.FactorTransformer
26 | Transformer:
27 | n_factors: 5
28 | in_channel: 64
29 | n_flow: 12
30 | hidden_depth: 2
31 | hidden_dim: 512
32 |
33 |
34 | loss: iin.losses.iin.FactorLoss
35 | Loss:
36 | rho: 0.975
37 |
38 | iterator: iin.iterators.iin.FactorTrainer
39 | base_learning_rate: 4.5e-06
40 | decay_start: 505001
41 | batch_size: 25
42 | log_freq: 1000
43 | num_epochs: 100
44 |
--------------------------------------------------------------------------------
/configs/celeba_iin.yaml:
--------------------------------------------------------------------------------
1 | experiment_name: celeba_iin
2 |
3 | datasets:
4 | train: iin.data.CelebATrain
5 | validation: iin.data.CelebATest
6 |
7 | fid:
8 | batch_size: 50
9 | fid_stats:
10 | pre_calc_stat_path: "fid_stats/celeba.npz"
11 |
12 | first_stage:
13 | checkpoint: logs/2020-04-23T23-43-25_celeba_ae/train/checkpoints/model-572968.ckpt
14 | model: iin.models.ae.Model
15 | subconfig:
16 | Model:
17 | deterministic: false
18 | in_channels: 3
19 | in_size: 64
20 | n_down: 4
21 | norm: an
22 | z_dim: 64
23 |
24 | model: iin.models.iin.VectorTransformer
25 | Transformer:
26 | in_channel: 64
27 | n_flow: 12
28 | hidden_depth: 2
29 | hidden_dim: 512
30 |
31 |
32 | loss: iin.losses.iin.Loss
33 |
34 | iterator: iin.iterators.iin.Trainer
35 | base_learning_rate: 4.5e-06
36 | decay_start: 505001
37 | batch_size: 25
38 | log_freq: 1000
39 | num_epochs: 100
40 |
--------------------------------------------------------------------------------
/configs/cifar_ae.yaml:
--------------------------------------------------------------------------------
1 | experiment_name: cifar_ae
2 |
3 | datasets:
4 | train: iin.data.CIFAR10Train
5 | validation: iin.data.CIFAR10Test
6 |
7 | fid:
8 | batch_size: 50
9 | fid_stats:
10 | pre_calc_stat_path: "fid_stats/cifar.npz"
11 |
12 | model: iin.models.ae.Model
13 | Model:
14 | deterministic: false
15 | in_channels: 3
16 | in_size: 32
17 | n_down: 4
18 | norm: an
19 | z_dim: 64
20 |
21 | loss: iin.losses.ae.LossPND
22 | Loss:
23 | calibrate: true
24 | disc_in_channels: 3
25 | disc_start: 100001
26 | logvar_init: 0.0
27 | perceptual_weight: 1.0
28 | d_lr_factor: 4.0
29 | spectral_norm: True
30 |
31 | iterator: iin.iterators.ae.Trainer
32 | base_learning_rate: 4.5e-06
33 | batch_size: 25
34 | log_freq: 1000
35 | num_epochs: 300
36 | decay_start: 505001
37 |
--------------------------------------------------------------------------------
/configs/cifar_iin.yaml:
--------------------------------------------------------------------------------
1 | experiment_name: cifar_iin
2 |
3 | datasets:
4 | train: iin.data.CIFAR10Train
5 | validation: iin.data.CIFAR10Test
6 |
7 | fid:
8 | batch_size: 50
9 | fid_stats:
10 | pre_calc_stat_path: "fid_stats/cifar.npz"
11 |
12 |
13 | first_stage:
14 | checkpoint: logs/2020-04-24T09-47-43_cifar_ae/train/checkpoints/model-534000.ckpt
15 | model: iin.models.ae.Model
16 | subconfig:
17 | Model:
18 | deterministic: false
19 | in_channels: 3
20 | in_size: 32
21 | n_down: 4
22 | norm: an
23 | z_dim: 64
24 |
25 | model: iin.models.iin.VectorTransformer
26 | Transformer:
27 | in_channel: 64
28 | n_flow: 12
29 | hidden_depth: 2
30 | hidden_dim: 512
31 |
32 |
33 | loss: iin.losses.iin.Loss
34 |
35 | iterator: iin.iterators.iin.Trainer
36 | base_learning_rate: 4.5e-06
37 | batch_size: 25
38 | log_freq: 1000
39 | num_epochs: 300
40 | decay_start: 505001
41 |
--------------------------------------------------------------------------------
/configs/cmnist_ae.yaml:
--------------------------------------------------------------------------------
1 | experiment_name: cmnist_ae
2 |
3 | datasets:
4 | train: iin.data.SingleColorfulMNISTTrain
5 | validation: iin.data.SingleColorfulMNISTTest
6 |
7 | fid:
8 | batch_size: 50
9 | fid_stats:
10 | pre_calc_stat_path: "fid_stats/cmnist.npz"
11 |
12 | model: iin.models.ae.Model
13 | Model:
14 | deterministic: false
15 | in_channels: 3
16 | in_size: 32
17 | n_down: 4
18 | norm: an
19 | z_dim: 64
20 |
21 | loss: iin.losses.ae.LossPND
22 | Loss:
23 | calibrate: true
24 | disc_in_channels: 3
25 | disc_start: 25001
26 | logvar_init: 0.0
27 | perceptual_weight: 1.0
28 | discriminator_weight: 150.0
29 |
30 | iterator: iin.iterators.ae.Trainer
31 | base_learning_rate: 4.5e-06
32 | batch_size: 25
33 | log_freq: 1000
34 | num_epochs: 50
35 | decay_start: 100001
36 |
--------------------------------------------------------------------------------
/configs/cmnist_clf.yaml:
--------------------------------------------------------------------------------
1 | experiment_name: cmnist_clf
2 |
3 | datasets:
4 | train: iin.data.SingleColorfulMNISTTrain
5 | validation: iin.data.SingleColorfulMNISTTest
6 | n_classes: 10
7 |
8 | model: iin.models.clf.Model
9 | Model:
10 | in_channels: 3
11 | in_size: 32
12 | n_down: 4
13 | norm: an
14 | z_dim: 64
15 |
16 | iterator: iin.iterators.clf.Trainer
17 | base_learning_rate: 4.5e-06
18 | batch_size: 25
19 | log_freq: 1000
20 | num_epochs: 50
21 | decay_start: 100001
22 |
--------------------------------------------------------------------------------
/configs/cmnist_clf_diin.yaml:
--------------------------------------------------------------------------------
1 | experiment_name: cmnist_diin
2 |
3 | datasets:
4 | train: iin.data.ColorfulMNISTTrain
5 | validation: iin.data.ColorfulMNISTTest
6 |
7 | first_stage:
8 | checkpoint: logs/2020-07-23T13-19-33_cmnist_clf/train/checkpoints/model-120000.ckpt
9 | model: iin.models.clf.Model
10 | subconfig:
11 | n_classes: 10
12 | Model:
13 | in_channels: 3
14 | in_size: 32
15 | n_down: 4
16 | norm: an
17 | z_dim: 64
18 |
19 | model: iin.models.iin.FactorTransformer
20 | Transformer:
21 | n_factors: 3
22 | in_channel: 64
23 | n_flow: 12
24 | hidden_depth: 2
25 | hidden_dim: 512
26 | factor_config:
27 | - 22
28 | - 11
29 | - 31
30 |
31 | loss: iin.losses.iin.FactorLoss
32 | Loss:
33 | rho: 0.975
34 |
35 | iterator: iin.iterators.iin.FactorTrainer
36 | base_learning_rate: 4.5e-06
37 | batch_size: 25
38 | log_freq: 1000
39 | num_epochs: 50
40 | decay_start: 100001
41 |
--------------------------------------------------------------------------------
/configs/cmnist_diin.yaml:
--------------------------------------------------------------------------------
1 | experiment_name: cmnist_diin
2 |
3 | datasets:
4 | train: iin.data.ColorfulMNISTTrain
5 | validation: iin.data.ColorfulMNISTTest
6 |
7 | fid:
8 | batch_size: 50
9 | fid_stats:
10 | pre_calc_stat_path: "fid_stats/cmnist.npz"
11 |
12 | first_stage:
13 | checkpoint: logs/2020-04-27T22-55-44_cmnist_ae/train/checkpoints/model-120000.ckpt
14 | model: iin.models.ae.Model
15 | subconfig:
16 | Model:
17 | deterministic: false
18 | in_channels: 3
19 | in_size: 32
20 | n_down: 4
21 | norm: an
22 | z_dim: 64
23 |
24 | model: iin.models.iin.FactorTransformer
25 | Transformer:
26 | n_factors: 3
27 | in_channel: 64
28 | n_flow: 12
29 | hidden_depth: 2
30 | hidden_dim: 512
31 |
32 |
33 | loss: iin.losses.iin.FactorLoss
34 | Loss:
35 | rho: 0.975
36 |
37 | iterator: iin.iterators.iin.FactorTrainer
38 | base_learning_rate: 4.5e-06
39 | batch_size: 25
40 | log_freq: 1000
41 | num_epochs: 50
42 | decay_start: 100001
43 |
44 |
--------------------------------------------------------------------------------
/configs/cmnist_dimeval.yaml:
--------------------------------------------------------------------------------
1 | experiment_name: cmnist_dimeval
2 |
3 | datasets:
4 | train: iin.data.ColorfulMNISTTrain
5 | validation: iin.data.ColorfulMNISTTest
6 | n_classes: 10
7 |
8 | iterator: iin.iterators.base.DimEvaluator
9 | batch_size: 25
10 | num_steps: 1
11 |
--------------------------------------------------------------------------------
/configs/fashionmnist_ae.yaml:
--------------------------------------------------------------------------------
1 | experiment_name: fashionmnist_ae
2 |
3 | datasets:
4 | train: iin.data.FashionMNISTTrain
5 | validation: iin.data.FashionMNISTTest
6 |
7 | fid:
8 | batch_size: 50
9 | fid_stats:
10 | pre_calc_stat_path: "fid_stats/fashionmnist.npz"
11 |
12 | model: iin.models.ae.Model
13 | Model:
14 | deterministic: false
15 | in_channels: 1
16 | in_size: 32
17 | n_down: 4
18 | norm: an
19 | z_dim: 64
20 |
21 | loss: iin.losses.ae.LossPND
22 | Loss:
23 | calibrate: true
24 | disc_in_channels: 1
25 | disc_start: 25001
26 | logvar_init: 0.0
27 | perceptual_weight: 1.0
28 |
29 | iterator: iin.iterators.ae.Trainer
30 | base_learning_rate: 4.5e-06
31 | batch_size: 25
32 | log_freq: 1000
33 | num_epochs: 50
34 | decay_start: 100001
35 |
--------------------------------------------------------------------------------
/configs/fashionmnist_iin.yaml:
--------------------------------------------------------------------------------
1 | experiment_name: fashionmnist_iin
2 |
3 | datasets:
4 | train: iin.data.FashionMNISTTrain
5 | validation: iin.data.FashionMNISTTest
6 |
7 | fid:
8 | batch_size: 50
9 | fid_stats:
10 | pre_calc_stat_path: "fid_stats/fashionmnist.npz"
11 |
12 | first_stage:
13 | checkpoint: logs/2020-04-23T00-17-38_fashionmnist_ae/train/checkpoints/model-120000.ckpt
14 | model: iin.models.ae.Model
15 | subconfig:
16 | Model:
17 | deterministic: false
18 | in_channels: 1
19 | in_size: 32
20 | n_down: 4
21 | norm: an
22 | z_dim: 64
23 |
24 | model: iin.models.iin.VectorTransformer
25 | Transformer:
26 | in_channel: 64
27 | n_flow: 12
28 | hidden_depth: 2
29 | hidden_dim: 512
30 |
31 |
32 | loss: iin.losses.iin.Loss
33 |
34 | iterator: iin.iterators.iin.Trainer
35 | base_learning_rate: 4.5e-06
36 | batch_size: 25
37 | log_freq: 1000
38 | num_epochs: 50
39 | decay_start: 100001
40 |
--------------------------------------------------------------------------------
/configs/mnist_ae.yaml:
--------------------------------------------------------------------------------
1 | experiment_name: mnist_ae
2 |
3 | datasets:
4 | train: iin.data.MNISTTrain
5 | validation: iin.data.MNISTTest
6 |
7 | fid:
8 | batch_size: 50
9 | fid_stats:
10 | pre_calc_stat_path: "fid_stats/mnist.npz"
11 |
12 | model: iin.models.ae.Model
13 | Model:
14 | deterministic: false
15 | in_channels: 1
16 | in_size: 32
17 | n_down: 4
18 | norm: an
19 | z_dim: 64
20 |
21 | loss: iin.losses.ae.LossPND
22 | Loss:
23 | calibrate: true
24 | disc_in_channels: 1
25 | disc_start: 25001
26 | logvar_init: 0.0
27 | perceptual_weight: 1.0
28 | discriminator_weight: 150.0
29 |
30 | iterator: iin.iterators.ae.Trainer
31 | base_learning_rate: 4.5e-06
32 | batch_size: 25
33 | log_freq: 1000
34 | num_epochs: 50
35 | decay_start: 100001
36 |
--------------------------------------------------------------------------------
/configs/mnist_iin.yaml:
--------------------------------------------------------------------------------
1 | experiment_name: mnist_iin
2 |
3 | datasets:
4 | train: iin.data.MNISTTrain
5 | validation: iin.data.MNISTTest
6 |
7 | fid:
8 | batch_size: 50
9 | fid_stats:
10 | pre_calc_stat_path: "fid_stats/mnist.npz"
11 |
12 | first_stage:
13 | checkpoint: logs/2020-04-23T00-17-34_mnist_ae/train/checkpoints/model-120000.ckpt
14 | model: iin.models.ae.Model
15 | subconfig:
16 | Model:
17 | deterministic: false
18 | in_channels: 1
19 | in_size: 32
20 | n_down: 4
21 | norm: an
22 | z_dim: 64
23 |
24 | model: iin.models.iin.VectorTransformer
25 | Transformer:
26 | in_channel: 64
27 | n_flow: 12
28 | hidden_depth: 2
29 | hidden_dim: 512
30 |
31 |
32 | loss: iin.losses.iin.Loss
33 |
34 | iterator: iin.iterators.iin.Trainer
35 | base_learning_rate: 4.5e-06
36 | batch_size: 25
37 | log_freq: 1000
38 | num_epochs: 50
39 | decay_start: 100001
40 |
--------------------------------------------------------------------------------
/configs/project.yaml:
--------------------------------------------------------------------------------
1 | integrations:
2 | wandb:
3 | active: True
4 | entity: trex
5 | project: iin
6 |
--------------------------------------------------------------------------------
/environment.yaml:
--------------------------------------------------------------------------------
1 | name: iin
2 | channels:
3 | - pytorch
4 | - defaults
5 | dependencies:
6 | - python=3.7
7 | - pip=19.3
8 | - cudatoolkit=10.1
9 | - pytorch=1.4
10 | - torchvision=0.5
11 | - numpy=1.17
12 | - pip:
13 | - git+https://github.com/pesser/edflow.git@711a728c7afc50a88017bc37213fbe485fc13efe#egg=edflow[full]
14 | - -e git+https://github.com/jhaux/PerceptualSimilarity.git@f752859b104702e52abbe49750c595b4fb980382#egg=perceptual_similarity
15 | - -e git+https://github.com/rromb/fid.git@e6deef55cdb56944a540c3d74739438111292ec8#egg=fid-callback
16 | - albumentations==0.4.3
17 | - opencv-python==4.1.2.30
18 | - pudb==2019.2
19 |
--------------------------------------------------------------------------------
/iin/data.py:
--------------------------------------------------------------------------------
1 | import albumentations
2 | import numpy as np
3 | from edflow.util import retrieve
4 | from edflow.data.dataset import PRNGMixin, DatasetMixin, SubDataset
5 | from edflow.datasets.mnist import (
6 | MNISTTrain as _MNISTTrain,
7 | MNISTTest as _MNISTTest)
8 | from edflow.datasets.fashionmnist import (
9 | FashionMNISTTrain as _FashionMNISTTrain,
10 | FashionMNISTTest as _FashionMNISTTest)
11 | from edflow.datasets.cifar import (
12 | CIFAR10Train as _CIFAR10Train,
13 | CIFAR10Test as _CIFAR10Test)
14 | from edflow.datasets.celeba import (
15 | CelebATrain as _CelebATrain,
16 | CelebATest as _CelebATest)
17 |
18 |
19 | class Base32(DatasetMixin, PRNGMixin):
20 | """Add support for resizing, cropping and dequantization."""
21 | def __init__(self, config):
22 | self.data = self.get_base_data(config)
23 | self.size = retrieve(config, "spatial_size", default=32)
24 |
25 | self.rescaler = albumentations.SmallestMaxSize(max_size = self.size)
26 | self.cropper = albumentations.CenterCrop(height=self.size,width=self.size)
27 | self.preprocessor = albumentations.Compose([self.rescaler, self.cropper])
28 |
29 | def preprocess_example(self, example):
30 | example["image"] = ((example["image"]+1)*127.5).astype(np.uint8)
31 | example["image"] = self.preprocessor(image=example["image"])["image"]
32 | example["image"] = (example["image"] + self.prng.random())/256. # dequantization
33 | example["image"] = (example["image"]*2.0-1.0).astype(np.float32)
34 | return example
35 |
36 | def get_example(self, i):
37 | example = super().get_example(i)
38 | example = self.preprocess_example(example)
39 | return example
40 |
41 |
42 | class MNISTTrain(Base32):
43 | def get_base_data(self, config):
44 | return _MNISTTrain(config)
45 |
46 |
47 | class MNISTTest(Base32):
48 | def get_base_data(self, config):
49 | return _MNISTTest(config)
50 |
51 |
52 | class FashionMNISTTrain(Base32):
53 | def get_base_data(self, config):
54 | return _FashionMNISTTrain(config)
55 |
56 |
57 | class FashionMNISTTest(Base32):
58 | def get_base_data(self, config):
59 | return _FashionMNISTTest(config)
60 |
61 |
62 | class CIFAR10Train(Base32):
63 | def get_base_data(self, config):
64 | return _CIFAR10Train(config)
65 |
66 |
67 | class CIFAR10Test(Base32):
68 | def get_base_data(self, config):
69 | return _CIFAR10Test(config)
70 |
71 |
72 | class BaseCelebA(DatasetMixin, PRNGMixin):
73 | """Add support for resizing, cropping and dequantization."""
74 | def __init__(self, config):
75 | self.data = self.get_base_data(config)
76 | self.size = retrieve(config, "spatial_size", default=64)
77 | self.attribute_descriptions = [
78 | "5_o_Clock_Shadow", "Arched_Eyebrows", "Attractive",
79 | "Bags_Under_Eyes", "Bald", "Bangs", "Big_Lips", "Big_Nose",
80 | "Black_Hair", "Blond_Hair", "Blurry", "Brown_Hair",
81 | "Bushy_Eyebrows", "Chubby", "Double_Chin", "Eyeglasses",
82 | "Goatee", "Gray_Hair", "Heavy_Makeup", "High_Cheekbones",
83 | "Male", "Mouth_Slightly_Open", "Mustache", "Narrow_Eyes",
84 | "No_Beard", "Oval_Face", "Pale_Skin", "Pointy_Nose",
85 | "Receding_Hairline", "Rosy_Cheeks", "Sideburns", "Smiling",
86 | "Straight_Hair", "Wavy_Hair", "Wearing_Earrings",
87 | "Wearing_Hat", "Wearing_Lipstick", "Wearing_Necklace",
88 | "Wearing_Necktie", "Young"]
89 |
90 | self.cropper = albumentations.CenterCrop(height=160,width=160)
91 | self.rescaler = albumentations.SmallestMaxSize(max_size = self.size)
92 | self.preprocessor = albumentations.Compose([self.cropper, self.rescaler])
93 |
94 | def preprocess_example(self, example):
95 | example["image"] = ((example["image"]+1)*127.5).astype(np.uint8)
96 | example["image"] = self.preprocessor(image=example["image"])["image"]
97 | example["image"] = (example["image"] + self.prng.random())/256. # dequantization
98 | example["image"] = (example["image"]*2.0-1.0).astype(np.float32)
99 | return example
100 |
101 | def get_example(self, i):
102 | example = super().get_example(i)
103 | example = self.preprocess_example(example)
104 | return example
105 |
106 |
107 | class CelebATrain(BaseCelebA):
108 | def get_base_data(self, config):
109 | return _CelebATrain(config)
110 |
111 |
112 | class CelebATest(BaseCelebA):
113 | def get_base_data(self, config):
114 | data = _CelebATest(config)
115 | indices = np.random.RandomState(1).choice(len(data), size=10000)
116 | return SubDataset(data, indices)
117 |
118 |
119 | class BaseFactorCelebA(BaseCelebA):
120 | def __init__(self, config):
121 | super().__init__(config)
122 | self.attributes = retrieve(config, "CelebAFactors/attributes",
123 | default=["Eyeglasses", "Male", "No_Beard", "Smiling"])
124 | self.attribute_indices = [self.attribute_descriptions.index(attr)
125 | for attr in self.attributes]
126 | self.pos_attribute_choices = [np.where(self.labels["attributes"][:,attridx]==1)[0]
127 | for attridx in self.attribute_indices]
128 | self.neg_attribute_choices = [np.where(self.labels["attributes"][:,attridx]==-1)[0]
129 | for attridx in self.attribute_indices]
130 | self.n_factors = len(self.attributes)+1
131 | self.residual_index = len(self.attributes)
132 |
133 | def get_factor_idx(self, i):
134 | factor = self.prng.choice(len(self.attributes))
135 | attridx = self.attribute_indices[factor]
136 | attr = self.labels["attributes"][i,attridx]
137 | if attr == 1:
138 | i2 = self.prng.choice(self.pos_attribute_choices[factor])
139 | else:
140 | i2 = self.prng.choice(self.neg_attribute_choices[factor])
141 | return factor, i2
142 |
143 | def get_example(self, i):
144 | e1 = super().get_example(i)
145 | factor, i2 = self.get_factor_idx(i)
146 | e2 = super().get_example(i2)
147 | example = {
148 | "factor": factor,
149 | "example1": e1,
150 | "example2": e2}
151 | return example
152 |
153 |
154 | class FactorCelebATrain(BaseFactorCelebA):
155 | def get_base_data(self, config):
156 | return _CelebATrain(config)
157 |
158 |
159 | class FactorCelebATest(BaseFactorCelebA):
160 | def __init__(self, config):
161 | super().__init__(config)
162 | self.test_prng = np.random.RandomState(1)
163 | self.factor_idx = [BaseFactorCelebA.get_factor_idx(self, i) for i in range(len(self))]
164 |
165 | @property
166 | def prng(self):
167 | return self.test_prng
168 |
169 | def get_factor_idx(self, i):
170 | return self.factor_idx[i]
171 |
172 | def get_base_data(self, config):
173 | data = _CelebATest(config)
174 | indices = np.random.RandomState(1).choice(len(data), size=10000)
175 | return SubDataset(data, indices)
176 |
177 |
178 | class ColorfulMNISTBase(DatasetMixin):
179 | def get_factor(self, i):
180 | factor = self.prng.choice(2)
181 | return factor
182 |
183 | def get_same_idx(self, i):
184 | cls = self.labels["class"][i]
185 | others = np.where(self.labels["class"] == cls)[0]
186 | return self.prng.choice(others)
187 |
188 | def get_other_idx(self, i):
189 | return self.prng.choice(len(self))
190 |
191 | def get_color(self, i):
192 | return self.prng.uniform(low=0,high=1,size=3).astype(np.float32)
193 |
194 | def get_example(self, i):
195 | example1 = super().get_example(i)
196 | factor = self.get_factor(i)
197 | example = {"factor": factor, "example1": example1}
198 |
199 | if factor == 0:
200 | # same digit, different color
201 | j = self.get_same_idx(i)
202 | color1 = self.get_color(i)
203 | color2 = self.get_color(j)
204 | else:
205 | # different digit, same color
206 | j = self.get_other_idx(i)
207 | color1 = self.get_color(i)
208 | color2 = color1
209 |
210 | example2 = super().get_example(j)
211 | example["example2"] = example2
212 |
213 | example["example1"]["image"] = example["example1"]["image"] * color1
214 | example["example2"]["image"] = example["example2"]["image"] * color2
215 |
216 | return example
217 |
218 |
219 | class ColorfulMNISTTrain(ColorfulMNISTBase, PRNGMixin):
220 | def __init__(self, config):
221 | self.data = MNISTTrain(config)
222 | self.n_factors = 3
223 | self.residual_index = 2
224 |
225 |
226 | class ColorfulMNISTTest(ColorfulMNISTBase):
227 | def __init__(self, config):
228 | self.data = MNISTTest(config)
229 | self.prng = np.random.RandomState(1)
230 | self.factor = [ColorfulMNISTBase.get_factor(self, i) for i in range(len(self))]
231 | self.same_idx = [ColorfulMNISTBase.get_same_idx(self, i) for i in range(len(self))]
232 | self.other_idx = [ColorfulMNISTBase.get_other_idx(self, i) for i in range(len(self))]
233 | self.color = [ColorfulMNISTBase.get_color(self, i) for i in range(len(self))]
234 | self.n_factors = 3
235 | self.residual_index = 2
236 |
237 | def get_factor(self, i):
238 | return self.factor[i]
239 |
240 | def get_same_idx(self, i):
241 | return self.same_idx[i]
242 |
243 | def get_other_idx(self, i):
244 | return self.other_idx[i]
245 |
246 | def get_color(self, i):
247 | return self.color[i]
248 |
249 |
250 | class SingleColorfulMNISTTrain(DatasetMixin):
251 | def __init__(self, config):
252 | self.data = ColorfulMNISTTrain(config)
253 |
254 | def get_example(self, i):
255 | return super().get_example(i)["example1"]
256 |
257 |
258 | class SingleColorfulMNISTTest(DatasetMixin):
259 | def __init__(self, config):
260 | self.data = ColorfulMNISTTest(config)
261 |
262 | def get_example(self, i):
263 | return super().get_example(i)["example1"]
264 |
--------------------------------------------------------------------------------
/iin/iterators/ae.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from perceptual_similarity import PerceptualLoss
4 | from edflow.util import retrieve
5 | from fid import fid_callback
6 |
7 | from iin.iterators.base import Iterator
8 |
9 |
10 | def rec_fid_callback(*args, **kwargs):
11 | return fid_callback.fid(*args, **kwargs,
12 | im_in_key="image",
13 | im_in_support="-1->1",
14 | im_out_key="reconstructions",
15 | im_out_support="0->255",
16 | name="fid_recons")
17 |
18 |
19 | def sample_fid_callback(*args, **kwargs):
20 | return fid_callback.fid(*args, **kwargs,
21 | im_in_key="image",
22 | im_in_support="-1->1",
23 | im_out_key="samples",
24 | im_out_support="0->255",
25 | name="fid_samples")
26 |
27 |
28 | def reconstruction_callback(root, data_in, data_out, config):
29 | log = {"scalars": dict()}
30 | log["scalars"]["rec_loss"] = np.mean(data_out.labels["rec_loss"])
31 | log["scalars"]["kl_loss"] = np.mean(data_out.labels["kl_loss"])
32 | return log
33 |
34 |
35 | class Trainer(Iterator):
36 | """
37 | AE Trainer. Expects `image` in batch, `encode -> Distribution` and `decode`
38 | methods on model.
39 | """
40 | def __init__(self, *args, **kwargs):
41 | super().__init__(*args, **kwargs)
42 | self.eval_loss = PerceptualLoss()
43 |
44 | def step_op(self, *args, **kwargs):
45 | inputs = kwargs["image"]
46 | inputs = self.totorch(inputs)
47 |
48 | posterior = self.model.encode(inputs)
49 | z = posterior.sample()
50 | reconstructions = self.model.decode(z)
51 | loss, log_dict, loss_train_op = self.loss(
52 | inputs, reconstructions, posterior, self.get_global_step())
53 | log_dict.setdefault("scalars", dict())
54 | log_dict.setdefault("images", dict())
55 |
56 | def train_op():
57 | loss_train_op()
58 | self.optimizer.zero_grad()
59 | loss.backward()
60 | self.optimizer.step()
61 | self.update_lr()
62 |
63 | def log_op():
64 | log_dict["images"]["inputs"] = inputs
65 | log_dict["images"]["reconstructions"] = reconstructions
66 |
67 | if not hasattr(self, "fixed_examples"):
68 | self.fixed_examples = [
69 | self.dataset[i]["image"]
70 | for i in np.random.RandomState(1).choice(len(self.dataset),
71 | self.config["batch_size"])]
72 | self.fixed_examples = np.stack(self.fixed_examples)
73 | self.fixed_examples = self.totorch(self.fixed_examples)
74 |
75 | with torch.no_grad():
76 | log_dict["images"]["fixed_inputs"] = self.fixed_examples
77 | log_dict["images"]["fixed_reconstructions"] = self.model.decode(
78 | self.model.encode(self.fixed_examples).sample())
79 | log_dict["images"]["decoded_sample"] = self.model.decode(
80 | torch.randn_like(posterior.mode()))
81 |
82 | for k in log_dict:
83 | for kk in log_dict[k]:
84 | log_dict[k][kk] = self.tonp(log_dict[k][kk])
85 |
86 | for i, param_group in enumerate(self.optimizer.param_groups):
87 | log_dict["scalars"]["lr_{:02}".format(i)] = param_group['lr']
88 | return log_dict
89 |
90 | def eval_op():
91 | with torch.no_grad():
92 | kl_loss = posterior.kl()
93 | rec_loss = self.eval_loss(reconstructions, inputs)
94 | samples = self.model.decode(torch.randn_like(posterior.mode()))
95 | return {"reconstructions": self.tonp(reconstructions),
96 | "samples": self.tonp(samples),
97 | "labels": {"rec_loss": self.tonp(rec_loss),
98 | "kl_loss": self.tonp(kl_loss)}}
99 |
100 | return {"train_op": train_op, "log_op": log_op, "eval_op": eval_op}
101 |
102 | @property
103 | def callbacks(self):
104 | cbs = {"eval_op": {"reconstruction": reconstruction_callback}}
105 | cbs["eval_op"]["fid_reconstruction"] = rec_fid_callback
106 | cbs["eval_op"]["fid_samples"] = sample_fid_callback
107 | return cbs
108 |
--------------------------------------------------------------------------------
/iin/iterators/base.py:
--------------------------------------------------------------------------------
1 | import itertools
2 | import torch
3 | import numpy as np
4 | import edflow
5 | from edflow import TemplateIterator, get_obj_from_str
6 | from edflow.util import retrieve
7 |
8 |
9 | def totorch(x, guess_image=True, device=None):
10 | if x.dtype == np.float64:
11 | x = x.astype(np.float32)
12 | x = torch.tensor(x)
13 | if guess_image and len(x.size()) == 4:
14 | x = x.transpose(2, 3).transpose(1, 2)
15 | if device is None:
16 | if torch.cuda.is_available():
17 | device = torch.device("cuda")
18 | else:
19 | device = torch.device("cpu")
20 | x = x.to(device)
21 | return x
22 |
23 |
24 | def tonp(x, guess_image=True):
25 | try:
26 | if guess_image and len(x.shape) == 4:
27 | x = x.transpose(1, 2).transpose(2, 3)
28 | return x.detach().cpu().numpy()
29 | except AttributeError:
30 | return x
31 |
32 |
33 | def get_learning_rate(config):
34 | if "learning_rate" in config:
35 | learning_rate = config["learning_rate"]
36 | elif "base_learning_rate" in config:
37 | learning_rate = config["base_learning_rate"]*config["batch_size"]
38 | else:
39 | raise KeyError()
40 | return learning_rate
41 |
42 |
43 | class Iterator(TemplateIterator):
44 | """
45 | Base class to handle device and state. Adds optimizer and loss.
46 | Call update_lr() in train op for lr scheduling.
47 |
48 | Config parameters:
49 | - test_mode : boolean : Put model into .eval() mode.
50 | - no_restore_keys : string1,string2 : Submodels which should not be
51 | restored from checkpoint.
52 | - learning_rate : float : Learning rate of Adam
53 | - base_learning_rate : float : Learning_rate per example to adjust for
54 | batch size (ignored if learning_rate is present)
55 | - decay_start : float : Step after which learning rate is decayed to
56 | zero.
57 | - loss : string : Import path of loss.
58 | """
59 | def __init__(self, *args, **kwargs):
60 | super().__init__(*args, **kwargs)
61 | if torch.cuda.is_available():
62 | self.device = torch.device("cuda")
63 | else:
64 | self.device = torch.device("cpu")
65 | self.model.to(self.device)
66 | self.test_mode = self.config.get("test_mode", False)
67 | if self.test_mode:
68 | self.model.eval()
69 | self.submodules = ["model"]
70 | self.do_not_restore_keys = retrieve(self.config, 'no_restore_keys', default='').split(',')
71 |
72 | self.learning_rate = get_learning_rate(self.config)
73 | self.logger.info("learning_rate: {}".format(self.learning_rate))
74 | params = self.model.parameters()
75 | if "loss" in self.config:
76 | self.loss = get_obj_from_str(self.config["loss"])(self.config)
77 | self.loss.to(self.device)
78 | self.submodules.append("loss")
79 | params = itertools.chain(params, self.loss.parameters())
80 | self.optimizer = torch.optim.Adam(
81 | params,
82 | lr=self.learning_rate,
83 | betas=(0.5, 0.9))
84 | self.submodules.append("optimizer")
85 | try:
86 | self.loss.set_last_layer(self.model.get_last_layer())
87 | except Exception as e:
88 | self.logger.info(' Could not set last layer for calibration. Reason:\n {}'.format(e))
89 |
90 | self.num_steps = retrieve(self.config, "num_steps")
91 | self.decay_start = retrieve(self.config, "decay_start", default=self.num_steps)
92 |
93 | def get_decay_factor(self):
94 | alpha = 1.0
95 | if self.num_steps > self.decay_start:
96 | alpha = 1.0 - np.clip(
97 | (self.get_global_step() - self.decay_start) /
98 | (self.num_steps - self.decay_start),
99 | 0.0, 1.0)
100 | return alpha
101 |
102 | def update_lr(self):
103 | for param_group in self.optimizer.param_groups:
104 | param_group['lr'] = self.get_decay_factor()*self.learning_rate
105 |
106 | def get_state(self):
107 | state = dict()
108 | for k in self.submodules:
109 | state[k] = getattr(self, k).state_dict()
110 | return state
111 |
112 | def save(self, checkpoint_path):
113 | torch.save(self.get_state(), checkpoint_path)
114 |
115 | def restore(self, checkpoint_path):
116 | state = torch.load(checkpoint_path)
117 | keys = list(state.keys())
118 | for k in keys:
119 | if hasattr(self, k):
120 | if k not in self.do_not_restore_keys:
121 | try:
122 | missing, unexpected = getattr(self, k).load_state_dict(state[k], strict=False)
123 | if missing:
124 | self.logger.info("Missing keys for {}: {}".format(k, missing))
125 | if unexpected:
126 | self.logger.info("Unexpected keys for {}: {}".format(k, unexpected))
127 | except TypeError:
128 | self.logger.info(k)
129 | try:
130 | getattr(self, k).load_state_dict(state[k])
131 | except ValueError:
132 | self.logger.info("Could not load state dict for key {}".format(k))
133 | else:
134 | self.logger.info('Restored key `{}`'.format(k))
135 | else:
136 | self.logger.info('Not restoring key `{}` (as specified)'.format(k))
137 | del state[k]
138 |
139 | def totorch(self, x, guess_image=True):
140 | return totorch(x, guess_image=guess_image, device=self.device)
141 |
142 | def tonp(self, x, guess_image=True):
143 | return tonp(x, guess_image=guess_image)
144 |
145 | def interpolate_corners(self, x, num_side, permute=False):
146 | return interpolate_corners(x, side=num_side, permute=permute)
147 |
148 |
149 |
150 | class DimEvaluator(Iterator):
151 | """
152 | Estimate dimensionalities for factors.
153 | AE Trainer. Expects `factor`, `example1` and `example2` with `image` in batch,
154 | `encode -> Distribution` methods on model.
155 | """
156 | def __init__(self, *args, **kwargs):
157 | super().__init__(*args, **kwargs)
158 |
159 | def step_op(self, *args, **kwargs):
160 | def eval_op():
161 | with torch.no_grad():
162 | hs = dict()
163 | factor=kwargs["factor"]
164 | with torch.no_grad():
165 | for k in ["example1", "example2"]:
166 | inputs = self.totorch(kwargs[k]["image"])
167 | hs[k] = self.model.encode(inputs).mode()
168 | log_dict = {}
169 | log_dict["labels"] = {"factor": factor}
170 | for k in ["example1", "example2"]:
171 | log_dict["labels"][k] = self.tonp(hs[k], guess_image=False)
172 | return log_dict
173 |
174 | return {"eval_op": eval_op}
175 |
176 | @property
177 | def callbacks(self):
178 | return {"eval_op": {"dim_callback": dim_callback}}
179 |
180 |
181 | def dim_callback(root, data_in, data_out, config):
182 | logger = edflow.get_logger("dim_callback")
183 |
184 | factors = data_out.labels["factor"]
185 | za = data_out.labels["example1"].squeeze()
186 | zb = data_out.labels["example2"].squeeze()
187 | za_by_factor = dict()
188 | zb_by_factor = dict()
189 | mean_by_factor = dict()
190 | score_by_factor = dict()
191 |
192 | zall = np.concatenate([za,zb], 0)
193 | mean = np.mean(zall, 0, keepdims=True)
194 | var = np.sum(np.mean((zall-mean)*(zall-mean), 0))
195 | for f in range(data_in.n_factors):
196 | if f != data_in.residual_index:
197 | indices = np.where(factors==f)[0]
198 | za_by_factor[f] = za[indices]
199 | zb_by_factor[f] = zb[indices]
200 | mean_by_factor[f] = 0.5*(
201 | np.mean(za_by_factor[f], 0, keepdims=True)+
202 | np.mean(zb_by_factor[f], 0, keepdims=True))
203 | score_by_factor[f] = np.sum(
204 | np.mean(
205 | (za_by_factor[f]-mean_by_factor[f])*(zb_by_factor[f]-mean_by_factor[f]), 0))
206 | score_by_factor[f] = score_by_factor[f]/var
207 | else:
208 | score_by_factor[f] = 1.0
209 | scores = np.array([score_by_factor[f] for f in range(data_in.n_factors)])
210 |
211 | m = np.max(scores)
212 | e = np.exp(scores-m)
213 | softmaxed = e / np.sum(e)
214 |
215 | dim = za.shape[1]
216 | dims = [int(s*dim) for s in softmaxed]
217 | dims[-1] = dim - sum(dims[:-1])
218 | logger.info("estimated factor dimensionalities: {}".format(dims))
219 |
--------------------------------------------------------------------------------
/iin/iterators/clf.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from edflow.util import retrieve
4 |
5 | from iin.iterators.base import Iterator
6 |
7 |
8 | class Trainer(Iterator):
9 | """
10 | Classification Trainer. Expects `image` and `class` in batch,
11 | and a model returning logits.
12 | """
13 | def __init__(self, *args, **kwargs):
14 | super().__init__(*args, **kwargs)
15 | self.n_classes = retrieve(self.config, "n_classes")
16 |
17 | def step_op(self, *args, **kwargs):
18 | inputs = kwargs["image"]
19 | inputs = self.totorch(inputs)
20 | labels = kwargs["class"]
21 | labels = self.totorch(labels).to(torch.int64)
22 | onehot = torch.nn.functional.one_hot(labels,
23 | num_classes=self.n_classes).float()
24 |
25 | logits = self.model(inputs)
26 | loss = torch.nn.functional.binary_cross_entropy_with_logits(logits,
27 | onehot,
28 | reduction="none")
29 | mean_loss = loss.mean()
30 |
31 | def train_op():
32 | self.optimizer.zero_grad()
33 | mean_loss.backward()
34 | self.optimizer.step()
35 | self.update_lr()
36 |
37 | def log_op():
38 | with torch.no_grad():
39 | prediction = torch.argmax(logits, dim=1)
40 | accuracy = (prediction==labels).float().mean()
41 |
42 | log_dict = {"scalars": {
43 | "loss": mean_loss,
44 | "acc": accuracy
45 | }}
46 |
47 | for k in log_dict:
48 | for kk in log_dict[k]:
49 | log_dict[k][kk] = self.tonp(log_dict[k][kk])
50 |
51 | for i, param_group in enumerate(self.optimizer.param_groups):
52 | log_dict["scalars"]["lr_{:02}".format(i)] = param_group['lr']
53 | return log_dict
54 |
55 | def eval_op():
56 | return {
57 | "labels": {"loss": self.tonp(loss), "logits": self.tonp(logits)},
58 | }
59 |
60 | return {"train_op": train_op, "log_op": log_op, "eval_op": eval_op}
61 |
62 | @property
63 | def callbacks(self):
64 | return {"eval_op": {"clf_callback": clf_callback}}
65 |
66 |
67 | def clf_callback(root, data_in, data_out, config):
68 | loss = data_out.labels["loss"].mean()
69 | prediction = data_out.labels["logits"].argmax(1)
70 | accuracy = (prediction == data_in.labels["class"][:prediction.shape[0]]).mean()
71 | return {"scalars": {"loss": loss, "accuracy": accuracy}}
72 |
--------------------------------------------------------------------------------
/iin/iterators/iin.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from edflow.util import retrieve, get_obj_from_str
4 |
5 | from iin.iterators.base import Iterator
6 | from iin.iterators.ae import sample_fid_callback
7 |
8 |
9 | def loss_callback(root, data_in, data_out, config):
10 | log = {"scalars": dict()}
11 | log["scalars"]["loss"] = np.mean(data_out.labels["loss"])
12 | return log
13 |
14 |
15 | class Trainer(Iterator):
16 | """
17 | Unsupervised IIN Trainer. Expects `image` in batch,
18 | `encode -> Distribution` and `decode` methods on first stage model.
19 | """
20 | def __init__(self, *args, **kwargs):
21 | super().__init__(*args, **kwargs)
22 | first_stage_config = self.config["first_stage"]
23 | self.init_first_stage(first_stage_config)
24 |
25 | def init_first_stage(self, config):
26 | subconfig = config["subconfig"]
27 | self.first_stage = get_obj_from_str(config["model"])(subconfig)
28 | if "checkpoint" in config:
29 | checkpoint = config["checkpoint"]
30 | state = torch.load(checkpoint)["model"]
31 | self.first_stage.load_state_dict(state)
32 | self.logger.info("Restored first stage from {}".format(checkpoint))
33 | self.first_stage.to(self.device)
34 | self.first_stage.eval()
35 |
36 | def step_op(self, *args, **kwargs):
37 | inputs = kwargs["image"]
38 | inputs = self.totorch(inputs)
39 |
40 | with torch.no_grad():
41 | posterior = self.first_stage.encode(inputs)
42 | z = posterior.sample()
43 | z_ss, logdet = self.model(z)
44 | loss, log_dict, loss_train_op = self.loss(z_ss, logdet, self.get_global_step())
45 | log_dict.setdefault("scalars", dict())
46 | log_dict.setdefault("images", dict())
47 |
48 | def train_op():
49 | loss_train_op()
50 | self.optimizer.zero_grad()
51 | loss.backward()
52 | self.optimizer.step()
53 | self.update_lr()
54 |
55 | def log_op():
56 | if not hasattr(self, "fixed_examples"):
57 | self.fixed_examples = np.random.RandomState(1).randn(*posterior.mode().shape)
58 | self.fixed_examples = self.totorch(self.fixed_examples)
59 |
60 | with torch.no_grad():
61 | reconstructions = self.first_stage.decode(
62 | self.model.reverse(z_ss))
63 | samples = self.first_stage.decode(
64 | self.model.reverse(torch.randn_like(posterior.mode())))
65 | fixed_samples = self.first_stage.decode(
66 | self.model.reverse(self.fixed_examples))
67 |
68 | log_dict["images"]["inputs"] = inputs
69 | log_dict["images"]["reconstructions"] = reconstructions
70 | log_dict["images"]["samples"] = samples
71 | log_dict["images"]["fixed_samples"] = fixed_samples
72 |
73 | for k in log_dict:
74 | for kk in log_dict[k]:
75 | log_dict[k][kk] = self.tonp(log_dict[k][kk])
76 |
77 | for i, param_group in enumerate(self.optimizer.param_groups):
78 | log_dict["scalars"]["lr_{:02}".format(i)] = param_group['lr']
79 | return log_dict
80 |
81 | def eval_op():
82 | with torch.no_grad():
83 | loss_ = torch.ones(inputs.shape[0])*loss
84 | samples = self.first_stage.decode(self.model.reverse(torch.randn_like(posterior.mode())))
85 | return {"samples": self.tonp(samples),
86 | "labels": {"loss": self.tonp(loss_)}}
87 |
88 | return {"train_op": train_op, "log_op": log_op, "eval_op": eval_op}
89 |
90 | @property
91 | def callbacks(self):
92 | cbs = {"eval_op": {"loss_cb": loss_callback}}
93 | cbs["eval_op"]["fid_samples"] = sample_fid_callback
94 | return cbs
95 |
96 |
97 | class FactorTrainer(Trainer):
98 | def step_op(self, *args, **kwargs):
99 | # get inputs
100 | inputs_factor = dict()
101 | z_factor = dict()
102 | z_ss_factor = dict()
103 | logdet_factor = dict()
104 | factors = kwargs["factor"]
105 | for k in ["example1", "example2"]:
106 | inputs = kwargs[k]["image"]
107 | inputs = self.totorch(inputs)
108 | inputs_factor[k] = inputs
109 |
110 | with torch.no_grad():
111 | posterior = self.first_stage.encode(inputs)
112 | z = posterior.sample()
113 | z_factor[k] = z
114 | z_ss, logdet = self.model(z)
115 | z_ss_factor[k] = z_ss
116 | logdet_factor[k] = logdet
117 |
118 | loss, log_dict, loss_train_op = self.loss(
119 | z_ss_factor, logdet_factor, factors, self.get_global_step())
120 |
121 | def train_op():
122 | loss_train_op()
123 | self.optimizer.zero_grad()
124 | loss.backward()
125 | self.optimizer.step()
126 | self.update_lr()
127 |
128 | def log_op():
129 | if not hasattr(self.model, "decode"):
130 | return None
131 | with torch.no_grad():
132 | for k in ["example1", "example2"]:
133 | log_dict["images"][k] = inputs_factor[k]
134 | # reencode after update of model
135 | ss_z1 = self.model(z_factor["example1"])[0]
136 | ss_z2 = self.model(z_factor["example2"])[0]
137 |
138 | log_dict["images"]["reconstruction"] = self.first_stage.decode(
139 | self.model.reverse(ss_z1))
140 |
141 | factor_mask = [
142 | torch.tensor(((factors==i) | ((factors<0) & (factors!=-i)))[:,None,None,None]).to(
143 | ss_z1[i]) for i in range(len(ss_z1))]
144 | ss_z_swap = [
145 | ss_z1[i] +
146 | factor_mask[i]*(
147 | ss_z2[i] - ss_z1[i])
148 | for i in range(len(factor_mask))]
149 | log_dict["images"]["decoded_swap"] = self.first_stage.decode(
150 | self.model.reverse(ss_z_swap))
151 |
152 | N_cross = 6
153 | z_cross = z_factor["example1"][:N_cross,...]
154 | shape = tuple(z_cross.shape)
155 | z_cross1 = z_cross[None,...][N_cross*[0],...].reshape(
156 | N_cross*N_cross, *shape[1:])
157 | z_cross2 = z_cross[:,None,...][:,N_cross*[0],...].reshape(
158 | N_cross*N_cross, *shape[1:])
159 | ss_z_cross1 = self.model(z_cross1)[0]
160 | ss_z_cross2 = self.model(z_cross2)[0]
161 | for i in range(len(ss_z1)):
162 | ss_z_cross = list(ss_z_cross2)
163 | ss_z_cross[i] = ss_z_cross1[i]
164 | log_dict["images"]["decoded_cross_{}".format(i)] = self.first_stage.decode(
165 | self.model.reverse(ss_z_cross))
166 |
167 | N_fixed = 6
168 | if not hasattr(self, "fixed_examples"):
169 | self.fixed_examples = [
170 | self.dataset[i]["example1"]["image"]
171 | for i in np.random.RandomState(1).choice(len(self.dataset),
172 | N_fixed)]
173 | self.fixed_examples = np.stack(self.fixed_examples)
174 | self.fixed_examples = self.totorch(self.fixed_examples)
175 | self.fixed_examples = self.first_stage.encode(self.fixed_examples).mode()
176 | shape = tuple(self.fixed_examples.shape)
177 | self.fixed1 = self.fixed_examples[None,...][N_fixed*[0],...].reshape(
178 | N_fixed*N_fixed, *shape[1:])
179 | self.fixed2 = self.fixed_examples[:,None,...][:,N_fixed*[0],...].reshape(
180 | N_fixed*N_fixed, *shape[1:])
181 |
182 | ss_z_fixed1 = self.model(self.fixed1)[0]
183 | ss_z_fixed2 = self.model(self.fixed2)[0]
184 | for i in range(len(ss_z_fixed1)):
185 | ss_z_cross = list(ss_z_fixed2)
186 | ss_z_cross[i] = ss_z_fixed1[i]
187 | log_dict["images"]["fixed_cross_{}".format(i)] = self.first_stage.decode(
188 | self.model.reverse(ss_z_cross))
189 |
190 | for k in log_dict:
191 | for kk in log_dict[k]:
192 | log_dict[k][kk] = self.tonp(log_dict[k][kk])
193 | return log_dict
194 |
195 | return {"train_op": train_op, "log_op": log_op}
196 |
197 | @property
198 | def callbacks(self):
199 | return dict()
200 |
--------------------------------------------------------------------------------
/iin/losses/ae.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from edflow.util import retrieve
3 | import torch.nn as nn
4 | import torch
5 | from perceptual_similarity import PerceptualLoss
6 | import functools
7 | from torch.nn.utils.spectral_norm import spectral_norm
8 |
9 |
10 | def do_spectral_norm(m):
11 | classname = m.__class__.__name__
12 | if classname.find('Conv') != -1:
13 | spectral_norm(m)
14 |
15 |
16 | def weights_init(m):
17 | classname = m.__class__.__name__
18 | if classname.find('Conv') != -1:
19 | nn.init.normal_(m.weight.data, 0.0, 0.02)
20 | elif classname.find('BatchNorm') != -1:
21 | nn.init.normal_(m.weight.data, 1.0, 0.02)
22 | nn.init.constant_(m.bias.data, 0)
23 |
24 |
25 | def adopt_weight(weight, global_step, threshold=0, value=0.):
26 | if global_step < threshold:
27 | weight = value
28 | return weight
29 |
30 |
31 | def l1(input, target):
32 | return torch.abs(input-target)
33 |
34 |
35 | def l2(input, target):
36 | return torch.pow((input-target), 2)
37 |
38 |
39 | class NLayerDiscriminator(nn.Module):
40 | """Defines a PatchGAN discriminator
41 | --> from
42 | https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
43 | """
44 |
45 | def __init__(self, input_nc=3, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
46 | """Construct a PatchGAN discriminator
47 | Parameters:
48 | input_nc (int) -- the number of channels in input images
49 | ndf (int) -- the number of filters in the last conv layer
50 | n_layers (int) -- the number of conv layers in the discriminator
51 | norm_layer -- normalization layer
52 | """
53 | super(NLayerDiscriminator, self).__init__()
54 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
55 | use_bias = norm_layer.func != nn.BatchNorm2d
56 | else:
57 | use_bias = norm_layer != nn.BatchNorm2d
58 |
59 | kw = 4
60 | padw = 1
61 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
62 | nf_mult = 1
63 | nf_mult_prev = 1
64 | for n in range(1, n_layers): # gradually increase the number of filters
65 | nf_mult_prev = nf_mult
66 | nf_mult = min(2 ** n, 8)
67 | sequence += [
68 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
69 | norm_layer(ndf * nf_mult),
70 | nn.LeakyReLU(0.2, True)
71 | ]
72 |
73 | nf_mult_prev = nf_mult
74 | nf_mult = min(2 ** n_layers, 8)
75 | sequence += [
76 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
77 | norm_layer(ndf * nf_mult),
78 | nn.LeakyReLU(0.2, True)
79 | ]
80 |
81 | sequence += [
82 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
83 | self.main = nn.Sequential(*sequence)
84 |
85 | def forward(self, input):
86 | """Standard forward."""
87 | return self.main(input)
88 |
89 |
90 | class LossPND(nn.Module):
91 | """Using LPIPS Perceptual loss"""
92 | def __init__(self, config):
93 | super().__init__()
94 | __pixel_loss_opt = {"l1": l1,
95 | "l2": l2 }
96 | self.config = config
97 | self.discriminator_iter_start = retrieve(config, "Loss/disc_start")
98 | self.disc_factor = retrieve(config, "Loss/disc_factor", default=1.0)
99 | self.kl_weight = retrieve(config, "Loss/kl_weight", default=1.0)
100 | self.perceptual_weight = retrieve(config, "Loss/perceptual_weight", default=1.0)
101 | if self.perceptual_weight > 0:
102 | self.perceptual_loss = PerceptualLoss()
103 | self.calibrate = retrieve(config, "Loss/calibrate", default=False)
104 | # output log variance
105 | self.logvar_init = retrieve(config, "Loss/logvar_init", default=0.0)
106 | self.logvar = nn.Parameter(torch.ones(size=())*self.logvar_init)
107 | # discriminator
108 | self.discriminator_weight = retrieve(config, "Loss/discriminator_weight", default=1.0)
109 | disc_nc_in = retrieve(config, "Loss/disc_in_channels", default=3)
110 | disc_layers = retrieve(config, "Loss/disc_num_layers", default=3)
111 | self.discriminator = NLayerDiscriminator(input_nc=disc_nc_in, n_layers=disc_layers).apply(weights_init)
112 | self.pixel_loss = __pixel_loss_opt[retrieve(config, "Loss/pixelloss", default="l1")]
113 | if retrieve(config, "Loss/spectral_norm", default=False):
114 | self.discriminator.apply(do_spectral_norm)
115 | if torch.cuda.is_available():
116 | self.discriminator.cuda()
117 | if "learning_rate" in self.config:
118 | learning_rate = self.config["learning_rate"]
119 | elif "base_learning_rate" in self.config:
120 | learning_rate = self.config["base_learning_rate"]*self.config["batch_size"]
121 | else:
122 | learning_rate = 0.001
123 | self.learning_rate = retrieve(config, "Loss/d_lr_factor", default=1.0)*learning_rate
124 | self.num_steps = retrieve(self.config, "num_steps")
125 | self.decay_start = retrieve(self.config, "decay_start", default=self.num_steps)
126 | self.d_optimizer = torch.optim.Adam(
127 | self.discriminator.parameters(),
128 | lr=learning_rate,
129 | betas=(0.5, 0.9))
130 |
131 | def get_decay_factor(self, global_step):
132 | alpha = 1.0
133 | if self.num_steps > self.decay_start:
134 | alpha = 1.0 - np.clip(
135 | (global_step - self.decay_start) /
136 | (self.num_steps - self.decay_start),
137 | 0.0, 1.0)
138 | return alpha
139 |
140 | def update_lr(self, global_step):
141 | for param_group in self.d_optimizer.param_groups:
142 | param_group['lr'] = self.get_decay_factor(global_step)*self.learning_rate
143 |
144 | def set_last_layer(self, last_layer):
145 | self.last_layer = [last_layer]
146 |
147 | def parameters(self):
148 | """Exclude discriminator from parameters."""
149 | ps = super().parameters()
150 | exclude = set(self.discriminator.parameters())
151 | ps = (p for p in ps if not p in exclude)
152 | return ps
153 |
154 | def forward(self, inputs, reconstructions, posteriors, global_step):
155 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
156 | rec_loss = self.pixel_loss(inputs, reconstructions) # l1 or l2
157 | if self.perceptual_weight > 0:
158 | p_loss = self.perceptual_loss(inputs, reconstructions)
159 | rec_loss = rec_loss + self.perceptual_weight*p_loss
160 |
161 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
162 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
163 |
164 | calibration = rec_loss / torch.exp(self.logvar) - 1.0
165 | calibration = torch.sum(calibration) / calibration.shape[0]
166 |
167 | kl_loss = posteriors.kl()
168 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
169 |
170 | logits_real = self.discriminator(inputs)
171 | logits_fake = self.discriminator(reconstructions)
172 | d_loss = 0.5*(
173 | torch.mean(torch.nn.functional.softplus(-logits_real)) +
174 | torch.mean(torch.nn.functional.softplus( logits_fake))) * disc_factor
175 | def train_op():
176 | self.d_optimizer.zero_grad()
177 | d_loss.backward(retain_graph=True)
178 | self.d_optimizer.step()
179 | self.update_lr(global_step)
180 |
181 | g_loss = -torch.mean(logits_fake)
182 |
183 | if not self.calibrate:
184 | loss = nll_loss + self.kl_weight*kl_loss + self.discriminator_weight*g_loss*disc_factor
185 | else:
186 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
187 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
188 | d_weight = torch.norm(nll_grads)/(torch.norm(g_grads)+1e-4)
189 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
190 | loss = nll_loss + self.kl_weight*kl_loss + d_weight*disc_factor*g_loss
191 | log = {"scalars": {"loss": loss, "logvar": self.logvar,
192 | "kl_loss": kl_loss, "nll_loss": nll_loss,
193 | "rec_loss": rec_loss.mean(),
194 | "calibration": calibration,
195 | "g_loss": g_loss, "d_loss": d_loss,
196 | "logits_real": torch.mean(logits_real),
197 | "logits_fake": torch.mean(logits_fake),
198 | }}
199 | if self.calibrate:
200 | log["scalars"]["d_weight"] = d_weight
201 | for i, param_group in enumerate(self.d_optimizer.param_groups):
202 | log["scalars"]["d_lr_{:02}".format(i)] = param_group['lr']
203 | return loss, log, train_op
204 |
--------------------------------------------------------------------------------
/iin/losses/iin.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch.nn as nn
3 | import torch
4 |
5 | from edflow.util import retrieve
6 |
7 |
8 | def nll(sample):
9 | return 0.5*torch.sum(torch.pow(sample, 2), dim=[1,2,3])
10 |
11 |
12 | class Loss(nn.Module):
13 | def __init__(self, config):
14 | super().__init__()
15 | self.config = config
16 |
17 | def forward(self, sample, logdet, global_step):
18 | nll_loss = torch.mean(nll(sample))
19 | assert len(logdet.shape) == 1
20 | nlogdet_loss = -torch.mean(logdet)
21 | loss = nll_loss + nlogdet_loss
22 | reference_nll_loss = torch.mean(nll(torch.randn_like(sample)))
23 | log = {"images": {},
24 | "scalars": {
25 | "loss": loss, "reference_nll_loss": reference_nll_loss,
26 | "nlogdet_loss": nlogdet_loss, "nll_loss": nll_loss,
27 | }}
28 | def train_op():
29 | pass
30 | return loss, log, train_op
31 |
32 |
33 | class FactorLoss(nn.Module):
34 | def __init__(self, config):
35 | super().__init__()
36 | self.config = config
37 | self.rho = retrieve(config, "Loss/rho", default=0.975)
38 |
39 | def forward(self, samples, logdets, factors, global_step):
40 | sample1 = samples["example1"]
41 | logdet1 = logdets["example1"]
42 | nll_loss1 = torch.mean(nll(torch.cat(sample1, dim=1)))
43 | assert len(logdet1.shape) == 1
44 | nlogdet_loss1 = -torch.mean(logdet1)
45 | loss1 = nll_loss1 + nlogdet_loss1
46 |
47 | sample2 = samples["example2"]
48 | logdet2 = logdets["example2"]
49 | factor_mask = [
50 | torch.tensor(((factors==i) | ((factors<0) & (factors!=-i)))[:,None,None,None]).to(
51 | sample2[i]) for i in range(len(sample2))]
52 | sample2_cond = [
53 | sample2[i] - factor_mask[i]*self.rho*sample1[i]
54 | for i in range(len(sample2))]
55 | nll_loss2 = [nll(sample2_cond[i]) for i in range(len(sample2_cond))]
56 | nll_loss2 = [nll_loss2[i]/(1.0-factor_mask[i][:,0,0,0]*self.rho**2)
57 | for i in range(len(sample2_cond))]
58 | nll_loss2 = [torch.mean(nll_loss2[i])
59 | for i in range(len(sample2_cond))]
60 | nll_loss2 = sum(nll_loss2)
61 | assert len(logdet2.shape) == 1
62 | nlogdet_loss2 = -torch.mean(logdet2)
63 | loss2 = nll_loss2 + nlogdet_loss2
64 |
65 | loss = loss1 + loss2
66 |
67 | log = {"images": {},
68 | "scalars": {
69 | "loss": loss, "loss1": loss1, "loss2": loss2,
70 | }}
71 | def train_op():
72 | pass
73 | return loss, log, train_op
74 |
--------------------------------------------------------------------------------
/iin/models/ae.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import torch.nn as nn
3 | import torch
4 | import numpy as np
5 | from edflow.util import retrieve
6 |
7 |
8 | class ActNorm(nn.Module):
9 | def __init__(self, num_features, affine=True, logdet=False):
10 | super().__init__()
11 | assert affine
12 | self.logdet = logdet
13 | self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
14 | self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
15 |
16 | self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
17 |
18 | def initialize(self, input):
19 | with torch.no_grad():
20 | flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
21 | mean = (
22 | flatten.mean(1)
23 | .unsqueeze(1)
24 | .unsqueeze(2)
25 | .unsqueeze(3)
26 | .permute(1, 0, 2, 3)
27 | )
28 | std = (
29 | flatten.std(1)
30 | .unsqueeze(1)
31 | .unsqueeze(2)
32 | .unsqueeze(3)
33 | .permute(1, 0, 2, 3)
34 | )
35 |
36 | self.loc.data.copy_(-mean)
37 | self.scale.data.copy_(1 / (std + 1e-6))
38 |
39 | def forward(self, input, reverse=False):
40 | if reverse:
41 | return self.reverse(input)
42 | _, _, height, width = input.shape
43 |
44 | if self.initialized.item() == 0:
45 | self.initialize(input)
46 | self.initialized.fill_(1)
47 |
48 | h = self.scale * (input + self.loc)
49 |
50 | if self.logdet:
51 | log_abs = torch.log(torch.abs(self.scale))
52 | logdet = height*width*torch.sum(log_abs)
53 | logdet = logdet * torch.ones(input.shape[0]).to(input)
54 | return h, logdet
55 |
56 | return h
57 |
58 | def reverse(self, output):
59 | return output / self.scale - self.loc
60 |
61 |
62 | _norm_options = {
63 | "in": nn.InstanceNorm2d,
64 | "bn": nn.BatchNorm2d,
65 | "an": ActNorm}
66 |
67 |
68 | def weights_init(m):
69 | classname = m.__class__.__name__
70 | if classname.find('Conv') != -1:
71 | nn.init.normal_(m.weight.data, 0.0, 0.02)
72 | elif classname.find('BatchNorm') != -1:
73 | nn.init.normal_(m.weight.data, 1.0, 0.02)
74 | nn.init.constant_(m.bias.data, 0)
75 |
76 |
77 | class FeatureLayer(nn.Module):
78 | def __init__(self, scale, in_channels=None, norm='IN'):
79 | super().__init__()
80 | self.scale = scale
81 | self.norm = _norm_options[norm.lower()]
82 | if in_channels is None:
83 | self.in_channels = 64*min(2**(self.scale-1), 16)
84 | else:
85 | self.in_channels = in_channels
86 | self.build()
87 |
88 | def forward(self, input):
89 | x = input
90 | for layer in self.sub_layers:
91 | x = layer(x)
92 | return x
93 |
94 | def build(self):
95 | Norm = functools.partial(self.norm, affine=True)
96 | Activate = lambda: nn.LeakyReLU(0.2)
97 | self.sub_layers = nn.ModuleList([
98 | nn.Conv2d(
99 | in_channels=self.in_channels,
100 | out_channels=64*min(2**self.scale, 16),
101 | kernel_size=4,
102 | stride=2,
103 | padding=1,
104 | bias=False),
105 | Norm(num_features=64*min(2**self.scale, 16)),
106 | Activate()])
107 |
108 |
109 | class LatentLayer(nn.Module):
110 | def __init__(self, in_channels, out_channels):
111 | super(LatentLayer, self).__init__()
112 | self.in_channels = in_channels
113 | self.out_channels = out_channels
114 | self.build()
115 |
116 | def forward(self, input):
117 | x = input
118 | for layer in self.sub_layers:
119 | x = layer(x)
120 | return x
121 |
122 | def build(self):
123 | self.sub_layers = nn.ModuleList([
124 | nn.Conv2d(
125 | in_channels=self.in_channels,
126 | out_channels=self.out_channels,
127 | kernel_size=1,
128 | stride=1,
129 | padding=0,
130 | bias=True)
131 | ])
132 |
133 |
134 | class DecoderLayer(nn.Module):
135 | def __init__(self, scale, in_channels=None, norm='IN'):
136 | super().__init__()
137 | self.scale = scale
138 | self.norm = _norm_options[norm.lower()]
139 | if in_channels is not None:
140 | self.in_channels = in_channels
141 | else:
142 | self.in_channels = 64*min(2**(self.scale+1), 16)
143 | self.build()
144 |
145 | def forward(self, input):
146 | d = input
147 | for layer in self.sub_layers:
148 | d = layer(d)
149 | return d
150 |
151 | def build(self):
152 | Norm = functools.partial(self.norm, affine=True)
153 | Activate = lambda: nn.LeakyReLU(0.2)
154 | self.sub_layers = nn.ModuleList([
155 | nn.ConvTranspose2d(
156 | in_channels=self.in_channels,
157 | out_channels=64*min(2**self.scale, 16),
158 | kernel_size=4,
159 | stride=2,
160 | padding=1,
161 | bias=False),
162 | Norm(num_features=64*min(2**self.scale, 16)),
163 | Activate()])
164 |
165 |
166 | class DenseEncoderLayer(nn.Module):
167 | def __init__(self, scale, spatial_size, out_size, in_channels=None):
168 | super().__init__()
169 | self.scale = scale
170 | self.in_channels = 64*min(2**(self.scale-1), 16)
171 | if in_channels is not None:
172 | self.in_channels = in_channels
173 | self.out_channels = out_size
174 | self.kernel_size = spatial_size
175 | self.build()
176 |
177 | def forward(self, input):
178 | x = input
179 | for layer in self.sub_layers:
180 | x = layer(x)
181 | return x
182 |
183 | def build(self):
184 | self.sub_layers = nn.ModuleList([
185 | nn.Conv2d(
186 | in_channels=self.in_channels,
187 | out_channels=self.out_channels,
188 | kernel_size=self.kernel_size,
189 | stride=1,
190 | padding=0,
191 | bias=True)])
192 |
193 |
194 | class DenseDecoderLayer(nn.Module):
195 | def __init__(self, scale, spatial_size, in_size):
196 | super().__init__()
197 | self.scale = scale
198 | self.in_channels = in_size
199 | self.out_channels = 64*min(2**self.scale, 16)
200 | self.kernel_size = spatial_size
201 | self.build()
202 |
203 | def forward(self, input):
204 | x = input
205 | for layer in self.sub_layers:
206 | x = layer(x)
207 | return x
208 |
209 | def build(self):
210 | self.sub_layers = nn.ModuleList([
211 | nn.ConvTranspose2d(
212 | in_channels=self.in_channels,
213 | out_channels=self.out_channels,
214 | kernel_size=self.kernel_size,
215 | stride=1,
216 | padding=0,
217 | bias=True)])
218 |
219 |
220 | class ImageLayer(nn.Module):
221 | def __init__(self, out_channels=3, in_channels=64):
222 | super().__init__()
223 | self.in_channels = in_channels
224 | self.out_channels = out_channels
225 | self.build()
226 |
227 | def forward(self, input):
228 | x = input
229 | for layer in self.sub_layers:
230 | x = layer(x)
231 | return x
232 |
233 | def build(self):
234 | FinalActivate = lambda: torch.nn.Tanh()
235 | self.sub_layers = nn.ModuleList([
236 | nn.ConvTranspose2d(
237 | in_channels=self.in_channels,
238 | out_channels=self.out_channels,
239 | kernel_size=4,
240 | stride=2,
241 | padding=1,
242 | bias=False),
243 | FinalActivate()
244 | ])
245 |
246 |
247 | class Distribution(object):
248 | def __init__(self, parameters, deterministic=False):
249 | self.parameters = parameters
250 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
251 | self.logvar = torch.clamp(self.logvar, -30.0, 10.0)
252 | self.deterministic = deterministic
253 | self.std = torch.exp(0.5*self.logvar)
254 | self.var = torch.exp(self.logvar)
255 | if self.deterministic:
256 | self.var = self.std = torch.zeros_like(self.mean).to(
257 | torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
258 |
259 | def sample(self):
260 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
261 | x = self.mean + self.std*torch.randn(self.mean.shape).to(device)
262 | return x
263 |
264 | def kl(self, other=None):
265 | if self.deterministic:
266 | return torch.Tensor([0.])
267 | else:
268 | if other is None:
269 | return 0.5*torch.sum(torch.pow(self.mean, 2)
270 | + self.var - 1.0 - self.logvar,
271 | dim=[1,2,3])
272 | else:
273 | return 0.5*torch.sum(
274 | torch.pow(self.mean - other.mean, 2) / other.var
275 | + self.var / other.var - 1.0 - self.logvar + other.logvar,
276 | dim=[1,2,3])
277 |
278 | def nll(self, sample):
279 | if self.deterministic:
280 | return torch.Tensor([0.])
281 | logtwopi = np.log(2.0*np.pi)
282 | return 0.5*torch.sum(
283 | logtwopi+self.logvar+torch.pow(sample-self.mean, 2) / self.var,
284 | dim=[1,2,3])
285 |
286 | def mode(self):
287 | return self.mean
288 |
289 |
290 | class Model(nn.Module):
291 | def __init__(self, config):
292 | super().__init__()
293 | import torch.backends.cudnn as cudnn
294 | cudnn.benchmark = True
295 | n_down = retrieve(config, "Model/n_down")
296 | z_dim = retrieve(config, "Model/z_dim")
297 | in_size = retrieve(config, "Model/in_size")
298 | bottleneck_size = in_size // 2**n_down
299 | in_channels = retrieve(config, "Model/in_channels")
300 | norm = retrieve(config, "Model/norm")
301 | self.be_deterministic = retrieve(config, "Model/deterministic")
302 |
303 | self.feature_layers = nn.ModuleList()
304 | self.decoder_layers = nn.ModuleList()
305 |
306 | self.feature_layers.append(FeatureLayer(0, in_channels=in_channels, norm=norm))
307 | for scale in range(1, n_down):
308 | self.feature_layers.append(FeatureLayer(scale, norm=norm))
309 |
310 | self.dense_encode = DenseEncoderLayer(n_down, bottleneck_size, 2*z_dim)
311 | self.dense_decode = DenseDecoderLayer(n_down-1, bottleneck_size, z_dim)
312 |
313 | for scale in range(n_down-1):
314 | self.decoder_layers.append(DecoderLayer(scale, norm=norm))
315 | self.image_layer = ImageLayer(out_channels=in_channels)
316 |
317 | self.apply(weights_init)
318 |
319 | self.n_down = n_down
320 | self.z_dim = z_dim
321 | self.bottleneck_size = bottleneck_size
322 |
323 | def encode(self, input):
324 | h = input
325 | for layer in self.feature_layers:
326 | h = layer(h)
327 | h = self.dense_encode(h)
328 | return Distribution(h, deterministic=self.be_deterministic)
329 |
330 | def decode(self, input):
331 | h = input
332 | h = self.dense_decode(h)
333 | for layer in reversed(self.decoder_layers):
334 | h = layer(h)
335 | h = self.image_layer(h)
336 | return h
337 |
338 | def get_last_layer(self):
339 | return self.image_layer.sub_layers[0].weight
340 |
--------------------------------------------------------------------------------
/iin/models/clf.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | from edflow.util import retrieve
5 |
6 | from iin.models.ae import FeatureLayer, DenseEncoderLayer, weights_init
7 |
8 |
9 | class Distribution(object):
10 | def __init__(self, value):
11 | self.value = value
12 |
13 | def sample(self):
14 | return self.value
15 |
16 | def mode(self):
17 | return self.value
18 |
19 |
20 | class Model(nn.Module):
21 | def __init__(self, config):
22 | super().__init__()
23 | import torch.backends.cudnn as cudnn
24 | cudnn.benchmark = True
25 | n_down = retrieve(config, "Model/n_down")
26 | z_dim = retrieve(config, "Model/z_dim")
27 | in_size = retrieve(config, "Model/in_size")
28 | z_dim = retrieve(config, "Model/z_dim")
29 | bottleneck_size = in_size // 2**n_down
30 | in_channels = retrieve(config, "Model/in_channels")
31 | norm = retrieve(config, "Model/norm")
32 | n_classes = retrieve(config, "n_classes")
33 |
34 | self.feature_layers = nn.ModuleList()
35 |
36 | self.feature_layers.append(FeatureLayer(0, in_channels=in_channels, norm=norm))
37 | for scale in range(1, n_down):
38 | self.feature_layers.append(FeatureLayer(scale, norm=norm))
39 |
40 | self.dense_encode = DenseEncoderLayer(n_down, bottleneck_size, z_dim)
41 | self.classifier = torch.nn.Linear(z_dim, n_classes)
42 |
43 | self.apply(weights_init)
44 |
45 | self.n_down = n_down
46 | self.bottleneck_size = bottleneck_size
47 |
48 | def forward(self, input):
49 | h = self.encode(input).mode()
50 | assert h.shape[2] == h.shape[3] == 1
51 | h = h[:,:,0,0]
52 | h = self.classifier(h)
53 | return h
54 |
55 | def encode(self, input):
56 | h = input
57 | for layer in self.feature_layers:
58 | h = layer(h)
59 | h = self.dense_encode(h)
60 | return Distribution(h)
61 |
--------------------------------------------------------------------------------
/iin/models/iin.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | from edflow.util import retrieve, get_obj_from_str
5 |
6 |
7 | class Shuffle(nn.Module):
8 | def __init__(self, in_channels, **kwargs):
9 | super(Shuffle, self).__init__()
10 | self.in_channels = in_channels
11 | idx = torch.randperm(in_channels)
12 | self.register_buffer('forward_shuffle_idx', nn.Parameter(idx, requires_grad=False))
13 | self.register_buffer('backward_shuffle_idx', nn.Parameter(torch.argsort(idx), requires_grad=False))
14 |
15 | def forward(self, x, reverse=False, conditioning=None):
16 | if not reverse:
17 | return x[:, self.forward_shuffle_idx, ...], 0
18 | else:
19 | return x[:, self.backward_shuffle_idx, ...]
20 |
21 |
22 | class BasicFullyConnectedNet(nn.Module):
23 | def __init__(self, dim, depth, hidden_dim=256, use_tanh=False, use_bn=False, out_dim=None):
24 | super(BasicFullyConnectedNet, self).__init__()
25 | layers = []
26 | layers.append(nn.Linear(dim, hidden_dim))
27 | if use_bn:
28 | layers.append(nn.BatchNorm1d(hidden_dim))
29 | layers.append(nn.LeakyReLU())
30 | for d in range(depth):
31 | layers.append(nn.Linear(hidden_dim, hidden_dim))
32 | if use_bn:
33 | layers.append(nn.BatchNorm1d(hidden_dim))
34 | layers.append(nn.LeakyReLU())
35 | layers.append(nn.Linear(hidden_dim, dim if out_dim is None else out_dim))
36 | if use_tanh:
37 | layers.append(nn.Tanh())
38 | self.main = nn.Sequential(*layers)
39 |
40 | def forward(self, x):
41 | return self.main(x)
42 |
43 |
44 | class DoubleVectorCouplingBlock(nn.Module):
45 | """In contrast to VectorCouplingBlock, this module assures alternating chunking in upper and lower half."""
46 | def __init__(self, in_channels, hidden_dim, depth=2, use_hidden_bn=False, n_blocks=2):
47 | super(DoubleVectorCouplingBlock, self).__init__()
48 | assert in_channels % 2 == 0
49 | self.s = nn.ModuleList([BasicFullyConnectedNet(dim=in_channels // 2, depth=depth, hidden_dim=hidden_dim,
50 | use_tanh=True) for _ in range(n_blocks)])
51 | self.t = nn.ModuleList([BasicFullyConnectedNet(dim=in_channels // 2, depth=depth, hidden_dim=hidden_dim,
52 | use_tanh=False) for _ in range(n_blocks)])
53 |
54 | def forward(self, x, reverse=False):
55 | if not reverse:
56 | logdet = 0
57 | for i in range(len(self.s)):
58 | idx_apply, idx_keep = 0, 1
59 | if i % 2 != 0:
60 | x = torch.cat(torch.chunk(x, 2, dim=1)[::-1], dim=1)
61 | x = torch.chunk(x, 2, dim=1)
62 | scale = self.s[i](x[idx_apply])
63 | x_ = x[idx_keep] * (scale.exp()) + self.t[i](x[idx_apply])
64 | x = torch.cat((x[idx_apply], x_), dim=1)
65 | logdet_ = torch.sum(scale.view(x.size(0), -1), dim=1)
66 | logdet = logdet + logdet_
67 | return x, logdet
68 | else:
69 | idx_apply, idx_keep = 0, 1
70 | for i in reversed(range(len(self.s))):
71 | if i % 2 == 0:
72 | x = torch.cat(torch.chunk(x, 2, dim=1)[::-1], dim=1)
73 | x = torch.chunk(x, 2, dim=1)
74 | x_ = (x[idx_keep] - self.t[i](x[idx_apply])) * (self.s[i](x[idx_apply]).neg().exp())
75 | x = torch.cat((x[idx_apply], x_), dim=1)
76 | return x
77 |
78 |
79 | class VectorActNorm(nn.Module):
80 | def __init__(self, in_channel, logdet=True, **kwargs):
81 | super().__init__()
82 | self.loc = nn.Parameter(torch.zeros(1, in_channel, 1, 1))
83 | self.scale = nn.Parameter(torch.ones(1, in_channel, 1, 1))
84 | self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
85 | self.logdet = logdet
86 |
87 | def initialize(self, input):
88 | if len(input.shape) == 2:
89 | input = input[:, :, None, None]
90 | with torch.no_grad():
91 | flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
92 | mean = (
93 | flatten.mean(1)
94 | .unsqueeze(1)
95 | .unsqueeze(2)
96 | .unsqueeze(3)
97 | .permute(1, 0, 2, 3)
98 | )
99 | std = (
100 | flatten.std(1)
101 | .unsqueeze(1)
102 | .unsqueeze(2)
103 | .unsqueeze(3)
104 | .permute(1, 0, 2, 3)
105 | )
106 |
107 | self.loc.data.copy_(-mean)
108 | self.scale.data.copy_(1 / (std + 1e-6))
109 |
110 | def forward(self, input, reverse=False, conditioning=None):
111 | if len(input.shape) == 2:
112 | input = input[:, :, None, None]
113 | if not reverse:
114 | _, _, height, width = input.shape
115 | if self.initialized.item() == 0:
116 | self.initialize(input)
117 | self.initialized.fill_(1)
118 | log_abs = torch.log(torch.abs(self.scale))
119 | logdet = height * width * torch.sum(log_abs)
120 | logdet = logdet * torch.ones(input.shape[0]).to(input)
121 | if not self.logdet:
122 | return (self.scale * (input + self.loc)).squeeze()
123 | return (self.scale * (input + self.loc)).squeeze(), logdet
124 | else:
125 | return self.reverse(input)
126 |
127 | def reverse(self, output, conditioning=None):
128 | return (output / self.scale - self.loc).squeeze()
129 |
130 |
131 | class Flow(nn.Module):
132 | def __init__(self, module_list, in_channels, hidden_dim, hidden_depth):
133 | super(Flow, self).__init__()
134 | self.in_channels = in_channels
135 | self.flow = nn.ModuleList(
136 | [module(in_channels, hidden_dim=hidden_dim, depth=hidden_depth) for module in module_list])
137 |
138 | def forward(self, x, condition=None, reverse=False):
139 | if not reverse:
140 | logdet = 0
141 | for i in range(len(self.flow)):
142 | x, logdet_ = self.flow[i](x)
143 | logdet = logdet + logdet_
144 | return x, logdet
145 | else:
146 | for i in reversed(range(len(self.flow))):
147 | x = self.flow[i](x, reverse=True)
148 | return x
149 |
150 |
151 | class EfficientVRNVP(nn.Module):
152 | def __init__(self, module_list, in_channels, n_flow, hidden_dim, hidden_depth):
153 | super().__init__()
154 | assert in_channels % 2 == 0
155 | self.flow = nn.ModuleList([Flow(module_list, in_channels, hidden_dim, hidden_depth) for n in range(n_flow)])
156 |
157 | def forward(self, x, condition=None, reverse=False):
158 | if not reverse:
159 | logdet = 0
160 | for i in range(len(self.flow)):
161 | x, logdet_ = self.flow[i](x, condition=condition)
162 | logdet = logdet + logdet_
163 | return x, logdet
164 | else:
165 | for i in reversed(range(len(self.flow))):
166 | x = self.flow[i](x, condition=condition, reverse=True)
167 | return x, None
168 |
169 | def reverse(self, x, condition=None):
170 | return self.flow(x, condition=condition, reverse=True)
171 |
172 |
173 | class VectorTransformer(nn.Module):
174 | def __init__(self, config):
175 | super().__init__()
176 | import torch.backends.cudnn as cudnn
177 | cudnn.benchmark = True
178 | self.config = config
179 |
180 | self.in_channel = retrieve(config, "Transformer/in_channel")
181 | self.n_flow = retrieve(config, "Transformer/n_flow")
182 | self.depth_submodules = retrieve(config, "Transformer/hidden_depth")
183 | self.hidden_dim = retrieve(config, "Transformer/hidden_dim")
184 | modules = [VectorActNorm, DoubleVectorCouplingBlock, Shuffle]
185 | self.realnvp = EfficientVRNVP(modules, self.in_channel, self.n_flow, self.hidden_dim,
186 | hidden_depth=self.depth_submodules)
187 |
188 | def forward(self, input, reverse=False):
189 | if reverse:
190 | return self.reverse(input)
191 | input = input.squeeze()
192 | out, logdet = self.realnvp(input)
193 | return out[:, :, None, None], logdet
194 |
195 | def reverse(self, out):
196 | out = out.squeeze()
197 | return self.realnvp(out, reverse=True)[0][:, :, None, None]
198 |
199 |
200 | class FactorTransformer(VectorTransformer):
201 | def __init__(self, config):
202 | super().__init__(config)
203 | self.n_factors = retrieve(config, "Transformer/n_factors", default=2)
204 | self.factor_config = retrieve(config, "Transformer/factor_config", default=list())
205 |
206 | def forward(self, input):
207 | out, logdet = super().forward(input)
208 | if self.factor_config:
209 | out = torch.split(out, self.factor_config, dim=1)
210 | else:
211 | out = torch.chunk(out, self.n_factors, dim=1)
212 | return out, logdet
213 |
214 | def reverse(self, out):
215 | out = torch.cat(out, dim=1)
216 | return super().reverse(out)
217 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(name='iin',
4 | version='0.1',
5 | description='Code accompanying the paper '
6 | 'A Disentangling Invertible Interpretation Network for Explaining Latent Representations'
7 | 'https://arxiv.org/abs/2004.13166',
8 | author='Esser, Patrick and Rombach, Robin and Ommer, Bjoern ',
9 | packages=find_packages(),
10 | install_requires=[
11 | 'torch>=1.4.0',
12 | 'torchvision>=0.5',
13 | 'numpy>=1.17',
14 | 'scipy>=1.0.1'
15 | ],
16 | zip_safe=False)
17 |
--------------------------------------------------------------------------------