├── README.md
├── assets
├── superresolutionfigure.png
├── teaser.png
├── texttoimage.jpg
└── unpairedtranslationfigure.png
├── configs
├── autoencoder
│ ├── anime_photography_256.yaml
│ ├── celeba_celebahq_ffhq_256.yaml
│ ├── coco256.yaml
│ ├── faces256.yaml
│ ├── faces32.yaml
│ └── portraits_photography_256.yaml
├── creativity
│ ├── anime_photography_256.yaml
│ ├── celeba_celebahq_ffhq_256.yaml
│ └── portraits_photography_256.yaml
└── translation
│ ├── faces32-to-faces256.yaml
│ ├── sbert-to-ae-coco256.yaml
│ ├── sbert-to-bigbigan.yaml
│ └── sbert-to-biggan256.yaml
├── data
├── WORDMAP_coco_5_cap_per_img_5_min_word_freq.json
├── animegwerncroptrain.txt
├── animegwerncropvalidation.txt
├── celebahqtrain.txt
├── celebahqvalidation.txt
├── coco_imagenet_overlap_idx.txt
├── examples
│ ├── anime
│ │ ├── 176890.jpg
│ │ ├── 1931960.jpg
│ │ ├── 2775531.jpg
│ │ ├── 3007790.jpg
│ │ ├── 331480.jpg
│ │ ├── 348930.jpg
│ │ ├── 4460500.jpg
│ │ ├── 499130.jpg
│ │ ├── 519280.jpg
│ │ ├── 656160.jpg
│ │ ├── 708030.jpg
│ │ ├── 881700.jpg
│ │ ├── 903790.jpg
│ │ ├── 910920.jpg
│ │ └── 949780.jpg
│ ├── humanface
│ │ ├── 00010.png
│ │ ├── 00012.png
│ │ ├── 00018.png
│ │ ├── 00048.png
│ │ ├── 00059.png
│ │ ├── 65710.png
│ │ ├── 65719.png
│ │ ├── 65732.png
│ │ ├── 65831.png
│ │ ├── 65843.png
│ │ ├── 65855.png
│ │ ├── 65946.png
│ │ ├── 65949.png
│ │ ├── 65990.png
│ │ ├── 65998.png
│ │ ├── 65999.png
│ │ ├── 66040.png
│ │ ├── 66048.png
│ │ ├── 66056.png
│ │ ├── 66090.png
│ │ ├── 66100.png
│ │ ├── 66118.png
│ │ └── 66157.png
│ └── oilportrait
│ │ ├── beethoven.jpeg
│ │ ├── descartes.jpeg
│ │ ├── fermat.jpg
│ │ ├── galileo.jpeg
│ │ ├── jeanjaques.jpeg
│ │ ├── kant.jpeg
│ │ ├── monalisa.jpeg
│ │ └── newton.jpeg
├── ffhqtrain.txt
├── ffhqvalidation.txt
├── portraitstrain.txt
└── portraitsvalidation.txt
├── env_bigbigan.yaml
├── environment.yaml
├── ml4cad.py
├── net2net
├── __init__.py
├── ckpt_util.py
├── data
│ ├── __init__.py
│ ├── base.py
│ ├── coco.py
│ ├── faces.py
│ ├── utils.py
│ └── zcodes.py
├── models
│ ├── __init__.py
│ ├── autoencoder.py
│ └── flows
│ │ ├── __init__.py
│ │ ├── flow.py
│ │ └── util.py
└── modules
│ ├── __init__.py
│ ├── autoencoder
│ ├── __init__.py
│ ├── basic.py
│ ├── decoder.py
│ ├── encoder.py
│ ├── loss.py
│ └── lpips.py
│ ├── captions
│ ├── __init__.py
│ ├── model.py
│ └── models.py
│ ├── discriminator
│ ├── __init__.py
│ └── model.py
│ ├── distributions
│ ├── __init__.py
│ └── distributions.py
│ ├── facenet
│ ├── __init__.py
│ ├── inception_resnet_v1.py
│ └── model.py
│ ├── flow
│ ├── __init__.py
│ ├── base.py
│ ├── blocks.py
│ ├── flatflow.py
│ └── loss.py
│ ├── gan
│ ├── __init__.py
│ ├── bigbigan.py
│ └── biggan.py
│ ├── labels
│ ├── __init__.py
│ └── model.py
│ ├── mlp
│ ├── __init__.py
│ └── models.py
│ ├── sbert
│ ├── __init__.py
│ └── model.py
│ └── util.py
├── setup.py
└── translation.py
/README.md:
--------------------------------------------------------------------------------
1 | # Net2Net
2 | Code accompanying the NeurIPS 2020 oral paper
3 |
4 | [**Network-to-Network Translation with Conditional Invertible Neural Networks**](https://compvis.github.io/net2net/)
5 | [Robin Rombach](https://github.com/rromb)\*,
6 | [Patrick Esser](https://github.com/pesser)\*,
7 | [Björn Ommer](https://hci.iwr.uni-heidelberg.de/Staff/bommer)
8 | \* equal contribution
9 |
10 | **tl;dr** Our approach distills the residual information of one model with respect to
11 | another's and thereby enables translation between fixed off-the-shelf expert
12 | models such as BERT and BigGAN without having to modify or finetune them.
13 |
14 | 
15 | [arXiv](https://arxiv.org/abs/2005.13580) | [BibTeX](#bibtex) | [Project Page](https://compvis.github.io/net2net/)
16 |
17 | **News Dec 19th, 2020**: added SBERT-to-BigGAN, SBERT-to-BigBiGAN and SBERT-to-AE (COCO)
18 | ## Requirements
19 | A suitable [conda](https://conda.io/) environment named `net2net` can be created
20 | and activated with:
21 |
22 | ```
23 | conda env create -f environment.yaml
24 | conda activate net2net
25 | ```
26 |
27 | ## Datasets
28 | - **CelebA**: Create a symlink 'data/CelebA' pointing to a folder which contains the following files:
29 | ```
30 | .
31 | ├── identity_CelebA.txt
32 | ├── img_align_celeba
33 | ├── list_attr_celeba.txt
34 | └── list_eval_partition.txt
35 | ```
36 | These files can be obtained [here](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html).
37 | - **CelebA-HQ**: Create a symlink `data/celebahq` pointing to a folder containing
38 | the `.npy` files of CelebA-HQ (instructions to obtain them can be found in
39 | the [PGGAN repository](https://github.com/tkarras/progressive_growing_of_gans)).
40 | - **FFHQ**: Create a symlink `data/ffhq` pointing to the `images1024x1024` folder
41 | obtained from the [FFHQ repository](https://github.com/NVlabs/ffhq-dataset).
42 | - **Anime Faces**: First download the face images from the [Anime Crop dataset](https://www.gwern.net/Crops) and then apply
43 | the preprocessing of [FFHQ](https://github.com/NVlabs/ffhq-dataset) to those images. We only keep images
44 | where the underlying [dlib face recognition model](http://dlib.net/face_landmark_detection.py.html) recognizes
45 | a face. Finally, create a symlink `data/anime` which contains the processed anime face images.
46 | - **Oil Portraits**: [Download here.](https://heibox.uni-heidelberg.de/f/4f35bdc16eea4158aa47/?dl=1)
47 | Unpack the content and place the files in `data/portraits`. It consists of
48 | 18k oil portraits, which were obtained by running [dlib](http://dlib.net/face_landmark_detection.py.html) on a subset of the [WikiArt dataset](https://www.wikiart.org/)
49 | dataset, kindly provided by [A Style-Aware Content Loss for Real-time HD Style Transfer](https://github.com/CompVis/adaptive-style-transfer).
50 | - **COCO**: Create a symlink `data/coco` containing the images from the 2017
51 | split in `train2017` and `val2017`, and their annotations in `annotations`.
52 | Files can be obtained from the [COCO webpage](https://cocodataset.org).
53 |
54 | ## ML4Creativity Demo
55 | We include a [streamlit](https://www.streamlit.io/) demo, which utilizes our
56 | approach to demonstrate biases of datasets and their creative applications.
57 | More information can be found in our paper [A Note on Data Biases in Generative
58 | Models](https://drive.google.com/file/d/1PGhBTEAgj2A_FnYMk_1VU-uOxcWY076B/view?usp=sharing) from the [Machine Learning for Creativity and Design](https://neurips2020creativity.github.io/) at [NeurIPS 2020](https://nips.cc/Conferences/2020). Download the models from
59 |
60 | - [2020-11-30T23-32-28_celeba_celebahq_ffhq_256](https://k00.fr/lro927bu)
61 | - [2020-12-02T13-58-19_anime_photography_256](https://heibox.uni-heidelberg.de/d/075e81e16de948aea7a1/)
62 | - [2020-12-02T16-19-39_portraits_photography_256](https://k00.fr/y3rvnl3j)
63 |
64 |
65 | and place them into `logs`. Run the demo with
66 |
67 | ```
68 | streamlit run ml4cad.py
69 | ```
70 |
71 | ## Training
72 | Our code uses [Pytorch-Lightning](https://www.pytorchlightning.ai/) and thus natively supports
73 | things like 16-bit precision, multi-GPU training and gradient accumulation. Training details for any model need to be specified in a dedicated `.yaml` file.
74 | In general, such a config file is structured as follows:
75 | ```
76 | model:
77 | base_learning_rate: 4.5e-6
78 | target:
79 | params:
80 | ...
81 | data:
82 | target: translation.DataModuleFromConfig
83 | params:
84 | batch_size: ...
85 | num_workers: ...
86 | train:
87 | target:
88 | params:
89 | ...
90 | validation:
91 | target:
92 | params:
93 | ...
94 | ```
95 | Any Pytorch-Lightning model specified under `model.target` is then trained on the specified data
96 | by running the command:
97 | ```
98 | python translation.py --base -t --gpus 0,
99 | ```
100 | All available Pytorch-Lightning [trainer](https://pytorch-lightning.readthedocs.io/en/stable/trainer.html) arguments can be added via the command line, e.g. run
101 | ```
102 | python translation.py --base -t --gpus 0,1,2,3 --precision 16 --accumulate_grad_batches 2
103 | ```
104 | to train a model on 4 GPUs using 16-bit precision and a 2-step gradient accumulation.
105 | More details are provided in the examples below.
106 |
107 | ### Training a cINN
108 | Training a cINN for network-to-network translation usually utilizes the Lighnting Module `net2net.models.flows.flow.Net2NetFlow`
109 | and makes a few further assumptions on the configuration file and model interface:
110 | ```
111 | model:
112 | base_learning_rate: 4.5e-6
113 | target: net2net.models.flows.flow.Net2NetFlow
114 | params:
115 | flow_config:
116 | target:
117 | params:
118 | ...
119 |
120 | cond_stage_config:
121 | target:
122 | params:
123 | ...
124 |
125 | first_stage_config:
126 | target:
127 | params:
128 | ...
129 | ```
130 | Here, the entries under `flow_config` specifies the architecture and parameters of the conditional INN;
131 | `cond_stage_config` specifies the first network whose representation is to be translated into another network
132 | specified by `first_stage_config`. Our model `net2net.models.flows.flow.Net2NetFlow` expects that the first
133 | network has a `.encode()` method which produces the representation of interest, while the second network should
134 | have an `encode()` and a `decode()` method, such that both of them applied sequentially produce the networks output. This allows for a modular combination of arbitrary models of interest. For more details, see the examples below.
135 |
136 | ### Training a cINN - Superresolution
137 | 
138 | Training details for a cINN to concatenate two autoencoders from different image scales for stochastic
139 | superresolution are specified in `configs/translation/faces32-to-256.yaml`.
140 |
141 | To train a model for translating from 32 x 32 images to 256 x 256 images on GPU 0, run
142 | ```
143 | python translation.py --base configs/translation/faces32-to-faces256.yaml -t --gpus 0,
144 | ```
145 | and specify any additional training commands as described above. Note that this setup requires two
146 | pretrained autoencoder models, one on 32 x 32 images and the other on 256 x 256. If you want to
147 | train them yourself on a combination of FFHQ and CelebA-HQ, run
148 | ```
149 | python translation.py --base configs/autoencoder/faces32.yaml -t --gpus ,
150 | ```
151 | for the 32 x 32 images; and
152 | ```
153 | python translation.py --base configs/autoencoder/faces256.yaml -t --gpus ,
154 | ```
155 | for the model on 256 x 256 images. After training, adopt the corresponding model paths in `configs/translation/faces32-to-faces256.yaml`. Additionally, we provide weights of pretrained autoencoders for both settings:
156 | [Weights 32x32](https://heibox.uni-heidelberg.de/f/b0b103af8406467abe48/); [Weights256x256](https://k00.fr/94lw2vlg).
157 | To run the training as described above, put them into
158 | `logs/2020-10-16T17-11-42_FacesFQ32x32/checkpoints/last.ckpt`and
159 | `logs/2020-09-16T16-23-39_FacesXL256z128/checkpoints/last.ckpt`, respectively.
160 |
161 | ### Training a cINN - Unpaired Translation
162 | 
163 | All training scenarios for unpaired translation are specified in the configs in `configs/creativity`.
164 | We provide code and pretrained autoencoder models for three different translation tasks:
165 | - **Anime** ⟷ **Photography**; see `configs/creativity/anime_photography_256.yaml`.
166 | Download autoencoder checkpoint ([Download Anime+Photography](https://heibox.uni-heidelberg.de/f/315c628c8b0e40238132/)) and place into `logs/2020-09-30T21-40-22_AnimeAndFHQ/checkpoints/epoch=000007.ckpt`.
167 | - **Oil-Portrait** ⟷ **Photography**; see `configs/creativity/portraits_photography_256.yaml`
168 | Download autoencoder checkpoint ([Download Portrait+Photography](https://heibox.uni-heidelberg.de/f/4f9449418a2e4025bb5f/)) and place into `logs/2020-09-29T23-47-10_PortraitsAndFFHQ/checkpoints/epoch=000004.ckpt`.
169 | - **FFHQ** ⟷ **CelebA-HQ** ⟷ **CelebA**; see `configs/creativity/celeba_celebahq_ffhq_256.yaml`
170 | Download autoencoder checkpoint ([Download FFHQ+CelebAHQ+CelebA](https://k00.fr/94lw2vlg)) and place into `logs/2020-09-16T16-23-39_FacesXL256z128/checkpoints/last.ckpt`.
171 | Note that this is the same autoencoder checkpoint as for the stochastic superresolution experiment.
172 |
173 | To train a cINN on one of these unpaired transfer tasks using the first GPU, simply run
174 | ```
175 | python translation.py --base configs/creativity/.yaml -t --gpus 0,
176 | ```
177 | where `.yaml` is one of `portraits_photography_256.yaml`, `celeba_celebahq_ffhq_256.yaml`
178 | or `anime_photography_256.yaml`. Providing additional arguments to the pytorch-lightning
179 | trainer object is also possible as described above.
180 |
181 | In our framework, unpaired translation between domains is formulated as a
182 | translation between expert 1, a model which can infer the domain a given image
183 | belongs to, and expert 2, a model which can synthesize images of each domain.
184 | In the examples provided, we assume that the domain label comes with the
185 | dataset and provide the `net2net.modules.labels.model.Labelator` module, which
186 | simply returns a one hot encoding of this label. However, one could also use a
187 | classification model which infers the domain label from the image itself.
188 | For expert 2, our examples use an autoencoder trained jointly on all domains,
189 | which is easily achieved by concatenating datasets together. The provided
190 | `net2net.data.base.ConcatDatasetWithIndex` concatenates datasets and returns
191 | the corresponding dataset label for each example, which can then be used by the
192 | `Labelator` class for the translation. The training configurations for the
193 | autoencoders used in the creativity experiments are included in
194 | `configs/autoencoder/anime_photography_256.yaml`,
195 | `configs/autoencoder/celeba_celebahq_ffhq_256.yaml` and
196 | `configs/autoencoder/portraits_photography_256.yaml`.
197 |
198 | #### Unpaired Translation on Custom Datasets
199 | Create pytorch datasets for each
200 | of your domains, create a concatenated dataset with `ConcatDatasetWithIndex`
201 | (follow the example in `net2net.data.faces.CCFQTrain`), train an
202 | autoencoder on the concatenated dataset (adjust the `data` section in
203 | `configs/autoencoder/celeba_celebahq_ffhq_256.yaml`) and finally train a
204 | net2net translation model between a `Labelator` and your autoencoder (adjust
205 | the sections `data` and `first_stage_config` in
206 | `configs/creativity/celeba_celebahq_ffhq_256.yaml`). You can then also add your
207 | new model to the available modes in the `ml4cad.py` demo to visualize the
208 | results.
209 |
210 |
211 | ### Training a cINN - Text-to-Image
212 | 
213 | We provide code to obtain a text-to-image model by translating between a text
214 | model ([SBERT](https://www.sbert.net/)) and an image decoder. To show the
215 | flexibility of our approach, we include code for three different
216 | decoders: BigGAN, as described in the paper,
217 | [BigBiGAN](https://deepmind.com/research/open-source/BigBiGAN-Large-Scale-Adversarial-Representation-Learning),
218 | which is only available as a [tensorflow](https://www.tensorflow.org/) model
219 | and thus nicely shows how our approach can work with black-box experts, and an
220 | autoencoder.
221 |
222 | #### SBERT-to-BigGAN
223 | Train with
224 | ```
225 | python translation.py --base configs/translation/sbert-to-biggan256.yaml -t --gpus 0,
226 | ```
227 | When running it for the first time, the required models will be downloaded
228 | automatically.
229 |
230 | #### SBERT-to-BigBiGAN
231 | Since BigBiGAN is only available on
232 | [tensorflow-hub](https://tfhub.dev/s?q=bigbigan), this example has an
233 | additional dependency on tensorflow. A suitable environment is provided in
234 | `env_bigbigan.yaml`, and you will need COCO for training. You can then start
235 | training with
236 | ```
237 | python translation.py --base configs/translation/sbert-to-bigbigan.yaml -t --gpus 0,
238 | ```
239 | Note that the `BigBiGAN` class is just a naive wrapper, which converts pytorch
240 | tensors to numpy arrays, feeds them to the tensorflow graph and again converts
241 | the result to pytorch tensors. It does not require gradients of the expert
242 | model and serves as a good example on how to use black-box experts.
243 |
244 | #### SBERT-to-AE
245 | Similarly to the other examples, you can also train your own autoencoder on
246 | COCO with
247 | ```
248 | python translation.py --base configs/autoencoder/coco256.yaml -t --gpus 0,
249 | ```
250 | or [download a pre-trained
251 | one](https://k00.fr/fbti4058), and translate
252 | to it by running
253 | ```
254 | python translation.py --base configs/translation/sbert-to-ae-coco256.yaml -t --gpus 0,
255 | ```
256 |
257 | ## Shout-outs
258 | Thanks to everyone who makes their code and models available.
259 |
260 | - BigGAN code and weights from: [LoreGoetschalckx/GANalyze](https://github.com/LoreGoetschalckx/GANalyze)
261 | - Code and weights for the captioning model: [https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning)
262 |
263 |
264 | ## BibTeX
265 |
266 | ```
267 | @misc{rombach2020networktonetwork,
268 | title={Network-to-Network Translation with Conditional Invertible Neural Networks},
269 | author={Robin Rombach and Patrick Esser and Björn Ommer},
270 | year={2020},
271 | eprint={2005.13580},
272 | archivePrefix={arXiv},
273 | primaryClass={cs.CV}
274 | }
275 | ```
276 |
277 | ```
278 | @misc{esser2020note,
279 | title={A Note on Data Biases in Generative Models},
280 | author={Patrick Esser and Robin Rombach and Björn Ommer},
281 | year={2020},
282 | eprint={2012.02516},
283 | archivePrefix={arXiv},
284 | primaryClass={cs.CV}
285 | }
286 | ```
287 |
--------------------------------------------------------------------------------
/assets/superresolutionfigure.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/assets/superresolutionfigure.png
--------------------------------------------------------------------------------
/assets/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/assets/teaser.png
--------------------------------------------------------------------------------
/assets/texttoimage.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/assets/texttoimage.jpg
--------------------------------------------------------------------------------
/assets/unpairedtranslationfigure.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/assets/unpairedtranslationfigure.png
--------------------------------------------------------------------------------
/configs/autoencoder/anime_photography_256.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-6
3 | target: net2net.models.autoencoder.BigAE
4 | params:
5 | loss_config:
6 | target: net2net.modules.autoencoder.loss.LPIPSWithDiscriminator
7 | params:
8 | disc_start: 75001
9 | kl_weight: 0.000001
10 | disc_weight: 0.5
11 |
12 | encoder_config:
13 | target: net2net.modules.autoencoder.encoder.ResnetEncoder
14 | params:
15 | in_channels: 3
16 | in_size: 256
17 | pretrained: false
18 | type: resnet101
19 | z_dim: 128
20 |
21 | decoder_config:
22 | target: net2net.modules.autoencoder.decoder.BigGANDecoderWrapper
23 | params:
24 | z_dim: 128
25 | in_size: 256
26 | use_actnorm_in_dec: true
27 |
28 | data:
29 | target: translation.DataModuleFromConfig
30 | params:
31 | batch_size: 3
32 | train:
33 | target: net2net.data.faces.FacesHQAndAnimeTrain
34 | params:
35 | size: 256
36 | validation:
37 | target: net2net.data.faces.FacesHQAndAnimeValidation
38 | params:
39 | size: 256
40 |
--------------------------------------------------------------------------------
/configs/autoencoder/celeba_celebahq_ffhq_256.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-6
3 | target: net2net.models.autoencoder.BigAE
4 | params:
5 | loss_config:
6 | target: net2net.modules.autoencoder.loss.LPIPSWithDiscriminator
7 | params:
8 | disc_start: 75001
9 | kl_weight: 0.000001
10 | disc_weight: 0.5
11 |
12 | encoder_config:
13 | target: net2net.modules.autoencoder.encoder.ResnetEncoder
14 | params:
15 | in_channels: 3
16 | in_size: 256
17 | pretrained: false
18 | type: resnet101
19 | z_dim: 128
20 |
21 | decoder_config:
22 | target: net2net.modules.autoencoder.decoder.BigGANDecoderWrapper
23 | params:
24 | z_dim: 128
25 | in_size: 256
26 | use_actnorm_in_dec: true
27 |
28 |
29 | data:
30 | target: translation.DataModuleFromConfig
31 | params:
32 | batch_size: 3
33 | train:
34 | target: net2net.data.faces.CCFQTrain
35 | params:
36 | size: 256
37 | validation:
38 | target: net2net.data.faces.CCFQValidation
39 | params:
40 | size: 256
41 |
--------------------------------------------------------------------------------
/configs/autoencoder/coco256.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: net2net.models.autoencoder.BigAE
4 | params:
5 | loss_config:
6 | target: net2net.modules.autoencoder.loss.LPIPSWithDiscriminator
7 | params:
8 | disc_start: 250001
9 | kl_weight: 0.000001
10 | disc_weight: 0.5
11 |
12 | encoder_config:
13 | target: net2net.modules.autoencoder.encoder.ResnetEncoder
14 | params:
15 | in_channels: 3
16 | in_size: 256
17 | pretrained: false
18 | type: resnet101
19 | z_dim: 256
20 |
21 | decoder_config:
22 | target: net2net.modules.autoencoder.decoder.BigGANDecoderWrapper
23 | params:
24 | z_dim: 256
25 | in_size: 256
26 | use_actnorm_in_dec: true
27 |
28 | data:
29 | target: translation.DataModuleFromConfig
30 | params:
31 | batch_size: 3
32 | train:
33 | target: net2net.data.coco.CocoImagesAndCaptionsTrain
34 | params:
35 | size: 256
36 | validation:
37 | target: net2net.data.coco.CocoImagesAndCaptionsValidation
38 | params:
39 | size: 256
40 |
--------------------------------------------------------------------------------
/configs/autoencoder/faces256.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-6
3 | target: net2net.models.autoencoder.BigAE
4 | params:
5 | loss_config:
6 | target: net2net.modules.autoencoder.loss.LPIPSWithDiscriminator
7 | params:
8 | disc_start: 75001
9 | kl_weight: 0.000001
10 | disc_weight: 0.5
11 |
12 | encoder_config:
13 | target: net2net.modules.autoencoder.encoder.ResnetEncoder
14 | params:
15 | in_channels: 3
16 | in_size: 256
17 | pretrained: false
18 | type: resnet101
19 | z_dim: 128
20 |
21 | decoder_config:
22 | target: net2net.modules.autoencoder.decoder.BigGANDecoderWrapper
23 | params:
24 | z_dim: 128
25 | in_size: 256
26 | use_actnorm_in_dec: true
27 |
28 |
29 | data:
30 | target: translation.DataModuleFromConfig
31 | params:
32 | batch_size: 3
33 | train:
34 | target: net2net.data.faces.CelebFQTrain
35 | params:
36 | size: 256
37 | validation:
38 | target: net2net.data.faces.CelebFQValidation
39 | params:
40 | size: 256
41 |
--------------------------------------------------------------------------------
/configs/autoencoder/faces32.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-6
3 | target: net2net.models.autoencoder.BasicAE
4 | params:
5 | ae_config:
6 | target: net2net.modules.autoencoder.basic.BasicAEModel
7 | params:
8 | in_size: 32
9 | n_down: 4
10 | z_dim: 128
11 | in_channels: 3
12 | deterministic: False
13 |
14 | loss_config:
15 | target: net2net.modules.autoencoder.loss.LPIPSWithDiscriminator
16 | params:
17 | disc_start: 10001
18 | kl_weight: 0.000001
19 | disc_weight: 0.5
20 |
21 |
22 | data:
23 | target: translation.DataModuleFromConfig
24 | params:
25 | batch_size: 12
26 | train:
27 | target: net2net.data.faces.CelebFQTrain
28 | params:
29 | size: 32
30 | validation:
31 | target: net2net.data.faces.CelebFQValidation
32 | params:
33 | size: 32
34 |
--------------------------------------------------------------------------------
/configs/autoencoder/portraits_photography_256.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-6
3 | target: net2net.models.autoencoder.BigAE
4 | params:
5 | loss_config:
6 | target: net2net.modules.autoencoder.loss.LPIPSWithDiscriminator
7 | params:
8 | disc_start: 75001
9 | kl_weight: 0.000001
10 | disc_weight: 0.5
11 |
12 | encoder_config:
13 | target: net2net.modules.autoencoder.encoder.ResnetEncoder
14 | params:
15 | in_channels: 3
16 | in_size: 256
17 | pretrained: false
18 | type: resnet101
19 | z_dim: 128
20 |
21 | decoder_config:
22 | target: net2net.modules.autoencoder.decoder.BigGANDecoderWrapper
23 | params:
24 | z_dim: 128
25 | in_size: 256
26 | use_actnorm_in_dec: true
27 |
28 |
29 | data:
30 | target: translation.DataModuleFromConfig
31 | params:
32 | batch_size: 3
33 | train:
34 | target: net2net.data.faces.FFHQAndPortraitsTrain
35 | params:
36 | size: 256
37 | validation:
38 | target: net2net.data.faces.FFHQAndPortraitsValidation
39 | params:
40 | size: 256
41 |
--------------------------------------------------------------------------------
/configs/creativity/anime_photography_256.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-6
3 | target: net2net.models.flows.flow.Net2NetFlow
4 | params:
5 | first_stage_key: "image"
6 | cond_stage_key: "class"
7 | flow_config:
8 | target: net2net.modules.flow.flatflow.ConditionalFlatCouplingFlow
9 | params:
10 | conditioning_dim: 2
11 | embedding_dim: 10
12 | conditioning_depth: 2
13 | n_flows: 20
14 | in_channels: 128
15 | hidden_dim: 1024
16 | hidden_depth: 2
17 | activation: "none"
18 | conditioner_use_bn: True
19 |
20 | cond_stage_config:
21 | target: net2net.modules.labels.model.Labelator
22 | params:
23 | num_classes: 2
24 | as_one_hot: True
25 |
26 | first_stage_config:
27 | target: net2net.models.autoencoder.BigAE
28 | params:
29 | ckpt_path: "logs/2020-09-30T21-40-22_AnimeAndFHQ/checkpoints/epoch=000007.ckpt"
30 | encoder_config:
31 | target: net2net.modules.autoencoder.encoder.ResnetEncoder
32 | params:
33 | in_channels: 3
34 | in_size: 256
35 | pretrained: false
36 | type: resnet101
37 | z_dim: 128
38 |
39 | decoder_config:
40 | target: net2net.modules.autoencoder.decoder.BigGANDecoderWrapper
41 | params:
42 | z_dim: 128
43 | in_size: 256
44 | use_actnorm_in_dec: true
45 |
46 | loss_config:
47 | target: net2net.modules.autoencoder.loss.DummyLoss
48 |
49 | data:
50 | target: translation.DataModuleFromConfig
51 | params:
52 | batch_size: 15
53 | train:
54 | target: net2net.data.faces.FacesHQAndAnimeTrain
55 | params:
56 | size: 256
57 | validation:
58 | target: net2net.data.faces.FacesHQAndAnimeValidation
59 | params:
60 | size: 256
61 |
--------------------------------------------------------------------------------
/configs/creativity/celeba_celebahq_ffhq_256.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-6
3 | target: net2net.models.flows.flow.Net2NetFlow
4 | params:
5 | first_stage_key: "image"
6 | cond_stage_key: "class"
7 | flow_config:
8 | target: net2net.modules.flow.flatflow.ConditionalFlatCouplingFlow
9 | params:
10 | conditioning_dim: 3
11 | embedding_dim: 10
12 | conditioning_depth: 2
13 | n_flows: 20
14 | in_channels: 128
15 | hidden_dim: 1024
16 | hidden_depth: 2
17 | activation: "none"
18 | conditioner_use_bn: True
19 |
20 | cond_stage_config:
21 | target: net2net.modules.labels.model.Labelator
22 | params:
23 | num_classes: 3
24 | as_one_hot: True
25 |
26 | first_stage_config:
27 | target: net2net.models.autoencoder.BigAE
28 | params:
29 | ckpt_path: "logs/2020-09-16T16-23-39_FacesXL256z128/checkpoints/last.ckpt"
30 | encoder_config:
31 | target: net2net.modules.autoencoder.encoder.ResnetEncoder
32 | params:
33 | in_channels: 3
34 | in_size: 256
35 | pretrained: false
36 | type: resnet101
37 | z_dim: 128
38 |
39 | decoder_config:
40 | target: net2net.modules.autoencoder.decoder.BigGANDecoderWrapper
41 | params:
42 | z_dim: 128
43 | in_size: 256
44 | use_actnorm_in_dec: true
45 |
46 | loss_config:
47 | target: net2net.modules.autoencoder.loss.DummyLoss
48 |
49 | data:
50 | target: translation.DataModuleFromConfig
51 | params:
52 | batch_size: 15
53 | train:
54 | target: net2net.data.faces.CCFQTrain
55 | params:
56 | size: 256
57 | validation:
58 | target: net2net.data.faces.CCFQValidation
59 | params:
60 | size: 256
61 |
--------------------------------------------------------------------------------
/configs/creativity/portraits_photography_256.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-6
3 | target: net2net.models.flows.flow.Net2NetFlow
4 | params:
5 | first_stage_key: "image"
6 | cond_stage_key: "class"
7 | flow_config:
8 | target: net2net.modules.flow.flatflow.ConditionalFlatCouplingFlow
9 | params:
10 | conditioning_dim: 2
11 | embedding_dim: 10
12 | conditioning_depth: 2
13 | n_flows: 20
14 | in_channels: 128
15 | hidden_dim: 1024
16 | hidden_depth: 2
17 | activation: "none"
18 | conditioner_use_bn: True
19 |
20 | cond_stage_config:
21 | target: net2net.modules.labels.model.Labelator
22 | params:
23 | num_classes: 2
24 | as_one_hot: True
25 |
26 | first_stage_config:
27 | target: net2net.models.autoencoder.BigAE
28 | params:
29 | ckpt_path: "logs/2020-09-29T23-47-10_PortraitsAndFFHQ/checkpoints/epoch=000004.ckpt"
30 | encoder_config:
31 | target: net2net.modules.autoencoder.encoder.ResnetEncoder
32 | params:
33 | in_channels: 3
34 | in_size: 256
35 | pretrained: false
36 | type: resnet101
37 | z_dim: 128
38 |
39 | decoder_config:
40 | target: net2net.modules.autoencoder.decoder.BigGANDecoderWrapper
41 | params:
42 | z_dim: 128
43 | in_size: 256
44 | use_actnorm_in_dec: true
45 |
46 | loss_config:
47 | target: net2net.modules.autoencoder.loss.DummyLoss
48 |
49 | data:
50 | target: translation.DataModuleFromConfig
51 | params:
52 | batch_size: 15
53 | train:
54 | target: net2net.data.faces.FFHQAndPortraitsTrain
55 | params:
56 | size: 256
57 | validation:
58 | target: net2net.data.faces.FFHQAndPortraitsValidation
59 | params:
60 | size: 256
61 |
--------------------------------------------------------------------------------
/configs/translation/faces32-to-faces256.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-6
3 | target: net2net.models.flows.flow.Net2NetFlow
4 | params:
5 | first_stage_key: "image"
6 | cond_stage_key: "image"
7 | interpolate_cond_size: 32
8 | flow_config:
9 | target: net2net.modules.flow.flatflow.ConditionalFlatCouplingFlow
10 | params:
11 | conditioning_dim: 128
12 | embedding_dim: 128
13 | conditioning_depth: 2
14 | n_flows: 20
15 | in_channels: 128
16 | hidden_dim: 1024
17 | hidden_depth: 2
18 | activation: "none"
19 | conditioner_use_bn: True
20 |
21 | cond_stage_config:
22 | target: net2net.models.autoencoder.BasicAE
23 | params:
24 | ckpt_path: "logs/2020-10-16T17-11-42_FacesFQ32x32/checkpoints/last.ckpt"
25 | ae_config:
26 | target: net2net.modules.autoencoder.basic.BasicAEModel
27 | params:
28 | in_size: 32
29 | n_down: 4
30 | z_dim: 128
31 | in_channels: 3
32 | deterministic: False
33 |
34 | loss_config:
35 | target: net2net.modules.autoencoder.loss.DummyLoss # dummy
36 |
37 | first_stage_config:
38 | target: net2net.models.autoencoder.BigAE
39 | params:
40 | ckpt_path: "logs/2020-09-16T16-23-39_FacesXL256z128/checkpoints/last.ckpt"
41 |
42 | encoder_config:
43 | target: net2net.modules.autoencoder.encoder.ResnetEncoder
44 | params:
45 | in_channels: 3
46 | in_size: 256
47 | pretrained: false
48 | type: resnet101
49 | z_dim: 128
50 |
51 | decoder_config:
52 | target: net2net.modules.autoencoder.decoder.BigGANDecoderWrapper
53 | params:
54 | z_dim: 128
55 | in_size: 256
56 | use_actnorm_in_dec: true
57 |
58 | loss_config:
59 | target: net2net.modules.autoencoder.loss.DummyLoss
60 |
61 | data:
62 | target: translation.DataModuleFromConfig
63 | params:
64 | batch_size: 6
65 | train:
66 | target: net2net.data.faces.CelebFQTrain
67 | params:
68 | size: 256
69 | validation:
70 | target: net2net.data.faces.CelebFQValidation
71 | params:
72 | size: 256
73 |
--------------------------------------------------------------------------------
/configs/translation/sbert-to-ae-coco256.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-6
3 | target: net2net.models.flows.flow.Net2NetFlow
4 | params:
5 | first_stage_key: "image"
6 | cond_stage_key: "caption"
7 | flow_config:
8 | target: net2net.modules.flow.flatflow.ConditionalFlatCouplingFlow
9 | params:
10 | conditioning_dim: 1024
11 | embedding_dim: 256
12 | conditioning_depth: 2
13 | n_flows: 24
14 | in_channels: 256
15 | hidden_dim: 1024
16 | hidden_depth: 2
17 | activation: "none"
18 | conditioner_use_bn: True
19 |
20 | cond_stage_config:
21 | target: net2net.modules.sbert.model.SentenceEmbedder
22 | params:
23 | version: "bert-large-nli-stsb-mean-tokens"
24 |
25 | first_stage_config:
26 | target: net2net.models.autoencoder.BigAE
27 | params:
28 | ckpt_path: "logs/2020-12-18T22-49-43_coco256/checkpoints/last.ckpt"
29 |
30 | encoder_config:
31 | target: net2net.modules.autoencoder.encoder.ResnetEncoder
32 | params:
33 | in_channels: 3
34 | in_size: 256
35 | pretrained: false
36 | type: resnet101
37 | z_dim: 256
38 |
39 | decoder_config:
40 | target: net2net.modules.autoencoder.decoder.BigGANDecoderWrapper
41 | params:
42 | z_dim: 256
43 | in_size: 256
44 | use_actnorm_in_dec: true
45 |
46 | loss_config:
47 | target: net2net.modules.autoencoder.loss.DummyLoss
48 |
49 | data:
50 | target: translation.DataModuleFromConfig
51 | params:
52 | batch_size: 16
53 | train:
54 | target: net2net.data.coco.CocoImagesAndCaptionsTrain
55 | params:
56 | size: 256
57 | validation:
58 | target: net2net.data.coco.CocoImagesAndCaptionsValidation
59 | params:
60 | size: 256
61 |
--------------------------------------------------------------------------------
/configs/translation/sbert-to-bigbigan.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-6
3 | target: net2net.models.flows.flow.Net2NetFlow
4 | params:
5 | first_stage_key: "image"
6 | cond_stage_key: "caption"
7 | flow_config:
8 | target: net2net.modules.flow.flatflow.ConditionalFlatCouplingFlow
9 | params:
10 | conditioning_dim: 1024
11 | embedding_dim: 256
12 | conditioning_depth: 2
13 | n_flows: 24
14 | in_channels: 120
15 | hidden_dim: 1024
16 | hidden_depth: 2
17 | activation: "none"
18 | conditioner_use_bn: True
19 |
20 | cond_stage_config:
21 | target: net2net.modules.sbert.model.SentenceEmbedder
22 | params:
23 | version: "bert-large-nli-stsb-mean-tokens"
24 |
25 | first_stage_config:
26 | target: net2net.modules.gan.bigbigan.BigBiGAN
27 |
28 | data:
29 | target: translation.DataModuleFromConfig
30 | params:
31 | batch_size: 16
32 | train:
33 | target: net2net.data.coco.CocoImagesAndCaptionsTrain
34 | params:
35 | size: 256
36 | validation:
37 | target: net2net.data.coco.CocoImagesAndCaptionsValidation
38 | params:
39 | size: 256
40 |
--------------------------------------------------------------------------------
/configs/translation/sbert-to-biggan256.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-6
3 | target: net2net.models.flows.flow.Net2BigGANFlow
4 | params:
5 | flow_config:
6 | target: net2net.modules.flow.flatflow.ConditionalFlatCouplingFlow
7 | params:
8 | conditioning_dim: 768
9 | embedding_dim: 268
10 | conditioning_depth: 2
11 | n_flows: 24
12 | in_channels: 268
13 | hidden_dim: 1024
14 | hidden_depth: 2
15 | activation: "none"
16 | conditioner_use_bn: True
17 |
18 | cond_stage_config:
19 | target: net2net.modules.sbert.model.SentenceEmbedder
20 | params:
21 | version: "bert-base-nli-stsb-mean-tokens"
22 |
23 | gan_config:
24 | target: "net2net.modules.gan.biggan.BigGANWrapper"
25 | params:
26 | image_size: 256
27 |
28 | make_cond_config:
29 | # Takes an image produced by BigGAN and produces a textual description
30 | target: "net2net.modules.captions.model.Img2Text"
31 |
32 | data:
33 | target: translation.DataModuleFromConfig
34 | params:
35 | batch_size: 16
36 | train:
37 | target: net2net.data.zcodes.RestrictedTrainSamples
38 | params:
39 | n_samples: 100000
40 | z_shape:
41 | - 140
42 |
43 | truncation: 2.5
44 | validation:
45 | target: net2net.data.zcodes.RestrictedTestSamples
46 | params:
47 | n_samples: 10000
48 | z_shape:
49 | - 140
50 |
51 | truncation: 2.5
52 |
--------------------------------------------------------------------------------
/data/coco_imagenet_overlap_idx.txt:
--------------------------------------------------------------------------------
1 | 11
2 | 14
3 | 15
4 | 17
5 | 21
6 | 22
7 | 77
8 | 105
9 | 153
10 | 200
11 | 230
12 | 235
13 | 238
14 | 239
15 | 248
16 | 251
17 | 252
18 | 256
19 | 275
20 | 281
21 | 282
22 | 283
23 | 284
24 | 285
25 | 294
26 | 295
27 | 296
28 | 297
29 | 340
30 | 344
31 | 349
32 | 383
33 | 385
34 | 386
35 | 387
36 | 388
37 | 409
38 | 414
39 | 423
40 | 436
41 | 440
42 | 444
43 | 451
44 | 457
45 | 466
46 | 475
47 | 479
48 | 508
49 | 512
50 | 515
51 | 526
52 | 530
53 | 537
54 | 544
55 | 555
56 | 559
57 | 565
58 | 569
59 | 603
60 | 620
61 | 623
62 | 647
63 | 651
64 | 659
65 | 672
66 | 693
67 | 703
68 | 705
69 | 717
70 | 720
71 | 729
72 | 737
73 | 751
74 | 760
75 | 761
76 | 765
77 | 770
78 | 779
79 | 788
80 | 795
81 | 799
82 | 809
83 | 817
84 | 829
85 | 831
86 | 850
87 | 859
88 | 861
89 | 864
90 | 867
91 | 878
92 | 879
93 | 883
94 | 892
95 | 895
96 | 898
97 | 904
98 | 905
99 | 906
100 | 907
101 | 910
102 | 917
103 | 921
104 | 923
105 | 934
106 | 937
107 | 950
108 | 954
109 | 956
110 | 963
111 | 966
112 | 968
113 | 990
114 | 999
115 |
--------------------------------------------------------------------------------
/data/examples/anime/176890.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/anime/176890.jpg
--------------------------------------------------------------------------------
/data/examples/anime/1931960.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/anime/1931960.jpg
--------------------------------------------------------------------------------
/data/examples/anime/2775531.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/anime/2775531.jpg
--------------------------------------------------------------------------------
/data/examples/anime/3007790.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/anime/3007790.jpg
--------------------------------------------------------------------------------
/data/examples/anime/331480.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/anime/331480.jpg
--------------------------------------------------------------------------------
/data/examples/anime/348930.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/anime/348930.jpg
--------------------------------------------------------------------------------
/data/examples/anime/4460500.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/anime/4460500.jpg
--------------------------------------------------------------------------------
/data/examples/anime/499130.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/anime/499130.jpg
--------------------------------------------------------------------------------
/data/examples/anime/519280.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/anime/519280.jpg
--------------------------------------------------------------------------------
/data/examples/anime/656160.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/anime/656160.jpg
--------------------------------------------------------------------------------
/data/examples/anime/708030.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/anime/708030.jpg
--------------------------------------------------------------------------------
/data/examples/anime/881700.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/anime/881700.jpg
--------------------------------------------------------------------------------
/data/examples/anime/903790.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/anime/903790.jpg
--------------------------------------------------------------------------------
/data/examples/anime/910920.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/anime/910920.jpg
--------------------------------------------------------------------------------
/data/examples/anime/949780.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/anime/949780.jpg
--------------------------------------------------------------------------------
/data/examples/humanface/00010.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/humanface/00010.png
--------------------------------------------------------------------------------
/data/examples/humanface/00012.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/humanface/00012.png
--------------------------------------------------------------------------------
/data/examples/humanface/00018.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/humanface/00018.png
--------------------------------------------------------------------------------
/data/examples/humanface/00048.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/humanface/00048.png
--------------------------------------------------------------------------------
/data/examples/humanface/00059.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/humanface/00059.png
--------------------------------------------------------------------------------
/data/examples/humanface/65710.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/humanface/65710.png
--------------------------------------------------------------------------------
/data/examples/humanface/65719.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/humanface/65719.png
--------------------------------------------------------------------------------
/data/examples/humanface/65732.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/humanface/65732.png
--------------------------------------------------------------------------------
/data/examples/humanface/65831.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/humanface/65831.png
--------------------------------------------------------------------------------
/data/examples/humanface/65843.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/humanface/65843.png
--------------------------------------------------------------------------------
/data/examples/humanface/65855.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/humanface/65855.png
--------------------------------------------------------------------------------
/data/examples/humanface/65946.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/humanface/65946.png
--------------------------------------------------------------------------------
/data/examples/humanface/65949.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/humanface/65949.png
--------------------------------------------------------------------------------
/data/examples/humanface/65990.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/humanface/65990.png
--------------------------------------------------------------------------------
/data/examples/humanface/65998.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/humanface/65998.png
--------------------------------------------------------------------------------
/data/examples/humanface/65999.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/humanface/65999.png
--------------------------------------------------------------------------------
/data/examples/humanface/66040.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/humanface/66040.png
--------------------------------------------------------------------------------
/data/examples/humanface/66048.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/humanface/66048.png
--------------------------------------------------------------------------------
/data/examples/humanface/66056.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/humanface/66056.png
--------------------------------------------------------------------------------
/data/examples/humanface/66090.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/humanface/66090.png
--------------------------------------------------------------------------------
/data/examples/humanface/66100.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/humanface/66100.png
--------------------------------------------------------------------------------
/data/examples/humanface/66118.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/humanface/66118.png
--------------------------------------------------------------------------------
/data/examples/humanface/66157.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/humanface/66157.png
--------------------------------------------------------------------------------
/data/examples/oilportrait/beethoven.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/oilportrait/beethoven.jpeg
--------------------------------------------------------------------------------
/data/examples/oilportrait/descartes.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/oilportrait/descartes.jpeg
--------------------------------------------------------------------------------
/data/examples/oilportrait/fermat.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/oilportrait/fermat.jpg
--------------------------------------------------------------------------------
/data/examples/oilportrait/galileo.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/oilportrait/galileo.jpeg
--------------------------------------------------------------------------------
/data/examples/oilportrait/jeanjaques.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/oilportrait/jeanjaques.jpeg
--------------------------------------------------------------------------------
/data/examples/oilportrait/kant.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/oilportrait/kant.jpeg
--------------------------------------------------------------------------------
/data/examples/oilportrait/monalisa.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/oilportrait/monalisa.jpeg
--------------------------------------------------------------------------------
/data/examples/oilportrait/newton.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/oilportrait/newton.jpeg
--------------------------------------------------------------------------------
/env_bigbigan.yaml:
--------------------------------------------------------------------------------
1 | name: net2net_bigbigan
2 | channels:
3 | - pytorch
4 | - defaults
5 | dependencies:
6 | - python=3.7
7 | - pip=19.3
8 | - cudatoolkit=10.1
9 | - cudnn=7.6.5
10 | - pytorch=1.6
11 | - torchvision=0.7
12 | - numpy=1.18
13 | - pip:
14 | - albumentations==0.4.3
15 | - opencv-python==4.1.2.30
16 | - pudb==2019.2
17 | - imageio==2.9.0
18 | - imageio-ffmpeg==0.4.2
19 | - plotly==4.9.0
20 | - pytorch-lightning==0.9.0
21 | - omegaconf==2.0.0
22 | - streamlit==0.71.0
23 | - test-tube>=0.7.5
24 | - sentence-transformers>=0.3.8
25 | - tensorflow==2.3.1
26 | - tensorflow-hub==0.10.0
27 | - -e .
28 |
--------------------------------------------------------------------------------
/environment.yaml:
--------------------------------------------------------------------------------
1 | name: net2net
2 | channels:
3 | - pytorch
4 | - defaults
5 | dependencies:
6 | - python=3.7
7 | - pip=19.3
8 | - cudatoolkit=10.2
9 | - pytorch=1.6
10 | - torchvision=0.7
11 | - numpy=1.18
12 | - pip:
13 | - albumentations==0.4.3
14 | - opencv-python==4.1.2.30
15 | - pudb==2019.2
16 | - imageio==2.9.0
17 | - imageio-ffmpeg==0.4.2
18 | - plotly==4.9.0
19 | - pytorch-lightning==0.9.0
20 | - omegaconf==2.0.0
21 | - streamlit==0.71.0
22 | - test-tube>=0.7.5
23 | - sentence-transformers>=0.3.8
24 | - -e .
25 |
--------------------------------------------------------------------------------
/net2net/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/net2net/__init__.py
--------------------------------------------------------------------------------
/net2net/ckpt_util.py:
--------------------------------------------------------------------------------
1 | import os, hashlib
2 | import requests
3 | from tqdm import tqdm
4 |
5 | URL_MAP = {
6 | "biggan_128": "https://heibox.uni-heidelberg.de/f/56ed256209fd40968864/?dl=1",
7 | "biggan_256": "https://heibox.uni-heidelberg.de/f/437b501944874bcc92a4/?dl=1",
8 | "dequant_vae": "https://heibox.uni-heidelberg.de/f/e7c8959b50a64f40826e/?dl=1",
9 | "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1",
10 | "coco_captioner": "https://heibox.uni-heidelberg.de/f/b03aae864a0f42f1a2c3/?dl=1",
11 | "coco_word_map": "https://heibox.uni-heidelberg.de/f/1518aa8461d94e0cb3eb/?dl=1"
12 | }
13 |
14 | CKPT_MAP = {
15 | "biggan_128": "biggan-128.pth",
16 | "biggan_256": "biggan-256.pth",
17 | "dequant_vae": "dequantvae-20000.ckpt",
18 | "vgg_lpips": "autoencoders/lpips/vgg.pth",
19 | "coco_captioner": "captioning_model_pt16.ckpt",
20 | }
21 |
22 | MD5_MAP = {
23 | "biggan_128": "a2148cf64807444113fac5eede060d28",
24 | "biggan_256": "e23db3caa34ac4c4ae922a75258dcb8d",
25 | "dequant_vae": "5c2a6fe765142cbdd9f10f15d65a68b6",
26 | "vgg_lpips": "d507d7349b931f0638a25a48a722f98a",
27 | "coco_captioner": "db185e0f6791e60d27c00de0f40c376c",
28 | }
29 |
30 |
31 | def download(url, local_path, chunk_size=1024):
32 | os.makedirs(os.path.split(local_path)[0], exist_ok=True)
33 | with requests.get(url, stream=True) as r:
34 | total_size = int(r.headers.get("content-length", 0))
35 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
36 | with open(local_path, "wb") as f:
37 | for data in r.iter_content(chunk_size=chunk_size):
38 | if data:
39 | f.write(data)
40 | pbar.update(chunk_size)
41 |
42 |
43 | def md5_hash(path):
44 | with open(path, "rb") as f:
45 | content = f.read()
46 | return hashlib.md5(content).hexdigest()
47 |
48 |
49 | def get_ckpt_path(name, root, check=False):
50 | assert name in URL_MAP
51 | path = os.path.join(root, CKPT_MAP[name])
52 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
53 | print("Downloading {} from {} to {}".format(name, URL_MAP[name], path))
54 | download(URL_MAP[name], path)
55 | md5 = md5_hash(path)
56 | assert md5 == MD5_MAP[name], md5
57 | return path
58 |
--------------------------------------------------------------------------------
/net2net/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/net2net/data/__init__.py
--------------------------------------------------------------------------------
/net2net/data/base.py:
--------------------------------------------------------------------------------
1 | import os, bisect
2 | import numpy as np
3 | import albumentations
4 | from PIL import Image
5 | from torch.utils.data import Dataset, ConcatDataset
6 |
7 |
8 | class ConcatDatasetWithIndex(ConcatDataset):
9 | """Modified from original pytorch code to return dataset idx"""
10 | def __getitem__(self, idx):
11 | if idx < 0:
12 | if -idx > len(self):
13 | raise ValueError("absolute value of index should not exceed dataset length")
14 | idx = len(self) + idx
15 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
16 | if dataset_idx == 0:
17 | sample_idx = idx
18 | else:
19 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
20 | return self.datasets[dataset_idx][sample_idx], dataset_idx
21 |
22 |
23 | class ImagePaths(Dataset):
24 | def __init__(self, paths, size=None, random_crop=False):
25 | self.size = size
26 | self.random_crop = random_crop
27 |
28 | self.labels = dict()
29 | self.labels["file_path_"] = paths
30 | self._length = len(paths)
31 |
32 | if self.size is not None and self.size > 0:
33 | self.rescaler = albumentations.SmallestMaxSize(max_size = self.size)
34 | if not self.random_crop:
35 | self.cropper = albumentations.CenterCrop(height=self.size,width=self.size)
36 | else:
37 | self.cropper = albumentations.RandomCrop(height=self.size,width=self.size)
38 | self.preprocessor = albumentations.Compose([self.rescaler, self.cropper])
39 | else:
40 | self.preprocessor = lambda **kwargs: kwargs
41 |
42 | def __len__(self):
43 | return self._length
44 |
45 | def preprocess_image(self, image_path):
46 | image = Image.open(image_path)
47 | if not image.mode == "RGB":
48 | image = image.convert("RGB")
49 | image = np.array(image).astype(np.uint8)
50 | image = self.preprocessor(image=image)["image"]
51 | image = (image/127.5 - 1.0).astype(np.float32)
52 | return image
53 |
54 | def __getitem__(self, i):
55 | example = dict()
56 | example["image"] = self.preprocess_image(self.labels["file_path_"][i])
57 | for k in self.labels:
58 | example[k] = self.labels[k][i]
59 | return example
60 |
61 |
62 | class NumpyPaths(ImagePaths):
63 | def preprocess_image(self, image_path):
64 | image = np.load(image_path).squeeze(0) # 3 x 1024 x 1024
65 | image = np.transpose(image, (1,2,0))
66 | image = Image.fromarray(image, mode="RGB")
67 | image = np.array(image).astype(np.uint8)
68 | image = self.preprocessor(image=image)["image"]
69 | image = (image/127.5 - 1.0).astype(np.float32)
70 | return image
71 |
--------------------------------------------------------------------------------
/net2net/data/coco.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import albumentations
4 | import numpy as np
5 | from PIL import Image
6 | from tqdm import tqdm
7 | from torch.utils.data import Dataset
8 |
9 |
10 | class CocoBase(Dataset):
11 | """needed for (image, caption, segmentation) pairs"""
12 | def __init__(self, size=None, dataroot="", datajson="", onehot_segmentation=False, use_stuffthing=False,
13 | crop_size=None, force_no_crop=False):
14 | self.split = self.get_split()
15 | self.size = size
16 | if crop_size is None:
17 | self.crop_size = size
18 | else:
19 | self.crop_size = crop_size
20 |
21 | self.onehot = onehot_segmentation # return segmentation as rgb or one hot
22 | self.stuffthing = use_stuffthing # include thing in segmentation
23 | if self.onehot and not self.stuffthing:
24 | raise NotImplemented("One hot mode is only supported for the "
25 | "stuffthings version because labels are stored "
26 | "a bit different.")
27 |
28 | data_json = datajson
29 | with open(data_json) as json_file:
30 | self.json_data = json.load(json_file)
31 | self.img_id_to_captions = dict()
32 | self.img_id_to_filepath = dict()
33 | self.img_id_to_segmentation_filepath = dict()
34 |
35 | assert data_json.split("/")[-1] in ["captions_train2017.json",
36 | "captions_val2017.json"]
37 |
38 | if self.stuffthing:
39 | self.segmentation_prefix = (
40 | "data/cocostuffthings/val2017" if
41 | data_json.endswith("captions_val2017.json") else
42 | "data/cocostuffthings/train2017")
43 | else:
44 | self.segmentation_prefix = (
45 | "data/coco/annotations/stuff_val2017_pixelmaps" if
46 | data_json.endswith("captions_val2017.json") else
47 | "data/coco/annotations/stuff_train2017_pixelmaps")
48 |
49 | imagedirs = self.json_data["images"]
50 | self.labels = {"image_ids": list()}
51 | for imgdir in tqdm(imagedirs, desc="ImgToPath"):
52 | self.img_id_to_filepath[imgdir["id"]] = os.path.join(dataroot, imgdir["file_name"])
53 | self.img_id_to_captions[imgdir["id"]] = list()
54 | pngfilename = imgdir["file_name"].replace("jpg", "png")
55 | self.img_id_to_segmentation_filepath[imgdir["id"]] = os.path.join(
56 | self.segmentation_prefix, pngfilename)
57 | self.labels["image_ids"].append(imgdir["id"])
58 |
59 | capdirs = self.json_data["annotations"]
60 | for capdir in tqdm(capdirs, desc="ImgToCaptions"):
61 | # there are in average 5 captions per image
62 | self.img_id_to_captions[capdir["image_id"]].append(np.array([capdir["caption"]]))
63 |
64 | self.rescaler = albumentations.SmallestMaxSize(max_size=self.size)
65 | if self.split=="validation":
66 | self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size)
67 | else:
68 | self.cropper = albumentations.RandomCrop(height=self.crop_size, width=self.crop_size)
69 | self.preprocessor = albumentations.Compose(
70 | [self.rescaler, self.cropper],
71 | additional_targets={"segmentation": "image"})
72 | if force_no_crop:
73 | self.rescaler = albumentations.Resize(height=self.size, width=self.size)
74 | self.preprocessor = albumentations.Compose(
75 | [self.rescaler],
76 | additional_targets={"segmentation": "image"})
77 |
78 | def __len__(self):
79 | return len(self.labels["image_ids"])
80 |
81 | def preprocess_image(self, image_path, segmentation_path):
82 | image = Image.open(image_path)
83 | if not image.mode == "RGB":
84 | image = image.convert("RGB")
85 | image = np.array(image).astype(np.uint8)
86 |
87 | segmentation = Image.open(segmentation_path)
88 | if not self.onehot and not segmentation.mode == "RGB":
89 | segmentation = segmentation.convert("RGB")
90 | segmentation = np.array(segmentation).astype(np.uint8)
91 | if self.onehot:
92 | assert self.stuffthing
93 | # stored in caffe format: unlabeled==255. stuff and thing from
94 | # 0-181. to be compatible with the labels in
95 | # https://github.com/nightrome/cocostuff/blob/master/labels.txt
96 | # we shift stuffthing one to the right and put unlabeled in zero
97 | # as long as segmentation is uint8 shifting to right handles the
98 | # latter too
99 | assert segmentation.dtype == np.uint8
100 | segmentation = segmentation + 1
101 |
102 | processed = self.preprocessor(image=image, segmentation=segmentation)
103 | image, segmentation = processed["image"], processed["segmentation"]
104 | image = (image / 127.5 - 1.0).astype(np.float32)
105 |
106 | if self.onehot:
107 | assert segmentation.dtype == np.uint8
108 | # make it one hot
109 | n_labels = 183
110 | flatseg = np.ravel(segmentation)
111 | onehot = np.zeros((flatseg.size, n_labels), dtype=np.bool)
112 | onehot[np.arange(flatseg.size), flatseg] = True
113 | onehot = onehot.reshape(segmentation.shape + (n_labels,)).astype(int)
114 | segmentation = onehot
115 | else:
116 | segmentation = (segmentation / 127.5 - 1.0).astype(np.float32)
117 | return image, segmentation
118 |
119 | def __getitem__(self, i):
120 | img_path = self.img_id_to_filepath[self.labels["image_ids"][i]]
121 | seg_path = self.img_id_to_segmentation_filepath[self.labels["image_ids"][i]]
122 | image, segmentation = self.preprocess_image(img_path, seg_path)
123 | captions = self.img_id_to_captions[self.labels["image_ids"][i]]
124 | # randomly draw one of all available captions per image
125 | caption = captions[np.random.randint(0, len(captions))]
126 | example = {"image": image,
127 | "caption": [str(caption[0])],
128 | "segmentation": segmentation,
129 | "img_path": img_path,
130 | "seg_path": seg_path
131 | }
132 | return example
133 |
134 |
135 | class CocoImagesAndCaptionsTrain(CocoBase):
136 | """returns a pair of (image, caption)"""
137 | def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False):
138 | super().__init__(size=size,
139 | dataroot="data/coco/train2017",
140 | datajson="data/coco/annotations/captions_train2017.json",
141 | onehot_segmentation=onehot_segmentation,
142 | use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop)
143 |
144 | def get_split(self):
145 | return "train"
146 |
147 |
148 | class CocoImagesAndCaptionsValidation(CocoBase):
149 | """returns a pair of (image, caption)"""
150 | def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False):
151 | super().__init__(size=size,
152 | dataroot="data/coco/val2017",
153 | datajson="data/coco/annotations/captions_val2017.json",
154 | onehot_segmentation=onehot_segmentation,
155 | use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop)
156 |
157 | def get_split(self):
158 | return "validation"
159 |
160 |
--------------------------------------------------------------------------------
/net2net/data/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import urllib
4 | import tarfile, zipfile
5 | from pathlib import Path
6 | from tqdm import tqdm
7 |
8 |
9 | def unpack(path):
10 | if path.endswith("tar.gz"):
11 | with tarfile.open(path, "r:gz") as tar:
12 | tar.extractall(path=os.path.split(path)[0])
13 | elif path.endswith("tar"):
14 | with tarfile.open(path, "r:") as tar:
15 | tar.extractall(path=os.path.split(path)[0])
16 | elif path.endswith("zip"):
17 | with zipfile.ZipFile(path, "r") as f:
18 | f.extractall(path=os.path.split(path)[0])
19 | else:
20 | raise NotImplementedError(
21 | "Unknown file extension: {}".format(os.path.splitext(path)[1])
22 | )
23 |
24 |
25 | def reporthook(bar):
26 | """tqdm progress bar for downloads."""
27 |
28 | def hook(b=1, bsize=1, tsize=None):
29 | if tsize is not None:
30 | bar.total = tsize
31 | bar.update(b * bsize - bar.n)
32 |
33 | return hook
34 |
35 |
36 | def get_root(name):
37 | base = "data/"
38 | root = os.path.join(base, name)
39 | os.makedirs(root, exist_ok=True)
40 | return root
41 |
42 |
43 | def is_prepared(root):
44 | return Path(root).joinpath(".ready").exists()
45 |
46 |
47 | def mark_prepared(root):
48 | Path(root).joinpath(".ready").touch()
49 |
50 |
51 | def prompt_download(file_, source, target_dir, content_dir=None):
52 | targetpath = os.path.join(target_dir, file_)
53 | while not os.path.exists(targetpath):
54 | if content_dir is not None and os.path.exists(
55 | os.path.join(target_dir, content_dir)
56 | ):
57 | break
58 | print(
59 | "Please download '{}' from '{}' to '{}'.".format(file_, source, targetpath)
60 | )
61 | if content_dir is not None:
62 | print(
63 | "Or place its content into '{}'.".format(
64 | os.path.join(target_dir, content_dir)
65 | )
66 | )
67 | input("Press Enter when done...")
68 | return targetpath
69 |
70 |
71 | def download_url(file_, url, target_dir):
72 | targetpath = os.path.join(target_dir, file_)
73 | os.makedirs(target_dir, exist_ok=True)
74 | with tqdm(
75 | unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=file_
76 | ) as bar:
77 | urllib.request.urlretrieve(url, targetpath, reporthook=reporthook(bar))
78 | return targetpath
79 |
80 |
81 | def download_urls(urls, target_dir):
82 | paths = dict()
83 | for fname, url in urls.items():
84 | outpath = download_url(fname, url, target_dir)
85 | paths[fname] = outpath
86 | return paths
87 |
88 |
89 | def quadratic_crop(x, bbox, alpha=1.0):
90 | """bbox is xmin, ymin, xmax, ymax"""
91 | im_h, im_w = x.shape[:2]
92 | bbox = np.array(bbox, dtype=np.float32)
93 | bbox = np.clip(bbox, 0, max(im_h, im_w))
94 | center = 0.5 * (bbox[0] + bbox[2]), 0.5 * (bbox[1] + bbox[3])
95 | w = bbox[2] - bbox[0]
96 | h = bbox[3] - bbox[1]
97 | l = int(alpha * max(w, h))
98 | l = max(l, 2)
99 |
100 | required_padding = -1 * min(
101 | center[0] - l, center[1] - l, im_w - (center[0] + l), im_h - (center[1] + l)
102 | )
103 | required_padding = int(np.ceil(required_padding))
104 | if required_padding > 0:
105 | padding = [
106 | [required_padding, required_padding],
107 | [required_padding, required_padding],
108 | ]
109 | padding += [[0, 0]] * (len(x.shape) - 2)
110 | x = np.pad(x, padding, "reflect")
111 | center = center[0] + required_padding, center[1] + required_padding
112 | xmin = int(center[0] - l / 2)
113 | ymin = int(center[1] - l / 2)
114 | return np.array(x[ymin : ymin + l, xmin : xmin + l, ...])
115 |
--------------------------------------------------------------------------------
/net2net/data/zcodes.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from tqdm import tqdm
4 | from torch.utils.data import Dataset
5 |
6 |
7 | class PRNGMixin(object):
8 | """Adds a prng property which is a numpy RandomState which gets
9 | reinitialized whenever the pid changes to avoid synchronized sampling
10 | behavior when used in conjunction with multiprocessing."""
11 |
12 | @property
13 | def prng(self):
14 | currentpid = os.getpid()
15 | if getattr(self, "_initpid", None) != currentpid:
16 | self._initpid = currentpid
17 | self._prng = np.random.RandomState()
18 | return self._prng
19 |
20 |
21 | class TrainSamples(Dataset, PRNGMixin):
22 | def __init__(self, n_samples, z_shape, n_classes, truncation=0):
23 | self.n_samples = n_samples
24 | self.z_shape = z_shape
25 | self.n_classes = n_classes
26 | self.truncation_threshold = truncation
27 | if self.truncation_threshold > 0:
28 | print("Applying truncation at level {}".format(self.truncation_threshold))
29 |
30 | def __len__(self):
31 | return self.n_samples
32 |
33 | def __getitem__(self, i):
34 | z = self.prng.randn(*self.z_shape)
35 | if self.truncation_threshold > 0:
36 | for k, zi in enumerate(z):
37 | while abs(zi) > self.truncation_threshold:
38 | zi = self.prng.randn(1)
39 | z[k] = zi
40 | cls = self.prng.randint(self.n_classes)
41 | return {"z": z.astype(np.float32), "class": cls}
42 |
43 |
44 | class TestSamples(Dataset):
45 | def __init__(self, n_samples, z_shape, n_classes, truncation=0):
46 | self.prng = np.random.RandomState(1)
47 | self.n_samples = n_samples
48 | self.z_shape = z_shape
49 | self.n_classes = n_classes
50 | self.truncation_threshold = truncation
51 | if self.truncation_threshold > 0:
52 | print("Applying truncation at level {}".format(self.truncation_threshold))
53 | self.zs = self.prng.randn(self.n_samples, *self.z_shape)
54 | if self.truncation_threshold > 0:
55 | print("Applying truncation at level {}".format(self.truncation_threshold))
56 | ix = 0
57 | for z in tqdm(self.zs, desc="Truncation:"):
58 | for k, zi in enumerate(z):
59 | while abs(zi) > self.truncation_threshold:
60 | zi = self.prng.randn(1)
61 | z[k] = zi
62 | self.zs[ix] = z
63 | ix += 1
64 | print("Created truncated test data.")
65 | self.clss = self.prng.randint(self.n_classes, size=(self.n_samples,))
66 |
67 | def __len__(self):
68 | return self.n_samples
69 |
70 | def __getitem__(self, i):
71 | return {"z": self.zs[i].astype(np.float32), "class": self.clss[i]}
72 |
73 |
74 | class RestrictedTrainSamples(Dataset, PRNGMixin):
75 | def __init__(self, n_samples, z_shape, truncation=0):
76 | index_path = "data/coco_imagenet_overlap_idx.txt"
77 | self.n_samples = n_samples
78 | self.z_shape = z_shape
79 | self.classes = np.loadtxt(index_path).astype(int)
80 | self.truncation_threshold = truncation
81 | if self.truncation_threshold > 0:
82 | print("Applying truncation at level {}".format(self.truncation_threshold))
83 |
84 | def __len__(self):
85 | return self.n_samples
86 |
87 | def __getitem__(self, i):
88 | z = self.prng.randn(*self.z_shape)
89 | if self.truncation_threshold > 0:
90 | for k, zi in enumerate(z):
91 | while abs(zi) > self.truncation_threshold:
92 | zi = self.prng.randn(1)
93 | z[k] = zi
94 | cls = self.prng.choice(self.classes)
95 | return {"z": z.astype(np.float32), "class": cls}
96 |
97 |
98 | class RestrictedTestSamples(Dataset):
99 | def __init__(self, n_samples, z_shape, truncation=0):
100 | index_path = "data/coco_imagenet_overlap_idx.txt"
101 |
102 | self.prng = np.random.RandomState(1)
103 | self.n_samples = n_samples
104 | self.z_shape = z_shape
105 |
106 | self.classes = np.loadtxt(index_path).astype(int)
107 | self.clss = self.prng.choice(self.classes, size=(self.n_samples,), replace=True)
108 | self.truncation_threshold = truncation
109 | self.zs = self.prng.randn(self.n_samples, *self.z_shape)
110 | if self.truncation_threshold > 0:
111 | print("Applying truncation at level {}".format(self.truncation_threshold))
112 | ix = 0
113 | for z in tqdm(self.zs, desc="Truncation:"):
114 | for k, zi in enumerate(z):
115 | while abs(zi) > self.truncation_threshold:
116 | zi = self.prng.randn(1)
117 | z[k] = zi
118 | self.zs[ix] = z
119 | ix += 1
120 | print("Created truncated test data.")
121 |
122 | def __len__(self):
123 | return self.n_samples
124 |
125 | def __getitem__(self, i):
126 | return {"z": self.zs[i].astype(np.float32), "class": self.clss[i]}
127 |
128 |
129 |
--------------------------------------------------------------------------------
/net2net/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/net2net/models/__init__.py
--------------------------------------------------------------------------------
/net2net/models/autoencoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pytorch_lightning as pl
3 |
4 | from net2net.modules.distributions.distributions import DiagonalGaussianDistribution
5 | from translation import instantiate_from_config
6 |
7 |
8 | class BigAE(pl.LightningModule):
9 | def __init__(self,
10 | encoder_config,
11 | decoder_config,
12 | loss_config,
13 | ckpt_path=None,
14 | ignore_keys=[]
15 | ):
16 | super().__init__()
17 | self.encoder = instantiate_from_config(encoder_config)
18 | self.decoder = instantiate_from_config(decoder_config)
19 | self.loss = instantiate_from_config(loss_config)
20 |
21 | if ckpt_path is not None:
22 | print("Loading model from {}".format(ckpt_path))
23 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
24 |
25 | def init_from_ckpt(self, path, ignore_keys=list()):
26 | try:
27 | sd = torch.load(path, map_location="cpu")["state_dict"]
28 | except KeyError:
29 | sd = torch.load(path, map_location="cpu")
30 |
31 | keys = list(sd.keys())
32 | for k in keys:
33 | for ik in ignore_keys:
34 | if k.startswith(ik):
35 | print("Deleting key {} from state_dict.".format(k))
36 | del sd[k]
37 | missing, unexpected = self.load_state_dict(sd, strict=False)
38 | if len(missing) > 0:
39 | print(f"Missing keys in state dict: {missing}")
40 | if len(unexpected) > 0:
41 | print(f"Unexpected keys in state dict: {unexpected}")
42 |
43 | def encode(self, x, return_mode=False):
44 | moments = self.encoder(x)
45 | posterior = DiagonalGaussianDistribution(moments, deterministic=False)
46 | if return_mode:
47 | return posterior.mode()
48 | return posterior.sample()
49 |
50 | def decode(self, z):
51 | if len(z.shape) == 4:
52 | z = z.squeeze(-1).squeeze(-1)
53 | return self.decoder(z)
54 |
55 | def forward(self, x):
56 | moments = self.encoder(x)
57 | posterior = DiagonalGaussianDistribution(moments)
58 | h = posterior.sample()
59 | reconstructions = self.decoder(h.squeeze(-1).squeeze(-1))
60 | return reconstructions, posterior
61 |
62 | def get_last_layer(self):
63 | return getattr(self.decoder.decoder.colorize.module, 'weight_bar')
64 |
65 | def log_images(self, batch, split=""):
66 | log = dict()
67 | inputs = batch["image"].permute(0, 3, 1, 2)
68 | inputs = inputs.to(self.device)
69 | reconstructions, posterior = self(inputs)
70 | log["inputs"] = inputs
71 | log["reconstructions"] = reconstructions
72 | return log
73 |
74 | def training_step(self, batch, batch_idx, optimizer_idx):
75 | inputs = batch["image"].permute(0, 3, 1, 2)
76 | reconstructions, posterior = self(inputs)
77 |
78 | if optimizer_idx == 0:
79 | # train encoder+decoder+logvar
80 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
81 | last_layer=self.get_last_layer(), split="train")
82 | output = pl.TrainResult(minimize=aeloss, checkpoint_on=aeloss)
83 | output.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
84 | output.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
85 | return output
86 |
87 | if optimizer_idx == 1:
88 | # train the discriminator
89 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
90 | last_layer=self.get_last_layer(), split="train")
91 | output = pl.TrainResult(minimize=discloss)
92 | output.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
93 | output.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
94 | del output["checkpoint_on"] # NOTE pl currently sets checkpoint_on=minimize by default TODO
95 | return output
96 |
97 | def validation_step(self, batch, batch_idx):
98 | inputs = batch["image"].permute(0, 3, 1, 2)
99 | reconstructions, posterior = self(inputs)
100 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
101 | last_layer=self.get_last_layer(), split="val")
102 |
103 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
104 | last_layer=self.get_last_layer(), split="val")
105 | output = pl.EvalResult(checkpoint_on=aeloss)
106 | output.log_dict(log_dict_ae)
107 | output.log_dict(log_dict_disc)
108 | return output
109 |
110 | def configure_optimizers(self):
111 | lr = self.learning_rate
112 | opt_ae = torch.optim.Adam(list(self.encoder.parameters())+list(self.decoder.parameters())+[self.loss.logvar],
113 | lr=lr, betas=(0.5, 0.9))
114 | opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9))
115 | return [opt_ae, opt_disc], []
116 |
117 | def on_epoch_end(self):
118 | pass
119 |
120 |
121 | class BasicAE(pl.LightningModule):
122 | def __init__(self, ae_config, loss_config, ckpt_path=None, ignore_keys=[]):
123 | super().__init__()
124 | self.autoencoder = instantiate_from_config(ae_config)
125 | self.loss = instantiate_from_config(loss_config)
126 | if ckpt_path is not None:
127 | print("Loading model from {}".format(ckpt_path))
128 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
129 |
130 | def init_from_ckpt(self, path, ignore_keys=list()):
131 | try:
132 | sd = torch.load(path, map_location="cpu")["state_dict"]
133 | except KeyError:
134 | sd = torch.load(path, map_location="cpu")
135 |
136 | keys = list(sd.keys())
137 | for k in keys:
138 | for ik in ignore_keys:
139 | if k.startswith(ik):
140 | print("Deleting key {} from state_dict.".format(k))
141 | del sd[k]
142 | self.load_state_dict(sd, strict=False)
143 |
144 | def forward(self, x):
145 | posterior = self.autoencoder.encode(x)
146 | h = posterior.sample()
147 | reconstructions = self.autoencoder.decode(h)
148 | return reconstructions, posterior
149 |
150 | def encode(self, x):
151 | posterior = self.autoencoder.encode(x)
152 | h = posterior.sample()
153 | return h
154 |
155 | def get_last_layer(self):
156 | return self.autoencoder.get_last_layer()
157 |
158 | def log_images(self, batch, split=""):
159 | log = dict()
160 | inputs = batch["image"].permute(0, 3, 1, 2)
161 | inputs = inputs.to(self.device)
162 | reconstructions, posterior = self(inputs)
163 | log["inputs"] = inputs
164 | log["reconstructions"] = reconstructions
165 | return log
166 |
167 | def training_step(self, batch, batch_idx, optimizer_idx):
168 | inputs = batch["image"].permute(0, 3, 1, 2)
169 | reconstructions, posterior = self(inputs)
170 |
171 | if optimizer_idx == 0:
172 | # train encoder+decoder+logvar
173 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
174 | last_layer=self.get_last_layer(), split="train")
175 | output = pl.TrainResult(minimize=aeloss, checkpoint_on=aeloss)
176 | output.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
177 | output.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
178 | return output
179 |
180 | if optimizer_idx == 1:
181 | # train the discriminator
182 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
183 | last_layer=self.get_last_layer(), split="train")
184 | output = pl.TrainResult(minimize=discloss)
185 | output.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
186 | output.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
187 | del output["checkpoint_on"] # NOTE pl currently sets checkpoint_on=minimize by default TODO
188 | return output
189 |
190 | def validation_step(self, batch, batch_idx):
191 | inputs = batch["image"].permute(0, 3, 1, 2)
192 | reconstructions, posterior = self(inputs)
193 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
194 | last_layer=self.get_last_layer(), split="val")
195 |
196 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
197 | last_layer=self.get_last_layer(), split="val")
198 | output = pl.EvalResult(checkpoint_on=aeloss)
199 | output.log_dict(log_dict_ae)
200 | output.log_dict(log_dict_disc)
201 | return output
202 |
203 | def configure_optimizers(self):
204 | lr = self.learning_rate
205 | opt_ae = torch.optim.Adam(list(self.autoencoder.parameters())+[self.loss.logvar],
206 | lr=lr, betas=(0.5, 0.9))
207 | opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9))
208 | return [opt_ae, opt_disc], []
209 |
--------------------------------------------------------------------------------
/net2net/models/flows/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/net2net/models/flows/__init__.py
--------------------------------------------------------------------------------
/net2net/models/flows/flow.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import pytorch_lightning as pl
4 |
5 | from translation import instantiate_from_config
6 | from net2net.modules.flow.loss import NLL
7 | from net2net.ckpt_util import get_ckpt_path
8 | from net2net.modules.util import log_txt_as_img
9 |
10 |
11 | def disabled_train(self, mode=True):
12 | """Overwrite model.train with this function to make sure train/eval mode
13 | does not change anymore."""
14 | return self
15 |
16 |
17 | class Flow(pl.LightningModule):
18 | def __init__(self, flow_config):
19 | super().__init__()
20 | self.flow = instantiate_from_config(config=flow_config)
21 | self.loss = NLL()
22 |
23 | def forward(self, x):
24 | zz, logdet = self.flow(x)
25 | return zz, logdet
26 |
27 | def sample_like(self, query):
28 | z = self.flow.sample(query.shape[0], device=query.device).float()
29 | return z
30 |
31 | def shared_step(self, batch, batch_idx):
32 | x, labels = batch
33 | x = x.float()
34 | zz, logdet = self(x)
35 | loss, log_dict = self.loss(zz, logdet)
36 | return loss, log_dict
37 |
38 | def training_step(self, batch, batch_idx):
39 | loss, log_dict = self.shared_step(batch, batch_idx)
40 | output = pl.TrainResult(minimize=loss, checkpoint_on=loss)
41 | output.log_dict(log_dict, prog_bar=False, on_epoch=True)
42 | return output
43 |
44 | def validation_step(self, batch, batch_idx):
45 | loss, log_dict = self.shared_step(batch, batch_idx)
46 | output = pl.EvalResult(checkpoint_on=loss)
47 | output.log_dict(log_dict, prog_bar=False)
48 |
49 | x, _ = batch
50 | x = x.float()
51 | sample = self.sample_like(x)
52 | output.sample_like = sample
53 | output.input = x.clone()
54 |
55 | return output
56 |
57 | def configure_optimizers(self):
58 | opt = torch.optim.Adam((self.flow.parameters()),lr=self.learning_rate, betas=(0.5, 0.9))
59 | return [opt], []
60 |
61 |
62 | class Net2NetFlow(pl.LightningModule):
63 | def __init__(self,
64 | flow_config,
65 | first_stage_config,
66 | cond_stage_config,
67 | ckpt_path=None,
68 | ignore_keys=[],
69 | first_stage_key="image",
70 | cond_stage_key="image",
71 | interpolate_cond_size=-1
72 | ):
73 | super().__init__()
74 | self.init_first_stage_from_ckpt(first_stage_config)
75 | self.init_cond_stage_from_ckpt(cond_stage_config)
76 | self.flow = instantiate_from_config(config=flow_config)
77 | self.loss = NLL()
78 | self.first_stage_key = first_stage_key
79 | self.cond_stage_key = cond_stage_key
80 | self.interpolate_cond_size = interpolate_cond_size
81 | if ckpt_path is not None:
82 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
83 |
84 | def init_from_ckpt(self, path, ignore_keys=list()):
85 | sd = torch.load(path, map_location="cpu")["state_dict"]
86 | for k in sd.keys():
87 | for ik in ignore_keys:
88 | if k.startswith(ik):
89 | self.print("Deleting key {} from state_dict.".format(k))
90 | del sd[k]
91 | self.load_state_dict(sd, strict=False)
92 | print(f"Restored from {path}")
93 |
94 | def init_first_stage_from_ckpt(self, config):
95 | model = instantiate_from_config(config)
96 | model = model.eval()
97 | model.train = disabled_train
98 | self.first_stage_model = model
99 |
100 | def init_cond_stage_from_ckpt(self, config):
101 | model = instantiate_from_config(config)
102 | model = model.eval()
103 | model.train = disabled_train
104 | self.cond_stage_model = model
105 |
106 | def forward(self, x, c):
107 | c = self.encode_to_c(c)
108 | q = self.encode_to_z(x)
109 | zz, logdet = self.flow(q, c)
110 | return zz, logdet
111 |
112 | @torch.no_grad()
113 | def sample_conditional(self, c):
114 | z = self.flow.sample(c)
115 | return z
116 |
117 | @torch.no_grad()
118 | def encode_to_z(self, x):
119 | z = self.first_stage_model.encode(x).detach()
120 | return z
121 |
122 | @torch.no_grad()
123 | def encode_to_c(self, c):
124 | c = self.cond_stage_model.encode(c).detach()
125 | return c
126 |
127 | @torch.no_grad()
128 | def decode_to_img(self, z):
129 | x = self.first_stage_model.decode(z.detach())
130 | return x
131 |
132 | @torch.no_grad()
133 | def log_images(self, batch, split=""):
134 | log = dict()
135 | x = self.get_input(self.first_stage_key, batch).to(self.device)
136 | xc = self.get_input(self.cond_stage_key, batch, is_conditioning=True)
137 | if self.cond_stage_key not in ["text", "caption"]:
138 | xc = xc.to(self.device)
139 |
140 | z = self.encode_to_z(x)
141 | c = self.encode_to_c(xc)
142 |
143 | zz, _ = self.flow(z, c)
144 | zrec = self.flow.reverse(zz, c)
145 | xrec = self.decode_to_img(zrec)
146 | z_sample = self.sample_conditional(c)
147 | xsample = self.decode_to_img(z_sample)
148 |
149 | cshift = torch.cat((c[1:],c[:1]),dim=0)
150 | zshift = self.flow.reverse(zz, cshift)
151 | xshift = self.decode_to_img(zshift)
152 |
153 | log["inputs"] = x
154 | if self.cond_stage_key not in ["text", "caption", "class"]:
155 | log["conditioning"] = xc
156 | else:
157 | _,_,h,w = x.shape
158 | log["conditioning"] = log_txt_as_img((w,h), xc)
159 |
160 | log["reconstructions"] = xrec
161 | log["shift"] = xshift
162 | log["samples"] = xsample
163 | return log
164 |
165 | def get_input(self, key, batch, is_conditioning=False):
166 | x = batch[key]
167 | if key in ["caption", "text"]:
168 | x = list(x[0])
169 | elif key in ["class"]:
170 | pass
171 | else:
172 | if len(x.shape) == 3:
173 | x = x[..., None]
174 | x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
175 | if is_conditioning:
176 | if self.interpolate_cond_size > -1:
177 | x = F.interpolate(x, size=(self.interpolate_cond_size, self.interpolate_cond_size))
178 | return x
179 |
180 | def shared_step(self, batch, batch_idx, split="train"):
181 | x = self.get_input(self.first_stage_key, batch)
182 | c = self.get_input(self.cond_stage_key, batch, is_conditioning=True)
183 | zz, logdet = self(x, c)
184 | loss, log_dict = self.loss(zz, logdet, split=split)
185 | return loss, log_dict
186 |
187 | def training_step(self, batch, batch_idx):
188 | loss, log_dict = self.shared_step(batch, batch_idx, split="train")
189 | output = pl.TrainResult(minimize=loss, checkpoint_on=loss)
190 | output.log_dict(log_dict, prog_bar=False, on_epoch=True, logger=True, on_step=True)
191 | return output
192 |
193 | def validation_step(self, batch, batch_idx):
194 | loss, log_dict = self.shared_step(batch, batch_idx, split="val")
195 | output = pl.EvalResult(checkpoint_on=loss)
196 | output.log_dict(log_dict, prog_bar=False, logger=True)
197 | return output
198 |
199 | def configure_optimizers(self):
200 | opt = torch.optim.Adam((self.flow.parameters()),
201 | lr=self.learning_rate,
202 | betas=(0.5, 0.9),
203 | amsgrad=True)
204 | return [opt], []
205 |
206 |
207 | class Net2BigGANFlow(Net2NetFlow):
208 | def __init__(self,
209 | flow_config,
210 | gan_config,
211 | cond_stage_config,
212 | make_cond_config,
213 | ckpt_path=None,
214 | ignore_keys=[],
215 | cond_stage_key="caption"
216 | ):
217 | super().__init__(flow_config=flow_config,
218 | first_stage_config=gan_config, cond_stage_config=cond_stage_config,
219 | ckpt_path=ckpt_path, ignore_keys=ignore_keys, cond_stage_key=cond_stage_key
220 | )
221 |
222 | self.init_to_c_model(make_cond_config)
223 | self.init_preprocessing()
224 |
225 | @torch.no_grad()
226 | def get_input(self, batch, move_to_device=False):
227 | zin = batch["z"]
228 | cin = batch["class"]
229 | if move_to_device:
230 | zin, cin = zin.to(self.device), cin.to(self.device)
231 | # dequantize the discrete class code
232 | cin = self.first_stage_model.embed_labels(cin, labels_are_one_hot=False)
233 | split_sizes = [zin.shape[1], cin.shape[1]]
234 | xin = self.first_stage_model.generate_from_embedding(zin, cin)
235 | cin = self.dequantizer(cin)
236 | xc = self.to_c_model(xin)
237 | zflow = torch.cat([zin, cin.detach()], dim=1)[:, :, None, None] # this will be flowed
238 | return {"zcode": zflow,
239 | "xgen": xin,
240 | "xcon": xc,
241 | "split_sizes": split_sizes
242 | }
243 |
244 | def init_to_c_model(self, config):
245 | model = instantiate_from_config(config)
246 | model = model.eval()
247 | model.train = disabled_train
248 | self.to_c_model = model
249 |
250 | def init_preprocessing(self):
251 | dqcfg = {"target": "net2net.modules.autoencoder.basic.BasicFullyConnectedVAE"}
252 | self.dequantizer = instantiate_from_config(dqcfg)
253 | ckpt = get_ckpt_path("dequant_vae", "net2net/modules/autoencoder/dequant_vae")
254 | self.dequantizer.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
255 | self.dequantizer.eval()
256 | self.dequantizer.train = disabled_train
257 |
258 | def shared_step(self, batch, batch_idx, split="train"):
259 | data = self.get_input(batch)
260 | z, c = data["zcode"], data["xcon"]
261 | zz, logdet = self(z, c)
262 | loss, log_dict = self.loss(zz, logdet, split=split)
263 | return loss, log_dict
264 |
265 | def forward(self, z, c):
266 | c = self.encode_to_c(c)
267 | zz, logdet = self.flow(z, c)
268 | return zz, logdet
269 |
270 | @torch.no_grad()
271 | def log_images(self, batch, split=""):
272 | log = dict()
273 | data = self.get_input(batch, move_to_device=True)
274 | z, xc, x = data["zcode"], data["xcon"], data["xgen"]
275 | c = self.encode_to_c(xc)
276 | zz, _ = self.flow(z, c)
277 |
278 | z_sample = self.sample_conditional(c)
279 | zdec, cdec = torch.split(z_sample, data["split_sizes"], dim=1)
280 | xsample = self.first_stage_model.generate_from_embedding(zdec.squeeze(-1).squeeze(-1),
281 | cdec.squeeze(-1).squeeze(-1))
282 |
283 | cshift = torch.cat((c[1:],c[:1]),dim=0)
284 | zshift = self.flow.reverse(zz, cshift)
285 | zshift, cshift = torch.split(zshift, data["split_sizes"], dim=1)
286 | xshift = self.first_stage_model.generate_from_embedding(zshift.squeeze(-1).squeeze(-1),
287 | cshift.squeeze(-1).squeeze(-1))
288 |
289 | log["inputs"] = x
290 | if self.cond_stage_key not in ["text", "caption", "class"]:
291 | log["conditioning"] = xc
292 | else:
293 | _,_,h,w = x.shape
294 | log["conditioning"] = log_txt_as_img((w,h), xc)
295 |
296 | log["shift"] = xshift
297 | log["samples"] = xsample
298 | return log
299 |
300 | @torch.no_grad()
301 | def sample_conditional(self, c):
302 | z = self.flow.sample(c)
303 | return z
304 |
--------------------------------------------------------------------------------
/net2net/models/flows/util.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 | import numpy as np
4 | import matplotlib.pyplot as plt
5 | #from sklearn.neighbors import KernelDensity
6 |
7 |
8 | def kde2D(x, y, bandwidth, xbins=250j, ybins=250j, **kwargs):
9 | """Build 2D kernel density estimate (KDE)."""
10 |
11 | # create grid of sample locations (default: 100x100)
12 | xx, yy = np.mgrid[x.min():x.max():xbins,
13 | y.min():y.max():ybins]
14 |
15 | xy_sample = np.vstack([yy.ravel(), xx.ravel()]).T
16 | xy_train = np.vstack([y, x]).T
17 |
18 | kde_skl = KernelDensity(bandwidth=bandwidth, **kwargs)
19 | kde_skl.fit(xy_train)
20 |
21 | # score_samples() returns the log-likelihood of the samples
22 | z = np.exp(kde_skl.score_samples(xy_sample))
23 | return xx, yy, np.reshape(z, xx.shape)
24 |
25 |
26 | def plot2d(x, savepath=None):
27 | """make a scatter plot of x and return an Image of it"""
28 | x = x.cpu().numpy().squeeze()
29 | fig = plt.figure(dpi=300)
30 | xx, yy, zz = kde2D(x[:,0], x[:,1], 0.1)
31 | plt.pcolormesh(xx, yy, zz)
32 | plt.scatter(x[:,0], x[:, 1], s=0.1, c='mistyrose')
33 | if savepath is not None:
34 | plt.savefig(savepath, dpi=300)
35 | return fig
36 |
37 |
38 | def reshape_to_grid(x, num_samples=16, iw=28, ih=28, nc=1):
39 | x = x[:num_samples]
40 | x = x.detach().cpu()
41 | x = torch.reshape(x, (x.shape[0], nc, iw, ih))
42 | xgrid = torchvision.utils.make_grid(x)
43 | return xgrid
--------------------------------------------------------------------------------
/net2net/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/net2net/modules/__init__.py
--------------------------------------------------------------------------------
/net2net/modules/autoencoder/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/net2net/modules/autoencoder/__init__.py
--------------------------------------------------------------------------------
/net2net/modules/autoencoder/basic.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import functools
4 |
5 | from net2net.modules.distributions.distributions import DiagonalGaussianDistribution
6 |
7 |
8 | def weights_init(m):
9 | classname = m.__class__.__name__
10 | if classname.find('Conv') != -1:
11 | nn.init.normal_(m.weight.data, 0.0, 0.02)
12 | elif classname.find('BatchNorm') != -1:
13 | nn.init.normal_(m.weight.data, 1.0, 0.02)
14 | nn.init.constant_(m.bias.data, 0)
15 |
16 |
17 | class ActNorm(nn.Module):
18 | def __init__(self, num_features, logdet=False, affine=True,
19 | allow_reverse_init=False):
20 | assert affine
21 | super().__init__()
22 | self.logdet = logdet
23 | self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
24 | self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
25 | self.allow_reverse_init = allow_reverse_init
26 |
27 | self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
28 |
29 | def initialize(self, input):
30 | with torch.no_grad():
31 | flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
32 | mean = (
33 | flatten.mean(1)
34 | .unsqueeze(1)
35 | .unsqueeze(2)
36 | .unsqueeze(3)
37 | .permute(1, 0, 2, 3)
38 | )
39 | std = (
40 | flatten.std(1)
41 | .unsqueeze(1)
42 | .unsqueeze(2)
43 | .unsqueeze(3)
44 | .permute(1, 0, 2, 3)
45 | )
46 |
47 | self.loc.data.copy_(-mean)
48 | self.scale.data.copy_(1 / (std + 1e-6))
49 |
50 | def forward(self, input, reverse=False):
51 | if reverse:
52 | return self.reverse(input)
53 | if len(input.shape) == 2:
54 | input = input[:,:,None,None]
55 | squeeze = True
56 | else:
57 | squeeze = False
58 |
59 | _, _, height, width = input.shape
60 |
61 | if self.training and self.initialized.item() == 0:
62 | self.initialize(input)
63 | self.initialized.fill_(1)
64 |
65 | h = self.scale * (input + self.loc)
66 |
67 | if squeeze:
68 | h = h.squeeze(-1).squeeze(-1)
69 |
70 | if self.logdet:
71 | log_abs = torch.log(torch.abs(self.scale))
72 | logdet = height*width*torch.sum(log_abs)
73 | logdet = logdet * torch.ones(input.shape[0]).to(input)
74 | return h, logdet
75 |
76 | return h
77 |
78 | def reverse(self, output):
79 | if self.training and self.initialized.item() == 0:
80 | if not self.allow_reverse_init:
81 | raise RuntimeError(
82 | "Initializing ActNorm in reverse direction is "
83 | "disabled by default. Use allow_reverse_init=True to enable."
84 | )
85 | else:
86 | self.initialize(output)
87 | self.initialized.fill_(1)
88 |
89 | if len(output.shape) == 2:
90 | output = output[:,:,None,None]
91 | squeeze = True
92 | else:
93 | squeeze = False
94 |
95 | h = output / self.scale - self.loc
96 |
97 | if squeeze:
98 | h = h.squeeze(-1).squeeze(-1)
99 | return h
100 |
101 |
102 | class BasicFullyConnectedNet(nn.Module):
103 | def __init__(self, dim, depth, hidden_dim=256, use_tanh=False, use_bn=False, out_dim=None, use_an=False):
104 | super(BasicFullyConnectedNet, self).__init__()
105 | layers = []
106 | layers.append(nn.Linear(dim, hidden_dim))
107 | if use_bn:
108 | assert not use_an
109 | layers.append(nn.BatchNorm1d(hidden_dim))
110 | if use_an:
111 | assert not use_bn
112 | layers.append(ActNorm(hidden_dim))
113 | layers.append(nn.LeakyReLU())
114 | for d in range(depth):
115 | layers.append(nn.Linear(hidden_dim, hidden_dim))
116 | if use_bn:
117 | layers.append(nn.BatchNorm1d(hidden_dim))
118 | layers.append(nn.LeakyReLU())
119 | layers.append(nn.Linear(hidden_dim, dim if out_dim is None else out_dim))
120 | if use_tanh:
121 | layers.append(nn.Tanh())
122 | self.main = nn.Sequential(*layers)
123 |
124 | def forward(self, x):
125 | return self.main(x)
126 |
127 |
128 | _norm_options = {
129 | "in": nn.InstanceNorm2d,
130 | "bn": nn.BatchNorm2d,
131 | "an": ActNorm}
132 |
133 |
134 | class BasicAEModel(nn.Module):
135 | def __init__(self, n_down, z_dim, in_size, in_channels, deterministic=False):
136 | super().__init__()
137 | bottleneck_size = in_size // 2**n_down
138 | norm = "an"
139 | self.be_deterministic = deterministic
140 |
141 | self.feature_layers = nn.ModuleList()
142 | self.decoder_layers = nn.ModuleList()
143 |
144 | self.feature_layers.append(FeatureLayer(0, in_channels=in_channels, norm=norm))
145 | for scale in range(1, n_down):
146 | self.feature_layers.append(FeatureLayer(scale, norm=norm))
147 |
148 | self.dense_encode = DenseEncoderLayer(n_down, bottleneck_size, 2*z_dim)
149 | self.dense_decode = DenseDecoderLayer(n_down-1, bottleneck_size, z_dim)
150 |
151 | for scale in range(n_down-1):
152 | self.decoder_layers.append(DecoderLayer(scale, norm=norm))
153 | self.image_layer = ImageLayer(out_channels=in_channels)
154 |
155 | self.apply(weights_init)
156 |
157 | self.n_down = n_down
158 | self.z_dim = z_dim
159 | self.bottleneck_size = bottleneck_size
160 |
161 | def encode(self, input):
162 | h = input
163 | for layer in self.feature_layers:
164 | h = layer(h)
165 | h = self.dense_encode(h)
166 | return DiagonalGaussianDistribution(h, deterministic=self.be_deterministic)
167 |
168 | def decode(self, input):
169 | h = input
170 | h = self.dense_decode(h)
171 | for layer in reversed(self.decoder_layers):
172 | h = layer(h)
173 | h = self.image_layer(h)
174 | return h
175 |
176 | def get_last_layer(self):
177 | return self.image_layer.sub_layers[0].weight
178 |
179 |
180 | class FeatureLayer(nn.Module):
181 | def __init__(self, scale, in_channels=None, norm='IN'):
182 | super().__init__()
183 | self.scale = scale
184 | self.norm = _norm_options[norm.lower()]
185 | if in_channels is None:
186 | self.in_channels = 64*min(2**(self.scale-1), 16)
187 | else:
188 | self.in_channels = in_channels
189 | self.build()
190 |
191 | def forward(self, input):
192 | x = input
193 | for layer in self.sub_layers:
194 | x = layer(x)
195 | return x
196 |
197 | def build(self):
198 | Norm = functools.partial(self.norm, affine=True)
199 | Activate = lambda: nn.LeakyReLU(0.2)
200 | self.sub_layers = nn.ModuleList([
201 | nn.Conv2d(
202 | in_channels=self.in_channels,
203 | out_channels=64*min(2**self.scale, 16),
204 | kernel_size=4,
205 | stride=2,
206 | padding=1,
207 | bias=False),
208 | Norm(num_features=64*min(2**self.scale, 16)),
209 | Activate()])
210 |
211 |
212 | class LatentLayer(nn.Module):
213 | def __init__(self, in_channels, out_channels):
214 | super(LatentLayer, self).__init__()
215 | self.in_channels = in_channels
216 | self.out_channels = out_channels
217 | self.build()
218 |
219 | def forward(self, input):
220 | x = input
221 | for layer in self.sub_layers:
222 | x = layer(x)
223 | return x
224 |
225 | def build(self):
226 | self.sub_layers = nn.ModuleList([
227 | nn.Conv2d(
228 | in_channels=self.in_channels,
229 | out_channels=self.out_channels,
230 | kernel_size=1,
231 | stride=1,
232 | padding=0,
233 | bias=True)
234 | ])
235 |
236 |
237 | class DecoderLayer(nn.Module):
238 | def __init__(self, scale, in_channels=None, norm='IN'):
239 | super().__init__()
240 | self.scale = scale
241 | self.norm = _norm_options[norm.lower()]
242 | if in_channels is not None:
243 | self.in_channels = in_channels
244 | else:
245 | self.in_channels = 64*min(2**(self.scale+1), 16)
246 | self.build()
247 |
248 | def forward(self, input):
249 | d = input
250 | for layer in self.sub_layers:
251 | d = layer(d)
252 | return d
253 |
254 | def build(self):
255 | Norm = functools.partial(self.norm, affine=True)
256 | Activate = lambda: nn.LeakyReLU(0.2)
257 | self.sub_layers = nn.ModuleList([
258 | nn.ConvTranspose2d(
259 | in_channels=self.in_channels,
260 | out_channels=64*min(2**self.scale, 16),
261 | kernel_size=4,
262 | stride=2,
263 | padding=1,
264 | bias=False),
265 | Norm(num_features=64*min(2**self.scale, 16)),
266 | Activate()])
267 |
268 |
269 | class DenseEncoderLayer(nn.Module):
270 | def __init__(self, scale, spatial_size, out_size, in_channels=None):
271 | super().__init__()
272 | self.scale = scale
273 | self.in_channels = 64*min(2**(self.scale-1), 16)
274 | if in_channels is not None:
275 | self.in_channels = in_channels
276 | self.out_channels = out_size
277 | self.kernel_size = spatial_size
278 | self.build()
279 |
280 | def forward(self, input):
281 | x = input
282 | for layer in self.sub_layers:
283 | x = layer(x)
284 | return x
285 |
286 | def build(self):
287 | self.sub_layers = nn.ModuleList([
288 | nn.Conv2d(
289 | in_channels=self.in_channels,
290 | out_channels=self.out_channels,
291 | kernel_size=self.kernel_size,
292 | stride=1,
293 | padding=0,
294 | bias=True)])
295 |
296 |
297 | class DenseDecoderLayer(nn.Module):
298 | def __init__(self, scale, spatial_size, in_size):
299 | super().__init__()
300 | self.scale = scale
301 | self.in_channels = in_size
302 | self.out_channels = 64*min(2**self.scale, 16)
303 | self.kernel_size = spatial_size
304 | self.build()
305 |
306 | def forward(self, input):
307 | x = input
308 | for layer in self.sub_layers:
309 | x = layer(x)
310 | return x
311 |
312 | def build(self):
313 | self.sub_layers = nn.ModuleList([
314 | nn.ConvTranspose2d(
315 | in_channels=self.in_channels,
316 | out_channels=self.out_channels,
317 | kernel_size=self.kernel_size,
318 | stride=1,
319 | padding=0,
320 | bias=True)])
321 |
322 |
323 | class ImageLayer(nn.Module):
324 | def __init__(self, out_channels=3, in_channels=64):
325 | super().__init__()
326 | self.in_channels = in_channels
327 | self.out_channels = out_channels
328 | self.build()
329 |
330 | def forward(self, input):
331 | x = input
332 | for layer in self.sub_layers:
333 | x = layer(x)
334 | return x
335 |
336 | def build(self):
337 | FinalActivate = lambda: torch.nn.Tanh()
338 | self.sub_layers = nn.ModuleList([
339 | nn.ConvTranspose2d(
340 | in_channels=self.in_channels,
341 | out_channels=self.out_channels,
342 | kernel_size=4,
343 | stride=2,
344 | padding=1,
345 | bias=False),
346 | FinalActivate()
347 | ])
348 |
349 |
350 | class BasicFullyConnectedVAE(nn.Module):
351 | def __init__(self, n_down=2, z_dim=128, in_channels=128, mid_channels=4096, use_bn=False, deterministic=False):
352 | super().__init__()
353 |
354 | self.be_deterministic = deterministic
355 | self.encoder = BasicFullyConnectedNet(dim=in_channels, depth=n_down,
356 | hidden_dim=mid_channels,
357 | out_dim=in_channels,
358 | use_bn=use_bn)
359 | self.mu_layer = BasicFullyConnectedNet(in_channels, depth=n_down,
360 | hidden_dim=mid_channels,
361 | out_dim=z_dim,
362 | use_bn=use_bn)
363 | self.logvar_layer = BasicFullyConnectedNet(in_channels, depth=n_down,
364 | hidden_dim=mid_channels,
365 | out_dim=z_dim,
366 | use_bn=use_bn)
367 | self.decoder = BasicFullyConnectedNet(dim=z_dim, depth=n_down+1,
368 | hidden_dim=mid_channels,
369 | out_dim=in_channels,
370 | use_bn=use_bn)
371 |
372 | def encode(self, x):
373 | h = self.encoder(x)
374 | mu = self.mu_layer(h)
375 | logvar = self.logvar_layer(h)
376 | return DiagonalGaussianDistribution(torch.cat((mu, logvar), dim=1), deterministic=self.be_deterministic)
377 |
378 | def decode(self, x):
379 | x = self.decoder(x)
380 | return x
381 |
382 | def forward(self, x):
383 | x = self.encode(x).sample()
384 | x = self.decoder(x)
385 | return x
386 |
387 | def get_last_layer(self):
388 | return self.decoder.main[-1].weight
389 |
--------------------------------------------------------------------------------
/net2net/modules/autoencoder/decoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from net2net.modules.gan.biggan import load_variable_latsize_generator
5 |
6 | class ClassUp(nn.Module):
7 | def __init__(self, dim, depth, hidden_dim=256, use_sigmoid=False, out_dim=None):
8 | super().__init__()
9 | layers = []
10 | layers.append(nn.Linear(dim, hidden_dim))
11 | layers.append(nn.LeakyReLU())
12 | for d in range(depth):
13 | layers.append(nn.Linear(hidden_dim, hidden_dim))
14 | layers.append(nn.LeakyReLU())
15 | layers.append(nn.Linear(hidden_dim, dim if out_dim is None else out_dim))
16 | if use_sigmoid:
17 | layers.append(nn.Sigmoid())
18 | self.main = nn.Sequential(*layers)
19 |
20 | def forward(self, x):
21 | x = self.main(x.squeeze(-1).squeeze(-1))
22 | x = torch.nn.functional.softmax(x, dim=1)
23 | return x
24 |
25 |
26 | class BigGANDecoderWrapper(nn.Module):
27 | """Wraps a BigGAN into our autoencoding framework"""
28 | def __init__(self, z_dim, in_size=128, use_actnorm_in_dec=False, extra_z_dims=list()):
29 | super().__init__()
30 | self.z_dim = z_dim
31 | class_embedding_dim = 1000
32 | self.extra_z_dims = extra_z_dims
33 | self.map_to_class_embedding = ClassUp(z_dim, depth=2, hidden_dim=2*class_embedding_dim,
34 | use_sigmoid=False, out_dim=class_embedding_dim)
35 | self.decoder = load_variable_latsize_generator(in_size, z_dim,
36 | use_actnorm=use_actnorm_in_dec,
37 | n_class=class_embedding_dim,
38 | extra_z_dims=self.extra_z_dims)
39 |
40 | def forward(self, x, labels=None):
41 | emb = self.map_to_class_embedding(x[:,:self.z_dim,...])
42 | x = self.decoder(x, emb)
43 | return x
--------------------------------------------------------------------------------
/net2net/modules/autoencoder/encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torchvision
4 | from torchvision import models
5 |
6 | from net2net.modules.autoencoder.basic import ActNorm, DenseEncoderLayer
7 |
8 |
9 | class ResnetEncoder(nn.Module):
10 | def __init__(self, z_dim, in_size, in_channels=3,
11 | pretrained=False, type="resnet50",
12 | double_z=True, pre_process=True,
13 | ):
14 | super().__init__()
15 | __possible_resnets = {
16 | 'resnet18': models.resnet18,
17 | 'resnet34': models.resnet34,
18 | 'resnet50': models.resnet50,
19 | 'resnet101': models.resnet101
20 | }
21 | self.use_preprocess = pre_process
22 | self.in_channels = in_channels
23 | norm_layer = ActNorm
24 | self.z_dim = z_dim
25 | self.model = __possible_resnets[type](pretrained=pretrained, norm_layer=norm_layer)
26 |
27 | self.image_transform = torchvision.transforms.Compose(
28 | [torchvision.transforms.Lambda(self.normscale)]
29 | )
30 |
31 | size_pre_fc = self.get_spatial_size(in_size)
32 | assert size_pre_fc[2]==size_pre_fc[3], 'Output spatial size is not quadratic'
33 | spatial_size = size_pre_fc[2]
34 | num_channels_pre_fc = size_pre_fc[1]
35 | # replace last fc
36 | self.model.fc = DenseEncoderLayer(0,
37 | spatial_size=spatial_size,
38 | out_size=2*z_dim if double_z else z_dim,
39 | in_channels=num_channels_pre_fc)
40 | if self.in_channels != 3:
41 | self.model.in_ch_match = nn.Conv2d(self.in_channels, 3, 3, 1)
42 |
43 | def forward(self, x):
44 | if self.use_preprocess:
45 | x = self.pre_process(x)
46 | if self.in_channels != 3:
47 | assert not self.use_preprocess
48 | x = self.model.in_ch_match(x)
49 | features = self.features(x)
50 | encoding = self.model.fc(features)
51 | return encoding
52 |
53 | def rescale(self, x):
54 | return 0.5 * (x + 1)
55 |
56 | def normscale(self, image):
57 | normalize = torchvision.transforms.Normalize(mean=self.mean, std=self.std)
58 | return torch.stack([normalize(self.rescale(x)) for x in image])
59 |
60 | def features(self, x):
61 | if self.use_preprocess:
62 | x = self.pre_process(x)
63 | x = self.model.conv1(x)
64 | x = self.model.bn1(x)
65 | x = self.model.relu(x)
66 | x = self.model.maxpool(x)
67 | x = self.model.layer1(x)
68 | x = self.model.layer2(x)
69 | x = self.model.layer3(x)
70 | x = self.model.layer4(x)
71 | x = self.model.avgpool(x)
72 | return x
73 |
74 | def post_features(self, x):
75 | x = self.model.fc(x)
76 | return x
77 |
78 | def pre_process(self, x):
79 | x = self.image_transform(x)
80 | return x
81 |
82 | def get_spatial_size(self, ipt_size):
83 | x = torch.randn(1, 3, ipt_size, ipt_size)
84 | return self.features(x).size()
85 |
86 | @property
87 | def mean(self):
88 | return [0.485, 0.456, 0.406]
89 |
90 | @property
91 | def std(self):
92 | return [0.229, 0.224, 0.225]
93 |
94 | @property
95 | def input_size(self):
96 | return [3, 224, 224]
97 |
--------------------------------------------------------------------------------
/net2net/modules/autoencoder/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from net2net.modules.autoencoder.lpips import LPIPS # LPIPS loss
6 | from net2net.modules.discriminator.model import NLayerDiscriminator, weights_init
7 |
8 |
9 | def adopt_weight(weight, global_step, threshold=0, value=0.):
10 | if global_step < threshold:
11 | weight = value
12 | return weight
13 |
14 |
15 | def hinge_d_loss(logits_real, logits_fake):
16 | loss_real = torch.mean(F.relu(1. - logits_real))
17 | loss_fake = torch.mean(F.relu(1. + logits_fake))
18 | d_loss = 0.5 * (loss_real + loss_fake)
19 | return d_loss
20 |
21 |
22 | def vanilla_d_loss(logits_real, logits_fake):
23 | d_loss = 0.5 * (
24 | torch.mean(torch.nn.functional.softplus(-logits_real)) +
25 | torch.mean(torch.nn.functional.softplus(logits_fake)))
26 | return d_loss
27 |
28 |
29 | class LPIPSWithDiscriminator(nn.Module):
30 | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,
31 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
32 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
33 | evo_disc=False, disc_loss="hinge"):
34 |
35 | super().__init__()
36 | assert disc_loss in ["hinge", "vanilla"]
37 | self.kl_weight = kl_weight
38 | self.pixel_weight = pixelloss_weight
39 | self.perceptual_loss = LPIPS().eval()
40 | self.perceptual_weight = perceptual_weight
41 | # output log variance
42 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
43 |
44 | if evo_disc:
45 | self.discriminator = NLayerDiscriminatorEvoNorm(input_nc=disc_in_channels,
46 | n_layers=disc_num_layers
47 | ).apply(weights_init)
48 | else:
49 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
50 | n_layers=disc_num_layers,
51 | use_actnorm=use_actnorm
52 | ).apply(weights_init)
53 | self.discriminator_iter_start = disc_start
54 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
55 | self.disc_factor = disc_factor
56 | self.discriminator_weight = disc_weight
57 | self.disc_conditional = disc_conditional
58 |
59 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
60 | if last_layer is not None:
61 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
62 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
63 | else:
64 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
65 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
66 |
67 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
68 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
69 | d_weight = d_weight * self.discriminator_weight
70 | return d_weight
71 |
72 | def forward(self, inputs, reconstructions, posteriors, optimizer_idx,
73 | global_step, last_layer=None, cond=None, split="train",
74 | side_outputs=None, weights=None):
75 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
76 | if self.perceptual_weight > 0:
77 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
78 | rec_loss = rec_loss + self.perceptual_weight * p_loss
79 |
80 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
81 | weighted_nll_loss = nll_loss
82 | if weights is not None:
83 | weighted_nll_loss = weights*nll_loss
84 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
85 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
86 | kl_loss = posteriors.kl()
87 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
88 |
89 | # now the GAN part
90 | if optimizer_idx == 0:
91 | # generator update
92 | if cond is None:
93 | assert not self.disc_conditional
94 | logits_fake = self.discriminator(reconstructions.contiguous())
95 | else:
96 | assert self.disc_conditional
97 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
98 | g_loss = -torch.mean(logits_fake)
99 |
100 | if self.disc_factor > 0.0:
101 | try:
102 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
103 | except RuntimeError:
104 | assert not self.training
105 | d_weight = torch.tensor(0.0)
106 | else:
107 | d_weight = torch.tensor(0.0)
108 |
109 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
110 | loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss
111 |
112 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(),
113 | "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(),
114 | "{}/rec_loss".format(split): rec_loss.detach().mean(),
115 | "{}/d_weight".format(split): d_weight.detach(),
116 | "{}/disc_factor".format(split): torch.tensor(disc_factor),
117 | "{}/g_loss".format(split): g_loss.detach().mean(),
118 | }
119 | return loss, log
120 |
121 | if optimizer_idx == 1:
122 | # second pass for discriminator update
123 | if cond is None:
124 | logits_real = self.discriminator(inputs.contiguous().detach())
125 | logits_fake = self.discriminator(reconstructions.contiguous().detach())
126 | else:
127 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
128 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
129 |
130 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
131 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
132 |
133 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
134 | "{}/logits_real".format(split): logits_real.detach().mean(),
135 | "{}/logits_fake".format(split): logits_fake.detach().mean()
136 | }
137 | return d_loss, log
138 |
139 | class DummyLoss:
140 | pass
--------------------------------------------------------------------------------
/net2net/modules/autoencoder/lpips.py:
--------------------------------------------------------------------------------
1 | """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torchvision import models
6 | from collections import namedtuple
7 |
8 | from net2net.ckpt_util import get_ckpt_path
9 |
10 |
11 | class LPIPS(nn.Module):
12 | # Learned perceptual metric
13 | def __init__(self, use_dropout=True):
14 | super().__init__()
15 | self.scaling_layer = ScalingLayer()
16 | self.chns = [64, 128, 256, 512, 512] # vg16 features
17 | self.net = vgg16(pretrained=True, requires_grad=False)
18 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
19 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
20 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
21 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
22 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
23 | self.load_from_pretrained()
24 | for param in self.parameters():
25 | param.requires_grad = False
26 |
27 | def load_from_pretrained(self, name="vgg_lpips"):
28 | ckpt = get_ckpt_path(name, "net2net/modules/autoencoder/lpips")
29 | self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
30 | print("loaded pretrained LPIPS loss from {}".format(ckpt))
31 |
32 | @classmethod
33 | def from_pretrained(cls, name="vgg_lpips"):
34 | if name is not "vgg_lpips":
35 | raise NotImplementedError
36 | model = cls()
37 | ckpt = get_ckpt_path(name)
38 | model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
39 | return model
40 |
41 | def forward(self, input, target):
42 | in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
43 | outs0, outs1 = self.net(in0_input), self.net(in1_input)
44 | feats0, feats1, diffs = {}, {}, {}
45 | lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
46 | for kk in range(len(self.chns)):
47 | feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
48 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
49 |
50 | res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]
51 | val = res[0]
52 | for l in range(1, len(self.chns)):
53 | val += res[l]
54 | return val
55 |
56 |
57 | class ScalingLayer(nn.Module):
58 | def __init__(self):
59 | super(ScalingLayer, self).__init__()
60 | self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
61 | self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])
62 |
63 | def forward(self, inp):
64 | return (inp - self.shift) / self.scale
65 |
66 |
67 | class NetLinLayer(nn.Module):
68 | """ A single linear layer which does a 1x1 conv """
69 | def __init__(self, chn_in, chn_out=1, use_dropout=False):
70 | super(NetLinLayer, self).__init__()
71 | layers = [nn.Dropout(), ] if (use_dropout) else []
72 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
73 | self.model = nn.Sequential(*layers)
74 |
75 |
76 | class vgg16(torch.nn.Module):
77 | def __init__(self, requires_grad=False, pretrained=True):
78 | super(vgg16, self).__init__()
79 | vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
80 | self.slice1 = torch.nn.Sequential()
81 | self.slice2 = torch.nn.Sequential()
82 | self.slice3 = torch.nn.Sequential()
83 | self.slice4 = torch.nn.Sequential()
84 | self.slice5 = torch.nn.Sequential()
85 | self.N_slices = 5
86 | for x in range(4):
87 | self.slice1.add_module(str(x), vgg_pretrained_features[x])
88 | for x in range(4, 9):
89 | self.slice2.add_module(str(x), vgg_pretrained_features[x])
90 | for x in range(9, 16):
91 | self.slice3.add_module(str(x), vgg_pretrained_features[x])
92 | for x in range(16, 23):
93 | self.slice4.add_module(str(x), vgg_pretrained_features[x])
94 | for x in range(23, 30):
95 | self.slice5.add_module(str(x), vgg_pretrained_features[x])
96 | if not requires_grad:
97 | for param in self.parameters():
98 | param.requires_grad = False
99 |
100 | def forward(self, X):
101 | h = self.slice1(X)
102 | h_relu1_2 = h
103 | h = self.slice2(h)
104 | h_relu2_2 = h
105 | h = self.slice3(h)
106 | h_relu3_3 = h
107 | h = self.slice4(h)
108 | h_relu4_3 = h
109 | h = self.slice5(h)
110 | h_relu5_3 = h
111 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
112 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
113 | return out
114 |
115 |
116 | def normalize_tensor(x,eps=1e-10):
117 | norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True))
118 | return x/(norm_factor+eps)
119 |
120 |
121 | def spatial_average(x, keepdim=True):
122 | return x.mean([2,3],keepdim=keepdim)
123 |
--------------------------------------------------------------------------------
/net2net/modules/captions/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/net2net/modules/captions/__init__.py
--------------------------------------------------------------------------------
/net2net/modules/captions/model.py:
--------------------------------------------------------------------------------
1 | """Code is based on https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning"""
2 |
3 | import os, sys
4 | import json
5 | import torch
6 | import torch.nn as nn
7 | import numpy as np
8 | import matplotlib.pyplot as plt
9 | import torchvision
10 | import torch.nn.functional as F
11 | from torchvision import transforms
12 | from PIL import Image
13 |
14 | from net2net.ckpt_util import get_ckpt_path
15 |
16 | #import warnings
17 | #warnings.filterwarnings("ignore")
18 |
19 | from net2net.modules.captions.models import Encoder, DecoderWithAttention
20 |
21 |
22 | rescale = lambda x: 0.5*(x+1)
23 |
24 |
25 | def imresize(img, size):
26 | return np.array(Image.fromarray(img).resize(size))
27 |
28 |
29 | class Img2Text(nn.Module):
30 | def __init__(self):
31 | super().__init__()
32 | model_path = get_ckpt_path("coco_captioner", "net2net/modules/captions")
33 | word_map_path = "data/WORDMAP_coco_5_cap_per_img_5_min_word_freq.json"
34 |
35 | # Load word map (word2ix)
36 | with open(word_map_path, 'r') as j:
37 | word_map = json.load(j)
38 | rev_word_map = {v: k for k, v in word_map.items()} # ix2word
39 | self.word_map = word_map
40 | self.rev_word_map = rev_word_map
41 |
42 | checkpoint = torch.load(model_path)
43 |
44 | self.encoder = Encoder()
45 | self.decoder = DecoderWithAttention(embed_dim=512, decoder_dim=512, attention_dim=512, vocab_size=9490)
46 | missing, unexpected = self.load_state_dict(checkpoint, strict=False)
47 | if len(missing) > 0:
48 | print(f"Missing keys in state-dict: {missing}")
49 | if len(unexpected) > 0:
50 | print(f"Unexpected keys in state-dict: {unexpected}")
51 | self.encoder.eval()
52 | self.decoder.eval()
53 |
54 | resize = transforms.Lambda(lambda image: F.interpolate(image, size=(256, 256), mode="bilinear"))
55 | normalize = torchvision.transforms.Normalize(mean=self.mean, std=self.std)
56 | norm = torchvision.transforms.Lambda(lambda image: torch.stack([normalize(rescale(x)) for x in image]))
57 | self.img_transform = transforms.Compose([resize, norm])
58 | self.device = "cuda"
59 |
60 | def _pre_process(self, x):
61 | x = self.img_transform(x)
62 | return x
63 |
64 | @property
65 | def mean(self):
66 | return [0.485, 0.456, 0.406]
67 |
68 | @property
69 | def std(self):
70 | return [0.229, 0.224, 0.225]
71 |
72 | def forward(self, x):
73 | captions = list()
74 | for subx in x:
75 | subx = subx.unsqueeze(0)
76 | captions.append(self.make_single_caption(subx))
77 | return captions
78 |
79 | def make_single_caption(self, x):
80 | seq = self.caption_image_beam_search(x)[0][0]
81 | words = [self.rev_word_map[ind] for ind in seq]
82 | words = words[:50]
83 | #if len(words) > 50:
84 | # return np.array([''])
85 | text = ''
86 | for word in words:
87 | text += word + ' '
88 | return text
89 |
90 | def caption_image_beam_search(self, image, beam_size=3):
91 | """
92 | Reads a batch of images and captions each of it with beam search.
93 | :param image: batch of pytorch images
94 | :param beam_size: number of sequences to consider at each decode-step
95 | :return: caption, weights for visualization
96 | """
97 |
98 | k = beam_size
99 | vocab_size = len(self.word_map)
100 |
101 | # Encode
102 | # image is a batch of images
103 | encoder_out_ = self.encoder(image) # (b, enc_image_size, enc_image_size, encoder_dim)
104 | enc_image_size = encoder_out_.size(1)
105 | encoder_dim = encoder_out_.size(3)
106 | batch_size = encoder_out_.size(0)
107 |
108 | # Flatten encoding
109 | encoder_out_ = encoder_out_.view(batch_size, -1, encoder_dim) # (1, num_pixels, encoder_dim)
110 | num_pixels = encoder_out_.size(1)
111 |
112 | sequences = list()
113 | alphas_ = list()
114 | # We'll treat the problem as having a batch size of k per example
115 | for single_example in encoder_out_:
116 | single_example = single_example[None, ...]
117 | encoder_out = single_example.expand(k, num_pixels, encoder_dim) # (k, num_pixels, encoder_dim)
118 |
119 | # Tensor to store top k previous words at each step; now they're just
120 | k_prev_words = torch.LongTensor([[self.word_map['']]] * k).to(self.device) # (k, 1)
121 |
122 | # Tensor to store top k sequences; now they're just
123 | seqs = k_prev_words # (k, 1)
124 |
125 | # Tensor to store top k sequences' scores; now they're just 0
126 | top_k_scores = torch.zeros(k, 1).to(self.device) # (k, 1)
127 |
128 | # Tensor to store top k sequences' alphas; now they're just 1s
129 | seqs_alpha = torch.ones(k, 1, enc_image_size, enc_image_size).to(self.device) # (k, 1, enc_image_size, enc_image_size)
130 |
131 | # Lists to store completed sequences, their alphas and scores
132 | complete_seqs = list()
133 | complete_seqs_alpha = list()
134 | complete_seqs_scores = list()
135 |
136 | # Start decoding
137 | step = 1
138 | h, c = self.decoder.init_hidden_state(encoder_out)
139 |
140 | # s is a number less than or equal to k, because sequences are removed from this process once they hit
141 | while True:
142 | embeddings = self.decoder.embedding(k_prev_words).squeeze(1) # (s, embed_dim)
143 | awe, alpha = self.decoder.attention(encoder_out, h) # (s, encoder_dim), (s, num_pixels)
144 | alpha = alpha.view(-1, enc_image_size, enc_image_size) # (s, enc_image_size, enc_image_size)
145 | gate = self.decoder.sigmoid(self.decoder.f_beta(h)) # gating scalar, (s, encoder_dim)
146 | awe = gate * awe
147 | h, c = self.decoder.decode_step(torch.cat([embeddings, awe], dim=1), (h, c)) # (s, decoder_dim)
148 | scores = self.decoder.fc(h) # (s, vocab_size)
149 | scores = F.log_softmax(scores, dim=1)
150 | # Add
151 | scores = top_k_scores.expand_as(scores) + scores # (s, vocab_size)
152 |
153 | # For the first step, all k points will have the same scores (since same k previous words, h, c)
154 | if step == 1:
155 | top_k_scores, top_k_words = scores[0].topk(k, 0, True, True) # (s)
156 | else:
157 | # Unroll and find top scores, and their unrolled indices
158 | top_k_scores, top_k_words = scores.view(-1).topk(k, 0, True, True) # (s)
159 |
160 | # Convert unrolled indices to actual indices of scores
161 | prev_word_inds = top_k_words // vocab_size # (s)
162 | next_word_inds = top_k_words % vocab_size # (s)
163 |
164 | # Add new words to sequences, alphas
165 | seqs = torch.cat([seqs[prev_word_inds], next_word_inds.unsqueeze(1)], dim=1) # (s, step+1)
166 | seqs_alpha = torch.cat([seqs_alpha[prev_word_inds], alpha[prev_word_inds].unsqueeze(1)],
167 | dim=1) # (s, step+1, enc_image_size, enc_image_size)
168 |
169 | # Which sequences are incomplete (didn't reach )?
170 | incomplete_inds = [ind for ind, next_word in enumerate(next_word_inds) if
171 | next_word != self.word_map['']]
172 | complete_inds = list(set(range(len(next_word_inds))) - set(incomplete_inds))
173 |
174 | # Set aside complete sequences
175 | if len(complete_inds) > 0:
176 | complete_seqs.extend(seqs[complete_inds].tolist())
177 | complete_seqs_alpha.extend(seqs_alpha[complete_inds].tolist())
178 | complete_seqs_scores.extend(top_k_scores[complete_inds])
179 | k -= len(complete_inds) # reduce beam length accordingly
180 |
181 | # Proceed with incomplete sequences
182 | if k == 0:
183 | break
184 | seqs = seqs[incomplete_inds]
185 | seqs_alpha = seqs_alpha[incomplete_inds]
186 | h = h[prev_word_inds[incomplete_inds]]
187 | c = c[prev_word_inds[incomplete_inds]]
188 | encoder_out = encoder_out[prev_word_inds[incomplete_inds]]
189 | top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1)
190 | k_prev_words = next_word_inds[incomplete_inds].unsqueeze(1)
191 |
192 | # Break if things have been going on too long
193 | if step > 50:
194 | break
195 | step += 1
196 |
197 | try:
198 | i = complete_seqs_scores.index(max(complete_seqs_scores))
199 | seq = complete_seqs[i]
200 | alphas = complete_seqs_alpha[i]
201 | except ValueError:
202 | print("Catching an empty sequence.")
203 | try:
204 | len_ = len(sequences[-1])
205 | seq = [0]*len_
206 | alphas = None
207 | except:
208 | seq = [0]*9
209 | alphas = None
210 |
211 | sequences.append(seq)
212 | alphas_.append(alphas)
213 |
214 | return sequences, alphas_
215 |
216 | def visualize_text(self, root, images, sequences, n_row=5, img_name='examples'):
217 | """
218 | plot the text corresponding to the given images in a matplotlib figure.
219 | images are a batch of pytorch images
220 | """
221 |
222 | n_img = images.size(0)
223 | n_col = max(n_img // n_row + 1, 2)
224 |
225 | fig, ax = plt.subplots(n_row, n_col)
226 |
227 | i = 0
228 | j = 0
229 | for image, seq in zip(images, sequences):
230 | if i == n_row:
231 | i = 0
232 | j += 1
233 | image = image.cpu().numpy().transpose(1, 2, 0)
234 | image = 255*(0.5*(image+1))
235 | image = Image.fromarray(image.astype('uint8'))
236 | image = image.resize([14 * 24, 14 * 24], Image.LANCZOS)
237 | words = [self.rev_word_map[ind] for ind in seq]
238 | if len(words) > 50:
239 | return
240 | text = ''
241 | for word in words:
242 | text += word + ' '
243 |
244 | ax[i, j].text(0, 1, '%s' % (text), color='black', backgroundcolor='white', fontsize=12)
245 | ax[i, j].imshow(image)
246 | ax[i, j].axis('off')
247 |
248 | plt.savefig(os.path.join(root, img_name + '.png'))
249 |
250 |
251 | if __name__ == '__main__':
252 | model = Img2Text()
253 | print("done.")
254 |
--------------------------------------------------------------------------------
/net2net/modules/captions/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torchvision
4 |
5 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
6 |
7 |
8 | class Encoder(nn.Module):
9 | """
10 | Encoder.
11 | """
12 |
13 | def __init__(self, encoded_image_size=14):
14 | super(Encoder, self).__init__()
15 | self.enc_image_size = encoded_image_size
16 |
17 | resnet = torchvision.models.resnet101(pretrained=True) # pretrained ImageNet ResNet-101
18 |
19 | # Remove linear and pool layers (since we're not doing classification)
20 | modules = list(resnet.children())[:-2]
21 | self.resnet = nn.Sequential(*modules)
22 |
23 | # Resize image to fixed size to allow input images of variable size
24 | self.adaptive_pool = nn.AdaptiveAvgPool2d((encoded_image_size, encoded_image_size))
25 |
26 | self.fine_tune()
27 |
28 | def forward(self, images):
29 | """
30 | Forward propagation.
31 |
32 | :param images: images, a tensor of dimensions (batch_size, 3, image_size, image_size)
33 | :return: encoded images
34 | """
35 | out = self.resnet(images) # (batch_size, 2048, image_size/32, image_size/32)
36 | out = self.adaptive_pool(out) # (batch_size, 2048, encoded_image_size, encoded_image_size)
37 | out = out.permute(0, 2, 3, 1) # (batch_size, encoded_image_size, encoded_image_size, 2048)
38 | return out
39 |
40 | def fine_tune(self, fine_tune=True):
41 | """
42 | Allow or prevent the computation of gradients for convolutional blocks 2 through 4 of the encoder.
43 |
44 | :param fine_tune: Allow?
45 | """
46 | for p in self.resnet.parameters():
47 | p.requires_grad = False
48 | # If fine-tuning, only fine-tune convolutional blocks 2 through 4
49 | for c in list(self.resnet.children())[5:]:
50 | for p in c.parameters():
51 | p.requires_grad = fine_tune
52 |
53 |
54 | class Attention(nn.Module):
55 | """
56 | Attention Network.
57 | """
58 |
59 | def __init__(self, encoder_dim, decoder_dim, attention_dim):
60 | """
61 | :param encoder_dim: feature size of encoded images
62 | :param decoder_dim: size of decoder's RNN
63 | :param attention_dim: size of the attention network
64 | """
65 | super(Attention, self).__init__()
66 | self.encoder_att = nn.Linear(encoder_dim, attention_dim) # linear layer to transform encoded image
67 | self.decoder_att = nn.Linear(decoder_dim, attention_dim) # linear layer to transform decoder's output
68 | self.full_att = nn.Linear(attention_dim, 1) # linear layer to calculate values to be softmax-ed
69 | self.relu = nn.ReLU()
70 | #self.softmax = nn.Softmax(dim=1) # softmax layer to calculate weights
71 |
72 | def forward(self, encoder_out, decoder_hidden):
73 | """
74 | Forward propagation.
75 |
76 | :param encoder_out: encoded images, a tensor of dimension (batch_size, num_pixels, encoder_dim)
77 | :param decoder_hidden: previous decoder output, a tensor of dimension (batch_size, decoder_dim)
78 | :return: attention weighted encoding, weights
79 | """
80 | att1 = self.encoder_att(encoder_out) # (batch_size, num_pixels, attention_dim)
81 | att2 = self.decoder_att(decoder_hidden) # (batch_size, attention_dim)
82 | att = self.full_att(self.relu(att1 + att2.unsqueeze(1))).squeeze(2) # (batch_size, num_pixels)
83 | #alpha = self.softmax(att) # (batch_size, num_pixels)
84 | alpha = torch.nn.functional.softmax(att, dim=1)
85 | attention_weighted_encoding = (encoder_out * alpha.unsqueeze(2)).sum(dim=1) # (batch_size, encoder_dim)
86 |
87 | return attention_weighted_encoding, alpha
88 |
89 |
90 | class DecoderWithAttention(nn.Module):
91 | """
92 | Decoder.
93 | """
94 |
95 | def __init__(self, attention_dim, embed_dim, decoder_dim, vocab_size, encoder_dim=2048, dropout=0.5):
96 | """
97 | :param attention_dim: size of attention network
98 | :param embed_dim: embedding size
99 | :param decoder_dim: size of decoder's RNN
100 | :param vocab_size: size of vocabulary
101 | :param encoder_dim: feature size of encoded images
102 | :param dropout: dropout
103 | """
104 | super(DecoderWithAttention, self).__init__()
105 |
106 | self.encoder_dim = encoder_dim
107 | self.attention_dim = attention_dim
108 | self.embed_dim = embed_dim
109 | self.decoder_dim = decoder_dim
110 | self.vocab_size = vocab_size
111 | self.dropout = dropout
112 |
113 | self.attention = Attention(encoder_dim, decoder_dim, attention_dim) # attention network
114 |
115 | self.embedding = nn.Embedding(vocab_size, embed_dim) # embedding layer
116 | self.dropout = nn.Dropout(p=self.dropout)
117 | self.decode_step = nn.LSTMCell(embed_dim + encoder_dim, decoder_dim, bias=True) # decoding LSTMCell
118 | self.init_h = nn.Linear(encoder_dim, decoder_dim) # linear layer to find initial hidden state of LSTMCell
119 | self.init_c = nn.Linear(encoder_dim, decoder_dim) # linear layer to find initial cell state of LSTMCell
120 | self.f_beta = nn.Linear(decoder_dim, encoder_dim) # linear layer to create a sigmoid-activated gate
121 | self.sigmoid = nn.Sigmoid()
122 | self.fc = nn.Linear(decoder_dim, vocab_size) # linear layer to find scores over vocabulary
123 | self.init_weights() # initialize some layers with the uniform distribution
124 |
125 | def init_weights(self):
126 | """
127 | Initializes some parameters with values from the uniform distribution, for easier convergence.
128 | """
129 | self.embedding.weight.data.uniform_(-0.1, 0.1)
130 | self.fc.bias.data.fill_(0)
131 | self.fc.weight.data.uniform_(-0.1, 0.1)
132 |
133 | def load_pretrained_embeddings(self, embeddings):
134 | """
135 | Loads embedding layer with pre-trained embeddings.
136 |
137 | :param embeddings: pre-trained embeddings
138 | """
139 | self.embedding.weight = nn.Parameter(embeddings)
140 |
141 | def fine_tune_embeddings(self, fine_tune=True):
142 | """
143 | Allow fine-tuning of embedding layer? (Only makes sense to not-allow if using pre-trained embeddings).
144 |
145 | :param fine_tune: Allow?
146 | """
147 | for p in self.embedding.parameters():
148 | p.requires_grad = fine_tune
149 |
150 | def init_hidden_state(self, encoder_out):
151 | """
152 | Creates the initial hidden and cell states for the decoder's LSTM based on the encoded images.
153 |
154 | :param encoder_out: encoded images, a tensor of dimension (batch_size, num_pixels, encoder_dim)
155 | :return: hidden state, cell state
156 | """
157 | mean_encoder_out = encoder_out.mean(dim=1)
158 | h = self.init_h(mean_encoder_out) # (batch_size, decoder_dim)
159 | c = self.init_c(mean_encoder_out)
160 | return h, c
161 |
162 | def forward(self, encoder_out, encoded_captions, caption_lengths):
163 | """
164 | Forward propagation.
165 |
166 | :param encoder_out: encoded images, a tensor of dimension (batch_size, enc_image_size, enc_image_size, encoder_dim)
167 | :param encoded_captions: encoded captions, a tensor of dimension (batch_size, max_caption_length)
168 | :param caption_lengths: caption lengths, a tensor of dimension (batch_size, 1)
169 | :return: scores for vocabulary, sorted encoded captions, decode lengths, weights, sort indices
170 | """
171 |
172 | batch_size = encoder_out.size(0)
173 | encoder_dim = encoder_out.size(-1)
174 | vocab_size = self.vocab_size
175 |
176 | # Flatten image
177 | encoder_out = encoder_out.view(batch_size, -1, encoder_dim) # (batch_size, num_pixels, encoder_dim)
178 | num_pixels = encoder_out.size(1)
179 |
180 | # Sort input data by decreasing lengths; why? apparent below
181 | caption_lengths, sort_ind = caption_lengths.squeeze(1).sort(dim=0, descending=True)
182 | encoder_out = encoder_out[sort_ind]
183 | encoded_captions = encoded_captions[sort_ind]
184 |
185 | # Embedding
186 | embeddings = self.embedding(encoded_captions) # (batch_size, max_caption_length, embed_dim)
187 |
188 | # Initialize LSTM state
189 | h, c = self.init_hidden_state(encoder_out) # (batch_size, decoder_dim)
190 |
191 | # We won't decode at the position, since we've finished generating as soon as we generate
192 | # So, decoding lengths are actual lengths - 1
193 | decode_lengths = (caption_lengths - 1).tolist()
194 |
195 | # Create tensors to hold word predicion scores and alphas
196 | predictions = torch.zeros(batch_size, max(decode_lengths), vocab_size).to(device)
197 | alphas = torch.zeros(batch_size, max(decode_lengths), num_pixels).to(device)
198 |
199 | # At each time-step, decode by
200 | # attention-weighing the encoder's output based on the decoder's previous hidden state output
201 | # then generate a new word in the decoder with the previous word and the attention weighted encoding
202 | for t in range(max(decode_lengths)):
203 | batch_size_t = sum([l > t for l in decode_lengths])
204 | attention_weighted_encoding, alpha = self.attention(encoder_out[:batch_size_t],
205 | h[:batch_size_t])
206 | gate = self.sigmoid(self.f_beta(h[:batch_size_t])) # gating scalar, (batch_size_t, encoder_dim)
207 | attention_weighted_encoding = gate * attention_weighted_encoding
208 | h, c = self.decode_step(
209 | torch.cat([embeddings[:batch_size_t, t, :], attention_weighted_encoding], dim=1),
210 | (h[:batch_size_t], c[:batch_size_t])) # (batch_size_t, decoder_dim)
211 | preds = self.fc(self.dropout(h)) # (batch_size_t, vocab_size)
212 | predictions[:batch_size_t, t, :] = preds
213 | alphas[:batch_size_t, t, :] = alpha
214 |
215 | return predictions, encoded_captions, decode_lengths, alphas, sort_ind
216 |
--------------------------------------------------------------------------------
/net2net/modules/discriminator/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/net2net/modules/discriminator/__init__.py
--------------------------------------------------------------------------------
/net2net/modules/discriminator/model.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import torch.nn as nn
3 |
4 | from net2net.modules.autoencoder.basic import ActNorm
5 |
6 |
7 | def weights_init(m):
8 | classname = m.__class__.__name__
9 | if classname.find('Conv') != -1:
10 | nn.init.normal_(m.weight.data, 0.0, 0.02)
11 | elif classname.find('BatchNorm') != -1:
12 | nn.init.normal_(m.weight.data, 1.0, 0.02)
13 | nn.init.constant_(m.bias.data, 0)
14 |
15 |
16 | class NLayerDiscriminator(nn.Module):
17 | """Defines a PatchGAN discriminator as in Pix2Pix
18 | --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
19 | """
20 | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
21 | """Construct a PatchGAN discriminator
22 | Parameters:
23 | input_nc (int) -- the number of channels in input images
24 | ndf (int) -- the number of filters in the last conv layer
25 | n_layers (int) -- the number of conv layers in the discriminator
26 | norm_layer -- normalization layer
27 | """
28 | super(NLayerDiscriminator, self).__init__()
29 | if not use_actnorm:
30 | norm_layer = nn.BatchNorm2d
31 | else:
32 | norm_layer = ActNorm
33 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
34 | use_bias = norm_layer.func != nn.BatchNorm2d
35 | else:
36 | use_bias = norm_layer != nn.BatchNorm2d
37 |
38 | kw = 4
39 | padw = 1
40 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
41 | nf_mult = 1
42 | nf_mult_prev = 1
43 | for n in range(1, n_layers): # gradually increase the number of filters
44 | nf_mult_prev = nf_mult
45 | nf_mult = min(2 ** n, 8)
46 | sequence += [
47 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
48 | norm_layer(ndf * nf_mult),
49 | nn.LeakyReLU(0.2, True)
50 | ]
51 |
52 | nf_mult_prev = nf_mult
53 | nf_mult = min(2 ** n_layers, 8)
54 | sequence += [
55 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
56 | norm_layer(ndf * nf_mult),
57 | nn.LeakyReLU(0.2, True)
58 | ]
59 |
60 | sequence += [
61 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
62 | self.main = nn.Sequential(*sequence)
63 |
64 | def forward(self, input):
65 | """Standard forward."""
66 | return self.main(input)
67 |
--------------------------------------------------------------------------------
/net2net/modules/distributions/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/net2net/modules/distributions/__init__.py
--------------------------------------------------------------------------------
/net2net/modules/distributions/distributions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | class AbstractDistribution:
6 | def sample(self):
7 | raise NotImplementedError()
8 |
9 | def mode(self):
10 | raise NotImplementedError()
11 |
12 |
13 | class DiracDistribution(AbstractDistribution):
14 | def __init__(self, value):
15 | self.value = value
16 |
17 | def sample(self):
18 | return self.value
19 |
20 | def mode(self):
21 | return self.value
22 |
23 |
24 | class DiagonalGaussianDistribution(object):
25 | def __init__(self, parameters, deterministic=False):
26 | self.parameters = parameters
27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28 | self.logvar = torch.clamp(self.logvar, -30.0, 10.0)
29 | self.deterministic = deterministic
30 | self.std = torch.exp(0.5 * self.logvar)
31 | self.var = torch.exp(self.logvar)
32 | if self.deterministic:
33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
34 |
35 | def sample(self):
36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
37 | return x
38 |
39 | def kl(self, other=None):
40 | if self.deterministic:
41 | return torch.Tensor([0.])
42 | else:
43 | if other is None:
44 | return 0.5 * torch.sum(torch.pow(self.mean, 2)
45 | + self.var - 1.0 - self.logvar,
46 | dim=[1, 2, 3])
47 | else:
48 | return 0.5 * torch.sum(
49 | torch.pow(self.mean - other.mean, 2) / other.var
50 | + self.var / other.var - 1.0 - self.logvar + other.logvar,
51 | dim=[1, 2, 3])
52 |
53 | def nll(self, sample):
54 | if self.deterministic:
55 | return torch.Tensor([0.])
56 | logtwopi = np.log(2.0 * np.pi)
57 | return 0.5 * torch.sum(
58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
59 | dim=[1, 2, 3])
60 |
61 | def mode(self):
62 | return self.mean
63 |
--------------------------------------------------------------------------------
/net2net/modules/facenet/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/net2net/modules/facenet/__init__.py
--------------------------------------------------------------------------------
/net2net/modules/facenet/inception_resnet_v1.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 | import requests
5 | from requests.adapters import HTTPAdapter
6 | import os
7 |
8 |
9 | class BasicConv2d(nn.Module):
10 |
11 | def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0):
12 | super().__init__()
13 | self.conv = nn.Conv2d(
14 | in_planes, out_planes,
15 | kernel_size=kernel_size, stride=stride,
16 | padding=padding, bias=False
17 | ) # verify bias false
18 | self.bn = nn.BatchNorm2d(
19 | out_planes,
20 | eps=0.001, # value found in tensorflow
21 | momentum=0.1, # default pytorch value
22 | affine=True
23 | )
24 | self.relu = nn.ReLU(inplace=False)
25 |
26 | def forward(self, x):
27 | x = self.conv(x)
28 | x = self.bn(x)
29 | x = self.relu(x)
30 | return x
31 |
32 |
33 | class Block35(nn.Module):
34 |
35 | def __init__(self, scale=1.0):
36 | super().__init__()
37 |
38 | self.scale = scale
39 |
40 | self.branch0 = BasicConv2d(256, 32, kernel_size=1, stride=1)
41 |
42 | self.branch1 = nn.Sequential(
43 | BasicConv2d(256, 32, kernel_size=1, stride=1),
44 | BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1)
45 | )
46 |
47 | self.branch2 = nn.Sequential(
48 | BasicConv2d(256, 32, kernel_size=1, stride=1),
49 | BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1),
50 | BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1)
51 | )
52 |
53 | self.conv2d = nn.Conv2d(96, 256, kernel_size=1, stride=1)
54 | self.relu = nn.ReLU(inplace=False)
55 |
56 | def forward(self, x):
57 | x0 = self.branch0(x)
58 | x1 = self.branch1(x)
59 | x2 = self.branch2(x)
60 | out = torch.cat((x0, x1, x2), 1)
61 | out = self.conv2d(out)
62 | out = out * self.scale + x
63 | out = self.relu(out)
64 | return out
65 |
66 |
67 | class Block17(nn.Module):
68 |
69 | def __init__(self, scale=1.0):
70 | super().__init__()
71 |
72 | self.scale = scale
73 |
74 | self.branch0 = BasicConv2d(896, 128, kernel_size=1, stride=1)
75 |
76 | self.branch1 = nn.Sequential(
77 | BasicConv2d(896, 128, kernel_size=1, stride=1),
78 | BasicConv2d(128, 128, kernel_size=(1,7), stride=1, padding=(0,3)),
79 | BasicConv2d(128, 128, kernel_size=(7,1), stride=1, padding=(3,0))
80 | )
81 |
82 | self.conv2d = nn.Conv2d(256, 896, kernel_size=1, stride=1)
83 | self.relu = nn.ReLU(inplace=False)
84 |
85 | def forward(self, x):
86 | x0 = self.branch0(x)
87 | x1 = self.branch1(x)
88 | out = torch.cat((x0, x1), 1)
89 | out = self.conv2d(out)
90 | out = out * self.scale + x
91 | out = self.relu(out)
92 | return out
93 |
94 |
95 | class Block8(nn.Module):
96 |
97 | def __init__(self, scale=1.0, noReLU=False):
98 | super().__init__()
99 |
100 | self.scale = scale
101 | self.noReLU = noReLU
102 |
103 | self.branch0 = BasicConv2d(1792, 192, kernel_size=1, stride=1)
104 |
105 | self.branch1 = nn.Sequential(
106 | BasicConv2d(1792, 192, kernel_size=1, stride=1),
107 | BasicConv2d(192, 192, kernel_size=(1,3), stride=1, padding=(0,1)),
108 | BasicConv2d(192, 192, kernel_size=(3,1), stride=1, padding=(1,0))
109 | )
110 |
111 | self.conv2d = nn.Conv2d(384, 1792, kernel_size=1, stride=1)
112 | if not self.noReLU:
113 | self.relu = nn.ReLU(inplace=False)
114 |
115 | def forward(self, x):
116 | x0 = self.branch0(x)
117 | x1 = self.branch1(x)
118 | out = torch.cat((x0, x1), 1)
119 | out = self.conv2d(out)
120 | out = out * self.scale + x
121 | if not self.noReLU:
122 | out = self.relu(out)
123 | return out
124 |
125 |
126 | class Mixed_6a(nn.Module):
127 |
128 | def __init__(self):
129 | super().__init__()
130 |
131 | self.branch0 = BasicConv2d(256, 384, kernel_size=3, stride=2)
132 |
133 | self.branch1 = nn.Sequential(
134 | BasicConv2d(256, 192, kernel_size=1, stride=1),
135 | BasicConv2d(192, 192, kernel_size=3, stride=1, padding=1),
136 | BasicConv2d(192, 256, kernel_size=3, stride=2)
137 | )
138 |
139 | self.branch2 = nn.MaxPool2d(3, stride=2)
140 |
141 | def forward(self, x):
142 | x0 = self.branch0(x)
143 | x1 = self.branch1(x)
144 | x2 = self.branch2(x)
145 | out = torch.cat((x0, x1, x2), 1)
146 | return out
147 |
148 |
149 | class Mixed_7a(nn.Module):
150 |
151 | def __init__(self):
152 | super().__init__()
153 |
154 | self.branch0 = nn.Sequential(
155 | BasicConv2d(896, 256, kernel_size=1, stride=1),
156 | BasicConv2d(256, 384, kernel_size=3, stride=2)
157 | )
158 |
159 | self.branch1 = nn.Sequential(
160 | BasicConv2d(896, 256, kernel_size=1, stride=1),
161 | BasicConv2d(256, 256, kernel_size=3, stride=2)
162 | )
163 |
164 | self.branch2 = nn.Sequential(
165 | BasicConv2d(896, 256, kernel_size=1, stride=1),
166 | BasicConv2d(256, 256, kernel_size=3, stride=1, padding=1),
167 | BasicConv2d(256, 256, kernel_size=3, stride=2)
168 | )
169 |
170 | self.branch3 = nn.MaxPool2d(3, stride=2)
171 |
172 | def forward(self, x):
173 | x0 = self.branch0(x)
174 | x1 = self.branch1(x)
175 | x2 = self.branch2(x)
176 | x3 = self.branch3(x)
177 | out = torch.cat((x0, x1, x2, x3), 1)
178 | return out
179 |
180 |
181 | class InceptionResnetV1(nn.Module):
182 | """Inception Resnet V1 model with optional loading of pretrained weights.
183 |
184 | Model parameters can be loaded based on pretraining on the VGGFace2 or CASIA-Webface
185 | datasets. Pretrained state_dicts are automatically downloaded on model instantiation if
186 | requested and cached in the torch cache. Subsequent instantiations use the cache rather than
187 | redownloading.
188 |
189 | Keyword Arguments:
190 | pretrained {str} -- Optional pretraining dataset. Either 'vggface2' or 'casia-webface'.
191 | (default: {None})
192 | classify {bool} -- Whether the model should output classification probabilities or feature
193 | embeddings. (default: {False})
194 | num_classes {int} -- Number of output classes. If 'pretrained' is set and num_classes not
195 | equal to that used for the pretrained model, the final linear layer will be randomly
196 | initialized. (default: {None})
197 | dropout_prob {float} -- Dropout probability. (default: {0.6})
198 | """
199 | def __init__(self, pretrained=None, classify=False, num_classes=None, dropout_prob=0.6, device=None):
200 | super().__init__()
201 |
202 | # Set simple attributes
203 | self.pretrained = pretrained
204 | self.classify = classify
205 | self.num_classes = num_classes
206 |
207 | if pretrained == 'vggface2':
208 | tmp_classes = 8631
209 | elif pretrained == 'casia-webface':
210 | tmp_classes = 10575
211 | elif pretrained is None and self.num_classes is None:
212 | raise Exception('At least one of "pretrained" or "num_classes" must be specified')
213 | else:
214 | tmp_classes = self.num_classes
215 |
216 |
217 | # Define layers
218 | self.conv2d_1a = BasicConv2d(3, 32, kernel_size=3, stride=2)
219 | self.conv2d_2a = BasicConv2d(32, 32, kernel_size=3, stride=1)
220 | self.conv2d_2b = BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1)
221 | self.maxpool_3a = nn.MaxPool2d(3, stride=2)
222 | self.conv2d_3b = BasicConv2d(64, 80, kernel_size=1, stride=1)
223 | self.conv2d_4a = BasicConv2d(80, 192, kernel_size=3, stride=1)
224 | self.conv2d_4b = BasicConv2d(192, 256, kernel_size=3, stride=2)
225 | self.repeat_1 = nn.Sequential(
226 | Block35(scale=0.17),
227 | Block35(scale=0.17),
228 | Block35(scale=0.17),
229 | Block35(scale=0.17),
230 | Block35(scale=0.17),
231 | )
232 | self.mixed_6a = Mixed_6a()
233 | self.repeat_2 = nn.Sequential(
234 | Block17(scale=0.10),
235 | Block17(scale=0.10),
236 | Block17(scale=0.10),
237 | Block17(scale=0.10),
238 | Block17(scale=0.10),
239 | Block17(scale=0.10),
240 | Block17(scale=0.10),
241 | Block17(scale=0.10),
242 | Block17(scale=0.10),
243 | Block17(scale=0.10),
244 | )
245 | self.mixed_7a = Mixed_7a()
246 | self.repeat_3 = nn.Sequential(
247 | Block8(scale=0.20),
248 | Block8(scale=0.20),
249 | Block8(scale=0.20),
250 | Block8(scale=0.20),
251 | Block8(scale=0.20),
252 | )
253 | self.block8 = Block8(noReLU=True)
254 | self.avgpool_1a = nn.AdaptiveAvgPool2d(1)
255 | self.dropout = nn.Dropout(dropout_prob)
256 | self.last_linear = nn.Linear(1792, 512, bias=False)
257 | self.last_bn = nn.BatchNorm1d(512, eps=0.001, momentum=0.1, affine=True)
258 | self.logits = nn.Linear(512, tmp_classes)
259 |
260 | if pretrained is not None:
261 | load_weights(self, pretrained)
262 |
263 | if self.num_classes is not None:
264 | self.logits = nn.Linear(512, self.num_classes)
265 |
266 | self.device = torch.device('cpu')
267 | if device is not None:
268 | self.device = device
269 | self.to(device)
270 |
271 | def forward(self, x):
272 | """Calculate embeddings or probabilities given a batch of input image tensors.
273 |
274 | Arguments:
275 | x {torch.tensor} -- Batch of image tensors representing faces.
276 |
277 | Returns:
278 | torch.tensor -- Batch of embeddings or softmax probabilities.
279 | """
280 | x = self.conv2d_1a(x)
281 | x = self.conv2d_2a(x)
282 | x = self.conv2d_2b(x)
283 | x = self.maxpool_3a(x)
284 | x = self.conv2d_3b(x)
285 | x = self.conv2d_4a(x)
286 | x = self.conv2d_4b(x)
287 | x = self.repeat_1(x)
288 | x = self.mixed_6a(x)
289 | x = self.repeat_2(x)
290 | x = self.mixed_7a(x)
291 | x = self.repeat_3(x)
292 | x = self.block8(x)
293 | x = self.avgpool_1a(x)
294 | x = self.dropout(x)
295 | x = self.last_linear(x.view(x.shape[0], -1))
296 | x = self.last_bn(x)
297 | x = F.normalize(x, p=2, dim=1)
298 | if self.classify:
299 | x = self.logits(x)
300 | return x
301 |
302 |
303 | def load_weights(mdl, name):
304 | """Download pretrained state_dict and load into model.
305 |
306 | Arguments:
307 | mdl {torch.nn.Module} -- Pytorch model.
308 | name {str} -- Name of dataset that was used to generate pretrained state_dict.
309 |
310 | Raises:
311 | ValueError: If 'pretrained' not equal to 'vggface2' or 'casia-webface'.
312 | """
313 | if name == 'vggface2':
314 | features_path = 'https://drive.google.com/uc?export=download&id=1cWLH_hPns8kSfMz9kKl9PsG5aNV2VSMn'
315 | logits_path = 'https://drive.google.com/uc?export=download&id=1mAie3nzZeno9UIzFXvmVZrDG3kwML46X'
316 | elif name == 'casia-webface':
317 | features_path = 'https://drive.google.com/uc?export=download&id=1LSHHee_IQj5W3vjBcRyVaALv4py1XaGy'
318 | logits_path = 'https://drive.google.com/uc?export=download&id=1QrhPgn1bGlDxAil2uc07ctunCQoDnCzT'
319 | else:
320 | raise ValueError('Pretrained models only exist for "vggface2" and "casia-webface"')
321 |
322 | model_dir = os.path.join(get_torch_home(), 'checkpoints')
323 | os.makedirs(model_dir, exist_ok=True)
324 |
325 | state_dict = {}
326 | for i, path in enumerate([features_path, logits_path]):
327 | cached_file = os.path.join(model_dir, '{}_{}.pt'.format(name, path[-10:]))
328 | if not os.path.exists(cached_file):
329 | print('Downloading parameters ({}/2)'.format(i+1))
330 | s = requests.Session()
331 | s.mount('https://', HTTPAdapter(max_retries=10))
332 | r = s.get(path, allow_redirects=True)
333 | with open(cached_file, 'wb') as f:
334 | f.write(r.content)
335 | state_dict.update(torch.load(cached_file))
336 |
337 | mdl.load_state_dict(state_dict)
338 |
339 |
340 | def get_torch_home():
341 | torch_home = os.path.expanduser(
342 | os.getenv(
343 | 'TORCH_HOME',
344 | os.path.join(os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch')
345 | )
346 | )
347 | return torch_home
348 |
--------------------------------------------------------------------------------
/net2net/modules/facenet/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import functional as F
4 |
5 | from net2net.modules.facenet.inception_resnet_v1 import InceptionResnetV1
6 |
7 |
8 | """FaceNet adopted from https://github.com/timesler/facenet-pytorch"""
9 |
10 |
11 | class FaceNet(nn.Module):
12 | def __init__(self):
13 | super().__init__()
14 | # InceptionResnetV1 has a bottleneck of size 512
15 | self.net = InceptionResnetV1(pretrained='vggface2').eval()
16 |
17 | def _pre_process(self, x):
18 | # TODO: neccessary for InceptionResnetV1?
19 | # seems like mtcnn (multi-task cnn) preprocessing is neccessary, but not 100% sure
20 | return x
21 |
22 | def forward(self, x, return_logits=False):
23 | # output are logits of size 8631 or embeddings of size 512
24 | x = self._pre_process(x)
25 | emb = self.net(x)
26 | if return_logits:
27 | return self.net.logits(emb)
28 | return emb
29 |
30 | def encode(self, x):
31 | return self(x)
32 |
33 | def return_features(self, x):
34 | """ returned features have the following dimensions:
35 |
36 | torch.Size([11, 3, 128, 128]), x 49152
37 | torch.Size([11, 192, 28, 28]), x 150528
38 | torch.Size([11, 896, 6, 6]), x 32256
39 | torch.Size([11, 1792, 1, 1]), x 1792
40 | torch.Size([11, 512]) x 512
41 | logits (8xxx) x 8xxx
42 | """
43 |
44 | x = self._pre_process(x)
45 | features = [x] # this
46 | x = self.net.conv2d_1a(x)
47 | x = self.net.conv2d_2a(x)
48 | x = self.net.conv2d_2b(x)
49 | x = self.net.maxpool_3a(x)
50 | x = self.net.conv2d_3b(x)
51 | x = self.net.conv2d_4a(x)
52 | features.append(x) # this
53 | x = self.net.conv2d_4b(x)
54 | x = self.net.repeat_1(x)
55 | x = self.net.mixed_6a(x)
56 | features.append(x) # this
57 | x = self.net.repeat_2(x)
58 | x = self.net.mixed_7a(x)
59 | x = self.net.repeat_3(x)
60 | x = self.net.block8(x)
61 | x = self.net.avgpool_1a(x)
62 | features.append(x) # this
63 | x = self.net.dropout(x)
64 | x = self.net.last_linear(x.view(x.shape[0], -1))
65 | x = self.net.last_bn(x)
66 | emb = F.normalize(x, p=2, dim=1) # the final embeddings
67 | features.append(emb[..., None, None]) # need extra dimensions for flow later
68 | features.append(self.net.logits(emb).unsqueeze(-1).unsqueeze(-1))
69 | return features # has 6 elements as of now
70 |
71 |
--------------------------------------------------------------------------------
/net2net/modules/flow/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/net2net/modules/flow/__init__.py
--------------------------------------------------------------------------------
/net2net/modules/flow/base.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | class NormalizingFlow(nn.Module):
5 | def __init__(self, *args, **kwargs):
6 | super().__init__()
7 |
8 | def forward(self, *args, **kwargs):
9 | # return transformed, logdet
10 | raise NotImplementedError
11 |
12 | def reverse(self, *args, **kwargs):
13 | # return transformed_reverse
14 | raise NotImplementedError
15 |
16 | def sample(self, *args, **kwargs):
17 | # return sample
18 | raise NotImplementedError
19 |
--------------------------------------------------------------------------------
/net2net/modules/flow/blocks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import functools
4 |
5 | from net2net.modules.autoencoder.basic import ActNorm, BasicFullyConnectedNet
6 |
7 |
8 | class Flow(nn.Module):
9 | def __init__(self, module_list, in_channels, hidden_dim, hidden_depth):
10 | super(Flow, self).__init__()
11 | self.in_channels = in_channels
12 | self.flow = nn.ModuleList(
13 | [module(in_channels, hidden_dim=hidden_dim, depth=hidden_depth) for module in module_list])
14 |
15 | def forward(self, x, condition=None, reverse=False):
16 | if not reverse:
17 | logdet = 0
18 | for i in range(len(self.flow)):
19 | x, logdet_ = self.flow[i](x)
20 | logdet = logdet + logdet_
21 | return x, logdet
22 | else:
23 | for i in reversed(range(len(self.flow))):
24 | x = self.flow[i](x, reverse=True)
25 | return x
26 |
27 |
28 | class UnconditionalFlatDoubleCouplingFlowBlock(nn.Module):
29 | def __init__(self, in_channels, hidden_dim, hidden_depth):
30 | super().__init__()
31 | self.norm_layer = ActNorm(in_channels, logdet=True)
32 | self.coupling = DoubleVectorCouplingBlock(in_channels,
33 | hidden_dim,
34 | hidden_depth)
35 | self.shuffle = Shuffle(in_channels)
36 |
37 | def forward(self, x, reverse=False):
38 | if not reverse:
39 | h = x
40 | logdet = 0.0
41 | h, ld = self.norm_layer(h)
42 | logdet += ld
43 | h, ld = self.coupling(h)
44 | logdet += ld
45 | h, ld = self.shuffle(h)
46 | logdet += ld
47 | return h, logdet
48 | else:
49 | h = x
50 | h = self.shuffle(h, reverse=True)
51 | h = self.coupling(h, reverse=True)
52 | h = self.norm_layer(h, reverse=True)
53 | return h
54 |
55 | def reverse(self, out):
56 | return self.forward(out, reverse=True)
57 |
58 |
59 | class PureAffineDoubleCouplingFlowBlock(nn.Module):
60 | def __init__(self, in_channels, hidden_dim, hidden_depth):
61 | super().__init__()
62 | self.coupling = DoubleVectorCouplingBlock(in_channels,
63 | hidden_dim,
64 | hidden_depth)
65 |
66 | def forward(self, x, reverse=False):
67 | if not reverse:
68 | h = x
69 | logdet = 0.0
70 | h, ld = self.coupling(h)
71 | logdet += ld
72 | return h, logdet
73 | else:
74 | h = x
75 | h = self.coupling(h, reverse=True)
76 | return h
77 |
78 | def reverse(self, out):
79 | return self.forward(out, reverse=True)
80 |
81 |
82 | class ConditionalFlow(nn.Module):
83 | """Flat version. Feeds an embedding into the flow in every block"""
84 | def __init__(self, in_channels, embedding_dim, hidden_dim, hidden_depth,
85 | n_flows, conditioning_option="none", activation='lrelu'):
86 | super().__init__()
87 | self.in_channels = in_channels
88 | self.cond_channels = embedding_dim
89 | self.mid_channels = hidden_dim
90 | self.num_blocks = hidden_depth
91 | self.n_flows = n_flows
92 | self.conditioning_option = conditioning_option
93 |
94 | self.sub_layers = nn.ModuleList()
95 | if self.conditioning_option.lower() != "none":
96 | self.conditioning_layers = nn.ModuleList()
97 | for flow in range(self.n_flows):
98 | self.sub_layers.append(ConditionalFlatDoubleCouplingFlowBlock(
99 | self.in_channels, self.cond_channels, self.mid_channels,
100 | self.num_blocks, activation=activation)
101 | )
102 | if self.conditioning_option.lower() != "none":
103 | self.conditioning_layers.append(nn.Conv2d(self.cond_channels, self.cond_channels, 1))
104 |
105 | def forward(self, x, embedding, reverse=False):
106 | hconds = list()
107 | hcond = embedding[:, :, None, None]
108 | for i in range(self.n_flows):
109 | if self.conditioning_option.lower() == "parallel":
110 | hcond = self.conditioning_layers[i](embedding)
111 | elif self.conditioning_option.lower() == "sequential":
112 | hcond = self.conditioning_layers[i](hcond)
113 | hconds.append(hcond)
114 | if not reverse:
115 | logdet = 0.0
116 | for i in range(self.n_flows):
117 | x, logdet_ = self.sub_layers[i](x, hconds[i])
118 | logdet = logdet + logdet_
119 | return x, logdet
120 | else:
121 | for i in reversed(range(self.n_flows)):
122 | x = self.sub_layers[i](x, hconds[i], reverse=True)
123 | return x
124 |
125 | def reverse(self, out, xcond):
126 | return self(out, xcond, reverse=True)
127 |
128 |
129 | class DoubleVectorCouplingBlock(nn.Module):
130 | """Support uneven inputs"""
131 | def __init__(self, in_channels, hidden_dim, hidden_depth=2):
132 | super().__init__()
133 | dim1 = (in_channels // 2) + (in_channels % 2)
134 | dim2 = in_channels // 2
135 | self.s = nn.ModuleList([
136 | BasicFullyConnectedNet(dim=dim1, out_dim=dim2, depth=hidden_depth,
137 | hidden_dim=hidden_dim, use_tanh=True),
138 | BasicFullyConnectedNet(dim=dim1, out_dim=dim2, depth=hidden_depth,
139 | hidden_dim=hidden_dim, use_tanh=True),
140 | ])
141 | self.t = nn.ModuleList([
142 | BasicFullyConnectedNet(dim=dim1, out_dim=dim2, depth=hidden_depth,
143 | hidden_dim=hidden_dim, use_tanh=False),
144 | BasicFullyConnectedNet(dim=dim1, out_dim=dim2, depth=hidden_depth,
145 | hidden_dim=hidden_dim, use_tanh=False),
146 | ])
147 |
148 | def forward(self, x, reverse=False):
149 | assert len(x.shape) == 4
150 | x = x.squeeze(-1).squeeze(-1)
151 | if not reverse:
152 | logdet = 0
153 | for i in range(len(self.s)):
154 | idx_apply, idx_keep = 0, 1
155 | if i % 2 != 0:
156 | x = torch.cat(torch.chunk(x, 2, dim=1)[::-1], dim=1)
157 | x = torch.chunk(x, 2, dim=1)
158 | scale = self.s[i](x[idx_apply])
159 | x_ = x[idx_keep] * (scale.exp()) + self.t[i](x[idx_apply])
160 | x = torch.cat((x[idx_apply], x_), dim=1)
161 | logdet_ = torch.sum(scale.view(x.size(0), -1), dim=1)
162 | logdet = logdet + logdet_
163 | return x[:,:,None,None], logdet
164 | else:
165 | idx_apply, idx_keep = 0, 1
166 | for i in reversed(range(len(self.s))):
167 | if i % 2 == 0:
168 | x = torch.cat(torch.chunk(x, 2, dim=1)[::-1], dim=1)
169 | x = torch.chunk(x, 2, dim=1)
170 | x_ = (x[idx_keep] - self.t[i](x[idx_apply])) * (self.s[i](x[idx_apply]).neg().exp())
171 | x = torch.cat((x[idx_apply], x_), dim=1)
172 | return x[:,:,None,None]
173 |
174 |
175 | class ConditionalDoubleVectorCouplingBlock(nn.Module):
176 | def __init__(self, in_channels, cond_channels, hidden_dim, depth=2):
177 | super(ConditionalDoubleVectorCouplingBlock, self).__init__()
178 | self.s = nn.ModuleList([
179 | BasicFullyConnectedNet(dim=in_channels // 2 + cond_channels, depth=depth,
180 | hidden_dim=hidden_dim, use_tanh=True,
181 | out_dim=in_channels // 2) for _ in range(2)])
182 | self.t = nn.ModuleList([
183 | BasicFullyConnectedNet(dim=in_channels // 2 + cond_channels, depth=depth,
184 | hidden_dim=hidden_dim, use_tanh=False,
185 | out_dim=in_channels // 2) for _ in range(2)])
186 |
187 | def forward(self, x, xc, reverse=False):
188 | assert len(x.shape) == 4
189 | assert len(xc.shape) == 4
190 | x = x.squeeze(-1).squeeze(-1)
191 | xc = xc.squeeze(-1).squeeze(-1)
192 | if not reverse:
193 | logdet = 0
194 | for i in range(len(self.s)):
195 | idx_apply, idx_keep = 0, 1
196 | if i % 2 != 0:
197 | x = torch.cat(torch.chunk(x, 2, dim=1)[::-1], dim=1)
198 | x = torch.chunk(x, 2, dim=1)
199 | conditioner_input = torch.cat((x[idx_apply], xc), dim=1)
200 | scale = self.s[i](conditioner_input)
201 | x_ = x[idx_keep] * scale.exp() + self.t[i](conditioner_input)
202 | x = torch.cat((x[idx_apply], x_), dim=1)
203 | logdet_ = torch.sum(scale, dim=1)
204 | logdet = logdet + logdet_
205 | return x[:, :, None, None], logdet
206 | else:
207 | idx_apply, idx_keep = 0, 1
208 | for i in reversed(range(len(self.s))):
209 | if i % 2 == 0:
210 | x = torch.cat(torch.chunk(x, 2, dim=1)[::-1], dim=1)
211 | x = torch.chunk(x, 2, dim=1)
212 | conditioner_input = torch.cat((x[idx_apply], xc), dim=1)
213 | x_ = (x[idx_keep] - self.t[i](conditioner_input)) * self.s[i](conditioner_input).neg().exp()
214 | x = torch.cat((x[idx_apply], x_), dim=1)
215 | return x[:, :, None, None]
216 |
217 |
218 | class ConditionalFlatDoubleCouplingFlowBlock(nn.Module):
219 | def __init__(self, in_channels, cond_channels, hidden_dim, hidden_depth, activation="lrelu"):
220 | super().__init__()
221 | __possible_activations = {"lrelu": InvLeakyRelu,
222 | "none": IgnoreLeakyRelu
223 | }
224 | self.norm_layer = ActNorm(in_channels, logdet=True)
225 | self.coupling = ConditionalDoubleVectorCouplingBlock(in_channels,
226 | cond_channels,
227 | hidden_dim,
228 | hidden_depth)
229 | self.activation = __possible_activations[activation]()
230 | self.shuffle = Shuffle(in_channels)
231 |
232 | def forward(self, x, xcond, reverse=False):
233 | if not reverse:
234 | h = x
235 | logdet = 0.0
236 | h, ld = self.norm_layer(h)
237 | logdet += ld
238 | h, ld = self.activation(h)
239 | logdet += ld
240 | h, ld = self.coupling(h, xcond)
241 | logdet += ld
242 | h, ld = self.shuffle(h)
243 | logdet += ld
244 | return h, logdet
245 | else:
246 | h = x
247 | h = self.shuffle(h, reverse=True)
248 | h = self.coupling(h, xcond, reverse=True)
249 | h = self.activation(h, reverse=True)
250 | h = self.norm_layer(h, reverse=True)
251 | return h
252 |
253 | def reverse(self, out, xcond):
254 | return self.forward(out, xcond, reverse=True)
255 |
256 |
257 | class Shuffle(nn.Module):
258 | def __init__(self, in_channels, **kwargs):
259 | super(Shuffle, self).__init__()
260 | self.in_channels = in_channels
261 | idx = torch.randperm(in_channels)
262 | self.register_buffer('forward_shuffle_idx', nn.Parameter(idx, requires_grad=False))
263 | self.register_buffer('backward_shuffle_idx', nn.Parameter(torch.argsort(idx), requires_grad=False))
264 |
265 | def forward(self, x, reverse=False, conditioning=None):
266 | if not reverse:
267 | return x[:, self.forward_shuffle_idx, ...], 0
268 | else:
269 | return x[:, self.backward_shuffle_idx, ...]
270 |
271 |
272 | class IgnoreLeakyRelu(nn.Module):
273 | """performs identity op."""
274 |
275 | def __init__(self, alpha=0.9):
276 | super().__init__()
277 |
278 | def forward(self, input, reverse=False):
279 | if reverse:
280 | return self.reverse(input)
281 | h = input
282 | return h, 0.0
283 |
284 | def reverse(self, input):
285 | h = input
286 | return h
287 |
288 |
289 | class InvLeakyRelu(nn.Module):
290 | def __init__(self, alpha=0.9):
291 | super().__init__()
292 | self.alpha = alpha
293 |
294 | def forward(self, input, reverse=False):
295 | if reverse:
296 | return self.reverse(input)
297 | scaling = (input >= 0).to(input) + (input < 0).to(input) * self.alpha
298 | h = input * scaling
299 | return h, 0.0
300 |
301 | def reverse(self, input):
302 | scaling = (input >= 0).to(input) + (input < 0).to(input) * self.alpha
303 | h = input / scaling
304 | return h
305 |
306 |
307 | class InvParametricRelu(InvLeakyRelu):
308 | def __init__(self, alpha=0.9):
309 | super().__init__()
310 | self.alpha = nn.Parameter(torch.tensor(alpha), requires_grad=True)
311 |
312 |
313 | class FeatureLayer(nn.Module):
314 | def __init__(self, scale, in_channels=None, norm='AN', width_multiplier=1):
315 | super().__init__()
316 |
317 | norm_options = {
318 | "in": nn.InstanceNorm2d,
319 | "bn": nn.BatchNorm2d,
320 | "an": ActNorm}
321 |
322 | self.scale = scale
323 | self.norm = norm_options[norm.lower()]
324 | self.wm = width_multiplier
325 | if in_channels is None:
326 | self.in_channels = int(self.wm * 64 * min(2 ** (self.scale - 1), 16))
327 | else:
328 | self.in_channels = in_channels
329 | self.out_channels = int(self.wm * 64 * min(2 ** self.scale, 16))
330 | self.build()
331 |
332 | def forward(self, input):
333 | x = input
334 | for layer in self.sub_layers:
335 | x = layer(x)
336 | return x
337 |
338 | def build(self):
339 | Norm = functools.partial(self.norm, affine=True)
340 | self.sub_layers = nn.ModuleList([
341 | nn.Conv2d(
342 | in_channels=self.in_channels,
343 | out_channels=self.out_channels,
344 | kernel_size=4,
345 | stride=2,
346 | padding=1,
347 | bias=False),
348 | Norm(num_features=self.out_channels),
349 | nn.LeakyReLU(0.2)])
350 |
351 |
352 | class DenseEncoderLayer(nn.Module):
353 | def __init__(self, scale, spatial_size, out_size, in_channels=None,
354 | width_multiplier=1):
355 | super().__init__()
356 | self.scale = scale
357 | self.wm = width_multiplier
358 | self.in_channels = int(self.wm * 64 * min(2 ** (self.scale - 1), 16))
359 | if in_channels is not None:
360 | self.in_channels = in_channels
361 | self.out_channels = out_size
362 | self.kernel_size = spatial_size
363 | self.build()
364 |
365 | def forward(self, input):
366 | x = input
367 | for layer in self.sub_layers:
368 | x = layer(x)
369 | return x
370 |
371 | def build(self):
372 | self.sub_layers = nn.ModuleList([
373 | nn.Conv2d(
374 | in_channels=self.in_channels,
375 | out_channels=self.out_channels,
376 | kernel_size=self.kernel_size,
377 | stride=1,
378 | padding=0,
379 | bias=True)])
380 |
--------------------------------------------------------------------------------
/net2net/modules/flow/flatflow.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 |
5 | from net2net.modules.autoencoder.basic import ActNorm
6 | from net2net.modules.flow.blocks import UnconditionalFlatDoubleCouplingFlowBlock, PureAffineDoubleCouplingFlowBlock, \
7 | ConditionalFlatDoubleCouplingFlowBlock
8 | from net2net.modules.flow.base import NormalizingFlow
9 | from net2net.modules.autoencoder.basic import FeatureLayer, DenseEncoderLayer
10 | from net2net.modules.flow.blocks import BasicFullyConnectedNet
11 |
12 |
13 | class UnconditionalFlatCouplingFlow(NormalizingFlow):
14 | """Flat, multiple blocks of ActNorm, DoubleAffineCoupling, Shuffle"""
15 | def __init__(self, in_channels, n_flows, hidden_dim, hidden_depth):
16 | super().__init__()
17 | self.in_channels = in_channels
18 | self.mid_channels = hidden_dim
19 | self.num_blocks = hidden_depth
20 | self.n_flows = n_flows
21 | self.sub_layers = nn.ModuleList()
22 |
23 | for flow in range(self.n_flows):
24 | self.sub_layers.append(UnconditionalFlatDoubleCouplingFlowBlock(
25 | self.in_channels, self.mid_channels,
26 | self.num_blocks)
27 | )
28 |
29 | def forward(self, x, reverse=False):
30 | if len(x.shape) == 2:
31 | x = x[:,:,None,None]
32 | self.last_outs = []
33 | self.last_logdets = []
34 | if not reverse:
35 | logdet = 0.0
36 | for i in range(self.n_flows):
37 | x, logdet_ = self.sub_layers[i](x)
38 | logdet = logdet + logdet_
39 | self.last_outs.append(x)
40 | self.last_logdets.append(logdet)
41 | return x, logdet
42 | else:
43 | for i in reversed(range(self.n_flows)):
44 | x = self.sub_layers[i](x, reverse=True)
45 | return x
46 |
47 | def reverse(self, out):
48 | if len(out.shape) == 2:
49 | out = out[:,:,None,None]
50 | return self(out, reverse=True)
51 |
52 | def sample(self, num_samples, device="cpu"):
53 | zz = torch.randn(num_samples, self.in_channels, 1, 1).to(device)
54 | return self.reverse(zz)
55 |
56 | def get_last_layer(self):
57 | return getattr(self.sub_layers[-1].coupling.t[-1].main[-1], 'weight')
58 |
59 |
60 | class PureAffineFlatCouplingFlow(UnconditionalFlatCouplingFlow):
61 | """Flat, multiple blocks of DoubleAffineCoupling"""
62 | def __init__(self, in_channels, n_flows, hidden_dim, hidden_depth):
63 | super().__init__(in_channels, n_flows, hidden_dim, hidden_depth)
64 | del self.sub_layers
65 | self.sub_layers = nn.ModuleList()
66 | for flow in range(self.n_flows):
67 | self.sub_layers.append(PureAffineDoubleCouplingFlowBlock(
68 | self.in_channels, self.mid_channels,
69 | self.num_blocks)
70 | )
71 |
72 |
73 | class DenseEmbedder(nn.Module):
74 | """Supposed to map small-scale features (e.g. labels) to some given latent dim"""
75 | def __init__(self, in_dim, up_dim, depth=4, given_dims=None):
76 | super().__init__()
77 | self.net = nn.ModuleList()
78 | if given_dims is not None:
79 | assert given_dims[0] == in_dim
80 | assert given_dims[-1] == up_dim
81 | dims = given_dims
82 | else:
83 | dims = np.linspace(in_dim, up_dim, depth).astype(int)
84 | for l in range(len(dims)-2):
85 | self.net.append(nn.Conv2d(dims[l], dims[l + 1], 1))
86 | self.net.append(ActNorm(dims[l + 1]))
87 | self.net.append(nn.LeakyReLU(0.2))
88 |
89 | self.net.append(nn.Conv2d(dims[-2], dims[-1], 1))
90 |
91 | def forward(self, x):
92 | for layer in self.net:
93 | x = layer(x)
94 | return x.squeeze(-1).squeeze(-1)
95 |
96 |
97 | class Embedder(nn.Module):
98 | """Embeds a 4-dim tensor onto dense latent code, much like the classic encoder."""
99 | def __init__(self, in_spatial_size, in_channels, emb_dim, n_down=4):
100 | super().__init__()
101 | self.feature_layers = nn.ModuleList()
102 | norm = 'an' # hard coded yes
103 | bottleneck_size = in_spatial_size // 2**n_down
104 | self.feature_layers.append(FeatureLayer(0, in_channels=in_channels, norm=norm))
105 | for scale in range(1, n_down):
106 | self.feature_layers.append(FeatureLayer(scale, norm=norm))
107 | self.dense_encode = DenseEncoderLayer(n_down, bottleneck_size, emb_dim)
108 | if n_down == 1:
109 | # add some extra parameters to make model a little more powerful ?
110 | print(" Warning: Embedder for ConditionalTransformer has only one down-sampling step. You might want to "
111 | "increase its capacity.")
112 |
113 | def forward(self, input):
114 | h = input
115 | for layer in self.feature_layers:
116 | h = layer(h)
117 | h = self.dense_encode(h)
118 | return h.squeeze(-1).squeeze(-1)
119 |
120 |
121 | class ConditionalFlatCouplingFlow(nn.Module):
122 | """Flat version. Feeds an embedding into the flow in every block"""
123 | def __init__(self, in_channels, conditioning_dim, embedding_dim, hidden_dim, hidden_depth,
124 | n_flows, conditioning_option="none", activation='lrelu',
125 | conditioning_hidden_dim=256, conditioning_depth=2, conditioner_use_bn=False,
126 | conditioner_use_an=False):
127 | super().__init__()
128 | self.in_channels = in_channels
129 | self.cond_channels = embedding_dim
130 | self.mid_channels = hidden_dim
131 | self.num_blocks = hidden_depth
132 | self.n_flows = n_flows
133 | self.conditioning_option = conditioning_option
134 | # TODO: also for spatial inputs...
135 | if conditioner_use_bn:
136 | assert not conditioner_use_an, 'Can not use ActNorm and BatchNorm simultaneously in Embedder.'
137 | print("Note: Conditioning network uses batch-normalization. "
138 | "Make sure to train with a sufficiently large batch size")
139 |
140 | self.embedder = BasicFullyConnectedNet(dim=conditioning_dim,
141 | depth=conditioning_depth,
142 | out_dim=embedding_dim,
143 | hidden_dim=conditioning_hidden_dim,
144 | use_bn=conditioner_use_bn,
145 | use_an=conditioner_use_an)
146 |
147 | self.sub_layers = nn.ModuleList()
148 | if self.conditioning_option.lower() != "none":
149 | self.conditioning_layers = nn.ModuleList()
150 | for flow in range(self.n_flows):
151 | self.sub_layers.append(ConditionalFlatDoubleCouplingFlowBlock(
152 | self.in_channels, self.cond_channels, self.mid_channels,
153 | self.num_blocks, activation=activation)
154 | )
155 | if self.conditioning_option.lower() != "none":
156 | self.conditioning_layers.append(nn.Conv2d(self.cond_channels, self.cond_channels, 1))
157 |
158 | def forward(self, x, cond, reverse=False):
159 | hconds = list()
160 | if len(cond.shape) == 4:
161 | if cond.shape[2] == 1:
162 | assert cond.shape[3] == 1
163 | cond = cond.squeeze(-1).squeeze(-1)
164 | else:
165 | raise ValueError("Spatial conditionings not yet supported. TODO")
166 | embedding = self.embedder(cond.float())
167 | hcond = embedding[:, :, None, None]
168 | for i in range(self.n_flows):
169 | if self.conditioning_option.lower() == "parallel":
170 | hcond = self.conditioning_layers[i](embedding)
171 | elif self.conditioning_option.lower() == "sequential":
172 | hcond = self.conditioning_layers[i](hcond)
173 | hconds.append(hcond)
174 | if not reverse:
175 | logdet = 0.0
176 | for i in range(self.n_flows):
177 | x, logdet_ = self.sub_layers[i](x, hconds[i])
178 | logdet = logdet + logdet_
179 | return x, logdet
180 | else:
181 | for i in reversed(range(self.n_flows)):
182 | x = self.sub_layers[i](x, hconds[i], reverse=True)
183 | return x
184 |
185 | def reverse(self, out, xcond):
186 | return self(out, xcond, reverse=True)
187 |
188 | def sample(self, xc):
189 | zz = torch.randn(xc.shape[0], self.in_channels, 1, 1).to(xc)
190 | return self.reverse(zz, xc)
191 |
192 |
193 |
--------------------------------------------------------------------------------
/net2net/modules/flow/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | def nll(sample):
6 | if len(sample.shape) == 2:
7 | sample = sample[:,:,None,None]
8 | return 0.5*torch.sum(torch.pow(sample, 2), dim=[1,2,3])
9 |
10 |
11 | class NLL(nn.Module):
12 | def __init__(self):
13 | super().__init__()
14 |
15 | def forward(self, sample, logdet, split="train"):
16 | nll_loss = torch.mean(nll(sample))
17 | assert len(logdet.shape) == 1
18 | nlogdet_loss = -torch.mean(logdet)
19 | loss = nll_loss + nlogdet_loss
20 | reference_nll_loss = torch.mean(nll(torch.randn_like(sample)))
21 | log = {f"{split}/total_loss": loss, f"{split}/reference_nll_loss": reference_nll_loss,
22 | f"{split}/nlogdet_loss": nlogdet_loss, f"{split}/nll_loss": nll_loss,
23 | }
24 | return loss, log
--------------------------------------------------------------------------------
/net2net/modules/gan/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/net2net/modules/gan/__init__.py
--------------------------------------------------------------------------------
/net2net/modules/gan/bigbigan.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 | import tensorflow.compat.v1 as tf
5 | tf.disable_v2_behavior()
6 |
7 | import tensorflow_hub as hub
8 |
9 |
10 | class BigBiGAN(object):
11 | def __init__(self,
12 | module_path='https://tfhub.dev/deepmind/bigbigan-resnet50/1',
13 | allow_growth=True):
14 | """Initialize a BigBiGAN from the given TF Hub module."""
15 | self._module = hub.Module(module_path)
16 |
17 | # encode graph
18 | self.enc_ph = self.make_encoder_ph()
19 | self.z_sample = self.encode_graph(self.enc_ph)
20 | self.z_mean = self.encode_graph(self.enc_ph, return_all_features=True)['z_mean']
21 |
22 | # decode graph
23 | self.gen_ph = self.make_generator_ph()
24 | self.gen_samples = self.generate_graph(self.gen_ph, upsample=True)
25 |
26 | # session
27 | init = tf.global_variables_initializer()
28 | gpu_options = tf.GPUOptions(allow_growth=allow_growth)
29 | self.sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
30 | self.sess.run(init)
31 |
32 | def generate_graph(self, z, upsample=False):
33 | """Run a batch of latents z through the generator to generate images.
34 |
35 | Args:
36 | z: A batch of 120D Gaussian latents, shape [N, 120].
37 |
38 | Returns: a batch of generated RGB images, shape [N, 128, 128, 3], range
39 | [-1, 1].
40 | """
41 | outputs = self._module(z, signature='generate', as_dict=True)
42 | return outputs['upsampled' if upsample else 'default']
43 |
44 | def make_generator_ph(self):
45 | """Creates a tf.placeholder with the dtype & shape of generator inputs."""
46 | info = self._module.get_input_info_dict('generate')['z']
47 | return tf.placeholder(dtype=info.dtype, shape=info.get_shape())
48 |
49 | def encode_graph(self, x, return_all_features=False):
50 | """Run a batch of images x through the encoder.
51 |
52 | Args:
53 | x: A batch of data (256x256 RGB images), shape [N, 256, 256, 3], range
54 | [-1, 1].
55 | return_all_features: If True, return all features computed by the encoder.
56 | Otherwise (default) just return a sample z_hat.
57 |
58 | Returns: the sample z_hat of shape [N, 120] (or a dict of all features if
59 | return_all_features).
60 | """
61 | outputs = self._module(x, signature='encode', as_dict=True)
62 | return outputs if return_all_features else outputs['z_sample']
63 |
64 | def make_encoder_ph(self):
65 | """Creates a tf.placeholder with the dtype & shape of encoder inputs."""
66 | info = self._module.get_input_info_dict('encode')['x']
67 | return tf.placeholder(dtype=info.dtype, shape=info.get_shape())
68 |
69 | @torch.no_grad()
70 | def encode(self, x_torch):
71 | x_np = x_torch.detach().permute(0,2,3,1).cpu().numpy()
72 | feed_dict = {self.enc_ph: x_np}
73 | z = self.sess.run(self.z_sample, feed_dict=feed_dict)
74 | z_torch = torch.tensor(z).to(device=x_torch.device)
75 | return z_torch.unsqueeze(-1).unsqueeze(-1)
76 |
77 | @torch.no_grad()
78 | def decode(self, z_torch):
79 | z_np = z_torch.detach().squeeze(-1).squeeze(-1).cpu().numpy()
80 | feed_dict = {self.gen_ph: z_np}
81 | x = self.sess.run(self.gen_samples, feed_dict=feed_dict)
82 | x = x.transpose(0,3,1,2)
83 | x_torch = torch.tensor(x).to(device=z_torch.device)
84 | return x_torch
85 |
86 | def eval(self):
87 | # interface requirement
88 | return self
89 |
--------------------------------------------------------------------------------
/net2net/modules/labels/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/net2net/modules/labels/__init__.py
--------------------------------------------------------------------------------
/net2net/modules/labels/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 |
6 |
7 | class Labelator(nn.Module):
8 | def __init__(self, num_classes, as_one_hot=True):
9 | super().__init__()
10 | self.num_classes = num_classes
11 | self.as_one_hot = as_one_hot
12 |
13 | def encode(self, x):
14 | if self.as_one_hot:
15 | x = self.make_one_hot(x)
16 | return x
17 |
18 | def other_label(self, given_label):
19 | # if only two classes are present, inverts them
20 | others = []
21 | for l in given_label:
22 | other = int(np.random.choice(np.arange(self.num_classes)))
23 | while other == l:
24 | other = int(np.random.choice(np.arange(self.num_classes)))
25 | others.append(other)
26 | return torch.LongTensor(others)
27 |
28 | def make_one_hot(self, label):
29 | one_hot = F.one_hot(label, num_classes=self.num_classes)
30 | return one_hot
31 |
--------------------------------------------------------------------------------
/net2net/modules/mlp/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/net2net/modules/mlp/__init__.py
--------------------------------------------------------------------------------
/net2net/modules/mlp/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | import torch.nn.functional as F
5 |
6 |
7 | class NeRFMLP(nn.Module):
8 | """basic MLP from the NERF paper"""
9 | def __init__(self, in_dim, out_dim, width=256):
10 | super().__init__()
11 | self.D = 8
12 | self.W = width # hidden dim
13 | self.skips = [4]
14 | self.n_in = in_dim
15 | self.out_dim = out_dim
16 |
17 | self.layers = nn.ModuleList()
18 | self.layers.append(nn.Linear(self.n_in, self.W))
19 | for i in range(1, self.D):
20 | if i-1 in self.skips:
21 | nin = self.W + self.n_in
22 | else:
23 | nin = self.W
24 | self.layers.append(nn.Linear(nin, self.W))
25 | self.out_layer = nn.Linear(self.W, self.out_dim)
26 |
27 | def forward(self, z):
28 | h = z
29 | for i in range(self.D):
30 | h = self.layers[i](h)
31 | h = F.relu(h)
32 | if i in self.skips:
33 | h = torch.cat([h,z], dim=1)
34 | h = self.out_layer(h)
35 | return h
36 |
37 |
38 | class SineLayer(nn.Module):
39 | # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of omega_0.
40 |
41 | # If is_first=True, omega_0 is a frequency factor which simply multiplies the activations before the
42 | # nonlinearity. Different signals may require different omega_0 in the first layer - this is a
43 | # hyperparameter.
44 |
45 | # If is_first=False, then the weights will be divided by omega_0 so as to keep the magnitude of
46 | # activations constant, but boost gradients to the weight matrix (see supplement Sec. 1.5)
47 |
48 | def __init__(self, in_features, out_features, bias=True,
49 | is_first=False, omega_0=30):
50 | super().__init__()
51 | self.omega_0 = omega_0
52 | self.is_first = is_first
53 | self.in_features = in_features
54 | self.linear = nn.Linear(in_features, out_features, bias=bias)
55 | self.init_weights()
56 |
57 | def init_weights(self):
58 | with torch.no_grad():
59 | if self.is_first:
60 | self.linear.weight.uniform_(-1 / self.in_features,
61 | 1 / self.in_features)
62 | else:
63 | self.linear.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0,
64 | np.sqrt(6 / self.in_features) / self.omega_0)
65 |
66 | def forward(self, input):
67 | return torch.sin(self.omega_0 * self.linear(input))
68 |
69 | def forward_with_intermediate(self, input):
70 | # For visualization of activation distributions
71 | intermediate = self.omega_0 * self.linear(input)
72 | return torch.sin(intermediate), intermediate
73 |
74 |
75 | class Siren(nn.Module):
76 | def __init__(self, in_features, hidden_features, hidden_layers, out_features, outermost_linear=False,
77 | first_omega_0=30, hidden_omega_0=30.):
78 | super().__init__()
79 |
80 | self.net = []
81 | self.net.append(SineLayer(in_features, hidden_features,
82 | is_first=True, omega_0=first_omega_0))
83 |
84 | for i in range(hidden_layers):
85 | self.net.append(SineLayer(hidden_features, hidden_features,
86 | is_first=False, omega_0=hidden_omega_0))
87 |
88 | if outermost_linear:
89 | final_linear = nn.Linear(hidden_features, out_features)
90 |
91 | with torch.no_grad():
92 | final_linear.weight.uniform_(-np.sqrt(6 / hidden_features) / hidden_omega_0,
93 | np.sqrt(6 / hidden_features) / hidden_omega_0)
94 |
95 | self.net.append(final_linear)
96 | else:
97 | self.net.append(SineLayer(hidden_features, out_features,
98 | is_first=False, omega_0=hidden_omega_0))
99 |
100 | self.net = nn.Sequential(*self.net)
101 |
102 | def forward(self, coords):
103 | #coords = coords.clone().detach().requires_grad_(True) # allows to take derivative w.r.t. input
104 | output = self.net(coords)
105 | #return output, coords
106 | return output
107 |
108 | def forward_with_activations(self, coords, retain_grad=False):
109 | '''Returns not only model output, but also intermediate activations.
110 | Only used for visualizing activations later!'''
111 | activations = OrderedDict()
112 | activation_count = 0
113 | x = coords.clone().detach().requires_grad_(True)
114 | activations['input'] = x
115 | for i, layer in enumerate(self.net):
116 | if isinstance(layer, SineLayer):
117 | x, intermed = layer.forward_with_intermediate(x)
118 |
119 | if retain_grad:
120 | x.retain_grad()
121 | intermed.retain_grad()
122 |
123 | activations['_'.join((str(layer.__class__), "%d" % activation_count))] = intermed
124 | activation_count += 1
125 | else:
126 | x = layer(x)
127 |
128 | if retain_grad:
129 | x.retain_grad()
130 |
131 | activations['_'.join((str(layer.__class__), "%d" % activation_count))] = x
132 | activation_count += 1
133 | return activations
134 |
135 |
136 | # And finally, differential operators that allow us to leverage autograd to
137 | # compute gradients, the laplacian, etc.
138 |
139 | def laplace(y, x):
140 | grad = gradient(y, x)
141 | return divergence(grad, x)
142 |
143 |
144 | def divergence(y, x):
145 | div = 0.
146 | for i in range(y.shape[-1]):
147 | div += torch.autograd.grad(y[..., i], x, torch.ones_like(y[..., i]), create_graph=True)[0][..., i:i+1]
148 | return div
149 |
150 |
151 | def gradient(y, x, grad_outputs=None):
152 | if grad_outputs is None:
153 | grad_outputs = torch.ones_like(y)
154 | grad = torch.autograd.grad(y, [x], grad_outputs=grad_outputs, create_graph=True)[0]
155 | return grad
156 |
157 |
158 | if __name__ == "__main__":
159 | siren = Siren(2, 64, 2, 2)
160 | x = torch.randn(11, 2)
161 | x.requires_grad = True
162 | y = siren(x).mean()
163 | grad1 = torch.autograd.grad(y, x)[0]
164 | print(grad1.shape)
165 | print("done.")
--------------------------------------------------------------------------------
/net2net/modules/sbert/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/net2net/modules/sbert/__init__.py
--------------------------------------------------------------------------------
/net2net/modules/sbert/model.py:
--------------------------------------------------------------------------------
1 | # check out https://github.com/UKPLab/sentence-transformers,
2 | # list of pretrained models @ https://www.sbert.net/docs/pretrained_models.html
3 |
4 | from sentence_transformers import SentenceTransformer
5 | import numpy as np
6 | import torch.nn as nn
7 |
8 |
9 | class SentenceEmbedder(nn.Module):
10 | def __init__(self, version='bert-large-nli-stsb-mean-tokens'):
11 | super().__init__()
12 | np.set_printoptions(threshold=100)
13 | # Load Sentence model (based on BERT) from URL
14 | self.model = SentenceTransformer(version, device="cuda")
15 | self.model.eval()
16 |
17 | def forward(self, sentences):
18 | """sentences are expect to be a list of strings, e.g.
19 | sentences = ['This framework generates embeddings for each input sentence',
20 | 'Sentences are passed as a list of string.',
21 | 'The quick brown fox jumps over the lazy dog.'
22 | ]
23 | """
24 | sentence_embeddings = self.model.encode(sentences, batch_size=len(sentences), show_progress_bar=False,
25 | convert_to_tensor=True)
26 | return sentence_embeddings.cuda()
27 |
28 | def encode(self, sentences):
29 | embeddings = self(sentences)
30 | return embeddings[:,:,None,None]
31 |
32 |
33 | if __name__ == '__main__':
34 | model = SentenceEmbedder(version='distilroberta-base-paraphrase-v1')
35 | sentences = ['This framework generates embeddings for each input sentence',
36 | 'Sentences are passed as a list of string.',
37 | 'The quick brown fox jumps over the lazy dog.'
38 | ]
39 | emb = model.encode(sentences)
40 | print(emb.shape)
41 | print("done.")
42 |
--------------------------------------------------------------------------------
/net2net/modules/util.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torch.nn import init
6 | from torchvision import models
7 | import os
8 | from PIL import Image, ImageDraw
9 |
10 |
11 | def log_txt_as_img(wh, xc):
12 | b = len(xc)
13 | txts = list()
14 | for bi in range(b):
15 | txt = Image.new("RGB", wh, color="white")
16 | draw = ImageDraw.Draw(txt)
17 | nc = int(40 * (wh[0]/256))
18 | lines = "\n".join(xc[bi][start:start+nc] for start in range(0, len(xc[bi]), nc))
19 | draw.text((0,0), lines, fill="black")
20 | txt = np.array(txt).transpose(2,0,1)/127.5-1.0
21 | txts.append(txt)
22 | txts = np.stack(txts)
23 | txts = torch.tensor(txts)
24 | return txts
25 |
26 |
27 | class Downscale(nn.Module):
28 | def __init__(self, mode="bilinear", size=32):
29 | super().__init__()
30 | self.mode = mode
31 | self.out_size = size
32 |
33 | def forward(self, x):
34 | x = F.interpolate(x, mode=self.mode, size=self.out_size)
35 | return x
36 |
37 |
38 | class DownscaleUpscale(nn.Module):
39 | def __init__(self, mode_down="bilinear", mode_up="bilinear", size=32):
40 | super().__init__()
41 | self.mode_down = mode_down
42 | self.mode_up = mode_up
43 | self.out_size = size
44 |
45 | def forward(self, x):
46 | assert len(x.shape) == 4
47 | z = F.interpolate(x, mode=self.mode_down, size=self.out_size)
48 | x = F.interpolate(z, mode=self.mode_up, size=x.shape[2])
49 | return x
50 |
51 |
52 | class TpsGridGen(nn.Module):
53 | def __init__(self, out_h=256, out_w=192, use_regular_grid=True, grid_size=3, reg_factor=0, use_cuda=True):
54 | super(TpsGridGen, self).__init__()
55 | self.out_h, self.out_w = out_h, out_w
56 | self.reg_factor = reg_factor
57 | self.use_cuda = use_cuda
58 |
59 | # create grid in numpy
60 | self.grid = np.zeros( [self.out_h, self.out_w, 3], dtype=np.float32)
61 | # sampling grid with dim-0 coords (Y)
62 | self.grid_X,self.grid_Y = np.meshgrid(np.linspace(-1,1,out_w),np.linspace(-1,1,out_h))
63 | # grid_X,grid_Y: size [1,H,W,1,1]
64 | self.grid_X = torch.FloatTensor(self.grid_X).unsqueeze(0).unsqueeze(3)
65 | self.grid_Y = torch.FloatTensor(self.grid_Y).unsqueeze(0).unsqueeze(3)
66 | if use_cuda:
67 | self.grid_X = self.grid_X.cuda()
68 | self.grid_Y = self.grid_Y.cuda()
69 |
70 | # initialize regular grid for control points P_i
71 | if use_regular_grid:
72 | axis_coords = np.linspace(-1,1,grid_size)
73 | self.N = grid_size*grid_size
74 | P_Y,P_X = np.meshgrid(axis_coords,axis_coords)
75 | P_X = np.reshape(P_X,(-1,1)) # size (N,1)
76 | P_Y = np.reshape(P_Y,(-1,1)) # size (N,1)
77 | P_X = torch.FloatTensor(P_X)
78 | P_Y = torch.FloatTensor(P_Y)
79 | self.P_X_base = P_X.clone()
80 | self.P_Y_base = P_Y.clone()
81 | self.Li = self.compute_L_inverse(P_X,P_Y).unsqueeze(0)
82 | self.P_X = P_X.unsqueeze(2).unsqueeze(3).unsqueeze(4).transpose(0,4)
83 | self.P_Y = P_Y.unsqueeze(2).unsqueeze(3).unsqueeze(4).transpose(0,4)
84 | if use_cuda:
85 | self.P_X = self.P_X.cuda()
86 | self.P_Y = self.P_Y.cuda()
87 | self.P_X_base = self.P_X_base.cuda()
88 | self.P_Y_base = self.P_Y_base.cuda()
89 |
90 |
91 | def forward(self, theta):
92 | warped_grid = self.apply_transformation(theta,torch.cat((self.grid_X,self.grid_Y),3))
93 |
94 | return warped_grid
95 |
96 | def compute_L_inverse(self,X,Y):
97 | N = X.size()[0] # num of points (along dim 0)
98 | # construct matrix K
99 | Xmat = X.expand(N,N)
100 | Ymat = Y.expand(N,N)
101 | P_dist_squared = torch.pow(Xmat-Xmat.transpose(0,1),2)+torch.pow(Ymat-Ymat.transpose(0,1),2)
102 | P_dist_squared[P_dist_squared==0]=1 # make diagonal 1 to avoid NaN in log computation
103 | K = torch.mul(P_dist_squared,torch.log(P_dist_squared))
104 | # construct matrix L
105 | O = torch.FloatTensor(N,1).fill_(1)
106 | Z = torch.FloatTensor(3,3).fill_(0)
107 | P = torch.cat((O,X,Y),1)
108 | L = torch.cat((torch.cat((K,P),1),torch.cat((P.transpose(0,1),Z),1)),0)
109 | Li = torch.inverse(L)
110 | if self.use_cuda:
111 | Li = Li.cuda()
112 | return Li
113 |
114 | def apply_transformation(self,theta,points):
115 | if theta.dim()==2:
116 | theta = theta.unsqueeze(2).unsqueeze(3)
117 | # points should be in the [B,H,W,2] format,
118 | # where points[:,:,:,0] are the X coords
119 | # and points[:,:,:,1] are the Y coords
120 |
121 | # input are the corresponding control points P_i
122 | batch_size = theta.size()[0]
123 | # split theta into point coordinates
124 | Q_X=theta[:,:self.N,:,:].squeeze(3)
125 | Q_Y=theta[:,self.N:,:,:].squeeze(3)
126 | Q_X = Q_X + self.P_X_base.expand_as(Q_X)
127 | Q_Y = Q_Y + self.P_Y_base.expand_as(Q_Y)
128 |
129 | # get spatial dimensions of points
130 | points_b = points.size()[0]
131 | points_h = points.size()[1]
132 | points_w = points.size()[2]
133 |
134 | # repeat pre-defined control points along spatial dimensions of points to be transformed
135 | P_X = self.P_X.expand((1,points_h,points_w,1,self.N))
136 | P_Y = self.P_Y.expand((1,points_h,points_w,1,self.N))
137 |
138 | # compute weigths for non-linear part
139 | W_X = torch.bmm(self.Li[:,:self.N,:self.N].expand((batch_size,self.N,self.N)),Q_X)
140 | W_Y = torch.bmm(self.Li[:,:self.N,:self.N].expand((batch_size,self.N,self.N)),Q_Y)
141 | # reshape
142 | # W_X,W,Y: size [B,H,W,1,N]
143 | W_X = W_X.unsqueeze(3).unsqueeze(4).transpose(1,4).repeat(1,points_h,points_w,1,1)
144 | W_Y = W_Y.unsqueeze(3).unsqueeze(4).transpose(1,4).repeat(1,points_h,points_w,1,1)
145 | # compute weights for affine part
146 | A_X = torch.bmm(self.Li[:,self.N:,:self.N].expand((batch_size,3,self.N)),Q_X)
147 | A_Y = torch.bmm(self.Li[:,self.N:,:self.N].expand((batch_size,3,self.N)),Q_Y)
148 | # reshape
149 | # A_X,A,Y: size [B,H,W,1,3]
150 | A_X = A_X.unsqueeze(3).unsqueeze(4).transpose(1,4).repeat(1,points_h,points_w,1,1)
151 | A_Y = A_Y.unsqueeze(3).unsqueeze(4).transpose(1,4).repeat(1,points_h,points_w,1,1)
152 |
153 | # compute distance P_i - (grid_X,grid_Y)
154 | # grid is expanded in point dim 4, but not in batch dim 0, as points P_X,P_Y are fixed for all batch
155 | points_X_for_summation = points[:,:,:,0].unsqueeze(3).unsqueeze(4).expand(points[:,:,:,0].size()+(1,self.N))
156 | points_Y_for_summation = points[:,:,:,1].unsqueeze(3).unsqueeze(4).expand(points[:,:,:,1].size()+(1,self.N))
157 |
158 | if points_b==1:
159 | delta_X = points_X_for_summation-P_X
160 | delta_Y = points_Y_for_summation-P_Y
161 | else:
162 | # use expanded P_X,P_Y in batch dimension
163 | delta_X = points_X_for_summation-P_X.expand_as(points_X_for_summation)
164 | delta_Y = points_Y_for_summation-P_Y.expand_as(points_Y_for_summation)
165 |
166 | dist_squared = torch.pow(delta_X,2)+torch.pow(delta_Y,2)
167 | # U: size [1,H,W,1,N]
168 | dist_squared[dist_squared==0]=1 # avoid NaN in log computation
169 | U = torch.mul(dist_squared,torch.log(dist_squared))
170 |
171 | # expand grid in batch dimension if necessary
172 | points_X_batch = points[:,:,:,0].unsqueeze(3)
173 | points_Y_batch = points[:,:,:,1].unsqueeze(3)
174 | if points_b==1:
175 | points_X_batch = points_X_batch.expand((batch_size,)+points_X_batch.size()[1:])
176 | points_Y_batch = points_Y_batch.expand((batch_size,)+points_Y_batch.size()[1:])
177 |
178 | points_X_prime = A_X[:,:,:,:,0]+ \
179 | torch.mul(A_X[:,:,:,:,1],points_X_batch) + \
180 | torch.mul(A_X[:,:,:,:,2],points_Y_batch) + \
181 | torch.sum(torch.mul(W_X,U.expand_as(W_X)),4)
182 |
183 | points_Y_prime = A_Y[:,:,:,:,0]+ \
184 | torch.mul(A_Y[:,:,:,:,1],points_X_batch) + \
185 | torch.mul(A_Y[:,:,:,:,2],points_Y_batch) + \
186 | torch.sum(torch.mul(W_Y,U.expand_as(W_Y)),4)
187 |
188 | return torch.cat((points_X_prime,points_Y_prime),3)
189 |
190 |
191 | def random_tps(*args, grid_size=4, reg_factor=0, strength_factor=1.0):
192 | """Random TPS. Device and size determined from first argument, all
193 | remaining arguments transformed with the same parameters."""
194 | x = args[0]
195 | is_np = type(x) == np.ndarray
196 | no_batch = len(x.shape) == 3
197 | if no_batch:
198 | args = [x[None,...] for x in args]
199 | if is_np:
200 | args = [torch.tensor(x.copy()).permute(0,3,1,2) for x in args]
201 | x = args[0]
202 | use_cuda = x.is_cuda
203 | b,c,h,w = x.shape
204 | grid_size = 4
205 | tps = TpsGridGen(out_h=h,
206 | out_w=w,
207 | use_regular_grid=True,
208 | grid_size=grid_size,
209 | reg_factor=reg_factor,
210 | use_cuda=use_cuda)
211 | # theta = b,2*N - first N = X, second N = Y,
212 | control = torch.cat([tps.P_X_base, tps.P_Y_base], dim=0)
213 | control = control[None,:,0][b*[0],...]
214 | theta = (torch.rand(b,2*grid_size*grid_size)*2-1.0) / grid_size * strength_factor
215 | final = control+theta.to(control)
216 | final = torch.clamp(final, -1.0, 1.0)
217 | final = final - control
218 | final[control==-1.0] = 0.0
219 | final[control==1.0] = 0.0
220 | grid = tps(final)
221 |
222 | is_uint8 = [x.dtype == torch.uint8 for x in args]
223 | args = [args[i].to(grid)+0.5 if is_uint8[i] else args[i] for i in range(len(args))]
224 | out = [torch.nn.functional.grid_sample(x, grid, align_corners=True) for x in args]
225 | out = [out[i].to(torch.uint8) if is_uint8[i] else out[i] for i in range(len(out))]
226 | if is_np:
227 | out = [x.permute(0,2,3,1).numpy() for x in out]
228 | if no_batch:
229 | out = [x[0,...] for x in out]
230 | if len(out) == 1:
231 | out = out[0]
232 | return out
233 |
234 |
235 | def count_params(model):
236 | total_params = sum(p.numel() for p in model.parameters())
237 | return total_params
238 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name='net2net',
5 | version='0.0.1',
6 | description='Translate between networks through their latent representations.',
7 | packages=find_packages(),
8 | install_requires=[
9 | 'torch',
10 | 'numpy',
11 | 'tqdm',
12 | ],
13 | )
--------------------------------------------------------------------------------