├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── datasets.py ├── eval.py ├── hubconf.py ├── illustartionCelebaHQ.jpg ├── illustration.png ├── models ├── DCGAN.py ├── README.md ├── UTs │ ├── __init__.py │ └── test_ac_criterion.py ├── __init__.py ├── base_GAN.py ├── datasets │ ├── __init__.py │ ├── attrib_dataset.py │ ├── hd5.py │ └── utils │ │ ├── __init__.py │ │ └── db_stats.py ├── eval │ ├── build_nn_db.py │ ├── inception.py │ ├── inspirational_generation.py │ ├── laplacian_SWD.py │ ├── metric_plot.py │ ├── nn_metric.py │ └── visualization.py ├── gan_visualizer.py ├── loss_criterions │ ├── GDPP_loss.py │ ├── __init__.py │ ├── ac_criterion.py │ ├── base_loss_criterions.py │ ├── gradient_losses.py │ ├── logistic_loss.py │ └── loss_texture.py ├── metrics │ ├── __init__.py │ ├── inception_score.py │ ├── laplacian_swd.py │ └── nn_score.py ├── networks │ ├── DCGAN_nets.py │ ├── __init__.py │ ├── constant_net.py │ ├── custom_layers.py │ ├── mini_batch_stddev_module.py │ ├── progressive_conv_net.py │ └── styleGAN.py ├── progressive_gan.py ├── styleGAN.py ├── trainer │ ├── DCGAN_trainer.py │ ├── __init__.py │ ├── gan_trainer.py │ ├── progressive_gan_trainer.py │ ├── standard_configurations │ │ ├── __init__.py │ │ ├── dcgan_config.py │ │ ├── pgan_config.py │ │ └── stylegan_config.py │ └── styleGAN_trainer.py └── utils │ ├── __init__.py │ ├── config.py │ ├── image_transform.py │ ├── product_module.py │ └── utils.py ├── requirements.txt ├── save_feature_extractor.py ├── train.py └── visualization ├── __init__.py ├── np_visualizer.py └── visualizer.py /.gitignore: -------------------------------------------------------------------------------- 1 | output_networks/* 2 | *.pyc 3 | setup.py 4 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | Facebook has adopted a Code of Conduct that we expect project participants to adhere to. 4 | Please read the [full text](https://code.fb.com/codeofconduct/) 5 | so that you can understand what actions will and will not be tolerated. 6 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to pytorch_GAN_zoo 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `master`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## Coding Style 30 | * flake8 standards 31 | 32 | ## License 33 | By contributing to pytorch_GAN_zoo, you agree that your contributions will be licensed 34 | under the LICENSE file in the root directory of this source tree. 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2019, Facebook 2 | 3 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 4 | 5 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 6 | 7 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 8 | 9 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 10 | 11 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 12 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pytorch GAN Zoo 2 | 3 | A GAN toolbox for researchers and developers with: 4 | - Progressive Growing of GAN(PGAN): https://arxiv.org/pdf/1710.10196.pdf 5 | - DCGAN: https://arxiv.org/pdf/1511.06434.pdf 6 | - StyleGAN (beta): https://arxiv.org/abs/1812.04948 7 | 8 | illustration 9 | Picture: Generated samples from GANs trained on celebaHQ, fashionGen, DTD. 10 | 11 | 12 | celeba 13 | Picture: fake faces with celebaHQ 14 | 15 | This code also implements diverse tools: 16 | - GDPP method from [GDPP: Learning Diverse Generations Using Determinantal Point Process](https://arxiv.org/abs/1812.00068) 17 | - Image generation "inspired" from a reference image using an already trained GAN from [Inspirational Adversarial Image Generation](https://arxiv.org/abs/1906.11661) 18 | - AC-GAN conditioning from [Conditional Image Synthesis With Auxiliary Classifier GANs](https://arxiv.org/abs/1610.09585) 19 | - [SWD metric](https://hal.archives-ouvertes.fr/hal-00476064/document) 20 | - [Inception Score](https://papers.nips.cc/paper/6125-improved-techniques-for-training-gans.pdf) 21 | - Logistic loss from [Which training method of GANs actually converge](https://arxiv.org/pdf/1801.04406.pdf) 22 | 23 | ## Requirements 24 | 25 | This project requires: 26 | - pytorch 27 | - torchvision 28 | - numpy 29 | - scipy 30 | - h5py (fashionGen) 31 | 32 | Optional: 33 | - visdom 34 | - nevergrad (inspirational generation) 35 | 36 | If you don't already have pytorch or torchvision please have a look at https://pytorch.org/ as the installation command may vary depending on your OS and your version of CUDA. 37 | 38 | You can install all other dependencies with pip by running: 39 | 40 | ``` 41 | pip install -r requirements.txt 42 | ``` 43 | 44 | ## Recommended datasets 45 | - celebA: http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html 46 | - celebAHQ: https://github.com/nperraud/download-celebA-HQ 47 | - fashionGen: https://fashion-gen.com/ 48 | - DTD: https://www.robots.ox.ac.uk/~vgg/data/dtd/ 49 | - CIFAR10: http://www.cs.toronto.edu/~kriz/cifar.html 50 | 51 | ## Quick training 52 | 53 | The datasets.py script allows you to prepare your datasets and build their corresponding configuration files. 54 | 55 | If you want to waste no time and just launch a training session on celeba cropped 56 | 57 | ``` 58 | python datasets.py celeba_cropped $PATH_TO_CELEBA/img_align_celeba/ -o $OUTPUT_DATASET 59 | python train.py PGAN -c config_celeba_cropped.json --restart -n celeba_cropped 60 | ``` 61 | 62 | And wait for a few days. Your checkpoints will be dumped in output_networks/celeba_cropped. You should get 128x128 generations at the end. 63 | 64 | For celebaHQ: 65 | 66 | ``` 67 | python datasets.py celebaHQ $PATH_TO_CELEBAHQ -o $OUTPUT_DATASET - f 68 | python train.py PGAN -c config_celebaHQ.json --restart -n celebaHQ 69 | ``` 70 | 71 | Your checkpoints will be dumped in output_networks/celebaHQ. You should get 1024x1024 generations at the end. 72 | 73 | For fashionGen: 74 | 75 | ``` 76 | python datasets.py fashionGen $PATH_TO_FASHIONGEN_RES_256 -o $OUTPUT_DIR 77 | python train.py PGAN -c config_fashionGen.json --restart -n fashionGen 78 | ``` 79 | 80 | The above command will train the fashionGen model up resolution 256x256. If you want to train fashionGen on a specific sub-dataset for example CLOTHING, run: 81 | 82 | ``` 83 | python train.py PGAN -c config_fashionGen.json --restart -n fashionGen -v CLOTHING 84 | ``` 85 | 86 | Four sub-datasets are available: CLOTHING, SHOES, BAGS and ACCESSORIES. 87 | 88 | For the DTD texture dataset: 89 | 90 | ``` 91 | python datasets.py dtd $PATH_TO_DTD 92 | python train.py PGAN -c config_dtd.json --restart -n dtd 93 | ``` 94 | 95 | For cifar10: 96 | 97 | ``` 98 | python datasets.py cifar10 $PATH_TO_CIFAR10 -o $OUTPUT_DATASET 99 | python train.py PGAN -c config_cifar10.json --restart -n cifar10 100 | ``` 101 | 102 | ## Load a pretrained model with torch.hub 103 | 104 | Models trained on celebaHQ, fashionGen, cifar10 and celeba cropped are available with [torch.hub](https://pytorch.org/docs/stable/hub.html). 105 | 106 | Checkpoints: 107 | - PGAN: 108 | - celebaHQ https://dl.fbaipublicfiles.com/gan_zoo/PGAN/celebaHQ_s6_i80000-6196db68.pth 109 | - celeba_cropped https://dl.fbaipublicfiles.com/gan_zoo/PGAN/celebaCropped_s5_i83000-2b0acc76.pth 110 | - dtd https://dl.fbaipublicfiles.com/gan_zoo/PGAN/testDTD_s5_i96000-04efa39f.pth 111 | 112 | - DCGAN 113 | - fashionGen https://dl.fbaipublicfiles.com/gan_zoo/DCGAN_fashionGen-1d67302.pth 114 | 115 | See hubconf.py for how to load a checkpoint ! 116 | 117 | ## GDPP 118 | 119 | To apply the GDPP loss to your model just add the option --GDPP true to your training command. 120 | 121 | ## (beta) StyleGAN 122 | 123 | To run StyleGAN, use the model name StyleGAN when running train.py. Besides,to run StyleGAN you can use the pre-computed configurations for celeba and celebaHQ. For example: 124 | 125 | ``` 126 | python train.py StyleGAN -c config_celebaHQ.json --restart -n style_gan_celeba 127 | ``` 128 | 129 | ## Advanced guidelines 130 | 131 | ### How to run a training session ? 132 | 133 | ``` 134 | python train.py $MODEL_NAME -c $CONFIGURATION_FILE[-n $RUN_NAME][-d $OUTPUT_DIRECTORY][OVERRIDES] 135 | ``` 136 | 137 | Where: 138 | 139 | 1 - MODEL_NAME is the name of the model you want to run. Currently, two models are available: 140 | - PGAN(progressive growing of gan) 141 | - PPGAN(decoupled version of PGAN) 142 | 143 | 2 - CONFIGURATION_FILE(mandatory): path to a training configuration file. This file is a json file containing at least a pathDB entry with the path to the training dataset. See below for more informations about this file. 144 | 145 | 3 - RUN_NAME is the name you want to give to your training session. All checkpoints will be saved in $OUTPUT_DIRECTORY/$RUN_NAME. Default value is default 146 | 147 | 4 - OUTPUT_DIRECTORY is the directory were all training sessions are saved. Default value is output_networks 148 | 149 | 5 - OVERRIDES: you can overrides some of the models parameters defined in "config" field of the configuration file(see below) in the command line. For example: 150 | 151 | ``` 152 | python train.py PPGAN -c coin.json -n PAN --learningRate 0.2 153 | ``` 154 | 155 | Will force the learning rate to be 0.2 in the training whatever the configuration file coin.json specifies. 156 | 157 | To get all the possible override options, please type: 158 | 159 | ``` 160 | python train.py $MODEL_NAME --overrides 161 | ``` 162 | 163 | ## Configuration file of a training session 164 | 165 | The minimum configuration file for a training session is a json file with the following lines 166 | 167 | ``` 168 | { 169 | "pathDB": PATH_TO_YOUR_DATASET 170 | } 171 | ``` 172 | 173 | Where a dataset can be: 174 | - a folder with all your images in .jpg, .png or .npy format 175 | - a folder with N subfolder and images in it (see imagefolderDataset = True below) 176 | - a .h5 file(cf fashionGen) 177 | 178 | To this you can add a "config" entry giving overrides to the standard configuration. See models/trainer/standard_configurations to see all possible options. For example: 179 | 180 | ``` 181 | { 182 | "pathDB": PATH_TO_YOUR_DATASET, 183 | "config": {"baseLearningRate": 0.1, 184 | "miniBatchSize": 22} 185 | } 186 | ``` 187 | 188 | Will override the learning rate and the mini-batch-size. Please note that if you specify a - -baseLearningRate option in your command line, the command line will prevail. Depending on how you work you might prefer to have specific configuration files for each run or only rely on one configuration file and input your training parameters via the command line. 189 | 190 | Other fields are available on the configuration file, like: 191 | - pathAttribDict(string): path to a .json file matching each image with its attributes. To be more precise with a standard dataset, it is a dictionary with the following entries: 192 | 193 | ``` 194 | { 195 | image_name1.jpg: {attribute1: label, attribute2, label ...} 196 | image_name2.jpg: {attribute1: label, attribute2, label ...} 197 | ... 198 | } 199 | ``` 200 | 201 | With a dataset in the fashionGen format(.h5) it's a dictionary summing up statistics on the class to be sampled. 202 | 203 | - imagefolderDataset(bool): set to true to handle datasets in the torchvision.datasets.ImageFolder format 204 | - selectedAttributes(list): if specified, learn only the given attributes during the training session 205 | - pathPartition(string): path to a partition of the training dataset 206 | - partitionValue(string): if pathPartition is specified, name of the partition to choose 207 | - miniBatchScheduler(dictionary): dictionary updating the size of the mini batch at different scale of the training 208 | ex {"2": 16, "7": 8} meaning that the mini batch size will be 16 from scale 16 to 6 and 8 from scale 7 209 | - configScheduler(dictionary): dictionary updating the model configuration at different scale of the training 210 | ex {"2": {"baseLearningRate": 0.1, "epsilonD": 1}} meaning that the learning rate and epsilonD will be updated to 0.1 and 1 from scale 2 and beyond 211 | 212 | ## How to run a evaluation of the results of your training session ? 213 | 214 | You need to use the eval.py script. 215 | 216 | ### Image generation 217 | 218 | You can generate more images from an existing checkpoint using: 219 | ``` 220 | python eval.py visualization -n $modelName -m $modelType 221 | ``` 222 | 223 | Where modelType is in [PGAN, PPGAN, DCGAN] and modelName is the name given to your model. This script will load the last checkpoint detected at testNets/$modelName. If you want to load a specific iteration, please call: 224 | 225 | ``` 226 | python eval.py visualization -n $modelName -m $modelType -s $SCALE -i $ITER 227 | ``` 228 | 229 | If your model is conditioned, you can ask the visualizer to print out some conditioned generations. First, use --showLabels to see all the available categories and their labels. 230 | 231 | ``` 232 | python eval.py visualization -n $modelName -m $modelType --showLabels 233 | ``` 234 | 235 | Then, run your generation with: 236 | 237 | ``` 238 | python eval.py visualization -n $modelName -m $modelType --$CATEGORY_NAME $LABEL_NAME 239 | ``` 240 | 241 | For example with a model trained on fashionGen: 242 | 243 | ``` 244 | python eval.py visualization -n $modelName -m $modelType --Class T_SHIRT 245 | ``` 246 | 247 | Will plot a batch of T_SHIRTS in visdom. 248 | 249 | ### Fake dataset generation 250 | 251 | To save a randomly generated fake dataset from a checkpoint please use: 252 | 253 | ``` 254 | python eval.py visualization -n $modelName -m $modelType --save_dataset $PATH_TO_THE_OUTPUT_DATASET --size_dataset $SIZE_OF_THE_OUTPUT 255 | ``` 256 | 257 | ### SWD metric 258 | 259 | Using the same kind of configuration file as above, just launch: 260 | 261 | ``` 262 | python eval.py laplacian_SWD -c $CONFIGURATION_FILE -n $modelName -m $modelType 263 | ``` 264 | 265 | Where $CONFIGURATION_FILE is the training configuration file called by train.py (see above): it must contains a "pathDB" field pointing to path to the dataset's directory. For example, if you followed the instruction of the Quick Training section to launch a training session on celebaHQ your configuration file will be config_celebaHQ.json. 266 | 267 | You can add optional arguments: 268 | 269 | - -s $SCALE: specify the scale at which the evaluation should be done(if not set, will take the highest one) 270 | - -i $ITER: specify the iteration to evaluate(if not set, will take the highest one) 271 | - --selfNoise: returns the typical noise of the SWD distance for each resolution 272 | 273 | ### Inspirational generation 274 | 275 | To make an inspirational generation, you first need to build a feature extractor: 276 | 277 | ``` 278 | python save_feature_extractor.py {vgg16, vgg19} $PATH_TO_THE_OUTPUT_FEATURE_EXTRACTOR --layers 3 4 5 279 | ``` 280 | 281 | Then run your model: 282 | 283 | ``` 284 | python eval.py inspirational_generation -n $modelName -m $modelType --inputImage $pathTotheInputImage -f $PATH_TO_THE_OUTPUT_FEATURE_EXTRACTOR 285 | ``` 286 | 287 | ### I have generated my metrics. How can i plot them on visdom ? 288 | 289 | Just run 290 | ``` 291 | python eval.py metric_plot -n $modelName 292 | ``` 293 | 294 | ## LICENSE 295 | 296 | This project is under BSD-3 license. 297 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import importlib 3 | import argparse 4 | import sys 5 | 6 | if __name__ == "__main__": 7 | 8 | parser = argparse.ArgumentParser(description='Testing script', add_help=False) 9 | parser.add_argument('evaluation_name', type=str, 10 | help='Name of the evaluation method to launch. To get \ 11 | the arguments specific to an evaluation method please \ 12 | use: eval.py evaluation_name -h') 13 | parser.add_argument('--no_vis', help='Print more data', 14 | action='store_true') 15 | parser.add_argument('--np_vis', help=' Replace visdom by a numpy based \ 16 | visualizer (SLURM)', 17 | action='store_true') 18 | parser.add_argument('-m', '--module', help="Module to evaluate, available\ 19 | modules: PGAN, PPGAN, DCGAN", 20 | type=str, dest="module") 21 | parser.add_argument('-n', '--name', help="Model's name", 22 | type=str, dest="name") 23 | parser.add_argument('-d', '--dir', help='Output directory', 24 | type=str, dest="dir", default="output_networks") 25 | parser.add_argument('-i', '--iter', help='Iteration to evaluate', 26 | type=int, dest="iter") 27 | parser.add_argument('-s', '--scale', help='Scale to evaluate', 28 | type=int, dest="scale") 29 | parser.add_argument('-c', '--config', help='Training configuration', 30 | type=str, dest="config") 31 | parser.add_argument('-v', '--partitionValue', help="Partition's value", 32 | type=str, dest="partition_value") 33 | parser.add_argument("-A", "--statsFile", dest="statsFile", 34 | type=str, help="Path to the statistics file") 35 | 36 | if len(sys.argv) > 1 and sys.argv[1] in ['-h', '--help']: 37 | parser.print_help() 38 | sys.exit() 39 | 40 | args, unknown = parser.parse_known_args() 41 | 42 | vis_module = None 43 | if args.np_vis: 44 | vis_module = importlib.import_module("visualization.np_visualizer") 45 | elif args.no_vis: 46 | print("Visualization disabled") 47 | else: 48 | vis_module = importlib.import_module("visualization.visualizer") 49 | 50 | module = importlib.import_module("models.eval." + args.evaluation_name) 51 | print("Running " + args.evaluation_name) 52 | 53 | parser.add_argument('-h', '--help', action='help') 54 | out = module.test(parser, visualisation=vis_module) 55 | 56 | if out is not None and not out: 57 | print("...FAIL") 58 | 59 | else: 60 | print("...OK") 61 | -------------------------------------------------------------------------------- /hubconf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | ''' 3 | hubconf.py for pytorch_gan_zoo repo 4 | 5 | ## Users can get the diverse models of pytorch_gan_zoo by calling 6 | hub_model = hub.load( 7 | 'facebookresearch/pytorch_gan_zoo:master', 8 | $MODEL_NAME, # 9 | config = None, 10 | useGPU = True, 11 | pretrained=False) # (Not pretrained models online yet) 12 | 13 | Available model'names are [DCGAN, PGAN, StyleGAN]. 14 | The config option should be a dictionnary defining the training parameters of 15 | the model. See ??/pytorch_gan_zoo/models/trainer/standard_configurations to see 16 | all possible options 17 | 18 | ## How can I use my model ? 19 | 20 | ### Build a random vector 21 | 22 | inputRandom, randomLabels = model.buildNoiseData((int) $BATCH_SIZE) 23 | 24 | ### Feed a random vector to the model 25 | 26 | model.test(inputRandom, 27 | getAvG=True, 28 | toCPU=True) 29 | 30 | Arguments: 31 | - getAvG (bool) get the smoothed version of the generator (advised) 32 | - toCPU (bool) if set to False the output tensor will be a torch.cuda tensor 33 | 34 | ### Acces the generator 35 | 36 | model.netG() 37 | 38 | ### Acces the discriminator 39 | 40 | model.netD() 41 | 42 | ## Can I train my model ? 43 | 44 | Of course. You can set all training parameters in the constructor (losses to use, 45 | learning rate, number of iterations etc...) and use the optimizeParameters() 46 | method to make a training steps. 47 | 48 | Typically here will be a sample code: 49 | 50 | for input_real in dataset: 51 | 52 | allLosses = model.optimizeParameters(inputs_real) 53 | 54 | # Do something with the losses 55 | 56 | Please have a look at 57 | 58 | models/trainer/standard_configurations to see all the 59 | training parameters you can use. 60 | 61 | ''' 62 | 63 | import torch.utils.model_zoo as model_zoo 64 | 65 | # Optional list of dependencies required by the package 66 | dependencies = ['torch'] 67 | 68 | 69 | def PGAN(pretrained=False, *args, **kwargs): 70 | """ 71 | Progressive growing model 72 | pretrained (bool): load a pretrained model ? 73 | model_name (string): if pretrained, load one of the following models 74 | celebaHQ-256, celebaHQ-512, DTD, celeba, cifar10. Default is celebaHQ. 75 | """ 76 | from models.progressive_gan import ProgressiveGAN as PGAN 77 | if 'config' not in kwargs or kwargs['config'] is None: 78 | kwargs['config'] = {} 79 | 80 | model = PGAN(useGPU=kwargs.get('useGPU', True), 81 | storeAVG=True, 82 | **kwargs['config']) 83 | 84 | checkpoint = {"celebAHQ-256": 'https://dl.fbaipublicfiles.com/gan_zoo/PGAN/celebaHQ_s6_i80000-6196db68.pth', 85 | "celebAHQ-512": 'https://dl.fbaipublicfiles.com/gan_zoo/PGAN/celebaHQ16_december_s7_i96000-9c72988c.pth', 86 | "DTD": 'https://dl.fbaipublicfiles.com/gan_zoo/PGAN/testDTD_s5_i96000-04efa39f.pth', 87 | "celeba": "https://dl.fbaipublicfiles.com/gan_zoo/PGAN/celebaCropped_s5_i83000-2b0acc76.pth"} 88 | if pretrained: 89 | if "model_name" in kwargs: 90 | if kwargs["model_name"] not in checkpoint.keys(): 91 | raise ValueError("model_name should be in " 92 | + str(checkpoint.keys())) 93 | else: 94 | print("Loading default model : celebaHQ-256") 95 | kwargs["model_name"] = "celebAHQ-256" 96 | state_dict = model_zoo.load_url(checkpoint[kwargs["model_name"]], 97 | map_location='cpu') 98 | model.load_state_dict(state_dict) 99 | return model 100 | 101 | 102 | def StyleGAN(pretrained=False, *args, **kwargs): 103 | """ 104 | NVIDIA StyleGAN 105 | pretrained (bool): load a 1024x1024 model trained on FlickrHQ 106 | """ 107 | from models.styleGAN import StyleGAN 108 | if 'config' not in kwargs or kwargs['config'] is None: 109 | kwargs['config'] = {} 110 | 111 | model = StyleGAN(useGPU=kwargs.get('useGPU', True), 112 | storeAVG=True, 113 | **kwargs['config']) 114 | 115 | checkpoint = 'https://dl.fbaipublicfiles.com/gan_zoo/StyleGAN/FFHQ_styleGAN-7cbdec00.pth' 116 | if pretrained: 117 | print("Loading default model : Flickr-HQ") 118 | state_dict = model_zoo.load_url(checkpoint, 119 | map_location='cpu') 120 | model.load_state_dict(state_dict) 121 | return model 122 | 123 | 124 | def DCGAN(pretrained=False, *args, **kwargs): 125 | """ 126 | DCGAN basic model 127 | pretrained (bool): load a pretrained model ? In this case load a model 128 | trained on fashionGen cloth 129 | """ 130 | from models.DCGAN import DCGAN 131 | if 'config' not in kwargs or kwargs['config'] is None: 132 | kwargs['config'] = {} 133 | 134 | model = DCGAN(useGPU=kwargs.get('useGPU', True), 135 | storeAVG=True, 136 | **kwargs['config']) 137 | 138 | checkpoint = 'https://dl.fbaipublicfiles.com/gan_zoo/DCGAN_fashionGen-1d67302.pth' 139 | if pretrained: 140 | state_dict = model_zoo.load_url(checkpoint, map_location='cpu') 141 | model.load_state_dict(state_dict) 142 | return model 143 | -------------------------------------------------------------------------------- /illustartionCelebaHQ.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/pytorch_GAN_zoo/b75dee40918caabb4fe7ec561522717bf096a8cb/illustartionCelebaHQ.jpg -------------------------------------------------------------------------------- /illustration.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/pytorch_GAN_zoo/b75dee40918caabb4fe7ec561522717bf096a8cb/illustration.png -------------------------------------------------------------------------------- /models/DCGAN.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import torch.optim as optim 3 | 4 | from .base_GAN import BaseGAN 5 | from .utils.config import BaseConfig 6 | from .networks.DCGAN_nets import GNet, DNet 7 | 8 | 9 | class DCGAN(BaseGAN): 10 | r""" 11 | Implementation of DCGAN 12 | """ 13 | 14 | def __init__(self, 15 | dimLatentVector=64, 16 | dimG=64, 17 | dimD=64, 18 | depth=3, 19 | **kwargs): 20 | r""" 21 | Args: 22 | 23 | Specific Arguments: 24 | - latentVectorDim (int): dimension of the input latent vector 25 | - dimG (int): reference depth of a layer in the generator 26 | - dimD (int): reference depth of a layer in the discriminator 27 | - depth (int): number of convolution layer in the model 28 | - **kwargs: arguments of the BaseGAN class 29 | 30 | """ 31 | if 'config' not in vars(self): 32 | self.config = BaseConfig() 33 | 34 | self.config.dimG = dimG 35 | self.config.dimD = dimD 36 | self.config.depth = depth 37 | 38 | BaseGAN.__init__(self, dimLatentVector, **kwargs) 39 | 40 | def getNetG(self): 41 | 42 | gnet = GNet(self.config.latentVectorDim, 43 | self.config.dimOutput, 44 | self.config.dimG, 45 | depthModel=self.config.depth, 46 | generationActivation=self.lossCriterion.generationActivation) 47 | return gnet 48 | 49 | def getNetD(self): 50 | 51 | dnet = DNet(self.config.dimOutput, 52 | self.config.dimD, 53 | self.lossCriterion.sizeDecisionLayer 54 | + self.config.categoryVectorDim, 55 | depthModel=self.config.depth) 56 | return dnet 57 | 58 | def getOptimizerD(self): 59 | return optim.Adam(filter(lambda p: p.requires_grad, self.netD.parameters()), 60 | betas=[0.5, 0.999], lr=self.config.learningRate) 61 | 62 | def getOptimizerG(self): 63 | return optim.Adam(filter(lambda p: p.requires_grad, self.netG.parameters()), 64 | betas=[0.5, 0.999], lr=self.config.learningRate) 65 | 66 | def getSize(self): 67 | size = 2**(self.config.depth + 3) 68 | return (size, size) 69 | -------------------------------------------------------------------------------- /models/README.md: -------------------------------------------------------------------------------- 1 | # models 2 | 3 | ## Architecture 4 | 5 | **networks/**: networks' architectures 6 | 7 | **utils/**: various utilities 8 | 9 | **eval/**: all evaluation scripts are defined here 10 | 11 | **datasets/**: specific dataset models 12 | 13 | **loss_criterions/**: gan loss criterions. Whether they are "basic" (MSE...) or 14 | more model specific (AC-GAN) 15 | 16 | **merics/**: metrics used to estimate the quality of the trained models 17 | 18 | **trainer/**: wrappers used to handle the GAN's training. Things like: number of iterations, logging, visualization... will be handled here. 19 | 20 | --- 21 | **base_gan.py** : the reference structure for GANs. All GANs must inherit from this class. 22 | 23 | This mother BaseGANs handles: 24 | * the GAN training sequence as described in Generative Adversarial Nets 25 | * GPU and multi-GPU 26 | * saving and loading into a file 27 | * gradient penalty 28 | * nature of the loss 29 | * conditional generation (ACGAN) 30 | 31 | What should be handled in a child class 32 | * nature of the G and D networks 33 | * kind of optimizers used (will be moved in BaseGAN) 34 | * and other functions model specific 35 | 36 | **progressive_gan.py**: an implementation of [NVIDIA's progressive gan](http://research.nvidia.com/sites/default/files/pubs/2017-10_Progressive-Growing-of/karras2018iclr-paper.pdf). This class inherits from the BaseGAN abstract class. 37 | 38 | **DCGAN.py**: an implementation of [DCGAN]( https://arxiv.org/pdf/1511.06434.pdf) a very simple and basic GAN structure. This class inherits from the BaseGAN abstract class. 39 | 40 | Among other things, it gives the user the possibility to add new layers to the model during the training. 41 | 42 | **trainer/std_p_gan_config.py**: standard configuration for a ProgressiveGAN training. 43 | **trainer/std_dcgan_config.py**: standard configuration for a DCAGN training. 44 | 45 | All possible configuration parameters for ProgressiveGANTrainer are described here. 46 | -------------------------------------------------------------------------------- /models/UTs/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | -------------------------------------------------------------------------------- /models/UTs/test_ac_criterion.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | from math import exp, log 3 | import torch 4 | from ..loss_criterions.ac_criterion import ACGanCriterion 5 | 6 | 7 | def test(): 8 | 9 | # test 10 | attribKeysList = {"Gender": {"order": 0, 11 | "values": ['M', 'W', 'Racoon'] 12 | }, 13 | "Console": {"order": 1, 14 | "values": ["None", "PC", "PS", "XBOX"] 15 | } 16 | } 17 | allowMultiple = ["Console"] 18 | test = ACGanCriterion(attribKeysList, allowMultiple=allowMultiple) 19 | tar, inLat = test.buildRandomCriterionTensor(2) 20 | 21 | if tar.size()[0] != 2: 22 | print("Invalid batch size for the target") 23 | return False 24 | if inLat.size()[0] != 2: 25 | print("Invalid batch size for the input latent vector") 26 | return False 27 | if tar.size()[1] != 5: 28 | print("Invalid feature size for the target") 29 | return False 30 | if inLat.size()[1] != 7: 31 | print("Invalid feature size for the input latent vector") 32 | return False 33 | 34 | testTarget = torch.tensor([[0., 1., 0., 0., 1.], 35 | [2., 0., 1., 0., 0.]]) 36 | testTensor = torch.tensor([[0.2, 0.1, 0.7, 0.5, 0.8, 0.9, 0.01], 37 | [0.2, 0.1, 0.7, 0.5, 0.8, 0.9, 0.01]]) 38 | 39 | a = -0.2 + log(exp(0.2) + exp(0.1) + exp(0.7)) 40 | b = -0.7 + log(exp(0.2) + exp(0.1) + exp(0.7)) 41 | c = log(1 + exp(-0.5)) - log(exp(-0.8)/(1+exp(-0.8))) \ 42 | - log(exp(-0.9)/(1+exp(-0.9))) + log(1 + exp(-0.01)) 43 | d = - log(exp(-0.5)/(1+exp(-0.5))) + log(1 + exp(-0.8)) \ 44 | - log(exp(-0.9)/(1+exp(-0.9))) - log(exp(-0.01)/(1 + exp(-0.01))) 45 | 46 | expectedResult = (a+b+(c+d)/4)/2 47 | result = test.getLoss(testTensor, testTarget) 48 | 49 | assert abs(result.item() - expectedResult) <= 0.001 50 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/pytorch_GAN_zoo/b75dee40918caabb4fe7ec561522717bf096a8cb/models/__init__.py -------------------------------------------------------------------------------- /models/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | -------------------------------------------------------------------------------- /models/datasets/attrib_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import os 3 | import json 4 | 5 | from copy import deepcopy 6 | 7 | import torchvision.transforms as Transforms 8 | import torch 9 | 10 | from torch.utils.data import Dataset 11 | from ..utils.image_transform import pil_loader 12 | 13 | 14 | from .utils.db_stats import buildKeyOrder 15 | 16 | 17 | class AttribDataset(Dataset): 18 | r""" 19 | A dataset class adapted to the specificites of our YSL fashion dataset. 20 | It loads both the images and an attribute dictionnary describing each image's 21 | attribute. 22 | """ 23 | 24 | def __init__(self, 25 | pathdb, 26 | attribDictPath=None, 27 | specificAttrib=None, 28 | transform=None, 29 | mimicImageFolder=False, 30 | ignoreAttribs=False, 31 | getEqualizer=False, 32 | pathMask=None): 33 | r""" 34 | Args: 35 | 36 | - root (string): path to the directory containing the images 37 | - attribDictPath (string): path to a json file containing the images' 38 | specific attributes 39 | - specificAttrib (list of string): if not None, specify which attributes 40 | be selected 41 | - transform (torchvision.transforms): transformation to apply to the 42 | loaded images. 43 | - mimicImageFolder (bool): set to True if the dataset is stored in the 44 | torchvision.datasets.ImageFolder format 45 | - ignoreAttribs (bool): set to True if you just want to use the attrib 46 | dict as a filter on images' name 47 | """ 48 | 49 | self.totAttribSize = 0 50 | self.hasAttrib = attribDictPath is not None or mimicImageFolder 51 | self.pathdb = pathdb 52 | self.transform = transform 53 | self.shiftAttrib = None 54 | self.stats = None 55 | self.pathMask = None 56 | 57 | if attribDictPath: 58 | if ignoreAttribs: 59 | self.attribDict = None 60 | 61 | with open(attribDictPath, 'rb') as file: 62 | tmpDict = json.load(file) 63 | self.listImg = [imgName for imgName in os.listdir(pathdb) 64 | if (os.path.splitext(imgName)[1] in [".jpg", 65 | ".png", ".npy"] and imgName in tmpDict)] 66 | else: 67 | self.loadAttribDict(attribDictPath, pathdb, specificAttrib) 68 | 69 | elif mimicImageFolder: 70 | self.loadImageFolder(pathdb) 71 | else: 72 | self.attribDict = None 73 | self.listImg = [imgName for imgName in os.listdir(pathdb) 74 | if os.path.splitext(imgName)[1] in [".jpg", ".png", 75 | ".npy"]] 76 | 77 | if pathMask is not None: 78 | print("Path mask found " + pathMask) 79 | self.pathMask = pathMask 80 | self.listImg = [imgName for imgName in self.listImg 81 | if os.path.isfile(os.path.join(pathMask, 82 | os.path.splitext(imgName)[0] + "_mask.jpg"))] 83 | 84 | if len(self.listImg) == 0: 85 | raise AttributeError("Empty dataset") 86 | 87 | self.buildStatsOnDict() 88 | 89 | print("%d images found" % len(self)) 90 | 91 | def __len__(self): 92 | return len(self.listImg) 93 | 94 | def hasMask(self): 95 | return self.pathMask is not None 96 | 97 | def buildStatsOnDict(self): 98 | 99 | if self.attribDict is None: 100 | return 101 | 102 | self.stats = {} 103 | for item in self.attribDict: 104 | 105 | for category, value in self.attribDict[item].items(): 106 | 107 | if category not in self.stats: 108 | self.stats[category] = {} 109 | 110 | if value not in self.stats[category]: 111 | self.stats[category][value] = 0 112 | 113 | self.stats[category][value] += 1 114 | 115 | def loadAttribDict(self, 116 | dictPath, 117 | dbDir, 118 | specificAttrib): 119 | r""" 120 | Load a dictionnary describing the attributes of each image in the 121 | dataset and save the list of all the possible attributes and their 122 | acceptable values. 123 | 124 | Args: 125 | 126 | - dictPath (string): path to a json file describing the dictionnary. 127 | If None, no attribute will be loaded 128 | - dbDir (string): path to the directory containing the dataset 129 | - specificAttrib (list of string): if not None, specify which 130 | attributes should be selected 131 | """ 132 | 133 | self.attribDict = {} 134 | attribList = {} 135 | 136 | with open(dictPath, 'rb') as file: 137 | tmpDict = json.load(file) 138 | 139 | for fileName, attrib in tmpDict.items(): 140 | 141 | if not os.path.isfile(os.path.join(dbDir, fileName)): 142 | continue 143 | 144 | if specificAttrib is None: 145 | self.attribDict[fileName] = deepcopy(attrib) 146 | else: 147 | self.attribDict[fileName] = { 148 | k: attrib[k] for k in specificAttrib} 149 | 150 | for attribName, attribVal in self.attribDict[fileName].items(): 151 | if attribName not in attribList: 152 | attribList[attribName] = set() 153 | 154 | attribList[attribName].add(attribVal) 155 | 156 | # Filter the attrib list 157 | self.totAttribSize = 0 158 | 159 | self.shiftAttrib = {} 160 | self.shiftAttribVal = {} 161 | 162 | for attribName, attribVals in attribList.items(): 163 | 164 | if len(attribVals) == 1: 165 | continue 166 | 167 | self.shiftAttrib[attribName] = self.totAttribSize 168 | self.totAttribSize += 1 169 | 170 | self.shiftAttribVal[attribName] = { 171 | name: c for c, name in enumerate(attribVals)} 172 | 173 | # Img list 174 | self.listImg = list(self.attribDict.keys()) 175 | 176 | def loadImageFolder(self, pathdb): 177 | r""" 178 | Load a dataset saved in the torchvision.datasets.ImageFolder format. 179 | 180 | Arguments: 181 | - pathdb: path to the directory containing the dataset 182 | """ 183 | 184 | listDir = [dirName for dirName in os.listdir(pathdb) 185 | if os.path.isdir(os.path.join(pathdb, dirName))] 186 | 187 | imgExt = [".jpg", ".png", ".JPEG"] 188 | 189 | self.attribDict = {} 190 | 191 | self.totAttribSize = 1 192 | self.shiftAttrib = {"Main": 0} 193 | self.shiftAttribVal = {"Main": {}} 194 | 195 | for index, dirName in enumerate(listDir): 196 | 197 | dirPath = os.path.join(pathdb, dirName) 198 | self.shiftAttribVal["Main"][dirName] = index 199 | 200 | for img in os.listdir(dirPath): 201 | 202 | if os.path.splitext(img)[1] in imgExt: 203 | fullName = os.path.join(dirName, img) 204 | self.attribDict[fullName] = {"Main": dirName} 205 | 206 | # Img list 207 | self.listImg = list(self.attribDict.keys()) 208 | 209 | def __getitem__(self, idx): 210 | 211 | imgName = self.listImg[idx] 212 | imgPath = os.path.join(self.pathdb, imgName) 213 | img = pil_loader(imgPath) 214 | 215 | if self.transform is not None: 216 | img = self.transform(img) 217 | 218 | # Build the attribute tensor 219 | attr = [0 for i in range(self.totAttribSize)] 220 | 221 | if self.hasAttrib: 222 | attribVals = self.attribDict[imgName] 223 | for key, val in attribVals.items(): 224 | baseShift = self.shiftAttrib[key] 225 | attr[baseShift] = self.shiftAttribVal[key][val] 226 | else: 227 | attr = [0] 228 | 229 | if self.pathMask is not None: 230 | mask_path = os.path.join( 231 | self.pathMask, os.path.splitext(imgName)[0] + "_mask.jpg") 232 | mask = pil_loader(mask_path) 233 | mask = Transforms.Grayscale(1)(mask) 234 | mask = self.transform(mask) 235 | 236 | return img, torch.tensor(attr, dtype=torch.long), mask 237 | 238 | return img, torch.tensor(attr, dtype=torch.long) 239 | 240 | def getName(self, idx): 241 | 242 | return self.listImg[idx] 243 | 244 | def getTextDescriptor(self, idx): 245 | r""" 246 | Get the text descriptor of the idx th image in the dataset 247 | """ 248 | imgName = self.listImg[idx] 249 | 250 | if not self.hasAttrib: 251 | return {} 252 | 253 | return self.attribDict[imgName] 254 | 255 | def getKeyOrders(self, equlizationWeights=False): 256 | r""" 257 | If the dataset is labelled, give the order in which the attributes are 258 | given 259 | 260 | Returns: 261 | 262 | A dictionary output[key] = { "order" : int , "values" : list of 263 | string} 264 | """ 265 | 266 | if self.attribDict is None: 267 | return None 268 | 269 | if equlizationWeights: 270 | 271 | if self.stats is None: 272 | raise ValueError("The weight equalization can only be \ 273 | performed on labelled datasets") 274 | 275 | return buildKeyOrder(self.shiftAttrib, 276 | self.shiftAttribVal, 277 | stats=self.stats) 278 | return buildKeyOrder(self.shiftAttrib, 279 | self.shiftAttribVal, 280 | stats=None) 281 | -------------------------------------------------------------------------------- /models/datasets/hd5.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import json 3 | import torch 4 | import h5py 5 | 6 | import copy 7 | 8 | from .utils.db_stats import buildKeyOrder 9 | 10 | 11 | class H5Dataset(torch.utils.data.Dataset): 12 | 13 | def __init__(self, 14 | file_path, 15 | partition_path=None, 16 | partition_value=None, 17 | transform=None, 18 | specificAttrib=None, 19 | stats_file=None, 20 | pathDBMask=None): 21 | super(H5Dataset, self).__init__() 22 | 23 | self.path = file_path 24 | self.partition_path = partition_path 25 | self.partition_value = partition_value 26 | 27 | if self.partition_value is None: 28 | self.partition_path = None 29 | print("No partition value found, ignoring the partition file") 30 | 31 | self.h5_file = None 32 | self.partition_file = None 33 | 34 | self.transform = transform 35 | 36 | self.attribKeys = copy.deepcopy(specificAttrib) 37 | self.statsData = None 38 | self.totAttribSize = 0 39 | 40 | if stats_file is not None: 41 | with open(stats_file, 'rb') as file: 42 | self.statsData = json.load(file) 43 | 44 | if self.partition_value is None and "GLOBAL" in self.statsData: 45 | self.statsData = self.statsData["GLOBAL"] 46 | 47 | elif self.partition_value in self.statsData: 48 | self.statsData = self.statsData[self.partition_value] 49 | 50 | self.buildAttribShift() 51 | 52 | self.pathDBMask = pathDBMask 53 | self.maskFile = None 54 | 55 | def __getitem__(self, index): 56 | 57 | if self.h5_file is None: 58 | self.h5_file = h5py.File(self.path, 'r') 59 | 60 | if self.partition_path is not None: 61 | self.partition_file = h5py.File(self.partition_path, 'r') 62 | 63 | if self.partition_file is not None: 64 | index = self.partition_file[self.partition_value][index] 65 | 66 | img = self.h5_file['input_image'][index] 67 | 68 | if self.transform is not None: 69 | img = self.transform(img) 70 | 71 | if self.statsData is not None: 72 | 73 | attr = [None for x in range(self.totAttribSize)] 74 | 75 | for key in self.attribKeys: 76 | 77 | label = str(self.h5_file[key][index][0]) 78 | shift = self.attribShift[key] 79 | attr[shift] = self.attribShiftVal[key][label] 80 | 81 | else: 82 | 83 | attr = [0] 84 | 85 | if self.pathDBMask is not None: 86 | 87 | if self.maskFile is None: 88 | self.maskFile = h5py.File(self.pathDBMask, 'r') 89 | 90 | mask = self.maskFile["mask"][index] 91 | mask = self.transform(mask) 92 | 93 | img = img * (mask + 1.0) * 0.5 + (1 - mask) * 0.5 94 | 95 | return img, torch.tensor(attr), mask 96 | 97 | return img, torch.tensor(attr) 98 | 99 | def __len__(self): 100 | if self.partition_path is None: 101 | with h5py.File(self.path, 'r') as db: 102 | lens = len(db['input_image']) 103 | else: 104 | with h5py.File(self.partition_path, 'r') as db: 105 | lens = len(db[self.partition_value]) 106 | return lens 107 | 108 | def getName(self, index): 109 | 110 | if self.partition_path is not None: 111 | if self.partition_file is None: 112 | self.partition_file = h5py.File(self.partition_path, 'r') 113 | 114 | return self.partition_file[self.partition_value][index] 115 | 116 | return index 117 | 118 | def buildAttribShift(self): 119 | 120 | self.attribShift = None 121 | self.attribShiftVal = None 122 | 123 | if self.statsData is None: 124 | return 125 | 126 | if self.attribKeys is None: 127 | self.attribKeys = [x for x in self.statsData.keys() if 128 | x != "totalSize"] 129 | 130 | self.attribShift = {} 131 | self.attribShiftVal = {} 132 | 133 | self.totAttribSize = 0 134 | 135 | for key in self.attribKeys: 136 | 137 | self.attribShift[key] = self.totAttribSize 138 | self.attribShiftVal[key] = { 139 | name: c 140 | for c, name in enumerate(list(self.statsData[key].keys()))} 141 | self.totAttribSize += 1 142 | 143 | def getKeyOrders(self, equlizationWeights=False): 144 | 145 | if equlizationWeights: 146 | raise ValueError("Equalization weight not implemented yet") 147 | 148 | return buildKeyOrder(self.attribShift, self.attribShiftVal) 149 | -------------------------------------------------------------------------------- /models/datasets/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | -------------------------------------------------------------------------------- /models/datasets/utils/db_stats.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | def getClassStats(inputDict, className): 3 | 4 | outStats = {} 5 | 6 | for item in inputDict: 7 | 8 | val = item[className] 9 | if val not in outStats: 10 | outStats[val] = 0 11 | 12 | outStats[val] += 1 13 | 14 | return outStats 15 | 16 | 17 | def buildDictStats(inputDict, classList): 18 | 19 | locStats = {"total": len(inputDict)} 20 | 21 | for cat in classList: 22 | 23 | locStats[cat] = getClassStats(inputDict, cat) 24 | 25 | return locStats 26 | 27 | 28 | def buildKeyOrder(shiftAttrib, 29 | shiftAttribVal, 30 | stats=None): 31 | r""" 32 | If the dataset is labelled, give the order in which the attributes are given 33 | 34 | Args: 35 | 36 | - shiftAttrib (dict): order of each category in the category vector 37 | - shiftAttribVal (dict): list (ordered) of each possible labels for each 38 | category of the category vector 39 | - stats (dict): if not None, number of representant of each label for 40 | each category. Will update the output dictionary with a 41 | "weights" index telling how each labels should be 42 | balanced in the classification loss. 43 | 44 | Returns: 45 | 46 | A dictionary output[key] = { "order" : int , "values" : list of string} 47 | """ 48 | 49 | MAX_VAL_EQUALIZATION = 10 50 | 51 | output = {} 52 | for key in shiftAttrib: 53 | output[key] = {} 54 | output[key]["order"] = shiftAttrib[key] 55 | output[key]["values"] = [None for i in range(len(shiftAttribVal[key]))] 56 | for cat, shift in shiftAttribVal[key].items(): 57 | output[key]["values"][shift] = cat 58 | 59 | if stats is not None: 60 | for key in output: 61 | 62 | n = sum([x for key, x in stats[key].items()]) 63 | 64 | output[key]["weights"] = {} 65 | for item, value in stats[key].items(): 66 | output[key]["weights"][item] = min( 67 | MAX_VAL_EQUALIZATION, n / float(value + 1.0)) 68 | 69 | return output 70 | -------------------------------------------------------------------------------- /models/eval/build_nn_db.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import os 3 | import json 4 | 5 | from ..metrics.nn_score import buildFeatureExtractor, saveFeatures 6 | from ..utils.utils import getVal, toStrKey 7 | from ..networks.constant_net import FeatureTransform 8 | 9 | 10 | def test(parser, visualisation=None): 11 | 12 | parser.add_argument('--size', help="Image size", 13 | type=int, dest="size", default=224) 14 | parser.add_argument('-f', '--featureExtractor', help="Path of the feature \ 15 | extractor", 16 | type=str, dest="featureExtractor") 17 | 18 | kwargs = vars(parser.parse_known_args()[0]) 19 | 20 | # Parameters 21 | configPath = getVal(kwargs, "config", None) 22 | if configPath is None: 23 | raise ValueError("You need to input a configuratrion file") 24 | 25 | pathFeatureExtractor = getVal(kwargs, "featureExtractor", None) 26 | if pathFeatureExtractor is None: 27 | raise ValueError("You need to input a feature extractor") 28 | 29 | with open(configPath, 'rb') as file: 30 | wholeConfig = json.load(file) 31 | 32 | # Load the dataset 33 | pathDB = wholeConfig["pathDB"] 34 | pathAttrib = wholeConfig.get("pathAttrib", None) 35 | pathMask = wholeConfig.get("pathDBMask", None) 36 | pathPartition = wholeConfig.get("pathPartition", None) 37 | partitionValue = wholeConfig.get("partitionValue", None) 38 | 39 | partitionValue = getVal(kwargs, "partition_value", None) 40 | 41 | model, mean, std = buildFeatureExtractor(pathFeatureExtractor) 42 | imgTransform = FeatureTransform(mean, std, size=kwargs['size']) 43 | 44 | print("Building the model's feature data") 45 | 46 | pathOutFeatures = os.path.splitext(pathFeatureExtractor)[0] + "_" + \ 47 | os.path.splitext(os.path.basename(pathDB))[0] + \ 48 | "_" + str(kwargs['size']) + \ 49 | toStrKey(partitionValue) + "_features.pkl" 50 | 51 | print("Saving the features at : " + pathOutFeatures) 52 | 53 | saveFeatures(model, imgTransform, pathDB, pathMask, pathAttrib, 54 | pathOutFeatures, pathPartition, partitionValue) 55 | 56 | print("All done") 57 | -------------------------------------------------------------------------------- /models/eval/inception.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import os 3 | import json 4 | import torchvision 5 | 6 | from ..metrics.inception_score import InceptionScore 7 | from ..utils.utils import printProgressBar 8 | from ..utils.utils import getVal, loadmodule, getLastCheckPoint, \ 9 | parse_state_name, getNameAndPackage, saveScore 10 | from ..networks.constant_net import FeatureTransform 11 | 12 | 13 | def test(parser, visualisation=None): 14 | 15 | kwargs = vars(parser.parse_args()) 16 | 17 | # Are all parameters available ? 18 | name = getVal(kwargs, "name", None) 19 | if name is None and not kwargs['selfNoise']: 20 | raise ValueError("You need to input a name") 21 | 22 | module = getVal(kwargs, "module", None) 23 | if module is None: 24 | raise ValueError("You need to input a module") 25 | 26 | # Loading the model 27 | scale = getVal(kwargs, "scale", None) 28 | 29 | if name is not None: 30 | iter = getVal(kwargs, "iter", None) 31 | 32 | checkPointDir = os.path.join(kwargs["dir"], name) 33 | checkpointData = getLastCheckPoint( 34 | checkPointDir, name, scale=scale, iter=iter) 35 | 36 | if checkpointData is None: 37 | print(scale, iter) 38 | if scale is not None or iter is not None: 39 | raise FileNotFoundError("Not checkpoint found for model " 40 | + name + " at directory " + dir + 41 | " for scale " + str(scale) + 42 | " at iteration " + str(iter)) 43 | raise FileNotFoundError( 44 | "Not checkpoint found for model " + name + " at directory " 45 | + dir) 46 | 47 | modelConfig, pathModel, _ = checkpointData 48 | with open(modelConfig, 'rb') as file: 49 | configData = json.load(file) 50 | 51 | modelPackage, modelName = getNameAndPackage(module) 52 | modelType = loadmodule(modelPackage, modelName) 53 | 54 | model = modelType(useGPU=True, 55 | storeAVG=True, 56 | **configData) 57 | 58 | if scale is None or iter is None: 59 | _, scale, iter = parse_state_name(pathModel) 60 | 61 | print("Checkpoint found at scale %d, iter %d" % (scale, iter)) 62 | model.load(pathModel) 63 | 64 | elif scale is None: 65 | raise AttributeError("Please provide a scale to compute the noise of \ 66 | the dataset") 67 | 68 | # Building the score instance 69 | classifier = torchvision.models.inception_v3(pretrained=True).cuda() 70 | scoreMaker = InceptionScore(classifier) 71 | 72 | batchSize = 16 73 | nBatch = 1000 74 | 75 | refMean = [2*p - 1 for p in[0.485, 0.456, 0.406]] 76 | refSTD = [2*p for p in [0.229, 0.224, 0.225]] 77 | imgTransform = FeatureTransform(mean=refMean, 78 | std=refSTD, 79 | size=299).cuda() 80 | 81 | print("Computing the inception score...") 82 | for index in range(nBatch): 83 | 84 | inputFake = model.test(model.buildNoiseData(batchSize)[0], 85 | toCPU=False, getAvG=True) 86 | 87 | scoreMaker.updateWithMiniBatch(imgTransform(inputFake)) 88 | printProgressBar(index, nBatch) 89 | 90 | printProgressBar(nBatch, nBatch) 91 | print("Merging the results, please wait it can take some time...") 92 | score = scoreMaker.getScore() 93 | 94 | # Now printing the results 95 | print(score) 96 | 97 | # Saving the results 98 | if name is not None: 99 | 100 | outPath = os.path.join(checkPointDir, name + "_swd.json") 101 | saveScore(outPath, score, 102 | scale, iter) 103 | -------------------------------------------------------------------------------- /models/eval/laplacian_SWD.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import os 3 | import json 4 | 5 | import torch 6 | 7 | from ..metrics.laplacian_swd import LaplacianSWDMetric 8 | from ..utils.utils import printProgressBar 9 | from ..datasets.attrib_dataset import AttribDataset 10 | from ..datasets.hd5 import H5Dataset 11 | from ..utils.utils import getVal, loadmodule, getLastCheckPoint, \ 12 | parse_state_name, getNameAndPackage, saveScore 13 | from ..utils.image_transform import standardTransform 14 | 15 | 16 | def test(parser, visualisation=None): 17 | 18 | parser.add_argument('--selfNoise', action='store_true', 19 | help="Compute the inner noise of the dataset") 20 | kwargs = vars(parser.parse_args()) 21 | 22 | # Are all parameters available ? 23 | name = getVal(kwargs, "name", None) 24 | if name is None and not kwargs['selfNoise']: 25 | raise ValueError("You need to input a name") 26 | 27 | module = getVal(kwargs, "module", None) 28 | if module is None: 29 | raise ValueError("You need to input a module") 30 | 31 | trainingConfig = getVal(kwargs, "config", None) 32 | if trainingConfig is None: 33 | raise ValueError("You need to input a configuration file") 34 | 35 | # Loading the model 36 | scale = getVal(kwargs, "scale", None) 37 | 38 | if name is not None: 39 | iter = getVal(kwargs, "iter", None) 40 | 41 | checkPointDir = os.path.join(kwargs["dir"], name) 42 | checkpointData = getLastCheckPoint( 43 | checkPointDir, name, scale=scale, iter=iter) 44 | 45 | if checkpointData is None: 46 | print(scale, iter) 47 | if scale is not None or iter is not None: 48 | raise FileNotFoundError("Not checkpoint found for model " 49 | + name + " at directory " + dir + 50 | " for scale " + str(scale) + 51 | " at iteration " + str(iter)) 52 | raise FileNotFoundError( 53 | "Not checkpoint found for model " + name + " at directory " 54 | + dir) 55 | 56 | modelConfig, pathModel, _ = checkpointData 57 | with open(modelConfig, 'rb') as file: 58 | configData = json.load(file) 59 | 60 | modelPackage, modelName = getNameAndPackage(module) 61 | modelType = loadmodule(modelPackage, modelName) 62 | 63 | model = modelType(useGPU=True, 64 | storeAVG=True, 65 | **configData) 66 | 67 | if scale is None or iter is None: 68 | _, scale, iter = parse_state_name(pathModel) 69 | 70 | print("Checkpoint found at scale %d, iter %d" % (scale, iter)) 71 | model.load(pathModel) 72 | 73 | elif scale is None: 74 | raise AttributeError("Please provide a scale to compute the noise of \ 75 | the dataset") 76 | 77 | # Building the score instance 78 | depthPyramid = min(scale, 4) 79 | SWDMetric = LaplacianSWDMetric(7, 128, depthPyramid) 80 | 81 | # Building the dataset 82 | with open(trainingConfig, 'rb') as file: 83 | wholeConfig = json.load(file) 84 | 85 | pathPartition = wholeConfig.get("pathPartition", None) 86 | partitionValue = wholeConfig.get("partitionValue", None) 87 | attribDict = wholeConfig.get('pathAttrib', None) 88 | partitionValue = getVal(kwargs, "partition_value", None) 89 | 90 | # Training dataset properties 91 | pathDB = wholeConfig["pathDB"] 92 | size = 2**(2 + scale) 93 | db_transform = standardTransform((size, size)) 94 | 95 | if os.path.splitext(pathDB)[1] == '.h5': 96 | dataset = H5Dataset(pathDB, 97 | transform=db_transform, 98 | partition_path=pathPartition, 99 | partition_value=partitionValue) 100 | else: 101 | dataset = AttribDataset(pathdb=pathDB, 102 | transform=db_transform, 103 | attribDictPath=attribDict) 104 | 105 | batchSize = 16 106 | dbLoader = torch.utils.data.DataLoader(dataset, batch_size=batchSize, 107 | num_workers=2, shuffle=True) 108 | 109 | # Metric parameters 110 | nImagesSampled = min(len(dataset), 16000) 111 | maxBatch = nImagesSampled / batchSize 112 | 113 | if kwargs['selfNoise']: 114 | 115 | print("Computing the inner noise of the dataset...") 116 | loader2 = torch.utils.data.DataLoader(dataset, batch_size=batchSize, 117 | num_workers=2, shuffle=True) 118 | 119 | for item, data in enumerate(zip(dbLoader, loader2)): 120 | 121 | if item > maxBatch: 122 | break 123 | 124 | real, fake = data 125 | SWDMetric.updateWithMiniBatch(real[0], fake[0]) 126 | printProgressBar(item, maxBatch) 127 | 128 | else: 129 | 130 | print("Generating the fake dataset...") 131 | for item, data in enumerate(dbLoader, 0): 132 | 133 | if item > maxBatch: 134 | break 135 | 136 | inputsReal, _ = data 137 | inputFake = model.test(model.buildNoiseData( 138 | inputsReal.size(0))[0], toCPU=False, getAvG=True) 139 | 140 | SWDMetric.updateWithMiniBatch(inputFake, inputsReal) 141 | printProgressBar(item, maxBatch) 142 | 143 | printProgressBar(maxBatch, maxBatch) 144 | print("Merging the results, please wait it can take some time...") 145 | score = SWDMetric.getScore() 146 | 147 | # Saving the results 148 | if name is not None: 149 | 150 | outPath = os.path.join(checkPointDir, name + "_swd.json") 151 | if kwargs['selfNoise']: 152 | saveScore(outPath, score, 153 | scale, "inner noise") 154 | else: 155 | saveScore(outPath, score, 156 | scale, iter) 157 | 158 | # Now printing the results 159 | print("") 160 | 161 | resolution = ['resolution '] + \ 162 | [str(int(size / (2**factor))) for factor in range(depthPyramid)] 163 | resolution[-1] += ' (background)' 164 | 165 | strScores = ['score'] + ["{:10.6f}".format(s) for s in score] 166 | 167 | formatCommand = ' '.join(['{:>16}' for x in range(depthPyramid + 1)]) 168 | 169 | print(formatCommand.format(*resolution)) 170 | print(formatCommand.format(*strScores)) 171 | -------------------------------------------------------------------------------- /models/eval/metric_plot.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import os 3 | import json 4 | 5 | from ..utils.utils import getVal 6 | 7 | 8 | def test(parser, visualisation=None): 9 | 10 | # Parameters 11 | kwargs = vars(parser.parse_args()) 12 | 13 | name = getVal(kwargs, "name", None) 14 | if name is None: 15 | raise ValueError("You need to input a name") 16 | 17 | if visualisation is None: 18 | raise ValueError("A visualizer is mandatory for this evaluation") 19 | 20 | checkPointDir = os.path.join(kwargs["dir"], name) 21 | 22 | suffixes = {"SWD": "_swd", "NN": "_nn_metric", 23 | "INCEPTION": "_inception_metric"} 24 | 25 | for key, value in suffixes.items(): 26 | 27 | pathFile = os.path.join(checkPointDir, name + value + ".json") 28 | 29 | if not os.path.isfile(pathFile): 30 | continue 31 | 32 | with open(pathFile, 'rb') as file: 33 | data = json.load(file) 34 | 35 | for scale in data: 36 | 37 | itemType = next(iter(data[scale].values())) 38 | 39 | if isinstance(itemType, dict): 40 | 41 | attribs = list(itemType.keys()) 42 | withAttribs = True 43 | nData = len(next(iter(itemType.values()))) 44 | 45 | else: 46 | attribs = [''] 47 | withAttribs = False 48 | nData = len(next(iter(data[scale].values()))) 49 | 50 | for attrib in attribs: 51 | 52 | env_name = name + "_" + key + "_scale_" + \ 53 | scale + "_" + os.path.basename(attrib) 54 | visualisation.delete_env(env_name) 55 | 56 | locIter = [] 57 | outYData = [[] for x in range(nData)] 58 | 59 | iterations = [int(x) for x in data[scale].keys() if x.isdigit()] 60 | 61 | iterations.sort() 62 | 63 | for iteration in iterations: 64 | 65 | locIter.append(iteration) 66 | 67 | if not withAttribs: 68 | 69 | for i in range(nData): 70 | 71 | outYData[i].append(data[scale][str(iteration)][i]) 72 | 73 | else: 74 | 75 | if attrib not in data[scale][str(iteration)]: 76 | continue 77 | 78 | for i in range(nData): 79 | outYData[i].append( 80 | data[scale][str(iteration)][attrib][i]) 81 | 82 | for i in range(nData): 83 | plotName = key + " " + str(i) 84 | visualisation.publishLinePlot([(plotName, outYData[i])], locIter, 85 | name=plotName, 86 | env=env_name) 87 | 88 | print(scale, plotName, sum(outYData[i]) / len(outYData[i])) 89 | -------------------------------------------------------------------------------- /models/eval/nn_metric.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | from ..networks.constant_net import FeatureTransform 3 | from ..metrics.nn_score import buildFeatureExtractor 4 | from ..utils.utils import getVal, loadmodule, getLastCheckPoint, \ 5 | parse_state_name, getNameAndPackage, toStrKey, saveScore 6 | from ..gan_visualizer import GANVisualizer 7 | import torch.nn as nn 8 | import torch 9 | import os 10 | import json 11 | 12 | import pickle 13 | 14 | import sys 15 | sys.path.append("..") 16 | 17 | 18 | def getModelName(pathConfig): 19 | 20 | pathConfig = os.path.basename(pathConfig) 21 | 22 | if pathConfig[-18:] != '_train_config.json': 23 | raise ValueError("Invalid configuration name") 24 | 25 | return pathConfig[:-18] 26 | 27 | 28 | def update_parser(parser): 29 | 30 | parser.add_argument('--showNN', action='store_true') 31 | parser.add_argument('--size', help="Image size", 32 | type=int, dest="size", default=224) 33 | parser.add_argument('-f', '--featureExtractor', help="Path to the feature \ 34 | extractor", 35 | type=str, dest="featureExtractor") 36 | 37 | 38 | def test(parser, visualisation=None): 39 | 40 | update_parser(parser) 41 | 42 | kwargs = vars(parser.parse_args()) 43 | # Parameters 44 | name = getVal(kwargs, "name", None) 45 | if name is None: 46 | raise ValueError("You need to input a name") 47 | 48 | module = getVal(kwargs, "module", None) 49 | if module is None: 50 | raise ValueError("You need to input a module") 51 | 52 | trainingConfig = getVal(kwargs, "config", None) 53 | if trainingConfig is None: 54 | raise ValueError("You need to input a configuration file") 55 | 56 | pathNNFeatureExtractor = getVal(kwargs, "featureExtractor", None) 57 | if pathNNFeatureExtractor is None: 58 | raise ValueError("You need to give a feature extractor") 59 | 60 | # Mandatory fields 61 | checkPointDir = os.path.join(kwargs["dir"], name) 62 | scale = getVal(kwargs, "scale", None) 63 | iter = getVal(kwargs, "iter", None) 64 | 65 | checkpointData = getLastCheckPoint( 66 | checkPointDir, name, scale=scale, iter=iter) 67 | 68 | if checkpointData is None: 69 | if scale is not None or iter is not None: 70 | raise FileNotFoundError("Not checkpoint found for model " + name 71 | + " at directory " + dir + " for scale " + 72 | str(scale) + " at iteration " + str(iter)) 73 | raise FileNotFoundError( 74 | "Not checkpoint found for model " + name + " at directory " + dir) 75 | 76 | modelConfig, pathModel, _ = checkpointData 77 | 78 | if scale is None or iter is None: 79 | _, scale, iter = parse_state_name(pathModel) 80 | 81 | # Feature extraction 82 | 83 | # Look for NN data 84 | 85 | with open(trainingConfig, 'rb') as file: 86 | wholeConfig = json.load(file) 87 | 88 | pathDB = wholeConfig.get("pathDB", None) 89 | if pathDB is None: 90 | raise ValueError("No training database found") 91 | 92 | partitionValue = wholeConfig.get("partitionValue", None) 93 | partitionValue = getVal(kwargs, "partition_value", None) 94 | 95 | pathOutFeatures = os.path.splitext(pathNNFeatureExtractor)[0] + "_" + \ 96 | os.path.splitext(os.path.basename(pathDB))[0] + "_" + \ 97 | str(kwargs['size']) + \ 98 | toStrKey(partitionValue) + "_features.pkl" 99 | 100 | if not os.path.isfile(pathNNFeatureExtractor) \ 101 | or not os.path.isfile(pathOutFeatures): 102 | raise FileNotFoundError("No model found at " + pathOutFeatures) 103 | 104 | print("Loading model " + pathModel) 105 | modelPackage, modelName = getNameAndPackage(module) 106 | modelType = loadmodule(modelPackage, modelName) 107 | visualizer = GANVisualizer( 108 | pathModel, modelConfig, modelType, visualisation) 109 | 110 | print("NN model found ! " + pathNNFeatureExtractor) 111 | featureExtractor, mean, std = buildFeatureExtractor(pathNNFeatureExtractor) 112 | 113 | imgTransform = nn.DataParallel(FeatureTransform( 114 | mean, std, kwargs['size'])).to(torch.device("cuda:0")) 115 | featureExtractor = nn.DataParallel( 116 | featureExtractor).to(torch.device("cuda:0")) 117 | 118 | with open(pathOutFeatures, 'rb') as file: 119 | nnSearch, names = pickle.load(file) 120 | 121 | if kwargs['showNN']: 122 | print("Retriving 10 neighbors for visualization") 123 | visualizer.visualizeNN(10, 5, featureExtractor, 124 | imgTransform, nnSearch, names, pathDB) 125 | print("Ready, please check out visdom main environement") 126 | else: 127 | 128 | outMetric = visualizer.exportNN( 129 | 1600, 8, featureExtractor, imgTransform, nnSearch) 130 | outPath = modelConfig[:-18] + "_nn_metric.json" 131 | 132 | saveScore(outPath, list(outMetric), scale, 133 | iter, pathNNFeatureExtractor) 134 | -------------------------------------------------------------------------------- /models/eval/visualization.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import os 3 | import json 4 | import sys 5 | 6 | import torch 7 | 8 | from ..gan_visualizer import GANVisualizer 9 | from ..utils.utils import loadmodule, getLastCheckPoint, getVal, \ 10 | getNameAndPackage, parse_state_name 11 | 12 | 13 | def getModelName(pathConfig): 14 | 15 | pathConfig = os.path.basename(pathConfig) 16 | 17 | if pathConfig[-18:] != '_train_config.json': 18 | raise ValueError("Invalid configuration name") 19 | 20 | return pathConfig[:-18] 21 | 22 | 23 | def updateParserWithLabels(parser, labels): 24 | 25 | for key in labels: 26 | parser.add_argument('--' + key, type=str, 27 | help=str(labels[key]["values"])) 28 | return parser 29 | 30 | 31 | def test(parser, visualisation=None): 32 | 33 | # Parameters 34 | parser.add_argument('--showLabels', action='store_true', 35 | help="For labelled datasets, show available labels") 36 | parser.add_argument('--interpolate', type=str, 37 | dest='interpolationPath', 38 | help="Path to some latent vectors to interpolate") 39 | parser.add_argument('--random_interpolate', action='store_true', 40 | help="Save a random interpolation") 41 | parser.add_argument('--save_dataset', type=str, dest="output_dataset", 42 | help="Save a dataset at the given location") 43 | parser.add_argument('--size_dataset', type=int, dest="size_dataset", 44 | default=10000, 45 | help="Size of the dataset to be saved") 46 | 47 | kwargs = vars(parser.parse_known_args()[0]) 48 | 49 | name = getVal(kwargs, "name", None) 50 | if name is None: 51 | parser.print_help() 52 | raise ValueError("You need to input a name") 53 | 54 | module = getVal(kwargs, "module", None) 55 | if module is None: 56 | parser.print_help() 57 | raise ValueError("You need to input a module") 58 | 59 | scale = getVal(kwargs, "scale", None) 60 | iter = getVal(kwargs, "iter", None) 61 | 62 | checkPointDir = os.path.join(kwargs["dir"], name) 63 | checkpointData = getLastCheckPoint(checkPointDir, 64 | name, 65 | scale=scale, 66 | iter=iter) 67 | 68 | if checkpointData is None: 69 | raise FileNotFoundError( 70 | "Not checkpoint found for model " + name + " at directory " + dir) 71 | 72 | modelConfig, pathModel, _ = checkpointData 73 | if scale is None: 74 | _, scale, _ = parse_state_name(pathModel) 75 | 76 | keysLabels = None 77 | with open(modelConfig, 'rb') as file: 78 | keysLabels = json.load(file)["attribKeysOrder"] 79 | if keysLabels is None: 80 | keysLabels = {} 81 | 82 | parser = updateParserWithLabels(parser, keysLabels) 83 | 84 | kwargs = vars(parser.parse_args()) 85 | 86 | if kwargs['showLabels']: 87 | parser.print_help() 88 | sys.exit() 89 | 90 | interpolationPath = getVal(kwargs, 'interpolationPath', None) 91 | 92 | pathLoss = os.path.join(checkPointDir, name + "_losses.pkl") 93 | pathOut = os.path.splitext(pathModel)[0] + "_fullavg.jpg" 94 | 95 | packageStr, modelTypeStr = getNameAndPackage(module) 96 | modelType = loadmodule(packageStr, modelTypeStr) 97 | exportMask = module in ["PPGAN"] 98 | 99 | visualizer = GANVisualizer( 100 | pathModel, modelConfig, modelType, visualisation) 101 | 102 | if interpolationPath is None and not kwargs['random_interpolate']: 103 | nImages = (256 // 2**(max(scale - 2, 3))) * 8 104 | visualizer.exportVisualization(pathOut, nImages, 105 | export_mask=exportMask) 106 | 107 | toPlot = {} 108 | for key in keysLabels: 109 | if kwargs.get(key, None) is not None: 110 | toPlot[key] = kwargs[key] 111 | 112 | if len(toPlot) > 0: 113 | visualizer.generateImagesFomConstraints( 114 | 16, toPlot, env=name + "_pictures") 115 | 116 | interpolationVectors = None 117 | if interpolationPath is not None: 118 | interpolationVectors = torch.load(interpolationPath) 119 | pathOut = os.path.splitext(interpolationPath)[0] + "_interpolations" 120 | elif kwargs['random_interpolate']: 121 | interpolationVectors, _ = visualizer.model.buildNoiseData(3) 122 | pathOut = os.path.splitext(pathModel)[0] + "_interpolations" 123 | 124 | if interpolationVectors is not None: 125 | 126 | if not os.path.isdir(pathOut): 127 | os.mkdir(pathOut) 128 | 129 | nImgs = interpolationVectors.size(0) 130 | for img in range(nImgs): 131 | 132 | indexNext = (img + 1) % nImgs 133 | path = os.path.join(pathOut, str(img) + "_" + str(indexNext)) 134 | 135 | if not os.path.isdir(path): 136 | os.mkdir(path) 137 | 138 | path = os.path.join(path, "") 139 | 140 | visualizer.saveInterpolation( 141 | 100, interpolationVectors[img], 142 | interpolationVectors[indexNext], path) 143 | 144 | outputDatasetPath = getVal(kwargs, "output_dataset", None) 145 | if outputDatasetPath is not None: 146 | print("Exporting a fake dataset at path " + outputDatasetPath) 147 | visualizer.exportDB(outputDatasetPath, kwargs["size_dataset"]) 148 | 149 | visualizer.plotLosses(pathLoss, name) 150 | -------------------------------------------------------------------------------- /models/gan_visualizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import os 3 | import json 4 | import math 5 | import pickle as pkl 6 | import torch 7 | import torch.nn.functional as F 8 | import torchvision.transforms as Transforms 9 | import numpy as np 10 | 11 | from .utils.image_transform import NumpyResize, NumpyToTensor 12 | from .datasets.attrib_dataset import pil_loader 13 | from .utils.utils import printProgressBar 14 | from .datasets.hd5 import H5Dataset 15 | 16 | 17 | class GANVisualizer(): 18 | r""" 19 | Several tools to export GAN generations 20 | """ 21 | 22 | def __init__(self, 23 | pathGan, 24 | pathConfig, 25 | ganType, 26 | visualizer): 27 | r""" 28 | Args 29 | pathGan (string): path to the GAN to load 30 | pathConfig (string): path to the GAN configuration 31 | ganType (BaseGANClass): type of GAn to load 32 | visualizer (visualizer class): either visualizer or np_visualizer 33 | """ 34 | 35 | with open(pathConfig, 'rb') as file: 36 | self.config = json.load(file) 37 | 38 | # TODO : update me 39 | self.model = ganType(useGPU=True, 40 | storeAVG=True, 41 | **self.config) 42 | 43 | self.model.load(pathGan) 44 | 45 | self.visualizer = visualizer 46 | self.keyShift = None 47 | 48 | self.buildKeyShift() 49 | 50 | def buildKeyShift(self): 51 | r""" 52 | Inilialize the labels shift for labelled models 53 | """ 54 | 55 | if self.model.config.attribKeysOrder is None: 56 | return 57 | 58 | self.keyShift = {f: {} 59 | for f in self.model.config.attribKeysOrder.keys()} 60 | 61 | for f in self.keyShift: 62 | 63 | order = self.model.config.attribKeysOrder[f]["order"] 64 | 65 | baseShift = sum([len(self.model.config.attribKeysOrder[f]["values"]) 66 | for f in self.model.config.attribKeysOrder 67 | if self.model.config.attribKeysOrder[f]["order"] < order]) 68 | for index, item in enumerate(self.model.config.attribKeysOrder[f]["values"]): 69 | self.keyShift[f][item] = baseShift + index 70 | 71 | def exportVisualization(self, 72 | path, 73 | nVisual=128, 74 | export_mask=False): 75 | r""" 76 | Save an image gathering sevral generations 77 | 78 | Args: 79 | path (string): output path of the image 80 | nVisual (int): number of generation to build 81 | export_mask (bool): for decoupled model, export the mask as well 82 | as the full output 83 | """ 84 | 85 | size = self.model.getSize()[0] 86 | maxBatchSize = max(1, int(256 / math.log(size, 2))) 87 | remaining = nVisual 88 | out = [] 89 | 90 | outTexture, outShape = [], [] 91 | 92 | while remaining > 0: 93 | 94 | currBatch = min(remaining, maxBatchSize) 95 | noiseData, _ = self.model.buildNoiseData(currBatch) 96 | img = self.model.test(noiseData, getAvG=True) 97 | out.append(img) 98 | 99 | if export_mask: 100 | try: 101 | _, shape, texture = self.model.getDetailledOutput( 102 | noiseData) 103 | outTexture.append(texture) 104 | outShape.append(shape) 105 | except AttributeError: 106 | print("WARNING, no mask available for this model") 107 | 108 | remaining -= currBatch 109 | 110 | toSave = torch.cat(out, dim=0) 111 | self.visualizer.saveTensor( 112 | toSave, (toSave.size()[2], toSave.size()[3]), path) 113 | 114 | if len(outTexture) > 0: 115 | toSave = torch.cat(outTexture, dim=0) 116 | pathTexture = os.path.splitext(path)[0] + "_texture.png" 117 | self.visualizer.saveTensor( 118 | toSave, (toSave.size()[2], toSave.size()[3]), pathTexture) 119 | 120 | toSave = torch.cat(outShape, dim=0) 121 | pathShape = os.path.splitext(path)[0] + "_shape.png" 122 | self.visualizer.saveTensor( 123 | toSave, (toSave.size()[2], toSave.size()[3]), pathShape) 124 | 125 | def exportDB(self, path, nItems): 126 | r""" 127 | Save dataset of fake generations 128 | 129 | Args: 130 | path (string): output path of the dataset 131 | nItems (int): number of generation to build 132 | """ 133 | 134 | size = self.model.getSize() 135 | maxBatchSize = max(1, int(256 / math.log(size[0], 2))) 136 | remaining = nItems 137 | 138 | index = 0 139 | 140 | if not os.path.isdir(path): 141 | os.mkdir(path) 142 | 143 | while remaining > 0: 144 | currBatch = min(remaining, maxBatchSize) 145 | noiseData, _ = self.model.buildNoiseData(currBatch) 146 | img = self.model.test(noiseData, getAvG=True, toCPU=True) 147 | 148 | for i in range(currBatch): 149 | imgPath = os.path.join(path, "gen_" + str(index) + ".jpg") 150 | self.visualizer.saveTensor(img[i].view(1, 3, size[0], size[1]), 151 | size, imgPath) 152 | index += 1 153 | 154 | remaining -= currBatch 155 | 156 | def generateImagesFomConstraints(self, 157 | nImages, 158 | constraints, 159 | env="visual", 160 | path=None): 161 | r""" 162 | Given label constraints, generate a set of images. 163 | 164 | Args: 165 | nImages (int): number of images to generate 166 | constraints (dict): set of constraints in the form of 167 | {attribute:label}. For example 168 | 169 | {"Gender": "Man", 170 | "Color": blue} 171 | env (string): visdom only, visdom environement where the 172 | generations should be exported 173 | path (string): if not None. Path wher the generations should be 174 | saved 175 | """ 176 | 177 | input = self.model.buildNoiseDataWithConstraints(nImages, constraints) 178 | outImg = self.model.test(input, getAvG=True) 179 | 180 | outSize = (outImg.size()[2], outImg.size()[3]) 181 | self.visualizer.publishTensors( 182 | outImg, outSize, 183 | caption="test", 184 | env=env) 185 | 186 | if path is not None: 187 | self.visualizer.saveTensor(outImg, outSize, path) 188 | 189 | def plotLosses(self, pathLoss, name="Data", clear=True): 190 | r""" 191 | Plot some losses in visdom 192 | 193 | Args: 194 | 195 | pathLoss (string): path to the pickle file where the loss are 196 | stored 197 | name (string): model name 198 | clear (bool): if True clear the visdom environement before plotting 199 | """ 200 | 201 | with open(pathLoss, 'rb') as file: 202 | lossData = pkl.load(file) 203 | 204 | nScales = len(lossData) 205 | 206 | for scale in range(nScales): 207 | 208 | locName = name + ("_s%d" % scale) 209 | 210 | if clear: 211 | self.visualizer.delete_env(locName) 212 | 213 | self.visualizer.publishLoss(lossData[scale], 214 | locName, 215 | env=locName) 216 | 217 | def saveInterpolation(self, N, vectorStart, vectorEnd, pathOut): 218 | r""" 219 | Given two latent vactors, export the interpolated generations between 220 | them. 221 | 222 | Args: 223 | 224 | N (int): number of interpolation to make 225 | vectorStart (torch.tensor): start latent vector 226 | vectorEnd (torch.tensor): end latent vector 227 | pathOut (string): path where the images sould be saved 228 | """ 229 | 230 | sizeStep = 1.0 / (N - 1) 231 | pathOut = os.path.splitext(pathOut)[0] 232 | 233 | vectorStart = vectorStart.view(1, -1, 1, 1) 234 | vectorEnd = vectorEnd.view(1, -1, 1, 1) 235 | 236 | nZeros = int(math.log10(N) + 1) 237 | 238 | for i in range(N): 239 | path = pathOut + str(i).zfill(nZeros) + ".jpg" 240 | t = i * sizeStep 241 | vector = (1 - t) * vectorStart + t * vectorEnd 242 | 243 | outImg = self.model.test(vector, getAvG=True, toCPU=True) 244 | self.visualizer.saveTensor( 245 | outImg, (outImg.size(2), outImg.size(3)), path) 246 | 247 | def visualizeNN(self, 248 | N, 249 | k, 250 | featureExtractor, 251 | imgTransform, 252 | nnSearch, 253 | names, 254 | pathDB): 255 | r""" 256 | Visualize the nearest neighbors of some random generations 257 | 258 | Args: 259 | 260 | N (int): number of generation to make 261 | k (int): number of neighbors to fetch 262 | featureExtractor (nn.Module): feature extractor 263 | imgTransform (nn.Module): image transform module 264 | nnSearch (np.KDTree): serach tree for the features 265 | names (list): a match between an image index and its name 266 | """ 267 | 268 | batchSize = 16 269 | nImages = 0 270 | 271 | vectorOut = [] 272 | 273 | size = self.model.getSize()[0] 274 | 275 | transform = Transforms.Compose([NumpyResize((size, size)), 276 | NumpyToTensor(), 277 | Transforms.Normalize((0.5, 0.5, 0.5), 278 | (0.5, 0.5, 0.5))]) 279 | 280 | dataset = None 281 | 282 | if os.path.splitext(pathDB)[1] == ".h5": 283 | dataset = H5Dataset(pathDB, 284 | transform=Transforms.Compose( 285 | [NumpyToTensor(), 286 | Transforms.Normalize((0.5, 0.5, 0.5), 287 | (0.5, 0.5, 0.5))])) 288 | 289 | while nImages < N: 290 | 291 | noiseData, _ = self.model.buildNoiseData(batchSize) 292 | imgOut = self.model.test( 293 | noiseData, getAvG=True, toCPU=False).detach() 294 | 295 | features = featureExtractor(imgTransform(imgOut)).detach().view( 296 | imgOut.size(0), -1).cpu().numpy() 297 | distances, indexes = nnSearch.query(features, k) 298 | nImages += batchSize 299 | 300 | for p in range(N): 301 | 302 | vectorOut.append(imgOut[p].cpu().view( 303 | 1, imgOut.size(1), imgOut.size(2), imgOut.size(3))) 304 | for ki in range(k): 305 | 306 | i = indexes[p][ki] 307 | if dataset is None: 308 | path = os.path.join(pathDB, names[i]) 309 | imgSource = transform(pil_loader(path)) 310 | imgSource = imgSource.view(1, imgSource.size( 311 | 0), imgSource.size(1), imgSource.size(2)) 312 | 313 | else: 314 | imgSource, _ = dataset[names[i]] 315 | imgSource = imgSource.view(1, imgSource.size( 316 | 0), imgSource.size(1), imgSource.size(2)) 317 | imgSource = F.upsample( 318 | imgSource, size=(size, size), mode='bilinear') 319 | 320 | vectorOut.append(imgSource) 321 | 322 | vectorOut = torch.cat(vectorOut, dim=0) 323 | self.visualizer.publishTensors(vectorOut, (224, 224), nrow=k + 1) 324 | 325 | def exportNN(self, N, k, featureExtractor, imgTransform, nnSearch): 326 | r""" 327 | Compute the nearest neighbors metric 328 | 329 | Args: 330 | 331 | N (int): number of generation to sample 332 | k (int): number of nearest neighbors to fetch 333 | featureExtractor (nn.Module): feature extractor 334 | imgTransform (nn.Module): image transform module 335 | nnSearch (np.KDTree): serach tree for the features 336 | """ 337 | 338 | batchSize = 16 339 | nImages = 0 340 | 341 | vectorOut = np.zeros(k) 342 | 343 | print("Computing the NN metric") 344 | while nImages < N: 345 | 346 | printProgressBar(nImages, N) 347 | 348 | noiseData, _ = self.model.buildNoiseData(batchSize) 349 | imgOut = self.model.test( 350 | noiseData, getAvG=True, toCPU=False).detach() 351 | 352 | features = featureExtractor(imgTransform(imgOut)).detach().view( 353 | imgOut.size(0), -1).cpu().numpy() 354 | distances = nnSearch.query(features, k)[0] 355 | vectorOut += distances.sum(axis=0) 356 | nImages += batchSize 357 | 358 | printProgressBar(N, N) 359 | vectorOut /= nImages 360 | return vectorOut 361 | -------------------------------------------------------------------------------- /models/loss_criterions/GDPP_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | def GDPPLoss(phiFake, phiReal, backward=True): 7 | r""" 8 | Implementation of the GDPP loss. Can be used with any kind of GAN 9 | architecture. 10 | 11 | Args: 12 | 13 | phiFake (tensor) : last feature layer of the discriminator on real data 14 | phiReal (tensor) : last feature layer of the discriminator on fake data 15 | backward (bool) : should we perform the backward operation ? 16 | 17 | Returns: 18 | 19 | Loss's value. The backward operation in performed within this operator 20 | """ 21 | def compute_diversity(phi): 22 | phi = F.normalize(phi, p=2, dim=1) 23 | SB = torch.mm(phi, phi.t()) 24 | eigVals, eigVecs = torch.symeig(SB, eigenvectors=True) 25 | return eigVals, eigVecs 26 | 27 | def normalize_min_max(eigVals): 28 | minV, maxV = torch.min(eigVals), torch.max(eigVals) 29 | if abs(minV - maxV) < 1e-10: 30 | return eigVals 31 | return (eigVals - minV) / (maxV - minV) 32 | 33 | fakeEigVals, fakeEigVecs = compute_diversity(phiFake) 34 | realEigVals, realEigVecs = compute_diversity(phiReal) 35 | 36 | # Scaling factor to make the two losses operating in comparable ranges. 37 | magnitudeLoss = 0.0001 * F.mse_loss(target=realEigVals, input=fakeEigVals) 38 | structureLoss = -torch.sum(torch.mul(fakeEigVecs, realEigVecs), 0) 39 | normalizedRealEigVals = normalize_min_max(realEigVals) 40 | weightedStructureLoss = torch.sum( 41 | torch.mul(normalizedRealEigVals, structureLoss)) 42 | gdppLoss = magnitudeLoss + weightedStructureLoss 43 | 44 | if backward: 45 | gdppLoss.backward(retain_graph=True) 46 | 47 | return gdppLoss.item() 48 | -------------------------------------------------------------------------------- /models/loss_criterions/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | -------------------------------------------------------------------------------- /models/loss_criterions/ac_criterion.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | from copy import deepcopy 3 | from random import randint 4 | import torch 5 | import torch.nn.functional as F 6 | import numpy as np 7 | 8 | 9 | class ACGANCriterion: 10 | r""" 11 | Class implementing all tools necessary for a GAN to take into account class 12 | conditionning while generating a model (cf Odena's AC-GAN) 13 | https://arxiv.org/pdf/1610.09585.pdf 14 | """ 15 | 16 | def __init__(self, 17 | attribKeysOrder): 18 | r""" 19 | Args: 20 | 21 | attribKeysOrder (dict): dictionary containing the all the possible 22 | categories of the DCGAN model and their 23 | order. 24 | 25 | Each entry of the attribKeysOrder is 26 | another dictionary with two fields: 27 | 28 | order: order of the category in the input 29 | vector 30 | values: possible values taken by this 31 | category 32 | 33 | Such a dictionary is returned by 34 | models.datasets.attrib_dataset.AttribDataset.getKeyOrders() 35 | Ex: 36 | attribKeysOrder = {"Gender": {"order": 0, "values":["M", "W"]}, 37 | "Nationality": {"order": 1, 38 | "values":["english", 39 | "french", 40 | "indian"]} 41 | } 42 | allowMultiple = ["Nationality"] 43 | 44 | Then a category vector corresponding to this pair could be: 45 | V = [0, 1, 1, 1, 0] 46 | 47 | Which would correspond to a sample of gender "W" and 48 | nationalities "english" and "french" 49 | """ 50 | self.nAttrib = len(attribKeysOrder) 51 | self.attribSize = [0 for i in range(self.nAttrib)] 52 | self.keyOrder = ['' for x in range(self.nAttrib)] 53 | self.labelsOrder = {} 54 | 55 | self.inputDict = deepcopy(attribKeysOrder) 56 | 57 | for key in attribKeysOrder: 58 | order = attribKeysOrder[key]["order"] 59 | self.keyOrder[order] = key 60 | self.attribSize[order] = len(attribKeysOrder[key]["values"]) 61 | self.labelsOrder[key] = {index: label for label, index in 62 | enumerate(attribKeysOrder[key]["values"])} 63 | 64 | self.labelWeights = torch.tensor( 65 | [1.0 for x in range(self.getInputDim())]) 66 | 67 | for key in attribKeysOrder: 68 | order = attribKeysOrder[key]["order"] 69 | if attribKeysOrder[key].get('weights', None) is not None: 70 | shift = sum(self.attribSize[:order]) 71 | 72 | for value, weight in attribKeysOrder[key]['weights'].items(): 73 | self.labelWeights[shift + 74 | self.labelsOrder[key][value]] = weight 75 | 76 | self.sizeOutput = self.nAttrib 77 | 78 | def generateConstraintsFromVector(self, n, labels): 79 | 80 | vect = [] 81 | 82 | for i in range(self.nAttrib): 83 | C = self.attribSize[i] 84 | key = self.keyOrder[i] 85 | 86 | if key in labels: 87 | value = labels[key] 88 | index = self.labelsOrder[key][value] 89 | out = torch.zeros(n, C, 1, 1) 90 | out[:, index] = 1 91 | else: 92 | v = np.random.randint(0, C, n) 93 | w = np.zeros((n, C), dtype='float32') 94 | w[np.arange(n), v] = 1 95 | out = torch.tensor(w).view(n, C, 1, 1) 96 | 97 | vect.append(out) 98 | return torch.cat(vect, dim=1) 99 | 100 | def buildRandomCriterionTensor(self, sizeBatch): 101 | r""" 102 | Build a batch of vectors with a random combination of the values of the 103 | existing classes 104 | 105 | Args: 106 | sizeBatch (int): number of vectors to generate 107 | 108 | Return: 109 | targetVector, latentVector 110 | 111 | targetVector : [sizeBatch, M] tensor used as a reference for the 112 | loss computation (see self.getLoss) 113 | latentVector : [sizeBatch, M', 1, 1] tensor. Should be 114 | concatenatenated with the random GAN input latent 115 | veCtor 116 | 117 | M' > M, input latent data should be coded with one-hot inputs while 118 | pytorch requires a different format for softmax loss 119 | (see self.getLoss) 120 | """ 121 | targetOut = [] 122 | inputLatent = [] 123 | 124 | for i in range(self.nAttrib): 125 | C = self.attribSize[i] 126 | v = np.random.randint(0, C, sizeBatch) 127 | w = np.zeros((sizeBatch, C), dtype='float32') 128 | w[np.arange(sizeBatch), v] = 1 129 | y = torch.tensor(w).view(sizeBatch, C) 130 | 131 | inputLatent.append(y) 132 | targetOut.append(torch.tensor(v).float().view(sizeBatch, 1)) 133 | 134 | return torch.cat(targetOut, dim=1), torch.cat(inputLatent, dim=1) 135 | 136 | def buildLatentCriterion(self, targetCat): 137 | 138 | batchSize = targetCat.size(0) 139 | idx = torch.arange(batchSize, device=targetCat.device) 140 | targetOut = torch.zeros((batchSize, sum(self.attribSize))) 141 | shift = 0 142 | 143 | for i in range(self.nAttrib): 144 | targetOut[idx, shift + targetCat[:, i]] = 1 145 | shift += self.attribSize[i] 146 | 147 | return targetOut 148 | 149 | def getInputDim(self): 150 | r""" 151 | Size of the latent vector given by self.buildRandomCriterionTensor 152 | """ 153 | return sum(self.attribSize) 154 | 155 | def getCriterion(self, outputD, target): 156 | r""" 157 | Compute the conditional loss between the network's output and the 158 | target. This loss, L, is the sum of the losses Lc of the categories 159 | defined in the criterion. We have: 160 | 161 | | Cross entropy loss for the class c if c is attached to a 162 | classification task. 163 | Lc = | Multi label soft margin loss for the class c if c is 164 | attached to a tagging task 165 | """ 166 | loss = 0 167 | shiftInput = 0 168 | shiftTarget = 0 169 | 170 | self.labelWeights = self.labelWeights.to(outputD.device) 171 | 172 | for i in range(self.nAttrib): 173 | C = self.attribSize[i] 174 | locInput = outputD[:, shiftInput:(shiftInput+C)] 175 | locTarget = target[:, shiftTarget].long() 176 | locLoss = F.cross_entropy(locInput, locTarget 177 | , weight=self.labelWeights[shiftInput:(shiftInput+C)]) 178 | shiftTarget += 1 179 | loss += locLoss 180 | shiftInput += C 181 | 182 | return loss 183 | -------------------------------------------------------------------------------- /models/loss_criterions/base_loss_criterions.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | class BaseLossWrapper: 7 | r""" 8 | Loss criterion class. Must define 4 members: 9 | sizeDecisionLayer : size of the decision layer of the discrimator 10 | 11 | getCriterion : how the loss is actually computed 12 | 13 | !! The activation function of the discriminator is computed within the 14 | loss !! 15 | """ 16 | 17 | def __init__(self, device): 18 | self.device = device 19 | 20 | def getCriterion(self, input, status): 21 | r""" 22 | Given an input tensor and its targeted status (detected as real or 23 | detected as fake) build the associated loss 24 | 25 | Args: 26 | 27 | - input (Tensor): decision tensor build by the model's discrimator 28 | - status (bool): if True -> this tensor should have been detected 29 | as a real input 30 | else -> it shouldn't have 31 | """ 32 | pass 33 | 34 | 35 | class MSE(BaseLossWrapper): 36 | r""" 37 | Mean Square error loss. 38 | """ 39 | 40 | def __init__(self, device): 41 | self.generationActivation = F.tanh 42 | self.sizeDecisionLayer = 1 43 | 44 | BaseLossWrapper.__init__(self, device) 45 | 46 | def getCriterion(self, input, status): 47 | size = input.size()[0] 48 | value = float(status) 49 | reference = torch.tensor([value]).expand(size, 1).to(self.device) 50 | return F.mse_loss(F.sigmoid(input[:, :self.sizeDecisionLayer]), 51 | reference) 52 | 53 | 54 | class WGANGP(BaseLossWrapper): 55 | r""" 56 | Paper WGANGP loss : linear activation for the generator. 57 | https://arxiv.org/pdf/1704.00028.pdf 58 | """ 59 | 60 | def __init__(self, device): 61 | 62 | self.generationActivation = None 63 | self.sizeDecisionLayer = 1 64 | 65 | BaseLossWrapper.__init__(self, device) 66 | 67 | def getCriterion(self, input, status): 68 | if status: 69 | return -input[:, 0].sum() 70 | return input[:, 0].sum() 71 | 72 | 73 | class Logistic(BaseLossWrapper): 74 | r""" 75 | "Which training method of GANs actually converge" 76 | https://arxiv.org/pdf/1801.04406.pdf 77 | """ 78 | 79 | def __init__(self, device): 80 | 81 | self.generationActivation = None 82 | self.sizeDecisionLayer = 1 83 | BaseLossWrapper.__init__(self, device) 84 | 85 | def getCriterion(self, input, status): 86 | if status: 87 | return F.softplus(-input[:, 0]).mean() 88 | return F.softplus(input[:, 0]).mean() 89 | 90 | 91 | class DCGAN(BaseLossWrapper): 92 | r""" 93 | Cross entropy loss. 94 | """ 95 | 96 | def __init__(self, device): 97 | 98 | self.generationActivation = F.tanh 99 | self.sizeDecisionLayer = 1 100 | 101 | BaseLossWrapper.__init__(self, device) 102 | 103 | def getCriterion(self, input, status): 104 | size = input.size()[0] 105 | value = int(status) 106 | reference = torch.tensor( 107 | [value], dtype=torch.float).expand(size).to(self.device) 108 | return F.binary_cross_entropy(torch.sigmoid(input[:, :self.sizeDecisionLayer]), reference) 109 | -------------------------------------------------------------------------------- /models/loss_criterions/gradient_losses.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import torch 3 | 4 | 5 | def WGANGPGradientPenalty(input, fake, discriminator, weight, backward=True): 6 | r""" 7 | Gradient penalty as described in 8 | "Improved Training of Wasserstein GANs" 9 | https://arxiv.org/pdf/1704.00028.pdf 10 | 11 | Args: 12 | 13 | - input (Tensor): batch of real data 14 | - fake (Tensor): batch of generated data. Must have the same size 15 | as the input 16 | - discrimator (nn.Module): discriminator network 17 | - weight (float): weight to apply to the penalty term 18 | - backward (bool): loss backpropagation 19 | """ 20 | 21 | batchSize = input.size(0) 22 | alpha = torch.rand(batchSize, 1) 23 | alpha = alpha.expand(batchSize, int(input.nelement() / 24 | batchSize)).contiguous().view( 25 | input.size()) 26 | alpha = alpha.to(input.device) 27 | interpolates = alpha * input + ((1 - alpha) * fake) 28 | 29 | interpolates = torch.autograd.Variable( 30 | interpolates, requires_grad=True) 31 | 32 | decisionInterpolate = discriminator(interpolates, False) 33 | decisionInterpolate = decisionInterpolate[:, 0].sum() 34 | 35 | gradients = torch.autograd.grad(outputs=decisionInterpolate, 36 | inputs=interpolates, 37 | create_graph=True, retain_graph=True) 38 | 39 | gradients = gradients[0].view(batchSize, -1) 40 | gradients = (gradients * gradients).sum(dim=1).sqrt() 41 | gradient_penalty = (((gradients - 1.0)**2)).sum() * weight 42 | 43 | if backward: 44 | gradient_penalty.backward(retain_graph=True) 45 | 46 | return gradient_penalty.item() 47 | 48 | 49 | def logisticGradientPenalty(input, discrimator, weight, backward=True): 50 | r""" 51 | Gradient penalty described in "Which training method of GANs actually 52 | converge 53 | https://arxiv.org/pdf/1801.04406.pdf 54 | 55 | Args: 56 | 57 | - input (Tensor): batch of real data 58 | - discrimator (nn.Module): discriminator network 59 | - weight (float): weight to apply to the penalty term 60 | - backward (bool): loss backpropagation 61 | """ 62 | 63 | locInput = torch.autograd.Variable( 64 | input, requires_grad=True) 65 | gradients = torch.autograd.grad(outputs=discrimator(locInput)[:, 0].sum(), 66 | inputs=locInput, 67 | create_graph=True, retain_graph=True)[0] 68 | 69 | gradients = gradients.view(gradients.size(0), -1) 70 | gradients = (gradients * gradients).sum(dim=1).mean() 71 | 72 | gradient_penalty = gradients * weight 73 | if backward: 74 | gradient_penalty.backward(retain_graph=True) 75 | 76 | return gradient_penalty.item() 77 | -------------------------------------------------------------------------------- /models/loss_criterions/logistic_loss.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/pytorch_GAN_zoo/b75dee40918caabb4fe7ec561522717bf096a8cb/models/loss_criterions/logistic_loss.py -------------------------------------------------------------------------------- /models/loss_criterions/loss_texture.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import torch 3 | 4 | import torch.nn.functional as F 5 | 6 | from ..utils.utils import loadmodule 7 | from ..networks.constant_net import FeatureTransform 8 | 9 | 10 | def extractRelUIndexes(sequence, layers): 11 | 12 | layers.sort() 13 | 14 | index = 0 15 | output = [] 16 | 17 | indexRef = 0 18 | indexScale = 1 19 | 20 | hasCaughtRelUOnLayer = False 21 | while indexRef < len(layers) and index < len(sequence): 22 | 23 | if isinstance(sequence[index], torch.nn.ReLU): 24 | 25 | if not hasCaughtRelUOnLayer and indexScale == layers[indexRef]: 26 | 27 | hasCaughtRelUOnLayer = True 28 | output.append(index) 29 | indexRef += 1 30 | 31 | if isinstance(sequence[index], torch.nn.MaxPool2d) \ 32 | or isinstance(sequence[index], torch.nn.AvgPool2d): 33 | 34 | hasCaughtRelUOnLayer = False 35 | indexScale += 1 36 | 37 | index += 1 38 | 39 | return output 40 | 41 | 42 | def extractIndexedLayers(sequence, 43 | x, 44 | indexes, 45 | detach): 46 | 47 | index = 0 48 | output = [] 49 | 50 | indexes.sort() 51 | 52 | for iSeq, layer in enumerate(sequence): 53 | 54 | if index >= len(indexes): 55 | break 56 | 57 | x = layer(x) 58 | 59 | if iSeq == indexes[index]: 60 | if detach: 61 | output.append(x.view(x.size(0), x.size(1), -1).detach()) 62 | else: 63 | output.append(x.view(x.size(0), x.size(1), -1)) 64 | index += 1 65 | 66 | return output 67 | 68 | 69 | class LossTexture(torch.nn.Module): 70 | r""" 71 | An implenetation of style transfer's (http://arxiv.org/abs/1703.06868) like 72 | loss. 73 | """ 74 | 75 | def __init__(self, 76 | device, 77 | modelName, 78 | scalesOut): 79 | r""" 80 | Args: 81 | - device (torch.device): torch.device("cpu") or 82 | torch.device("cuda:0") 83 | - modelName (string): name of the torchvision.models model. For 84 | example vgg19 85 | - scalesOut (list): index of the scales to extract. In the Style 86 | transfer paper it was [1,2,3,4] 87 | """ 88 | 89 | super(LossTexture, self).__init__() 90 | scalesOut.sort() 91 | 92 | model = loadmodule("torchvision.models", modelName, prefix='') 93 | self.featuresSeq = model(pretrained=True).features.to(device) 94 | self.indexLayers = extractRelUIndexes(self.featuresSeq, scalesOut) 95 | 96 | self.reductionFactor = [1 / float(2**(i - 1)) for i in scalesOut] 97 | 98 | refMean = [2*p - 1 for p in[0.485, 0.456, 0.406]] 99 | refSTD = [2*p for p in [0.229, 0.224, 0.225]] 100 | 101 | self.imgTransform = FeatureTransform(mean=refMean, 102 | std=refSTD, 103 | size=None) 104 | 105 | self.imgTransform = self.imgTransform.to(device) 106 | 107 | def getLoss(self, fake, reals, mask=None): 108 | 109 | featuresReals = self.getFeatures( 110 | reals, detach=True, prepImg=True, mask=mask).mean(dim=0) 111 | featuresFakes = self.getFeatures( 112 | fake, detach=False, prepImg=True, mask=None).mean(dim=0) 113 | 114 | outLoss = ((featuresReals - featuresFakes)**2).mean() 115 | return outLoss 116 | 117 | def getFeatures(self, image, detach=True, prepImg=True, mask=None): 118 | 119 | if prepImg: 120 | image = self.imgTransform(image) 121 | 122 | fullSequence = extractIndexedLayers(self.featuresSeq, 123 | image, 124 | self.indexLayers, 125 | detach) 126 | outFeatures = [] 127 | nFeatures = len(fullSequence) 128 | 129 | for i in range(nFeatures): 130 | 131 | if mask is not None: 132 | locMask = (1. + F.upsample(mask, 133 | size=(image.size(2) * self.reductionFactor[i], 134 | image.size(3) * self.reductionFactor[i]), 135 | mode='bilinear')) * 0.5 136 | locMask = locMask.view(locMask.size(0), locMask.size(1), -1) 137 | 138 | totVal = locMask.sum(dim=2) 139 | 140 | meanReals = (fullSequence[i] * locMask).sum(dim=2) / totVal 141 | varReals = ( 142 | (fullSequence[i]*fullSequence[i] * locMask).sum(dim=2) / totVal) - meanReals*meanReals 143 | 144 | else: 145 | meanReals = fullSequence[i].mean(dim=2) 146 | varReals = ( 147 | (fullSequence[i]*fullSequence[i]).mean(dim=2))\ 148 | - meanReals*meanReals 149 | 150 | outFeatures.append(meanReals) 151 | outFeatures.append(varReals) 152 | 153 | return torch.cat(outFeatures, dim=1) 154 | 155 | def forward(self, x, mask=None): 156 | 157 | return self.getFeatures(x, detach=False, prepImg=False, mask=mask) 158 | 159 | def saveModel(self, pathOut): 160 | 161 | torch.save(dict(model=self, fullDump=True, 162 | mean=self.imgTransform.mean.view(-1).tolist(), 163 | std=self.imgTransform.std.view(-1).tolist()), 164 | pathOut) 165 | -------------------------------------------------------------------------------- /models/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | -------------------------------------------------------------------------------- /models/metrics/inception_score.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import math 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | class InceptionScore(): 8 | def __init__(self, classifier): 9 | 10 | self.sumEntropy = 0 11 | self.sumSoftMax = None 12 | self.nItems = 0 13 | self.classifier = classifier.eval() 14 | 15 | def updateWithMiniBatch(self, ref): 16 | y = self.classifier(ref).detach() 17 | 18 | if self.sumSoftMax is None: 19 | self.sumSoftMax = torch.zeros(y.size()[1]).to(ref.device) 20 | 21 | # Entropy 22 | x = F.softmax(y, dim=1) * F.log_softmax(y, dim=1) 23 | self.sumEntropy += x.sum().item() 24 | 25 | # Sum soft max 26 | self.sumSoftMax += F.softmax(y, dim=1).sum(dim=0) 27 | 28 | # N items 29 | self.nItems += y.size()[0] 30 | 31 | def getScore(self): 32 | 33 | x = self.sumSoftMax 34 | x = x * torch.log(x / self.nItems) 35 | output = self.sumEntropy - (x.sum().item()) 36 | output /= self.nItems 37 | return math.exp(output) 38 | -------------------------------------------------------------------------------- /models/metrics/laplacian_swd.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import torch 3 | import torch.nn as nn 4 | 5 | import numpy as np 6 | from ..utils.utils import printProgressBar 7 | 8 | 9 | def getDescriptorsForMinibatch(minibatch, patchSize, nPatches): 10 | r""" 11 | Extract @param nPatches randomly chosen of size patchSize x patchSize 12 | from each image of the input @param minibatch 13 | 14 | Returns: 15 | 16 | A tensor of SxCxpatchSizexpatchSize where 17 | S = minibatch.size()[0] * nPatches is the total number of patches 18 | extracted from the minibatch. 19 | """ 20 | S = minibatch.size() 21 | 22 | maxX = S[2] - patchSize 23 | maxY = S[3] - patchSize 24 | 25 | baseX = torch.arange(0, patchSize, dtype=torch.long).expand(S[0] * nPatches, 26 | patchSize) \ 27 | + torch.randint(0, maxX, (S[0] * nPatches, 1), dtype=torch.long) 28 | baseY = torch.arange(0, patchSize, dtype=torch.long).expand(S[0] * nPatches, 29 | patchSize) \ 30 | + torch.randint(0, maxY, (S[0] * nPatches, 1), dtype=torch.long) 31 | 32 | baseX = baseX.view(S[0], nPatches, 1, patchSize).expand( 33 | S[0], nPatches, patchSize, patchSize) 34 | baseY = S[2] * baseY.view(S[0], nPatches, patchSize, 1) 35 | baseY = baseY.expand(S[0], nPatches, patchSize, patchSize) 36 | 37 | coords = baseX + baseY 38 | coords = coords.view(S[0], nPatches, 1, patchSize, patchSize).expand( 39 | S[0], nPatches, S[1], patchSize, patchSize) 40 | C = torch.arange(0, S[1], dtype=torch.long).view( 41 | 1, S[1]).expand(nPatches * S[0], S[1])*S[2]*S[3] 42 | coords = C.view(S[0], nPatches, S[1], 1, 1) + coords 43 | coords = coords.view(-1) 44 | 45 | return (minibatch.contiguous().view(-1)[coords]).view(-1, S[1], 46 | patchSize, patchSize) 47 | 48 | 49 | def getMeanStdDesc(desc): 50 | r""" 51 | Get the mean and the standard deviation of each channel accross the input 52 | batch. 53 | """ 54 | S = desc.size() 55 | assert len(S) == 4 56 | mean = torch.sum(desc.view(S[0], S[1], -1), 57 | dim=2).sum(dim=0) / (S[0] * S[3] * S[2]) 58 | var = torch.sum( 59 | (desc*desc).view(S[0], S[1], -1), dim=2).sum(dim=0) / \ 60 | (S[0] * S[3] * S[2]) 61 | var -= mean*mean 62 | var = var.clamp(min=0).sqrt().view(1, S[1]).expand(S[0], S[1]) 63 | mean = (mean.view(1, S[1])).expand(S[0], S[1]) 64 | 65 | return mean.view(S[0], S[1], 1, 1), var.view(S[0], S[1], 1, 1) 66 | 67 | 68 | # ------------------------------------------------------------------------------- 69 | # Laplacian pyramid generation, with LaplacianSWDMetric.convolution as input, 70 | # matches the corresponding openCV functions 71 | # ------------------------------------------------------------------------------- 72 | 73 | def pyrDown(minibatch, convolution): 74 | x = torch.nn.ReflectionPad2d(2)(minibatch) 75 | return convolution(x)[:, :, ::2, ::2].detach() 76 | 77 | 78 | def pyrUp(minibatch, convolution): 79 | S = minibatch.size() 80 | res = torch.zeros((S[0], S[1], S[2] * 2, S[3] * 2), 81 | dtype=minibatch.dtype).to(minibatch.device) 82 | res[:, :, ::2, ::2] = minibatch 83 | res = torch.nn.ReflectionPad2d(2)(res) 84 | return 4 * convolution(res).detach() 85 | 86 | # ---------------------------------------------------------------------------- 87 | 88 | 89 | def sliced_wasserstein(A, B, dir_repeats, dirs_per_repeat): 90 | r""" 91 | NVIDIA's approximation of the SWD distance. 92 | """ 93 | # (neighborhood, descriptor_component) 94 | assert A.ndim == 2 and A.shape == B.shape 95 | results = [] 96 | for repeat in range(dir_repeats): 97 | # (descriptor_component, direction) 98 | dirs = np.random.randn(A.shape[1], dirs_per_repeat) 99 | # normalize descriptor components for each direction 100 | dirs /= np.sqrt(np.sum(np.square(dirs), axis=0, keepdims=True)) 101 | dirs = dirs.astype(np.float32) 102 | # (neighborhood, direction) 103 | projA = np.matmul(A, dirs) 104 | projB = np.matmul(B, dirs) 105 | # sort neighborhood projections for each direction 106 | projA = np.sort(projA, axis=0) 107 | projB = np.sort(projB, axis=0) 108 | # pointwise wasserstein distances 109 | dists = np.abs(projA - projB) 110 | # average over neighborhoods and directions 111 | results.append(np.mean(dists)) 112 | return np.mean(results).item() 113 | 114 | 115 | def sliced_wasserstein_torch(A, B, dir_repeats, dirs_per_repeat): 116 | r""" 117 | NVIDIA's approximation of the SWD distance. 118 | """ 119 | results = [] 120 | for repeat in range(dir_repeats): 121 | # (descriptor_component, direction) 122 | dirs = torch.randn(A.size()[1], dirs_per_repeat, 123 | device=A.device, dtype=torch.float32) 124 | # normalize descriptor components for each direction 125 | dirs /= torch.sqrt(torch.sum(dirs*dirs, 0, keepdim=True)) 126 | # (neighborhood, direction) 127 | projA = torch.matmul(A, dirs) 128 | projB = torch.matmul(B, dirs) 129 | # sort neighborhood projections for each direction 130 | projA = torch.sort(projA, dim=0)[0] 131 | projB = torch.sort(projB, dim=0)[0] 132 | # pointwise wasserstein distances 133 | dists = torch.abs(projA - projB) 134 | # average over neighborhoods and directions 135 | results.append(torch.mean(dists).item()) 136 | return sum(results) / float(len(results)) 137 | 138 | 139 | def finalize_descriptors(desc): 140 | if isinstance(desc, list): 141 | desc = np.concatenate(desc, axis=0) 142 | assert desc.ndim == 4 # (neighborhood, channel, height, width) 143 | desc -= np.mean(desc, axis=(0, 2, 3), keepdims=True) 144 | desc /= np.std(desc, axis=(0, 2, 3), keepdims=True) 145 | desc = desc.reshape(desc.shape[0], -1) 146 | return desc 147 | 148 | 149 | class LaplacianSWDMetric: 150 | r""" 151 | SWD metrics used on patches extracted from laplacian pyramids of the input 152 | images. 153 | """ 154 | 155 | def __init__(self, 156 | patchSize, 157 | nDescriptorLevel, 158 | depthPyramid): 159 | r""" 160 | Args: 161 | patchSize (int): side length of each patch to extract 162 | nDescriptorLevel (int): number of patches to extract at each level 163 | of the pyramid 164 | depthPyramid (int): depth of the laplacian pyramid 165 | """ 166 | self.patchSize = patchSize 167 | self.nDescriptorLevel = nDescriptorLevel 168 | self.depthPyramid = depthPyramid 169 | 170 | self.descriptorsRef = [[] for x in range(depthPyramid)] 171 | self.descriptorsTarget = [[] for x in range(depthPyramid)] 172 | 173 | self.convolution = None 174 | 175 | def updateWithMiniBatch(self, ref, target): 176 | r""" 177 | Extract and store decsriptors from the current minibatch 178 | Args: 179 | ref (tensor): reference data. 180 | target (tensor): target data. 181 | 182 | Both tensor must have the same format: NxCxWxD 183 | N: minibatch size 184 | C: number of channels 185 | W: with 186 | H: height 187 | """ 188 | target = target.to(ref.device) 189 | modes = [(ref, self.descriptorsRef), (target, self.descriptorsTarget)] 190 | 191 | assert(ref.size() == target.size()) 192 | 193 | if not self.convolution: 194 | self.initConvolution(ref.device) 195 | 196 | for item, dest in modes: 197 | pyramid = self.generateLaplacianPyramid(item, self.depthPyramid) 198 | for scale in range(self.depthPyramid): 199 | dest[scale].append(getDescriptorsForMinibatch(pyramid[scale], 200 | self.patchSize, 201 | self.nDescriptorLevel).cpu().numpy()) 202 | 203 | def getScore(self): 204 | r""" 205 | Output the SWD distance between both distributions using the stored 206 | descriptors. 207 | """ 208 | output = [] 209 | 210 | descTarget = [finalize_descriptors(d) for d in self.descriptorsTarget] 211 | del self.descriptorsTarget 212 | 213 | descRef = [finalize_descriptors(d) for d in self.descriptorsRef] 214 | del self.descriptorsRef 215 | 216 | for scale in range(self.depthPyramid): 217 | printProgressBar(scale, self.depthPyramid) 218 | distance = sliced_wasserstein( 219 | descTarget[scale], descRef[scale], 4, 128) 220 | output.append(distance) 221 | printProgressBar(self.depthPyramid, self.depthPyramid) 222 | 223 | del descRef, descTarget 224 | 225 | return output 226 | 227 | def generateLaplacianPyramid(self, minibatch, num_levels): 228 | r""" 229 | Build the laplacian pyramids corresponding to the current minibatch. 230 | Args: 231 | minibatch (tensor): NxCxWxD, input batch 232 | num_levels (int): number of levels of the pyramids 233 | """ 234 | pyramid = [minibatch] 235 | for i in range(1, num_levels): 236 | pyramid.append(pyrDown(pyramid[-1], self.convolution)) 237 | pyramid[-2] -= pyrUp(pyramid[-1], self.convolution) 238 | return pyramid 239 | 240 | def reconstructLaplacianPyramid(self, pyramid): 241 | r""" 242 | Given a laplacian pyramid, reconstruct the corresponding minibatch 243 | 244 | Returns: 245 | A list L of tensors NxCxWxD, where L[i] represents the pyramids of 246 | the batch for the ith scale 247 | """ 248 | minibatch = pyramid[-1] 249 | for level in pyramid[-2::-1]: 250 | minibatch = pyrUp(minibatch, self.convolution) + level 251 | return minibatch 252 | 253 | def initConvolution(self, device): 254 | r""" 255 | Initialize the convolution used in openCV.pyrDown() and .pyrUp() 256 | """ 257 | gaussianFilter = torch.tensor([ 258 | [1, 4, 6, 4, 1], 259 | [4, 16, 24, 16, 4], 260 | [6, 24, 36, 24, 6], 261 | [4, 16, 24, 16, 4], 262 | [1, 4, 6, 4, 1]], dtype=torch.float) / 256.0 263 | 264 | self.convolution = nn.Conv2d(3, 3, (5, 5)) 265 | self.convolution.weight.data.fill_(0) 266 | self.convolution.weight.data[0][0] = gaussianFilter 267 | self.convolution.weight.data[1][1] = gaussianFilter 268 | self.convolution.weight.data[2][2] = gaussianFilter 269 | self.convolution.weight.requires_grad = False 270 | self.convolution = self.convolution.to(device) 271 | -------------------------------------------------------------------------------- /models/metrics/nn_score.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import os 3 | import json 4 | import random 5 | import pickle 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torchvision.models as models 10 | import torchvision.transforms as Transforms 11 | 12 | import scipy 13 | import scipy.spatial 14 | import numpy 15 | 16 | from models.datasets.attrib_dataset import AttribDataset 17 | from ..datasets.hd5 import H5Dataset 18 | from ..datasets.attrib_dataset import AttribDataset 19 | from ..loss_criterions.ac_criterion import ACGANCriterion 20 | from ..utils.utils import printProgressBar, loadmodule 21 | 22 | random.seed() 23 | 24 | 25 | def getStatsOnDataset(attributes): 26 | 27 | stats = {} 28 | 29 | for name, data in attributes.items(): 30 | for key, value in data.items(): 31 | 32 | if key not in stats: 33 | stats[key] = {} 34 | 35 | if value not in stats[key]: 36 | stats[key][value] = 0 37 | 38 | stats[key][value] += 1 39 | return stats 40 | 41 | 42 | def updateStatsWithData(stats, item): 43 | 44 | for key, value in item.items(): 45 | stats[key][value] += 1 46 | 47 | 48 | def buildTrainValTest(pathAttrib, 49 | shareTrain=0.8, 50 | shareVal=0.2): 51 | 52 | with open(pathAttrib, 'rb') as file: 53 | data = json.load(file) 54 | 55 | stats = getStatsOnDataset(data) 56 | 57 | shareTest = max(0., 1. - shareTrain - shareVal) 58 | 59 | targetTrain = {key: {value: stats[key][value] * shareTrain 60 | for value in stats[key]} for key in stats} 61 | targetVal = {key: {value: stats[key][value] * shareVal 62 | for value in stats[key]} for key in stats} 63 | targetTest = {key: {value: stats[key][value] * shareTest 64 | for value in stats[key]} for key in stats} 65 | 66 | keys = [key for key in data.keys()] 67 | random.shuffle(keys) 68 | 69 | outTrain = {} 70 | outVal = {} 71 | outTest = {} 72 | 73 | trainStats = {key: {value: 0 for value in stats[key]} for key in stats} 74 | valStats = {key: {value: 0 for value in stats[key]} for key in stats} 75 | testStats = {key: {value: 0 for value in stats[key]} for key in stats} 76 | 77 | for name in keys: 78 | 79 | scoreTrain = 0 80 | scoreVal = 0 81 | scoreTest = 0 82 | 83 | for category in data[name]: 84 | label = data[name][category] 85 | deltaTrain = max(0, targetTrain[category][label] - trainStats[category][label]) / ( 86 | targetTrain[category][label] + 1e-8) 87 | deltaVal = max(0, targetVal[category][label] - valStats[category] 88 | [label]) / (targetVal[category][label] + 1e-8) 89 | deltaTest = max(0, targetTest[category][label] - testStats[category][label]) / ( 90 | targetTest[category][label] + 1e-8) 91 | 92 | scoreTrain += deltaTrain**2 93 | scoreVal += deltaVal**2 94 | scoreTest += deltaTest**2 95 | 96 | if scoreTrain >= 0.999 or scoreTrain >= max(scoreVal, scoreTest): 97 | outTrain[name] = data[name] 98 | updateStatsWithData(trainStats, data[name]) 99 | elif scoreVal >= scoreTest: 100 | outVal[name] = data[name] 101 | updateStatsWithData(valStats, data[name]) 102 | else: 103 | outTest[name] = data[name] 104 | updateStatsWithData(testStats, data[name]) 105 | 106 | stats = {"Train": trainStats, "Val": valStats, "Test": testStats} 107 | 108 | return outTrain, outVal, outTest, stats 109 | 110 | 111 | def buildFeatureMaker(pathDB, 112 | pathTrainAttrib, 113 | pathValAttrib, 114 | specificAttrib=None, 115 | visualisation=None): 116 | 117 | # Parameters 118 | batchSize = 16 119 | nEpochs = 3 120 | learningRate = 1e-4 121 | beta1 = 0.9 122 | beta2 = 0.99 123 | device = torch.device("cuda:0") 124 | n_devices = torch.cuda.device_count() 125 | 126 | # Model 127 | resnet18 = models.resnet18(pretrained=True) 128 | resnet18.train() 129 | 130 | # Dataset 131 | size = 224 132 | transformList = [Transforms.Resize((size, size)), 133 | Transforms.RandomHorizontalFlip(), 134 | Transforms.ToTensor(), 135 | Transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))] 136 | 137 | transform = Transforms.Compose(transformList) 138 | 139 | dataset = AttribDataset(pathDB, transform=transform, 140 | attribDictPath=pathTrainAttrib, 141 | specificAttrib=specificAttrib, 142 | mimicImageFolder=False) 143 | 144 | validationDataset = AttribDataset(pathDB, transform=transform, 145 | attribDictPath=pathValAttrib, 146 | specificAttrib=specificAttrib, 147 | mimicImageFolder=False) 148 | 149 | print("%d training images detected, %d validation images detected" 150 | % (len(dataset), len(validationDataset))) 151 | 152 | # Optimization 153 | optimizer = torch.optim.Adam(resnet18.parameters(), 154 | betas=[beta1, beta2], 155 | lr=learningRate) 156 | 157 | lossMode = ACGANCriterion(dataset.getKeyOrders()) 158 | 159 | num_ftrs = resnet18.fc.in_features 160 | resnet18.fc = nn.Linear(num_ftrs, lossMode.getInputDim()) 161 | resnet18 = nn.DataParallel(resnet18).to(device) 162 | 163 | # Visualization data 164 | lossTrain = [] 165 | lossVal = [] 166 | iterList = [] 167 | tokenTrain = None 168 | tokenVal = None 169 | step = 0 170 | tmpLoss = 0 171 | 172 | for epoch in range(nEpochs): 173 | 174 | loader = torch.utils.data.DataLoader(dataset, 175 | batch_size=batchSize, 176 | shuffle=True, 177 | num_workers=n_devices) 178 | 179 | for iter, data in enumerate(loader): 180 | 181 | optimizer.zero_grad() 182 | 183 | inputs_real, labels = data 184 | inputs_real = inputs_real.to(device) 185 | labels = labels.to(device) 186 | 187 | predictedLabels = resnet18(inputs_real) 188 | 189 | loss = lossMode.getLoss(predictedLabels, labels) 190 | 191 | tmpLoss += loss.item() 192 | 193 | loss.backward() 194 | optimizer.step() 195 | 196 | if step % 100 == 0 and visualisation is not None: 197 | 198 | divisor = 100 199 | if step == 0: 200 | divisor = 1 201 | lossTrain.append(tmpLoss / divisor) 202 | iterList.append(step) 203 | tokenTrain = visualisation.publishLinePlot([('lossTrain', lossTrain)], iterList, 204 | name="Loss train", 205 | window_token=tokenTrain, 206 | env="main") 207 | 208 | validationLoader = torch.utils.data.DataLoader(validationDataset, 209 | batch_size=batchSize, 210 | shuffle=True, 211 | num_workers=n_devices) 212 | 213 | resnet18.eval() 214 | lossEval = 0 215 | i = 0 216 | for valData in validationLoader: 217 | 218 | inputs_real, labels = data 219 | inputs_real = inputs_real.to(device) 220 | labels = labels.to(device) 221 | lossEval += lossMode.getLoss(predictedLabels, 222 | labels).item() 223 | i += 1 224 | 225 | if i == 100: 226 | break 227 | 228 | lossEval /= i 229 | lossVal.append(lossEval) 230 | tokenVal = visualisation.publishLinePlot([('lossValidation', lossVal)], iterList, 231 | name="Loss validation", 232 | window_token=tokenVal, 233 | env="main") 234 | resnet18.train() 235 | 236 | print("[%5d ; %5d ] Loss train : %f ; Loss validation %f" 237 | % (epoch, iter, tmpLoss / divisor, lossEval)) 238 | tmpLoss = 0 239 | 240 | step += 1 241 | 242 | return resnet18.module 243 | 244 | 245 | def cutModelHead(model): 246 | 247 | modules = list(model.children())[:-1] 248 | model = nn.Sequential(*modules) 249 | 250 | return model 251 | 252 | 253 | def buildFeatureExtractor(pathModel, resetGrad=True): 254 | 255 | modelData = torch.load(pathModel) 256 | 257 | fullDump = modelData.get("fullDump", False) 258 | if fullDump: 259 | model = modelData['model'] 260 | else: 261 | modelType = loadmodule( 262 | modelData['package'], modelData['network'], prefix='') 263 | model = modelType(**modelData['kwargs']) 264 | model = cutModelHead(model) 265 | model.load_state_dict(modelData['data']) 266 | 267 | for param in model.parameters(): 268 | param.requires_grad = resetGrad 269 | 270 | mean = modelData['mean'] 271 | std = modelData['std'] 272 | 273 | return model, mean, std 274 | 275 | 276 | def saveFeatures(model, 277 | imgTransform, 278 | pathDB, 279 | pathMask, 280 | pathAttrib, 281 | outputFile, 282 | pathPartition=None, 283 | partitionValue=None): 284 | 285 | batchSize = 16 286 | 287 | transformList = [Transforms.ToTensor(), 288 | Transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] 289 | 290 | transform = Transforms.Compose(transformList) 291 | 292 | device = torch.device("cuda:0") 293 | n_devices = torch.cuda.device_count() 294 | 295 | parallelModel = nn.DataParallel(model).to(device).eval() 296 | parallelTransorm = nn.DataParallel(imgTransform).to(device) 297 | 298 | if os.path.splitext(pathDB)[1] == ".h5": 299 | dataset = H5Dataset(pathDB, 300 | transform=transform, 301 | pathDBMask=pathMask, 302 | partition_path=pathPartition, 303 | partition_value=partitionValue) 304 | 305 | else: 306 | dataset = AttribDataset(pathDB, transform=transform, 307 | attribDictPath=pathAttrib, 308 | specificAttrib=None, 309 | mimicImageFolder=False, 310 | pathMask=pathMask) 311 | 312 | loader = torch.utils.data.DataLoader(dataset, 313 | batch_size=batchSize, 314 | shuffle=False, 315 | num_workers=n_devices) 316 | 317 | outFeatures = [] 318 | 319 | nImg = 0 320 | totImg = len(dataset) 321 | 322 | for item in loader: 323 | 324 | if len(item) == 3: 325 | data, label, mask = item 326 | else: 327 | data, label = item 328 | mask = None 329 | 330 | printProgressBar(nImg, totImg) 331 | features = parallelModel(parallelTransorm( 332 | data)).detach().view(data.size(0), -1).cpu() 333 | outFeatures.append(features) 334 | 335 | nImg += batchSize 336 | 337 | printProgressBar(totImg, totImg) 338 | 339 | import sys 340 | sys.setrecursionlimit(10000) 341 | 342 | outFeatures = torch.cat(outFeatures, dim=0).numpy() 343 | tree = scipy.spatial.KDTree(outFeatures, leafsize=10) 344 | names = [dataset.getName(x) for x in range(totImg)] 345 | with open(outputFile, 'wb') as file: 346 | pickle.dump([tree, names], file) 347 | -------------------------------------------------------------------------------- /models/networks/DCGAN_nets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | from collections import OrderedDict 3 | import torch.nn as nn 4 | 5 | 6 | def weights_init(m): 7 | classname = m.__class__.__name__ 8 | if classname.find('Conv') != -1: 9 | m.weight.data.normal_(0.0, 0.02) 10 | elif classname.find('BatchNorm') != -1: 11 | m.weight.data.normal_(1.0, 0.02) 12 | m.bias.data.fill_(0) 13 | 14 | 15 | class GNet(nn.Module): 16 | def __init__(self, 17 | dimLatentVector, 18 | dimOutput, 19 | dimModelG, 20 | depthModel=3, 21 | generationActivation=nn.Tanh()): 22 | super(GNet, self).__init__() 23 | 24 | self.depthModel = depthModel 25 | self.refDim = dimModelG 26 | 27 | self.initFormatLayer(dimLatentVector) 28 | 29 | currDepth = int(dimModelG * (2**depthModel)) 30 | 31 | sequence = OrderedDict([]) 32 | # input is Z, going into a convolution 33 | sequence["batchNorm0"] = nn.BatchNorm2d(currDepth) 34 | sequence["relu0"] = nn.ReLU(True) 35 | 36 | for i in range(depthModel): 37 | 38 | nextDepth = int(currDepth / 2) 39 | 40 | # state size. (currDepth) x 2**(i+1) x 2**(i+1) 41 | sequence["convTranspose" + str(i+1)] = nn.ConvTranspose2d( 42 | currDepth, nextDepth, 4, 2, 1, bias=False) 43 | sequence["batchNorm" + str(i+1)] = nn.BatchNorm2d(nextDepth) 44 | sequence["relu" + str(i+1)] = nn.ReLU(True) 45 | 46 | currDepth = nextDepth 47 | 48 | sequence["outlayer"] = nn.ConvTranspose2d( 49 | dimModelG, dimOutput, 4, 2, 1, bias=False) 50 | 51 | self.outputAcctivation = generationActivation 52 | 53 | self.main = nn.Sequential(sequence) 54 | self.main.apply(weights_init) 55 | 56 | def initFormatLayer(self, dimLatentVector): 57 | 58 | currDepth = int(self.refDim * (2**self.depthModel)) 59 | self.formatLayer = nn.ConvTranspose2d( 60 | dimLatentVector, currDepth, 4, 1, 0, bias=False) 61 | 62 | def forward(self, input): 63 | 64 | x = input.view(-1, input.size(1), 1, 1) 65 | x = self.formatLayer(x) 66 | x = self.main(x) 67 | 68 | if self.outputAcctivation is None: 69 | return x 70 | 71 | return self.outputAcctivation(x) 72 | 73 | 74 | class DNet(nn.Module): 75 | def __init__(self, 76 | dimInput, 77 | dimModelD, 78 | sizeDecisionLayer, 79 | depthModel=3): 80 | super(DNet, self).__init__() 81 | 82 | currDepth = dimModelD 83 | sequence = OrderedDict([]) 84 | 85 | # input is (nc) x 2**(depthModel + 3) x 2**(depthModel + 3) 86 | sequence["convTranspose" + 87 | str(depthModel)] = nn.Conv2d(dimInput, currDepth, 88 | 4, 2, 1, bias=False) 89 | sequence["relu" + str(depthModel)] = nn.LeakyReLU(0.2, inplace=True) 90 | 91 | for i in range(depthModel): 92 | 93 | index = depthModel - i - 1 94 | nextDepth = currDepth * 2 95 | 96 | # state size. 97 | # (currDepth) x 2**(depthModel + 2 -i) x 2**(depthModel + 2 -i) 98 | sequence["convTranspose" + 99 | str(index)] = nn.Conv2d(currDepth, nextDepth, 100 | 4, 2, 1, bias=False) 101 | sequence["batchNorm" + str(index)] = nn.BatchNorm2d(nextDepth) 102 | sequence["relu" + str(index)] = nn.LeakyReLU(0.2, inplace=True) 103 | 104 | currDepth = nextDepth 105 | 106 | self.dimFeatureMap = currDepth 107 | 108 | self.main = nn.Sequential(sequence) 109 | self.main.apply(weights_init) 110 | 111 | self.initDecisionLayer(sizeDecisionLayer) 112 | 113 | def initDecisionLayer(self, sizeDecisionLayer): 114 | self.decisionLayer = nn.Conv2d( 115 | self.dimFeatureMap, sizeDecisionLayer, 4, 1, 0, bias=False) 116 | self.decisionLayer.apply(weights_init) 117 | self.sizeDecisionLayer = sizeDecisionLayer 118 | 119 | def forward(self, input, getFeature = False): 120 | x = self.main(input) 121 | 122 | if getFeature: 123 | 124 | return self.decisionLayer(x).view(-1, self.sizeDecisionLayer), \ 125 | x.view(-1, self.dimFeatureMap * 16) 126 | 127 | x = self.decisionLayer(x) 128 | return x.view(-1, self.sizeDecisionLayer) 129 | -------------------------------------------------------------------------------- /models/networks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | -------------------------------------------------------------------------------- /models/networks/constant_net.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class ConstantNet(nn.Module): 7 | r"A network that does nothing" 8 | 9 | def __init__(self, 10 | shapeOut=None): 11 | 12 | super(ConstantNet, self).__init__() 13 | self.shapeOut = shapeOut 14 | 15 | def forward(self, x): 16 | 17 | if self.shapeOut is not None: 18 | x = x.view(x.size[0], self.shapeOut[0], 19 | self.shapeOut[1], self.shapeOut[2]) 20 | 21 | return x 22 | 23 | 24 | class MeanStd(nn.Module): 25 | def __init__(self): 26 | super(MeanStd, self).__init__() 27 | 28 | def forward(self,x): 29 | 30 | # Size : N C W H 31 | x = x.view(x.size(0), x.size(1), -1) 32 | mean_x = torch.mean(x, dim=2) 33 | var_x = torch.mean(x**2, dim=2) - mean_x * mean_x 34 | return torch.cat([mean_x, var_x], dim=1) 35 | 36 | 37 | class FeatureTransform(nn.Module): 38 | r""" 39 | Concatenation of a resize tranform and a normalization 40 | """ 41 | 42 | def __init__(self, 43 | mean=None, 44 | std=None, 45 | size=224): 46 | 47 | super(FeatureTransform, self).__init__() 48 | self.size = size 49 | 50 | if mean is None: 51 | mean = [0., 0., 0.] 52 | 53 | if std is None: 54 | std = [1., 1., 1.] 55 | 56 | self.register_buffer('mean', torch.tensor( 57 | mean, dtype=torch.float).view(1, 3, 1, 1)) 58 | self.register_buffer('std', torch.tensor( 59 | std, dtype=torch.float).view(1, 3, 1, 1)) 60 | 61 | if size is None: 62 | self.upsamplingModule = None 63 | else: 64 | self.upsamplingModule = torch.nn.Upsample( 65 | (size, size), mode='bilinear') 66 | 67 | def forward(self, x): 68 | 69 | if self.upsamplingModule is not None: 70 | x = self.upsamplingModule(x) 71 | 72 | x = x - self.mean 73 | x = x / self.std 74 | 75 | return x 76 | -------------------------------------------------------------------------------- /models/networks/custom_layers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import math 3 | 4 | import torch.nn as nn 5 | 6 | from numpy import prod 7 | 8 | 9 | class NormalizationLayer(nn.Module): 10 | 11 | def __init__(self): 12 | super(NormalizationLayer, self).__init__() 13 | 14 | def forward(self, x, epsilon=1e-8): 15 | return x * (((x**2).mean(dim=1, keepdim=True) + epsilon).rsqrt()) 16 | 17 | 18 | def Upscale2d(x, factor=2): 19 | assert isinstance(factor, int) and factor >= 1 20 | if factor == 1: 21 | return x 22 | s = x.size() 23 | x = x.view(-1, s[1], s[2], 1, s[3], 1) 24 | x = x.expand(-1, s[1], s[2], factor, s[3], factor) 25 | x = x.contiguous().view(-1, s[1], s[2] * factor, s[3] * factor) 26 | return x 27 | 28 | 29 | def getLayerNormalizationFactor(x): 30 | r""" 31 | Get He's constant for the given layer 32 | https://www.cv-foundation.org/openaccess/content_iccv_2015/papers/He_Delving_Deep_into_ICCV_2015_paper.pdf 33 | """ 34 | size = x.weight.size() 35 | fan_in = prod(size[1:]) 36 | 37 | return math.sqrt(2.0 / fan_in) 38 | 39 | 40 | class ConstrainedLayer(nn.Module): 41 | r""" 42 | A handy refactor that allows the user to: 43 | - initialize one layer's bias to zero 44 | - apply He's initialization at runtime 45 | """ 46 | 47 | def __init__(self, 48 | module, 49 | equalized=True, 50 | lrMul=1.0, 51 | initBiasToZero=True): 52 | r""" 53 | equalized (bool): if true, the layer's weight should evolve within 54 | the range (-1, 1) 55 | initBiasToZero (bool): if true, bias will be initialized to zero 56 | """ 57 | 58 | super(ConstrainedLayer, self).__init__() 59 | 60 | self.module = module 61 | self.equalized = equalized 62 | 63 | if initBiasToZero: 64 | self.module.bias.data.fill_(0) 65 | if self.equalized: 66 | self.module.weight.data.normal_(0, 1) 67 | self.module.weight.data /= lrMul 68 | self.weight = getLayerNormalizationFactor(self.module) * lrMul 69 | 70 | def forward(self, x): 71 | 72 | x = self.module(x) 73 | if self.equalized: 74 | x *= self.weight 75 | return x 76 | 77 | 78 | class EqualizedConv2d(ConstrainedLayer): 79 | 80 | def __init__(self, 81 | nChannelsPrevious, 82 | nChannels, 83 | kernelSize, 84 | padding=0, 85 | bias=True, 86 | **kwargs): 87 | r""" 88 | A nn.Conv2d module with specific constraints 89 | Args: 90 | nChannelsPrevious (int): number of channels in the previous layer 91 | nChannels (int): number of channels of the current layer 92 | kernelSize (int): size of the convolutional kernel 93 | padding (int): convolution's padding 94 | bias (bool): with bias ? 95 | """ 96 | 97 | ConstrainedLayer.__init__(self, 98 | nn.Conv2d(nChannelsPrevious, nChannels, 99 | kernelSize, padding=padding, 100 | bias=bias), 101 | **kwargs) 102 | 103 | 104 | class EqualizedLinear(ConstrainedLayer): 105 | 106 | def __init__(self, 107 | nChannelsPrevious, 108 | nChannels, 109 | bias=True, 110 | **kwargs): 111 | r""" 112 | A nn.Linear module with specific constraints 113 | Args: 114 | nChannelsPrevious (int): number of channels in the previous layer 115 | nChannels (int): number of channels of the current layer 116 | bias (bool): with bias ? 117 | """ 118 | 119 | ConstrainedLayer.__init__(self, 120 | nn.Linear(nChannelsPrevious, nChannels, 121 | bias=bias), **kwargs) 122 | -------------------------------------------------------------------------------- /models/networks/mini_batch_stddev_module.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import torch 3 | 4 | 5 | def miniBatchStdDev(x, subGroupSize=4): 6 | r""" 7 | Add a minibatch standard deviation channel to the current layer. 8 | In other words: 9 | 1) Compute the standard deviation of the feature map over the minibatch 10 | 2) Get the mean, over all pixels and all channels of thsi ValueError 11 | 3) expand the layer and cocatenate it with the input 12 | 13 | Args: 14 | 15 | - x (tensor): previous layer 16 | - subGroupSize (int): size of the mini-batches on which the standard deviation 17 | should be computed 18 | """ 19 | size = x.size() 20 | subGroupSize = min(size[0], subGroupSize) 21 | if size[0] % subGroupSize != 0: 22 | subGroupSize = size[0] 23 | G = int(size[0] / subGroupSize) 24 | if subGroupSize > 1: 25 | y = x.view(-1, subGroupSize, size[1], size[2], size[3]) 26 | y = torch.var(y, 1) 27 | y = torch.sqrt(y + 1e-8) 28 | y = y.view(G, -1) 29 | y = torch.mean(y, 1).view(G, 1) 30 | y = y.expand(G, size[2]*size[3]).view((G, 1, 1, size[2], size[3])) 31 | y = y.expand(G, subGroupSize, -1, -1, -1) 32 | y = y.contiguous().view((-1, 1, size[2], size[3])) 33 | else: 34 | y = torch.zeros(x.size(0), 1, x.size(2), x.size(3), device=x.device) 35 | 36 | return torch.cat([x, y], dim=1) 37 | -------------------------------------------------------------------------------- /models/networks/styleGAN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .custom_layers import EqualizedConv2d, EqualizedLinear,\ 5 | NormalizationLayer, Upscale2d 6 | 7 | 8 | class AdaIN(nn.Module): 9 | 10 | def __init__(self, dimIn, dimOut, epsilon=1e-8): 11 | super(AdaIN, self).__init__() 12 | self.epsilon = epsilon 13 | self.styleModulator = EqualizedLinear(dimIn, 2*dimOut, equalized=True, 14 | initBiasToZero=True) 15 | self.dimOut = dimOut 16 | 17 | def forward(self, x, y): 18 | 19 | # x: N x C x W x H 20 | batchSize, nChannel, width, height = x.size() 21 | tmpX = x.view(batchSize, nChannel, -1) 22 | mux = tmpX.mean(dim=2).view(batchSize, nChannel, 1, 1) 23 | varx = torch.clamp((tmpX*tmpX).mean(dim=2).view(batchSize, nChannel, 1, 1) - mux*mux, min=0) 24 | varx = torch.rsqrt(varx + self.epsilon) 25 | x = (x - mux) * varx 26 | 27 | # Adapt style 28 | styleY = self.styleModulator(y) 29 | yA = styleY[:, : self.dimOut].view(batchSize, self.dimOut, 1, 1) 30 | yB = styleY[:, self.dimOut:].view(batchSize, self.dimOut, 1, 1) 31 | 32 | return yA * x + yB 33 | 34 | 35 | class NoiseMultiplier(nn.Module): 36 | 37 | def __init__(self): 38 | super(NoiseMultiplier, self).__init__() 39 | self.module = nn.Conv2d(1, 1, 1, bias=False) 40 | self.module.weight.data.fill_(0) 41 | 42 | def forward(self, x): 43 | 44 | return self.module(x) 45 | 46 | 47 | class MappingLayer(nn.Module): 48 | 49 | def __init__(self, dimIn, dimLatent, nLayers, leakyReluLeak=0.2): 50 | super(MappingLayer, self).__init__() 51 | self.FC = nn.ModuleList() 52 | 53 | inDim = dimIn 54 | for i in range(nLayers): 55 | self.FC.append(EqualizedLinear(inDim, dimLatent, lrMul=0.01, equalized=True, initBiasToZero=True)) 56 | inDim = dimLatent 57 | 58 | self.activation = torch.nn.LeakyReLU(leakyReluLeak) 59 | 60 | def forward(self, x): 61 | for layer in self.FC: 62 | x = self.activation(layer(x)) 63 | 64 | return x 65 | 66 | class GNet(nn.Module): 67 | 68 | def __init__(self, 69 | dimInput=512, 70 | dimMapping=512, 71 | dimOutput=3, 72 | nMappingLayers=8, 73 | leakyReluLeak=0.2, 74 | generationActivation=None, 75 | phiTruncation=0.5, 76 | gamma_avg=0.99): 77 | 78 | super(GNet, self).__init__() 79 | self.dimMapping = dimMapping 80 | self.mapping = MappingLayer(dimInput, dimMapping, nMappingLayers) 81 | self.baseScale0 = nn.Parameter(torch.ones(1, dimMapping, 4, 4), requires_grad=True) 82 | 83 | self.scaleLayers = nn.ModuleList() 84 | self.toRGBLayers = nn.ModuleList() 85 | self.noiseModulators = nn.ModuleList() 86 | self.depthScales = [dimMapping] 87 | self.noramlizationLayer = NormalizationLayer() 88 | 89 | self.adain00 = AdaIN(dimMapping, dimMapping) 90 | self.noiseMod00 = NoiseMultiplier() 91 | self.adain01 = AdaIN(dimMapping, dimMapping) 92 | self.noiseMod01 = NoiseMultiplier() 93 | self.conv0 = EqualizedConv2d(dimMapping, dimMapping, 3, equalized=True, 94 | initBiasToZero=True, padding=1) 95 | 96 | self.activation = torch.nn.LeakyReLU(leakyReluLeak) 97 | self.alpha = 0 98 | self.generationActivation = generationActivation 99 | self.dimOutput = dimOutput 100 | self.phiTruncation = phiTruncation 101 | 102 | self.register_buffer('mean_w', torch.randn(1, dimMapping)) 103 | self.gamma_avg = gamma_avg 104 | 105 | def setNewAlpha(self, alpha): 106 | r""" 107 | Update the value of the merging factor alpha 108 | 109 | Args: 110 | 111 | - alpha (float): merging factor, must be in [0, 1] 112 | """ 113 | 114 | if alpha < 0 or alpha > 1: 115 | raise ValueError("alpha must be in [0,1]") 116 | 117 | if not self.toRGBLayers: 118 | raise AttributeError("Can't set an alpha layer if only the scale 0" 119 | "is defined") 120 | 121 | self.alpha = alpha 122 | 123 | def addScale(self, dimNewScale): 124 | 125 | lastDim = self.depthScales[-1] 126 | self.scaleLayers.append(nn.ModuleList()) 127 | self.scaleLayers[-1].append(EqualizedConv2d(lastDim, 128 | dimNewScale, 129 | 3, 130 | padding=1, 131 | equalized=True, 132 | initBiasToZero=True)) 133 | 134 | self.scaleLayers[-1].append(AdaIN(self.dimMapping, dimNewScale)) 135 | self.scaleLayers[-1].append(EqualizedConv2d(dimNewScale, 136 | dimNewScale, 137 | 3, 138 | padding=1, 139 | equalized=True, 140 | initBiasToZero=True)) 141 | self.scaleLayers[-1].append(AdaIN(self.dimMapping, dimNewScale)) 142 | self.toRGBLayers.append(EqualizedConv2d(dimNewScale, 143 | self.dimOutput, 144 | 1, 145 | equalized=True, 146 | initBiasToZero=True)) 147 | 148 | self.noiseModulators.append(nn.ModuleList()) 149 | self.noiseModulators[-1].append(NoiseMultiplier()) 150 | self.noiseModulators[-1].append(NoiseMultiplier()) 151 | self.depthScales.append(dimNewScale) 152 | 153 | def forward(self, x): 154 | 155 | batchSize = x.size(0) 156 | mapping = self.mapping(self.noramlizationLayer(x)) 157 | if self.training: 158 | self.mean_w = self.gamma_avg * self.mean_w + (1-self.gamma_avg) * mapping.mean(dim=0, keepdim=True) 159 | 160 | if self.phiTruncation < 1: 161 | mapping = self.mean_w + self.phiTruncation * (mapping - self.mean_w) 162 | 163 | feature = self.baseScale0.expand(batchSize, -1, 4, 4) 164 | feature = feature + self.noiseMod00(torch.randn((batchSize, 1, 4, 4), device=x.device)) 165 | 166 | feature = self.activation(feature) 167 | feature = self.adain00(feature, mapping) 168 | feature = self.conv0(feature) 169 | feature = feature + self.noiseMod01(torch.randn((batchSize, 1, 4, 4), device=x.device)) 170 | feature = self.activation(feature) 171 | feature = self.adain01(feature, mapping) 172 | 173 | for nLayer, group in enumerate(self.scaleLayers): 174 | 175 | noiseMod = self.noiseModulators[nLayer] 176 | feature = Upscale2d(feature) 177 | feature = group[0](feature) + noiseMod[0](torch.randn((batchSize, 1, 178 | feature.size(2), 179 | feature.size(3)), device=x.device)) 180 | feature = self.activation(feature) 181 | feature = group[1](feature, mapping) 182 | feature = group[2](feature) + noiseMod[1](torch.randn((batchSize, 1, 183 | feature.size(2), 184 | feature.size(3)), device=x.device)) 185 | feature = self.activation(feature) 186 | feature = group[3](feature, mapping) 187 | 188 | if self.alpha > 0 and nLayer == len(self.scaleLayers) -2: 189 | y = self.toRGBLayers[-2](feature) 190 | y = Upscale2d(y) 191 | 192 | feature = self.toRGBLayers[-1](feature) 193 | # Blending with the lower resolution output when alpha > 0 194 | if self.alpha > 0: 195 | feature = self.alpha * y + (1.0-self.alpha) * feature 196 | 197 | if self.generationActivation is not None: 198 | feature = self.generationActivation(feature) 199 | 200 | return feature 201 | 202 | def getOutputSize(self): 203 | 204 | side = 2**(2 + len(self.toRGBLayers)) 205 | return (side, side) 206 | -------------------------------------------------------------------------------- /models/progressive_gan.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import torch.optim as optim 3 | 4 | from .base_GAN import BaseGAN 5 | from .utils.config import BaseConfig 6 | from .networks.progressive_conv_net import GNet, DNet 7 | 8 | 9 | class ProgressiveGAN(BaseGAN): 10 | r""" 11 | Implementation of NVIDIA's progressive GAN. 12 | """ 13 | 14 | def __init__(self, 15 | dimLatentVector=512, 16 | depthScale0=512, 17 | initBiasToZero=True, 18 | leakyness=0.2, 19 | perChannelNormalization=True, 20 | miniBatchStdDev=False, 21 | equalizedlR=True, 22 | **kwargs): 23 | r""" 24 | Args: 25 | 26 | Specific Arguments: 27 | - depthScale0 (int) 28 | - initBiasToZero (bool): should layer's bias be initialized to 29 | zero ? 30 | - leakyness (float): negative slope of the leakyRelU activation 31 | function 32 | - perChannelNormalization (bool): do we normalize the output of 33 | each convolutional layer ? 34 | - miniBatchStdDev (bool): mini batch regularization for the 35 | discriminator 36 | - equalizedlR (bool): if True, forces the optimizer to see weights 37 | in range (-1, 1) 38 | 39 | """ 40 | if not 'config' in vars(self): 41 | self.config = BaseConfig() 42 | 43 | self.config.depthScale0 = depthScale0 44 | self.config.initBiasToZero = initBiasToZero 45 | self.config.leakyReluLeak = leakyness 46 | self.config.depthOtherScales = [] 47 | self.config.perChannelNormalization = perChannelNormalization 48 | self.config.alpha = 0 49 | self.config.miniBatchStdDev = miniBatchStdDev 50 | self.config.equalizedlR = equalizedlR 51 | 52 | BaseGAN.__init__(self, dimLatentVector, **kwargs) 53 | 54 | def getNetG(self): 55 | 56 | gnet = GNet(self.config.latentVectorDim, 57 | self.config.depthScale0, 58 | initBiasToZero=self.config.initBiasToZero, 59 | leakyReluLeak=self.config.leakyReluLeak, 60 | normalization=self.config.perChannelNormalization, 61 | generationActivation=self.lossCriterion.generationActivation, 62 | dimOutput=self.config.dimOutput, 63 | equalizedlR=self.config.equalizedlR) 64 | 65 | # Add scales if necessary 66 | for depth in self.config.depthOtherScales: 67 | gnet.addScale(depth) 68 | 69 | # If new scales are added, give the generator a blending layer 70 | if self.config.depthOtherScales: 71 | gnet.setNewAlpha(self.config.alpha) 72 | 73 | return gnet 74 | 75 | def getNetD(self): 76 | 77 | dnet = DNet(self.config.depthScale0, 78 | initBiasToZero=self.config.initBiasToZero, 79 | leakyReluLeak=self.config.leakyReluLeak, 80 | sizeDecisionLayer=self.lossCriterion.sizeDecisionLayer + 81 | self.config.categoryVectorDim, 82 | miniBatchNormalization=self.config.miniBatchStdDev, 83 | dimInput=self.config.dimOutput, 84 | equalizedlR=self.config.equalizedlR) 85 | 86 | # Add scales if necessary 87 | for depth in self.config.depthOtherScales: 88 | dnet.addScale(depth) 89 | 90 | # If new scales are added, give the discriminator a blending layer 91 | if self.config.depthOtherScales: 92 | dnet.setNewAlpha(self.config.alpha) 93 | 94 | return dnet 95 | 96 | def getOptimizerD(self): 97 | return optim.Adam(filter(lambda p: p.requires_grad, self.netD.parameters()), 98 | betas=[0, 0.99], lr=self.config.learningRate) 99 | 100 | def getOptimizerG(self): 101 | return optim.Adam(filter(lambda p: p.requires_grad, self.netG.parameters()), 102 | betas=[0, 0.99], lr=self.config.learningRate) 103 | 104 | def addScale(self, depthNewScale): 105 | r""" 106 | Add a new scale to the model. The output resolution becomes twice 107 | bigger. 108 | """ 109 | self.netG = self.getOriginalG() 110 | self.netD = self.getOriginalD() 111 | 112 | self.netG.addScale(depthNewScale) 113 | self.netD.addScale(depthNewScale) 114 | 115 | self.config.depthOtherScales.append(depthNewScale) 116 | 117 | self.updateSolversDevice() 118 | 119 | def updateAlpha(self, newAlpha): 120 | r""" 121 | Update the blending factor alpha. 122 | 123 | Args: 124 | - alpha (float): blending factor (in [0,1]). 0 means only the 125 | highest resolution in considered (no blend), 1 126 | means the highest resolution is fully discarded. 127 | """ 128 | print("Changing alpha to %.3f" % newAlpha) 129 | 130 | self.getOriginalG().setNewAlpha(newAlpha) 131 | self.getOriginalD().setNewAlpha(newAlpha) 132 | 133 | if self.avgG: 134 | self.avgG.module.setNewAlpha(newAlpha) 135 | 136 | self.config.alpha = newAlpha 137 | 138 | def getSize(self): 139 | r""" 140 | Get output image size (W, H) 141 | """ 142 | return self.getOriginalG().getOutputSize() 143 | -------------------------------------------------------------------------------- /models/styleGAN.py: -------------------------------------------------------------------------------- 1 | from .progressive_gan import ProgressiveGAN 2 | from .networks.styleGAN import GNet 3 | from .utils.config import BaseConfig 4 | 5 | 6 | class StyleGAN(ProgressiveGAN): 7 | 8 | def __init__(self, 9 | nMappings=8, 10 | phiTruncation=0.5, 11 | gamma_avg=0.99, 12 | **kwargs): 13 | 14 | if not 'config' in vars(self): 15 | self.config = BaseConfig() 16 | 17 | self.config.nMappings = nMappings 18 | self.config.phiTruncation = phiTruncation 19 | self.config.gamma_avg = gamma_avg 20 | 21 | if self.config.phiTruncation >= 1: 22 | print("Disabling the truncation trick") 23 | ProgressiveGAN.__init__(self, **kwargs) 24 | 25 | def getNetG(self): 26 | 27 | gnet = GNet(dimInput=self.config.latentVectorDim, 28 | dimMapping=self.config.depthScale0, 29 | leakyReluLeak=self.config.leakyReluLeak, 30 | nMappingLayers=self.config.nMappings, 31 | generationActivation=self.lossCriterion.generationActivation, 32 | dimOutput=self.config.dimOutput, 33 | phiTruncation=self.config.phiTruncation, 34 | gamma_avg=self.config.gamma_avg) 35 | 36 | # Add scales if necessary 37 | for depth in self.config.depthOtherScales: 38 | gnet.addScale(depth) 39 | 40 | # If new scales are added, give the generator a blending layer 41 | if self.config.depthOtherScales: 42 | gnet.setNewAlpha(self.config.alpha) 43 | 44 | return gnet 45 | -------------------------------------------------------------------------------- /models/trainer/DCGAN_trainer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import os 3 | 4 | from ..DCGAN import DCGAN 5 | from .gan_trainer import GANTrainer 6 | from .standard_configurations.dcgan_config import _C 7 | 8 | 9 | class DCGANTrainer(GANTrainer): 10 | r""" 11 | A trainer structure for the DCGAN and DCGAN product models 12 | """ 13 | 14 | _defaultConfig = _C 15 | 16 | def getDefaultConfig(self): 17 | return DCGANTrainer._defaultConfig 18 | 19 | def __init__(self, 20 | pathdb, 21 | **kwargs): 22 | r""" 23 | Args: 24 | 25 | pathdb (string): path to the input dataset 26 | **kwargs: other arguments specific to the GANTrainer class 27 | """ 28 | 29 | GANTrainer.__init__(self, pathdb, **kwargs) 30 | 31 | self.lossProfile.append({"iter": [], "scale": 0}) 32 | 33 | def initModel(self): 34 | self.model = DCGAN(useGPU=self.useGPU, 35 | **vars(self.modelConfig)) 36 | 37 | def train(self): 38 | 39 | shift = 0 40 | if self.startIter >0: 41 | shift+= self.startIter 42 | 43 | if self.checkPointDir is not None: 44 | pathBaseConfig = os.path.join(self.checkPointDir, self.modelLabel 45 | + "_train_config.json") 46 | self.saveBaseConfig(pathBaseConfig) 47 | 48 | maxShift = int(self.modelConfig.nEpoch * len(self.getDBLoader(0))) 49 | 50 | for epoch in range(self.modelConfig.nEpoch): 51 | dbLoader = self.getDBLoader(0) 52 | self.trainOnEpoch(dbLoader, 0, shiftIter=shift) 53 | 54 | shift += len(dbLoader) 55 | 56 | if shift > maxShift: 57 | break 58 | 59 | label = self.modelLabel + ("_s%d_i%d" % 60 | (0, shift)) 61 | self.saveCheckpoint(self.checkPointDir, 62 | label, 0, shift) 63 | 64 | def initializeWithPretrainNetworks(self, 65 | pathD, 66 | pathGShape, 67 | pathGTexture, 68 | finetune=True): 69 | r""" 70 | Initialize a product gan by loading 3 pretrained networks 71 | 72 | Args: 73 | 74 | pathD (string): Path to the .pt file where the DCGAN discrimator is saved 75 | pathGShape (string): Path to .pt file where the DCGAN shape generator 76 | is saved 77 | pathGTexture (string): Path to .pt file where the DCGAN texture generator 78 | is saved 79 | 80 | finetune (bool): set to True to reinitialize the first layer of the 81 | generator and the last layer of the discriminator 82 | """ 83 | 84 | if not self.modelConfig.productGan: 85 | raise ValueError("Only product gan can be cross-initialized") 86 | 87 | self.model.loadG(pathGShape, pathGTexture, resetFormatLayer=finetune) 88 | self.model.load(pathD, loadG=False, loadD=True, 89 | loadConfig=False, finetuning=True) 90 | -------------------------------------------------------------------------------- /models/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | -------------------------------------------------------------------------------- /models/trainer/progressive_gan_trainer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import os 3 | 4 | from .standard_configurations.pgan_config import _C 5 | from ..progressive_gan import ProgressiveGAN 6 | from .gan_trainer import GANTrainer 7 | from ..utils.utils import getMinOccurence 8 | import torch.nn.functional as F 9 | 10 | 11 | class ProgressiveGANTrainer(GANTrainer): 12 | r""" 13 | A class managing a progressive GAN training. Logs, chekpoints, 14 | visualization, and number iterations are managed here. 15 | """ 16 | _defaultConfig = _C 17 | 18 | def getDefaultConfig(self): 19 | return ProgressiveGANTrainer._defaultConfig 20 | 21 | def __init__(self, 22 | pathdb, 23 | miniBatchScheduler=None, 24 | datasetProfile=None, 25 | configScheduler=None, 26 | **kwargs): 27 | r""" 28 | Args: 29 | - pathdb (string): path to the directorty containing the image 30 | dataset 31 | - useGPU (bool): set to True if you want to use the available GPUs 32 | for the training procedure 33 | - visualisation (module): if not None, a visualisation module to 34 | follow the evolution of the training 35 | - lossIterEvaluation (int): size of the interval on which the 36 | model'sloss will be evaluated 37 | - saveIter (int): frequency at which at checkpoint should be saved 38 | (relevant only if modelLabel != None) 39 | - checkPointDir (string): if not None, directory where the checkpoints 40 | should be saved 41 | - modelLabel (string): name of the model 42 | - config (dictionary): configuration dictionnary. See std_p_gan_config.py 43 | for all the possible options 44 | - numWorkers (int): number of GOU to use. Will be set to one if not 45 | useGPU 46 | - stopOnShitStorm (bool): should we stop the training if a diverging 47 | behavior is detected ? 48 | """ 49 | 50 | self.configScheduler = {} 51 | if configScheduler is not None: 52 | self.configScheduler = { 53 | int(key): value for key, value in configScheduler.items()} 54 | 55 | self.miniBatchScheduler = {} 56 | if miniBatchScheduler is not None: 57 | self.miniBatchScheduler = { 58 | int(x): value for x, value in miniBatchScheduler.items()} 59 | 60 | self.datasetProfile = {} 61 | if datasetProfile is not None: 62 | self.datasetProfile = { 63 | int(x): value for x, value in datasetProfile.items()} 64 | 65 | GANTrainer.__init__(self, pathdb, **kwargs) 66 | 67 | def initModel(self): 68 | r""" 69 | Initialize the GAN model. 70 | """ 71 | 72 | config = {key: value for key, value in vars(self.modelConfig).items()} 73 | config["depthScale0"] = self.modelConfig.depthScales[0] 74 | self.model = ProgressiveGAN(useGPU=self.useGPU, **config) 75 | 76 | def readTrainConfig(self, config): 77 | r""" 78 | Load a permanent configuration describing a models. The variables 79 | described in this file are constant through the training. 80 | """ 81 | 82 | GANTrainer.readTrainConfig(self, config) 83 | 84 | if self.modelConfig.alphaJumpMode not in ["custom", "linear"]: 85 | raise ValueError( 86 | "alphaJumpMode should be one of the followings: \ 87 | 'custom', 'linear'") 88 | 89 | if self.modelConfig.alphaJumpMode == "linear": 90 | 91 | self.modelConfig.alphaNJumps[0] = 0 92 | self.modelConfig.iterAlphaJump = [] 93 | self.modelConfig.alphaJumpVals = [] 94 | 95 | self.updateAlphaJumps( 96 | self.modelConfig.alphaNJumps, self.modelConfig.alphaSizeJumps) 97 | 98 | self.scaleSanityCheck() 99 | 100 | def scaleSanityCheck(self): 101 | 102 | # Sanity check 103 | n_scales = min(len(self.modelConfig.depthScales), 104 | len(self.modelConfig.maxIterAtScale), 105 | len(self.modelConfig.iterAlphaJump), 106 | len(self.modelConfig.alphaJumpVals)) 107 | 108 | self.modelConfig.depthScales = self.modelConfig.depthScales[:n_scales] 109 | self.modelConfig.maxIterAtScale = self.modelConfig.maxIterAtScale[:n_scales] 110 | self.modelConfig.iterAlphaJump = self.modelConfig.iterAlphaJump[:n_scales] 111 | self.modelConfig.alphaJumpVals = self.modelConfig.alphaJumpVals[:n_scales] 112 | 113 | self.modelConfig.size_scales = [4] 114 | for scale in range(1, n_scales): 115 | self.modelConfig.size_scales.append( 116 | self.modelConfig.size_scales[-1] * 2) 117 | 118 | self.modelConfig.n_scales = n_scales 119 | 120 | def updateAlphaJumps(self, nJumpScale, sizeJumpScale): 121 | r""" 122 | Given the number of iterations between two updates of alpha at each 123 | scale and the number of updates per scale, build the effective values of 124 | self.maxIterAtScale and self.alphaJumpVals. 125 | 126 | Args: 127 | 128 | - nJumpScale (list of int): for each scale, the number of times 129 | alpha should be updated 130 | - sizeJumpScale (list of int): for each scale, the number of 131 | iterations between two updates 132 | """ 133 | 134 | n_scales = min(len(nJumpScale), len(sizeJumpScale)) 135 | 136 | for scale in range(n_scales): 137 | 138 | self.modelConfig.iterAlphaJump.append([]) 139 | self.modelConfig.alphaJumpVals.append([]) 140 | 141 | if nJumpScale[scale] == 0: 142 | self.modelConfig.iterAlphaJump[-1].append(0) 143 | self.modelConfig.alphaJumpVals[-1].append(0.0) 144 | continue 145 | 146 | diffJump = 1.0 / float(nJumpScale[scale]) 147 | currVal = 1.0 148 | currIter = 0 149 | 150 | while currVal > 0: 151 | 152 | self.modelConfig.iterAlphaJump[-1].append(currIter) 153 | self.modelConfig.alphaJumpVals[-1].append(currVal) 154 | 155 | currIter += sizeJumpScale[scale] 156 | currVal -= diffJump 157 | 158 | self.modelConfig.iterAlphaJump[-1].append(currIter) 159 | self.modelConfig.alphaJumpVals[-1].append(0.0) 160 | 161 | def inScaleUpdate(self, iter, scale, input_real): 162 | 163 | if self.indexJumpAlpha < len(self.modelConfig.iterAlphaJump[scale]): 164 | if iter == self.modelConfig.iterAlphaJump[scale][self.indexJumpAlpha]: 165 | alpha = self.modelConfig.alphaJumpVals[scale][self.indexJumpAlpha] 166 | self.model.updateAlpha(alpha) 167 | self.indexJumpAlpha += 1 168 | 169 | if self.model.config.alpha > 0: 170 | low_res_real = F.avg_pool2d(input_real, (2, 2)) 171 | low_res_real = F.upsample( 172 | low_res_real, scale_factor=2, mode='nearest') 173 | 174 | alpha = self.model.config.alpha 175 | input_real = alpha * low_res_real + (1-alpha) * input_real 176 | 177 | return input_real 178 | 179 | def updateDatasetForScale(self, scale): 180 | 181 | self.modelConfig.miniBatchSize = getMinOccurence( 182 | self.miniBatchScheduler, scale, self.modelConfig.miniBatchSize) 183 | self.path_db = getMinOccurence( 184 | self.datasetProfile, scale, self.path_db) 185 | 186 | # Scale scheduler 187 | if self.configScheduler is not None: 188 | if scale in self.configScheduler: 189 | print("Scale %d, updating the training configuration" % scale) 190 | print(self.configScheduler[scale]) 191 | self.model.updateConfig(self.configScheduler[scale]) 192 | 193 | def train(self): 194 | r""" 195 | Launch the training. This one will stop if a divergent behavior is 196 | detected. 197 | 198 | Returns: 199 | 200 | - True if the training completed 201 | - False if the training was interrupted due to a divergent behavior 202 | """ 203 | 204 | n_scales = len(self.modelConfig.depthScales) 205 | 206 | if self.checkPointDir is not None: 207 | pathBaseConfig = os.path.join(self.checkPointDir, self.modelLabel 208 | + "_train_config.json") 209 | self.saveBaseConfig(pathBaseConfig) 210 | 211 | for scale in range(self.startScale, n_scales): 212 | 213 | self.updateDatasetForScale(scale) 214 | 215 | while scale >= len(self.lossProfile): 216 | self.lossProfile.append( 217 | {"scale": scale, "iter": []}) 218 | 219 | dbLoader = self.getDBLoader(scale) 220 | sizeDB = len(dbLoader) 221 | 222 | shiftIter = 0 223 | if self.startIter > 0: 224 | shiftIter = self.startIter 225 | self.startIter = 0 226 | 227 | shiftAlpha = 0 228 | while shiftAlpha < len(self.modelConfig.iterAlphaJump[scale]) and \ 229 | self.modelConfig.iterAlphaJump[scale][shiftAlpha] < shiftIter: 230 | shiftAlpha += 1 231 | 232 | while shiftIter < self.modelConfig.maxIterAtScale[scale]: 233 | 234 | self.indexJumpAlpha = shiftAlpha 235 | status = self.trainOnEpoch(dbLoader, scale, 236 | shiftIter=shiftIter, 237 | maxIter=self.modelConfig.maxIterAtScale[scale]) 238 | 239 | if not status: 240 | return False 241 | 242 | shiftIter += sizeDB 243 | while shiftAlpha < len(self.modelConfig.iterAlphaJump[scale]) and \ 244 | self.modelConfig.iterAlphaJump[scale][shiftAlpha] < shiftIter: 245 | shiftAlpha += 1 246 | 247 | # Save a checkpoint 248 | if self.checkPointDir is not None: 249 | realIter = min( 250 | shiftIter, self.modelConfig.maxIterAtScale[scale]) 251 | label = self.modelLabel + ("_s%d_i%d" % 252 | (scale, realIter)) 253 | self.saveCheckpoint(self.checkPointDir, 254 | label, scale, realIter) 255 | if scale == n_scales - 1: 256 | break 257 | 258 | self.model.addScale(self.modelConfig.depthScales[scale + 1]) 259 | 260 | self.startScale = n_scales 261 | self.startIter = self.modelConfig.maxIterAtScale[-1] 262 | return True 263 | 264 | def addNewScales(self, configNewScales): 265 | 266 | if configNewScales["alphaJumpMode"] not in ["custom", "linear"]: 267 | raise ValueError("alphaJumpMode should be one of the followings: \ 268 | 'custom', 'linear'") 269 | 270 | if configNewScales["alphaJumpMode"] == 'custom': 271 | self.modelConfig.iterAlphaJump = self.modelConfig.iterAlphaJump + \ 272 | configNewScales["iterAlphaJump"] 273 | self.modelConfig.alphaJumpVals = self.modelConfig.alphaJumpVals + \ 274 | configNewScales["alphaJumpVals"] 275 | 276 | else: 277 | self.updateAlphaJumps(configNewScales["alphaNJumps"], 278 | configNewScales["alphaSizeJumps"]) 279 | 280 | self.modelConfig.depthScales = self.modelConfig.depthScales + \ 281 | configNewScales["depthScales"] 282 | self.modelConfig.maxIterAtScale = self.modelConfig.maxIterAtScale + \ 283 | configNewScales["maxIterAtScale"] 284 | 285 | self.scaleSanityCheck() 286 | -------------------------------------------------------------------------------- /models/trainer/standard_configurations/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/pytorch_GAN_zoo/b75dee40918caabb4fe7ec561522717bf096a8cb/models/trainer/standard_configurations/__init__.py -------------------------------------------------------------------------------- /models/trainer/standard_configurations/dcgan_config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | from ...utils.config import BaseConfig 3 | 4 | # Default configuration for ProgressiveGANTrainer 5 | _C = BaseConfig() 6 | 7 | ############################################################ 8 | 9 | # Depth of a convolutional layer for each scale 10 | _C.depth = 3 11 | 12 | # Mini batch size 13 | _C.miniBatchSize = 16 14 | 15 | # Dimension of the latent vector 16 | _C.dimLatentVector = 100 17 | 18 | # Dimension of the output image 19 | _C.dimOutput = 3 20 | 21 | # Dimension of the generator 22 | _C.dimG = 64 23 | 24 | # Dimension of the discrimator 25 | _C.dimD = 64 26 | 27 | # Loss mode 28 | _C.lossMode = 'DCGAN' 29 | 30 | # Gradient penalty coefficient (WGANGP) 31 | _C.lambdaGP = 0. 32 | 33 | # Noise standard deviation in case of instance noise (0 <=> no Instance noise) 34 | _C.sigmaNoise = 0. 35 | 36 | # Weight penalty on |D(x)|^2 37 | _C.epsilonD = 0. 38 | 39 | # Base learning rate 40 | _C.baseLearningRate = 0.0002 41 | 42 | # In case of AC GAN, weight on the classification loss (per scale) 43 | _C.weightConditionG = 0.0 44 | _C.weightConditionD = 0.0 45 | 46 | # Activate GDPP loss ? 47 | _C.GDPP = False 48 | 49 | # Number of epochs 50 | _C.nEpoch = 10 51 | 52 | # Do not modify. Field used to save the attribute dictionnary for labelled 53 | # datasets 54 | _C.attribKeysOrder = None 55 | -------------------------------------------------------------------------------- /models/trainer/standard_configurations/pgan_config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | from ...utils.config import BaseConfig 3 | 4 | # Default configuration for ProgressiveGANTrainer 5 | _C = BaseConfig() 6 | 7 | # Maximum number of iteration at each scale 8 | _C.maxIterAtScale = [48000, 96000, 96000, 9 | 96000, 96000, 96000, 96000, 96000, 200000] 10 | 11 | # Blending mode. 12 | 13 | ############################################################ 14 | 15 | 16 | # 2 possible values are possible: 17 | # - custom: iterations at which alpha should be updated and new value after the update 18 | # are fully described by the user 19 | # - linear: The user just inputs the number of updates of alpha, and the number of iterations 20 | # between two updates for each scale 21 | _C.alphaJumpMode = "linear" 22 | 23 | # If _C.alphaJumpMode == "custom", then the following fields should be completed 24 | 25 | # For each scale, iteration at wich the blending factor alpha should be 26 | # updated 27 | _C.iterAlphaJump = [[], [0, 1000, 2000], [0, 1000, 4000, 8000, 16000], 28 | [0, 2000, 4000, 8000]] 29 | 30 | # New value of the blending factor alpha during the update (see above) 31 | _C.alphaJumpVals = [[], [1., 0.5, 0], [ 32 | 1, 0.75, 0.5, 0.25, 0.], [1., 0.75, 0.5, 0.]] 33 | 34 | # If _C.alphaJumpMode == "linear", then the following fields should be completed 35 | 36 | # Number of jumps per scale 37 | _C.alphaNJumps = [0, 600, 600, 600, 600, 600, 600, 600, 600] 38 | 39 | # Number of iterations between two jumps 40 | _C.alphaSizeJumps = [0, 32, 32, 32, 32, 32, 32, 32, 32, 32] 41 | 42 | ############################################################# 43 | 44 | # Depth of a convolutional layer for each scale 45 | _C.depthScales = [512, 512, 512, 512, 256, 128, 64, 32, 16] 46 | 47 | # Mini batch size 48 | _C.miniBatchSize = 16 49 | 50 | # Dimension of the latent vector 51 | _C.dimLatentVector = 512 52 | 53 | # Should bias be initialized to zero ? 54 | _C.initBiasToZero = True 55 | 56 | # Per channel normalization 57 | _C.perChannelNormalization = True 58 | 59 | # Loss mode 60 | _C.lossMode = 'WGANGP' 61 | 62 | # Gradient penalty coefficient (WGANGP) 63 | _C.lambdaGP = 10. 64 | 65 | # Leakyness of the leakyRelU activation function 66 | _C.leakyness = 0.2 67 | 68 | # Weight penalty on |D(x)|^2 69 | _C.epsilonD = 0.001 70 | 71 | # Mini batch regularization 72 | _C.miniBatchStdDev = True 73 | 74 | # Base learning rate 75 | _C.baseLearningRate = 0.001 76 | 77 | # RGB or grey level output ? 78 | _C.dimOutput = 3 79 | 80 | # In case of AC GAN, weight on the classification loss (per scale) 81 | _C.weightConditionG = 0.0 82 | _C.weightConditionD = 0.0 83 | 84 | # Do not fill. Loaded automatically 85 | _C.attribKeysOrder = None 86 | 87 | #Activate GDPP loss ? 88 | _C.GDPP = False 89 | -------------------------------------------------------------------------------- /models/trainer/standard_configurations/stylegan_config.py: -------------------------------------------------------------------------------- 1 | from ...utils.config import BaseConfig 2 | 3 | # Default configuration for ProgressiveGANTrainer 4 | _C = BaseConfig() 5 | 6 | # Maximum number of iteration at each scale 7 | _C.maxIterAtScale = [96000, 96000, 8 | 96000, 96000, 96000, 96000, 96000, 200000] 9 | 10 | # Blending mode. 11 | 12 | ############################################################ 13 | 14 | 15 | # 2 possible values are possible: 16 | # - custom: iterations at which alpha should be updated and new value after the update 17 | # are fully described by the user 18 | # - linear: The user just inputs the number of updates of alpha, and the number of iterations 19 | # between two updates for each scale 20 | _C.alphaJumpMode = "linear" 21 | 22 | # If _C.alphaJumpMode == "custom", then the following fields should be completed 23 | 24 | # For each scale, iteration at wich the blending factor alpha should be 25 | # updated 26 | _C.iterAlphaJump = [] 27 | 28 | # New value of the blending factor alpha during the update (see above) 29 | _C.alphaJumpVals = [] 30 | 31 | # If _C.alphaJumpMode == "linear", then the following fields should be completed 32 | 33 | # Number of jumps per scale 34 | _C.alphaNJumps = [0, 0, 600, 600, 600, 600, 600, 600, 600] 35 | 36 | # Number of iterations between two jumps 37 | _C.alphaSizeJumps = [0, 0, 32, 32, 32, 32, 32, 32, 32, 32] 38 | 39 | ############################################################# 40 | 41 | # Depth of a convolutional layer for each scale 42 | _C.depthScales = [512, 512, 512, 512, 256, 128, 64, 32, 16] 43 | 44 | # Mini batch size 45 | _C.miniBatchSize = 16 46 | 47 | # Dimension of the latent vector 48 | _C.dimLatentVector = 512 49 | 50 | # We are doing an alternative training. Number of consecutive updates of G 51 | _C.kInnerG = 1 52 | 53 | # We are doing an alternative training. Number of consecutive updates of D 54 | _C.kInnerD = 1 55 | 56 | # Should bias be initialized to zero ? 57 | _C.initBiasToZero = True 58 | 59 | # Per channel normalization 60 | _C.perChannelNormalization = True 61 | 62 | # Loss mode 63 | _C.lossMode = 'Logistic' 64 | 65 | # Gradient penalty coefficient (WGANGP) 66 | _C.lambdaGP = 0. 67 | 68 | # Leakyness of the leakyRelU activation function 69 | _C.leakyness = 0.2 70 | 71 | # Weight penalty on |D(x)|^2 72 | _C.epsilonD = 0. 73 | 74 | # Mini batch regularization 75 | _C.miniBatchStdDev = True 76 | 77 | # Base learning rate 78 | _C.baseLearningRate = 0.001 79 | 80 | # RGB or grey level output ? 81 | _C.dimOutput = 3 82 | 83 | # In case of AC GAN, weight on the classification loss (per scale) 84 | _C.weightConditionG = 0.0 85 | _C.weightConditionD = 0.0 86 | 87 | # Equalized learning rate 88 | _C.equalizedlR = True 89 | 90 | _C.attribKeysOrder = None 91 | 92 | _C.equalizeLabels = False 93 | 94 | # Truncation trick 95 | _C.phiTruncation = 0.5 96 | 97 | _C.gamma_avg = 0.99 98 | 99 | #Activate GDPP loss ? 100 | _C.GDPP = False 101 | 102 | _C.logisticGradReal = 5. 103 | -------------------------------------------------------------------------------- /models/trainer/styleGAN_trainer.py: -------------------------------------------------------------------------------- 1 | from ..styleGAN import StyleGAN 2 | from .progressive_gan_trainer import ProgressiveGANTrainer 3 | 4 | from .standard_configurations.stylegan_config import _C 5 | 6 | class StyleGANTrainer(ProgressiveGANTrainer): 7 | 8 | _defaultConfig = _C 9 | 10 | def getDefaultConfig(self): 11 | return StyleGANTrainer._defaultConfig 12 | 13 | def __init__(self, pathdb, **kwargs): 14 | ProgressiveGANTrainer.__init__(self, pathdb, **kwargs) 15 | 16 | def initModel(self): 17 | config = {key: value for key, value in vars(self.modelConfig).items()} 18 | config["depthScale0"] = self.modelConfig.depthScales[0] 19 | self.model = StyleGAN(useGPU=self.useGPU, **config) 20 | if self.startScale ==0: 21 | self.startScale = 1 22 | self.model.addScale(self.modelConfig.depthScales[1]) 23 | -------------------------------------------------------------------------------- /models/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | -------------------------------------------------------------------------------- /models/utils/config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | class BaseConfig(): 3 | r""" 4 | An empty class used for configuration members 5 | """ 6 | 7 | def __init__(self, orig=None): 8 | if orig is not None: 9 | print("cawet") 10 | 11 | 12 | def getConfigFromDict(obj, inputDict, defaultConfig): 13 | r""" 14 | Using a new configuration dictionary and a default configuration 15 | setup an object with the given configuration. 16 | 17 | for example, if you have 18 | inputDict = {"coin": 22} 19 | defaultConfig.coin = 23 20 | defaultConfig.pan = 12 21 | 22 | Then the given obj will get two new members 'coin' and 'pan' with 23 | obj.coin = 22 24 | obj.pan = 12 25 | 26 | Args: 27 | 28 | - obj (Object): the object to modify. 29 | - inputDict (dictionary): new configuration 30 | - defaultConfig (Object): default configuration 31 | """ 32 | if not inputDict: 33 | for member, value in vars(defaultConfig).items(): 34 | setattr(obj, member, value) 35 | else: 36 | for member, value in vars(defaultConfig).items(): 37 | setattr(obj, member, inputDict.get(member, value)) 38 | 39 | 40 | def updateConfig(obj, ref): 41 | r""" 42 | Update a configuration with the fields of another given configuration 43 | """ 44 | 45 | if isinstance(ref, dict): 46 | for member, value in ref.items(): 47 | setattr(obj, member, value) 48 | 49 | else: 50 | 51 | for member, value in vars(ref).items(): 52 | setattr(obj, member, value) 53 | 54 | 55 | def str2bool(v): 56 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 57 | return True 58 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 59 | return False 60 | else: 61 | raise AttributeError('Boolean value expected.') 62 | 63 | 64 | def updateParserWithConfig(parser, defaultConfig): 65 | 66 | for name, key in vars(defaultConfig).items(): 67 | if key is None: 68 | continue 69 | 70 | if isinstance(key, bool): 71 | parser.add_argument('--' + name, type=str2bool, dest=name) 72 | else: 73 | parser.add_argument('--' + name, type=type(key), dest=name) 74 | 75 | parser.add_argument('--overrides', 76 | action='store_true', 77 | help= "For more information on attribute parameters, \ 78 | please have a look at \ 79 | models/trainer/standard_configurations") 80 | return parser 81 | 82 | 83 | def getConfigOverrideFromParser(parsedArgs, defaultConfig): 84 | 85 | output = {} 86 | for arg, value in parsedArgs.items(): 87 | if value is None: 88 | continue 89 | 90 | if arg in vars(defaultConfig): 91 | output[arg] = value 92 | 93 | return output 94 | 95 | 96 | def getDictFromConfig(obj, referenceConfig, printDefault=True): 97 | r""" 98 | Retrieve all the members of obj which are also members of referenceConfig 99 | and dump them into a dictionnary 100 | 101 | If printDefault is activated, members of referenceConfig which are not found 102 | in obj will also be dumped 103 | """ 104 | 105 | output = {} 106 | for member, value in vars(referenceConfig).items(): 107 | if hasattr(obj, member): 108 | output[member] = getattr(obj, member) 109 | elif printDefault: 110 | output[member] = value 111 | 112 | return output 113 | -------------------------------------------------------------------------------- /models/utils/image_transform.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import torchvision.transforms as Transforms 3 | 4 | import os 5 | import random 6 | import numpy as np 7 | from PIL import Image 8 | 9 | # The equivalent of some torchvision.transforms operations but for numpy array 10 | # instead of PIL images 11 | 12 | 13 | class NumpyResize(object): 14 | 15 | def __init__(self, size): 16 | self.size = size 17 | 18 | def __call__(self, img): 19 | r""" 20 | Args: 21 | 22 | img (np array): image to be resized 23 | 24 | Returns: 25 | 26 | np array: resized image 27 | """ 28 | if not isinstance(img, Image.Image): 29 | img = Image.fromarray(img) 30 | return np.array(img.resize(self.size, resample=Image.BILINEAR)) 31 | 32 | def __repr__(self): 33 | return self.__class__.__name__ + '(p={})'.format(self.p) 34 | 35 | 36 | class NumpyFlip(object): 37 | 38 | def __init__(self, p=0.5): 39 | self.p = p 40 | random.seed(None) 41 | 42 | def __call__(self, img): 43 | """ 44 | Args: 45 | img (PIL Image): Image to be flipped. 46 | Returns: 47 | PIL Image: Randomly flipped image. 48 | """ 49 | if random.random() < self.p: 50 | return np.flip(img, 1).copy() 51 | return img 52 | 53 | def __repr__(self): 54 | return self.__class__.__name__ + '(p={})'.format(self.p) 55 | 56 | 57 | class NumpyToTensor(object): 58 | 59 | def __init__(self): 60 | return 61 | 62 | def __call__(self, img): 63 | r""" 64 | Turn a numpy objevt into a tensor. 65 | """ 66 | 67 | if len(img.shape) == 2: 68 | img = img.reshape(img.shape[0], img.shape[1], 1) 69 | 70 | return Transforms.functional.to_tensor(img) 71 | 72 | 73 | def pil_loader(path): 74 | imgExt = os.path.splitext(path)[1] 75 | if imgExt == ".npy": 76 | img = np.load(path)[0] 77 | return np.swapaxes(np.swapaxes(img, 0, 2), 0, 1) 78 | 79 | # open path as file to avoid ResourceWarning 80 | # (https://github.com/python-pillow/Pillow/issues/835) 81 | with open(path, 'rb') as f: 82 | img = Image.open(f) 83 | return img.convert('RGB') 84 | 85 | 86 | def standardTransform(size): 87 | return Transforms.Compose([NumpyResize(size), 88 | Transforms.ToTensor(), 89 | Transforms.Normalize((0.5, 0.5, 0.5), 90 | (0.5, 0.5, 0.5))]) 91 | -------------------------------------------------------------------------------- /models/utils/product_module.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | def buildMaskSplit(noiseGShape, 4 | noiseGTexture, 5 | categoryVectorDim, 6 | attribKeysOrder, 7 | attribShift, 8 | keySplits=None, 9 | mixedNoise=False): 10 | r""" 11 | Build a 8bits mask that split a full input latent vector into two 12 | intermediate latent vectors: one for the shape network and one for the 13 | texture network. 14 | """ 15 | 16 | # latent vector split 17 | # Reminder, a latent vector is organized as follow 18 | # [z1,......, z_N, c_1, ....., c_C] 19 | # N : size of the noise part 20 | # C : size of the conditional part (ACGAN) 21 | 22 | # Here we will split the vector in 23 | # [y1, ..., y_N1, z1,......, z_N2, c_1, ....., c_C] 24 | 25 | N1 = noiseGShape 26 | N2 = noiseGTexture 27 | 28 | if not mixedNoise: 29 | maskShape = [1 for x in range(N1)] + [0 for x in range(N2)] 30 | maskTexture = [0 for x in range(N1)] + [1 for x in range(N2)] 31 | else: 32 | maskShape = [1 for x in range(N1 + N2)] 33 | maskTexture = [1 for x in range(N1 + N2)] 34 | 35 | # Now the conditional part 36 | # Some conditions apply to the shape, other to the texture, and sometimes 37 | # to both 38 | if attribKeysOrder is not None: 39 | 40 | C = categoryVectorDim 41 | 42 | if keySplits is not None: 43 | maskShape = maskShape + [0 for x in range(C)] 44 | maskTexture = maskTexture + [0 for x in range(C)] 45 | 46 | for key in keySplits["GShape"]: 47 | 48 | index = attribKeysOrder[key]["order"] 49 | shift = N1 + N2 + attribShift[index] 50 | 51 | for i in range(shift, shift + len(attribKeysOrder[key]["values"])): 52 | maskShape[i] = 1 53 | 54 | for key in keySplits["GTexture"]: 55 | 56 | index = attribKeysOrder[key]["order"] 57 | shift = N1 + N2 + attribShift[index] 58 | for i in range(shift, shift + len(attribKeysOrder[key]["values"])): 59 | maskTexture[i] = 1 60 | else: 61 | 62 | maskShape = maskShape + [1 for x in range(C)] 63 | maskTexture = maskTexture + [1 for x in range(C)] 64 | 65 | return maskShape, maskTexture 66 | -------------------------------------------------------------------------------- /models/utils/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import os 3 | import time 4 | import json 5 | import math 6 | 7 | import torch 8 | 9 | 10 | def isinf(tensor): 11 | r"""Returns a new tensor with boolean elements representing if each element 12 | is `+/-INF` or not. 13 | 14 | Arguments: 15 | tensor (Tensor): A tensor to check 16 | 17 | Returns: 18 | Tensor: A ``torch.ByteTensor`` containing a 1 at each location of 19 | `+/-INF` elements and 0 otherwise 20 | 21 | Example:: 22 | 23 | >>> torch.isinf(torch.Tensor([1, float('inf'), 2, 24 | float('-inf'), float('nan')])) 25 | tensor([ 0, 1, 0, 1, 0], dtype=torch.uint8) 26 | """ 27 | if not isinstance(tensor, torch.Tensor): 28 | raise ValueError("The argument is not a tensor", str(tensor)) 29 | return tensor.abs() == math.inf 30 | 31 | 32 | def isnan(tensor): 33 | r"""Returns a new tensor with boolean elements representing if each element 34 | is `NaN` or not. 35 | 36 | Arguments: 37 | tensor (Tensor): A tensor to check 38 | 39 | Returns: 40 | Tensor: A ``torch.ByteTensor`` containing a 1 at each location of `NaN` 41 | elements. 42 | 43 | Example:: 44 | 45 | >>> torch.isnan(torch.tensor([1, float('nan'), 2])) 46 | tensor([ 0, 1, 0], dtype=torch.uint8) 47 | """ 48 | if not isinstance(tensor, torch.Tensor): 49 | raise ValueError("The argument is not a tensor", str(tensor)) 50 | return tensor != tensor 51 | 52 | 53 | def finiteCheck(parameters): 54 | if isinstance(parameters, torch.Tensor): 55 | parameters = [parameters] 56 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 57 | 58 | for p in parameters: 59 | infGrads = isinf(p.grad.data) 60 | p.grad.data[infGrads] = 0 61 | 62 | nanGrads = isnan(p.grad.data) 63 | p.grad.data[nanGrads] = 0 64 | 65 | 66 | def prepareClassifier(module, outFeatures): 67 | 68 | model = module() 69 | inFeatures = model.fc.in_features 70 | model.fc = torch.nn.Linear(inFeatures, outFeatures) 71 | 72 | return model 73 | 74 | 75 | def getMinOccurence(inputDict, value, default): 76 | 77 | keys = list(inputDict.keys()) 78 | outKeys = [x for x in keys if x <= value] 79 | outKeys.sort() 80 | 81 | if len(outKeys) == 0: 82 | return default 83 | 84 | return inputDict[outKeys[-1]] 85 | 86 | 87 | def getNameAndPackage(strCode): 88 | 89 | if strCode == 'PGAN': 90 | return "progressive_gan", "ProgressiveGAN" 91 | 92 | if strCode == 'PPGAN': 93 | return "pp_gan", "PPGAN" 94 | 95 | if strCode == "DCGAN": 96 | return "DCGAN", "DCGAN" 97 | 98 | if strCode == "StyleGAN": 99 | return "styleGAN", "StyleGAN" 100 | 101 | raise ValueError("Unrecognized code " + strCode) 102 | 103 | 104 | def parse_state_name(path): 105 | r""" 106 | Parse a file name with the given pattern: 107 | pattern = ($model_name)_s($scale)_i($iteration).pt 108 | 109 | Returns: None if the path doesn't fulfill the pattern 110 | """ 111 | path = os.path.splitext(os.path.basename(path))[0] 112 | 113 | data = path.split('_') 114 | 115 | if len(data) < 3: 116 | return None 117 | 118 | # Iteration 119 | if data[-1][0] == "i" and data[-1][1:].isdigit(): 120 | iteration = int(data[-1][1:]) 121 | else: 122 | return None 123 | 124 | if data[-2][0] == "s" and data[-2][1:].isdigit(): 125 | scale = int(data[-2][1:]) 126 | else: 127 | return None 128 | 129 | name = "_".join(data[:-2]) 130 | 131 | return name, scale, iteration 132 | 133 | 134 | def parse_config_name(path): 135 | r""" 136 | Parse a file name with the given pattern: 137 | pattern = ($model_name)_train_config.json 138 | 139 | Raise an error if the pattern doesn't match 140 | """ 141 | 142 | path = os.path.basename(path) 143 | 144 | if len(path) < 18 or path[-18:] != "_train_config.json": 145 | raise ValueError("Invalid configuration path") 146 | 147 | return path[:-18] 148 | 149 | 150 | def getLastCheckPoint(dir, name, scale=None, iter=None): 151 | r""" 152 | Get the last checkpoint of the model with name @param name detected in the 153 | directory (@param dir) 154 | 155 | Returns: 156 | trainConfig, pathModel, pathTmpData 157 | 158 | trainConfig: path to the training configuration (.json) 159 | pathModel: path to the model's weight data (.pt) 160 | pathTmpData: path to the temporary configuration (.json) 161 | """ 162 | trainConfig = os.path.join(dir, name + "_train_config.json") 163 | 164 | if not os.path.isfile(trainConfig): 165 | return None 166 | 167 | listFiles = [f for f in os.listdir(dir) if ( 168 | os.path.splitext(f)[1] == ".pt" and 169 | parse_state_name(f) is not None and 170 | parse_state_name(f)[0] == name)] 171 | 172 | if scale is not None: 173 | listFiles = [f for f in listFiles if parse_state_name(f)[1] == scale] 174 | 175 | if iter is not None: 176 | listFiles = [f for f in listFiles if parse_state_name(f)[2] == iter] 177 | 178 | listFiles.sort(reverse=True, key=lambda x: ( 179 | parse_state_name(x)[1], parse_state_name(x)[2])) 180 | 181 | if len(listFiles) == 0: 182 | return None 183 | 184 | pathModel = os.path.join(dir, listFiles[0]) 185 | pathTmpData = os.path.splitext(pathModel)[0] + "_tmp_config.json" 186 | 187 | if not os.path.isfile(pathTmpData): 188 | return None 189 | 190 | return trainConfig, pathModel, pathTmpData 191 | 192 | 193 | def getVal(kwargs, key, default): 194 | 195 | out = kwargs.get(key, default) 196 | if out is None: 197 | return default 198 | 199 | return out 200 | 201 | 202 | def toStrKey(item): 203 | 204 | if item is None: 205 | return "" 206 | 207 | out = "_" + str(item) 208 | out = out.replace("'", "") 209 | return out 210 | 211 | 212 | def num_flat_features(x): 213 | size = x.size()[1:] # all dimensions except the batch dimension 214 | num_features = 1 215 | for s in size: 216 | num_features *= s 217 | return num_features 218 | 219 | 220 | def printProgressBar(iteration, 221 | total, 222 | prefix='', 223 | suffix='', 224 | decimals=1, 225 | length=100, 226 | fill='#'): 227 | """ 228 | Call in a loop to create terminal progress bar 229 | @params: 230 | iteration - Required : current iteration (Int) 231 | total - Required : total iterations (Int) 232 | prefix - Optional : prefix string (Str) 233 | suffix - Optional : suffix string (Str) 234 | decimals - Optional : positive number of decimals in percent 235 | complete (Int) 236 | length - Optional : character length of bar (Int) 237 | fill - Optional : bar fill character (Str) 238 | """ 239 | percent = ("{0:." + str(decimals) + "f}").format(100 * 240 | (iteration / float(total))) 241 | filledLength = int(length * iteration // total) 242 | bar = fill * filledLength + '-' * (length - filledLength) 243 | print('\r%s |%s| %s%% %s' % (prefix, bar, percent, suffix), end='\r') 244 | # Print New Line on Complete 245 | if iteration == total: 246 | print() 247 | 248 | 249 | def loadPartOfStateDict(module, state_dict, forbiddenLayers=None): 250 | r""" 251 | Load the input state dict to the module except for the weights corresponding 252 | to one of the forbidden layers 253 | """ 254 | own_state = module.state_dict() 255 | if forbiddenLayers is None: 256 | forbiddenLayers = [] 257 | for name, param in state_dict.items(): 258 | if name.split(".")[0] in forbiddenLayers: 259 | continue 260 | if isinstance(param, torch.nn.Parameter): 261 | # backwards compatibility for serialized parameters 262 | param = param.data 263 | 264 | own_state[name].copy_(param) 265 | 266 | 267 | def loadStateDictCompatible(module, state_dict): 268 | r""" 269 | Load the input state dict to the module except for the weights corresponding 270 | to one of the forbidden layers 271 | """ 272 | own_state = module.state_dict() 273 | for name, param in state_dict.items(): 274 | if isinstance(param, torch.nn.Parameter): 275 | # backwards compatibility for serialized parameters 276 | param = param.data 277 | 278 | if name in own_state: 279 | own_state[name].copy_(param) 280 | continue 281 | 282 | # Else see if the input name is a prefix 283 | suffixes = ["bias", "weight"] 284 | found = False 285 | for suffix in suffixes: 286 | indexEnd = name.find(suffix) 287 | if indexEnd > 0: 288 | newKey = name[:indexEnd] + "module." + suffix 289 | if newKey in own_state: 290 | own_state[newKey].copy_(param) 291 | found = True 292 | break 293 | 294 | if not found: 295 | raise AttributeError("Unknow key " + name) 296 | 297 | 298 | def loadmodule(package, name, prefix='..'): 299 | r""" 300 | A dirty hack to load a module from a string input 301 | 302 | Args: 303 | package (string): package name 304 | name (string): module name 305 | 306 | Returns: 307 | A pointer to the loaded module 308 | """ 309 | strCmd = "from " + prefix + package + " import " + name + " as module" 310 | exec(strCmd) 311 | return eval('module') 312 | 313 | 314 | def saveScore(outPath, outValue, *args): 315 | 316 | flagPath = outPath + ".flag" 317 | 318 | while os.path.isfile(flagPath): 319 | time.sleep(1) 320 | 321 | open(flagPath, 'a').close() 322 | 323 | if os.path.isfile(outPath): 324 | with open(outPath, 'rb') as file: 325 | outDict = json.load(file) 326 | if not isinstance(outDict, dict): 327 | outDict = {} 328 | else: 329 | outDict = {} 330 | 331 | fullDict = outDict 332 | 333 | for item in args[:-1]: 334 | if str(item) not in outDict: 335 | outDict[str(item)] = {} 336 | outDict = outDict[str(item)] 337 | 338 | outDict[args[-1]] = outValue 339 | 340 | with open(outPath, 'w') as file: 341 | json.dump(fullDict, file, indent=2) 342 | 343 | os.remove(flagPath) 344 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | scipy 3 | visdom 4 | h5py 5 | nevergrad 6 | numpy 7 | -------------------------------------------------------------------------------- /save_feature_extractor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import torch 3 | import argparse 4 | from models.loss_criterions.loss_texture import LossTexture 5 | 6 | if __name__ == "__main__": 7 | 8 | parser = argparse.ArgumentParser(description='Testing script') 9 | parser.add_argument('model_name', type=str, 10 | choices=["vgg19", "vgg16"], 11 | help="""Name of the desured featire extractor: 12 | - vgg19, vgg16 : a variation of the style transfer \ 13 | feature developped in \ 14 | http://arxiv.org/abs/1703.06868""") 15 | parser.add_argument('--layers', type=int, nargs='*', 16 | help="For vgg models only. Layers to select. \ 17 | Default ones are 3, 4, 5.", default=None) 18 | parser.add_argument('output_path', type=str, 19 | help="""Path of the output feature extractor""") 20 | 21 | args = parser.parse_args() 22 | 23 | if args.model_name in ["vgg19", "vgg16"]: 24 | if args.layers is None: 25 | args.layers = [3, 4, 5] 26 | featureExtractor = LossTexture(torch.device("cpu"), 27 | args.model_name, 28 | args.layers) 29 | featureExtractor.saveModel(args.output_path) 30 | else: 31 | raise AttributeError(args.model_name + " not implemented yet") 32 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import os 3 | import sys 4 | import importlib 5 | import argparse 6 | 7 | from models.utils.utils import getVal, getLastCheckPoint, loadmodule 8 | from models.utils.config import getConfigOverrideFromParser, \ 9 | updateParserWithConfig 10 | 11 | import json 12 | 13 | 14 | def getTrainer(name): 15 | 16 | match = {"PGAN": ("progressive_gan_trainer", "ProgressiveGANTrainer"), 17 | "StyleGAN":("styleGAN_trainer", "StyleGANTrainer"), 18 | "DCGAN": ("DCGAN_trainer", "DCGANTrainer")} 19 | 20 | if name not in match: 21 | raise AttributeError("Invalid module name") 22 | 23 | return loadmodule("models.trainer." + match[name][0], 24 | match[name][1], 25 | prefix='') 26 | 27 | 28 | if __name__ == "__main__": 29 | 30 | parser = argparse.ArgumentParser(description='Testing script') 31 | parser.add_argument('model_name', type=str, 32 | help='Name of the model to launch, available models are\ 33 | PGAN, PPGAN adn StyleGAN. To get all possible option for a model\ 34 | please run train.py $MODEL_NAME -overrides') 35 | parser.add_argument('--no_vis', help=' Disable all visualizations', 36 | action='store_true') 37 | parser.add_argument('--np_vis', help=' Replace visdom by a numpy based \ 38 | visualizer (SLURM)', 39 | action='store_true') 40 | parser.add_argument('--restart', help=' If a checkpoint is detected, do \ 41 | not try to load it', 42 | action='store_true') 43 | parser.add_argument('-n', '--name', help="Model's name", 44 | type=str, dest="name", default="default") 45 | parser.add_argument('-d', '--dir', help='Output directory', 46 | type=str, dest="dir", default='output_networks') 47 | parser.add_argument('-c', '--config', help="Model's name", 48 | type=str, dest="configPath") 49 | parser.add_argument('-s', '--save_iter', help="If it applies, frequence at\ 50 | which a checkpoint should be saved. In the case of a\ 51 | evaluation test, iteration to work on.", 52 | type=int, dest="saveIter", default=16000) 53 | parser.add_argument('-e', '--eval_iter', help="If it applies, frequence at\ 54 | which a checkpoint should be saved", 55 | type=int, dest="evalIter", default=100) 56 | parser.add_argument('-S', '--Scale_iter', help="If it applies, scale to work\ 57 | on") 58 | parser.add_argument('-v', '--partitionValue', help="Partition's value", 59 | type=str, dest="partition_value") 60 | 61 | # Retrieve the model we want to launch 62 | baseArgs, unknown = parser.parse_known_args() 63 | trainerModule = getTrainer(baseArgs.model_name) 64 | 65 | # Build the output durectory if necessary 66 | if not os.path.isdir(baseArgs.dir): 67 | os.mkdir(baseArgs.dir) 68 | 69 | # Add overrides to the parser: changes to the model configuration can be 70 | # done via the command line 71 | parser = updateParserWithConfig(parser, trainerModule._defaultConfig) 72 | kwargs = vars(parser.parse_args()) 73 | configOverride = getConfigOverrideFromParser( 74 | kwargs, trainerModule._defaultConfig) 75 | 76 | if kwargs['overrides']: 77 | parser.print_help() 78 | sys.exit() 79 | 80 | # Checkpoint data 81 | modelLabel = kwargs["name"] 82 | restart = kwargs["restart"] 83 | checkPointDir = os.path.join(kwargs["dir"], modelLabel) 84 | checkPointData = getLastCheckPoint(checkPointDir, modelLabel) 85 | 86 | if not os.path.isdir(checkPointDir): 87 | os.mkdir(checkPointDir) 88 | 89 | # Training configuration 90 | configPath = kwargs.get("configPath", None) 91 | if configPath is None: 92 | raise ValueError("You need to input a configuratrion file") 93 | 94 | with open(kwargs["configPath"], 'rb') as file: 95 | trainingConfig = json.load(file) 96 | 97 | # Model configuration 98 | modelConfig = trainingConfig.get("config", {}) 99 | for item, val in configOverride.items(): 100 | modelConfig[item] = val 101 | trainingConfig["config"] = modelConfig 102 | 103 | # Visualization module 104 | vis_module = None 105 | if baseArgs.np_vis: 106 | vis_module = importlib.import_module("visualization.np_visualizer") 107 | elif baseArgs.no_vis: 108 | print("Visualization disabled") 109 | else: 110 | vis_module = importlib.import_module("visualization.visualizer") 111 | 112 | print("Running " + baseArgs.model_name) 113 | 114 | # Path to the image dataset 115 | pathDB = trainingConfig["pathDB"] 116 | trainingConfig.pop("pathDB", None) 117 | 118 | partitionValue = getVal(kwargs, "partition_value", 119 | trainingConfig.get("partitionValue", None)) 120 | 121 | GANTrainer = trainerModule(pathDB, 122 | useGPU=True, 123 | visualisation=vis_module, 124 | lossIterEvaluation=kwargs["evalIter"], 125 | checkPointDir=checkPointDir, 126 | saveIter= kwargs["saveIter"], 127 | modelLabel=modelLabel, 128 | partitionValue=partitionValue, 129 | **trainingConfig) 130 | 131 | # If a checkpoint is found, load it 132 | if not restart and checkPointData is not None: 133 | trainConfig, pathModel, pathTmpData = checkPointData 134 | print(f"Model found at path {pathModel}, pursuing the training") 135 | GANTrainer.loadSavedTraining(pathModel, trainConfig, pathTmpData) 136 | 137 | GANTrainer.train() 138 | -------------------------------------------------------------------------------- /visualization/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | -------------------------------------------------------------------------------- /visualization/np_visualizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import numpy as np 3 | import scipy 4 | import scipy.misc 5 | import imageio 6 | import torch 7 | from PIL import Image 8 | 9 | 10 | def make_numpy_grid(arrays_list, gridMaxWidth=2048, 11 | imgMinSize=128, 12 | interpolation='nearest'): 13 | 14 | # NCWH format 15 | N, C, W, H = arrays_list.shape 16 | 17 | arrays_list = ((arrays_list + 1.0) * 255.0 / 2.0).astype(np.uint8) 18 | 19 | if C == 1: 20 | arrays_list = np.reshape(arrays_list, (N, W, H)) 21 | 22 | gridMaxWidth = max(gridMaxWidth, W) 23 | 24 | imgSize = max(W, imgMinSize) 25 | imgHeight = int((float(imgSize) / W) * H) 26 | nImgsPerRows = min(N, int(gridMaxWidth // imgSize)) 27 | 28 | gridWidth = nImgsPerRows * imgSize 29 | 30 | nRows = N // nImgsPerRows 31 | if N % nImgsPerRows > 0: 32 | nRows += 1 33 | 34 | gridHeight = nRows * imgHeight 35 | if C == 1: 36 | outGrid = np.zeros((gridHeight, gridWidth), dtype='uint8') 37 | else: 38 | outGrid = np.zeros((gridHeight, gridWidth, C), dtype='uint8') 39 | outGrid += 255 40 | 41 | interp = { 42 | 'nearest': Image.NEAREST, 43 | 'lanczos': Image.LANCZOS, 44 | 'bilinear': Image.BILINEAR, 45 | 'bicubic': Image.BICUBIC 46 | } 47 | 48 | indexImage = 0 49 | for r in range(nRows): 50 | for c in range(nImgsPerRows): 51 | 52 | if indexImage == N: 53 | break 54 | 55 | xStart = c * imgSize 56 | yStart = r * imgHeight 57 | 58 | img = np.array(arrays_list[indexImage]) 59 | img = Image.fromarray(np.transpose(img, (1,2,0))) 60 | 61 | tmpImage = np.array(img.resize((imgSize, imgHeight), resample=interp[interpolation])) 62 | 63 | if C == 1: 64 | outGrid[yStart:(yStart + imgHeight), 65 | xStart:(xStart + imgSize)] = tmpImage 66 | else: 67 | outGrid[yStart:(yStart + imgHeight), 68 | xStart:(xStart + imgSize), :] = tmpImage 69 | 70 | indexImage += 1 71 | 72 | return outGrid 73 | 74 | 75 | def publishTensors(data, out_size_image, caption="", window_token=None, env="main"): 76 | return None 77 | 78 | 79 | def publishLoss(*args, **kwargs): 80 | return None 81 | 82 | 83 | def publishLinePlot(data, xData, name="", window_token=None, env="main"): 84 | return None 85 | 86 | 87 | def publishScatterPlot(data, name="", window_token=None): 88 | return None 89 | 90 | 91 | def saveTensor(data, out_size_image, path): 92 | 93 | interpolation = 'nearest' 94 | if isinstance(out_size_image, tuple): 95 | out_size_image = out_size_image[0] 96 | data = torch.clamp(data, min=-1, max=1) 97 | outdata = make_numpy_grid( 98 | data.numpy(), imgMinSize=out_size_image, interpolation=interpolation) 99 | imageio.imwrite(path, outdata) 100 | 101 | 102 | def delete_env(env_name): 103 | return None 104 | -------------------------------------------------------------------------------- /visualization/visualizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import visdom 3 | import torch 4 | import torchvision.transforms as Transforms 5 | import torchvision.utils as vutils 6 | import numpy as np 7 | import random 8 | 9 | vis = visdom.Visdom() 10 | 11 | 12 | def resizeTensor(data, out_size_image): 13 | 14 | out_data_size = (data.size()[0], data.size()[ 15 | 1], out_size_image[0], out_size_image[1]) 16 | 17 | outdata = torch.empty(out_data_size) 18 | data = torch.clamp(data, min=-1, max=1) 19 | 20 | interpolationMode = 0 21 | if out_size_image[0] < data.size()[0] and out_size_image[1] < data.size()[1]: 22 | interpolationMode = 2 23 | 24 | transform = Transforms.Compose([Transforms.Normalize((-1., -1., -1.), (2, 2, 2)), 25 | Transforms.ToPILImage(), 26 | Transforms.Resize( 27 | out_size_image, interpolation=interpolationMode), 28 | Transforms.ToTensor()]) 29 | 30 | for img in range(out_data_size[0]): 31 | outdata[img] = transform(data[img]) 32 | 33 | return outdata 34 | 35 | 36 | def publishTensors(data, out_size_image, caption="", window_token=None, env="main", nrow=16): 37 | global vis 38 | outdata = resizeTensor(data, out_size_image) 39 | return vis.images(outdata, opts=dict(caption=caption), win=window_token, env=env, nrow=nrow) 40 | 41 | 42 | def saveTensor(data, out_size_image, path): 43 | outdata = resizeTensor(data, out_size_image) 44 | vutils.save_image(outdata, path) 45 | 46 | 47 | def publishLoss(data, name="", window_tokens=None, env="main"): 48 | 49 | if window_tokens is None: 50 | window_tokens = {key: None for key in data} 51 | 52 | for key, plot in data.items(): 53 | 54 | if key in ("scale", "iter"): 55 | continue 56 | 57 | nItems = len(plot) 58 | inputY = np.array([plot[x] for x in range(nItems) if plot[x] is not None]) 59 | inputX = np.array([data["iter"][x] for x in range(nItems) if plot[x] is not None]) 60 | 61 | opts = {'title': key + (' scale %d loss over time' % data["scale"]), 62 | 'legend': [key], 'xlabel': 'iteration', 'ylabel': 'loss'} 63 | 64 | window_tokens[key] = vis.line(X=inputX, Y=inputY, opts=opts, 65 | win=window_tokens[key], env=env) 66 | 67 | return window_tokens 68 | 69 | 70 | def delete_env(name): 71 | 72 | vis.delete_env(name) 73 | 74 | 75 | def publishScatterPlot(data, name="", window_token=None): 76 | r""" 77 | Draws 2D or 3d scatter plots 78 | 79 | Args: 80 | 81 | data (list of tensors): list of Ni x 2 or Ni x 3 tensors. Each tensor 82 | representing a cloud of data 83 | name (string): plot name 84 | window_token (token): ID of the window the plot should be done 85 | 86 | Returns: 87 | 88 | ID of the window where the data are plotted 89 | """ 90 | 91 | if not isinstance(data, list): 92 | raise ValueError("Input data should be a list of tensors") 93 | 94 | nPlots = len(data) 95 | colors = [] 96 | 97 | random.seed(None) 98 | 99 | for item in range(nPlots): 100 | N = data[item].size()[0] 101 | colors.append(torch.randint(0, 256, (1, 3)).expand(N, 3)) 102 | 103 | colors = torch.cat(colors, dim=0).numpy() 104 | opts = {'markercolor': colors, 105 | 'caption': name} 106 | activeData = torch.cat(data, dim=0) 107 | 108 | return vis.scatter(activeData, opts=opts, win=window_token, name=name) 109 | --------------------------------------------------------------------------------