├── README.md ├── assets ├── superresolutionfigure.png ├── teaser.png ├── texttoimage.jpg └── unpairedtranslationfigure.png ├── configs ├── autoencoder │ ├── anime_photography_256.yaml │ ├── celeba_celebahq_ffhq_256.yaml │ ├── coco256.yaml │ ├── faces256.yaml │ ├── faces32.yaml │ └── portraits_photography_256.yaml ├── creativity │ ├── anime_photography_256.yaml │ ├── celeba_celebahq_ffhq_256.yaml │ └── portraits_photography_256.yaml └── translation │ ├── faces32-to-faces256.yaml │ ├── sbert-to-ae-coco256.yaml │ ├── sbert-to-bigbigan.yaml │ └── sbert-to-biggan256.yaml ├── data ├── WORDMAP_coco_5_cap_per_img_5_min_word_freq.json ├── animegwerncroptrain.txt ├── animegwerncropvalidation.txt ├── celebahqtrain.txt ├── celebahqvalidation.txt ├── coco_imagenet_overlap_idx.txt ├── examples │ ├── anime │ │ ├── 176890.jpg │ │ ├── 1931960.jpg │ │ ├── 2775531.jpg │ │ ├── 3007790.jpg │ │ ├── 331480.jpg │ │ ├── 348930.jpg │ │ ├── 4460500.jpg │ │ ├── 499130.jpg │ │ ├── 519280.jpg │ │ ├── 656160.jpg │ │ ├── 708030.jpg │ │ ├── 881700.jpg │ │ ├── 903790.jpg │ │ ├── 910920.jpg │ │ └── 949780.jpg │ ├── humanface │ │ ├── 00010.png │ │ ├── 00012.png │ │ ├── 00018.png │ │ ├── 00048.png │ │ ├── 00059.png │ │ ├── 65710.png │ │ ├── 65719.png │ │ ├── 65732.png │ │ ├── 65831.png │ │ ├── 65843.png │ │ ├── 65855.png │ │ ├── 65946.png │ │ ├── 65949.png │ │ ├── 65990.png │ │ ├── 65998.png │ │ ├── 65999.png │ │ ├── 66040.png │ │ ├── 66048.png │ │ ├── 66056.png │ │ ├── 66090.png │ │ ├── 66100.png │ │ ├── 66118.png │ │ └── 66157.png │ └── oilportrait │ │ ├── beethoven.jpeg │ │ ├── descartes.jpeg │ │ ├── fermat.jpg │ │ ├── galileo.jpeg │ │ ├── jeanjaques.jpeg │ │ ├── kant.jpeg │ │ ├── monalisa.jpeg │ │ └── newton.jpeg ├── ffhqtrain.txt ├── ffhqvalidation.txt ├── portraitstrain.txt └── portraitsvalidation.txt ├── env_bigbigan.yaml ├── environment.yaml ├── ml4cad.py ├── net2net ├── __init__.py ├── ckpt_util.py ├── data │ ├── __init__.py │ ├── base.py │ ├── coco.py │ ├── faces.py │ ├── utils.py │ └── zcodes.py ├── models │ ├── __init__.py │ ├── autoencoder.py │ └── flows │ │ ├── __init__.py │ │ ├── flow.py │ │ └── util.py └── modules │ ├── __init__.py │ ├── autoencoder │ ├── __init__.py │ ├── basic.py │ ├── decoder.py │ ├── encoder.py │ ├── loss.py │ └── lpips.py │ ├── captions │ ├── __init__.py │ ├── model.py │ └── models.py │ ├── discriminator │ ├── __init__.py │ └── model.py │ ├── distributions │ ├── __init__.py │ └── distributions.py │ ├── facenet │ ├── __init__.py │ ├── inception_resnet_v1.py │ └── model.py │ ├── flow │ ├── __init__.py │ ├── base.py │ ├── blocks.py │ ├── flatflow.py │ └── loss.py │ ├── gan │ ├── __init__.py │ ├── bigbigan.py │ └── biggan.py │ ├── labels │ ├── __init__.py │ └── model.py │ ├── mlp │ ├── __init__.py │ └── models.py │ ├── sbert │ ├── __init__.py │ └── model.py │ └── util.py ├── setup.py └── translation.py /README.md: -------------------------------------------------------------------------------- 1 | # Net2Net 2 | Code accompanying the NeurIPS 2020 oral paper 3 | 4 | [**Network-to-Network Translation with Conditional Invertible Neural Networks**](https://compvis.github.io/net2net/)
5 | [Robin Rombach](https://github.com/rromb)\*, 6 | [Patrick Esser](https://github.com/pesser)\*, 7 | [Björn Ommer](https://hci.iwr.uni-heidelberg.de/Staff/bommer)
8 | \* equal contribution 9 | 10 | **tl;dr** Our approach distills the residual information of one model with respect to 11 | another's and thereby enables translation between fixed off-the-shelf expert 12 | models such as BERT and BigGAN without having to modify or finetune them. 13 | 14 | ![teaser](assets/teaser.png) 15 | [arXiv](https://arxiv.org/abs/2005.13580) | [BibTeX](#bibtex) | [Project Page](https://compvis.github.io/net2net/) 16 | 17 | **News Dec 19th, 2020**: added SBERT-to-BigGAN, SBERT-to-BigBiGAN and SBERT-to-AE (COCO) 18 | ## Requirements 19 | A suitable [conda](https://conda.io/) environment named `net2net` can be created 20 | and activated with: 21 | 22 | ``` 23 | conda env create -f environment.yaml 24 | conda activate net2net 25 | ``` 26 | 27 | ## Datasets 28 | - **CelebA**: Create a symlink 'data/CelebA' pointing to a folder which contains the following files: 29 | ``` 30 | . 31 | ├── identity_CelebA.txt 32 | ├── img_align_celeba 33 | ├── list_attr_celeba.txt 34 | └── list_eval_partition.txt 35 | ``` 36 | These files can be obtained [here](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html). 37 | - **CelebA-HQ**: Create a symlink `data/celebahq` pointing to a folder containing 38 | the `.npy` files of CelebA-HQ (instructions to obtain them can be found in 39 | the [PGGAN repository](https://github.com/tkarras/progressive_growing_of_gans)). 40 | - **FFHQ**: Create a symlink `data/ffhq` pointing to the `images1024x1024` folder 41 | obtained from the [FFHQ repository](https://github.com/NVlabs/ffhq-dataset). 42 | - **Anime Faces**: First download the face images from the [Anime Crop dataset](https://www.gwern.net/Crops) and then apply 43 | the preprocessing of [FFHQ](https://github.com/NVlabs/ffhq-dataset) to those images. We only keep images 44 | where the underlying [dlib face recognition model](http://dlib.net/face_landmark_detection.py.html) recognizes 45 | a face. Finally, create a symlink `data/anime` which contains the processed anime face images. 46 | - **Oil Portraits**: [Download here.](https://heibox.uni-heidelberg.de/f/4f35bdc16eea4158aa47/?dl=1) 47 | Unpack the content and place the files in `data/portraits`. It consists of 48 | 18k oil portraits, which were obtained by running [dlib](http://dlib.net/face_landmark_detection.py.html) on a subset of the [WikiArt dataset](https://www.wikiart.org/) 49 | dataset, kindly provided by [A Style-Aware Content Loss for Real-time HD Style Transfer](https://github.com/CompVis/adaptive-style-transfer). 50 | - **COCO**: Create a symlink `data/coco` containing the images from the 2017 51 | split in `train2017` and `val2017`, and their annotations in `annotations`. 52 | Files can be obtained from the [COCO webpage](https://cocodataset.org). 53 | 54 | ## ML4Creativity Demo 55 | We include a [streamlit](https://www.streamlit.io/) demo, which utilizes our 56 | approach to demonstrate biases of datasets and their creative applications. 57 | More information can be found in our paper [A Note on Data Biases in Generative 58 | Models](https://drive.google.com/file/d/1PGhBTEAgj2A_FnYMk_1VU-uOxcWY076B/view?usp=sharing) from the [Machine Learning for Creativity and Design](https://neurips2020creativity.github.io/) at [NeurIPS 2020](https://nips.cc/Conferences/2020). Download the models from 59 | 60 | - [2020-11-30T23-32-28_celeba_celebahq_ffhq_256](https://k00.fr/lro927bu) 61 | - [2020-12-02T13-58-19_anime_photography_256](https://heibox.uni-heidelberg.de/d/075e81e16de948aea7a1/) 62 | - [2020-12-02T16-19-39_portraits_photography_256](https://k00.fr/y3rvnl3j) 63 | 64 | 65 | and place them into `logs`. Run the demo with 66 | 67 | ``` 68 | streamlit run ml4cad.py 69 | ``` 70 | 71 | ## Training 72 | Our code uses [Pytorch-Lightning](https://www.pytorchlightning.ai/) and thus natively supports 73 | things like 16-bit precision, multi-GPU training and gradient accumulation. Training details for any model need to be specified in a dedicated `.yaml` file. 74 | In general, such a config file is structured as follows: 75 | ``` 76 | model: 77 | base_learning_rate: 4.5e-6 78 | target: 79 | params: 80 | ... 81 | data: 82 | target: translation.DataModuleFromConfig 83 | params: 84 | batch_size: ... 85 | num_workers: ... 86 | train: 87 | target: 88 | params: 89 | ... 90 | validation: 91 | target: 92 | params: 93 | ... 94 | ``` 95 | Any Pytorch-Lightning model specified under `model.target` is then trained on the specified data 96 | by running the command: 97 | ``` 98 | python translation.py --base -t --gpus 0, 99 | ``` 100 | All available Pytorch-Lightning [trainer](https://pytorch-lightning.readthedocs.io/en/stable/trainer.html) arguments can be added via the command line, e.g. run 101 | ``` 102 | python translation.py --base -t --gpus 0,1,2,3 --precision 16 --accumulate_grad_batches 2 103 | ``` 104 | to train a model on 4 GPUs using 16-bit precision and a 2-step gradient accumulation. 105 | More details are provided in the examples below. 106 | 107 | ### Training a cINN 108 | Training a cINN for network-to-network translation usually utilizes the Lighnting Module `net2net.models.flows.flow.Net2NetFlow` 109 | and makes a few further assumptions on the configuration file and model interface: 110 | ``` 111 | model: 112 | base_learning_rate: 4.5e-6 113 | target: net2net.models.flows.flow.Net2NetFlow 114 | params: 115 | flow_config: 116 | target: 117 | params: 118 | ... 119 | 120 | cond_stage_config: 121 | target: 122 | params: 123 | ... 124 | 125 | first_stage_config: 126 | target: 127 | params: 128 | ... 129 | ``` 130 | Here, the entries under `flow_config` specifies the architecture and parameters of the conditional INN; 131 | `cond_stage_config` specifies the first network whose representation is to be translated into another network 132 | specified by `first_stage_config`. Our model `net2net.models.flows.flow.Net2NetFlow` expects that the first 133 | network has a `.encode()` method which produces the representation of interest, while the second network should 134 | have an `encode()` and a `decode()` method, such that both of them applied sequentially produce the networks output. This allows for a modular combination of arbitrary models of interest. For more details, see the examples below. 135 | 136 | ### Training a cINN - Superresolution 137 | ![superres](assets/superresolutionfigure.png) 138 | Training details for a cINN to concatenate two autoencoders from different image scales for stochastic 139 | superresolution are specified in `configs/translation/faces32-to-256.yaml`. 140 | 141 | To train a model for translating from 32 x 32 images to 256 x 256 images on GPU 0, run 142 | ``` 143 | python translation.py --base configs/translation/faces32-to-faces256.yaml -t --gpus 0, 144 | ``` 145 | and specify any additional training commands as described above. Note that this setup requires two 146 | pretrained autoencoder models, one on 32 x 32 images and the other on 256 x 256. If you want to 147 | train them yourself on a combination of FFHQ and CelebA-HQ, run 148 | ``` 149 | python translation.py --base configs/autoencoder/faces32.yaml -t --gpus , 150 | ``` 151 | for the 32 x 32 images; and 152 | ``` 153 | python translation.py --base configs/autoencoder/faces256.yaml -t --gpus , 154 | ``` 155 | for the model on 256 x 256 images. After training, adopt the corresponding model paths in `configs/translation/faces32-to-faces256.yaml`. Additionally, we provide weights of pretrained autoencoders for both settings: 156 | [Weights 32x32](https://heibox.uni-heidelberg.de/f/b0b103af8406467abe48/); [Weights256x256](https://k00.fr/94lw2vlg). 157 | To run the training as described above, put them into 158 | `logs/2020-10-16T17-11-42_FacesFQ32x32/checkpoints/last.ckpt`and 159 | `logs/2020-09-16T16-23-39_FacesXL256z128/checkpoints/last.ckpt`, respectively. 160 | 161 | ### Training a cINN - Unpaired Translation 162 | ![superres](assets/unpairedtranslationfigure.png) 163 | All training scenarios for unpaired translation are specified in the configs in `configs/creativity`. 164 | We provide code and pretrained autoencoder models for three different translation tasks: 165 | - **Anime** ⟷ **Photography**; see `configs/creativity/anime_photography_256.yaml`. 166 | Download autoencoder checkpoint ([Download Anime+Photography](https://heibox.uni-heidelberg.de/f/315c628c8b0e40238132/)) and place into `logs/2020-09-30T21-40-22_AnimeAndFHQ/checkpoints/epoch=000007.ckpt`. 167 | - **Oil-Portrait** ⟷ **Photography**; see `configs/creativity/portraits_photography_256.yaml` 168 | Download autoencoder checkpoint ([Download Portrait+Photography](https://heibox.uni-heidelberg.de/f/4f9449418a2e4025bb5f/)) and place into `logs/2020-09-29T23-47-10_PortraitsAndFFHQ/checkpoints/epoch=000004.ckpt`. 169 | - **FFHQ** ⟷ **CelebA-HQ** ⟷ **CelebA**; see `configs/creativity/celeba_celebahq_ffhq_256.yaml` 170 | Download autoencoder checkpoint ([Download FFHQ+CelebAHQ+CelebA](https://k00.fr/94lw2vlg)) and place into `logs/2020-09-16T16-23-39_FacesXL256z128/checkpoints/last.ckpt`. 171 | Note that this is the same autoencoder checkpoint as for the stochastic superresolution experiment. 172 | 173 | To train a cINN on one of these unpaired transfer tasks using the first GPU, simply run 174 | ``` 175 | python translation.py --base configs/creativity/.yaml -t --gpus 0, 176 | ``` 177 | where `.yaml` is one of `portraits_photography_256.yaml`, `celeba_celebahq_ffhq_256.yaml` 178 | or `anime_photography_256.yaml`. Providing additional arguments to the pytorch-lightning 179 | trainer object is also possible as described above. 180 | 181 | In our framework, unpaired translation between domains is formulated as a 182 | translation between expert 1, a model which can infer the domain a given image 183 | belongs to, and expert 2, a model which can synthesize images of each domain. 184 | In the examples provided, we assume that the domain label comes with the 185 | dataset and provide the `net2net.modules.labels.model.Labelator` module, which 186 | simply returns a one hot encoding of this label. However, one could also use a 187 | classification model which infers the domain label from the image itself. 188 | For expert 2, our examples use an autoencoder trained jointly on all domains, 189 | which is easily achieved by concatenating datasets together. The provided 190 | `net2net.data.base.ConcatDatasetWithIndex` concatenates datasets and returns 191 | the corresponding dataset label for each example, which can then be used by the 192 | `Labelator` class for the translation. The training configurations for the 193 | autoencoders used in the creativity experiments are included in 194 | `configs/autoencoder/anime_photography_256.yaml`, 195 | `configs/autoencoder/celeba_celebahq_ffhq_256.yaml` and 196 | `configs/autoencoder/portraits_photography_256.yaml`. 197 | 198 | #### Unpaired Translation on Custom Datasets 199 | Create pytorch datasets for each 200 | of your domains, create a concatenated dataset with `ConcatDatasetWithIndex` 201 | (follow the example in `net2net.data.faces.CCFQTrain`), train an 202 | autoencoder on the concatenated dataset (adjust the `data` section in 203 | `configs/autoencoder/celeba_celebahq_ffhq_256.yaml`) and finally train a 204 | net2net translation model between a `Labelator` and your autoencoder (adjust 205 | the sections `data` and `first_stage_config` in 206 | `configs/creativity/celeba_celebahq_ffhq_256.yaml`). You can then also add your 207 | new model to the available modes in the `ml4cad.py` demo to visualize the 208 | results. 209 | 210 | 211 | ### Training a cINN - Text-to-Image 212 | ![texttoimage](assets/texttoimage.jpg) 213 | We provide code to obtain a text-to-image model by translating between a text 214 | model ([SBERT](https://www.sbert.net/)) and an image decoder. To show the 215 | flexibility of our approach, we include code for three different 216 | decoders: BigGAN, as described in the paper, 217 | [BigBiGAN](https://deepmind.com/research/open-source/BigBiGAN-Large-Scale-Adversarial-Representation-Learning), 218 | which is only available as a [tensorflow](https://www.tensorflow.org/) model 219 | and thus nicely shows how our approach can work with black-box experts, and an 220 | autoencoder. 221 | 222 | #### SBERT-to-BigGAN 223 | Train with 224 | ``` 225 | python translation.py --base configs/translation/sbert-to-biggan256.yaml -t --gpus 0, 226 | ``` 227 | When running it for the first time, the required models will be downloaded 228 | automatically. 229 | 230 | #### SBERT-to-BigBiGAN 231 | Since BigBiGAN is only available on 232 | [tensorflow-hub](https://tfhub.dev/s?q=bigbigan), this example has an 233 | additional dependency on tensorflow. A suitable environment is provided in 234 | `env_bigbigan.yaml`, and you will need COCO for training. You can then start 235 | training with 236 | ``` 237 | python translation.py --base configs/translation/sbert-to-bigbigan.yaml -t --gpus 0, 238 | ``` 239 | Note that the `BigBiGAN` class is just a naive wrapper, which converts pytorch 240 | tensors to numpy arrays, feeds them to the tensorflow graph and again converts 241 | the result to pytorch tensors. It does not require gradients of the expert 242 | model and serves as a good example on how to use black-box experts. 243 | 244 | #### SBERT-to-AE 245 | Similarly to the other examples, you can also train your own autoencoder on 246 | COCO with 247 | ``` 248 | python translation.py --base configs/autoencoder/coco256.yaml -t --gpus 0, 249 | ``` 250 | or [download a pre-trained 251 | one](https://k00.fr/fbti4058), and translate 252 | to it by running 253 | ``` 254 | python translation.py --base configs/translation/sbert-to-ae-coco256.yaml -t --gpus 0, 255 | ``` 256 | 257 | ## Shout-outs 258 | Thanks to everyone who makes their code and models available. 259 | 260 | - BigGAN code and weights from: [LoreGoetschalckx/GANalyze](https://github.com/LoreGoetschalckx/GANalyze) 261 | - Code and weights for the captioning model: [https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning) 262 | 263 | 264 | ## BibTeX 265 | 266 | ``` 267 | @misc{rombach2020networktonetwork, 268 | title={Network-to-Network Translation with Conditional Invertible Neural Networks}, 269 | author={Robin Rombach and Patrick Esser and Björn Ommer}, 270 | year={2020}, 271 | eprint={2005.13580}, 272 | archivePrefix={arXiv}, 273 | primaryClass={cs.CV} 274 | } 275 | ``` 276 | 277 | ``` 278 | @misc{esser2020note, 279 | title={A Note on Data Biases in Generative Models}, 280 | author={Patrick Esser and Robin Rombach and Björn Ommer}, 281 | year={2020}, 282 | eprint={2012.02516}, 283 | archivePrefix={arXiv}, 284 | primaryClass={cs.CV} 285 | } 286 | ``` 287 | -------------------------------------------------------------------------------- /assets/superresolutionfigure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/assets/superresolutionfigure.png -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/assets/teaser.png -------------------------------------------------------------------------------- /assets/texttoimage.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/assets/texttoimage.jpg -------------------------------------------------------------------------------- /assets/unpairedtranslationfigure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/assets/unpairedtranslationfigure.png -------------------------------------------------------------------------------- /configs/autoencoder/anime_photography_256.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: net2net.models.autoencoder.BigAE 4 | params: 5 | loss_config: 6 | target: net2net.modules.autoencoder.loss.LPIPSWithDiscriminator 7 | params: 8 | disc_start: 75001 9 | kl_weight: 0.000001 10 | disc_weight: 0.5 11 | 12 | encoder_config: 13 | target: net2net.modules.autoencoder.encoder.ResnetEncoder 14 | params: 15 | in_channels: 3 16 | in_size: 256 17 | pretrained: false 18 | type: resnet101 19 | z_dim: 128 20 | 21 | decoder_config: 22 | target: net2net.modules.autoencoder.decoder.BigGANDecoderWrapper 23 | params: 24 | z_dim: 128 25 | in_size: 256 26 | use_actnorm_in_dec: true 27 | 28 | data: 29 | target: translation.DataModuleFromConfig 30 | params: 31 | batch_size: 3 32 | train: 33 | target: net2net.data.faces.FacesHQAndAnimeTrain 34 | params: 35 | size: 256 36 | validation: 37 | target: net2net.data.faces.FacesHQAndAnimeValidation 38 | params: 39 | size: 256 40 | -------------------------------------------------------------------------------- /configs/autoencoder/celeba_celebahq_ffhq_256.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: net2net.models.autoencoder.BigAE 4 | params: 5 | loss_config: 6 | target: net2net.modules.autoencoder.loss.LPIPSWithDiscriminator 7 | params: 8 | disc_start: 75001 9 | kl_weight: 0.000001 10 | disc_weight: 0.5 11 | 12 | encoder_config: 13 | target: net2net.modules.autoencoder.encoder.ResnetEncoder 14 | params: 15 | in_channels: 3 16 | in_size: 256 17 | pretrained: false 18 | type: resnet101 19 | z_dim: 128 20 | 21 | decoder_config: 22 | target: net2net.modules.autoencoder.decoder.BigGANDecoderWrapper 23 | params: 24 | z_dim: 128 25 | in_size: 256 26 | use_actnorm_in_dec: true 27 | 28 | 29 | data: 30 | target: translation.DataModuleFromConfig 31 | params: 32 | batch_size: 3 33 | train: 34 | target: net2net.data.faces.CCFQTrain 35 | params: 36 | size: 256 37 | validation: 38 | target: net2net.data.faces.CCFQValidation 39 | params: 40 | size: 256 41 | -------------------------------------------------------------------------------- /configs/autoencoder/coco256.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: net2net.models.autoencoder.BigAE 4 | params: 5 | loss_config: 6 | target: net2net.modules.autoencoder.loss.LPIPSWithDiscriminator 7 | params: 8 | disc_start: 250001 9 | kl_weight: 0.000001 10 | disc_weight: 0.5 11 | 12 | encoder_config: 13 | target: net2net.modules.autoencoder.encoder.ResnetEncoder 14 | params: 15 | in_channels: 3 16 | in_size: 256 17 | pretrained: false 18 | type: resnet101 19 | z_dim: 256 20 | 21 | decoder_config: 22 | target: net2net.modules.autoencoder.decoder.BigGANDecoderWrapper 23 | params: 24 | z_dim: 256 25 | in_size: 256 26 | use_actnorm_in_dec: true 27 | 28 | data: 29 | target: translation.DataModuleFromConfig 30 | params: 31 | batch_size: 3 32 | train: 33 | target: net2net.data.coco.CocoImagesAndCaptionsTrain 34 | params: 35 | size: 256 36 | validation: 37 | target: net2net.data.coco.CocoImagesAndCaptionsValidation 38 | params: 39 | size: 256 40 | -------------------------------------------------------------------------------- /configs/autoencoder/faces256.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: net2net.models.autoencoder.BigAE 4 | params: 5 | loss_config: 6 | target: net2net.modules.autoencoder.loss.LPIPSWithDiscriminator 7 | params: 8 | disc_start: 75001 9 | kl_weight: 0.000001 10 | disc_weight: 0.5 11 | 12 | encoder_config: 13 | target: net2net.modules.autoencoder.encoder.ResnetEncoder 14 | params: 15 | in_channels: 3 16 | in_size: 256 17 | pretrained: false 18 | type: resnet101 19 | z_dim: 128 20 | 21 | decoder_config: 22 | target: net2net.modules.autoencoder.decoder.BigGANDecoderWrapper 23 | params: 24 | z_dim: 128 25 | in_size: 256 26 | use_actnorm_in_dec: true 27 | 28 | 29 | data: 30 | target: translation.DataModuleFromConfig 31 | params: 32 | batch_size: 3 33 | train: 34 | target: net2net.data.faces.CelebFQTrain 35 | params: 36 | size: 256 37 | validation: 38 | target: net2net.data.faces.CelebFQValidation 39 | params: 40 | size: 256 41 | -------------------------------------------------------------------------------- /configs/autoencoder/faces32.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: net2net.models.autoencoder.BasicAE 4 | params: 5 | ae_config: 6 | target: net2net.modules.autoencoder.basic.BasicAEModel 7 | params: 8 | in_size: 32 9 | n_down: 4 10 | z_dim: 128 11 | in_channels: 3 12 | deterministic: False 13 | 14 | loss_config: 15 | target: net2net.modules.autoencoder.loss.LPIPSWithDiscriminator 16 | params: 17 | disc_start: 10001 18 | kl_weight: 0.000001 19 | disc_weight: 0.5 20 | 21 | 22 | data: 23 | target: translation.DataModuleFromConfig 24 | params: 25 | batch_size: 12 26 | train: 27 | target: net2net.data.faces.CelebFQTrain 28 | params: 29 | size: 32 30 | validation: 31 | target: net2net.data.faces.CelebFQValidation 32 | params: 33 | size: 32 34 | -------------------------------------------------------------------------------- /configs/autoencoder/portraits_photography_256.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: net2net.models.autoencoder.BigAE 4 | params: 5 | loss_config: 6 | target: net2net.modules.autoencoder.loss.LPIPSWithDiscriminator 7 | params: 8 | disc_start: 75001 9 | kl_weight: 0.000001 10 | disc_weight: 0.5 11 | 12 | encoder_config: 13 | target: net2net.modules.autoencoder.encoder.ResnetEncoder 14 | params: 15 | in_channels: 3 16 | in_size: 256 17 | pretrained: false 18 | type: resnet101 19 | z_dim: 128 20 | 21 | decoder_config: 22 | target: net2net.modules.autoencoder.decoder.BigGANDecoderWrapper 23 | params: 24 | z_dim: 128 25 | in_size: 256 26 | use_actnorm_in_dec: true 27 | 28 | 29 | data: 30 | target: translation.DataModuleFromConfig 31 | params: 32 | batch_size: 3 33 | train: 34 | target: net2net.data.faces.FFHQAndPortraitsTrain 35 | params: 36 | size: 256 37 | validation: 38 | target: net2net.data.faces.FFHQAndPortraitsValidation 39 | params: 40 | size: 256 41 | -------------------------------------------------------------------------------- /configs/creativity/anime_photography_256.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: net2net.models.flows.flow.Net2NetFlow 4 | params: 5 | first_stage_key: "image" 6 | cond_stage_key: "class" 7 | flow_config: 8 | target: net2net.modules.flow.flatflow.ConditionalFlatCouplingFlow 9 | params: 10 | conditioning_dim: 2 11 | embedding_dim: 10 12 | conditioning_depth: 2 13 | n_flows: 20 14 | in_channels: 128 15 | hidden_dim: 1024 16 | hidden_depth: 2 17 | activation: "none" 18 | conditioner_use_bn: True 19 | 20 | cond_stage_config: 21 | target: net2net.modules.labels.model.Labelator 22 | params: 23 | num_classes: 2 24 | as_one_hot: True 25 | 26 | first_stage_config: 27 | target: net2net.models.autoencoder.BigAE 28 | params: 29 | ckpt_path: "logs/2020-09-30T21-40-22_AnimeAndFHQ/checkpoints/epoch=000007.ckpt" 30 | encoder_config: 31 | target: net2net.modules.autoencoder.encoder.ResnetEncoder 32 | params: 33 | in_channels: 3 34 | in_size: 256 35 | pretrained: false 36 | type: resnet101 37 | z_dim: 128 38 | 39 | decoder_config: 40 | target: net2net.modules.autoencoder.decoder.BigGANDecoderWrapper 41 | params: 42 | z_dim: 128 43 | in_size: 256 44 | use_actnorm_in_dec: true 45 | 46 | loss_config: 47 | target: net2net.modules.autoencoder.loss.DummyLoss 48 | 49 | data: 50 | target: translation.DataModuleFromConfig 51 | params: 52 | batch_size: 15 53 | train: 54 | target: net2net.data.faces.FacesHQAndAnimeTrain 55 | params: 56 | size: 256 57 | validation: 58 | target: net2net.data.faces.FacesHQAndAnimeValidation 59 | params: 60 | size: 256 61 | -------------------------------------------------------------------------------- /configs/creativity/celeba_celebahq_ffhq_256.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: net2net.models.flows.flow.Net2NetFlow 4 | params: 5 | first_stage_key: "image" 6 | cond_stage_key: "class" 7 | flow_config: 8 | target: net2net.modules.flow.flatflow.ConditionalFlatCouplingFlow 9 | params: 10 | conditioning_dim: 3 11 | embedding_dim: 10 12 | conditioning_depth: 2 13 | n_flows: 20 14 | in_channels: 128 15 | hidden_dim: 1024 16 | hidden_depth: 2 17 | activation: "none" 18 | conditioner_use_bn: True 19 | 20 | cond_stage_config: 21 | target: net2net.modules.labels.model.Labelator 22 | params: 23 | num_classes: 3 24 | as_one_hot: True 25 | 26 | first_stage_config: 27 | target: net2net.models.autoencoder.BigAE 28 | params: 29 | ckpt_path: "logs/2020-09-16T16-23-39_FacesXL256z128/checkpoints/last.ckpt" 30 | encoder_config: 31 | target: net2net.modules.autoencoder.encoder.ResnetEncoder 32 | params: 33 | in_channels: 3 34 | in_size: 256 35 | pretrained: false 36 | type: resnet101 37 | z_dim: 128 38 | 39 | decoder_config: 40 | target: net2net.modules.autoencoder.decoder.BigGANDecoderWrapper 41 | params: 42 | z_dim: 128 43 | in_size: 256 44 | use_actnorm_in_dec: true 45 | 46 | loss_config: 47 | target: net2net.modules.autoencoder.loss.DummyLoss 48 | 49 | data: 50 | target: translation.DataModuleFromConfig 51 | params: 52 | batch_size: 15 53 | train: 54 | target: net2net.data.faces.CCFQTrain 55 | params: 56 | size: 256 57 | validation: 58 | target: net2net.data.faces.CCFQValidation 59 | params: 60 | size: 256 61 | -------------------------------------------------------------------------------- /configs/creativity/portraits_photography_256.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: net2net.models.flows.flow.Net2NetFlow 4 | params: 5 | first_stage_key: "image" 6 | cond_stage_key: "class" 7 | flow_config: 8 | target: net2net.modules.flow.flatflow.ConditionalFlatCouplingFlow 9 | params: 10 | conditioning_dim: 2 11 | embedding_dim: 10 12 | conditioning_depth: 2 13 | n_flows: 20 14 | in_channels: 128 15 | hidden_dim: 1024 16 | hidden_depth: 2 17 | activation: "none" 18 | conditioner_use_bn: True 19 | 20 | cond_stage_config: 21 | target: net2net.modules.labels.model.Labelator 22 | params: 23 | num_classes: 2 24 | as_one_hot: True 25 | 26 | first_stage_config: 27 | target: net2net.models.autoencoder.BigAE 28 | params: 29 | ckpt_path: "logs/2020-09-29T23-47-10_PortraitsAndFFHQ/checkpoints/epoch=000004.ckpt" 30 | encoder_config: 31 | target: net2net.modules.autoencoder.encoder.ResnetEncoder 32 | params: 33 | in_channels: 3 34 | in_size: 256 35 | pretrained: false 36 | type: resnet101 37 | z_dim: 128 38 | 39 | decoder_config: 40 | target: net2net.modules.autoencoder.decoder.BigGANDecoderWrapper 41 | params: 42 | z_dim: 128 43 | in_size: 256 44 | use_actnorm_in_dec: true 45 | 46 | loss_config: 47 | target: net2net.modules.autoencoder.loss.DummyLoss 48 | 49 | data: 50 | target: translation.DataModuleFromConfig 51 | params: 52 | batch_size: 15 53 | train: 54 | target: net2net.data.faces.FFHQAndPortraitsTrain 55 | params: 56 | size: 256 57 | validation: 58 | target: net2net.data.faces.FFHQAndPortraitsValidation 59 | params: 60 | size: 256 61 | -------------------------------------------------------------------------------- /configs/translation/faces32-to-faces256.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: net2net.models.flows.flow.Net2NetFlow 4 | params: 5 | first_stage_key: "image" 6 | cond_stage_key: "image" 7 | interpolate_cond_size: 32 8 | flow_config: 9 | target: net2net.modules.flow.flatflow.ConditionalFlatCouplingFlow 10 | params: 11 | conditioning_dim: 128 12 | embedding_dim: 128 13 | conditioning_depth: 2 14 | n_flows: 20 15 | in_channels: 128 16 | hidden_dim: 1024 17 | hidden_depth: 2 18 | activation: "none" 19 | conditioner_use_bn: True 20 | 21 | cond_stage_config: 22 | target: net2net.models.autoencoder.BasicAE 23 | params: 24 | ckpt_path: "logs/2020-10-16T17-11-42_FacesFQ32x32/checkpoints/last.ckpt" 25 | ae_config: 26 | target: net2net.modules.autoencoder.basic.BasicAEModel 27 | params: 28 | in_size: 32 29 | n_down: 4 30 | z_dim: 128 31 | in_channels: 3 32 | deterministic: False 33 | 34 | loss_config: 35 | target: net2net.modules.autoencoder.loss.DummyLoss # dummy 36 | 37 | first_stage_config: 38 | target: net2net.models.autoencoder.BigAE 39 | params: 40 | ckpt_path: "logs/2020-09-16T16-23-39_FacesXL256z128/checkpoints/last.ckpt" 41 | 42 | encoder_config: 43 | target: net2net.modules.autoencoder.encoder.ResnetEncoder 44 | params: 45 | in_channels: 3 46 | in_size: 256 47 | pretrained: false 48 | type: resnet101 49 | z_dim: 128 50 | 51 | decoder_config: 52 | target: net2net.modules.autoencoder.decoder.BigGANDecoderWrapper 53 | params: 54 | z_dim: 128 55 | in_size: 256 56 | use_actnorm_in_dec: true 57 | 58 | loss_config: 59 | target: net2net.modules.autoencoder.loss.DummyLoss 60 | 61 | data: 62 | target: translation.DataModuleFromConfig 63 | params: 64 | batch_size: 6 65 | train: 66 | target: net2net.data.faces.CelebFQTrain 67 | params: 68 | size: 256 69 | validation: 70 | target: net2net.data.faces.CelebFQValidation 71 | params: 72 | size: 256 73 | -------------------------------------------------------------------------------- /configs/translation/sbert-to-ae-coco256.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: net2net.models.flows.flow.Net2NetFlow 4 | params: 5 | first_stage_key: "image" 6 | cond_stage_key: "caption" 7 | flow_config: 8 | target: net2net.modules.flow.flatflow.ConditionalFlatCouplingFlow 9 | params: 10 | conditioning_dim: 1024 11 | embedding_dim: 256 12 | conditioning_depth: 2 13 | n_flows: 24 14 | in_channels: 256 15 | hidden_dim: 1024 16 | hidden_depth: 2 17 | activation: "none" 18 | conditioner_use_bn: True 19 | 20 | cond_stage_config: 21 | target: net2net.modules.sbert.model.SentenceEmbedder 22 | params: 23 | version: "bert-large-nli-stsb-mean-tokens" 24 | 25 | first_stage_config: 26 | target: net2net.models.autoencoder.BigAE 27 | params: 28 | ckpt_path: "logs/2020-12-18T22-49-43_coco256/checkpoints/last.ckpt" 29 | 30 | encoder_config: 31 | target: net2net.modules.autoencoder.encoder.ResnetEncoder 32 | params: 33 | in_channels: 3 34 | in_size: 256 35 | pretrained: false 36 | type: resnet101 37 | z_dim: 256 38 | 39 | decoder_config: 40 | target: net2net.modules.autoencoder.decoder.BigGANDecoderWrapper 41 | params: 42 | z_dim: 256 43 | in_size: 256 44 | use_actnorm_in_dec: true 45 | 46 | loss_config: 47 | target: net2net.modules.autoencoder.loss.DummyLoss 48 | 49 | data: 50 | target: translation.DataModuleFromConfig 51 | params: 52 | batch_size: 16 53 | train: 54 | target: net2net.data.coco.CocoImagesAndCaptionsTrain 55 | params: 56 | size: 256 57 | validation: 58 | target: net2net.data.coco.CocoImagesAndCaptionsValidation 59 | params: 60 | size: 256 61 | -------------------------------------------------------------------------------- /configs/translation/sbert-to-bigbigan.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: net2net.models.flows.flow.Net2NetFlow 4 | params: 5 | first_stage_key: "image" 6 | cond_stage_key: "caption" 7 | flow_config: 8 | target: net2net.modules.flow.flatflow.ConditionalFlatCouplingFlow 9 | params: 10 | conditioning_dim: 1024 11 | embedding_dim: 256 12 | conditioning_depth: 2 13 | n_flows: 24 14 | in_channels: 120 15 | hidden_dim: 1024 16 | hidden_depth: 2 17 | activation: "none" 18 | conditioner_use_bn: True 19 | 20 | cond_stage_config: 21 | target: net2net.modules.sbert.model.SentenceEmbedder 22 | params: 23 | version: "bert-large-nli-stsb-mean-tokens" 24 | 25 | first_stage_config: 26 | target: net2net.modules.gan.bigbigan.BigBiGAN 27 | 28 | data: 29 | target: translation.DataModuleFromConfig 30 | params: 31 | batch_size: 16 32 | train: 33 | target: net2net.data.coco.CocoImagesAndCaptionsTrain 34 | params: 35 | size: 256 36 | validation: 37 | target: net2net.data.coco.CocoImagesAndCaptionsValidation 38 | params: 39 | size: 256 40 | -------------------------------------------------------------------------------- /configs/translation/sbert-to-biggan256.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: net2net.models.flows.flow.Net2BigGANFlow 4 | params: 5 | flow_config: 6 | target: net2net.modules.flow.flatflow.ConditionalFlatCouplingFlow 7 | params: 8 | conditioning_dim: 768 9 | embedding_dim: 268 10 | conditioning_depth: 2 11 | n_flows: 24 12 | in_channels: 268 13 | hidden_dim: 1024 14 | hidden_depth: 2 15 | activation: "none" 16 | conditioner_use_bn: True 17 | 18 | cond_stage_config: 19 | target: net2net.modules.sbert.model.SentenceEmbedder 20 | params: 21 | version: "bert-base-nli-stsb-mean-tokens" 22 | 23 | gan_config: 24 | target: "net2net.modules.gan.biggan.BigGANWrapper" 25 | params: 26 | image_size: 256 27 | 28 | make_cond_config: 29 | # Takes an image produced by BigGAN and produces a textual description 30 | target: "net2net.modules.captions.model.Img2Text" 31 | 32 | data: 33 | target: translation.DataModuleFromConfig 34 | params: 35 | batch_size: 16 36 | train: 37 | target: net2net.data.zcodes.RestrictedTrainSamples 38 | params: 39 | n_samples: 100000 40 | z_shape: 41 | - 140 42 | 43 | truncation: 2.5 44 | validation: 45 | target: net2net.data.zcodes.RestrictedTestSamples 46 | params: 47 | n_samples: 10000 48 | z_shape: 49 | - 140 50 | 51 | truncation: 2.5 52 | -------------------------------------------------------------------------------- /data/coco_imagenet_overlap_idx.txt: -------------------------------------------------------------------------------- 1 | 11 2 | 14 3 | 15 4 | 17 5 | 21 6 | 22 7 | 77 8 | 105 9 | 153 10 | 200 11 | 230 12 | 235 13 | 238 14 | 239 15 | 248 16 | 251 17 | 252 18 | 256 19 | 275 20 | 281 21 | 282 22 | 283 23 | 284 24 | 285 25 | 294 26 | 295 27 | 296 28 | 297 29 | 340 30 | 344 31 | 349 32 | 383 33 | 385 34 | 386 35 | 387 36 | 388 37 | 409 38 | 414 39 | 423 40 | 436 41 | 440 42 | 444 43 | 451 44 | 457 45 | 466 46 | 475 47 | 479 48 | 508 49 | 512 50 | 515 51 | 526 52 | 530 53 | 537 54 | 544 55 | 555 56 | 559 57 | 565 58 | 569 59 | 603 60 | 620 61 | 623 62 | 647 63 | 651 64 | 659 65 | 672 66 | 693 67 | 703 68 | 705 69 | 717 70 | 720 71 | 729 72 | 737 73 | 751 74 | 760 75 | 761 76 | 765 77 | 770 78 | 779 79 | 788 80 | 795 81 | 799 82 | 809 83 | 817 84 | 829 85 | 831 86 | 850 87 | 859 88 | 861 89 | 864 90 | 867 91 | 878 92 | 879 93 | 883 94 | 892 95 | 895 96 | 898 97 | 904 98 | 905 99 | 906 100 | 907 101 | 910 102 | 917 103 | 921 104 | 923 105 | 934 106 | 937 107 | 950 108 | 954 109 | 956 110 | 963 111 | 966 112 | 968 113 | 990 114 | 999 115 | -------------------------------------------------------------------------------- /data/examples/anime/176890.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/anime/176890.jpg -------------------------------------------------------------------------------- /data/examples/anime/1931960.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/anime/1931960.jpg -------------------------------------------------------------------------------- /data/examples/anime/2775531.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/anime/2775531.jpg -------------------------------------------------------------------------------- /data/examples/anime/3007790.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/anime/3007790.jpg -------------------------------------------------------------------------------- /data/examples/anime/331480.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/anime/331480.jpg -------------------------------------------------------------------------------- /data/examples/anime/348930.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/anime/348930.jpg -------------------------------------------------------------------------------- /data/examples/anime/4460500.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/anime/4460500.jpg -------------------------------------------------------------------------------- /data/examples/anime/499130.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/anime/499130.jpg -------------------------------------------------------------------------------- /data/examples/anime/519280.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/anime/519280.jpg -------------------------------------------------------------------------------- /data/examples/anime/656160.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/anime/656160.jpg -------------------------------------------------------------------------------- /data/examples/anime/708030.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/anime/708030.jpg -------------------------------------------------------------------------------- /data/examples/anime/881700.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/anime/881700.jpg -------------------------------------------------------------------------------- /data/examples/anime/903790.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/anime/903790.jpg -------------------------------------------------------------------------------- /data/examples/anime/910920.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/anime/910920.jpg -------------------------------------------------------------------------------- /data/examples/anime/949780.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/anime/949780.jpg -------------------------------------------------------------------------------- /data/examples/humanface/00010.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/humanface/00010.png -------------------------------------------------------------------------------- /data/examples/humanface/00012.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/humanface/00012.png -------------------------------------------------------------------------------- /data/examples/humanface/00018.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/humanface/00018.png -------------------------------------------------------------------------------- /data/examples/humanface/00048.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/humanface/00048.png -------------------------------------------------------------------------------- /data/examples/humanface/00059.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/humanface/00059.png -------------------------------------------------------------------------------- /data/examples/humanface/65710.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/humanface/65710.png -------------------------------------------------------------------------------- /data/examples/humanface/65719.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/humanface/65719.png -------------------------------------------------------------------------------- /data/examples/humanface/65732.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/humanface/65732.png -------------------------------------------------------------------------------- /data/examples/humanface/65831.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/humanface/65831.png -------------------------------------------------------------------------------- /data/examples/humanface/65843.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/humanface/65843.png -------------------------------------------------------------------------------- /data/examples/humanface/65855.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/humanface/65855.png -------------------------------------------------------------------------------- /data/examples/humanface/65946.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/humanface/65946.png -------------------------------------------------------------------------------- /data/examples/humanface/65949.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/humanface/65949.png -------------------------------------------------------------------------------- /data/examples/humanface/65990.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/humanface/65990.png -------------------------------------------------------------------------------- /data/examples/humanface/65998.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/humanface/65998.png -------------------------------------------------------------------------------- /data/examples/humanface/65999.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/humanface/65999.png -------------------------------------------------------------------------------- /data/examples/humanface/66040.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/humanface/66040.png -------------------------------------------------------------------------------- /data/examples/humanface/66048.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/humanface/66048.png -------------------------------------------------------------------------------- /data/examples/humanface/66056.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/humanface/66056.png -------------------------------------------------------------------------------- /data/examples/humanface/66090.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/humanface/66090.png -------------------------------------------------------------------------------- /data/examples/humanface/66100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/humanface/66100.png -------------------------------------------------------------------------------- /data/examples/humanface/66118.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/humanface/66118.png -------------------------------------------------------------------------------- /data/examples/humanface/66157.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/humanface/66157.png -------------------------------------------------------------------------------- /data/examples/oilportrait/beethoven.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/oilportrait/beethoven.jpeg -------------------------------------------------------------------------------- /data/examples/oilportrait/descartes.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/oilportrait/descartes.jpeg -------------------------------------------------------------------------------- /data/examples/oilportrait/fermat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/oilportrait/fermat.jpg -------------------------------------------------------------------------------- /data/examples/oilportrait/galileo.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/oilportrait/galileo.jpeg -------------------------------------------------------------------------------- /data/examples/oilportrait/jeanjaques.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/oilportrait/jeanjaques.jpeg -------------------------------------------------------------------------------- /data/examples/oilportrait/kant.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/oilportrait/kant.jpeg -------------------------------------------------------------------------------- /data/examples/oilportrait/monalisa.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/oilportrait/monalisa.jpeg -------------------------------------------------------------------------------- /data/examples/oilportrait/newton.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/data/examples/oilportrait/newton.jpeg -------------------------------------------------------------------------------- /env_bigbigan.yaml: -------------------------------------------------------------------------------- 1 | name: net2net_bigbigan 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - python=3.7 7 | - pip=19.3 8 | - cudatoolkit=10.1 9 | - cudnn=7.6.5 10 | - pytorch=1.6 11 | - torchvision=0.7 12 | - numpy=1.18 13 | - pip: 14 | - albumentations==0.4.3 15 | - opencv-python==4.1.2.30 16 | - pudb==2019.2 17 | - imageio==2.9.0 18 | - imageio-ffmpeg==0.4.2 19 | - plotly==4.9.0 20 | - pytorch-lightning==0.9.0 21 | - omegaconf==2.0.0 22 | - streamlit==0.71.0 23 | - test-tube>=0.7.5 24 | - sentence-transformers>=0.3.8 25 | - tensorflow==2.3.1 26 | - tensorflow-hub==0.10.0 27 | - -e . 28 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: net2net 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - python=3.7 7 | - pip=19.3 8 | - cudatoolkit=10.2 9 | - pytorch=1.6 10 | - torchvision=0.7 11 | - numpy=1.18 12 | - pip: 13 | - albumentations==0.4.3 14 | - opencv-python==4.1.2.30 15 | - pudb==2019.2 16 | - imageio==2.9.0 17 | - imageio-ffmpeg==0.4.2 18 | - plotly==4.9.0 19 | - pytorch-lightning==0.9.0 20 | - omegaconf==2.0.0 21 | - streamlit==0.71.0 22 | - test-tube>=0.7.5 23 | - sentence-transformers>=0.3.8 24 | - -e . 25 | -------------------------------------------------------------------------------- /net2net/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/net2net/__init__.py -------------------------------------------------------------------------------- /net2net/ckpt_util.py: -------------------------------------------------------------------------------- 1 | import os, hashlib 2 | import requests 3 | from tqdm import tqdm 4 | 5 | URL_MAP = { 6 | "biggan_128": "https://heibox.uni-heidelberg.de/f/56ed256209fd40968864/?dl=1", 7 | "biggan_256": "https://heibox.uni-heidelberg.de/f/437b501944874bcc92a4/?dl=1", 8 | "dequant_vae": "https://heibox.uni-heidelberg.de/f/e7c8959b50a64f40826e/?dl=1", 9 | "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1", 10 | "coco_captioner": "https://heibox.uni-heidelberg.de/f/b03aae864a0f42f1a2c3/?dl=1", 11 | "coco_word_map": "https://heibox.uni-heidelberg.de/f/1518aa8461d94e0cb3eb/?dl=1" 12 | } 13 | 14 | CKPT_MAP = { 15 | "biggan_128": "biggan-128.pth", 16 | "biggan_256": "biggan-256.pth", 17 | "dequant_vae": "dequantvae-20000.ckpt", 18 | "vgg_lpips": "autoencoders/lpips/vgg.pth", 19 | "coco_captioner": "captioning_model_pt16.ckpt", 20 | } 21 | 22 | MD5_MAP = { 23 | "biggan_128": "a2148cf64807444113fac5eede060d28", 24 | "biggan_256": "e23db3caa34ac4c4ae922a75258dcb8d", 25 | "dequant_vae": "5c2a6fe765142cbdd9f10f15d65a68b6", 26 | "vgg_lpips": "d507d7349b931f0638a25a48a722f98a", 27 | "coco_captioner": "db185e0f6791e60d27c00de0f40c376c", 28 | } 29 | 30 | 31 | def download(url, local_path, chunk_size=1024): 32 | os.makedirs(os.path.split(local_path)[0], exist_ok=True) 33 | with requests.get(url, stream=True) as r: 34 | total_size = int(r.headers.get("content-length", 0)) 35 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: 36 | with open(local_path, "wb") as f: 37 | for data in r.iter_content(chunk_size=chunk_size): 38 | if data: 39 | f.write(data) 40 | pbar.update(chunk_size) 41 | 42 | 43 | def md5_hash(path): 44 | with open(path, "rb") as f: 45 | content = f.read() 46 | return hashlib.md5(content).hexdigest() 47 | 48 | 49 | def get_ckpt_path(name, root, check=False): 50 | assert name in URL_MAP 51 | path = os.path.join(root, CKPT_MAP[name]) 52 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): 53 | print("Downloading {} from {} to {}".format(name, URL_MAP[name], path)) 54 | download(URL_MAP[name], path) 55 | md5 = md5_hash(path) 56 | assert md5 == MD5_MAP[name], md5 57 | return path 58 | -------------------------------------------------------------------------------- /net2net/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/net2net/data/__init__.py -------------------------------------------------------------------------------- /net2net/data/base.py: -------------------------------------------------------------------------------- 1 | import os, bisect 2 | import numpy as np 3 | import albumentations 4 | from PIL import Image 5 | from torch.utils.data import Dataset, ConcatDataset 6 | 7 | 8 | class ConcatDatasetWithIndex(ConcatDataset): 9 | """Modified from original pytorch code to return dataset idx""" 10 | def __getitem__(self, idx): 11 | if idx < 0: 12 | if -idx > len(self): 13 | raise ValueError("absolute value of index should not exceed dataset length") 14 | idx = len(self) + idx 15 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) 16 | if dataset_idx == 0: 17 | sample_idx = idx 18 | else: 19 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] 20 | return self.datasets[dataset_idx][sample_idx], dataset_idx 21 | 22 | 23 | class ImagePaths(Dataset): 24 | def __init__(self, paths, size=None, random_crop=False): 25 | self.size = size 26 | self.random_crop = random_crop 27 | 28 | self.labels = dict() 29 | self.labels["file_path_"] = paths 30 | self._length = len(paths) 31 | 32 | if self.size is not None and self.size > 0: 33 | self.rescaler = albumentations.SmallestMaxSize(max_size = self.size) 34 | if not self.random_crop: 35 | self.cropper = albumentations.CenterCrop(height=self.size,width=self.size) 36 | else: 37 | self.cropper = albumentations.RandomCrop(height=self.size,width=self.size) 38 | self.preprocessor = albumentations.Compose([self.rescaler, self.cropper]) 39 | else: 40 | self.preprocessor = lambda **kwargs: kwargs 41 | 42 | def __len__(self): 43 | return self._length 44 | 45 | def preprocess_image(self, image_path): 46 | image = Image.open(image_path) 47 | if not image.mode == "RGB": 48 | image = image.convert("RGB") 49 | image = np.array(image).astype(np.uint8) 50 | image = self.preprocessor(image=image)["image"] 51 | image = (image/127.5 - 1.0).astype(np.float32) 52 | return image 53 | 54 | def __getitem__(self, i): 55 | example = dict() 56 | example["image"] = self.preprocess_image(self.labels["file_path_"][i]) 57 | for k in self.labels: 58 | example[k] = self.labels[k][i] 59 | return example 60 | 61 | 62 | class NumpyPaths(ImagePaths): 63 | def preprocess_image(self, image_path): 64 | image = np.load(image_path).squeeze(0) # 3 x 1024 x 1024 65 | image = np.transpose(image, (1,2,0)) 66 | image = Image.fromarray(image, mode="RGB") 67 | image = np.array(image).astype(np.uint8) 68 | image = self.preprocessor(image=image)["image"] 69 | image = (image/127.5 - 1.0).astype(np.float32) 70 | return image 71 | -------------------------------------------------------------------------------- /net2net/data/coco.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import albumentations 4 | import numpy as np 5 | from PIL import Image 6 | from tqdm import tqdm 7 | from torch.utils.data import Dataset 8 | 9 | 10 | class CocoBase(Dataset): 11 | """needed for (image, caption, segmentation) pairs""" 12 | def __init__(self, size=None, dataroot="", datajson="", onehot_segmentation=False, use_stuffthing=False, 13 | crop_size=None, force_no_crop=False): 14 | self.split = self.get_split() 15 | self.size = size 16 | if crop_size is None: 17 | self.crop_size = size 18 | else: 19 | self.crop_size = crop_size 20 | 21 | self.onehot = onehot_segmentation # return segmentation as rgb or one hot 22 | self.stuffthing = use_stuffthing # include thing in segmentation 23 | if self.onehot and not self.stuffthing: 24 | raise NotImplemented("One hot mode is only supported for the " 25 | "stuffthings version because labels are stored " 26 | "a bit different.") 27 | 28 | data_json = datajson 29 | with open(data_json) as json_file: 30 | self.json_data = json.load(json_file) 31 | self.img_id_to_captions = dict() 32 | self.img_id_to_filepath = dict() 33 | self.img_id_to_segmentation_filepath = dict() 34 | 35 | assert data_json.split("/")[-1] in ["captions_train2017.json", 36 | "captions_val2017.json"] 37 | 38 | if self.stuffthing: 39 | self.segmentation_prefix = ( 40 | "data/cocostuffthings/val2017" if 41 | data_json.endswith("captions_val2017.json") else 42 | "data/cocostuffthings/train2017") 43 | else: 44 | self.segmentation_prefix = ( 45 | "data/coco/annotations/stuff_val2017_pixelmaps" if 46 | data_json.endswith("captions_val2017.json") else 47 | "data/coco/annotations/stuff_train2017_pixelmaps") 48 | 49 | imagedirs = self.json_data["images"] 50 | self.labels = {"image_ids": list()} 51 | for imgdir in tqdm(imagedirs, desc="ImgToPath"): 52 | self.img_id_to_filepath[imgdir["id"]] = os.path.join(dataroot, imgdir["file_name"]) 53 | self.img_id_to_captions[imgdir["id"]] = list() 54 | pngfilename = imgdir["file_name"].replace("jpg", "png") 55 | self.img_id_to_segmentation_filepath[imgdir["id"]] = os.path.join( 56 | self.segmentation_prefix, pngfilename) 57 | self.labels["image_ids"].append(imgdir["id"]) 58 | 59 | capdirs = self.json_data["annotations"] 60 | for capdir in tqdm(capdirs, desc="ImgToCaptions"): 61 | # there are in average 5 captions per image 62 | self.img_id_to_captions[capdir["image_id"]].append(np.array([capdir["caption"]])) 63 | 64 | self.rescaler = albumentations.SmallestMaxSize(max_size=self.size) 65 | if self.split=="validation": 66 | self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size) 67 | else: 68 | self.cropper = albumentations.RandomCrop(height=self.crop_size, width=self.crop_size) 69 | self.preprocessor = albumentations.Compose( 70 | [self.rescaler, self.cropper], 71 | additional_targets={"segmentation": "image"}) 72 | if force_no_crop: 73 | self.rescaler = albumentations.Resize(height=self.size, width=self.size) 74 | self.preprocessor = albumentations.Compose( 75 | [self.rescaler], 76 | additional_targets={"segmentation": "image"}) 77 | 78 | def __len__(self): 79 | return len(self.labels["image_ids"]) 80 | 81 | def preprocess_image(self, image_path, segmentation_path): 82 | image = Image.open(image_path) 83 | if not image.mode == "RGB": 84 | image = image.convert("RGB") 85 | image = np.array(image).astype(np.uint8) 86 | 87 | segmentation = Image.open(segmentation_path) 88 | if not self.onehot and not segmentation.mode == "RGB": 89 | segmentation = segmentation.convert("RGB") 90 | segmentation = np.array(segmentation).astype(np.uint8) 91 | if self.onehot: 92 | assert self.stuffthing 93 | # stored in caffe format: unlabeled==255. stuff and thing from 94 | # 0-181. to be compatible with the labels in 95 | # https://github.com/nightrome/cocostuff/blob/master/labels.txt 96 | # we shift stuffthing one to the right and put unlabeled in zero 97 | # as long as segmentation is uint8 shifting to right handles the 98 | # latter too 99 | assert segmentation.dtype == np.uint8 100 | segmentation = segmentation + 1 101 | 102 | processed = self.preprocessor(image=image, segmentation=segmentation) 103 | image, segmentation = processed["image"], processed["segmentation"] 104 | image = (image / 127.5 - 1.0).astype(np.float32) 105 | 106 | if self.onehot: 107 | assert segmentation.dtype == np.uint8 108 | # make it one hot 109 | n_labels = 183 110 | flatseg = np.ravel(segmentation) 111 | onehot = np.zeros((flatseg.size, n_labels), dtype=np.bool) 112 | onehot[np.arange(flatseg.size), flatseg] = True 113 | onehot = onehot.reshape(segmentation.shape + (n_labels,)).astype(int) 114 | segmentation = onehot 115 | else: 116 | segmentation = (segmentation / 127.5 - 1.0).astype(np.float32) 117 | return image, segmentation 118 | 119 | def __getitem__(self, i): 120 | img_path = self.img_id_to_filepath[self.labels["image_ids"][i]] 121 | seg_path = self.img_id_to_segmentation_filepath[self.labels["image_ids"][i]] 122 | image, segmentation = self.preprocess_image(img_path, seg_path) 123 | captions = self.img_id_to_captions[self.labels["image_ids"][i]] 124 | # randomly draw one of all available captions per image 125 | caption = captions[np.random.randint(0, len(captions))] 126 | example = {"image": image, 127 | "caption": [str(caption[0])], 128 | "segmentation": segmentation, 129 | "img_path": img_path, 130 | "seg_path": seg_path 131 | } 132 | return example 133 | 134 | 135 | class CocoImagesAndCaptionsTrain(CocoBase): 136 | """returns a pair of (image, caption)""" 137 | def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False): 138 | super().__init__(size=size, 139 | dataroot="data/coco/train2017", 140 | datajson="data/coco/annotations/captions_train2017.json", 141 | onehot_segmentation=onehot_segmentation, 142 | use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop) 143 | 144 | def get_split(self): 145 | return "train" 146 | 147 | 148 | class CocoImagesAndCaptionsValidation(CocoBase): 149 | """returns a pair of (image, caption)""" 150 | def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False): 151 | super().__init__(size=size, 152 | dataroot="data/coco/val2017", 153 | datajson="data/coco/annotations/captions_val2017.json", 154 | onehot_segmentation=onehot_segmentation, 155 | use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop) 156 | 157 | def get_split(self): 158 | return "validation" 159 | 160 | -------------------------------------------------------------------------------- /net2net/data/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import urllib 4 | import tarfile, zipfile 5 | from pathlib import Path 6 | from tqdm import tqdm 7 | 8 | 9 | def unpack(path): 10 | if path.endswith("tar.gz"): 11 | with tarfile.open(path, "r:gz") as tar: 12 | tar.extractall(path=os.path.split(path)[0]) 13 | elif path.endswith("tar"): 14 | with tarfile.open(path, "r:") as tar: 15 | tar.extractall(path=os.path.split(path)[0]) 16 | elif path.endswith("zip"): 17 | with zipfile.ZipFile(path, "r") as f: 18 | f.extractall(path=os.path.split(path)[0]) 19 | else: 20 | raise NotImplementedError( 21 | "Unknown file extension: {}".format(os.path.splitext(path)[1]) 22 | ) 23 | 24 | 25 | def reporthook(bar): 26 | """tqdm progress bar for downloads.""" 27 | 28 | def hook(b=1, bsize=1, tsize=None): 29 | if tsize is not None: 30 | bar.total = tsize 31 | bar.update(b * bsize - bar.n) 32 | 33 | return hook 34 | 35 | 36 | def get_root(name): 37 | base = "data/" 38 | root = os.path.join(base, name) 39 | os.makedirs(root, exist_ok=True) 40 | return root 41 | 42 | 43 | def is_prepared(root): 44 | return Path(root).joinpath(".ready").exists() 45 | 46 | 47 | def mark_prepared(root): 48 | Path(root).joinpath(".ready").touch() 49 | 50 | 51 | def prompt_download(file_, source, target_dir, content_dir=None): 52 | targetpath = os.path.join(target_dir, file_) 53 | while not os.path.exists(targetpath): 54 | if content_dir is not None and os.path.exists( 55 | os.path.join(target_dir, content_dir) 56 | ): 57 | break 58 | print( 59 | "Please download '{}' from '{}' to '{}'.".format(file_, source, targetpath) 60 | ) 61 | if content_dir is not None: 62 | print( 63 | "Or place its content into '{}'.".format( 64 | os.path.join(target_dir, content_dir) 65 | ) 66 | ) 67 | input("Press Enter when done...") 68 | return targetpath 69 | 70 | 71 | def download_url(file_, url, target_dir): 72 | targetpath = os.path.join(target_dir, file_) 73 | os.makedirs(target_dir, exist_ok=True) 74 | with tqdm( 75 | unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=file_ 76 | ) as bar: 77 | urllib.request.urlretrieve(url, targetpath, reporthook=reporthook(bar)) 78 | return targetpath 79 | 80 | 81 | def download_urls(urls, target_dir): 82 | paths = dict() 83 | for fname, url in urls.items(): 84 | outpath = download_url(fname, url, target_dir) 85 | paths[fname] = outpath 86 | return paths 87 | 88 | 89 | def quadratic_crop(x, bbox, alpha=1.0): 90 | """bbox is xmin, ymin, xmax, ymax""" 91 | im_h, im_w = x.shape[:2] 92 | bbox = np.array(bbox, dtype=np.float32) 93 | bbox = np.clip(bbox, 0, max(im_h, im_w)) 94 | center = 0.5 * (bbox[0] + bbox[2]), 0.5 * (bbox[1] + bbox[3]) 95 | w = bbox[2] - bbox[0] 96 | h = bbox[3] - bbox[1] 97 | l = int(alpha * max(w, h)) 98 | l = max(l, 2) 99 | 100 | required_padding = -1 * min( 101 | center[0] - l, center[1] - l, im_w - (center[0] + l), im_h - (center[1] + l) 102 | ) 103 | required_padding = int(np.ceil(required_padding)) 104 | if required_padding > 0: 105 | padding = [ 106 | [required_padding, required_padding], 107 | [required_padding, required_padding], 108 | ] 109 | padding += [[0, 0]] * (len(x.shape) - 2) 110 | x = np.pad(x, padding, "reflect") 111 | center = center[0] + required_padding, center[1] + required_padding 112 | xmin = int(center[0] - l / 2) 113 | ymin = int(center[1] - l / 2) 114 | return np.array(x[ymin : ymin + l, xmin : xmin + l, ...]) 115 | -------------------------------------------------------------------------------- /net2net/data/zcodes.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from tqdm import tqdm 4 | from torch.utils.data import Dataset 5 | 6 | 7 | class PRNGMixin(object): 8 | """Adds a prng property which is a numpy RandomState which gets 9 | reinitialized whenever the pid changes to avoid synchronized sampling 10 | behavior when used in conjunction with multiprocessing.""" 11 | 12 | @property 13 | def prng(self): 14 | currentpid = os.getpid() 15 | if getattr(self, "_initpid", None) != currentpid: 16 | self._initpid = currentpid 17 | self._prng = np.random.RandomState() 18 | return self._prng 19 | 20 | 21 | class TrainSamples(Dataset, PRNGMixin): 22 | def __init__(self, n_samples, z_shape, n_classes, truncation=0): 23 | self.n_samples = n_samples 24 | self.z_shape = z_shape 25 | self.n_classes = n_classes 26 | self.truncation_threshold = truncation 27 | if self.truncation_threshold > 0: 28 | print("Applying truncation at level {}".format(self.truncation_threshold)) 29 | 30 | def __len__(self): 31 | return self.n_samples 32 | 33 | def __getitem__(self, i): 34 | z = self.prng.randn(*self.z_shape) 35 | if self.truncation_threshold > 0: 36 | for k, zi in enumerate(z): 37 | while abs(zi) > self.truncation_threshold: 38 | zi = self.prng.randn(1) 39 | z[k] = zi 40 | cls = self.prng.randint(self.n_classes) 41 | return {"z": z.astype(np.float32), "class": cls} 42 | 43 | 44 | class TestSamples(Dataset): 45 | def __init__(self, n_samples, z_shape, n_classes, truncation=0): 46 | self.prng = np.random.RandomState(1) 47 | self.n_samples = n_samples 48 | self.z_shape = z_shape 49 | self.n_classes = n_classes 50 | self.truncation_threshold = truncation 51 | if self.truncation_threshold > 0: 52 | print("Applying truncation at level {}".format(self.truncation_threshold)) 53 | self.zs = self.prng.randn(self.n_samples, *self.z_shape) 54 | if self.truncation_threshold > 0: 55 | print("Applying truncation at level {}".format(self.truncation_threshold)) 56 | ix = 0 57 | for z in tqdm(self.zs, desc="Truncation:"): 58 | for k, zi in enumerate(z): 59 | while abs(zi) > self.truncation_threshold: 60 | zi = self.prng.randn(1) 61 | z[k] = zi 62 | self.zs[ix] = z 63 | ix += 1 64 | print("Created truncated test data.") 65 | self.clss = self.prng.randint(self.n_classes, size=(self.n_samples,)) 66 | 67 | def __len__(self): 68 | return self.n_samples 69 | 70 | def __getitem__(self, i): 71 | return {"z": self.zs[i].astype(np.float32), "class": self.clss[i]} 72 | 73 | 74 | class RestrictedTrainSamples(Dataset, PRNGMixin): 75 | def __init__(self, n_samples, z_shape, truncation=0): 76 | index_path = "data/coco_imagenet_overlap_idx.txt" 77 | self.n_samples = n_samples 78 | self.z_shape = z_shape 79 | self.classes = np.loadtxt(index_path).astype(int) 80 | self.truncation_threshold = truncation 81 | if self.truncation_threshold > 0: 82 | print("Applying truncation at level {}".format(self.truncation_threshold)) 83 | 84 | def __len__(self): 85 | return self.n_samples 86 | 87 | def __getitem__(self, i): 88 | z = self.prng.randn(*self.z_shape) 89 | if self.truncation_threshold > 0: 90 | for k, zi in enumerate(z): 91 | while abs(zi) > self.truncation_threshold: 92 | zi = self.prng.randn(1) 93 | z[k] = zi 94 | cls = self.prng.choice(self.classes) 95 | return {"z": z.astype(np.float32), "class": cls} 96 | 97 | 98 | class RestrictedTestSamples(Dataset): 99 | def __init__(self, n_samples, z_shape, truncation=0): 100 | index_path = "data/coco_imagenet_overlap_idx.txt" 101 | 102 | self.prng = np.random.RandomState(1) 103 | self.n_samples = n_samples 104 | self.z_shape = z_shape 105 | 106 | self.classes = np.loadtxt(index_path).astype(int) 107 | self.clss = self.prng.choice(self.classes, size=(self.n_samples,), replace=True) 108 | self.truncation_threshold = truncation 109 | self.zs = self.prng.randn(self.n_samples, *self.z_shape) 110 | if self.truncation_threshold > 0: 111 | print("Applying truncation at level {}".format(self.truncation_threshold)) 112 | ix = 0 113 | for z in tqdm(self.zs, desc="Truncation:"): 114 | for k, zi in enumerate(z): 115 | while abs(zi) > self.truncation_threshold: 116 | zi = self.prng.randn(1) 117 | z[k] = zi 118 | self.zs[ix] = z 119 | ix += 1 120 | print("Created truncated test data.") 121 | 122 | def __len__(self): 123 | return self.n_samples 124 | 125 | def __getitem__(self, i): 126 | return {"z": self.zs[i].astype(np.float32), "class": self.clss[i]} 127 | 128 | 129 | -------------------------------------------------------------------------------- /net2net/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/net2net/models/__init__.py -------------------------------------------------------------------------------- /net2net/models/autoencoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytorch_lightning as pl 3 | 4 | from net2net.modules.distributions.distributions import DiagonalGaussianDistribution 5 | from translation import instantiate_from_config 6 | 7 | 8 | class BigAE(pl.LightningModule): 9 | def __init__(self, 10 | encoder_config, 11 | decoder_config, 12 | loss_config, 13 | ckpt_path=None, 14 | ignore_keys=[] 15 | ): 16 | super().__init__() 17 | self.encoder = instantiate_from_config(encoder_config) 18 | self.decoder = instantiate_from_config(decoder_config) 19 | self.loss = instantiate_from_config(loss_config) 20 | 21 | if ckpt_path is not None: 22 | print("Loading model from {}".format(ckpt_path)) 23 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) 24 | 25 | def init_from_ckpt(self, path, ignore_keys=list()): 26 | try: 27 | sd = torch.load(path, map_location="cpu")["state_dict"] 28 | except KeyError: 29 | sd = torch.load(path, map_location="cpu") 30 | 31 | keys = list(sd.keys()) 32 | for k in keys: 33 | for ik in ignore_keys: 34 | if k.startswith(ik): 35 | print("Deleting key {} from state_dict.".format(k)) 36 | del sd[k] 37 | missing, unexpected = self.load_state_dict(sd, strict=False) 38 | if len(missing) > 0: 39 | print(f"Missing keys in state dict: {missing}") 40 | if len(unexpected) > 0: 41 | print(f"Unexpected keys in state dict: {unexpected}") 42 | 43 | def encode(self, x, return_mode=False): 44 | moments = self.encoder(x) 45 | posterior = DiagonalGaussianDistribution(moments, deterministic=False) 46 | if return_mode: 47 | return posterior.mode() 48 | return posterior.sample() 49 | 50 | def decode(self, z): 51 | if len(z.shape) == 4: 52 | z = z.squeeze(-1).squeeze(-1) 53 | return self.decoder(z) 54 | 55 | def forward(self, x): 56 | moments = self.encoder(x) 57 | posterior = DiagonalGaussianDistribution(moments) 58 | h = posterior.sample() 59 | reconstructions = self.decoder(h.squeeze(-1).squeeze(-1)) 60 | return reconstructions, posterior 61 | 62 | def get_last_layer(self): 63 | return getattr(self.decoder.decoder.colorize.module, 'weight_bar') 64 | 65 | def log_images(self, batch, split=""): 66 | log = dict() 67 | inputs = batch["image"].permute(0, 3, 1, 2) 68 | inputs = inputs.to(self.device) 69 | reconstructions, posterior = self(inputs) 70 | log["inputs"] = inputs 71 | log["reconstructions"] = reconstructions 72 | return log 73 | 74 | def training_step(self, batch, batch_idx, optimizer_idx): 75 | inputs = batch["image"].permute(0, 3, 1, 2) 76 | reconstructions, posterior = self(inputs) 77 | 78 | if optimizer_idx == 0: 79 | # train encoder+decoder+logvar 80 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, 81 | last_layer=self.get_last_layer(), split="train") 82 | output = pl.TrainResult(minimize=aeloss, checkpoint_on=aeloss) 83 | output.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) 84 | output.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) 85 | return output 86 | 87 | if optimizer_idx == 1: 88 | # train the discriminator 89 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, 90 | last_layer=self.get_last_layer(), split="train") 91 | output = pl.TrainResult(minimize=discloss) 92 | output.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) 93 | output.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) 94 | del output["checkpoint_on"] # NOTE pl currently sets checkpoint_on=minimize by default TODO 95 | return output 96 | 97 | def validation_step(self, batch, batch_idx): 98 | inputs = batch["image"].permute(0, 3, 1, 2) 99 | reconstructions, posterior = self(inputs) 100 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, 101 | last_layer=self.get_last_layer(), split="val") 102 | 103 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, 104 | last_layer=self.get_last_layer(), split="val") 105 | output = pl.EvalResult(checkpoint_on=aeloss) 106 | output.log_dict(log_dict_ae) 107 | output.log_dict(log_dict_disc) 108 | return output 109 | 110 | def configure_optimizers(self): 111 | lr = self.learning_rate 112 | opt_ae = torch.optim.Adam(list(self.encoder.parameters())+list(self.decoder.parameters())+[self.loss.logvar], 113 | lr=lr, betas=(0.5, 0.9)) 114 | opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9)) 115 | return [opt_ae, opt_disc], [] 116 | 117 | def on_epoch_end(self): 118 | pass 119 | 120 | 121 | class BasicAE(pl.LightningModule): 122 | def __init__(self, ae_config, loss_config, ckpt_path=None, ignore_keys=[]): 123 | super().__init__() 124 | self.autoencoder = instantiate_from_config(ae_config) 125 | self.loss = instantiate_from_config(loss_config) 126 | if ckpt_path is not None: 127 | print("Loading model from {}".format(ckpt_path)) 128 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) 129 | 130 | def init_from_ckpt(self, path, ignore_keys=list()): 131 | try: 132 | sd = torch.load(path, map_location="cpu")["state_dict"] 133 | except KeyError: 134 | sd = torch.load(path, map_location="cpu") 135 | 136 | keys = list(sd.keys()) 137 | for k in keys: 138 | for ik in ignore_keys: 139 | if k.startswith(ik): 140 | print("Deleting key {} from state_dict.".format(k)) 141 | del sd[k] 142 | self.load_state_dict(sd, strict=False) 143 | 144 | def forward(self, x): 145 | posterior = self.autoencoder.encode(x) 146 | h = posterior.sample() 147 | reconstructions = self.autoencoder.decode(h) 148 | return reconstructions, posterior 149 | 150 | def encode(self, x): 151 | posterior = self.autoencoder.encode(x) 152 | h = posterior.sample() 153 | return h 154 | 155 | def get_last_layer(self): 156 | return self.autoencoder.get_last_layer() 157 | 158 | def log_images(self, batch, split=""): 159 | log = dict() 160 | inputs = batch["image"].permute(0, 3, 1, 2) 161 | inputs = inputs.to(self.device) 162 | reconstructions, posterior = self(inputs) 163 | log["inputs"] = inputs 164 | log["reconstructions"] = reconstructions 165 | return log 166 | 167 | def training_step(self, batch, batch_idx, optimizer_idx): 168 | inputs = batch["image"].permute(0, 3, 1, 2) 169 | reconstructions, posterior = self(inputs) 170 | 171 | if optimizer_idx == 0: 172 | # train encoder+decoder+logvar 173 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, 174 | last_layer=self.get_last_layer(), split="train") 175 | output = pl.TrainResult(minimize=aeloss, checkpoint_on=aeloss) 176 | output.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) 177 | output.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) 178 | return output 179 | 180 | if optimizer_idx == 1: 181 | # train the discriminator 182 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, 183 | last_layer=self.get_last_layer(), split="train") 184 | output = pl.TrainResult(minimize=discloss) 185 | output.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) 186 | output.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) 187 | del output["checkpoint_on"] # NOTE pl currently sets checkpoint_on=minimize by default TODO 188 | return output 189 | 190 | def validation_step(self, batch, batch_idx): 191 | inputs = batch["image"].permute(0, 3, 1, 2) 192 | reconstructions, posterior = self(inputs) 193 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, 194 | last_layer=self.get_last_layer(), split="val") 195 | 196 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, 197 | last_layer=self.get_last_layer(), split="val") 198 | output = pl.EvalResult(checkpoint_on=aeloss) 199 | output.log_dict(log_dict_ae) 200 | output.log_dict(log_dict_disc) 201 | return output 202 | 203 | def configure_optimizers(self): 204 | lr = self.learning_rate 205 | opt_ae = torch.optim.Adam(list(self.autoencoder.parameters())+[self.loss.logvar], 206 | lr=lr, betas=(0.5, 0.9)) 207 | opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9)) 208 | return [opt_ae, opt_disc], [] 209 | -------------------------------------------------------------------------------- /net2net/models/flows/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/net2net/models/flows/__init__.py -------------------------------------------------------------------------------- /net2net/models/flows/flow.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import pytorch_lightning as pl 4 | 5 | from translation import instantiate_from_config 6 | from net2net.modules.flow.loss import NLL 7 | from net2net.ckpt_util import get_ckpt_path 8 | from net2net.modules.util import log_txt_as_img 9 | 10 | 11 | def disabled_train(self, mode=True): 12 | """Overwrite model.train with this function to make sure train/eval mode 13 | does not change anymore.""" 14 | return self 15 | 16 | 17 | class Flow(pl.LightningModule): 18 | def __init__(self, flow_config): 19 | super().__init__() 20 | self.flow = instantiate_from_config(config=flow_config) 21 | self.loss = NLL() 22 | 23 | def forward(self, x): 24 | zz, logdet = self.flow(x) 25 | return zz, logdet 26 | 27 | def sample_like(self, query): 28 | z = self.flow.sample(query.shape[0], device=query.device).float() 29 | return z 30 | 31 | def shared_step(self, batch, batch_idx): 32 | x, labels = batch 33 | x = x.float() 34 | zz, logdet = self(x) 35 | loss, log_dict = self.loss(zz, logdet) 36 | return loss, log_dict 37 | 38 | def training_step(self, batch, batch_idx): 39 | loss, log_dict = self.shared_step(batch, batch_idx) 40 | output = pl.TrainResult(minimize=loss, checkpoint_on=loss) 41 | output.log_dict(log_dict, prog_bar=False, on_epoch=True) 42 | return output 43 | 44 | def validation_step(self, batch, batch_idx): 45 | loss, log_dict = self.shared_step(batch, batch_idx) 46 | output = pl.EvalResult(checkpoint_on=loss) 47 | output.log_dict(log_dict, prog_bar=False) 48 | 49 | x, _ = batch 50 | x = x.float() 51 | sample = self.sample_like(x) 52 | output.sample_like = sample 53 | output.input = x.clone() 54 | 55 | return output 56 | 57 | def configure_optimizers(self): 58 | opt = torch.optim.Adam((self.flow.parameters()),lr=self.learning_rate, betas=(0.5, 0.9)) 59 | return [opt], [] 60 | 61 | 62 | class Net2NetFlow(pl.LightningModule): 63 | def __init__(self, 64 | flow_config, 65 | first_stage_config, 66 | cond_stage_config, 67 | ckpt_path=None, 68 | ignore_keys=[], 69 | first_stage_key="image", 70 | cond_stage_key="image", 71 | interpolate_cond_size=-1 72 | ): 73 | super().__init__() 74 | self.init_first_stage_from_ckpt(first_stage_config) 75 | self.init_cond_stage_from_ckpt(cond_stage_config) 76 | self.flow = instantiate_from_config(config=flow_config) 77 | self.loss = NLL() 78 | self.first_stage_key = first_stage_key 79 | self.cond_stage_key = cond_stage_key 80 | self.interpolate_cond_size = interpolate_cond_size 81 | if ckpt_path is not None: 82 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) 83 | 84 | def init_from_ckpt(self, path, ignore_keys=list()): 85 | sd = torch.load(path, map_location="cpu")["state_dict"] 86 | for k in sd.keys(): 87 | for ik in ignore_keys: 88 | if k.startswith(ik): 89 | self.print("Deleting key {} from state_dict.".format(k)) 90 | del sd[k] 91 | self.load_state_dict(sd, strict=False) 92 | print(f"Restored from {path}") 93 | 94 | def init_first_stage_from_ckpt(self, config): 95 | model = instantiate_from_config(config) 96 | model = model.eval() 97 | model.train = disabled_train 98 | self.first_stage_model = model 99 | 100 | def init_cond_stage_from_ckpt(self, config): 101 | model = instantiate_from_config(config) 102 | model = model.eval() 103 | model.train = disabled_train 104 | self.cond_stage_model = model 105 | 106 | def forward(self, x, c): 107 | c = self.encode_to_c(c) 108 | q = self.encode_to_z(x) 109 | zz, logdet = self.flow(q, c) 110 | return zz, logdet 111 | 112 | @torch.no_grad() 113 | def sample_conditional(self, c): 114 | z = self.flow.sample(c) 115 | return z 116 | 117 | @torch.no_grad() 118 | def encode_to_z(self, x): 119 | z = self.first_stage_model.encode(x).detach() 120 | return z 121 | 122 | @torch.no_grad() 123 | def encode_to_c(self, c): 124 | c = self.cond_stage_model.encode(c).detach() 125 | return c 126 | 127 | @torch.no_grad() 128 | def decode_to_img(self, z): 129 | x = self.first_stage_model.decode(z.detach()) 130 | return x 131 | 132 | @torch.no_grad() 133 | def log_images(self, batch, split=""): 134 | log = dict() 135 | x = self.get_input(self.first_stage_key, batch).to(self.device) 136 | xc = self.get_input(self.cond_stage_key, batch, is_conditioning=True) 137 | if self.cond_stage_key not in ["text", "caption"]: 138 | xc = xc.to(self.device) 139 | 140 | z = self.encode_to_z(x) 141 | c = self.encode_to_c(xc) 142 | 143 | zz, _ = self.flow(z, c) 144 | zrec = self.flow.reverse(zz, c) 145 | xrec = self.decode_to_img(zrec) 146 | z_sample = self.sample_conditional(c) 147 | xsample = self.decode_to_img(z_sample) 148 | 149 | cshift = torch.cat((c[1:],c[:1]),dim=0) 150 | zshift = self.flow.reverse(zz, cshift) 151 | xshift = self.decode_to_img(zshift) 152 | 153 | log["inputs"] = x 154 | if self.cond_stage_key not in ["text", "caption", "class"]: 155 | log["conditioning"] = xc 156 | else: 157 | _,_,h,w = x.shape 158 | log["conditioning"] = log_txt_as_img((w,h), xc) 159 | 160 | log["reconstructions"] = xrec 161 | log["shift"] = xshift 162 | log["samples"] = xsample 163 | return log 164 | 165 | def get_input(self, key, batch, is_conditioning=False): 166 | x = batch[key] 167 | if key in ["caption", "text"]: 168 | x = list(x[0]) 169 | elif key in ["class"]: 170 | pass 171 | else: 172 | if len(x.shape) == 3: 173 | x = x[..., None] 174 | x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format) 175 | if is_conditioning: 176 | if self.interpolate_cond_size > -1: 177 | x = F.interpolate(x, size=(self.interpolate_cond_size, self.interpolate_cond_size)) 178 | return x 179 | 180 | def shared_step(self, batch, batch_idx, split="train"): 181 | x = self.get_input(self.first_stage_key, batch) 182 | c = self.get_input(self.cond_stage_key, batch, is_conditioning=True) 183 | zz, logdet = self(x, c) 184 | loss, log_dict = self.loss(zz, logdet, split=split) 185 | return loss, log_dict 186 | 187 | def training_step(self, batch, batch_idx): 188 | loss, log_dict = self.shared_step(batch, batch_idx, split="train") 189 | output = pl.TrainResult(minimize=loss, checkpoint_on=loss) 190 | output.log_dict(log_dict, prog_bar=False, on_epoch=True, logger=True, on_step=True) 191 | return output 192 | 193 | def validation_step(self, batch, batch_idx): 194 | loss, log_dict = self.shared_step(batch, batch_idx, split="val") 195 | output = pl.EvalResult(checkpoint_on=loss) 196 | output.log_dict(log_dict, prog_bar=False, logger=True) 197 | return output 198 | 199 | def configure_optimizers(self): 200 | opt = torch.optim.Adam((self.flow.parameters()), 201 | lr=self.learning_rate, 202 | betas=(0.5, 0.9), 203 | amsgrad=True) 204 | return [opt], [] 205 | 206 | 207 | class Net2BigGANFlow(Net2NetFlow): 208 | def __init__(self, 209 | flow_config, 210 | gan_config, 211 | cond_stage_config, 212 | make_cond_config, 213 | ckpt_path=None, 214 | ignore_keys=[], 215 | cond_stage_key="caption" 216 | ): 217 | super().__init__(flow_config=flow_config, 218 | first_stage_config=gan_config, cond_stage_config=cond_stage_config, 219 | ckpt_path=ckpt_path, ignore_keys=ignore_keys, cond_stage_key=cond_stage_key 220 | ) 221 | 222 | self.init_to_c_model(make_cond_config) 223 | self.init_preprocessing() 224 | 225 | @torch.no_grad() 226 | def get_input(self, batch, move_to_device=False): 227 | zin = batch["z"] 228 | cin = batch["class"] 229 | if move_to_device: 230 | zin, cin = zin.to(self.device), cin.to(self.device) 231 | # dequantize the discrete class code 232 | cin = self.first_stage_model.embed_labels(cin, labels_are_one_hot=False) 233 | split_sizes = [zin.shape[1], cin.shape[1]] 234 | xin = self.first_stage_model.generate_from_embedding(zin, cin) 235 | cin = self.dequantizer(cin) 236 | xc = self.to_c_model(xin) 237 | zflow = torch.cat([zin, cin.detach()], dim=1)[:, :, None, None] # this will be flowed 238 | return {"zcode": zflow, 239 | "xgen": xin, 240 | "xcon": xc, 241 | "split_sizes": split_sizes 242 | } 243 | 244 | def init_to_c_model(self, config): 245 | model = instantiate_from_config(config) 246 | model = model.eval() 247 | model.train = disabled_train 248 | self.to_c_model = model 249 | 250 | def init_preprocessing(self): 251 | dqcfg = {"target": "net2net.modules.autoencoder.basic.BasicFullyConnectedVAE"} 252 | self.dequantizer = instantiate_from_config(dqcfg) 253 | ckpt = get_ckpt_path("dequant_vae", "net2net/modules/autoencoder/dequant_vae") 254 | self.dequantizer.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) 255 | self.dequantizer.eval() 256 | self.dequantizer.train = disabled_train 257 | 258 | def shared_step(self, batch, batch_idx, split="train"): 259 | data = self.get_input(batch) 260 | z, c = data["zcode"], data["xcon"] 261 | zz, logdet = self(z, c) 262 | loss, log_dict = self.loss(zz, logdet, split=split) 263 | return loss, log_dict 264 | 265 | def forward(self, z, c): 266 | c = self.encode_to_c(c) 267 | zz, logdet = self.flow(z, c) 268 | return zz, logdet 269 | 270 | @torch.no_grad() 271 | def log_images(self, batch, split=""): 272 | log = dict() 273 | data = self.get_input(batch, move_to_device=True) 274 | z, xc, x = data["zcode"], data["xcon"], data["xgen"] 275 | c = self.encode_to_c(xc) 276 | zz, _ = self.flow(z, c) 277 | 278 | z_sample = self.sample_conditional(c) 279 | zdec, cdec = torch.split(z_sample, data["split_sizes"], dim=1) 280 | xsample = self.first_stage_model.generate_from_embedding(zdec.squeeze(-1).squeeze(-1), 281 | cdec.squeeze(-1).squeeze(-1)) 282 | 283 | cshift = torch.cat((c[1:],c[:1]),dim=0) 284 | zshift = self.flow.reverse(zz, cshift) 285 | zshift, cshift = torch.split(zshift, data["split_sizes"], dim=1) 286 | xshift = self.first_stage_model.generate_from_embedding(zshift.squeeze(-1).squeeze(-1), 287 | cshift.squeeze(-1).squeeze(-1)) 288 | 289 | log["inputs"] = x 290 | if self.cond_stage_key not in ["text", "caption", "class"]: 291 | log["conditioning"] = xc 292 | else: 293 | _,_,h,w = x.shape 294 | log["conditioning"] = log_txt_as_img((w,h), xc) 295 | 296 | log["shift"] = xshift 297 | log["samples"] = xsample 298 | return log 299 | 300 | @torch.no_grad() 301 | def sample_conditional(self, c): 302 | z = self.flow.sample(c) 303 | return z 304 | -------------------------------------------------------------------------------- /net2net/models/flows/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | #from sklearn.neighbors import KernelDensity 6 | 7 | 8 | def kde2D(x, y, bandwidth, xbins=250j, ybins=250j, **kwargs): 9 | """Build 2D kernel density estimate (KDE).""" 10 | 11 | # create grid of sample locations (default: 100x100) 12 | xx, yy = np.mgrid[x.min():x.max():xbins, 13 | y.min():y.max():ybins] 14 | 15 | xy_sample = np.vstack([yy.ravel(), xx.ravel()]).T 16 | xy_train = np.vstack([y, x]).T 17 | 18 | kde_skl = KernelDensity(bandwidth=bandwidth, **kwargs) 19 | kde_skl.fit(xy_train) 20 | 21 | # score_samples() returns the log-likelihood of the samples 22 | z = np.exp(kde_skl.score_samples(xy_sample)) 23 | return xx, yy, np.reshape(z, xx.shape) 24 | 25 | 26 | def plot2d(x, savepath=None): 27 | """make a scatter plot of x and return an Image of it""" 28 | x = x.cpu().numpy().squeeze() 29 | fig = plt.figure(dpi=300) 30 | xx, yy, zz = kde2D(x[:,0], x[:,1], 0.1) 31 | plt.pcolormesh(xx, yy, zz) 32 | plt.scatter(x[:,0], x[:, 1], s=0.1, c='mistyrose') 33 | if savepath is not None: 34 | plt.savefig(savepath, dpi=300) 35 | return fig 36 | 37 | 38 | def reshape_to_grid(x, num_samples=16, iw=28, ih=28, nc=1): 39 | x = x[:num_samples] 40 | x = x.detach().cpu() 41 | x = torch.reshape(x, (x.shape[0], nc, iw, ih)) 42 | xgrid = torchvision.utils.make_grid(x) 43 | return xgrid -------------------------------------------------------------------------------- /net2net/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/net2net/modules/__init__.py -------------------------------------------------------------------------------- /net2net/modules/autoencoder/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/net2net/modules/autoencoder/__init__.py -------------------------------------------------------------------------------- /net2net/modules/autoencoder/basic.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import functools 4 | 5 | from net2net.modules.distributions.distributions import DiagonalGaussianDistribution 6 | 7 | 8 | def weights_init(m): 9 | classname = m.__class__.__name__ 10 | if classname.find('Conv') != -1: 11 | nn.init.normal_(m.weight.data, 0.0, 0.02) 12 | elif classname.find('BatchNorm') != -1: 13 | nn.init.normal_(m.weight.data, 1.0, 0.02) 14 | nn.init.constant_(m.bias.data, 0) 15 | 16 | 17 | class ActNorm(nn.Module): 18 | def __init__(self, num_features, logdet=False, affine=True, 19 | allow_reverse_init=False): 20 | assert affine 21 | super().__init__() 22 | self.logdet = logdet 23 | self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) 24 | self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) 25 | self.allow_reverse_init = allow_reverse_init 26 | 27 | self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8)) 28 | 29 | def initialize(self, input): 30 | with torch.no_grad(): 31 | flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) 32 | mean = ( 33 | flatten.mean(1) 34 | .unsqueeze(1) 35 | .unsqueeze(2) 36 | .unsqueeze(3) 37 | .permute(1, 0, 2, 3) 38 | ) 39 | std = ( 40 | flatten.std(1) 41 | .unsqueeze(1) 42 | .unsqueeze(2) 43 | .unsqueeze(3) 44 | .permute(1, 0, 2, 3) 45 | ) 46 | 47 | self.loc.data.copy_(-mean) 48 | self.scale.data.copy_(1 / (std + 1e-6)) 49 | 50 | def forward(self, input, reverse=False): 51 | if reverse: 52 | return self.reverse(input) 53 | if len(input.shape) == 2: 54 | input = input[:,:,None,None] 55 | squeeze = True 56 | else: 57 | squeeze = False 58 | 59 | _, _, height, width = input.shape 60 | 61 | if self.training and self.initialized.item() == 0: 62 | self.initialize(input) 63 | self.initialized.fill_(1) 64 | 65 | h = self.scale * (input + self.loc) 66 | 67 | if squeeze: 68 | h = h.squeeze(-1).squeeze(-1) 69 | 70 | if self.logdet: 71 | log_abs = torch.log(torch.abs(self.scale)) 72 | logdet = height*width*torch.sum(log_abs) 73 | logdet = logdet * torch.ones(input.shape[0]).to(input) 74 | return h, logdet 75 | 76 | return h 77 | 78 | def reverse(self, output): 79 | if self.training and self.initialized.item() == 0: 80 | if not self.allow_reverse_init: 81 | raise RuntimeError( 82 | "Initializing ActNorm in reverse direction is " 83 | "disabled by default. Use allow_reverse_init=True to enable." 84 | ) 85 | else: 86 | self.initialize(output) 87 | self.initialized.fill_(1) 88 | 89 | if len(output.shape) == 2: 90 | output = output[:,:,None,None] 91 | squeeze = True 92 | else: 93 | squeeze = False 94 | 95 | h = output / self.scale - self.loc 96 | 97 | if squeeze: 98 | h = h.squeeze(-1).squeeze(-1) 99 | return h 100 | 101 | 102 | class BasicFullyConnectedNet(nn.Module): 103 | def __init__(self, dim, depth, hidden_dim=256, use_tanh=False, use_bn=False, out_dim=None, use_an=False): 104 | super(BasicFullyConnectedNet, self).__init__() 105 | layers = [] 106 | layers.append(nn.Linear(dim, hidden_dim)) 107 | if use_bn: 108 | assert not use_an 109 | layers.append(nn.BatchNorm1d(hidden_dim)) 110 | if use_an: 111 | assert not use_bn 112 | layers.append(ActNorm(hidden_dim)) 113 | layers.append(nn.LeakyReLU()) 114 | for d in range(depth): 115 | layers.append(nn.Linear(hidden_dim, hidden_dim)) 116 | if use_bn: 117 | layers.append(nn.BatchNorm1d(hidden_dim)) 118 | layers.append(nn.LeakyReLU()) 119 | layers.append(nn.Linear(hidden_dim, dim if out_dim is None else out_dim)) 120 | if use_tanh: 121 | layers.append(nn.Tanh()) 122 | self.main = nn.Sequential(*layers) 123 | 124 | def forward(self, x): 125 | return self.main(x) 126 | 127 | 128 | _norm_options = { 129 | "in": nn.InstanceNorm2d, 130 | "bn": nn.BatchNorm2d, 131 | "an": ActNorm} 132 | 133 | 134 | class BasicAEModel(nn.Module): 135 | def __init__(self, n_down, z_dim, in_size, in_channels, deterministic=False): 136 | super().__init__() 137 | bottleneck_size = in_size // 2**n_down 138 | norm = "an" 139 | self.be_deterministic = deterministic 140 | 141 | self.feature_layers = nn.ModuleList() 142 | self.decoder_layers = nn.ModuleList() 143 | 144 | self.feature_layers.append(FeatureLayer(0, in_channels=in_channels, norm=norm)) 145 | for scale in range(1, n_down): 146 | self.feature_layers.append(FeatureLayer(scale, norm=norm)) 147 | 148 | self.dense_encode = DenseEncoderLayer(n_down, bottleneck_size, 2*z_dim) 149 | self.dense_decode = DenseDecoderLayer(n_down-1, bottleneck_size, z_dim) 150 | 151 | for scale in range(n_down-1): 152 | self.decoder_layers.append(DecoderLayer(scale, norm=norm)) 153 | self.image_layer = ImageLayer(out_channels=in_channels) 154 | 155 | self.apply(weights_init) 156 | 157 | self.n_down = n_down 158 | self.z_dim = z_dim 159 | self.bottleneck_size = bottleneck_size 160 | 161 | def encode(self, input): 162 | h = input 163 | for layer in self.feature_layers: 164 | h = layer(h) 165 | h = self.dense_encode(h) 166 | return DiagonalGaussianDistribution(h, deterministic=self.be_deterministic) 167 | 168 | def decode(self, input): 169 | h = input 170 | h = self.dense_decode(h) 171 | for layer in reversed(self.decoder_layers): 172 | h = layer(h) 173 | h = self.image_layer(h) 174 | return h 175 | 176 | def get_last_layer(self): 177 | return self.image_layer.sub_layers[0].weight 178 | 179 | 180 | class FeatureLayer(nn.Module): 181 | def __init__(self, scale, in_channels=None, norm='IN'): 182 | super().__init__() 183 | self.scale = scale 184 | self.norm = _norm_options[norm.lower()] 185 | if in_channels is None: 186 | self.in_channels = 64*min(2**(self.scale-1), 16) 187 | else: 188 | self.in_channels = in_channels 189 | self.build() 190 | 191 | def forward(self, input): 192 | x = input 193 | for layer in self.sub_layers: 194 | x = layer(x) 195 | return x 196 | 197 | def build(self): 198 | Norm = functools.partial(self.norm, affine=True) 199 | Activate = lambda: nn.LeakyReLU(0.2) 200 | self.sub_layers = nn.ModuleList([ 201 | nn.Conv2d( 202 | in_channels=self.in_channels, 203 | out_channels=64*min(2**self.scale, 16), 204 | kernel_size=4, 205 | stride=2, 206 | padding=1, 207 | bias=False), 208 | Norm(num_features=64*min(2**self.scale, 16)), 209 | Activate()]) 210 | 211 | 212 | class LatentLayer(nn.Module): 213 | def __init__(self, in_channels, out_channels): 214 | super(LatentLayer, self).__init__() 215 | self.in_channels = in_channels 216 | self.out_channels = out_channels 217 | self.build() 218 | 219 | def forward(self, input): 220 | x = input 221 | for layer in self.sub_layers: 222 | x = layer(x) 223 | return x 224 | 225 | def build(self): 226 | self.sub_layers = nn.ModuleList([ 227 | nn.Conv2d( 228 | in_channels=self.in_channels, 229 | out_channels=self.out_channels, 230 | kernel_size=1, 231 | stride=1, 232 | padding=0, 233 | bias=True) 234 | ]) 235 | 236 | 237 | class DecoderLayer(nn.Module): 238 | def __init__(self, scale, in_channels=None, norm='IN'): 239 | super().__init__() 240 | self.scale = scale 241 | self.norm = _norm_options[norm.lower()] 242 | if in_channels is not None: 243 | self.in_channels = in_channels 244 | else: 245 | self.in_channels = 64*min(2**(self.scale+1), 16) 246 | self.build() 247 | 248 | def forward(self, input): 249 | d = input 250 | for layer in self.sub_layers: 251 | d = layer(d) 252 | return d 253 | 254 | def build(self): 255 | Norm = functools.partial(self.norm, affine=True) 256 | Activate = lambda: nn.LeakyReLU(0.2) 257 | self.sub_layers = nn.ModuleList([ 258 | nn.ConvTranspose2d( 259 | in_channels=self.in_channels, 260 | out_channels=64*min(2**self.scale, 16), 261 | kernel_size=4, 262 | stride=2, 263 | padding=1, 264 | bias=False), 265 | Norm(num_features=64*min(2**self.scale, 16)), 266 | Activate()]) 267 | 268 | 269 | class DenseEncoderLayer(nn.Module): 270 | def __init__(self, scale, spatial_size, out_size, in_channels=None): 271 | super().__init__() 272 | self.scale = scale 273 | self.in_channels = 64*min(2**(self.scale-1), 16) 274 | if in_channels is not None: 275 | self.in_channels = in_channels 276 | self.out_channels = out_size 277 | self.kernel_size = spatial_size 278 | self.build() 279 | 280 | def forward(self, input): 281 | x = input 282 | for layer in self.sub_layers: 283 | x = layer(x) 284 | return x 285 | 286 | def build(self): 287 | self.sub_layers = nn.ModuleList([ 288 | nn.Conv2d( 289 | in_channels=self.in_channels, 290 | out_channels=self.out_channels, 291 | kernel_size=self.kernel_size, 292 | stride=1, 293 | padding=0, 294 | bias=True)]) 295 | 296 | 297 | class DenseDecoderLayer(nn.Module): 298 | def __init__(self, scale, spatial_size, in_size): 299 | super().__init__() 300 | self.scale = scale 301 | self.in_channels = in_size 302 | self.out_channels = 64*min(2**self.scale, 16) 303 | self.kernel_size = spatial_size 304 | self.build() 305 | 306 | def forward(self, input): 307 | x = input 308 | for layer in self.sub_layers: 309 | x = layer(x) 310 | return x 311 | 312 | def build(self): 313 | self.sub_layers = nn.ModuleList([ 314 | nn.ConvTranspose2d( 315 | in_channels=self.in_channels, 316 | out_channels=self.out_channels, 317 | kernel_size=self.kernel_size, 318 | stride=1, 319 | padding=0, 320 | bias=True)]) 321 | 322 | 323 | class ImageLayer(nn.Module): 324 | def __init__(self, out_channels=3, in_channels=64): 325 | super().__init__() 326 | self.in_channels = in_channels 327 | self.out_channels = out_channels 328 | self.build() 329 | 330 | def forward(self, input): 331 | x = input 332 | for layer in self.sub_layers: 333 | x = layer(x) 334 | return x 335 | 336 | def build(self): 337 | FinalActivate = lambda: torch.nn.Tanh() 338 | self.sub_layers = nn.ModuleList([ 339 | nn.ConvTranspose2d( 340 | in_channels=self.in_channels, 341 | out_channels=self.out_channels, 342 | kernel_size=4, 343 | stride=2, 344 | padding=1, 345 | bias=False), 346 | FinalActivate() 347 | ]) 348 | 349 | 350 | class BasicFullyConnectedVAE(nn.Module): 351 | def __init__(self, n_down=2, z_dim=128, in_channels=128, mid_channels=4096, use_bn=False, deterministic=False): 352 | super().__init__() 353 | 354 | self.be_deterministic = deterministic 355 | self.encoder = BasicFullyConnectedNet(dim=in_channels, depth=n_down, 356 | hidden_dim=mid_channels, 357 | out_dim=in_channels, 358 | use_bn=use_bn) 359 | self.mu_layer = BasicFullyConnectedNet(in_channels, depth=n_down, 360 | hidden_dim=mid_channels, 361 | out_dim=z_dim, 362 | use_bn=use_bn) 363 | self.logvar_layer = BasicFullyConnectedNet(in_channels, depth=n_down, 364 | hidden_dim=mid_channels, 365 | out_dim=z_dim, 366 | use_bn=use_bn) 367 | self.decoder = BasicFullyConnectedNet(dim=z_dim, depth=n_down+1, 368 | hidden_dim=mid_channels, 369 | out_dim=in_channels, 370 | use_bn=use_bn) 371 | 372 | def encode(self, x): 373 | h = self.encoder(x) 374 | mu = self.mu_layer(h) 375 | logvar = self.logvar_layer(h) 376 | return DiagonalGaussianDistribution(torch.cat((mu, logvar), dim=1), deterministic=self.be_deterministic) 377 | 378 | def decode(self, x): 379 | x = self.decoder(x) 380 | return x 381 | 382 | def forward(self, x): 383 | x = self.encode(x).sample() 384 | x = self.decoder(x) 385 | return x 386 | 387 | def get_last_layer(self): 388 | return self.decoder.main[-1].weight 389 | -------------------------------------------------------------------------------- /net2net/modules/autoencoder/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from net2net.modules.gan.biggan import load_variable_latsize_generator 5 | 6 | class ClassUp(nn.Module): 7 | def __init__(self, dim, depth, hidden_dim=256, use_sigmoid=False, out_dim=None): 8 | super().__init__() 9 | layers = [] 10 | layers.append(nn.Linear(dim, hidden_dim)) 11 | layers.append(nn.LeakyReLU()) 12 | for d in range(depth): 13 | layers.append(nn.Linear(hidden_dim, hidden_dim)) 14 | layers.append(nn.LeakyReLU()) 15 | layers.append(nn.Linear(hidden_dim, dim if out_dim is None else out_dim)) 16 | if use_sigmoid: 17 | layers.append(nn.Sigmoid()) 18 | self.main = nn.Sequential(*layers) 19 | 20 | def forward(self, x): 21 | x = self.main(x.squeeze(-1).squeeze(-1)) 22 | x = torch.nn.functional.softmax(x, dim=1) 23 | return x 24 | 25 | 26 | class BigGANDecoderWrapper(nn.Module): 27 | """Wraps a BigGAN into our autoencoding framework""" 28 | def __init__(self, z_dim, in_size=128, use_actnorm_in_dec=False, extra_z_dims=list()): 29 | super().__init__() 30 | self.z_dim = z_dim 31 | class_embedding_dim = 1000 32 | self.extra_z_dims = extra_z_dims 33 | self.map_to_class_embedding = ClassUp(z_dim, depth=2, hidden_dim=2*class_embedding_dim, 34 | use_sigmoid=False, out_dim=class_embedding_dim) 35 | self.decoder = load_variable_latsize_generator(in_size, z_dim, 36 | use_actnorm=use_actnorm_in_dec, 37 | n_class=class_embedding_dim, 38 | extra_z_dims=self.extra_z_dims) 39 | 40 | def forward(self, x, labels=None): 41 | emb = self.map_to_class_embedding(x[:,:self.z_dim,...]) 42 | x = self.decoder(x, emb) 43 | return x -------------------------------------------------------------------------------- /net2net/modules/autoencoder/encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | from torchvision import models 5 | 6 | from net2net.modules.autoencoder.basic import ActNorm, DenseEncoderLayer 7 | 8 | 9 | class ResnetEncoder(nn.Module): 10 | def __init__(self, z_dim, in_size, in_channels=3, 11 | pretrained=False, type="resnet50", 12 | double_z=True, pre_process=True, 13 | ): 14 | super().__init__() 15 | __possible_resnets = { 16 | 'resnet18': models.resnet18, 17 | 'resnet34': models.resnet34, 18 | 'resnet50': models.resnet50, 19 | 'resnet101': models.resnet101 20 | } 21 | self.use_preprocess = pre_process 22 | self.in_channels = in_channels 23 | norm_layer = ActNorm 24 | self.z_dim = z_dim 25 | self.model = __possible_resnets[type](pretrained=pretrained, norm_layer=norm_layer) 26 | 27 | self.image_transform = torchvision.transforms.Compose( 28 | [torchvision.transforms.Lambda(self.normscale)] 29 | ) 30 | 31 | size_pre_fc = self.get_spatial_size(in_size) 32 | assert size_pre_fc[2]==size_pre_fc[3], 'Output spatial size is not quadratic' 33 | spatial_size = size_pre_fc[2] 34 | num_channels_pre_fc = size_pre_fc[1] 35 | # replace last fc 36 | self.model.fc = DenseEncoderLayer(0, 37 | spatial_size=spatial_size, 38 | out_size=2*z_dim if double_z else z_dim, 39 | in_channels=num_channels_pre_fc) 40 | if self.in_channels != 3: 41 | self.model.in_ch_match = nn.Conv2d(self.in_channels, 3, 3, 1) 42 | 43 | def forward(self, x): 44 | if self.use_preprocess: 45 | x = self.pre_process(x) 46 | if self.in_channels != 3: 47 | assert not self.use_preprocess 48 | x = self.model.in_ch_match(x) 49 | features = self.features(x) 50 | encoding = self.model.fc(features) 51 | return encoding 52 | 53 | def rescale(self, x): 54 | return 0.5 * (x + 1) 55 | 56 | def normscale(self, image): 57 | normalize = torchvision.transforms.Normalize(mean=self.mean, std=self.std) 58 | return torch.stack([normalize(self.rescale(x)) for x in image]) 59 | 60 | def features(self, x): 61 | if self.use_preprocess: 62 | x = self.pre_process(x) 63 | x = self.model.conv1(x) 64 | x = self.model.bn1(x) 65 | x = self.model.relu(x) 66 | x = self.model.maxpool(x) 67 | x = self.model.layer1(x) 68 | x = self.model.layer2(x) 69 | x = self.model.layer3(x) 70 | x = self.model.layer4(x) 71 | x = self.model.avgpool(x) 72 | return x 73 | 74 | def post_features(self, x): 75 | x = self.model.fc(x) 76 | return x 77 | 78 | def pre_process(self, x): 79 | x = self.image_transform(x) 80 | return x 81 | 82 | def get_spatial_size(self, ipt_size): 83 | x = torch.randn(1, 3, ipt_size, ipt_size) 84 | return self.features(x).size() 85 | 86 | @property 87 | def mean(self): 88 | return [0.485, 0.456, 0.406] 89 | 90 | @property 91 | def std(self): 92 | return [0.229, 0.224, 0.225] 93 | 94 | @property 95 | def input_size(self): 96 | return [3, 224, 224] 97 | -------------------------------------------------------------------------------- /net2net/modules/autoencoder/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from net2net.modules.autoencoder.lpips import LPIPS # LPIPS loss 6 | from net2net.modules.discriminator.model import NLayerDiscriminator, weights_init 7 | 8 | 9 | def adopt_weight(weight, global_step, threshold=0, value=0.): 10 | if global_step < threshold: 11 | weight = value 12 | return weight 13 | 14 | 15 | def hinge_d_loss(logits_real, logits_fake): 16 | loss_real = torch.mean(F.relu(1. - logits_real)) 17 | loss_fake = torch.mean(F.relu(1. + logits_fake)) 18 | d_loss = 0.5 * (loss_real + loss_fake) 19 | return d_loss 20 | 21 | 22 | def vanilla_d_loss(logits_real, logits_fake): 23 | d_loss = 0.5 * ( 24 | torch.mean(torch.nn.functional.softplus(-logits_real)) + 25 | torch.mean(torch.nn.functional.softplus(logits_fake))) 26 | return d_loss 27 | 28 | 29 | class LPIPSWithDiscriminator(nn.Module): 30 | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0, 31 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 32 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 33 | evo_disc=False, disc_loss="hinge"): 34 | 35 | super().__init__() 36 | assert disc_loss in ["hinge", "vanilla"] 37 | self.kl_weight = kl_weight 38 | self.pixel_weight = pixelloss_weight 39 | self.perceptual_loss = LPIPS().eval() 40 | self.perceptual_weight = perceptual_weight 41 | # output log variance 42 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) 43 | 44 | if evo_disc: 45 | self.discriminator = NLayerDiscriminatorEvoNorm(input_nc=disc_in_channels, 46 | n_layers=disc_num_layers 47 | ).apply(weights_init) 48 | else: 49 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 50 | n_layers=disc_num_layers, 51 | use_actnorm=use_actnorm 52 | ).apply(weights_init) 53 | self.discriminator_iter_start = disc_start 54 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss 55 | self.disc_factor = disc_factor 56 | self.discriminator_weight = disc_weight 57 | self.disc_conditional = disc_conditional 58 | 59 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 60 | if last_layer is not None: 61 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 62 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 63 | else: 64 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 65 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 66 | 67 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 68 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 69 | d_weight = d_weight * self.discriminator_weight 70 | return d_weight 71 | 72 | def forward(self, inputs, reconstructions, posteriors, optimizer_idx, 73 | global_step, last_layer=None, cond=None, split="train", 74 | side_outputs=None, weights=None): 75 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 76 | if self.perceptual_weight > 0: 77 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 78 | rec_loss = rec_loss + self.perceptual_weight * p_loss 79 | 80 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar 81 | weighted_nll_loss = nll_loss 82 | if weights is not None: 83 | weighted_nll_loss = weights*nll_loss 84 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] 85 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 86 | kl_loss = posteriors.kl() 87 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] 88 | 89 | # now the GAN part 90 | if optimizer_idx == 0: 91 | # generator update 92 | if cond is None: 93 | assert not self.disc_conditional 94 | logits_fake = self.discriminator(reconstructions.contiguous()) 95 | else: 96 | assert self.disc_conditional 97 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 98 | g_loss = -torch.mean(logits_fake) 99 | 100 | if self.disc_factor > 0.0: 101 | try: 102 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 103 | except RuntimeError: 104 | assert not self.training 105 | d_weight = torch.tensor(0.0) 106 | else: 107 | d_weight = torch.tensor(0.0) 108 | 109 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 110 | loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss 111 | 112 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(), 113 | "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(), 114 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 115 | "{}/d_weight".format(split): d_weight.detach(), 116 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 117 | "{}/g_loss".format(split): g_loss.detach().mean(), 118 | } 119 | return loss, log 120 | 121 | if optimizer_idx == 1: 122 | # second pass for discriminator update 123 | if cond is None: 124 | logits_real = self.discriminator(inputs.contiguous().detach()) 125 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 126 | else: 127 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 128 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 129 | 130 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 131 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 132 | 133 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 134 | "{}/logits_real".format(split): logits_real.detach().mean(), 135 | "{}/logits_fake".format(split): logits_fake.detach().mean() 136 | } 137 | return d_loss, log 138 | 139 | class DummyLoss: 140 | pass -------------------------------------------------------------------------------- /net2net/modules/autoencoder/lpips.py: -------------------------------------------------------------------------------- 1 | """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torchvision import models 6 | from collections import namedtuple 7 | 8 | from net2net.ckpt_util import get_ckpt_path 9 | 10 | 11 | class LPIPS(nn.Module): 12 | # Learned perceptual metric 13 | def __init__(self, use_dropout=True): 14 | super().__init__() 15 | self.scaling_layer = ScalingLayer() 16 | self.chns = [64, 128, 256, 512, 512] # vg16 features 17 | self.net = vgg16(pretrained=True, requires_grad=False) 18 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 19 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 20 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 21 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 22 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 23 | self.load_from_pretrained() 24 | for param in self.parameters(): 25 | param.requires_grad = False 26 | 27 | def load_from_pretrained(self, name="vgg_lpips"): 28 | ckpt = get_ckpt_path(name, "net2net/modules/autoencoder/lpips") 29 | self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) 30 | print("loaded pretrained LPIPS loss from {}".format(ckpt)) 31 | 32 | @classmethod 33 | def from_pretrained(cls, name="vgg_lpips"): 34 | if name is not "vgg_lpips": 35 | raise NotImplementedError 36 | model = cls() 37 | ckpt = get_ckpt_path(name) 38 | model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) 39 | return model 40 | 41 | def forward(self, input, target): 42 | in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) 43 | outs0, outs1 = self.net(in0_input), self.net(in1_input) 44 | feats0, feats1, diffs = {}, {}, {} 45 | lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] 46 | for kk in range(len(self.chns)): 47 | feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) 48 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 49 | 50 | res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))] 51 | val = res[0] 52 | for l in range(1, len(self.chns)): 53 | val += res[l] 54 | return val 55 | 56 | 57 | class ScalingLayer(nn.Module): 58 | def __init__(self): 59 | super(ScalingLayer, self).__init__() 60 | self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 61 | self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None]) 62 | 63 | def forward(self, inp): 64 | return (inp - self.shift) / self.scale 65 | 66 | 67 | class NetLinLayer(nn.Module): 68 | """ A single linear layer which does a 1x1 conv """ 69 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 70 | super(NetLinLayer, self).__init__() 71 | layers = [nn.Dropout(), ] if (use_dropout) else [] 72 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] 73 | self.model = nn.Sequential(*layers) 74 | 75 | 76 | class vgg16(torch.nn.Module): 77 | def __init__(self, requires_grad=False, pretrained=True): 78 | super(vgg16, self).__init__() 79 | vgg_pretrained_features = models.vgg16(pretrained=pretrained).features 80 | self.slice1 = torch.nn.Sequential() 81 | self.slice2 = torch.nn.Sequential() 82 | self.slice3 = torch.nn.Sequential() 83 | self.slice4 = torch.nn.Sequential() 84 | self.slice5 = torch.nn.Sequential() 85 | self.N_slices = 5 86 | for x in range(4): 87 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 88 | for x in range(4, 9): 89 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 90 | for x in range(9, 16): 91 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 92 | for x in range(16, 23): 93 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 94 | for x in range(23, 30): 95 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 96 | if not requires_grad: 97 | for param in self.parameters(): 98 | param.requires_grad = False 99 | 100 | def forward(self, X): 101 | h = self.slice1(X) 102 | h_relu1_2 = h 103 | h = self.slice2(h) 104 | h_relu2_2 = h 105 | h = self.slice3(h) 106 | h_relu3_3 = h 107 | h = self.slice4(h) 108 | h_relu4_3 = h 109 | h = self.slice5(h) 110 | h_relu5_3 = h 111 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) 112 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 113 | return out 114 | 115 | 116 | def normalize_tensor(x,eps=1e-10): 117 | norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True)) 118 | return x/(norm_factor+eps) 119 | 120 | 121 | def spatial_average(x, keepdim=True): 122 | return x.mean([2,3],keepdim=keepdim) 123 | -------------------------------------------------------------------------------- /net2net/modules/captions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/net2net/modules/captions/__init__.py -------------------------------------------------------------------------------- /net2net/modules/captions/model.py: -------------------------------------------------------------------------------- 1 | """Code is based on https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning""" 2 | 3 | import os, sys 4 | import json 5 | import torch 6 | import torch.nn as nn 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | import torchvision 10 | import torch.nn.functional as F 11 | from torchvision import transforms 12 | from PIL import Image 13 | 14 | from net2net.ckpt_util import get_ckpt_path 15 | 16 | #import warnings 17 | #warnings.filterwarnings("ignore") 18 | 19 | from net2net.modules.captions.models import Encoder, DecoderWithAttention 20 | 21 | 22 | rescale = lambda x: 0.5*(x+1) 23 | 24 | 25 | def imresize(img, size): 26 | return np.array(Image.fromarray(img).resize(size)) 27 | 28 | 29 | class Img2Text(nn.Module): 30 | def __init__(self): 31 | super().__init__() 32 | model_path = get_ckpt_path("coco_captioner", "net2net/modules/captions") 33 | word_map_path = "data/WORDMAP_coco_5_cap_per_img_5_min_word_freq.json" 34 | 35 | # Load word map (word2ix) 36 | with open(word_map_path, 'r') as j: 37 | word_map = json.load(j) 38 | rev_word_map = {v: k for k, v in word_map.items()} # ix2word 39 | self.word_map = word_map 40 | self.rev_word_map = rev_word_map 41 | 42 | checkpoint = torch.load(model_path) 43 | 44 | self.encoder = Encoder() 45 | self.decoder = DecoderWithAttention(embed_dim=512, decoder_dim=512, attention_dim=512, vocab_size=9490) 46 | missing, unexpected = self.load_state_dict(checkpoint, strict=False) 47 | if len(missing) > 0: 48 | print(f"Missing keys in state-dict: {missing}") 49 | if len(unexpected) > 0: 50 | print(f"Unexpected keys in state-dict: {unexpected}") 51 | self.encoder.eval() 52 | self.decoder.eval() 53 | 54 | resize = transforms.Lambda(lambda image: F.interpolate(image, size=(256, 256), mode="bilinear")) 55 | normalize = torchvision.transforms.Normalize(mean=self.mean, std=self.std) 56 | norm = torchvision.transforms.Lambda(lambda image: torch.stack([normalize(rescale(x)) for x in image])) 57 | self.img_transform = transforms.Compose([resize, norm]) 58 | self.device = "cuda" 59 | 60 | def _pre_process(self, x): 61 | x = self.img_transform(x) 62 | return x 63 | 64 | @property 65 | def mean(self): 66 | return [0.485, 0.456, 0.406] 67 | 68 | @property 69 | def std(self): 70 | return [0.229, 0.224, 0.225] 71 | 72 | def forward(self, x): 73 | captions = list() 74 | for subx in x: 75 | subx = subx.unsqueeze(0) 76 | captions.append(self.make_single_caption(subx)) 77 | return captions 78 | 79 | def make_single_caption(self, x): 80 | seq = self.caption_image_beam_search(x)[0][0] 81 | words = [self.rev_word_map[ind] for ind in seq] 82 | words = words[:50] 83 | #if len(words) > 50: 84 | # return np.array(['']) 85 | text = '' 86 | for word in words: 87 | text += word + ' ' 88 | return text 89 | 90 | def caption_image_beam_search(self, image, beam_size=3): 91 | """ 92 | Reads a batch of images and captions each of it with beam search. 93 | :param image: batch of pytorch images 94 | :param beam_size: number of sequences to consider at each decode-step 95 | :return: caption, weights for visualization 96 | """ 97 | 98 | k = beam_size 99 | vocab_size = len(self.word_map) 100 | 101 | # Encode 102 | # image is a batch of images 103 | encoder_out_ = self.encoder(image) # (b, enc_image_size, enc_image_size, encoder_dim) 104 | enc_image_size = encoder_out_.size(1) 105 | encoder_dim = encoder_out_.size(3) 106 | batch_size = encoder_out_.size(0) 107 | 108 | # Flatten encoding 109 | encoder_out_ = encoder_out_.view(batch_size, -1, encoder_dim) # (1, num_pixels, encoder_dim) 110 | num_pixels = encoder_out_.size(1) 111 | 112 | sequences = list() 113 | alphas_ = list() 114 | # We'll treat the problem as having a batch size of k per example 115 | for single_example in encoder_out_: 116 | single_example = single_example[None, ...] 117 | encoder_out = single_example.expand(k, num_pixels, encoder_dim) # (k, num_pixels, encoder_dim) 118 | 119 | # Tensor to store top k previous words at each step; now they're just 120 | k_prev_words = torch.LongTensor([[self.word_map['']]] * k).to(self.device) # (k, 1) 121 | 122 | # Tensor to store top k sequences; now they're just 123 | seqs = k_prev_words # (k, 1) 124 | 125 | # Tensor to store top k sequences' scores; now they're just 0 126 | top_k_scores = torch.zeros(k, 1).to(self.device) # (k, 1) 127 | 128 | # Tensor to store top k sequences' alphas; now they're just 1s 129 | seqs_alpha = torch.ones(k, 1, enc_image_size, enc_image_size).to(self.device) # (k, 1, enc_image_size, enc_image_size) 130 | 131 | # Lists to store completed sequences, their alphas and scores 132 | complete_seqs = list() 133 | complete_seqs_alpha = list() 134 | complete_seqs_scores = list() 135 | 136 | # Start decoding 137 | step = 1 138 | h, c = self.decoder.init_hidden_state(encoder_out) 139 | 140 | # s is a number less than or equal to k, because sequences are removed from this process once they hit 141 | while True: 142 | embeddings = self.decoder.embedding(k_prev_words).squeeze(1) # (s, embed_dim) 143 | awe, alpha = self.decoder.attention(encoder_out, h) # (s, encoder_dim), (s, num_pixels) 144 | alpha = alpha.view(-1, enc_image_size, enc_image_size) # (s, enc_image_size, enc_image_size) 145 | gate = self.decoder.sigmoid(self.decoder.f_beta(h)) # gating scalar, (s, encoder_dim) 146 | awe = gate * awe 147 | h, c = self.decoder.decode_step(torch.cat([embeddings, awe], dim=1), (h, c)) # (s, decoder_dim) 148 | scores = self.decoder.fc(h) # (s, vocab_size) 149 | scores = F.log_softmax(scores, dim=1) 150 | # Add 151 | scores = top_k_scores.expand_as(scores) + scores # (s, vocab_size) 152 | 153 | # For the first step, all k points will have the same scores (since same k previous words, h, c) 154 | if step == 1: 155 | top_k_scores, top_k_words = scores[0].topk(k, 0, True, True) # (s) 156 | else: 157 | # Unroll and find top scores, and their unrolled indices 158 | top_k_scores, top_k_words = scores.view(-1).topk(k, 0, True, True) # (s) 159 | 160 | # Convert unrolled indices to actual indices of scores 161 | prev_word_inds = top_k_words // vocab_size # (s) 162 | next_word_inds = top_k_words % vocab_size # (s) 163 | 164 | # Add new words to sequences, alphas 165 | seqs = torch.cat([seqs[prev_word_inds], next_word_inds.unsqueeze(1)], dim=1) # (s, step+1) 166 | seqs_alpha = torch.cat([seqs_alpha[prev_word_inds], alpha[prev_word_inds].unsqueeze(1)], 167 | dim=1) # (s, step+1, enc_image_size, enc_image_size) 168 | 169 | # Which sequences are incomplete (didn't reach )? 170 | incomplete_inds = [ind for ind, next_word in enumerate(next_word_inds) if 171 | next_word != self.word_map['']] 172 | complete_inds = list(set(range(len(next_word_inds))) - set(incomplete_inds)) 173 | 174 | # Set aside complete sequences 175 | if len(complete_inds) > 0: 176 | complete_seqs.extend(seqs[complete_inds].tolist()) 177 | complete_seqs_alpha.extend(seqs_alpha[complete_inds].tolist()) 178 | complete_seqs_scores.extend(top_k_scores[complete_inds]) 179 | k -= len(complete_inds) # reduce beam length accordingly 180 | 181 | # Proceed with incomplete sequences 182 | if k == 0: 183 | break 184 | seqs = seqs[incomplete_inds] 185 | seqs_alpha = seqs_alpha[incomplete_inds] 186 | h = h[prev_word_inds[incomplete_inds]] 187 | c = c[prev_word_inds[incomplete_inds]] 188 | encoder_out = encoder_out[prev_word_inds[incomplete_inds]] 189 | top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1) 190 | k_prev_words = next_word_inds[incomplete_inds].unsqueeze(1) 191 | 192 | # Break if things have been going on too long 193 | if step > 50: 194 | break 195 | step += 1 196 | 197 | try: 198 | i = complete_seqs_scores.index(max(complete_seqs_scores)) 199 | seq = complete_seqs[i] 200 | alphas = complete_seqs_alpha[i] 201 | except ValueError: 202 | print("Catching an empty sequence.") 203 | try: 204 | len_ = len(sequences[-1]) 205 | seq = [0]*len_ 206 | alphas = None 207 | except: 208 | seq = [0]*9 209 | alphas = None 210 | 211 | sequences.append(seq) 212 | alphas_.append(alphas) 213 | 214 | return sequences, alphas_ 215 | 216 | def visualize_text(self, root, images, sequences, n_row=5, img_name='examples'): 217 | """ 218 | plot the text corresponding to the given images in a matplotlib figure. 219 | images are a batch of pytorch images 220 | """ 221 | 222 | n_img = images.size(0) 223 | n_col = max(n_img // n_row + 1, 2) 224 | 225 | fig, ax = plt.subplots(n_row, n_col) 226 | 227 | i = 0 228 | j = 0 229 | for image, seq in zip(images, sequences): 230 | if i == n_row: 231 | i = 0 232 | j += 1 233 | image = image.cpu().numpy().transpose(1, 2, 0) 234 | image = 255*(0.5*(image+1)) 235 | image = Image.fromarray(image.astype('uint8')) 236 | image = image.resize([14 * 24, 14 * 24], Image.LANCZOS) 237 | words = [self.rev_word_map[ind] for ind in seq] 238 | if len(words) > 50: 239 | return 240 | text = '' 241 | for word in words: 242 | text += word + ' ' 243 | 244 | ax[i, j].text(0, 1, '%s' % (text), color='black', backgroundcolor='white', fontsize=12) 245 | ax[i, j].imshow(image) 246 | ax[i, j].axis('off') 247 | 248 | plt.savefig(os.path.join(root, img_name + '.png')) 249 | 250 | 251 | if __name__ == '__main__': 252 | model = Img2Text() 253 | print("done.") 254 | -------------------------------------------------------------------------------- /net2net/modules/captions/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torchvision 4 | 5 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 6 | 7 | 8 | class Encoder(nn.Module): 9 | """ 10 | Encoder. 11 | """ 12 | 13 | def __init__(self, encoded_image_size=14): 14 | super(Encoder, self).__init__() 15 | self.enc_image_size = encoded_image_size 16 | 17 | resnet = torchvision.models.resnet101(pretrained=True) # pretrained ImageNet ResNet-101 18 | 19 | # Remove linear and pool layers (since we're not doing classification) 20 | modules = list(resnet.children())[:-2] 21 | self.resnet = nn.Sequential(*modules) 22 | 23 | # Resize image to fixed size to allow input images of variable size 24 | self.adaptive_pool = nn.AdaptiveAvgPool2d((encoded_image_size, encoded_image_size)) 25 | 26 | self.fine_tune() 27 | 28 | def forward(self, images): 29 | """ 30 | Forward propagation. 31 | 32 | :param images: images, a tensor of dimensions (batch_size, 3, image_size, image_size) 33 | :return: encoded images 34 | """ 35 | out = self.resnet(images) # (batch_size, 2048, image_size/32, image_size/32) 36 | out = self.adaptive_pool(out) # (batch_size, 2048, encoded_image_size, encoded_image_size) 37 | out = out.permute(0, 2, 3, 1) # (batch_size, encoded_image_size, encoded_image_size, 2048) 38 | return out 39 | 40 | def fine_tune(self, fine_tune=True): 41 | """ 42 | Allow or prevent the computation of gradients for convolutional blocks 2 through 4 of the encoder. 43 | 44 | :param fine_tune: Allow? 45 | """ 46 | for p in self.resnet.parameters(): 47 | p.requires_grad = False 48 | # If fine-tuning, only fine-tune convolutional blocks 2 through 4 49 | for c in list(self.resnet.children())[5:]: 50 | for p in c.parameters(): 51 | p.requires_grad = fine_tune 52 | 53 | 54 | class Attention(nn.Module): 55 | """ 56 | Attention Network. 57 | """ 58 | 59 | def __init__(self, encoder_dim, decoder_dim, attention_dim): 60 | """ 61 | :param encoder_dim: feature size of encoded images 62 | :param decoder_dim: size of decoder's RNN 63 | :param attention_dim: size of the attention network 64 | """ 65 | super(Attention, self).__init__() 66 | self.encoder_att = nn.Linear(encoder_dim, attention_dim) # linear layer to transform encoded image 67 | self.decoder_att = nn.Linear(decoder_dim, attention_dim) # linear layer to transform decoder's output 68 | self.full_att = nn.Linear(attention_dim, 1) # linear layer to calculate values to be softmax-ed 69 | self.relu = nn.ReLU() 70 | #self.softmax = nn.Softmax(dim=1) # softmax layer to calculate weights 71 | 72 | def forward(self, encoder_out, decoder_hidden): 73 | """ 74 | Forward propagation. 75 | 76 | :param encoder_out: encoded images, a tensor of dimension (batch_size, num_pixels, encoder_dim) 77 | :param decoder_hidden: previous decoder output, a tensor of dimension (batch_size, decoder_dim) 78 | :return: attention weighted encoding, weights 79 | """ 80 | att1 = self.encoder_att(encoder_out) # (batch_size, num_pixels, attention_dim) 81 | att2 = self.decoder_att(decoder_hidden) # (batch_size, attention_dim) 82 | att = self.full_att(self.relu(att1 + att2.unsqueeze(1))).squeeze(2) # (batch_size, num_pixels) 83 | #alpha = self.softmax(att) # (batch_size, num_pixels) 84 | alpha = torch.nn.functional.softmax(att, dim=1) 85 | attention_weighted_encoding = (encoder_out * alpha.unsqueeze(2)).sum(dim=1) # (batch_size, encoder_dim) 86 | 87 | return attention_weighted_encoding, alpha 88 | 89 | 90 | class DecoderWithAttention(nn.Module): 91 | """ 92 | Decoder. 93 | """ 94 | 95 | def __init__(self, attention_dim, embed_dim, decoder_dim, vocab_size, encoder_dim=2048, dropout=0.5): 96 | """ 97 | :param attention_dim: size of attention network 98 | :param embed_dim: embedding size 99 | :param decoder_dim: size of decoder's RNN 100 | :param vocab_size: size of vocabulary 101 | :param encoder_dim: feature size of encoded images 102 | :param dropout: dropout 103 | """ 104 | super(DecoderWithAttention, self).__init__() 105 | 106 | self.encoder_dim = encoder_dim 107 | self.attention_dim = attention_dim 108 | self.embed_dim = embed_dim 109 | self.decoder_dim = decoder_dim 110 | self.vocab_size = vocab_size 111 | self.dropout = dropout 112 | 113 | self.attention = Attention(encoder_dim, decoder_dim, attention_dim) # attention network 114 | 115 | self.embedding = nn.Embedding(vocab_size, embed_dim) # embedding layer 116 | self.dropout = nn.Dropout(p=self.dropout) 117 | self.decode_step = nn.LSTMCell(embed_dim + encoder_dim, decoder_dim, bias=True) # decoding LSTMCell 118 | self.init_h = nn.Linear(encoder_dim, decoder_dim) # linear layer to find initial hidden state of LSTMCell 119 | self.init_c = nn.Linear(encoder_dim, decoder_dim) # linear layer to find initial cell state of LSTMCell 120 | self.f_beta = nn.Linear(decoder_dim, encoder_dim) # linear layer to create a sigmoid-activated gate 121 | self.sigmoid = nn.Sigmoid() 122 | self.fc = nn.Linear(decoder_dim, vocab_size) # linear layer to find scores over vocabulary 123 | self.init_weights() # initialize some layers with the uniform distribution 124 | 125 | def init_weights(self): 126 | """ 127 | Initializes some parameters with values from the uniform distribution, for easier convergence. 128 | """ 129 | self.embedding.weight.data.uniform_(-0.1, 0.1) 130 | self.fc.bias.data.fill_(0) 131 | self.fc.weight.data.uniform_(-0.1, 0.1) 132 | 133 | def load_pretrained_embeddings(self, embeddings): 134 | """ 135 | Loads embedding layer with pre-trained embeddings. 136 | 137 | :param embeddings: pre-trained embeddings 138 | """ 139 | self.embedding.weight = nn.Parameter(embeddings) 140 | 141 | def fine_tune_embeddings(self, fine_tune=True): 142 | """ 143 | Allow fine-tuning of embedding layer? (Only makes sense to not-allow if using pre-trained embeddings). 144 | 145 | :param fine_tune: Allow? 146 | """ 147 | for p in self.embedding.parameters(): 148 | p.requires_grad = fine_tune 149 | 150 | def init_hidden_state(self, encoder_out): 151 | """ 152 | Creates the initial hidden and cell states for the decoder's LSTM based on the encoded images. 153 | 154 | :param encoder_out: encoded images, a tensor of dimension (batch_size, num_pixels, encoder_dim) 155 | :return: hidden state, cell state 156 | """ 157 | mean_encoder_out = encoder_out.mean(dim=1) 158 | h = self.init_h(mean_encoder_out) # (batch_size, decoder_dim) 159 | c = self.init_c(mean_encoder_out) 160 | return h, c 161 | 162 | def forward(self, encoder_out, encoded_captions, caption_lengths): 163 | """ 164 | Forward propagation. 165 | 166 | :param encoder_out: encoded images, a tensor of dimension (batch_size, enc_image_size, enc_image_size, encoder_dim) 167 | :param encoded_captions: encoded captions, a tensor of dimension (batch_size, max_caption_length) 168 | :param caption_lengths: caption lengths, a tensor of dimension (batch_size, 1) 169 | :return: scores for vocabulary, sorted encoded captions, decode lengths, weights, sort indices 170 | """ 171 | 172 | batch_size = encoder_out.size(0) 173 | encoder_dim = encoder_out.size(-1) 174 | vocab_size = self.vocab_size 175 | 176 | # Flatten image 177 | encoder_out = encoder_out.view(batch_size, -1, encoder_dim) # (batch_size, num_pixels, encoder_dim) 178 | num_pixels = encoder_out.size(1) 179 | 180 | # Sort input data by decreasing lengths; why? apparent below 181 | caption_lengths, sort_ind = caption_lengths.squeeze(1).sort(dim=0, descending=True) 182 | encoder_out = encoder_out[sort_ind] 183 | encoded_captions = encoded_captions[sort_ind] 184 | 185 | # Embedding 186 | embeddings = self.embedding(encoded_captions) # (batch_size, max_caption_length, embed_dim) 187 | 188 | # Initialize LSTM state 189 | h, c = self.init_hidden_state(encoder_out) # (batch_size, decoder_dim) 190 | 191 | # We won't decode at the position, since we've finished generating as soon as we generate 192 | # So, decoding lengths are actual lengths - 1 193 | decode_lengths = (caption_lengths - 1).tolist() 194 | 195 | # Create tensors to hold word predicion scores and alphas 196 | predictions = torch.zeros(batch_size, max(decode_lengths), vocab_size).to(device) 197 | alphas = torch.zeros(batch_size, max(decode_lengths), num_pixels).to(device) 198 | 199 | # At each time-step, decode by 200 | # attention-weighing the encoder's output based on the decoder's previous hidden state output 201 | # then generate a new word in the decoder with the previous word and the attention weighted encoding 202 | for t in range(max(decode_lengths)): 203 | batch_size_t = sum([l > t for l in decode_lengths]) 204 | attention_weighted_encoding, alpha = self.attention(encoder_out[:batch_size_t], 205 | h[:batch_size_t]) 206 | gate = self.sigmoid(self.f_beta(h[:batch_size_t])) # gating scalar, (batch_size_t, encoder_dim) 207 | attention_weighted_encoding = gate * attention_weighted_encoding 208 | h, c = self.decode_step( 209 | torch.cat([embeddings[:batch_size_t, t, :], attention_weighted_encoding], dim=1), 210 | (h[:batch_size_t], c[:batch_size_t])) # (batch_size_t, decoder_dim) 211 | preds = self.fc(self.dropout(h)) # (batch_size_t, vocab_size) 212 | predictions[:batch_size_t, t, :] = preds 213 | alphas[:batch_size_t, t, :] = alpha 214 | 215 | return predictions, encoded_captions, decode_lengths, alphas, sort_ind 216 | -------------------------------------------------------------------------------- /net2net/modules/discriminator/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/net2net/modules/discriminator/__init__.py -------------------------------------------------------------------------------- /net2net/modules/discriminator/model.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import torch.nn as nn 3 | 4 | from net2net.modules.autoencoder.basic import ActNorm 5 | 6 | 7 | def weights_init(m): 8 | classname = m.__class__.__name__ 9 | if classname.find('Conv') != -1: 10 | nn.init.normal_(m.weight.data, 0.0, 0.02) 11 | elif classname.find('BatchNorm') != -1: 12 | nn.init.normal_(m.weight.data, 1.0, 0.02) 13 | nn.init.constant_(m.bias.data, 0) 14 | 15 | 16 | class NLayerDiscriminator(nn.Module): 17 | """Defines a PatchGAN discriminator as in Pix2Pix 18 | --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py 19 | """ 20 | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): 21 | """Construct a PatchGAN discriminator 22 | Parameters: 23 | input_nc (int) -- the number of channels in input images 24 | ndf (int) -- the number of filters in the last conv layer 25 | n_layers (int) -- the number of conv layers in the discriminator 26 | norm_layer -- normalization layer 27 | """ 28 | super(NLayerDiscriminator, self).__init__() 29 | if not use_actnorm: 30 | norm_layer = nn.BatchNorm2d 31 | else: 32 | norm_layer = ActNorm 33 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters 34 | use_bias = norm_layer.func != nn.BatchNorm2d 35 | else: 36 | use_bias = norm_layer != nn.BatchNorm2d 37 | 38 | kw = 4 39 | padw = 1 40 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] 41 | nf_mult = 1 42 | nf_mult_prev = 1 43 | for n in range(1, n_layers): # gradually increase the number of filters 44 | nf_mult_prev = nf_mult 45 | nf_mult = min(2 ** n, 8) 46 | sequence += [ 47 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), 48 | norm_layer(ndf * nf_mult), 49 | nn.LeakyReLU(0.2, True) 50 | ] 51 | 52 | nf_mult_prev = nf_mult 53 | nf_mult = min(2 ** n_layers, 8) 54 | sequence += [ 55 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), 56 | norm_layer(ndf * nf_mult), 57 | nn.LeakyReLU(0.2, True) 58 | ] 59 | 60 | sequence += [ 61 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map 62 | self.main = nn.Sequential(*sequence) 63 | 64 | def forward(self, input): 65 | """Standard forward.""" 66 | return self.main(input) 67 | -------------------------------------------------------------------------------- /net2net/modules/distributions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/net2net/modules/distributions/__init__.py -------------------------------------------------------------------------------- /net2net/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 10.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.]) 42 | else: 43 | if other is None: 44 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 45 | + self.var - 1.0 - self.logvar, 46 | dim=[1, 2, 3]) 47 | else: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean - other.mean, 2) / other.var 50 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 51 | dim=[1, 2, 3]) 52 | 53 | def nll(self, sample): 54 | if self.deterministic: 55 | return torch.Tensor([0.]) 56 | logtwopi = np.log(2.0 * np.pi) 57 | return 0.5 * torch.sum( 58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 59 | dim=[1, 2, 3]) 60 | 61 | def mode(self): 62 | return self.mean 63 | -------------------------------------------------------------------------------- /net2net/modules/facenet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/net2net/modules/facenet/__init__.py -------------------------------------------------------------------------------- /net2net/modules/facenet/inception_resnet_v1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | import requests 5 | from requests.adapters import HTTPAdapter 6 | import os 7 | 8 | 9 | class BasicConv2d(nn.Module): 10 | 11 | def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0): 12 | super().__init__() 13 | self.conv = nn.Conv2d( 14 | in_planes, out_planes, 15 | kernel_size=kernel_size, stride=stride, 16 | padding=padding, bias=False 17 | ) # verify bias false 18 | self.bn = nn.BatchNorm2d( 19 | out_planes, 20 | eps=0.001, # value found in tensorflow 21 | momentum=0.1, # default pytorch value 22 | affine=True 23 | ) 24 | self.relu = nn.ReLU(inplace=False) 25 | 26 | def forward(self, x): 27 | x = self.conv(x) 28 | x = self.bn(x) 29 | x = self.relu(x) 30 | return x 31 | 32 | 33 | class Block35(nn.Module): 34 | 35 | def __init__(self, scale=1.0): 36 | super().__init__() 37 | 38 | self.scale = scale 39 | 40 | self.branch0 = BasicConv2d(256, 32, kernel_size=1, stride=1) 41 | 42 | self.branch1 = nn.Sequential( 43 | BasicConv2d(256, 32, kernel_size=1, stride=1), 44 | BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1) 45 | ) 46 | 47 | self.branch2 = nn.Sequential( 48 | BasicConv2d(256, 32, kernel_size=1, stride=1), 49 | BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1), 50 | BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1) 51 | ) 52 | 53 | self.conv2d = nn.Conv2d(96, 256, kernel_size=1, stride=1) 54 | self.relu = nn.ReLU(inplace=False) 55 | 56 | def forward(self, x): 57 | x0 = self.branch0(x) 58 | x1 = self.branch1(x) 59 | x2 = self.branch2(x) 60 | out = torch.cat((x0, x1, x2), 1) 61 | out = self.conv2d(out) 62 | out = out * self.scale + x 63 | out = self.relu(out) 64 | return out 65 | 66 | 67 | class Block17(nn.Module): 68 | 69 | def __init__(self, scale=1.0): 70 | super().__init__() 71 | 72 | self.scale = scale 73 | 74 | self.branch0 = BasicConv2d(896, 128, kernel_size=1, stride=1) 75 | 76 | self.branch1 = nn.Sequential( 77 | BasicConv2d(896, 128, kernel_size=1, stride=1), 78 | BasicConv2d(128, 128, kernel_size=(1,7), stride=1, padding=(0,3)), 79 | BasicConv2d(128, 128, kernel_size=(7,1), stride=1, padding=(3,0)) 80 | ) 81 | 82 | self.conv2d = nn.Conv2d(256, 896, kernel_size=1, stride=1) 83 | self.relu = nn.ReLU(inplace=False) 84 | 85 | def forward(self, x): 86 | x0 = self.branch0(x) 87 | x1 = self.branch1(x) 88 | out = torch.cat((x0, x1), 1) 89 | out = self.conv2d(out) 90 | out = out * self.scale + x 91 | out = self.relu(out) 92 | return out 93 | 94 | 95 | class Block8(nn.Module): 96 | 97 | def __init__(self, scale=1.0, noReLU=False): 98 | super().__init__() 99 | 100 | self.scale = scale 101 | self.noReLU = noReLU 102 | 103 | self.branch0 = BasicConv2d(1792, 192, kernel_size=1, stride=1) 104 | 105 | self.branch1 = nn.Sequential( 106 | BasicConv2d(1792, 192, kernel_size=1, stride=1), 107 | BasicConv2d(192, 192, kernel_size=(1,3), stride=1, padding=(0,1)), 108 | BasicConv2d(192, 192, kernel_size=(3,1), stride=1, padding=(1,0)) 109 | ) 110 | 111 | self.conv2d = nn.Conv2d(384, 1792, kernel_size=1, stride=1) 112 | if not self.noReLU: 113 | self.relu = nn.ReLU(inplace=False) 114 | 115 | def forward(self, x): 116 | x0 = self.branch0(x) 117 | x1 = self.branch1(x) 118 | out = torch.cat((x0, x1), 1) 119 | out = self.conv2d(out) 120 | out = out * self.scale + x 121 | if not self.noReLU: 122 | out = self.relu(out) 123 | return out 124 | 125 | 126 | class Mixed_6a(nn.Module): 127 | 128 | def __init__(self): 129 | super().__init__() 130 | 131 | self.branch0 = BasicConv2d(256, 384, kernel_size=3, stride=2) 132 | 133 | self.branch1 = nn.Sequential( 134 | BasicConv2d(256, 192, kernel_size=1, stride=1), 135 | BasicConv2d(192, 192, kernel_size=3, stride=1, padding=1), 136 | BasicConv2d(192, 256, kernel_size=3, stride=2) 137 | ) 138 | 139 | self.branch2 = nn.MaxPool2d(3, stride=2) 140 | 141 | def forward(self, x): 142 | x0 = self.branch0(x) 143 | x1 = self.branch1(x) 144 | x2 = self.branch2(x) 145 | out = torch.cat((x0, x1, x2), 1) 146 | return out 147 | 148 | 149 | class Mixed_7a(nn.Module): 150 | 151 | def __init__(self): 152 | super().__init__() 153 | 154 | self.branch0 = nn.Sequential( 155 | BasicConv2d(896, 256, kernel_size=1, stride=1), 156 | BasicConv2d(256, 384, kernel_size=3, stride=2) 157 | ) 158 | 159 | self.branch1 = nn.Sequential( 160 | BasicConv2d(896, 256, kernel_size=1, stride=1), 161 | BasicConv2d(256, 256, kernel_size=3, stride=2) 162 | ) 163 | 164 | self.branch2 = nn.Sequential( 165 | BasicConv2d(896, 256, kernel_size=1, stride=1), 166 | BasicConv2d(256, 256, kernel_size=3, stride=1, padding=1), 167 | BasicConv2d(256, 256, kernel_size=3, stride=2) 168 | ) 169 | 170 | self.branch3 = nn.MaxPool2d(3, stride=2) 171 | 172 | def forward(self, x): 173 | x0 = self.branch0(x) 174 | x1 = self.branch1(x) 175 | x2 = self.branch2(x) 176 | x3 = self.branch3(x) 177 | out = torch.cat((x0, x1, x2, x3), 1) 178 | return out 179 | 180 | 181 | class InceptionResnetV1(nn.Module): 182 | """Inception Resnet V1 model with optional loading of pretrained weights. 183 | 184 | Model parameters can be loaded based on pretraining on the VGGFace2 or CASIA-Webface 185 | datasets. Pretrained state_dicts are automatically downloaded on model instantiation if 186 | requested and cached in the torch cache. Subsequent instantiations use the cache rather than 187 | redownloading. 188 | 189 | Keyword Arguments: 190 | pretrained {str} -- Optional pretraining dataset. Either 'vggface2' or 'casia-webface'. 191 | (default: {None}) 192 | classify {bool} -- Whether the model should output classification probabilities or feature 193 | embeddings. (default: {False}) 194 | num_classes {int} -- Number of output classes. If 'pretrained' is set and num_classes not 195 | equal to that used for the pretrained model, the final linear layer will be randomly 196 | initialized. (default: {None}) 197 | dropout_prob {float} -- Dropout probability. (default: {0.6}) 198 | """ 199 | def __init__(self, pretrained=None, classify=False, num_classes=None, dropout_prob=0.6, device=None): 200 | super().__init__() 201 | 202 | # Set simple attributes 203 | self.pretrained = pretrained 204 | self.classify = classify 205 | self.num_classes = num_classes 206 | 207 | if pretrained == 'vggface2': 208 | tmp_classes = 8631 209 | elif pretrained == 'casia-webface': 210 | tmp_classes = 10575 211 | elif pretrained is None and self.num_classes is None: 212 | raise Exception('At least one of "pretrained" or "num_classes" must be specified') 213 | else: 214 | tmp_classes = self.num_classes 215 | 216 | 217 | # Define layers 218 | self.conv2d_1a = BasicConv2d(3, 32, kernel_size=3, stride=2) 219 | self.conv2d_2a = BasicConv2d(32, 32, kernel_size=3, stride=1) 220 | self.conv2d_2b = BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1) 221 | self.maxpool_3a = nn.MaxPool2d(3, stride=2) 222 | self.conv2d_3b = BasicConv2d(64, 80, kernel_size=1, stride=1) 223 | self.conv2d_4a = BasicConv2d(80, 192, kernel_size=3, stride=1) 224 | self.conv2d_4b = BasicConv2d(192, 256, kernel_size=3, stride=2) 225 | self.repeat_1 = nn.Sequential( 226 | Block35(scale=0.17), 227 | Block35(scale=0.17), 228 | Block35(scale=0.17), 229 | Block35(scale=0.17), 230 | Block35(scale=0.17), 231 | ) 232 | self.mixed_6a = Mixed_6a() 233 | self.repeat_2 = nn.Sequential( 234 | Block17(scale=0.10), 235 | Block17(scale=0.10), 236 | Block17(scale=0.10), 237 | Block17(scale=0.10), 238 | Block17(scale=0.10), 239 | Block17(scale=0.10), 240 | Block17(scale=0.10), 241 | Block17(scale=0.10), 242 | Block17(scale=0.10), 243 | Block17(scale=0.10), 244 | ) 245 | self.mixed_7a = Mixed_7a() 246 | self.repeat_3 = nn.Sequential( 247 | Block8(scale=0.20), 248 | Block8(scale=0.20), 249 | Block8(scale=0.20), 250 | Block8(scale=0.20), 251 | Block8(scale=0.20), 252 | ) 253 | self.block8 = Block8(noReLU=True) 254 | self.avgpool_1a = nn.AdaptiveAvgPool2d(1) 255 | self.dropout = nn.Dropout(dropout_prob) 256 | self.last_linear = nn.Linear(1792, 512, bias=False) 257 | self.last_bn = nn.BatchNorm1d(512, eps=0.001, momentum=0.1, affine=True) 258 | self.logits = nn.Linear(512, tmp_classes) 259 | 260 | if pretrained is not None: 261 | load_weights(self, pretrained) 262 | 263 | if self.num_classes is not None: 264 | self.logits = nn.Linear(512, self.num_classes) 265 | 266 | self.device = torch.device('cpu') 267 | if device is not None: 268 | self.device = device 269 | self.to(device) 270 | 271 | def forward(self, x): 272 | """Calculate embeddings or probabilities given a batch of input image tensors. 273 | 274 | Arguments: 275 | x {torch.tensor} -- Batch of image tensors representing faces. 276 | 277 | Returns: 278 | torch.tensor -- Batch of embeddings or softmax probabilities. 279 | """ 280 | x = self.conv2d_1a(x) 281 | x = self.conv2d_2a(x) 282 | x = self.conv2d_2b(x) 283 | x = self.maxpool_3a(x) 284 | x = self.conv2d_3b(x) 285 | x = self.conv2d_4a(x) 286 | x = self.conv2d_4b(x) 287 | x = self.repeat_1(x) 288 | x = self.mixed_6a(x) 289 | x = self.repeat_2(x) 290 | x = self.mixed_7a(x) 291 | x = self.repeat_3(x) 292 | x = self.block8(x) 293 | x = self.avgpool_1a(x) 294 | x = self.dropout(x) 295 | x = self.last_linear(x.view(x.shape[0], -1)) 296 | x = self.last_bn(x) 297 | x = F.normalize(x, p=2, dim=1) 298 | if self.classify: 299 | x = self.logits(x) 300 | return x 301 | 302 | 303 | def load_weights(mdl, name): 304 | """Download pretrained state_dict and load into model. 305 | 306 | Arguments: 307 | mdl {torch.nn.Module} -- Pytorch model. 308 | name {str} -- Name of dataset that was used to generate pretrained state_dict. 309 | 310 | Raises: 311 | ValueError: If 'pretrained' not equal to 'vggface2' or 'casia-webface'. 312 | """ 313 | if name == 'vggface2': 314 | features_path = 'https://drive.google.com/uc?export=download&id=1cWLH_hPns8kSfMz9kKl9PsG5aNV2VSMn' 315 | logits_path = 'https://drive.google.com/uc?export=download&id=1mAie3nzZeno9UIzFXvmVZrDG3kwML46X' 316 | elif name == 'casia-webface': 317 | features_path = 'https://drive.google.com/uc?export=download&id=1LSHHee_IQj5W3vjBcRyVaALv4py1XaGy' 318 | logits_path = 'https://drive.google.com/uc?export=download&id=1QrhPgn1bGlDxAil2uc07ctunCQoDnCzT' 319 | else: 320 | raise ValueError('Pretrained models only exist for "vggface2" and "casia-webface"') 321 | 322 | model_dir = os.path.join(get_torch_home(), 'checkpoints') 323 | os.makedirs(model_dir, exist_ok=True) 324 | 325 | state_dict = {} 326 | for i, path in enumerate([features_path, logits_path]): 327 | cached_file = os.path.join(model_dir, '{}_{}.pt'.format(name, path[-10:])) 328 | if not os.path.exists(cached_file): 329 | print('Downloading parameters ({}/2)'.format(i+1)) 330 | s = requests.Session() 331 | s.mount('https://', HTTPAdapter(max_retries=10)) 332 | r = s.get(path, allow_redirects=True) 333 | with open(cached_file, 'wb') as f: 334 | f.write(r.content) 335 | state_dict.update(torch.load(cached_file)) 336 | 337 | mdl.load_state_dict(state_dict) 338 | 339 | 340 | def get_torch_home(): 341 | torch_home = os.path.expanduser( 342 | os.getenv( 343 | 'TORCH_HOME', 344 | os.path.join(os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch') 345 | ) 346 | ) 347 | return torch_home 348 | -------------------------------------------------------------------------------- /net2net/modules/facenet/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | 5 | from net2net.modules.facenet.inception_resnet_v1 import InceptionResnetV1 6 | 7 | 8 | """FaceNet adopted from https://github.com/timesler/facenet-pytorch""" 9 | 10 | 11 | class FaceNet(nn.Module): 12 | def __init__(self): 13 | super().__init__() 14 | # InceptionResnetV1 has a bottleneck of size 512 15 | self.net = InceptionResnetV1(pretrained='vggface2').eval() 16 | 17 | def _pre_process(self, x): 18 | # TODO: neccessary for InceptionResnetV1? 19 | # seems like mtcnn (multi-task cnn) preprocessing is neccessary, but not 100% sure 20 | return x 21 | 22 | def forward(self, x, return_logits=False): 23 | # output are logits of size 8631 or embeddings of size 512 24 | x = self._pre_process(x) 25 | emb = self.net(x) 26 | if return_logits: 27 | return self.net.logits(emb) 28 | return emb 29 | 30 | def encode(self, x): 31 | return self(x) 32 | 33 | def return_features(self, x): 34 | """ returned features have the following dimensions: 35 | 36 | torch.Size([11, 3, 128, 128]), x 49152 37 | torch.Size([11, 192, 28, 28]), x 150528 38 | torch.Size([11, 896, 6, 6]), x 32256 39 | torch.Size([11, 1792, 1, 1]), x 1792 40 | torch.Size([11, 512]) x 512 41 | logits (8xxx) x 8xxx 42 | """ 43 | 44 | x = self._pre_process(x) 45 | features = [x] # this 46 | x = self.net.conv2d_1a(x) 47 | x = self.net.conv2d_2a(x) 48 | x = self.net.conv2d_2b(x) 49 | x = self.net.maxpool_3a(x) 50 | x = self.net.conv2d_3b(x) 51 | x = self.net.conv2d_4a(x) 52 | features.append(x) # this 53 | x = self.net.conv2d_4b(x) 54 | x = self.net.repeat_1(x) 55 | x = self.net.mixed_6a(x) 56 | features.append(x) # this 57 | x = self.net.repeat_2(x) 58 | x = self.net.mixed_7a(x) 59 | x = self.net.repeat_3(x) 60 | x = self.net.block8(x) 61 | x = self.net.avgpool_1a(x) 62 | features.append(x) # this 63 | x = self.net.dropout(x) 64 | x = self.net.last_linear(x.view(x.shape[0], -1)) 65 | x = self.net.last_bn(x) 66 | emb = F.normalize(x, p=2, dim=1) # the final embeddings 67 | features.append(emb[..., None, None]) # need extra dimensions for flow later 68 | features.append(self.net.logits(emb).unsqueeze(-1).unsqueeze(-1)) 69 | return features # has 6 elements as of now 70 | 71 | -------------------------------------------------------------------------------- /net2net/modules/flow/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/net2net/modules/flow/__init__.py -------------------------------------------------------------------------------- /net2net/modules/flow/base.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class NormalizingFlow(nn.Module): 5 | def __init__(self, *args, **kwargs): 6 | super().__init__() 7 | 8 | def forward(self, *args, **kwargs): 9 | # return transformed, logdet 10 | raise NotImplementedError 11 | 12 | def reverse(self, *args, **kwargs): 13 | # return transformed_reverse 14 | raise NotImplementedError 15 | 16 | def sample(self, *args, **kwargs): 17 | # return sample 18 | raise NotImplementedError 19 | -------------------------------------------------------------------------------- /net2net/modules/flow/blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import functools 4 | 5 | from net2net.modules.autoencoder.basic import ActNorm, BasicFullyConnectedNet 6 | 7 | 8 | class Flow(nn.Module): 9 | def __init__(self, module_list, in_channels, hidden_dim, hidden_depth): 10 | super(Flow, self).__init__() 11 | self.in_channels = in_channels 12 | self.flow = nn.ModuleList( 13 | [module(in_channels, hidden_dim=hidden_dim, depth=hidden_depth) for module in module_list]) 14 | 15 | def forward(self, x, condition=None, reverse=False): 16 | if not reverse: 17 | logdet = 0 18 | for i in range(len(self.flow)): 19 | x, logdet_ = self.flow[i](x) 20 | logdet = logdet + logdet_ 21 | return x, logdet 22 | else: 23 | for i in reversed(range(len(self.flow))): 24 | x = self.flow[i](x, reverse=True) 25 | return x 26 | 27 | 28 | class UnconditionalFlatDoubleCouplingFlowBlock(nn.Module): 29 | def __init__(self, in_channels, hidden_dim, hidden_depth): 30 | super().__init__() 31 | self.norm_layer = ActNorm(in_channels, logdet=True) 32 | self.coupling = DoubleVectorCouplingBlock(in_channels, 33 | hidden_dim, 34 | hidden_depth) 35 | self.shuffle = Shuffle(in_channels) 36 | 37 | def forward(self, x, reverse=False): 38 | if not reverse: 39 | h = x 40 | logdet = 0.0 41 | h, ld = self.norm_layer(h) 42 | logdet += ld 43 | h, ld = self.coupling(h) 44 | logdet += ld 45 | h, ld = self.shuffle(h) 46 | logdet += ld 47 | return h, logdet 48 | else: 49 | h = x 50 | h = self.shuffle(h, reverse=True) 51 | h = self.coupling(h, reverse=True) 52 | h = self.norm_layer(h, reverse=True) 53 | return h 54 | 55 | def reverse(self, out): 56 | return self.forward(out, reverse=True) 57 | 58 | 59 | class PureAffineDoubleCouplingFlowBlock(nn.Module): 60 | def __init__(self, in_channels, hidden_dim, hidden_depth): 61 | super().__init__() 62 | self.coupling = DoubleVectorCouplingBlock(in_channels, 63 | hidden_dim, 64 | hidden_depth) 65 | 66 | def forward(self, x, reverse=False): 67 | if not reverse: 68 | h = x 69 | logdet = 0.0 70 | h, ld = self.coupling(h) 71 | logdet += ld 72 | return h, logdet 73 | else: 74 | h = x 75 | h = self.coupling(h, reverse=True) 76 | return h 77 | 78 | def reverse(self, out): 79 | return self.forward(out, reverse=True) 80 | 81 | 82 | class ConditionalFlow(nn.Module): 83 | """Flat version. Feeds an embedding into the flow in every block""" 84 | def __init__(self, in_channels, embedding_dim, hidden_dim, hidden_depth, 85 | n_flows, conditioning_option="none", activation='lrelu'): 86 | super().__init__() 87 | self.in_channels = in_channels 88 | self.cond_channels = embedding_dim 89 | self.mid_channels = hidden_dim 90 | self.num_blocks = hidden_depth 91 | self.n_flows = n_flows 92 | self.conditioning_option = conditioning_option 93 | 94 | self.sub_layers = nn.ModuleList() 95 | if self.conditioning_option.lower() != "none": 96 | self.conditioning_layers = nn.ModuleList() 97 | for flow in range(self.n_flows): 98 | self.sub_layers.append(ConditionalFlatDoubleCouplingFlowBlock( 99 | self.in_channels, self.cond_channels, self.mid_channels, 100 | self.num_blocks, activation=activation) 101 | ) 102 | if self.conditioning_option.lower() != "none": 103 | self.conditioning_layers.append(nn.Conv2d(self.cond_channels, self.cond_channels, 1)) 104 | 105 | def forward(self, x, embedding, reverse=False): 106 | hconds = list() 107 | hcond = embedding[:, :, None, None] 108 | for i in range(self.n_flows): 109 | if self.conditioning_option.lower() == "parallel": 110 | hcond = self.conditioning_layers[i](embedding) 111 | elif self.conditioning_option.lower() == "sequential": 112 | hcond = self.conditioning_layers[i](hcond) 113 | hconds.append(hcond) 114 | if not reverse: 115 | logdet = 0.0 116 | for i in range(self.n_flows): 117 | x, logdet_ = self.sub_layers[i](x, hconds[i]) 118 | logdet = logdet + logdet_ 119 | return x, logdet 120 | else: 121 | for i in reversed(range(self.n_flows)): 122 | x = self.sub_layers[i](x, hconds[i], reverse=True) 123 | return x 124 | 125 | def reverse(self, out, xcond): 126 | return self(out, xcond, reverse=True) 127 | 128 | 129 | class DoubleVectorCouplingBlock(nn.Module): 130 | """Support uneven inputs""" 131 | def __init__(self, in_channels, hidden_dim, hidden_depth=2): 132 | super().__init__() 133 | dim1 = (in_channels // 2) + (in_channels % 2) 134 | dim2 = in_channels // 2 135 | self.s = nn.ModuleList([ 136 | BasicFullyConnectedNet(dim=dim1, out_dim=dim2, depth=hidden_depth, 137 | hidden_dim=hidden_dim, use_tanh=True), 138 | BasicFullyConnectedNet(dim=dim1, out_dim=dim2, depth=hidden_depth, 139 | hidden_dim=hidden_dim, use_tanh=True), 140 | ]) 141 | self.t = nn.ModuleList([ 142 | BasicFullyConnectedNet(dim=dim1, out_dim=dim2, depth=hidden_depth, 143 | hidden_dim=hidden_dim, use_tanh=False), 144 | BasicFullyConnectedNet(dim=dim1, out_dim=dim2, depth=hidden_depth, 145 | hidden_dim=hidden_dim, use_tanh=False), 146 | ]) 147 | 148 | def forward(self, x, reverse=False): 149 | assert len(x.shape) == 4 150 | x = x.squeeze(-1).squeeze(-1) 151 | if not reverse: 152 | logdet = 0 153 | for i in range(len(self.s)): 154 | idx_apply, idx_keep = 0, 1 155 | if i % 2 != 0: 156 | x = torch.cat(torch.chunk(x, 2, dim=1)[::-1], dim=1) 157 | x = torch.chunk(x, 2, dim=1) 158 | scale = self.s[i](x[idx_apply]) 159 | x_ = x[idx_keep] * (scale.exp()) + self.t[i](x[idx_apply]) 160 | x = torch.cat((x[idx_apply], x_), dim=1) 161 | logdet_ = torch.sum(scale.view(x.size(0), -1), dim=1) 162 | logdet = logdet + logdet_ 163 | return x[:,:,None,None], logdet 164 | else: 165 | idx_apply, idx_keep = 0, 1 166 | for i in reversed(range(len(self.s))): 167 | if i % 2 == 0: 168 | x = torch.cat(torch.chunk(x, 2, dim=1)[::-1], dim=1) 169 | x = torch.chunk(x, 2, dim=1) 170 | x_ = (x[idx_keep] - self.t[i](x[idx_apply])) * (self.s[i](x[idx_apply]).neg().exp()) 171 | x = torch.cat((x[idx_apply], x_), dim=1) 172 | return x[:,:,None,None] 173 | 174 | 175 | class ConditionalDoubleVectorCouplingBlock(nn.Module): 176 | def __init__(self, in_channels, cond_channels, hidden_dim, depth=2): 177 | super(ConditionalDoubleVectorCouplingBlock, self).__init__() 178 | self.s = nn.ModuleList([ 179 | BasicFullyConnectedNet(dim=in_channels // 2 + cond_channels, depth=depth, 180 | hidden_dim=hidden_dim, use_tanh=True, 181 | out_dim=in_channels // 2) for _ in range(2)]) 182 | self.t = nn.ModuleList([ 183 | BasicFullyConnectedNet(dim=in_channels // 2 + cond_channels, depth=depth, 184 | hidden_dim=hidden_dim, use_tanh=False, 185 | out_dim=in_channels // 2) for _ in range(2)]) 186 | 187 | def forward(self, x, xc, reverse=False): 188 | assert len(x.shape) == 4 189 | assert len(xc.shape) == 4 190 | x = x.squeeze(-1).squeeze(-1) 191 | xc = xc.squeeze(-1).squeeze(-1) 192 | if not reverse: 193 | logdet = 0 194 | for i in range(len(self.s)): 195 | idx_apply, idx_keep = 0, 1 196 | if i % 2 != 0: 197 | x = torch.cat(torch.chunk(x, 2, dim=1)[::-1], dim=1) 198 | x = torch.chunk(x, 2, dim=1) 199 | conditioner_input = torch.cat((x[idx_apply], xc), dim=1) 200 | scale = self.s[i](conditioner_input) 201 | x_ = x[idx_keep] * scale.exp() + self.t[i](conditioner_input) 202 | x = torch.cat((x[idx_apply], x_), dim=1) 203 | logdet_ = torch.sum(scale, dim=1) 204 | logdet = logdet + logdet_ 205 | return x[:, :, None, None], logdet 206 | else: 207 | idx_apply, idx_keep = 0, 1 208 | for i in reversed(range(len(self.s))): 209 | if i % 2 == 0: 210 | x = torch.cat(torch.chunk(x, 2, dim=1)[::-1], dim=1) 211 | x = torch.chunk(x, 2, dim=1) 212 | conditioner_input = torch.cat((x[idx_apply], xc), dim=1) 213 | x_ = (x[idx_keep] - self.t[i](conditioner_input)) * self.s[i](conditioner_input).neg().exp() 214 | x = torch.cat((x[idx_apply], x_), dim=1) 215 | return x[:, :, None, None] 216 | 217 | 218 | class ConditionalFlatDoubleCouplingFlowBlock(nn.Module): 219 | def __init__(self, in_channels, cond_channels, hidden_dim, hidden_depth, activation="lrelu"): 220 | super().__init__() 221 | __possible_activations = {"lrelu": InvLeakyRelu, 222 | "none": IgnoreLeakyRelu 223 | } 224 | self.norm_layer = ActNorm(in_channels, logdet=True) 225 | self.coupling = ConditionalDoubleVectorCouplingBlock(in_channels, 226 | cond_channels, 227 | hidden_dim, 228 | hidden_depth) 229 | self.activation = __possible_activations[activation]() 230 | self.shuffle = Shuffle(in_channels) 231 | 232 | def forward(self, x, xcond, reverse=False): 233 | if not reverse: 234 | h = x 235 | logdet = 0.0 236 | h, ld = self.norm_layer(h) 237 | logdet += ld 238 | h, ld = self.activation(h) 239 | logdet += ld 240 | h, ld = self.coupling(h, xcond) 241 | logdet += ld 242 | h, ld = self.shuffle(h) 243 | logdet += ld 244 | return h, logdet 245 | else: 246 | h = x 247 | h = self.shuffle(h, reverse=True) 248 | h = self.coupling(h, xcond, reverse=True) 249 | h = self.activation(h, reverse=True) 250 | h = self.norm_layer(h, reverse=True) 251 | return h 252 | 253 | def reverse(self, out, xcond): 254 | return self.forward(out, xcond, reverse=True) 255 | 256 | 257 | class Shuffle(nn.Module): 258 | def __init__(self, in_channels, **kwargs): 259 | super(Shuffle, self).__init__() 260 | self.in_channels = in_channels 261 | idx = torch.randperm(in_channels) 262 | self.register_buffer('forward_shuffle_idx', nn.Parameter(idx, requires_grad=False)) 263 | self.register_buffer('backward_shuffle_idx', nn.Parameter(torch.argsort(idx), requires_grad=False)) 264 | 265 | def forward(self, x, reverse=False, conditioning=None): 266 | if not reverse: 267 | return x[:, self.forward_shuffle_idx, ...], 0 268 | else: 269 | return x[:, self.backward_shuffle_idx, ...] 270 | 271 | 272 | class IgnoreLeakyRelu(nn.Module): 273 | """performs identity op.""" 274 | 275 | def __init__(self, alpha=0.9): 276 | super().__init__() 277 | 278 | def forward(self, input, reverse=False): 279 | if reverse: 280 | return self.reverse(input) 281 | h = input 282 | return h, 0.0 283 | 284 | def reverse(self, input): 285 | h = input 286 | return h 287 | 288 | 289 | class InvLeakyRelu(nn.Module): 290 | def __init__(self, alpha=0.9): 291 | super().__init__() 292 | self.alpha = alpha 293 | 294 | def forward(self, input, reverse=False): 295 | if reverse: 296 | return self.reverse(input) 297 | scaling = (input >= 0).to(input) + (input < 0).to(input) * self.alpha 298 | h = input * scaling 299 | return h, 0.0 300 | 301 | def reverse(self, input): 302 | scaling = (input >= 0).to(input) + (input < 0).to(input) * self.alpha 303 | h = input / scaling 304 | return h 305 | 306 | 307 | class InvParametricRelu(InvLeakyRelu): 308 | def __init__(self, alpha=0.9): 309 | super().__init__() 310 | self.alpha = nn.Parameter(torch.tensor(alpha), requires_grad=True) 311 | 312 | 313 | class FeatureLayer(nn.Module): 314 | def __init__(self, scale, in_channels=None, norm='AN', width_multiplier=1): 315 | super().__init__() 316 | 317 | norm_options = { 318 | "in": nn.InstanceNorm2d, 319 | "bn": nn.BatchNorm2d, 320 | "an": ActNorm} 321 | 322 | self.scale = scale 323 | self.norm = norm_options[norm.lower()] 324 | self.wm = width_multiplier 325 | if in_channels is None: 326 | self.in_channels = int(self.wm * 64 * min(2 ** (self.scale - 1), 16)) 327 | else: 328 | self.in_channels = in_channels 329 | self.out_channels = int(self.wm * 64 * min(2 ** self.scale, 16)) 330 | self.build() 331 | 332 | def forward(self, input): 333 | x = input 334 | for layer in self.sub_layers: 335 | x = layer(x) 336 | return x 337 | 338 | def build(self): 339 | Norm = functools.partial(self.norm, affine=True) 340 | self.sub_layers = nn.ModuleList([ 341 | nn.Conv2d( 342 | in_channels=self.in_channels, 343 | out_channels=self.out_channels, 344 | kernel_size=4, 345 | stride=2, 346 | padding=1, 347 | bias=False), 348 | Norm(num_features=self.out_channels), 349 | nn.LeakyReLU(0.2)]) 350 | 351 | 352 | class DenseEncoderLayer(nn.Module): 353 | def __init__(self, scale, spatial_size, out_size, in_channels=None, 354 | width_multiplier=1): 355 | super().__init__() 356 | self.scale = scale 357 | self.wm = width_multiplier 358 | self.in_channels = int(self.wm * 64 * min(2 ** (self.scale - 1), 16)) 359 | if in_channels is not None: 360 | self.in_channels = in_channels 361 | self.out_channels = out_size 362 | self.kernel_size = spatial_size 363 | self.build() 364 | 365 | def forward(self, input): 366 | x = input 367 | for layer in self.sub_layers: 368 | x = layer(x) 369 | return x 370 | 371 | def build(self): 372 | self.sub_layers = nn.ModuleList([ 373 | nn.Conv2d( 374 | in_channels=self.in_channels, 375 | out_channels=self.out_channels, 376 | kernel_size=self.kernel_size, 377 | stride=1, 378 | padding=0, 379 | bias=True)]) 380 | -------------------------------------------------------------------------------- /net2net/modules/flow/flatflow.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | from net2net.modules.autoencoder.basic import ActNorm 6 | from net2net.modules.flow.blocks import UnconditionalFlatDoubleCouplingFlowBlock, PureAffineDoubleCouplingFlowBlock, \ 7 | ConditionalFlatDoubleCouplingFlowBlock 8 | from net2net.modules.flow.base import NormalizingFlow 9 | from net2net.modules.autoencoder.basic import FeatureLayer, DenseEncoderLayer 10 | from net2net.modules.flow.blocks import BasicFullyConnectedNet 11 | 12 | 13 | class UnconditionalFlatCouplingFlow(NormalizingFlow): 14 | """Flat, multiple blocks of ActNorm, DoubleAffineCoupling, Shuffle""" 15 | def __init__(self, in_channels, n_flows, hidden_dim, hidden_depth): 16 | super().__init__() 17 | self.in_channels = in_channels 18 | self.mid_channels = hidden_dim 19 | self.num_blocks = hidden_depth 20 | self.n_flows = n_flows 21 | self.sub_layers = nn.ModuleList() 22 | 23 | for flow in range(self.n_flows): 24 | self.sub_layers.append(UnconditionalFlatDoubleCouplingFlowBlock( 25 | self.in_channels, self.mid_channels, 26 | self.num_blocks) 27 | ) 28 | 29 | def forward(self, x, reverse=False): 30 | if len(x.shape) == 2: 31 | x = x[:,:,None,None] 32 | self.last_outs = [] 33 | self.last_logdets = [] 34 | if not reverse: 35 | logdet = 0.0 36 | for i in range(self.n_flows): 37 | x, logdet_ = self.sub_layers[i](x) 38 | logdet = logdet + logdet_ 39 | self.last_outs.append(x) 40 | self.last_logdets.append(logdet) 41 | return x, logdet 42 | else: 43 | for i in reversed(range(self.n_flows)): 44 | x = self.sub_layers[i](x, reverse=True) 45 | return x 46 | 47 | def reverse(self, out): 48 | if len(out.shape) == 2: 49 | out = out[:,:,None,None] 50 | return self(out, reverse=True) 51 | 52 | def sample(self, num_samples, device="cpu"): 53 | zz = torch.randn(num_samples, self.in_channels, 1, 1).to(device) 54 | return self.reverse(zz) 55 | 56 | def get_last_layer(self): 57 | return getattr(self.sub_layers[-1].coupling.t[-1].main[-1], 'weight') 58 | 59 | 60 | class PureAffineFlatCouplingFlow(UnconditionalFlatCouplingFlow): 61 | """Flat, multiple blocks of DoubleAffineCoupling""" 62 | def __init__(self, in_channels, n_flows, hidden_dim, hidden_depth): 63 | super().__init__(in_channels, n_flows, hidden_dim, hidden_depth) 64 | del self.sub_layers 65 | self.sub_layers = nn.ModuleList() 66 | for flow in range(self.n_flows): 67 | self.sub_layers.append(PureAffineDoubleCouplingFlowBlock( 68 | self.in_channels, self.mid_channels, 69 | self.num_blocks) 70 | ) 71 | 72 | 73 | class DenseEmbedder(nn.Module): 74 | """Supposed to map small-scale features (e.g. labels) to some given latent dim""" 75 | def __init__(self, in_dim, up_dim, depth=4, given_dims=None): 76 | super().__init__() 77 | self.net = nn.ModuleList() 78 | if given_dims is not None: 79 | assert given_dims[0] == in_dim 80 | assert given_dims[-1] == up_dim 81 | dims = given_dims 82 | else: 83 | dims = np.linspace(in_dim, up_dim, depth).astype(int) 84 | for l in range(len(dims)-2): 85 | self.net.append(nn.Conv2d(dims[l], dims[l + 1], 1)) 86 | self.net.append(ActNorm(dims[l + 1])) 87 | self.net.append(nn.LeakyReLU(0.2)) 88 | 89 | self.net.append(nn.Conv2d(dims[-2], dims[-1], 1)) 90 | 91 | def forward(self, x): 92 | for layer in self.net: 93 | x = layer(x) 94 | return x.squeeze(-1).squeeze(-1) 95 | 96 | 97 | class Embedder(nn.Module): 98 | """Embeds a 4-dim tensor onto dense latent code, much like the classic encoder.""" 99 | def __init__(self, in_spatial_size, in_channels, emb_dim, n_down=4): 100 | super().__init__() 101 | self.feature_layers = nn.ModuleList() 102 | norm = 'an' # hard coded yes 103 | bottleneck_size = in_spatial_size // 2**n_down 104 | self.feature_layers.append(FeatureLayer(0, in_channels=in_channels, norm=norm)) 105 | for scale in range(1, n_down): 106 | self.feature_layers.append(FeatureLayer(scale, norm=norm)) 107 | self.dense_encode = DenseEncoderLayer(n_down, bottleneck_size, emb_dim) 108 | if n_down == 1: 109 | # add some extra parameters to make model a little more powerful ? 110 | print(" Warning: Embedder for ConditionalTransformer has only one down-sampling step. You might want to " 111 | "increase its capacity.") 112 | 113 | def forward(self, input): 114 | h = input 115 | for layer in self.feature_layers: 116 | h = layer(h) 117 | h = self.dense_encode(h) 118 | return h.squeeze(-1).squeeze(-1) 119 | 120 | 121 | class ConditionalFlatCouplingFlow(nn.Module): 122 | """Flat version. Feeds an embedding into the flow in every block""" 123 | def __init__(self, in_channels, conditioning_dim, embedding_dim, hidden_dim, hidden_depth, 124 | n_flows, conditioning_option="none", activation='lrelu', 125 | conditioning_hidden_dim=256, conditioning_depth=2, conditioner_use_bn=False, 126 | conditioner_use_an=False): 127 | super().__init__() 128 | self.in_channels = in_channels 129 | self.cond_channels = embedding_dim 130 | self.mid_channels = hidden_dim 131 | self.num_blocks = hidden_depth 132 | self.n_flows = n_flows 133 | self.conditioning_option = conditioning_option 134 | # TODO: also for spatial inputs... 135 | if conditioner_use_bn: 136 | assert not conditioner_use_an, 'Can not use ActNorm and BatchNorm simultaneously in Embedder.' 137 | print("Note: Conditioning network uses batch-normalization. " 138 | "Make sure to train with a sufficiently large batch size") 139 | 140 | self.embedder = BasicFullyConnectedNet(dim=conditioning_dim, 141 | depth=conditioning_depth, 142 | out_dim=embedding_dim, 143 | hidden_dim=conditioning_hidden_dim, 144 | use_bn=conditioner_use_bn, 145 | use_an=conditioner_use_an) 146 | 147 | self.sub_layers = nn.ModuleList() 148 | if self.conditioning_option.lower() != "none": 149 | self.conditioning_layers = nn.ModuleList() 150 | for flow in range(self.n_flows): 151 | self.sub_layers.append(ConditionalFlatDoubleCouplingFlowBlock( 152 | self.in_channels, self.cond_channels, self.mid_channels, 153 | self.num_blocks, activation=activation) 154 | ) 155 | if self.conditioning_option.lower() != "none": 156 | self.conditioning_layers.append(nn.Conv2d(self.cond_channels, self.cond_channels, 1)) 157 | 158 | def forward(self, x, cond, reverse=False): 159 | hconds = list() 160 | if len(cond.shape) == 4: 161 | if cond.shape[2] == 1: 162 | assert cond.shape[3] == 1 163 | cond = cond.squeeze(-1).squeeze(-1) 164 | else: 165 | raise ValueError("Spatial conditionings not yet supported. TODO") 166 | embedding = self.embedder(cond.float()) 167 | hcond = embedding[:, :, None, None] 168 | for i in range(self.n_flows): 169 | if self.conditioning_option.lower() == "parallel": 170 | hcond = self.conditioning_layers[i](embedding) 171 | elif self.conditioning_option.lower() == "sequential": 172 | hcond = self.conditioning_layers[i](hcond) 173 | hconds.append(hcond) 174 | if not reverse: 175 | logdet = 0.0 176 | for i in range(self.n_flows): 177 | x, logdet_ = self.sub_layers[i](x, hconds[i]) 178 | logdet = logdet + logdet_ 179 | return x, logdet 180 | else: 181 | for i in reversed(range(self.n_flows)): 182 | x = self.sub_layers[i](x, hconds[i], reverse=True) 183 | return x 184 | 185 | def reverse(self, out, xcond): 186 | return self(out, xcond, reverse=True) 187 | 188 | def sample(self, xc): 189 | zz = torch.randn(xc.shape[0], self.in_channels, 1, 1).to(xc) 190 | return self.reverse(zz, xc) 191 | 192 | 193 | -------------------------------------------------------------------------------- /net2net/modules/flow/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def nll(sample): 6 | if len(sample.shape) == 2: 7 | sample = sample[:,:,None,None] 8 | return 0.5*torch.sum(torch.pow(sample, 2), dim=[1,2,3]) 9 | 10 | 11 | class NLL(nn.Module): 12 | def __init__(self): 13 | super().__init__() 14 | 15 | def forward(self, sample, logdet, split="train"): 16 | nll_loss = torch.mean(nll(sample)) 17 | assert len(logdet.shape) == 1 18 | nlogdet_loss = -torch.mean(logdet) 19 | loss = nll_loss + nlogdet_loss 20 | reference_nll_loss = torch.mean(nll(torch.randn_like(sample))) 21 | log = {f"{split}/total_loss": loss, f"{split}/reference_nll_loss": reference_nll_loss, 22 | f"{split}/nlogdet_loss": nlogdet_loss, f"{split}/nll_loss": nll_loss, 23 | } 24 | return loss, log -------------------------------------------------------------------------------- /net2net/modules/gan/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/net2net/modules/gan/__init__.py -------------------------------------------------------------------------------- /net2net/modules/gan/bigbigan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | import tensorflow.compat.v1 as tf 5 | tf.disable_v2_behavior() 6 | 7 | import tensorflow_hub as hub 8 | 9 | 10 | class BigBiGAN(object): 11 | def __init__(self, 12 | module_path='https://tfhub.dev/deepmind/bigbigan-resnet50/1', 13 | allow_growth=True): 14 | """Initialize a BigBiGAN from the given TF Hub module.""" 15 | self._module = hub.Module(module_path) 16 | 17 | # encode graph 18 | self.enc_ph = self.make_encoder_ph() 19 | self.z_sample = self.encode_graph(self.enc_ph) 20 | self.z_mean = self.encode_graph(self.enc_ph, return_all_features=True)['z_mean'] 21 | 22 | # decode graph 23 | self.gen_ph = self.make_generator_ph() 24 | self.gen_samples = self.generate_graph(self.gen_ph, upsample=True) 25 | 26 | # session 27 | init = tf.global_variables_initializer() 28 | gpu_options = tf.GPUOptions(allow_growth=allow_growth) 29 | self.sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) 30 | self.sess.run(init) 31 | 32 | def generate_graph(self, z, upsample=False): 33 | """Run a batch of latents z through the generator to generate images. 34 | 35 | Args: 36 | z: A batch of 120D Gaussian latents, shape [N, 120]. 37 | 38 | Returns: a batch of generated RGB images, shape [N, 128, 128, 3], range 39 | [-1, 1]. 40 | """ 41 | outputs = self._module(z, signature='generate', as_dict=True) 42 | return outputs['upsampled' if upsample else 'default'] 43 | 44 | def make_generator_ph(self): 45 | """Creates a tf.placeholder with the dtype & shape of generator inputs.""" 46 | info = self._module.get_input_info_dict('generate')['z'] 47 | return tf.placeholder(dtype=info.dtype, shape=info.get_shape()) 48 | 49 | def encode_graph(self, x, return_all_features=False): 50 | """Run a batch of images x through the encoder. 51 | 52 | Args: 53 | x: A batch of data (256x256 RGB images), shape [N, 256, 256, 3], range 54 | [-1, 1]. 55 | return_all_features: If True, return all features computed by the encoder. 56 | Otherwise (default) just return a sample z_hat. 57 | 58 | Returns: the sample z_hat of shape [N, 120] (or a dict of all features if 59 | return_all_features). 60 | """ 61 | outputs = self._module(x, signature='encode', as_dict=True) 62 | return outputs if return_all_features else outputs['z_sample'] 63 | 64 | def make_encoder_ph(self): 65 | """Creates a tf.placeholder with the dtype & shape of encoder inputs.""" 66 | info = self._module.get_input_info_dict('encode')['x'] 67 | return tf.placeholder(dtype=info.dtype, shape=info.get_shape()) 68 | 69 | @torch.no_grad() 70 | def encode(self, x_torch): 71 | x_np = x_torch.detach().permute(0,2,3,1).cpu().numpy() 72 | feed_dict = {self.enc_ph: x_np} 73 | z = self.sess.run(self.z_sample, feed_dict=feed_dict) 74 | z_torch = torch.tensor(z).to(device=x_torch.device) 75 | return z_torch.unsqueeze(-1).unsqueeze(-1) 76 | 77 | @torch.no_grad() 78 | def decode(self, z_torch): 79 | z_np = z_torch.detach().squeeze(-1).squeeze(-1).cpu().numpy() 80 | feed_dict = {self.gen_ph: z_np} 81 | x = self.sess.run(self.gen_samples, feed_dict=feed_dict) 82 | x = x.transpose(0,3,1,2) 83 | x_torch = torch.tensor(x).to(device=z_torch.device) 84 | return x_torch 85 | 86 | def eval(self): 87 | # interface requirement 88 | return self 89 | -------------------------------------------------------------------------------- /net2net/modules/labels/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/net2net/modules/labels/__init__.py -------------------------------------------------------------------------------- /net2net/modules/labels/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | 7 | class Labelator(nn.Module): 8 | def __init__(self, num_classes, as_one_hot=True): 9 | super().__init__() 10 | self.num_classes = num_classes 11 | self.as_one_hot = as_one_hot 12 | 13 | def encode(self, x): 14 | if self.as_one_hot: 15 | x = self.make_one_hot(x) 16 | return x 17 | 18 | def other_label(self, given_label): 19 | # if only two classes are present, inverts them 20 | others = [] 21 | for l in given_label: 22 | other = int(np.random.choice(np.arange(self.num_classes))) 23 | while other == l: 24 | other = int(np.random.choice(np.arange(self.num_classes))) 25 | others.append(other) 26 | return torch.LongTensor(others) 27 | 28 | def make_one_hot(self, label): 29 | one_hot = F.one_hot(label, num_classes=self.num_classes) 30 | return one_hot 31 | -------------------------------------------------------------------------------- /net2net/modules/mlp/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/net2net/modules/mlp/__init__.py -------------------------------------------------------------------------------- /net2net/modules/mlp/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import torch.nn.functional as F 5 | 6 | 7 | class NeRFMLP(nn.Module): 8 | """basic MLP from the NERF paper""" 9 | def __init__(self, in_dim, out_dim, width=256): 10 | super().__init__() 11 | self.D = 8 12 | self.W = width # hidden dim 13 | self.skips = [4] 14 | self.n_in = in_dim 15 | self.out_dim = out_dim 16 | 17 | self.layers = nn.ModuleList() 18 | self.layers.append(nn.Linear(self.n_in, self.W)) 19 | for i in range(1, self.D): 20 | if i-1 in self.skips: 21 | nin = self.W + self.n_in 22 | else: 23 | nin = self.W 24 | self.layers.append(nn.Linear(nin, self.W)) 25 | self.out_layer = nn.Linear(self.W, self.out_dim) 26 | 27 | def forward(self, z): 28 | h = z 29 | for i in range(self.D): 30 | h = self.layers[i](h) 31 | h = F.relu(h) 32 | if i in self.skips: 33 | h = torch.cat([h,z], dim=1) 34 | h = self.out_layer(h) 35 | return h 36 | 37 | 38 | class SineLayer(nn.Module): 39 | # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of omega_0. 40 | 41 | # If is_first=True, omega_0 is a frequency factor which simply multiplies the activations before the 42 | # nonlinearity. Different signals may require different omega_0 in the first layer - this is a 43 | # hyperparameter. 44 | 45 | # If is_first=False, then the weights will be divided by omega_0 so as to keep the magnitude of 46 | # activations constant, but boost gradients to the weight matrix (see supplement Sec. 1.5) 47 | 48 | def __init__(self, in_features, out_features, bias=True, 49 | is_first=False, omega_0=30): 50 | super().__init__() 51 | self.omega_0 = omega_0 52 | self.is_first = is_first 53 | self.in_features = in_features 54 | self.linear = nn.Linear(in_features, out_features, bias=bias) 55 | self.init_weights() 56 | 57 | def init_weights(self): 58 | with torch.no_grad(): 59 | if self.is_first: 60 | self.linear.weight.uniform_(-1 / self.in_features, 61 | 1 / self.in_features) 62 | else: 63 | self.linear.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0, 64 | np.sqrt(6 / self.in_features) / self.omega_0) 65 | 66 | def forward(self, input): 67 | return torch.sin(self.omega_0 * self.linear(input)) 68 | 69 | def forward_with_intermediate(self, input): 70 | # For visualization of activation distributions 71 | intermediate = self.omega_0 * self.linear(input) 72 | return torch.sin(intermediate), intermediate 73 | 74 | 75 | class Siren(nn.Module): 76 | def __init__(self, in_features, hidden_features, hidden_layers, out_features, outermost_linear=False, 77 | first_omega_0=30, hidden_omega_0=30.): 78 | super().__init__() 79 | 80 | self.net = [] 81 | self.net.append(SineLayer(in_features, hidden_features, 82 | is_first=True, omega_0=first_omega_0)) 83 | 84 | for i in range(hidden_layers): 85 | self.net.append(SineLayer(hidden_features, hidden_features, 86 | is_first=False, omega_0=hidden_omega_0)) 87 | 88 | if outermost_linear: 89 | final_linear = nn.Linear(hidden_features, out_features) 90 | 91 | with torch.no_grad(): 92 | final_linear.weight.uniform_(-np.sqrt(6 / hidden_features) / hidden_omega_0, 93 | np.sqrt(6 / hidden_features) / hidden_omega_0) 94 | 95 | self.net.append(final_linear) 96 | else: 97 | self.net.append(SineLayer(hidden_features, out_features, 98 | is_first=False, omega_0=hidden_omega_0)) 99 | 100 | self.net = nn.Sequential(*self.net) 101 | 102 | def forward(self, coords): 103 | #coords = coords.clone().detach().requires_grad_(True) # allows to take derivative w.r.t. input 104 | output = self.net(coords) 105 | #return output, coords 106 | return output 107 | 108 | def forward_with_activations(self, coords, retain_grad=False): 109 | '''Returns not only model output, but also intermediate activations. 110 | Only used for visualizing activations later!''' 111 | activations = OrderedDict() 112 | activation_count = 0 113 | x = coords.clone().detach().requires_grad_(True) 114 | activations['input'] = x 115 | for i, layer in enumerate(self.net): 116 | if isinstance(layer, SineLayer): 117 | x, intermed = layer.forward_with_intermediate(x) 118 | 119 | if retain_grad: 120 | x.retain_grad() 121 | intermed.retain_grad() 122 | 123 | activations['_'.join((str(layer.__class__), "%d" % activation_count))] = intermed 124 | activation_count += 1 125 | else: 126 | x = layer(x) 127 | 128 | if retain_grad: 129 | x.retain_grad() 130 | 131 | activations['_'.join((str(layer.__class__), "%d" % activation_count))] = x 132 | activation_count += 1 133 | return activations 134 | 135 | 136 | # And finally, differential operators that allow us to leverage autograd to 137 | # compute gradients, the laplacian, etc. 138 | 139 | def laplace(y, x): 140 | grad = gradient(y, x) 141 | return divergence(grad, x) 142 | 143 | 144 | def divergence(y, x): 145 | div = 0. 146 | for i in range(y.shape[-1]): 147 | div += torch.autograd.grad(y[..., i], x, torch.ones_like(y[..., i]), create_graph=True)[0][..., i:i+1] 148 | return div 149 | 150 | 151 | def gradient(y, x, grad_outputs=None): 152 | if grad_outputs is None: 153 | grad_outputs = torch.ones_like(y) 154 | grad = torch.autograd.grad(y, [x], grad_outputs=grad_outputs, create_graph=True)[0] 155 | return grad 156 | 157 | 158 | if __name__ == "__main__": 159 | siren = Siren(2, 64, 2, 2) 160 | x = torch.randn(11, 2) 161 | x.requires_grad = True 162 | y = siren(x).mean() 163 | grad1 = torch.autograd.grad(y, x)[0] 164 | print(grad1.shape) 165 | print("done.") -------------------------------------------------------------------------------- /net2net/modules/sbert/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/net2net/5311cb8495e378158be1f5374004eb9a3f7de992/net2net/modules/sbert/__init__.py -------------------------------------------------------------------------------- /net2net/modules/sbert/model.py: -------------------------------------------------------------------------------- 1 | # check out https://github.com/UKPLab/sentence-transformers, 2 | # list of pretrained models @ https://www.sbert.net/docs/pretrained_models.html 3 | 4 | from sentence_transformers import SentenceTransformer 5 | import numpy as np 6 | import torch.nn as nn 7 | 8 | 9 | class SentenceEmbedder(nn.Module): 10 | def __init__(self, version='bert-large-nli-stsb-mean-tokens'): 11 | super().__init__() 12 | np.set_printoptions(threshold=100) 13 | # Load Sentence model (based on BERT) from URL 14 | self.model = SentenceTransformer(version, device="cuda") 15 | self.model.eval() 16 | 17 | def forward(self, sentences): 18 | """sentences are expect to be a list of strings, e.g. 19 | sentences = ['This framework generates embeddings for each input sentence', 20 | 'Sentences are passed as a list of string.', 21 | 'The quick brown fox jumps over the lazy dog.' 22 | ] 23 | """ 24 | sentence_embeddings = self.model.encode(sentences, batch_size=len(sentences), show_progress_bar=False, 25 | convert_to_tensor=True) 26 | return sentence_embeddings.cuda() 27 | 28 | def encode(self, sentences): 29 | embeddings = self(sentences) 30 | return embeddings[:,:,None,None] 31 | 32 | 33 | if __name__ == '__main__': 34 | model = SentenceEmbedder(version='distilroberta-base-paraphrase-v1') 35 | sentences = ['This framework generates embeddings for each input sentence', 36 | 'Sentences are passed as a list of string.', 37 | 'The quick brown fox jumps over the lazy dog.' 38 | ] 39 | emb = model.encode(sentences) 40 | print(emb.shape) 41 | print("done.") 42 | -------------------------------------------------------------------------------- /net2net/modules/util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.nn import init 6 | from torchvision import models 7 | import os 8 | from PIL import Image, ImageDraw 9 | 10 | 11 | def log_txt_as_img(wh, xc): 12 | b = len(xc) 13 | txts = list() 14 | for bi in range(b): 15 | txt = Image.new("RGB", wh, color="white") 16 | draw = ImageDraw.Draw(txt) 17 | nc = int(40 * (wh[0]/256)) 18 | lines = "\n".join(xc[bi][start:start+nc] for start in range(0, len(xc[bi]), nc)) 19 | draw.text((0,0), lines, fill="black") 20 | txt = np.array(txt).transpose(2,0,1)/127.5-1.0 21 | txts.append(txt) 22 | txts = np.stack(txts) 23 | txts = torch.tensor(txts) 24 | return txts 25 | 26 | 27 | class Downscale(nn.Module): 28 | def __init__(self, mode="bilinear", size=32): 29 | super().__init__() 30 | self.mode = mode 31 | self.out_size = size 32 | 33 | def forward(self, x): 34 | x = F.interpolate(x, mode=self.mode, size=self.out_size) 35 | return x 36 | 37 | 38 | class DownscaleUpscale(nn.Module): 39 | def __init__(self, mode_down="bilinear", mode_up="bilinear", size=32): 40 | super().__init__() 41 | self.mode_down = mode_down 42 | self.mode_up = mode_up 43 | self.out_size = size 44 | 45 | def forward(self, x): 46 | assert len(x.shape) == 4 47 | z = F.interpolate(x, mode=self.mode_down, size=self.out_size) 48 | x = F.interpolate(z, mode=self.mode_up, size=x.shape[2]) 49 | return x 50 | 51 | 52 | class TpsGridGen(nn.Module): 53 | def __init__(self, out_h=256, out_w=192, use_regular_grid=True, grid_size=3, reg_factor=0, use_cuda=True): 54 | super(TpsGridGen, self).__init__() 55 | self.out_h, self.out_w = out_h, out_w 56 | self.reg_factor = reg_factor 57 | self.use_cuda = use_cuda 58 | 59 | # create grid in numpy 60 | self.grid = np.zeros( [self.out_h, self.out_w, 3], dtype=np.float32) 61 | # sampling grid with dim-0 coords (Y) 62 | self.grid_X,self.grid_Y = np.meshgrid(np.linspace(-1,1,out_w),np.linspace(-1,1,out_h)) 63 | # grid_X,grid_Y: size [1,H,W,1,1] 64 | self.grid_X = torch.FloatTensor(self.grid_X).unsqueeze(0).unsqueeze(3) 65 | self.grid_Y = torch.FloatTensor(self.grid_Y).unsqueeze(0).unsqueeze(3) 66 | if use_cuda: 67 | self.grid_X = self.grid_X.cuda() 68 | self.grid_Y = self.grid_Y.cuda() 69 | 70 | # initialize regular grid for control points P_i 71 | if use_regular_grid: 72 | axis_coords = np.linspace(-1,1,grid_size) 73 | self.N = grid_size*grid_size 74 | P_Y,P_X = np.meshgrid(axis_coords,axis_coords) 75 | P_X = np.reshape(P_X,(-1,1)) # size (N,1) 76 | P_Y = np.reshape(P_Y,(-1,1)) # size (N,1) 77 | P_X = torch.FloatTensor(P_X) 78 | P_Y = torch.FloatTensor(P_Y) 79 | self.P_X_base = P_X.clone() 80 | self.P_Y_base = P_Y.clone() 81 | self.Li = self.compute_L_inverse(P_X,P_Y).unsqueeze(0) 82 | self.P_X = P_X.unsqueeze(2).unsqueeze(3).unsqueeze(4).transpose(0,4) 83 | self.P_Y = P_Y.unsqueeze(2).unsqueeze(3).unsqueeze(4).transpose(0,4) 84 | if use_cuda: 85 | self.P_X = self.P_X.cuda() 86 | self.P_Y = self.P_Y.cuda() 87 | self.P_X_base = self.P_X_base.cuda() 88 | self.P_Y_base = self.P_Y_base.cuda() 89 | 90 | 91 | def forward(self, theta): 92 | warped_grid = self.apply_transformation(theta,torch.cat((self.grid_X,self.grid_Y),3)) 93 | 94 | return warped_grid 95 | 96 | def compute_L_inverse(self,X,Y): 97 | N = X.size()[0] # num of points (along dim 0) 98 | # construct matrix K 99 | Xmat = X.expand(N,N) 100 | Ymat = Y.expand(N,N) 101 | P_dist_squared = torch.pow(Xmat-Xmat.transpose(0,1),2)+torch.pow(Ymat-Ymat.transpose(0,1),2) 102 | P_dist_squared[P_dist_squared==0]=1 # make diagonal 1 to avoid NaN in log computation 103 | K = torch.mul(P_dist_squared,torch.log(P_dist_squared)) 104 | # construct matrix L 105 | O = torch.FloatTensor(N,1).fill_(1) 106 | Z = torch.FloatTensor(3,3).fill_(0) 107 | P = torch.cat((O,X,Y),1) 108 | L = torch.cat((torch.cat((K,P),1),torch.cat((P.transpose(0,1),Z),1)),0) 109 | Li = torch.inverse(L) 110 | if self.use_cuda: 111 | Li = Li.cuda() 112 | return Li 113 | 114 | def apply_transformation(self,theta,points): 115 | if theta.dim()==2: 116 | theta = theta.unsqueeze(2).unsqueeze(3) 117 | # points should be in the [B,H,W,2] format, 118 | # where points[:,:,:,0] are the X coords 119 | # and points[:,:,:,1] are the Y coords 120 | 121 | # input are the corresponding control points P_i 122 | batch_size = theta.size()[0] 123 | # split theta into point coordinates 124 | Q_X=theta[:,:self.N,:,:].squeeze(3) 125 | Q_Y=theta[:,self.N:,:,:].squeeze(3) 126 | Q_X = Q_X + self.P_X_base.expand_as(Q_X) 127 | Q_Y = Q_Y + self.P_Y_base.expand_as(Q_Y) 128 | 129 | # get spatial dimensions of points 130 | points_b = points.size()[0] 131 | points_h = points.size()[1] 132 | points_w = points.size()[2] 133 | 134 | # repeat pre-defined control points along spatial dimensions of points to be transformed 135 | P_X = self.P_X.expand((1,points_h,points_w,1,self.N)) 136 | P_Y = self.P_Y.expand((1,points_h,points_w,1,self.N)) 137 | 138 | # compute weigths for non-linear part 139 | W_X = torch.bmm(self.Li[:,:self.N,:self.N].expand((batch_size,self.N,self.N)),Q_X) 140 | W_Y = torch.bmm(self.Li[:,:self.N,:self.N].expand((batch_size,self.N,self.N)),Q_Y) 141 | # reshape 142 | # W_X,W,Y: size [B,H,W,1,N] 143 | W_X = W_X.unsqueeze(3).unsqueeze(4).transpose(1,4).repeat(1,points_h,points_w,1,1) 144 | W_Y = W_Y.unsqueeze(3).unsqueeze(4).transpose(1,4).repeat(1,points_h,points_w,1,1) 145 | # compute weights for affine part 146 | A_X = torch.bmm(self.Li[:,self.N:,:self.N].expand((batch_size,3,self.N)),Q_X) 147 | A_Y = torch.bmm(self.Li[:,self.N:,:self.N].expand((batch_size,3,self.N)),Q_Y) 148 | # reshape 149 | # A_X,A,Y: size [B,H,W,1,3] 150 | A_X = A_X.unsqueeze(3).unsqueeze(4).transpose(1,4).repeat(1,points_h,points_w,1,1) 151 | A_Y = A_Y.unsqueeze(3).unsqueeze(4).transpose(1,4).repeat(1,points_h,points_w,1,1) 152 | 153 | # compute distance P_i - (grid_X,grid_Y) 154 | # grid is expanded in point dim 4, but not in batch dim 0, as points P_X,P_Y are fixed for all batch 155 | points_X_for_summation = points[:,:,:,0].unsqueeze(3).unsqueeze(4).expand(points[:,:,:,0].size()+(1,self.N)) 156 | points_Y_for_summation = points[:,:,:,1].unsqueeze(3).unsqueeze(4).expand(points[:,:,:,1].size()+(1,self.N)) 157 | 158 | if points_b==1: 159 | delta_X = points_X_for_summation-P_X 160 | delta_Y = points_Y_for_summation-P_Y 161 | else: 162 | # use expanded P_X,P_Y in batch dimension 163 | delta_X = points_X_for_summation-P_X.expand_as(points_X_for_summation) 164 | delta_Y = points_Y_for_summation-P_Y.expand_as(points_Y_for_summation) 165 | 166 | dist_squared = torch.pow(delta_X,2)+torch.pow(delta_Y,2) 167 | # U: size [1,H,W,1,N] 168 | dist_squared[dist_squared==0]=1 # avoid NaN in log computation 169 | U = torch.mul(dist_squared,torch.log(dist_squared)) 170 | 171 | # expand grid in batch dimension if necessary 172 | points_X_batch = points[:,:,:,0].unsqueeze(3) 173 | points_Y_batch = points[:,:,:,1].unsqueeze(3) 174 | if points_b==1: 175 | points_X_batch = points_X_batch.expand((batch_size,)+points_X_batch.size()[1:]) 176 | points_Y_batch = points_Y_batch.expand((batch_size,)+points_Y_batch.size()[1:]) 177 | 178 | points_X_prime = A_X[:,:,:,:,0]+ \ 179 | torch.mul(A_X[:,:,:,:,1],points_X_batch) + \ 180 | torch.mul(A_X[:,:,:,:,2],points_Y_batch) + \ 181 | torch.sum(torch.mul(W_X,U.expand_as(W_X)),4) 182 | 183 | points_Y_prime = A_Y[:,:,:,:,0]+ \ 184 | torch.mul(A_Y[:,:,:,:,1],points_X_batch) + \ 185 | torch.mul(A_Y[:,:,:,:,2],points_Y_batch) + \ 186 | torch.sum(torch.mul(W_Y,U.expand_as(W_Y)),4) 187 | 188 | return torch.cat((points_X_prime,points_Y_prime),3) 189 | 190 | 191 | def random_tps(*args, grid_size=4, reg_factor=0, strength_factor=1.0): 192 | """Random TPS. Device and size determined from first argument, all 193 | remaining arguments transformed with the same parameters.""" 194 | x = args[0] 195 | is_np = type(x) == np.ndarray 196 | no_batch = len(x.shape) == 3 197 | if no_batch: 198 | args = [x[None,...] for x in args] 199 | if is_np: 200 | args = [torch.tensor(x.copy()).permute(0,3,1,2) for x in args] 201 | x = args[0] 202 | use_cuda = x.is_cuda 203 | b,c,h,w = x.shape 204 | grid_size = 4 205 | tps = TpsGridGen(out_h=h, 206 | out_w=w, 207 | use_regular_grid=True, 208 | grid_size=grid_size, 209 | reg_factor=reg_factor, 210 | use_cuda=use_cuda) 211 | # theta = b,2*N - first N = X, second N = Y, 212 | control = torch.cat([tps.P_X_base, tps.P_Y_base], dim=0) 213 | control = control[None,:,0][b*[0],...] 214 | theta = (torch.rand(b,2*grid_size*grid_size)*2-1.0) / grid_size * strength_factor 215 | final = control+theta.to(control) 216 | final = torch.clamp(final, -1.0, 1.0) 217 | final = final - control 218 | final[control==-1.0] = 0.0 219 | final[control==1.0] = 0.0 220 | grid = tps(final) 221 | 222 | is_uint8 = [x.dtype == torch.uint8 for x in args] 223 | args = [args[i].to(grid)+0.5 if is_uint8[i] else args[i] for i in range(len(args))] 224 | out = [torch.nn.functional.grid_sample(x, grid, align_corners=True) for x in args] 225 | out = [out[i].to(torch.uint8) if is_uint8[i] else out[i] for i in range(len(out))] 226 | if is_np: 227 | out = [x.permute(0,2,3,1).numpy() for x in out] 228 | if no_batch: 229 | out = [x[0,...] for x in out] 230 | if len(out) == 1: 231 | out = out[0] 232 | return out 233 | 234 | 235 | def count_params(model): 236 | total_params = sum(p.numel() for p in model.parameters()) 237 | return total_params 238 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='net2net', 5 | version='0.0.1', 6 | description='Translate between networks through their latent representations.', 7 | packages=find_packages(), 8 | install_requires=[ 9 | 'torch', 10 | 'numpy', 11 | 'tqdm', 12 | ], 13 | ) --------------------------------------------------------------------------------