├── LICENSE ├── README.md ├── animations └── carla_256.gif ├── configs ├── carla.yaml ├── cats.yaml ├── celebA.yaml ├── celebAHQ.yaml ├── cub.yaml ├── debug.yaml ├── default.yaml └── pretrained_models.yaml ├── data ├── cub │ └── filtered_files.txt ├── download_carla.sh ├── download_carla_poses.sh ├── preprocess_cats.py └── preprocess_cub.py ├── environment.yml ├── eval.py ├── external └── colmap │ ├── __init__.py │ ├── filter_points.py │ └── run_colmap_automatic.sh ├── graf ├── config.py ├── datasets.py ├── gan_training.py ├── models │ ├── discriminator.py │ └── generator.py ├── transforms.py └── utils.py ├── submodules ├── GAN_stability │ ├── .gitignore │ ├── LICENSE │ ├── README.md │ ├── configs │ │ ├── celebAHQ.yaml │ │ ├── default.yaml │ │ ├── imagenet.yaml │ │ ├── lsun_bedroom.yaml │ │ ├── lsun_bridge.yaml │ │ ├── lsun_church.yaml │ │ ├── lsun_tower.yaml │ │ └── pretrained │ │ │ ├── celebAHQ_pretrained.yaml │ │ │ ├── celebA_pretrained.yaml │ │ │ ├── imagenet_pretrained.yaml │ │ │ ├── lsun_bedroom_pretrained.yaml │ │ │ ├── lsun_bridge_pretrained.yaml │ │ │ ├── lsun_church_pretrained.yaml │ │ │ └── lsun_tower_pretrained.yaml │ ├── gan_training │ │ ├── __init__.py │ │ ├── checkpoints.py │ │ ├── config.py │ │ ├── distributions.py │ │ ├── eval.py │ │ ├── inputs.py │ │ ├── logger.py │ │ ├── metrics │ │ │ ├── __init__.py │ │ │ ├── fid_score.py │ │ │ ├── inception.py │ │ │ ├── inception_score.py │ │ │ └── kid_score.py │ │ ├── models │ │ │ ├── __init__.py │ │ │ ├── resnet.py │ │ │ ├── resnet2.py │ │ │ ├── resnet3.py │ │ │ └── resnet4.py │ │ ├── ops.py │ │ ├── train.py │ │ └── utils.py │ ├── interpolate.py │ ├── interpolate_class.py │ ├── notebooks │ │ ├── DiracGAN.ipynb │ │ ├── create_video.sh │ │ └── diracgan │ │ │ ├── __init__.py │ │ │ ├── gans.py │ │ │ ├── plotting.py │ │ │ ├── simulate.py │ │ │ ├── subplots.py │ │ │ └── util.py │ ├── results │ │ ├── celebA-HQ.jpg │ │ ├── imagenet_00.jpg │ │ ├── imagenet_01.jpg │ │ ├── imagenet_02.jpg │ │ ├── imagenet_03.jpg │ │ └── imagenet_04.jpg │ ├── test.py │ └── train.py └── nerf_pytorch │ ├── .gitignore │ ├── .gitmodules │ ├── LICENSE │ ├── README.md │ ├── configs │ ├── config_fern.txt │ ├── config_flower.txt │ ├── config_fortress.txt │ ├── config_horns.txt │ ├── config_lego.txt │ └── config_trex.txt │ ├── download_example_data.sh │ ├── imgs │ └── pipeline.jpg │ ├── load_blender.py │ ├── load_deepvoxels.py │ ├── load_llff.py │ ├── requirements.txt │ ├── run_nerf.py │ ├── run_nerf_helpers.py │ ├── run_nerf_helpers_mod.py │ ├── run_nerf_mod.py │ └── torchsearchsorted │ ├── .gitignore │ ├── LICENSE │ ├── README.md │ ├── examples │ ├── benchmark.py │ └── test.py │ ├── setup.py │ ├── src │ ├── cpu │ │ ├── searchsorted_cpu_wrapper.cpp │ │ └── searchsorted_cpu_wrapper.h │ ├── cuda │ │ ├── searchsorted_cuda_kernel.cu │ │ ├── searchsorted_cuda_kernel.h │ │ ├── searchsorted_cuda_wrapper.cpp │ │ └── searchsorted_cuda_wrapper.h │ └── torchsearchsorted │ │ ├── __init__.py │ │ ├── searchsorted.py │ │ └── utils.py │ └── test │ ├── conftest.py │ └── test_searchsorted.py └── train.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 autonomousvision 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GRAF 2 | 3 |
4 |
5 |
6 | 7 | This repository contains official code for the paper 8 | [GRAF: Generative Radiance Fields for 3D-Aware Image Synthesis](https://avg.is.tuebingen.mpg.de/publications/schwarz2020neurips). 9 | 10 | You can find detailed usage instructions for training your own models and using pre-trained models below. 11 | 12 | 13 | If you find our code or paper useful, please consider citing 14 | 15 | @inproceedings{Schwarz2020NEURIPS, 16 | title = {GRAF: Generative Radiance Fields for 3D-Aware Image Synthesis}, 17 | author = {Schwarz, Katja and Liao, Yiyi and Niemeyer, Michael and Geiger, Andreas}, 18 | booktitle = {Advances in Neural Information Processing Systems (NeurIPS)}, 19 | year = {2020} 20 | } 21 | 22 | ## Installation 23 | First you have to make sure that you have all dependencies in place. 24 | The simplest way to do so, is to use [anaconda](https://www.anaconda.com/). 25 | 26 | You can create an anaconda environment called `graf` using 27 | ``` 28 | conda env create -f environment.yml 29 | conda activate graf 30 | ``` 31 | 32 | Next, for nerf-pytorch install torchsearchsorted. Note that this requires `torch>=1.4.0` and `CUDA >= v10.1`. 33 | You can install torchsearchsorted via 34 | ``` 35 | cd submodules/nerf_pytorch 36 | pip install -r requirements.txt 37 | cd torchsearchsorted 38 | pip install . 39 | cd ../../../ 40 | ``` 41 | 42 | ## Demo 43 | 44 | You can now test our code via: 45 | ``` 46 | python eval.py configs/carla.yaml --pretrained --rotation_elevation 47 | ``` 48 | This script should create a folder `results/carla_128_from_pretrained/eval/` where you can find generated videos varying camera pose for the Cars dataset. 49 | 50 | ## Datasets 51 | 52 | If you only want to generate images using our pretrained models you do not need to download the datasets. 53 | The datasets are only needed if you want to train a model from scratch. 54 | 55 | ### Cars 56 | 57 | To download the Cars dataset from the paper simply run 58 | ``` 59 | cd data 60 | ./download_carla.sh 61 | cd .. 62 | ``` 63 | This creates a folder `data/carla/` downloads the images as a zip file and extracts them to `data/carla/`. 64 | While we do not use camera poses in this project we provide them for completeness. Your can download them by running 65 | ``` 66 | cd data 67 | ./download_carla_poses.sh 68 | cd .. 69 | ``` 70 | This downloads the camera intrinsics (single file, equal for all images) and extrinsics corresponding to each image. 71 | 72 | ### Faces 73 | 74 | Download [celebA](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html). 75 | Then replace `data/celebA` in `configs/celebA.yaml` with `*PATH/TO/CELEBA*/Img/img_align_celebA`. 76 | 77 | Download [celebA_hq](https://github.com/tkarras/progressive_growing_of_gans). 78 | Then replace `data/celebA_hq` in `configs/celebAHQ.yaml` with `*PATH/TO/CELEBA_HQ*`. 79 | 80 | ### Cats 81 | Download the [CatDataset](https://www.kaggle.com/crawford/cat-dataset). 82 | Run 83 | ``` 84 | cd data 85 | python preprocess_cats.py PATH/TO/CATS/DATASET 86 | cd .. 87 | ``` 88 | to preprocess the data and save it to `data/cats`. 89 | If successful this script should print: `Preprocessed 9407 images.` 90 | 91 | ### Birds 92 | Download [CUB-200-2011](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html) and the corresponding [Segmentation Masks](https://drive.google.com/file/d/1EamOKGLoTuZdtcVYbHMWNpkn3iAVj8TP/view). 93 | Run 94 | ``` 95 | cd data 96 | python preprocess_cub.py PATH/TO/CUB-200-2011 PATH/TO/SEGMENTATION/MASKS 97 | cd .. 98 | ``` 99 | to preprocess the data and save it to `data/cub`. 100 | If successful this script should print: `Preprocessed 8444 images.` 101 | 102 | ## Usage 103 | 104 | When you have installed all dependencies, you are ready to run our pre-trained models for 3D-aware image synthesis. 105 | 106 | ### Generate images using a pretrained model 107 | 108 | To evaluate a pretrained model, run 109 | ``` 110 | python eval.py CONFIG.yaml --pretrained --fid_kid --rotation_elevation --shape_appearance 111 | ``` 112 | where you replace CONFIG.yaml with one of the config files in `./configs`. 113 | 114 | This script should create a folder `results/EXPNAME/eval` with FID and KID scores in `fid_kid.csv`, videos for rotation and elevation in the respective folders and an interpolation for shape and appearance, `shape_appearance.png`. 115 | 116 | Note that some pretrained models are available for different image sizes which you can choose by setting `data:imsize` in the config file to one of the following values: 117 | ``` 118 | configs/carla.yaml: 119 | data:imsize 64 or 128 or 256 or 512 120 | configs/celebA.yaml: 121 | data:imsize 64 or 128 122 | configs/celebAHQ.yaml: 123 | data:imsize 256 or 512 124 | ``` 125 | 126 | ### Train a model from scratch 127 | 128 | To train a 3D-aware generative model from scratch run 129 | ``` 130 | python train.py CONFIG.yaml 131 | ``` 132 | where you replace `CONFIG.yaml` with your config file. 133 | The easiest way is to use one of the existing config files in the `./configs` directory 134 | which correspond to the experiments presented in the paper. 135 | Note that this will train the model from scratch and will not resume training for a pretrained model. 136 | 137 | You can monitor on the training process using [tensorboard](https://www.tensorflow.org/guide/summaries_and_tensorboard): 138 | ``` 139 | cd OUTPUT_DIR 140 | tensorboard --logdir ./monitoring --port 6006 141 | ``` 142 | where you replace `OUTPUT_DIR` with the respective output directory. 143 | 144 | For available training options, please take a look at `configs/default.yaml`. 145 | 146 | ### Evaluation of a new model 147 | 148 | For evaluation of the models run 149 | ``` 150 | python eval.py CONFIG.yaml --fid_kid --rotation_elevation --shape_appearance 151 | ``` 152 | where you replace `CONFIG.yaml` with your config file. 153 | 154 | ## Multi-View Consistency Check 155 | 156 | You can evaluate the multi-view consistency of the generated images by running a Multi-View-Stereo (MVS) algorithm on the generated images. This evaluation uses [COLMAP](https://colmap.github.io/) and make sure that you have COLMAP installed to run 157 | ``` 158 | python eval.py CONFIG.yaml --reconstruction 159 | ``` 160 | where you replace `CONFIG.yaml` with your config file. You can also evaluate our pretrained models via: 161 | ``` 162 | python eval.py configs/carla.yaml --pretrained --reconstruction 163 | ``` 164 | This script should create a folder `results/EXPNAME/eval/reconstruction/` where you can find generated multi-view images in `images/` and the corresponding 3D reconstructions in `models/`. 165 | 166 | ## Further Information 167 | 168 | ### GAN training 169 | 170 | This repository uses Lars Mescheder's awesome framework for [GAN training](https://github.com/LMescheder/GAN_stability). 171 | 172 | ### NeRF 173 | 174 | We base our code for the Generator on this great [Pytorch reimplementation](https://github.com/yenchenlin/nerf-pytorch) of Neural Radiance Fields. 175 | -------------------------------------------------------------------------------- /animations/carla_256.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/graf/c50d342fb567aec335b92e3f867c54b4dc4e1d09/animations/carla_256.gif -------------------------------------------------------------------------------- /configs/carla.yaml: -------------------------------------------------------------------------------- 1 | expname: carla_128 2 | data: 3 | imsize: 128 4 | datadir: data/carla 5 | type: carla 6 | radius: 10. 7 | near: 7.5 8 | far: 12.5 9 | fov: 30.0 10 | -------------------------------------------------------------------------------- /configs/cats.yaml: -------------------------------------------------------------------------------- 1 | expname: cats_64 2 | data: 3 | datadir: data/cats 4 | type: cats 5 | imsize: 64 6 | white_bkgd: False 7 | radius: 10 8 | near: 7.5 9 | far: 12.5 10 | fov: 10 11 | umin: 0 12 | umax: 0.19444444444444445 #70 deg 13 | vmin: 0.32898992833716556 # 70 deg 14 | vmax: 0.45642212862617093 # 85 deg 15 | 16 | -------------------------------------------------------------------------------- /configs/celebA.yaml: -------------------------------------------------------------------------------- 1 | expname: celebA_64 2 | data: 3 | datadir: /PATH/TO/CELEBA/Img/img_align_celebA 4 | type: celebA 5 | imsize: 64 6 | white_bkgd: False 7 | radius: 9.5,10.5 8 | near: 7.5 9 | far: 12.5 10 | fov: 10. 11 | umin: 0 12 | umax: 0.25 13 | vmin: 0.32898992833716556 # 70 deg 14 | vmax: 0.45642212862617093 # 85 deg 15 | 16 | -------------------------------------------------------------------------------- /configs/celebAHQ.yaml: -------------------------------------------------------------------------------- 1 | expname: celebAHQ_256 2 | data: 3 | datadir: /PATH/TO/CELEBA_HQ 4 | type: celebA_hq 5 | imsize: 256 6 | white_bkgd: False 7 | radius: 9.5,10.5 8 | near: 7.5 9 | far: 12.5 10 | fov: 10 11 | umin: 0 12 | umax: 0.25 13 | vmin: 0.32898992833716556 # 70 deg 14 | vmax: 0.45642212862617093 # 85 deg 15 | ray_sampler: 16 | min_scale: 0.125 17 | scale_anneal: 0.0019 18 | training: 19 | fid_every: 10000 20 | -------------------------------------------------------------------------------- /configs/cub.yaml: -------------------------------------------------------------------------------- 1 | expname: cub_64 2 | data: 3 | imsize: 64 4 | datadir: data/cub 5 | type: cub 6 | radius: 9,11 7 | near: 7.5 8 | far: 12.5 9 | fov: 30 10 | vmin: 0.24999999999999994 # 60 deg 11 | vmax: 0.5435778713738291 # 95 deg 12 | discriminator: 13 | hflip: True 14 | nerf: 15 | use_viewdirs: False -------------------------------------------------------------------------------- /configs/debug.yaml: -------------------------------------------------------------------------------- 1 | expname: debug 2 | data: 3 | imsize: 64 4 | datadir: data/cub 5 | type: cub 6 | radius: 10. 7 | near: 7.5 8 | far: 12.5 9 | fov: 30.0 10 | training: 11 | batch_size: 2 12 | nworkers: 0 13 | fid_every: -1 14 | -------------------------------------------------------------------------------- /configs/default.yaml: -------------------------------------------------------------------------------- 1 | expname: default 2 | data: 3 | datadir: data/carla 4 | type: carla 5 | imsize: 64 6 | white_bkgd: True 7 | near: 1. 8 | far: 6. 9 | radius: 3.4 # set according to near and far plane 10 | fov: 90. 11 | orthographic: False 12 | umin: 0. # 0 deg, convert to degree via 360. * u 13 | umax: 1. # 360 deg, convert to degree via 360. * u 14 | vmin: 0. # 0 deg, convert to degrees via arccos(1 - 2 * v) * 180. / pi 15 | vmax: 0.45642212862617093 # 85 deg, convert to degrees via arccos(1 - 2 * v) * 180. / pi 16 | nerf: 17 | i_embed: 0 18 | use_viewdirs: True 19 | multires: 10 20 | multires_views: 4 21 | N_samples: 64 22 | N_importance: 0 23 | netdepth: 8 24 | netwidth: 256 25 | netdepth_fine: 8 26 | netwidth_fine: 256 27 | perturb: 1. 28 | raw_noise_std: 1. 29 | decrease_noise: True 30 | z_dist: 31 | type: gauss 32 | dim: 256 33 | dim_appearance: 128 # This dimension is subtracted from "dim" 34 | ray_sampler: 35 | min_scale: 0.25 36 | max_scale: 1. 37 | scale_anneal: 0.0025 # no effect if scale_anneal<0, else the minimum scale decreases exponentially until converge to min_scale 38 | N_samples: 1024 # 32*32, patchsize 39 | discriminator: 40 | ndf: 64 41 | hflip: False # Randomly flip discriminator input horizontally 42 | training: 43 | outdir: ./results 44 | model_file: model.pt 45 | monitoring: tensorboard 46 | use_amp: False # Use automated mixed precision 47 | nworkers: 6 48 | batch_size: 8 49 | chunk: 32768 # 1024*32 50 | netchunk: 65536 # 1024*64 51 | lr_g: 0.0005 52 | lr_d: 0.0001 53 | lr_anneal: 0.5 54 | lr_anneal_every: 50000,100000,200000 55 | equalize_lr: False 56 | gan_type: standard 57 | reg_type: real 58 | reg_param: 10. 59 | optimizer: rmsprop 60 | n_test_samples_with_same_shape_code: 4 61 | take_model_average: true 62 | model_average_beta: 0.999 63 | model_average_reinit: false 64 | restart_every: -1 65 | save_best: fid 66 | fid_every: 5000 # Valid for FID and KID 67 | print_every: 10 68 | sample_every: 500 69 | save_every: 900 70 | backup_every: 50000 71 | video_every: 10000 72 | -------------------------------------------------------------------------------- /configs/pretrained_models.yaml: -------------------------------------------------------------------------------- 1 | carla: 2 | 64: https://s3.eu-central-1.amazonaws.com/avg-projects/graf/models/carla/carla_64.pt 3 | 128: https://s3.eu-central-1.amazonaws.com/avg-projects/graf/models/carla/carla_128.pt 4 | 256: https://s3.eu-central-1.amazonaws.com/avg-projects/graf/models/carla/carla_256.pt 5 | 512: https://s3.eu-central-1.amazonaws.com/avg-projects/graf/models/carla/carla_512.pt 6 | celebA: 7 | 64: https://s3.eu-central-1.amazonaws.com/avg-projects/graf/models/faces/celebA_64.pt 8 | 128: https://s3.eu-central-1.amazonaws.com/avg-projects/graf/models/faces/celebA_128.pt 9 | celebA_hq: 10 | 256: https://s3.eu-central-1.amazonaws.com/avg-projects/graf/models/faces/celebA_hq_256.pt 11 | 512: https://s3.eu-central-1.amazonaws.com/avg-projects/graf/models/faces/celebA_hq_512.pt 12 | cats: 13 | 64: https://s3.eu-central-1.amazonaws.com/avg-projects/graf/models/cats/cats_64.pt 14 | cub: 15 | 64: https://s3.eu-central-1.amazonaws.com/avg-projects/graf/models/birds/cub_64.pt -------------------------------------------------------------------------------- /data/download_carla.sh: -------------------------------------------------------------------------------- 1 | wget https://s3.eu-central-1.amazonaws.com/avg-projects/graf/data/carla.zip 2 | unzip carla.zip 3 | cd .. 4 | -------------------------------------------------------------------------------- /data/download_carla_poses.sh: -------------------------------------------------------------------------------- 1 | mkdir -p ./carla 2 | cd ./carla 3 | wget https://s3.eu-central-1.amazonaws.com/avg-projects/graf/data/carla_poses.zip 4 | unzip carla_poses.zip 5 | cd .. -------------------------------------------------------------------------------- /data/preprocess_cats.py: -------------------------------------------------------------------------------- 1 | ### adapted from https://github.com/AlexiaJM/RelativisticGAN/blob/master/code/preprocess_cat_dataset.py 2 | ### original code from https://github.com/microe/angora-blue/blob/master/cascade_training/describe.py by Erik Hovland 3 | import argparse 4 | import cv2 5 | import glob 6 | import math 7 | import os 8 | from tqdm import tqdm 9 | 10 | 11 | def rotateCoords(coords, center, angleRadians): 12 | # Positive y is down so reverse the angle, too. 13 | angleRadians = -angleRadians 14 | xs, ys = coords[::2], coords[1::2] 15 | newCoords = [] 16 | n = min(len(xs), len(ys)) 17 | i = 0 18 | centerX = center[0] 19 | centerY = center[1] 20 | cosAngle = math.cos(angleRadians) 21 | sinAngle = math.sin(angleRadians) 22 | while i < n: 23 | xOffset = xs[i] - centerX 24 | yOffset = ys[i] - centerY 25 | newX = xOffset * cosAngle - yOffset * sinAngle + centerX 26 | newY = xOffset * sinAngle + yOffset * cosAngle + centerY 27 | newCoords += [newX, newY] 28 | i += 1 29 | return newCoords 30 | 31 | 32 | def preprocessCatFace(coords, image): 33 | leftEyeX, leftEyeY = coords[0], coords[1] 34 | rightEyeX, rightEyeY = coords[2], coords[3] 35 | mouthX = coords[4] 36 | if leftEyeX > rightEyeX and leftEyeY < rightEyeY and \ 37 | mouthX > rightEyeX: 38 | # The "right eye" is in the second quadrant of the face, 39 | # while the "left eye" is in the fourth quadrant (from the 40 | # viewer's perspective.) Swap the eyes' labels in order to 41 | # simplify the rotation logic. 42 | leftEyeX, rightEyeX = rightEyeX, leftEyeX 43 | leftEyeY, rightEyeY = rightEyeY, leftEyeY 44 | 45 | eyesCenter = (0.5 * (leftEyeX + rightEyeX), 46 | 0.5 * (leftEyeY + rightEyeY)) 47 | 48 | eyesDeltaX = rightEyeX - leftEyeX 49 | eyesDeltaY = rightEyeY - leftEyeY 50 | eyesAngleRadians = math.atan2(eyesDeltaY, eyesDeltaX) 51 | eyesAngleDegrees = eyesAngleRadians * 180.0 / math.pi 52 | 53 | # Straighten the image and fill in gray for blank borders. 54 | rotation = cv2.getRotationMatrix2D( 55 | eyesCenter, eyesAngleDegrees, 1.0) 56 | imageSize = image.shape[1::-1] 57 | straight = cv2.warpAffine(image, rotation, imageSize, 58 | borderValue=(128, 128, 128)) 59 | 60 | # Straighten the coordinates of the features. 61 | newCoords = rotateCoords( 62 | coords, eyesCenter, eyesAngleRadians) 63 | 64 | # Make the face as wide as the space between the ear bases. 65 | w = abs(newCoords[16] - newCoords[6]) 66 | # Make the face square. 67 | h = w 68 | # Put the center point between the eyes at (0.5, 0.4) in 69 | # proportion to the entire face. 70 | minX = eyesCenter[0] - w / 2 71 | if minX < 0: 72 | w += minX 73 | minX = 0 74 | minY = eyesCenter[1] - h * 2 / 5 75 | if minY < 0: 76 | h += minY 77 | minY = 0 78 | 79 | # Crop the face. 80 | crop = straight[int(minY):int(minY + h), int(minX):int(minX + w)] 81 | # Return the crop. 82 | return crop 83 | 84 | 85 | def describePositive(root, outdir): 86 | filenames = glob.glob('%s/CAT_*/*.jpg' % root) 87 | 88 | for imagePath in tqdm(filenames, total=len(filenames), desc='Process images...'): 89 | # Open the '.cat' annotation file associated with this 90 | # image. 91 | if not os.path.isfile('%s.cat' % imagePath): 92 | print('.cat file missing at %s' % imagePath) 93 | continue 94 | input = open('%s.cat' % imagePath, 'r') 95 | # Read the coordinates of the cat features from the 96 | # file. Discard the first number, which is the number 97 | # of features. 98 | coords = [int(i) for i in input.readline().split()[1:]] 99 | # Read the image. 100 | image = cv2.imread(imagePath) 101 | # Straighten and crop the cat face. 102 | crop = preprocessCatFace(coords, image) 103 | if crop is None: 104 | print('Failed to preprocess image at %s' % imagePath) 105 | continue 106 | # Save the crop to folders based on size 107 | h, w, colors = crop.shape 108 | if min(h, w) >= 64: 109 | Path1 = imagePath.replace(root, outdir) 110 | os.makedirs(os.path.dirname(Path1), exist_ok=True) 111 | resized_crop = cv2.resize(crop, (64, 64)) 112 | cv2.imwrite(Path1, resized_crop) 113 | 114 | 115 | if __name__ == '__main__': 116 | # Arguments 117 | parser = argparse.ArgumentParser( 118 | description='Crop cats from the CatDataset.' 119 | ) 120 | parser.add_argument('root', type=str, help='Path to data directory containing "CAT_00" - "CAT_06" folders.') 121 | args = parser.parse_args() 122 | 123 | outdir = './cats' 124 | os.makedirs(outdir, exist_ok=True) 125 | 126 | describePositive(args.root, outdir) 127 | print('Preprocessed {} images.'.format(len(glob.glob(os.path.join(outdir, '*/*.jpg'))))) -------------------------------------------------------------------------------- /data/preprocess_cub.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from tqdm import tqdm 4 | from PIL import Image 5 | import numpy as np 6 | import glob 7 | from torchvision.transforms import CenterCrop 8 | 9 | 10 | if __name__ == '__main__': 11 | # Arguments 12 | parser = argparse.ArgumentParser( 13 | description='Select split from CUB200-2011 and crop birds from images.' 14 | ) 15 | parser.add_argument('root', type=str, 16 | help='Path to data directory containing bounding box file and "images" folder.') 17 | parser.add_argument('maskdir', type=str, help='Path to data directory containing the segmentation masks.') 18 | args = parser.parse_args() 19 | 20 | imdir = os.path.join(args.root, 'images') 21 | bboxfile = os.path.join(args.root, 'bounding_boxes.txt') 22 | maskdir = args.maskdir 23 | namefile = './cub/filtered_files.txt' 24 | outdir = './cub' 25 | os.makedirs(outdir, exist_ok=True) 26 | 27 | # load files 28 | with open(namefile, 'r') as f: 29 | id_filename = [line.split(' ') for line in f.read().splitlines()] 30 | 31 | # load bounding boxes 32 | boxes = {} 33 | with open(bboxfile, 'r') as f: 34 | for line in f.read().splitlines(): 35 | k, x, y, w, h = line.split(' ') 36 | box = float(x), float(y), float(x) + float(w), float(y) + float(h) # (left, up, right, down) 37 | boxes[k] = box 38 | 39 | for i, (id, filename) in tqdm(enumerate(id_filename), total=len(id_filename)): 40 | path = os.path.join(imdir, filename) 41 | img = Image.open(path).convert('RGBA') 42 | 43 | # load alpha 44 | path = os.path.join(maskdir, filename.replace('.jpg', '.png')) 45 | alpha = Image.open(path) 46 | if alpha.mode == 'RGBA': 47 | alpha = alpha.split()[-1] 48 | alpha = alpha.convert('L') 49 | img.putalpha(alpha) 50 | 51 | # crop square images using bbox 52 | img = img.crop(boxes[id]) 53 | s = max(img.size) 54 | img = CenterCrop(s)(img) # CenterCrop pads image to square using zeros (also for alpha) 55 | 56 | # composite 57 | img = np.array(img) 58 | alpha = (img[..., 3:4]) > 127 # convert to binary mask 59 | bg = np.array(255 * (1. - alpha), np.uint8) 60 | img = img[..., :3] * alpha + bg 61 | img = Image.fromarray(img) 62 | 63 | img.save(os.path.join(outdir, '%06d.png' % i)) 64 | 65 | print('Preprocessed {} images.'.format(len(glob.glob(os.path.join(outdir, '*.png'))))) 66 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: graf 2 | channels: 3 | - conda-forge 4 | - anaconda 5 | - defaults 6 | dependencies: 7 | - blas=1.0=mkl 8 | - ca-certificates=2020.11.8=ha878542_0 9 | - certifi=2020.11.8=py38h578d9bd_0 10 | - intel-openmp=2020.2=254 11 | - joblib=0.17.0=py_0 12 | - ld_impl_linux-64=2.33.1=h53a641e_7 13 | - libedit=3.1.20191231=h14c3975_1 14 | - libffi=3.3=he6710b0_2 15 | - libgcc-ng=9.1.0=hdf63c60_0 16 | - libgfortran-ng=7.3.0=hdf63c60_0 17 | - libprotobuf=3.13.0.1=h8b12597_0 18 | - libstdcxx-ng=9.1.0=hdf63c60_0 19 | - mkl=2019.4=243 20 | - mkl-service=2.3.0=py38he904b0f_0 21 | - mkl_fft=1.2.0=py38h23d657b_0 22 | - mkl_random=1.1.0=py38h962f231_0 23 | - ncurses=6.2=he6710b0_1 24 | - numpy-base=1.19.1=py38hfa32c7d_0 25 | - openssl=1.1.1h=h516909a_0 26 | - pip=20.2.4=py38_0 27 | - python=3.8.5=h7579374_1 28 | - python_abi=3.8=1_cp38 29 | - pyyaml=5.3.1=py38h7b6447c_1 30 | - readline=8.0=h7b6447c_0 31 | - scikit-learn=0.23.2=py38h0573a6f_0 32 | - scipy=1.5.2=py38h0b6359f_0 33 | - setuptools=50.3.0=py38hb0f4dca_1 34 | - six=1.15.0=py_0 35 | - sqlite=3.33.0=h62c20be_0 36 | - tensorboardx=2.1=py_0 37 | - threadpoolctl=2.1.0=pyh5ca1d4c_0 38 | - tk=8.6.10=hbc83047_0 39 | - wheel=0.35.1=py_0 40 | - xz=5.2.5=h7b6447c_0 41 | - yaml=0.2.5=h7b6447c_0 42 | - zlib=1.2.11=h7b6447c_3 43 | - pip: 44 | - absl-py==0.11.0 45 | - configargparse==1.2.3 46 | - cycler==0.10.0 47 | - dataclasses==0.6 48 | - future==0.18.2 49 | - grpcio==1.33.2 50 | - imageio==2.9.0 51 | - imageio-ffmpeg==0.4.2 52 | - kiwisolver==1.3.1 53 | - markdown==3.3.3 54 | - matplotlib==3.3.3 55 | - numpy==1.19.4 56 | - opencv-python==4.4.0.46 57 | - pillow==8.0.1 58 | - protobuf==3.14.0 59 | - pyparsing==2.4.7 60 | - python-dateutil==2.8.1 61 | - tensorboard==1.14.0 62 | - torch==1.7.0 63 | - torchvision==0.8.1 64 | - tqdm==4.53.0 65 | - typing-extensions==3.7.4.3 66 | - werkzeug==1.0.1 67 | -------------------------------------------------------------------------------- /external/colmap/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/graf/c50d342fb567aec335b92e3f867c54b4dc4e1d09/external/colmap/__init__.py -------------------------------------------------------------------------------- /external/colmap/filter_points.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import glob 4 | import struct 5 | import numpy as np 6 | 7 | def readBinaryPly(pcdFile, fmt='ffffffBBB', fmt_len=27): 8 | 9 | with open(pcdFile, 'rb') as f: 10 | plyData = f.readlines() 11 | 12 | headLine = plyData.index(b'end_header\n')+1 13 | plyData = plyData[headLine:] 14 | plyData = b"".join(plyData) 15 | 16 | n_pts_loaded = int(len(plyData)/fmt_len) 17 | 18 | data = [] 19 | for i in range(n_pts_loaded): 20 | pts=struct.unpack(fmt, plyData[i*fmt_len:(i+1)*fmt_len]) 21 | data.append(pts) 22 | data=np.asarray(data) 23 | 24 | return data 25 | 26 | def writeBinaryPly(pcdFile, data): 27 | fmt = '=ffffffBBB' 28 | fmt_len = 27 29 | n_pts = data.shape[0] 30 | 31 | with open(pcdFile, 'wb') as f: 32 | f.write(b'ply\n') 33 | f.write(b'format binary_little_endian 1.0\n') 34 | f.write(b'comment\n') 35 | f.write(b'element vertex %d\n' % n_pts) 36 | f.write(b'property float x\n') 37 | f.write(b'property float y\n') 38 | f.write(b'property float z\n') 39 | f.write(b'property float nx\n') 40 | f.write(b'property float ny\n') 41 | f.write(b'property float nz\n') 42 | f.write(b'property uchar red\n') 43 | f.write(b'property uchar green\n') 44 | f.write(b'property uchar blue\n') 45 | f.write(b'end_header\n') 46 | 47 | for i in range(n_pts): 48 | f.write(struct.pack(fmt, *data[i,0:6], *data[i,6:9].astype(np.uint8))) 49 | 50 | 51 | def filter_ply(object_dir): 52 | 53 | ply_files = sorted(glob.glob(os.path.join(object_dir, 'dense', '*', 'fused.ply'))) 54 | 55 | for ply_file in ply_files: 56 | ply_filter_file = ply_file.replace('.ply', '_filtered.ply') 57 | plydata = readBinaryPly(ply_file) 58 | vertex = plydata[:,0:3] 59 | normal = plydata[:,3:6] 60 | color = plydata[:,6:9] 61 | 62 | mask = np.mean(color,1)<(0.85 * 255.) 63 | color = color[mask, :] 64 | normal = normal[mask, :] 65 | vertex = vertex[mask, :] 66 | plydata = np.hstack((vertex, normal, color)) 67 | writeBinaryPly(ply_filter_file, plydata) 68 | print('Processed file {}'.format(ply_filter_file)) 69 | 70 | if __name__=='__main__': 71 | 72 | object_dir=sys.argv[1] 73 | filter_ply(object_dir) 74 | -------------------------------------------------------------------------------- /external/colmap/run_colmap_automatic.sh: -------------------------------------------------------------------------------- 1 | #set -e 2 | 3 | input_dir=$1 4 | output_dir=$2 5 | 6 | mkdir -p ${output_dir} 7 | echo Processing ${input_dir} ... 8 | colmap automatic_reconstructor \ 9 | --workspace_path ${output_dir} \ 10 | --image_path ${input_dir}/ \ 11 | --single_camera=1 \ 12 | --dense=1 \ 13 | --gpu_index=0 14 | -------------------------------------------------------------------------------- /graf/config.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torchvision.transforms import * 4 | 5 | from .datasets import * 6 | from .transforms import FlexGridRaySampler 7 | from .utils import polar_to_cartesian, look_at, to_phi, to_theta 8 | 9 | 10 | def save_config(outpath, config): 11 | from yaml import safe_dump 12 | with open(outpath, 'w') as f: 13 | safe_dump(config, f) 14 | 15 | 16 | def update_config(config, unknown): 17 | # update config given args 18 | for idx,arg in enumerate(unknown): 19 | if arg.startswith("--"): 20 | if (':') in arg: 21 | k1,k2 = arg.replace("--","").split(':') 22 | argtype = type(config[k1][k2]) 23 | if argtype == bool: 24 | v = unknown[idx+1].lower() == 'true' 25 | else: 26 | if config[k1][k2] is not None: 27 | v = type(config[k1][k2])(unknown[idx+1]) 28 | else: 29 | v = unknown[idx+1] 30 | print(f'Changing {k1}:{k2} ---- {config[k1][k2]} to {v}') 31 | config[k1][k2] = v 32 | else: 33 | k = arg.replace('--','') 34 | v = unknown[idx+1] 35 | argtype = type(config[k]) 36 | print(f'Changing {k} ---- {config[k]} to {v}') 37 | config[k] = v 38 | 39 | return config 40 | 41 | 42 | def get_data(config): 43 | H = W = imsize = config['data']['imsize'] 44 | dset_type = config['data']['type'] 45 | fov = config['data']['fov'] 46 | 47 | transforms = Compose([ 48 | Resize(imsize), 49 | ToTensor(), 50 | Lambda(lambda x: x * 2 - 1), 51 | ]) 52 | 53 | kwargs = { 54 | 'data_dirs': config['data']['datadir'], 55 | 'transforms': transforms 56 | } 57 | 58 | if dset_type == 'carla': 59 | dset = Carla(**kwargs) 60 | 61 | elif dset_type == 'celebA': 62 | assert imsize <= 128, 'cropped GT data has lower resolution than imsize, consider using celebA_hq instead' 63 | transforms.transforms.insert(0, RandomHorizontalFlip()) 64 | transforms.transforms.insert(0, CenterCrop(108)) 65 | 66 | dset = CelebA(**kwargs) 67 | 68 | elif dset_type == 'celebA_hq': 69 | transforms.transforms.insert(0, RandomHorizontalFlip()) 70 | transforms.transforms.insert(0, CenterCrop(650)) 71 | 72 | dset = CelebAHQ(**kwargs) 73 | 74 | elif dset_type == 'cats': 75 | transforms.transforms.insert(0, RandomHorizontalFlip()) 76 | dset = Cats(**kwargs) 77 | 78 | elif dset_type == 'cub': 79 | dset = CUB(**kwargs) 80 | 81 | dset.H = dset.W = imsize 82 | dset.focal = W/2 * 1 / np.tan((.5 * fov * np.pi/180.)) 83 | radius = config['data']['radius'] 84 | render_radius = radius 85 | if isinstance(radius, str): 86 | radius = tuple(float(r) for r in radius.split(',')) 87 | render_radius = max(radius) 88 | dset.radius = radius 89 | 90 | # compute render poses 91 | N = 40 92 | theta = 0.5 * (to_theta(config['data']['vmin']) + to_theta(config['data']['vmax'])) 93 | angle_range = (to_phi(config['data']['umin']), to_phi(config['data']['umax'])) 94 | render_poses = get_render_poses(render_radius, angle_range=angle_range, theta=theta, N=N) 95 | 96 | print('Loaded {}'.format(dset_type), imsize, len(dset), render_poses.shape, [H,W,dset.focal,dset.radius], config['data']['datadir']) 97 | return dset, [H,W,dset.focal,dset.radius], render_poses 98 | 99 | 100 | def get_render_poses(radius, angle_range=(0, 360), theta=0, N=40, swap_angles=False): 101 | poses = [] 102 | theta = max(0.1, theta) 103 | for angle in np.linspace(angle_range[0],angle_range[1],N+1)[:-1]: 104 | angle = max(0.1, angle) 105 | if swap_angles: 106 | loc = polar_to_cartesian(radius, theta, angle, deg=True) 107 | else: 108 | loc = polar_to_cartesian(radius, angle, theta, deg=True) 109 | R = look_at(loc)[0] 110 | RT = np.concatenate([R, loc.reshape(3, 1)], axis=1) 111 | poses.append(RT) 112 | return torch.from_numpy(np.stack(poses)) 113 | 114 | 115 | def build_models(config, disc=True): 116 | from argparse import Namespace 117 | from submodules.nerf_pytorch.run_nerf_mod import create_nerf 118 | from .models.generator import Generator 119 | from .models.discriminator import Discriminator 120 | 121 | config_nerf = Namespace(**config['nerf']) 122 | # Update config for NERF 123 | config_nerf.chunk = min(config['training']['chunk'], 1024*config['training']['batch_size']) # let batch size for training with patches limit the maximal memory 124 | config_nerf.netchunk = config['training']['netchunk'] 125 | config_nerf.white_bkgd = config['data']['white_bkgd'] 126 | config_nerf.feat_dim = config['z_dist']['dim'] 127 | config_nerf.feat_dim_appearance = config['z_dist']['dim_appearance'] 128 | 129 | render_kwargs_train, render_kwargs_test, params, named_parameters = create_nerf(config_nerf) 130 | 131 | bds_dict = {'near': config['data']['near'], 'far': config['data']['far']} 132 | render_kwargs_train.update(bds_dict) 133 | render_kwargs_test.update(bds_dict) 134 | 135 | ray_sampler = FlexGridRaySampler(N_samples=config['ray_sampler']['N_samples'], 136 | min_scale=config['ray_sampler']['min_scale'], 137 | max_scale=config['ray_sampler']['max_scale'], 138 | scale_anneal=config['ray_sampler']['scale_anneal'], 139 | orthographic=config['data']['orthographic']) 140 | 141 | H, W, f, r = config['data']['hwfr'] 142 | generator = Generator(H, W, f, r, 143 | ray_sampler=ray_sampler, 144 | render_kwargs_train=render_kwargs_train, render_kwargs_test=render_kwargs_test, 145 | parameters=params, named_parameters=named_parameters, 146 | chunk=config_nerf.chunk, 147 | range_u=(float(config['data']['umin']), float(config['data']['umax'])), 148 | range_v=(float(config['data']['vmin']), float(config['data']['vmax'])), 149 | orthographic=config['data']['orthographic'], 150 | ) 151 | 152 | discriminator = None 153 | if disc: 154 | disc_kwargs = {'nc': 3, # channels for patch discriminator 155 | 'ndf': config['discriminator']['ndf'], 156 | 'imsize': int(np.sqrt(config['ray_sampler']['N_samples'])), 157 | 'hflip': config['discriminator']['hflip']} 158 | 159 | discriminator = Discriminator(**disc_kwargs) 160 | 161 | return generator, discriminator 162 | 163 | 164 | def build_lr_scheduler(optimizer, config, last_epoch=-1): 165 | import torch.optim as optim 166 | step_size = config['training']['lr_anneal_every'] 167 | if isinstance(step_size, str): 168 | milestones = [int(m) for m in step_size.split(',')] 169 | lr_scheduler = optim.lr_scheduler.MultiStepLR( 170 | optimizer, 171 | milestones=milestones, 172 | gamma=config['training']['lr_anneal'], 173 | last_epoch=last_epoch) 174 | else: 175 | lr_scheduler = optim.lr_scheduler.StepLR( 176 | optimizer, 177 | step_size=step_size, 178 | gamma=config['training']['lr_anneal'], 179 | last_epoch=last_epoch 180 | ) 181 | return lr_scheduler 182 | -------------------------------------------------------------------------------- /graf/datasets.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import numpy as np 3 | from PIL import Image 4 | 5 | from torchvision.datasets.vision import VisionDataset 6 | 7 | 8 | class ImageDataset(VisionDataset): 9 | """ 10 | Load images from multiple data directories. 11 | Folder structure: data_dir/filename.png 12 | """ 13 | 14 | def __init__(self, data_dirs, transforms=None): 15 | # Use multiple root folders 16 | if not isinstance(data_dirs, list): 17 | data_dirs = [data_dirs] 18 | 19 | # initialize base class 20 | VisionDataset.__init__(self, root=data_dirs, transform=transforms) 21 | 22 | self.filenames = [] 23 | root = [] 24 | 25 | for ddir in self.root: 26 | filenames = self._get_files(ddir) 27 | self.filenames.extend(filenames) 28 | root.append(ddir) 29 | 30 | def __len__(self): 31 | return len(self.filenames) 32 | 33 | @staticmethod 34 | def _get_files(root_dir): 35 | return glob.glob(f'{root_dir}/*.png') + glob.glob(f'{root_dir}/*.jpg') 36 | 37 | def __getitem__(self, idx): 38 | filename = self.filenames[idx] 39 | img = Image.open(filename).convert('RGB') 40 | if self.transform is not None: 41 | img = self.transform(img) 42 | return img 43 | 44 | 45 | class Carla(ImageDataset): 46 | def __init__(self, *args, **kwargs): 47 | super(Carla, self).__init__(*args, **kwargs) 48 | 49 | 50 | class CelebA(ImageDataset): 51 | def __init__(self, *args, **kwargs): 52 | super(CelebA, self).__init__(*args, **kwargs) 53 | 54 | 55 | class CUB(ImageDataset): 56 | def __init__(self, *args, **kwargs): 57 | super(CUB, self).__init__(*args, **kwargs) 58 | 59 | 60 | class Cats(ImageDataset): 61 | def __init__(self, *args, **kwargs): 62 | super(Cats, self).__init__(*args, **kwargs) 63 | 64 | @staticmethod 65 | def _get_files(root_dir): 66 | return glob.glob(f'{root_dir}/CAT_*/*.jpg') 67 | 68 | 69 | class CelebAHQ(ImageDataset): 70 | def __init__(self, *args, **kwargs): 71 | super(CelebAHQ, self).__init__(*args, **kwargs) 72 | 73 | def _get_files(self, root): 74 | return glob.glob(f'{root}/*.npy') 75 | 76 | def __getitem__(self, idx): 77 | img = np.load(self.filenames[idx]).squeeze(0).transpose(1,2,0) 78 | if img.dtype == np.uint8: 79 | pass 80 | elif img.dtype == np.float32: 81 | img = (img * 255).astype(np.uint8) 82 | else: 83 | raise NotImplementedError 84 | img = Image.fromarray(img).convert('RGB') 85 | if self.transform is not None: 86 | img = self.transform(img) 87 | 88 | return img 89 | -------------------------------------------------------------------------------- /graf/gan_training.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | from tqdm import tqdm 5 | 6 | from submodules.GAN_stability.gan_training.train import toggle_grad, Trainer as TrainerBase 7 | from submodules.GAN_stability.gan_training.eval import Evaluator as EvaluatorBase 8 | from submodules.GAN_stability.gan_training.metrics import FIDEvaluator, KIDEvaluator 9 | 10 | from .utils import save_video, color_depth_map 11 | 12 | 13 | class Trainer(TrainerBase): 14 | def __init__(self, *args, use_amp=False, **kwargs): 15 | super(Trainer, self).__init__(*args, **kwargs) 16 | self.use_amp = use_amp 17 | if self.use_amp: 18 | self.scaler = torch.cuda.amp.GradScaler() 19 | 20 | def generator_trainstep(self, y, z): 21 | if not self.use_amp: 22 | return super(Trainer, self).generator_trainstep(y, z) 23 | assert (y.size(0) == z.size(0)) 24 | toggle_grad(self.generator, True) 25 | toggle_grad(self.discriminator, False) 26 | self.generator.train() 27 | self.discriminator.train() 28 | self.g_optimizer.zero_grad() 29 | 30 | with torch.cuda.amp.autocast(): 31 | x_fake = self.generator(z, y) 32 | d_fake = self.discriminator(x_fake, y) 33 | gloss = self.compute_loss(d_fake, 1) 34 | self.scaler.scale(gloss).backward() 35 | 36 | self.scaler.step(self.g_optimizer) 37 | self.scaler.update() 38 | 39 | return gloss.item() 40 | 41 | def discriminator_trainstep(self, x_real, y, z): 42 | return super(Trainer, self).discriminator_trainstep(x_real, y, z) # spectral norm raises error for when using amp 43 | 44 | 45 | class Evaluator(EvaluatorBase): 46 | def __init__(self, eval_fid_kid, *args, **kwargs): 47 | super(Evaluator, self).__init__(*args, **kwargs) 48 | if eval_fid_kid: 49 | self.inception_eval = FIDEvaluator( 50 | device=self.device, 51 | batch_size=self.batch_size, 52 | resize=True, 53 | n_samples=20000, 54 | n_samples_fake=1000, 55 | ) 56 | 57 | def get_rays(self, pose): 58 | return self.generator.val_ray_sampler(self.generator.H, self.generator.W, 59 | self.generator.focal, pose)[0] 60 | 61 | def create_samples(self, z, poses=None): 62 | self.generator.eval() 63 | 64 | N_samples = len(z) 65 | device = self.generator.device 66 | z = z.to(device).split(self.batch_size) 67 | if poses is None: 68 | rays = [None] * len(z) 69 | else: 70 | rays = torch.stack([self.get_rays(poses[i].to(device)) for i in range(N_samples)]) 71 | rays = rays.split(self.batch_size) 72 | 73 | rgb, disp, acc = [], [], [] 74 | with torch.no_grad(): 75 | for z_i, rays_i in tqdm(zip(z, rays), total=len(z), desc='Create samples...'): 76 | bs = len(z_i) 77 | if rays_i is not None: 78 | rays_i = rays_i.permute(1, 0, 2, 3).flatten(1, 2) # Bx2x(HxW)xC -> 2x(BxHxW)x3 79 | rgb_i, disp_i, acc_i, _ = self.generator(z_i, rays=rays_i) 80 | 81 | reshape = lambda x: x.view(bs, self.generator.H, self.generator.W, x.shape[1]).permute(0, 3, 1, 2) # (NxHxW)xC -> NxCxHxW 82 | rgb.append(reshape(rgb_i).cpu()) 83 | disp.append(reshape(disp_i).cpu()) 84 | acc.append(reshape(acc_i).cpu()) 85 | 86 | rgb = torch.cat(rgb) 87 | disp = torch.cat(disp) 88 | acc = torch.cat(acc) 89 | 90 | depth = self.disp_to_cdepth(disp) 91 | 92 | return rgb, depth, acc 93 | 94 | def make_video(self, basename, z, poses, as_gif=True): 95 | """ Generate images and save them as video. 96 | z (N_samples, zdim): latent codes 97 | poses (N_frames, 3 x 4): camera poses for all frames of video 98 | """ 99 | N_samples, N_frames = len(z), len(poses) 100 | 101 | # reshape inputs 102 | z = z.unsqueeze(1).expand(-1, N_frames, -1).flatten(0, 1) # (N_samples x N_frames) x z_dim 103 | poses = poses.unsqueeze(0) \ 104 | .expand(N_samples, -1, -1, -1).flatten(0, 1) # (N_samples x N_frames) x 3 x 4 105 | 106 | rgbs, depths, accs = self.create_samples(z, poses=poses) 107 | 108 | reshape = lambda x: x.view(N_samples, N_frames, *x.shape[1:]) 109 | rgbs = reshape(rgbs) 110 | depths = reshape(depths) 111 | print('Done, saving', rgbs.shape) 112 | 113 | fps = min(int(N_frames / 2.), 25) # aim for at least 2 second video 114 | for i in range(N_samples): 115 | save_video(rgbs[i], basename + '{:04d}_rgb.mp4'.format(i), as_gif=as_gif, fps=fps) 116 | save_video(depths[i], basename + '{:04d}_depth.mp4'.format(i), as_gif=as_gif, fps=fps) 117 | 118 | def disp_to_cdepth(self, disps): 119 | """Convert depth to color values""" 120 | if (disps == 2e10).all(): # no values predicted 121 | return torch.ones_like(disps) 122 | 123 | near, far = self.generator.render_kwargs_test['near'], self.generator.render_kwargs_test['far'] 124 | 125 | disps = disps / 2 + 0.5 # [-1, 1] -> [0, 1] 126 | 127 | depth = 1. / torch.max(1e-10 * torch.ones_like(disps), disps) # disparity -> depth 128 | depth[disps == 1e10] = far # set undefined values to far plane 129 | 130 | # scale between near, far plane for better visualization 131 | depth = (depth - near) / (far - near) 132 | 133 | depth = np.stack([color_depth_map(d) for d in depth[:, 0].detach().cpu().numpy()]) # convert to color 134 | depth = (torch.from_numpy(depth).permute(0, 3, 1, 2) / 255.) * 2 - 1 # [0, 255] -> [-1, 1] 135 | 136 | return depth 137 | 138 | def compute_fid_kid(self, sample_generator=None): 139 | if sample_generator is None: 140 | def sample(): 141 | while True: 142 | z = self.zdist.sample((self.batch_size,)) 143 | rgb, _, _ = self.create_samples(z) 144 | # convert to uint8 and back to get correct binning 145 | rgb = (rgb / 2 + 0.5).mul_(255).clamp_(0, 255).to(torch.uint8).to(torch.float) / 255. * 2 - 1 146 | yield rgb.cpu() 147 | 148 | sample_generator = sample() 149 | 150 | fid, (kids, vars) = self.inception_eval.get_fid_kid(sample_generator) 151 | kid = np.mean(kids) 152 | return fid, kid 153 | -------------------------------------------------------------------------------- /graf/models/discriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Discriminator(nn.Module): 6 | def __init__(self, nc=3, ndf=64, imsize=64, hflip=False): 7 | super(Discriminator, self).__init__() 8 | self.nc = nc 9 | assert(imsize==32 or imsize==64 or imsize==128) 10 | self.imsize = imsize 11 | self.hflip = hflip 12 | 13 | SN = torch.nn.utils.spectral_norm 14 | IN = lambda x : nn.InstanceNorm2d(x) 15 | 16 | blocks = [] 17 | if self.imsize==128: 18 | blocks += [ 19 | # input is (nc) x 128 x 128 20 | SN(nn.Conv2d(nc, ndf//2, 4, 2, 1, bias=False)), 21 | nn.LeakyReLU(0.2, inplace=True), 22 | # input is (ndf//2) x 64 x 64 23 | SN(nn.Conv2d(ndf//2, ndf, 4, 2, 1, bias=False)), 24 | IN(ndf), 25 | nn.LeakyReLU(0.2, inplace=True), 26 | # state size. (ndf) x 32 x 32 27 | SN(nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False)), 28 | #nn.BatchNorm2d(ndf * 2), 29 | IN(ndf * 2), 30 | nn.LeakyReLU(0.2, inplace=True), 31 | ] 32 | elif self.imsize==64: 33 | blocks += [ 34 | # input is (nc) x 64 x 64 35 | SN(nn.Conv2d(nc, ndf, 4, 2, 1, bias=False)), 36 | nn.LeakyReLU(0.2, inplace=True), 37 | # state size. (ndf) x 32 x 32 38 | SN(nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False)), 39 | #nn.BatchNorm2d(ndf * 2), 40 | IN(ndf * 2), 41 | nn.LeakyReLU(0.2, inplace=True), 42 | ] 43 | else: 44 | blocks += [ 45 | # input is (nc) x 32 x 32 46 | SN(nn.Conv2d(nc, ndf * 2, 4, 2, 1, bias=False)), 47 | #nn.BatchNorm2d(ndf * 2), 48 | IN(ndf * 2), 49 | nn.LeakyReLU(0.2, inplace=True), 50 | ] 51 | 52 | blocks += [ 53 | # state size. (ndf*2) x 16 x 16 54 | SN(nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False)), 55 | #nn.BatchNorm2d(ndf * 4), 56 | IN(ndf * 4), 57 | nn.LeakyReLU(0.2, inplace=True), 58 | # state size. (ndf*4) x 8 x 8 59 | SN(nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False)), 60 | #nn.BatchNorm2d(ndf * 8), 61 | IN(ndf * 8), 62 | nn.LeakyReLU(0.2, inplace=True), 63 | # state size. (ndf*8) x 4 x 4 64 | SN(nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False)), 65 | # nn.Sigmoid() 66 | ] 67 | blocks = [x for x in blocks if x] 68 | self.main = nn.Sequential(*blocks) 69 | 70 | def forward(self, input, y=None): 71 | input = input[:, :self.nc] 72 | input = input.view(-1, self.imsize, self.imsize, self.nc).permute(0, 3, 1, 2) # (BxN_samples)xC -> BxCxHxW 73 | 74 | if self.hflip: # Randomly flip input horizontally 75 | input_flipped = input.flip(3) 76 | mask = torch.randint(0, 2, (len(input),1, 1, 1)).bool().expand(-1, *input.shape[1:]) 77 | input = torch.where(mask, input, input_flipped) 78 | 79 | return self.main(input) 80 | 81 | 82 | -------------------------------------------------------------------------------- /graf/models/generator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from ..utils import sample_on_sphere, look_at, to_sphere 4 | from ..transforms import FullRaySampler 5 | from submodules.nerf_pytorch.run_nerf_mod import render, run_network # import conditional render 6 | from functools import partial 7 | 8 | 9 | class Generator(object): 10 | def __init__(self, H, W, focal, radius, ray_sampler, render_kwargs_train, render_kwargs_test, parameters, named_parameters, 11 | range_u=(0,1), range_v=(0.01,0.49), chunk=None, device='cuda', orthographic=False): 12 | self.device = device 13 | self.H = int(H) 14 | self.W = int(W) 15 | self.focal = focal 16 | self.radius = radius 17 | self.range_u = range_u 18 | self.range_v = range_v 19 | self.chunk = chunk 20 | coords = torch.from_numpy(np.stack(np.meshgrid(np.arange(H), np.arange(W), indexing='ij'), -1)) 21 | self.coords = coords.view(-1, 2) 22 | 23 | self.ray_sampler = ray_sampler 24 | self.val_ray_sampler = FullRaySampler(orthographic=orthographic) 25 | self.render_kwargs_train = render_kwargs_train 26 | self.render_kwargs_test = render_kwargs_test 27 | self.initial_raw_noise_std = self.render_kwargs_train['raw_noise_std'] 28 | self._parameters = parameters 29 | self._named_parameters = named_parameters 30 | self.module_dict = {'generator': self.render_kwargs_train['network_fn']} 31 | for name, module in [('generator_fine', self.render_kwargs_train['network_fine'])]: 32 | if module is not None: 33 | self.module_dict[name] = module 34 | 35 | for k, v in self.module_dict.items(): 36 | if k in ['generator', 'generator_fine']: 37 | continue # parameters already included 38 | self._parameters += list(v.parameters()) 39 | self._named_parameters += list(v.named_parameters()) 40 | 41 | self.parameters = lambda: self._parameters # save as function to enable calling model.parameters() 42 | self.named_parameters = lambda: self._named_parameters # save as function to enable calling model.named_parameters() 43 | self.use_test_kwargs = False 44 | 45 | self.render = partial(render, H=self.H, W=self.W, focal=self.focal, chunk=self.chunk) 46 | 47 | def __call__(self, z, y=None, rays=None): 48 | bs = z.shape[0] 49 | if rays is None: 50 | rays = torch.cat([self.sample_rays() for _ in range(bs)], dim=1) 51 | 52 | render_kwargs = self.render_kwargs_test if self.use_test_kwargs else self.render_kwargs_train 53 | render_kwargs = dict(render_kwargs) # copy 54 | 55 | # in the case of a variable radius 56 | # we need to adjust near and far plane for the rays 57 | # so they stay within the bounds defined wrt. maximal radius 58 | # otherwise each camera samples within its own near/far plane (relative to this camera's radius) 59 | # instead of the absolute value (relative to maximum camera radius) 60 | if isinstance(self.radius, tuple): 61 | assert self.radius[1] - self.radius[0] <= render_kwargs['near'], 'Your smallest radius lies behind your near plane!' 62 | 63 | rays_radius = rays[0].norm(dim=-1) 64 | shift = (self.radius[1] - rays_radius).view(-1, 1).float() # reshape s.t. shape matches required shape in run_nerf 65 | render_kwargs['near'] = render_kwargs['near'] - shift 66 | render_kwargs['far'] = render_kwargs['far'] - shift 67 | assert (render_kwargs['near'] >= 0).all() and (render_kwargs['far'] >= 0).all(), \ 68 | (rays_radius.min(), rays_radius.max(), shift.min(), shift.max()) 69 | 70 | 71 | render_kwargs['features'] = z 72 | rgb, disp, acc, extras = render(self.H, self.W, self.focal, chunk=self.chunk, rays=rays, 73 | **render_kwargs) 74 | 75 | rays_to_output = lambda x: x.view(len(x), -1) * 2 - 1 # (BxN_samples)xC 76 | 77 | if self.use_test_kwargs: # return all outputs 78 | return rays_to_output(rgb), \ 79 | rays_to_output(disp), \ 80 | rays_to_output(acc), extras 81 | 82 | rgb = rays_to_output(rgb) 83 | return rgb 84 | 85 | def decrease_nerf_noise(self, it): 86 | end_it = 5000 87 | if it < end_it: 88 | noise_std = self.initial_raw_noise_std - self.initial_raw_noise_std/end_it * it 89 | self.render_kwargs_train['raw_noise_std'] = noise_std 90 | 91 | def sample_pose(self): 92 | # sample location on unit sphere 93 | loc = sample_on_sphere(self.range_u, self.range_v) 94 | 95 | # sample radius if necessary 96 | radius = self.radius 97 | if isinstance(radius, tuple): 98 | radius = np.random.uniform(*radius) 99 | 100 | loc = loc * radius 101 | R = look_at(loc)[0] 102 | 103 | RT = np.concatenate([R, loc.reshape(3, 1)], axis=1) 104 | RT = torch.Tensor(RT.astype(np.float32)) 105 | return RT 106 | 107 | def sample_rays(self): 108 | pose = self.sample_pose() 109 | sampler = self.val_ray_sampler if self.use_test_kwargs else self.ray_sampler 110 | batch_rays, _, _ = sampler(self.H, self.W, self.focal, pose) 111 | return batch_rays 112 | 113 | def to(self, device): 114 | self.render_kwargs_train['network_fn'].to(device) 115 | if self.render_kwargs_train['network_fine'] is not None: 116 | self.render_kwargs_train['network_fine'].to(device) 117 | self.device = device 118 | return self 119 | 120 | def train(self): 121 | self.use_test_kwargs = False 122 | self.render_kwargs_train['network_fn'].train() 123 | if self.render_kwargs_train['network_fine'] is not None: 124 | self.render_kwargs_train['network_fine'].train() 125 | 126 | def eval(self): 127 | self.use_test_kwargs = True 128 | self.render_kwargs_train['network_fn'].eval() 129 | if self.render_kwargs_train['network_fine'] is not None: 130 | self.render_kwargs_train['network_fine'].eval() 131 | -------------------------------------------------------------------------------- /graf/transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from math import sqrt, exp 3 | 4 | from submodules.nerf_pytorch.run_nerf_helpers_mod import get_rays, get_rays_ortho 5 | 6 | 7 | class ImgToPatch(object): 8 | def __init__(self, ray_sampler, hwf): 9 | self.ray_sampler = ray_sampler 10 | self.hwf = hwf # camera intrinsics 11 | 12 | def __call__(self, img): 13 | rgbs = [] 14 | for img_i in img: 15 | pose = torch.eye(4) # use dummy pose to infer pixel values 16 | _, selected_idcs, pixels_i = self.ray_sampler(H=self.hwf[0], W=self.hwf[1], focal=self.hwf[2], pose=pose) 17 | if selected_idcs is not None: 18 | rgbs_i = img_i.flatten(1, 2).t()[selected_idcs] 19 | else: 20 | rgbs_i = torch.nn.functional.grid_sample(img_i.unsqueeze(0), 21 | pixels_i.unsqueeze(0), mode='bilinear', align_corners=True)[0] 22 | rgbs_i = rgbs_i.flatten(1, 2).t() 23 | rgbs.append(rgbs_i) 24 | 25 | rgbs = torch.cat(rgbs, dim=0) # (B*N)x3 26 | 27 | return rgbs 28 | 29 | 30 | class RaySampler(object): 31 | def __init__(self, N_samples, orthographic=False): 32 | super(RaySampler, self).__init__() 33 | self.N_samples = N_samples 34 | self.scale = torch.ones(1,).float() 35 | self.return_indices = True 36 | self.orthographic = orthographic 37 | 38 | def __call__(self, H, W, focal, pose): 39 | if self.orthographic: 40 | size_h, size_w = focal # Hacky 41 | rays_o, rays_d = get_rays_ortho(H, W, pose, size_h, size_w) 42 | else: 43 | rays_o, rays_d = get_rays(H, W, focal, pose) 44 | 45 | select_inds = self.sample_rays(H, W) 46 | 47 | if self.return_indices: 48 | rays_o = rays_o.view(-1, 3)[select_inds] 49 | rays_d = rays_d.view(-1, 3)[select_inds] 50 | 51 | h = (select_inds // W) / float(H) - 0.5 52 | w = (select_inds % W) / float(W) - 0.5 53 | 54 | hw = torch.stack([h,w]).t() 55 | 56 | else: 57 | rays_o = torch.nn.functional.grid_sample(rays_o.permute(2,0,1).unsqueeze(0), 58 | select_inds.unsqueeze(0), mode='bilinear', align_corners=True)[0] 59 | rays_d = torch.nn.functional.grid_sample(rays_d.permute(2,0,1).unsqueeze(0), 60 | select_inds.unsqueeze(0), mode='bilinear', align_corners=True)[0] 61 | rays_o = rays_o.permute(1,2,0).view(-1, 3) 62 | rays_d = rays_d.permute(1,2,0).view(-1, 3) 63 | 64 | hw = select_inds 65 | select_inds = None 66 | 67 | return torch.stack([rays_o, rays_d]), select_inds, hw 68 | 69 | def sample_rays(self, H, W): 70 | raise NotImplementedError 71 | 72 | 73 | class FullRaySampler(RaySampler): 74 | def __init__(self, **kwargs): 75 | super(FullRaySampler, self).__init__(N_samples=None, **kwargs) 76 | 77 | def sample_rays(self, H, W): 78 | return torch.arange(0, H*W) 79 | 80 | 81 | class FlexGridRaySampler(RaySampler): 82 | def __init__(self, N_samples, random_shift=True, random_scale=True, min_scale=0.25, max_scale=1., scale_anneal=-1, 83 | **kwargs): 84 | self.N_samples_sqrt = int(sqrt(N_samples)) 85 | super(FlexGridRaySampler, self).__init__(self.N_samples_sqrt**2, **kwargs) 86 | 87 | self.random_shift = random_shift 88 | self.random_scale = random_scale 89 | 90 | self.min_scale = min_scale 91 | self.max_scale = max_scale 92 | 93 | # nn.functional.grid_sample grid value range in [-1,1] 94 | self.w, self.h = torch.meshgrid([torch.linspace(-1,1,self.N_samples_sqrt), 95 | torch.linspace(-1,1,self.N_samples_sqrt)]) 96 | self.h = self.h.unsqueeze(2) 97 | self.w = self.w.unsqueeze(2) 98 | 99 | # directly return grid for grid_sample 100 | self.return_indices = False 101 | 102 | self.iterations = 0 103 | self.scale_anneal = scale_anneal 104 | 105 | def sample_rays(self, H, W): 106 | 107 | if self.scale_anneal>0: 108 | k_iter = self.iterations // 1000 * 3 109 | min_scale = max(self.min_scale, self.max_scale * exp(-k_iter*self.scale_anneal)) 110 | min_scale = min(0.9, min_scale) 111 | else: 112 | min_scale = self.min_scale 113 | 114 | scale = 1 115 | if self.random_scale: 116 | scale = torch.Tensor(1).uniform_(min_scale, self.max_scale) 117 | h = self.h * scale 118 | w = self.w * scale 119 | 120 | if self.random_shift: 121 | max_offset = 1-scale.item() 122 | h_offset = torch.Tensor(1).uniform_(0, max_offset) * (torch.randint(2,(1,)).float()-0.5)*2 123 | w_offset = torch.Tensor(1).uniform_(0, max_offset) * (torch.randint(2,(1,)).float()-0.5)*2 124 | 125 | h += h_offset 126 | w += w_offset 127 | 128 | self.scale = scale 129 | 130 | return torch.cat([h, w], dim=2) 131 | -------------------------------------------------------------------------------- /graf/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import imageio 4 | import os 5 | 6 | 7 | def get_nsamples(data_loader, N): 8 | x = [] 9 | n = 0 10 | while n < N: 11 | x_next = next(iter(data_loader)) 12 | x.append(x_next) 13 | n += x_next.size(0) 14 | x = torch.cat(x, dim=0)[:N] 15 | return x 16 | 17 | 18 | def count_trainable_parameters(model): 19 | model_parameters = filter(lambda p: p.requires_grad, model.parameters()) 20 | return sum([np.prod(p.size()) for p in model_parameters]) 21 | 22 | 23 | def save_video(imgs, fname, as_gif=False, fps=24, quality=8): 24 | # convert to np.uint8 25 | imgs = (255 * np.clip(imgs.permute(0, 2, 3, 1).detach().cpu().numpy() / 2 + 0.5, 0, 1)).astype(np.uint8) 26 | imageio.mimwrite(fname, imgs, fps=fps, quality=quality) 27 | 28 | if as_gif: # save as gif, too 29 | os.system(f'ffmpeg -i {fname} -r 15 ' 30 | f'-vf "scale=512:-1,split[s0][s1];[s0]palettegen[p];[s1][p]paletteuse" {os.path.splitext(fname)[0] + ".gif"}') 31 | 32 | 33 | def color_depth_map(depths, scale=None): 34 | """ 35 | Color an input depth map. 36 | 37 | Arguments: 38 | depths -- HxW numpy array of depths 39 | [scale=None] -- scaling the values (defaults to the maximum depth) 40 | 41 | Returns: 42 | colored_depths -- HxWx3 numpy array visualizing the depths 43 | """ 44 | 45 | _color_map_depths = np.array([ 46 | [0, 0, 0], # 0.000 47 | [0, 0, 255], # 0.114 48 | [255, 0, 0], # 0.299 49 | [255, 0, 255], # 0.413 50 | [0, 255, 0], # 0.587 51 | [0, 255, 255], # 0.701 52 | [255, 255, 0], # 0.886 53 | [255, 255, 255], # 1.000 54 | [255, 255, 255], # 1.000 55 | ]).astype(float) 56 | _color_map_bincenters = np.array([ 57 | 0.0, 58 | 0.114, 59 | 0.299, 60 | 0.413, 61 | 0.587, 62 | 0.701, 63 | 0.886, 64 | 1.000, 65 | 2.000, # doesn't make a difference, just strictly higher than 1 66 | ]) 67 | 68 | if scale is None: 69 | scale = depths.max() 70 | 71 | values = np.clip(depths.flatten() / scale, 0, 1) 72 | # for each value, figure out where they fit in in the bincenters: what is the last bincenter smaller than this value? 73 | lower_bin = ((values.reshape(-1, 1) >= _color_map_bincenters.reshape(1, -1)) * np.arange(0, 9)).max(axis=1) 74 | lower_bin_value = _color_map_bincenters[lower_bin] 75 | higher_bin_value = _color_map_bincenters[lower_bin + 1] 76 | alphas = (values - lower_bin_value) / (higher_bin_value - lower_bin_value) 77 | colors = _color_map_depths[lower_bin] * (1 - alphas).reshape(-1, 1) + _color_map_depths[ 78 | lower_bin + 1] * alphas.reshape(-1, 1) 79 | return colors.reshape(depths.shape[0], depths.shape[1], 3).astype(np.uint8) 80 | 81 | 82 | # Virtual camera utils 83 | 84 | 85 | def to_sphere(u, v): 86 | theta = 2 * np.pi * u 87 | phi = np.arccos(1 - 2 * v) 88 | cx = np.sin(phi) * np.cos(theta) 89 | cy = np.sin(phi) * np.sin(theta) 90 | cz = np.cos(phi) 91 | s = np.stack([cx, cy, cz]) 92 | return s 93 | 94 | 95 | def polar_to_cartesian(r, theta, phi, deg=True): 96 | if deg: 97 | phi = phi * np.pi / 180 98 | theta = theta * np.pi / 180 99 | cx = np.sin(phi) * np.cos(theta) 100 | cy = np.sin(phi) * np.sin(theta) 101 | cz = np.cos(phi) 102 | return r * np.stack([cx, cy, cz]) 103 | 104 | 105 | def to_uv(loc): 106 | # normalize to unit sphere 107 | loc = loc / loc.norm(dim=1, keepdim=True) 108 | 109 | cx, cy, cz = loc.t() 110 | v = (1 - cz) / 2 111 | 112 | phi = torch.acos(cz) 113 | sin_phi = torch.sin(phi) 114 | 115 | # ensure we do not divide by zero 116 | eps = 1e-8 117 | sin_phi[sin_phi.abs() < eps] = eps 118 | 119 | theta = torch.acos(cx / sin_phi) 120 | 121 | # check for sign of phi 122 | cx_rec = sin_phi * torch.cos(theta) 123 | if not np.isclose(cx.numpy(), cx_rec.numpy(), atol=1e-5).all(): 124 | sin_phi = -sin_phi 125 | 126 | # check for sign of theta 127 | cy_rec = sin_phi * torch.sin(theta) 128 | if not np.isclose(cy.numpy(), cy_rec.numpy(), atol=1e-5).all(): 129 | theta = -theta 130 | 131 | u = theta / (2 * np.pi) 132 | assert np.isclose(to_sphere(u, v).detach().cpu().numpy(), loc.t().detach().cpu().numpy(), atol=1e-5).all() 133 | 134 | return u, v 135 | 136 | 137 | def to_phi(u): 138 | return 360 * u # 2*pi*u*180/pi 139 | 140 | 141 | def to_theta(v): 142 | return np.arccos(1 - 2 * v) * 180. / np.pi 143 | 144 | 145 | def sample_on_sphere(range_u=(0, 1), range_v=(0, 1)): 146 | u = np.random.uniform(*range_u) 147 | v = np.random.uniform(*range_v) 148 | return to_sphere(u, v) 149 | 150 | 151 | def look_at(eye, at=np.array([0, 0, 0]), up=np.array([0, 0, 1]), eps=1e-5): 152 | at = at.astype(float).reshape(1, 3) 153 | up = up.astype(float).reshape(1, 3) 154 | 155 | eye = eye.reshape(-1, 3) 156 | up = up.repeat(eye.shape[0] // up.shape[0], axis=0) 157 | eps = np.array([eps]).reshape(1, 1).repeat(up.shape[0], axis=0) 158 | 159 | z_axis = eye - at 160 | z_axis /= np.max(np.stack([np.linalg.norm(z_axis, axis=1, keepdims=True), eps])) 161 | 162 | x_axis = np.cross(up, z_axis) 163 | x_axis /= np.max(np.stack([np.linalg.norm(x_axis, axis=1, keepdims=True), eps])) 164 | 165 | y_axis = np.cross(z_axis, x_axis) 166 | y_axis /= np.max(np.stack([np.linalg.norm(y_axis, axis=1, keepdims=True), eps])) 167 | 168 | r_mat = np.concatenate((x_axis.reshape(-1, 3, 1), y_axis.reshape(-1, 3, 1), z_axis.reshape(-1, 3, 1)), axis=2) 169 | 170 | return r_mat 171 | -------------------------------------------------------------------------------- /submodules/GAN_stability/.gitignore: -------------------------------------------------------------------------------- 1 | output 2 | data 3 | *_lmdb 4 | __pycache__ 5 | *.pyc 6 | -------------------------------------------------------------------------------- /submodules/GAN_stability/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Lars Mescheder 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /submodules/GAN_stability/README.md: -------------------------------------------------------------------------------- 1 | # GAN stability 2 | This repository contains the experiments in the supplementary material for the paper [Which Training Methods for GANs do actually Converge?](https://avg.is.tuebingen.mpg.de/publications/meschedericml2018). 3 | 4 | To cite this work, please use 5 | ``` 6 | @INPROCEEDINGS{Mescheder2018ICML, 7 | author = {Lars Mescheder and Sebastian Nowozin and Andreas Geiger}, 8 | title = {Which Training Methods for GANs do actually Converge?}, 9 | booktitle = {International Conference on Machine Learning (ICML)}, 10 | year = {2018} 11 | } 12 | ``` 13 | You can find further details on [our project page](https://avg.is.tuebingen.mpg.de/research_projects/convergence-and-stability-of-gan-training). 14 | 15 | # Usage 16 | First download your data and put it into the `./data` folder. 17 | 18 | To train a new model, first create a config script similar to the ones provided in the `./configs` folder. You can then train you model using 19 | ``` 20 | python train.py PATH_TO_CONFIG 21 | ``` 22 | 23 | To compute the inception score for your model and generate samples, use 24 | ``` 25 | python test.py PATH_TO_CONFIG 26 | ``` 27 | 28 | Finally, you can create nice latent space interpolations using 29 | ``` 30 | python interpolate.py PATH_TO_CONFIG 31 | ``` 32 | or 33 | ``` 34 | python interpolate_class.py PATH_TO_CONFIG 35 | ``` 36 | 37 | # Pretrained models 38 | We also provide several pretrained models. 39 | 40 | You can use the models for sampling by entering 41 | ``` 42 | python test.py PATH_TO_CONFIG 43 | ``` 44 | where `PATH_TO_CONFIG` is one of the config files 45 | ``` 46 | configs/pretrained/celebA_pretrained.yaml 47 | configs/pretrained/celebAHQ_pretrained.yaml 48 | configs/pretrained/imagenet_pretrained.yaml 49 | configs/pretrained/lsun_bedroom_pretrained.yaml 50 | configs/pretrained/lsun_bridge_pretrained.yaml 51 | configs/pretrained/lsun_church_pretrained.yaml 52 | configs/pretrained/lsun_tower_pretrained.yaml 53 | ``` 54 | Our script will automatically download the model checkpoints and run the generation. 55 | You can find the outputs in the `output/pretrained` folders. 56 | Similarly, you can use the scripts `interpolate.py` and `interpolate_class.py` for generating interpolations for the pretrained models. 57 | 58 | Please note that the config files `*_pretrained.yaml` are only for generation, not for training new models: when these configs are used for training, the model will be trained from scratch, but during inference our code will still use the pretrained model. 59 | 60 | # Notes 61 | * Batch normalization is currently *not* supported when using an exponential running average, as the running average is only computed over the parameters of the models and not the other buffers of the model. 62 | 63 | # Results 64 | ## celebA-HQ 65 | ![celebA-HQ](results/celebA-HQ.jpg) 66 | 67 | ## Imagenet 68 | ![Imagenet 0](results/imagenet_00.jpg) 69 | ![Imagenet 1](results/imagenet_01.jpg) 70 | ![Imagenet 2](results/imagenet_02.jpg) 71 | ![Imagenet 3](results/imagenet_03.jpg) 72 | ![Imagenet 4](results/imagenet_04.jpg) 73 | -------------------------------------------------------------------------------- /submodules/GAN_stability/configs/celebAHQ.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | type: npy 3 | train_dir: data/celebA-HQ 4 | test_dir: data/celebA-HQ 5 | img_size: 1024 6 | generator: 7 | name: resnet 8 | kwargs: 9 | nfilter: 16 10 | nfilter_max: 512 11 | embed_size: 1 12 | discriminator: 13 | name: resnet 14 | kwargs: 15 | nfilter: 16 16 | nfilter_max: 512 17 | embed_size: 1 18 | z_dist: 19 | type: gauss 20 | dim: 256 21 | training: 22 | out_dir: output/celebAHQ 23 | batch_size: 24 24 | test: 25 | batch_size: 4 26 | sample_size: 6 27 | sample_nrow: 3 28 | interpolations: 29 | nzs: 10 30 | nsubsteps: 75 31 | -------------------------------------------------------------------------------- /submodules/GAN_stability/configs/default.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | type: lsun 3 | train_dir: data/LSUN 4 | test_dir: data/LSUN 5 | lsun_categories_train: [bedroom_train] 6 | lsun_categories_test: [bedroom_test] 7 | img_size: 256 8 | nlabels: 1 9 | generator: 10 | name: resnet 11 | kwargs: 12 | discriminator: 13 | name: resnet 14 | kwargs: 15 | z_dist: 16 | type: gauss 17 | dim: 256 18 | training: 19 | out_dir: output/default 20 | gan_type: standard 21 | reg_type: real 22 | reg_param: 10. 23 | batch_size: 64 24 | nworkers: 16 25 | take_model_average: true 26 | model_average_beta: 0.999 27 | model_average_reinit: false 28 | monitoring: tensorboard 29 | sample_every: 1000 30 | sample_nlabels: 20 31 | inception_every: -1 32 | save_every: 900 33 | backup_every: 100000 34 | restart_every: -1 35 | optimizer: rmsprop 36 | lr_g: 0.0001 37 | lr_d: 0.0001 38 | lr_anneal: 1. 39 | lr_anneal_every: 150000 40 | d_steps: 1 41 | equalize_lr: false 42 | model_file: model.pt 43 | test: 44 | batch_size: 32 45 | sample_size: 64 46 | sample_nrow: 8 47 | use_model_average: true 48 | compute_inception: false 49 | conditional_samples: false 50 | model_file: model.pt 51 | interpolations: 52 | nzs: 10 53 | nsubsteps: 75 54 | -------------------------------------------------------------------------------- /submodules/GAN_stability/configs/imagenet.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | type: image 3 | train_dir: data/Imagenet 4 | test_dir: data/Imagenet 5 | img_size: 128 6 | nlabels: 1000 7 | generator: 8 | name: resnet2 9 | kwargs: 10 | nfilter: 64 11 | nfilter_max: 1024 12 | embed_size: 256 13 | discriminator: 14 | name: resnet2 15 | kwargs: 16 | nfilter: 64 17 | nfilter_max: 1024 18 | embed_size: 256 19 | z_dist: 20 | type: gauss 21 | dim: 256 22 | training: 23 | out_dir: output/imagenet 24 | gan_type: standard 25 | sample_nlabels: 20 26 | inception_every: 10000 27 | batch_size: 128 28 | test: 29 | batch_size: 32 30 | sample_size: 64 31 | sample_nrow: 8 32 | compute_inception: true 33 | conditional_samples: true 34 | interpolations: 35 | ys: [15, 157, 307, 321, 442, 483, 484, 525, 36 | 536, 598, 607, 734, 768, 795, 927, 977, 37 | 963, 946, 979] 38 | nzs: 10 39 | nsubsteps: 75 40 | -------------------------------------------------------------------------------- /submodules/GAN_stability/configs/lsun_bedroom.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | type: lsun 3 | train_dir: data/LSUN 4 | test_dir: data/LSUN 5 | lsun_categories_train: [bedroom_train] 6 | lsun_categories_test: [bedroom_test] 7 | img_size: 256 8 | generator: 9 | name: resnet 10 | kwargs: 11 | nfilter: 64 12 | nfilter_max: 1024 13 | embed_size: 1 14 | discriminator: 15 | name: resnet 16 | kwargs: 17 | nfilter: 64 18 | nfilter_max: 1024 19 | embed_size: 1 20 | z_dist: 21 | type: gauss 22 | dim: 256 23 | training: 24 | out_dir: output/lsun_bedroom 25 | test: 26 | batch_size: 32 27 | sample_size: 64 28 | sample_nrow: 8 29 | interpolations: 30 | nzs: 10 31 | nsubsteps: 75 32 | -------------------------------------------------------------------------------- /submodules/GAN_stability/configs/lsun_bridge.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | type: lsun 3 | train_dir: data/LSUN 4 | test_dir: data/LSUN 5 | lsun_categories_train: [bridge_train] 6 | lsun_categories_test: [bridge_train] 7 | img_size: 256 8 | generator: 9 | name: resnet 10 | kwargs: 11 | nfilter: 64 12 | nfilter_max: 1024 13 | embed_size: 1 14 | discriminator: 15 | name: resnet 16 | kwargs: 17 | nfilter: 64 18 | nfilter_max: 1024 19 | embed_size: 1 20 | z_dist: 21 | type: gauss 22 | dim: 256 23 | training: 24 | out_dir: output/lsun_bridge 25 | test: 26 | batch_size: 32 27 | sample_size: 64 28 | sample_nrow: 8 29 | interpolations: 30 | nzs: 10 31 | nsubsteps: 75 32 | -------------------------------------------------------------------------------- /submodules/GAN_stability/configs/lsun_church.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | type: lsun 3 | train_dir: data/LSUN 4 | test_dir: data/LSUN 5 | lsun_categories_train: [church_outdoor_train] 6 | lsun_categories_test: [church_outdoor_test] 7 | img_size: 256 8 | generator: 9 | name: resnet 10 | kwargs: 11 | nfilter: 64 12 | nfilter_max: 1024 13 | embed_size: 1 14 | discriminator: 15 | name: resnet 16 | kwargs: 17 | nfilter: 64 18 | nfilter_max: 1024 19 | embed_size: 1 20 | z_dist: 21 | type: gauss 22 | dim: 256 23 | training: 24 | out_dir: output/lsun_church 25 | test: 26 | batch_size: 32 27 | sample_size: 64 28 | sample_nrow: 8 29 | interpolations: 30 | nzs: 10 31 | nsubsteps: 75 32 | -------------------------------------------------------------------------------- /submodules/GAN_stability/configs/lsun_tower.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | type: lsun 3 | train_dir: data/LSUN 4 | test_dir: data/LSUN 5 | lsun_categories_train: [tower_train] 6 | lsun_categories_test: [tower_test] 7 | img_size: 256 8 | generator: 9 | name: resnet 10 | kwargs: 11 | nfilter: 64 12 | nfilter_max: 1024 13 | embed_size: 1 14 | discriminator: 15 | name: resnet 16 | kwargs: 17 | nfilter: 64 18 | nfilter_max: 1024 19 | embed_size: 1 20 | z_dist: 21 | type: gauss 22 | dim: 256 23 | training: 24 | out_dir: output/lsun_tower 25 | test: 26 | batch_size: 32 27 | sample_size: 64 28 | sample_nrow: 8 29 | interpolations: 30 | nzs: 10 31 | nsubsteps: 75 32 | -------------------------------------------------------------------------------- /submodules/GAN_stability/configs/pretrained/celebAHQ_pretrained.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | type: npy 3 | train_dir: data/celebA-HQ 4 | test_dir: data/celebA-HQ 5 | img_size: 1024 6 | generator: 7 | name: resnet 8 | kwargs: 9 | nfilter: 16 10 | nfilter_max: 512 11 | embed_size: 1 12 | discriminator: 13 | name: resnet 14 | kwargs: 15 | nfilter: 16 16 | nfilter_max: 512 17 | embed_size: 1 18 | z_dist: 19 | type: gauss 20 | dim: 256 21 | training: 22 | out_dir: output/pretrained/celebAHQ 23 | batch_size: 24 24 | test: 25 | model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/gan_stability/models/celebahq-baab46b2.pt 26 | batch_size: 4 27 | sample_size: 6 28 | sample_nrow: 3 29 | interpolations: 30 | nzs: 10 31 | nsubsteps: 75 32 | -------------------------------------------------------------------------------- /submodules/GAN_stability/configs/pretrained/celebA_pretrained.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | type: image 3 | train_dir: data/celebA 4 | test_dir: data/celebA 5 | img_size: 256 6 | generator: 7 | name: resnet4 8 | kwargs: 9 | nfilter: 64 10 | embed_size: 1 11 | discriminator: 12 | name: resnet4 13 | kwargs: 14 | nfilter: 64 15 | embed_size: 1 16 | z_dist: 17 | type: gauss 18 | dim: 256 19 | training: 20 | out_dir: output/pretrained/celebA 21 | test: 22 | model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/gan_stability/models/celeba-ab478c9d.pt 23 | batch_size: 32 24 | sample_size: 64 25 | sample_nrow: 8 26 | interpolations: 27 | nzs: 10 28 | nsubsteps: 75 29 | -------------------------------------------------------------------------------- /submodules/GAN_stability/configs/pretrained/imagenet_pretrained.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | type: image 3 | train_dir: data/Imagenet 4 | test_dir: data/Imagenet 5 | img_size: 128 6 | nlabels: 1000 7 | generator: 8 | name: resnet2 9 | kwargs: 10 | nfilter: 64 11 | nfilter_max: 1024 12 | embed_size: 256 13 | discriminator: 14 | name: resnet2 15 | kwargs: 16 | nfilter: 64 17 | nfilter_max: 1024 18 | embed_size: 256 19 | z_dist: 20 | type: gauss 21 | dim: 256 22 | training: 23 | out_dir: output/pretrained/imagenet 24 | sample_nlabels: 20 25 | inception_every: 10000 26 | batch_size: 128 27 | test: 28 | model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/gan_stability/models/imagenet-8c505f47.pt 29 | batch_size: 32 30 | sample_size: 64 31 | sample_nrow: 8 32 | compute_inception: false 33 | conditional_samples: true 34 | interpolations: 35 | ys: [15, 157, 307, 321, 442, 483, 484, 525, 36 | 536, 598, 607, 734, 768, 795, 927, 977, 37 | 963, 946, 979] 38 | nzs: 10 39 | nsubsteps: 75 40 | -------------------------------------------------------------------------------- /submodules/GAN_stability/configs/pretrained/lsun_bedroom_pretrained.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | type: lsun 3 | train_dir: data/LSUN 4 | test_dir: data/LSUN 5 | lsun_categories_train: [bedroom_train] 6 | lsun_categories_test: [bedroom_test] 7 | img_size: 256 8 | generator: 9 | name: resnet3 10 | kwargs: 11 | nfilter: 64 12 | embed_size: 1 13 | discriminator: 14 | name: resnet3 15 | kwargs: 16 | nfilter: 64 17 | embed_size: 1 18 | z_dist: 19 | type: gauss 20 | dim: 256 21 | training: 22 | out_dir: output/pretrained/lsun_bedroom 23 | test: 24 | model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/gan_stability/models/lsun_bedroom-df4e7dd2.pt 25 | batch_size: 32 26 | sample_size: 64 27 | sample_nrow: 8 28 | interpolations: 29 | nzs: 10 30 | nsubsteps: 75 31 | -------------------------------------------------------------------------------- /submodules/GAN_stability/configs/pretrained/lsun_bridge_pretrained.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | type: lsun 3 | train_dir: data/LSUN 4 | test_dir: data/LSUN 5 | lsun_categories_train: [bridge_train] 6 | lsun_categories_test: [bridge_test] 7 | img_size: 256 8 | generator: 9 | name: resnet3 10 | kwargs: 11 | nfilter: 64 12 | embed_size: 1 13 | discriminator: 14 | name: resnet3 15 | kwargs: 16 | nfilter: 64 17 | embed_size: 1 18 | z_dist: 19 | type: gauss 20 | dim: 256 21 | training: 22 | out_dir: output/pretrained/lsun_bridge 23 | test: 24 | model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/gan_stability/models/lsun_bridge-82887d22.pt 25 | batch_size: 32 26 | sample_size: 64 27 | sample_nrow: 8 28 | interpolations: 29 | nzs: 10 30 | nsubsteps: 75 31 | -------------------------------------------------------------------------------- /submodules/GAN_stability/configs/pretrained/lsun_church_pretrained.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | type: lsun 3 | train_dir: data/LSUN 4 | test_dir: data/LSUN 5 | lsun_categories_train: [church_outdoor_train] 6 | lsun_categories_test: [church_outdoor_test] 7 | img_size: 256 8 | generator: 9 | name: resnet3 10 | kwargs: 11 | nfilter: 64 12 | embed_size: 1 13 | discriminator: 14 | name: resnet3 15 | kwargs: 16 | nfilter: 64 17 | embed_size: 1 18 | z_dist: 19 | type: gauss 20 | dim: 256 21 | training: 22 | out_dir: output/pretrained/lsun_church 23 | test: 24 | model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/gan_stability/models/lsun_church-b6f0191b.pt 25 | batch_size: 32 26 | sample_size: 64 27 | sample_nrow: 8 28 | interpolations: 29 | nzs: 10 30 | nsubsteps: 75 31 | -------------------------------------------------------------------------------- /submodules/GAN_stability/configs/pretrained/lsun_tower_pretrained.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | type: lsun 3 | train_dir: data/LSUN 4 | test_dir: data/LSUN 5 | lsun_categories_train: [tower_train] 6 | lsun_categories_test: [tower_test] 7 | img_size: 256 8 | generator: 9 | name: resnet3 10 | kwargs: 11 | nfilter: 64 12 | embed_size: 1 13 | discriminator: 14 | name: resnet3 15 | kwargs: 16 | nfilter: 64 17 | embed_size: 1 18 | z_dist: 19 | type: gauss 20 | dim: 256 21 | training: 22 | out_dir: output/pretrained/lsun_tower 23 | test: 24 | model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/gan_stability/models/lsun_tower-1af5e570.pt 25 | batch_size: 32 26 | sample_size: 64 27 | sample_nrow: 8 28 | interpolations: 29 | nzs: 10 30 | nsubsteps: 75 31 | -------------------------------------------------------------------------------- /submodules/GAN_stability/gan_training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/graf/c50d342fb567aec335b92e3f867c54b4dc4e1d09/submodules/GAN_stability/gan_training/__init__.py -------------------------------------------------------------------------------- /submodules/GAN_stability/gan_training/checkpoints.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import urllib 4 | import torch 5 | from torch.utils import model_zoo 6 | 7 | 8 | class CheckpointIO(object): 9 | ''' CheckpointIO class. 10 | 11 | It handles saving and loading checkpoints. 12 | 13 | Args: 14 | checkpoint_dir (str): path where checkpoints are saved 15 | ''' 16 | def __init__(self, checkpoint_dir='./chkpts', **kwargs): 17 | self.module_dict = kwargs 18 | self.checkpoint_dir = checkpoint_dir 19 | if not os.path.exists(checkpoint_dir): 20 | os.makedirs(checkpoint_dir) 21 | 22 | def register_modules(self, **kwargs): 23 | ''' Registers modules in current module dictionary. 24 | ''' 25 | self.module_dict.update(kwargs) 26 | 27 | def save(self, filename, **kwargs): 28 | ''' Saves the current module dictionary. 29 | 30 | Args: 31 | filename (str): name of output file 32 | ''' 33 | if not os.path.isabs(filename): 34 | filename = os.path.join(self.checkpoint_dir, filename) 35 | 36 | outdict = kwargs 37 | for k, v in self.module_dict.items(): 38 | outdict[k] = v.state_dict() 39 | torch.save(outdict, filename) 40 | 41 | def load(self, filename): 42 | '''Loads a module dictionary from local file or url. 43 | 44 | Args: 45 | filename (str): name of saved module dictionary 46 | ''' 47 | if is_url(filename): 48 | return self.load_url(filename) 49 | else: 50 | return self.load_file(filename) 51 | 52 | def load_file(self, filename): 53 | '''Loads a module dictionary from file. 54 | 55 | Args: 56 | filename (str): name of saved module dictionary 57 | ''' 58 | 59 | if not os.path.isabs(filename): 60 | filename = os.path.join(self.checkpoint_dir, filename) 61 | 62 | if os.path.exists(filename): 63 | print(filename) 64 | print('=> Loading checkpoint from local file...') 65 | state_dict = torch.load(filename) 66 | scalars = self.parse_state_dict(state_dict) 67 | return scalars 68 | else: 69 | raise FileNotFoundError 70 | 71 | def load_url(self, url): 72 | '''Load a module dictionary from url. 73 | 74 | Args: 75 | url (str): url to saved model 76 | ''' 77 | print(url) 78 | print('=> Loading checkpoint from url...') 79 | state_dict = model_zoo.load_url(url, progress=True) 80 | scalars = self.parse_state_dict(state_dict) 81 | return scalars 82 | 83 | def parse_state_dict(self, state_dict): 84 | '''Parse state_dict of model and return scalars. 85 | 86 | Args: 87 | state_dict (dict): State dict of model 88 | ''' 89 | 90 | for k, v in self.module_dict.items(): 91 | if k in state_dict: 92 | v.load_state_dict(state_dict[k]) 93 | else: 94 | print('Warning: Could not find %s in checkpoint!' % k) 95 | scalars = {k: v for k, v in state_dict.items() 96 | if k not in self.module_dict} 97 | return scalars 98 | 99 | def is_url(url): 100 | scheme = urllib.parse.urlparse(url).scheme 101 | return scheme in ('http', 'https') -------------------------------------------------------------------------------- /submodules/GAN_stability/gan_training/config.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | from torch import optim 3 | from os import path 4 | from GAN_stability.gan_training.models import generator_dict, discriminator_dict 5 | from GAN_stability.gan_training.train import toggle_grad 6 | 7 | 8 | # General config 9 | def load_config(path, default_path): 10 | ''' Loads config file. 11 | 12 | Args: 13 | path (str): path to config file 14 | default_path (bool): whether to use default path 15 | ''' 16 | # Load configuration from file itself 17 | with open(path, 'r') as f: 18 | cfg_special = yaml.load(f) 19 | 20 | # Check if we should inherit from a config 21 | inherit_from = cfg_special.get('inherit_from') 22 | 23 | # If yes, load this config first as default 24 | # If no, use the default_path 25 | if inherit_from is not None: 26 | cfg = load_config(inherit_from, default_path) 27 | elif default_path is not None: 28 | with open(default_path, 'r') as f: 29 | cfg = yaml.load(f) 30 | else: 31 | cfg = dict() 32 | 33 | # Include main configuration 34 | update_recursive(cfg, cfg_special) 35 | 36 | return cfg 37 | 38 | 39 | def update_recursive(dict1, dict2): 40 | ''' Update two config dictionaries recursively. 41 | 42 | Args: 43 | dict1 (dict): first dictionary to be updated 44 | dict2 (dict): second dictionary which entries should be used 45 | 46 | ''' 47 | for k, v in dict2.items(): 48 | # Add item if not yet in dict1 49 | if k not in dict1: 50 | dict1[k] = None 51 | # Update 52 | if isinstance(dict1[k], dict): 53 | update_recursive(dict1[k], v) 54 | else: 55 | dict1[k] = v 56 | 57 | 58 | def build_models(config): 59 | # Get classes 60 | Generator = generator_dict[config['generator']['name']] 61 | Discriminator = discriminator_dict[config['discriminator']['name']] 62 | 63 | # Build models 64 | generator = Generator( 65 | z_dim=config['z_dist']['dim'], 66 | nlabels=config['data']['nlabels'], 67 | size=config['data']['img_size'], 68 | **config['generator']['kwargs'] 69 | ) 70 | discriminator = Discriminator( 71 | config['discriminator']['name'], 72 | nlabels=config['data']['nlabels'], 73 | size=config['data']['img_size'], 74 | **config['discriminator']['kwargs'] 75 | ) 76 | 77 | return generator, discriminator 78 | 79 | 80 | def build_optimizers(generator, discriminator, config): 81 | optimizer = config['training']['optimizer'] 82 | lr_g = config['training']['lr_g'] 83 | lr_d = config['training']['lr_d'] 84 | equalize_lr = config['training']['equalize_lr'] 85 | 86 | toggle_grad(generator, True) 87 | toggle_grad(discriminator, True) 88 | 89 | if equalize_lr: 90 | g_gradient_scales = getattr(generator, 'gradient_scales', dict()) 91 | d_gradient_scales = getattr(discriminator, 'gradient_scales', dict()) 92 | 93 | g_params = get_parameter_groups(generator.parameters(), 94 | g_gradient_scales, 95 | base_lr=lr_g) 96 | d_params = get_parameter_groups(discriminator.parameters(), 97 | d_gradient_scales, 98 | base_lr=lr_d) 99 | else: 100 | g_params = generator.parameters() 101 | d_params = discriminator.parameters() 102 | 103 | # Optimizers 104 | if optimizer == 'rmsprop': 105 | g_optimizer = optim.RMSprop(g_params, lr=lr_g, alpha=0.99, eps=1e-8) 106 | d_optimizer = optim.RMSprop(d_params, lr=lr_d, alpha=0.99, eps=1e-8) 107 | elif optimizer == 'adam': 108 | g_optimizer = optim.Adam(g_params, lr=lr_g, betas=(0., 0.99), eps=1e-8) 109 | d_optimizer = optim.Adam(d_params, lr=lr_d, betas=(0., 0.99), eps=1e-8) 110 | elif optimizer == 'sgd': 111 | g_optimizer = optim.SGD(g_params, lr=lr_g, momentum=0.) 112 | d_optimizer = optim.SGD(d_params, lr=lr_d, momentum=0.) 113 | 114 | return g_optimizer, d_optimizer 115 | 116 | 117 | def build_lr_scheduler(optimizer, config, last_epoch=-1): 118 | lr_scheduler = optim.lr_scheduler.StepLR( 119 | optimizer, 120 | step_size=config['training']['lr_anneal_every'], 121 | gamma=config['training']['lr_anneal'], 122 | last_epoch=last_epoch 123 | ) 124 | return lr_scheduler 125 | 126 | 127 | # Some utility functions 128 | def get_parameter_groups(parameters, gradient_scales, base_lr): 129 | param_groups = [] 130 | for p in parameters: 131 | c = gradient_scales.get(p, 1.) 132 | param_groups.append({ 133 | 'params': [p], 134 | 'lr': c * base_lr 135 | }) 136 | return param_groups 137 | -------------------------------------------------------------------------------- /submodules/GAN_stability/gan_training/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import distributions 3 | 4 | 5 | def get_zdist(dist_name, dim, device=None): 6 | # Get distribution 7 | if dist_name == 'uniform': 8 | low = -torch.ones(dim, device=device) 9 | high = torch.ones(dim, device=device) 10 | zdist = distributions.Uniform(low, high) 11 | elif dist_name == 'gauss': 12 | mu = torch.zeros(dim, device=device) 13 | scale = torch.ones(dim, device=device) 14 | zdist = distributions.Normal(mu, scale) 15 | else: 16 | raise NotImplementedError 17 | 18 | # Add dim attribute 19 | zdist.dim = dim 20 | 21 | return zdist 22 | 23 | 24 | def get_ydist(nlabels, device=None): 25 | logits = torch.zeros(nlabels, device=device) 26 | ydist = distributions.categorical.Categorical(logits=logits) 27 | 28 | # Add nlabels attribute 29 | ydist.nlabels = nlabels 30 | 31 | return ydist 32 | 33 | 34 | def interpolate_sphere(z1, z2, t): 35 | p = (z1 * z2).sum(dim=-1, keepdim=True) 36 | p = p / z1.pow(2).sum(dim=-1, keepdim=True).sqrt() 37 | p = p / z2.pow(2).sum(dim=-1, keepdim=True).sqrt() 38 | omega = torch.acos(p) 39 | s1 = torch.sin((1-t)*omega)/torch.sin(omega) 40 | s2 = torch.sin(t*omega)/torch.sin(omega) 41 | z = s1 * z1 + s2 * z2 42 | 43 | return z 44 | -------------------------------------------------------------------------------- /submodules/GAN_stability/gan_training/eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from GAN_stability.gan_training.metrics import inception_score 3 | 4 | 5 | class Evaluator(object): 6 | def __init__(self, generator, zdist, ydist, batch_size=64, 7 | inception_nsamples=60000, device=None): 8 | self.generator = generator 9 | self.zdist = zdist 10 | self.ydist = ydist 11 | self.inception_nsamples = inception_nsamples 12 | self.batch_size = batch_size 13 | self.device = device 14 | 15 | def compute_inception_score(self): 16 | self.generator.eval() 17 | imgs = [] 18 | while(len(imgs) < self.inception_nsamples): 19 | ztest = self.zdist.sample((self.batch_size,)) 20 | ytest = self.ydist.sample((self.batch_size,)) 21 | 22 | samples = self.generator(ztest, ytest) 23 | samples = [s.data.cpu().numpy() for s in samples] 24 | imgs.extend(samples) 25 | 26 | imgs = imgs[:self.inception_nsamples] 27 | score, score_std = inception_score( 28 | imgs, device=self.device, resize=True, splits=10 29 | ) 30 | 31 | return score, score_std 32 | 33 | def create_samples(self, z, y=None): 34 | self.generator.eval() 35 | batch_size = z.size(0) 36 | # Parse y 37 | if y is None: 38 | y = self.ydist.sample((batch_size,)) 39 | elif isinstance(y, int): 40 | y = torch.full((batch_size,), y, 41 | device=self.device, dtype=torch.int64) 42 | # Sample x 43 | with torch.no_grad(): 44 | x = self.generator(z, y) 45 | return x 46 | -------------------------------------------------------------------------------- /submodules/GAN_stability/gan_training/inputs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as transforms 3 | import torchvision.datasets as datasets 4 | import numpy as np 5 | 6 | 7 | def get_dataset(name, data_dir, size=64, lsun_categories=None): 8 | transform = transforms.Compose([ 9 | transforms.Resize(size), 10 | transforms.CenterCrop(size), 11 | transforms.RandomHorizontalFlip(), 12 | transforms.ToTensor(), 13 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 14 | transforms.Lambda(lambda x: x + 1./128 * torch.rand(x.size())), 15 | ]) 16 | 17 | if name == 'image': 18 | dataset = datasets.ImageFolder(data_dir, transform) 19 | nlabels = len(dataset.classes) 20 | elif name == 'npy': 21 | # Only support normalization for now 22 | dataset = datasets.DatasetFolder(data_dir, npy_loader, ['npy']) 23 | nlabels = len(dataset.classes) 24 | elif name == 'cifar10': 25 | dataset = datasets.CIFAR10(root=data_dir, train=True, download=True, 26 | transform=transform) 27 | nlabels = 10 28 | elif name == 'lsun': 29 | if lsun_categories is None: 30 | lsun_categories = 'train' 31 | dataset = datasets.LSUN(data_dir, lsun_categories, transform) 32 | nlabels = len(dataset.classes) 33 | elif name == 'lsun_class': 34 | dataset = datasets.LSUNClass(data_dir, transform, 35 | target_transform=(lambda t: 0)) 36 | nlabels = 1 37 | else: 38 | raise NotImplemented 39 | 40 | return dataset, nlabels 41 | 42 | 43 | def npy_loader(path): 44 | img = np.load(path) 45 | 46 | if img.dtype == np.uint8: 47 | img = img.astype(np.float32) 48 | img = img/127.5 - 1. 49 | elif img.dtype == np.float32: 50 | img = img * 2 - 1. 51 | else: 52 | raise NotImplementedError 53 | 54 | img = torch.Tensor(img) 55 | if len(img.size()) == 4: 56 | img.squeeze_(0) 57 | 58 | return img 59 | -------------------------------------------------------------------------------- /submodules/GAN_stability/gan_training/logger.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import os 3 | import torchvision 4 | 5 | 6 | class Logger(object): 7 | def __init__(self, log_dir='./logs', img_dir='./imgs', 8 | monitoring=None, monitoring_dir=None): 9 | self.stats = dict() 10 | self.log_dir = log_dir 11 | self.img_dir = img_dir 12 | 13 | if not os.path.exists(log_dir): 14 | os.makedirs(log_dir) 15 | 16 | if not os.path.exists(img_dir): 17 | os.makedirs(img_dir) 18 | 19 | if not (monitoring is None or monitoring == 'none'): 20 | self.setup_monitoring(monitoring, monitoring_dir) 21 | else: 22 | self.monitoring = None 23 | self.monitoring_dir = None 24 | 25 | def setup_monitoring(self, monitoring, monitoring_dir=None): 26 | self.monitoring = monitoring 27 | self.monitoring_dir = monitoring_dir 28 | 29 | if monitoring == 'telemetry': 30 | import telemetry 31 | self.tm = telemetry.ApplicationTelemetry() 32 | if self.tm.get_status() == 0: 33 | print('Telemetry successfully connected.') 34 | elif monitoring == 'tensorboard': 35 | import tensorboardX 36 | self.tb = tensorboardX.SummaryWriter(monitoring_dir) 37 | else: 38 | raise NotImplementedError('Monitoring tool "%s" not supported!' 39 | % monitoring) 40 | 41 | def add(self, category, k, v, it): 42 | if category not in self.stats: 43 | self.stats[category] = {} 44 | 45 | if k not in self.stats[category]: 46 | self.stats[category][k] = [] 47 | 48 | self.stats[category][k].append((it, v)) 49 | 50 | k_name = '%s/%s' % (category, k) 51 | if self.monitoring == 'telemetry': 52 | self.tm.metric_push_async({ 53 | 'metric': k_name, 'value': v, 'it': it 54 | }) 55 | elif self.monitoring == 'tensorboard': 56 | self.tb.add_scalar(k_name, v, it) 57 | 58 | def add_imgs(self, imgs, class_name, it): 59 | outdir = os.path.join(self.img_dir, class_name) 60 | if not os.path.exists(outdir): 61 | os.makedirs(outdir) 62 | outfile = os.path.join(outdir, '%08d.png' % it) 63 | 64 | imgs = imgs / 2 + 0.5 65 | imgs = torchvision.utils.make_grid(imgs) 66 | torchvision.utils.save_image(imgs.clone(), outfile, nrow=8) 67 | 68 | if self.monitoring == 'tensorboard': 69 | self.tb.add_image(class_name, imgs, it) 70 | 71 | def get_last(self, category, k, default=0.): 72 | if category not in self.stats: 73 | return default 74 | elif k not in self.stats[category]: 75 | return default 76 | else: 77 | return self.stats[category][k][-1][1] 78 | 79 | def save_stats(self, filename): 80 | filename = os.path.join(self.log_dir, filename) 81 | with open(filename, 'wb') as f: 82 | pickle.dump(self.stats, f) 83 | 84 | def load_stats(self, filename): 85 | filename = os.path.join(self.log_dir, filename) 86 | if not os.path.exists(filename): 87 | print('Warning: file "%s" does not exist!' % filename) 88 | return 89 | 90 | try: 91 | with open(filename, 'rb') as f: 92 | self.stats = pickle.load(f) 93 | except EOFError: 94 | print('Warning: log file corrupted!') 95 | -------------------------------------------------------------------------------- /submodules/GAN_stability/gan_training/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from GAN_stability.gan_training.metrics.inception_score import inception_score 2 | from GAN_stability.gan_training.metrics.fid_score import FIDEvaluator 3 | from GAN_stability.gan_training.metrics.kid_score import KIDEvaluator 4 | 5 | __all__ = [ 6 | inception_score, 7 | FIDEvaluator, 8 | KIDEvaluator 9 | ] 10 | -------------------------------------------------------------------------------- /submodules/GAN_stability/gan_training/metrics/inception_score.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | import torch.utils.data 5 | 6 | from torchvision.models.inception import inception_v3 7 | 8 | import numpy as np 9 | from scipy.stats import entropy 10 | 11 | 12 | def inception_score(imgs, device=None, batch_size=32, resize=False, splits=1): 13 | """Computes the inception score of the generated images imgs 14 | 15 | Args: 16 | imgs: Torch dataset of (3xHxW) numpy images normalized in the 17 | range [-1, 1] 18 | cuda: whether or not to run on GPU 19 | batch_size: batch size for feeding into Inception v3 20 | splits: number of splits 21 | """ 22 | N = len(imgs) 23 | 24 | assert batch_size > 0 25 | assert N > batch_size 26 | 27 | # Set up dataloader 28 | dataloader = torch.utils.data.DataLoader(imgs, batch_size=batch_size) 29 | 30 | # Load inception model 31 | inception_model = inception_v3(pretrained=True, transform_input=False) 32 | inception_model = inception_model.to(device) 33 | inception_model.eval() 34 | up = nn.Upsample(size=(299, 299), mode='bilinear').to(device) 35 | 36 | def get_pred(x): 37 | with torch.no_grad(): 38 | if resize: 39 | x = up(x) 40 | x = inception_model(x) 41 | out = F.softmax(x, dim=-1) 42 | out = out.cpu().numpy() 43 | return out 44 | 45 | # Get predictions 46 | preds = np.zeros((N, 1000)) 47 | 48 | for i, batch in enumerate(dataloader, 0): 49 | batchv = batch.to(device) 50 | batch_size_i = batch.size()[0] 51 | 52 | preds[i*batch_size:i*batch_size + batch_size_i] = get_pred(batchv) 53 | 54 | # Now compute the mean kl-div 55 | split_scores = [] 56 | 57 | for k in range(splits): 58 | part = preds[k * (N // splits): (k+1) * (N // splits), :] 59 | py = np.mean(part, axis=0) 60 | scores = [] 61 | for i in range(part.shape[0]): 62 | pyx = part[i, :] 63 | scores.append(entropy(pyx, py)) 64 | split_scores.append(np.exp(np.mean(scores))) 65 | 66 | return np.mean(split_scores), np.std(split_scores) 67 | -------------------------------------------------------------------------------- /submodules/GAN_stability/gan_training/models/__init__.py: -------------------------------------------------------------------------------- 1 | from GAN_stability.gan_training.models import ( 2 | resnet, resnet2, resnet3, resnet4, 3 | ) 4 | 5 | generator_dict = { 6 | 'resnet': resnet.Generator, 7 | 'resnet2': resnet2.Generator, 8 | 'resnet3': resnet3.Generator, 9 | 'resnet4': resnet4.Generator, 10 | } 11 | 12 | discriminator_dict = { 13 | 'resnet': resnet.Discriminator, 14 | 'resnet2': resnet2.Discriminator, 15 | 'resnet3': resnet3.Discriminator, 16 | 'resnet4': resnet4.Discriminator, 17 | } 18 | -------------------------------------------------------------------------------- /submodules/GAN_stability/gan_training/models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from torch.autograd import Variable 5 | import torch.utils.data 6 | import torch.utils.data.distributed 7 | import numpy as np 8 | 9 | 10 | class Generator(nn.Module): 11 | def __init__(self, z_dim, nlabels, size, embed_size=256, nfilter=64, nfilter_max=512, **kwargs): 12 | super().__init__() 13 | s0 = self.s0 = 4 14 | nf = self.nf = nfilter 15 | nf_max = self.nf_max = nfilter_max 16 | 17 | self.z_dim = z_dim 18 | 19 | # Submodules 20 | nlayers = int(np.log2(size / s0)) 21 | self.nf0 = min(nf_max, nf * 2**nlayers) 22 | 23 | self.embedding = nn.Embedding(nlabels, embed_size) 24 | self.fc = nn.Linear(z_dim + embed_size, self.nf0*s0*s0) 25 | 26 | blocks = [] 27 | for i in range(nlayers): 28 | nf0 = min(nf * 2**(nlayers-i), nf_max) 29 | nf1 = min(nf * 2**(nlayers-i-1), nf_max) 30 | blocks += [ 31 | ResnetBlock(nf0, nf1), 32 | nn.Upsample(scale_factor=2) 33 | ] 34 | 35 | blocks += [ 36 | ResnetBlock(nf, nf), 37 | ] 38 | 39 | self.resnet = nn.Sequential(*blocks) 40 | self.conv_img = nn.Conv2d(nf, 3, 3, padding=1) 41 | 42 | def forward(self, z, y): 43 | assert(z.size(0) == y.size(0)) 44 | batch_size = z.size(0) 45 | 46 | if y.dtype is torch.int64: 47 | yembed = self.embedding(y) 48 | else: 49 | yembed = y 50 | 51 | yembed = yembed / torch.norm(yembed, p=2, dim=1, keepdim=True) 52 | 53 | yz = torch.cat([z, yembed], dim=1) 54 | out = self.fc(yz) 55 | out = out.view(batch_size, self.nf0, self.s0, self.s0) 56 | 57 | out = self.resnet(out) 58 | 59 | out = self.conv_img(actvn(out)) 60 | out = torch.tanh(out) 61 | 62 | return out 63 | 64 | 65 | class Discriminator(nn.Module): 66 | def __init__(self, z_dim, nlabels, size, embed_size=256, nfilter=64, nfilter_max=1024): 67 | super().__init__() 68 | self.embed_size = embed_size 69 | s0 = self.s0 = 4 70 | nf = self.nf = nfilter 71 | nf_max = self.nf_max = nfilter_max 72 | 73 | # Submodules 74 | nlayers = int(np.log2(size / s0)) 75 | self.nf0 = min(nf_max, nf * 2**nlayers) 76 | 77 | blocks = [ 78 | ResnetBlock(nf, nf) 79 | ] 80 | 81 | for i in range(nlayers): 82 | nf0 = min(nf * 2**i, nf_max) 83 | nf1 = min(nf * 2**(i+1), nf_max) 84 | blocks += [ 85 | nn.AvgPool2d(3, stride=2, padding=1), 86 | ResnetBlock(nf0, nf1), 87 | ] 88 | 89 | self.conv_img = nn.Conv2d(3, 1*nf, 3, padding=1) 90 | self.resnet = nn.Sequential(*blocks) 91 | self.fc = nn.Linear(self.nf0*s0*s0, nlabels) 92 | 93 | def forward(self, x, y): 94 | assert(x.size(0) == y.size(0)) 95 | batch_size = x.size(0) 96 | 97 | out = self.conv_img(x) 98 | out = self.resnet(out) 99 | out = out.view(batch_size, self.nf0*self.s0*self.s0) 100 | out = self.fc(actvn(out)) 101 | 102 | index = Variable(torch.LongTensor(range(out.size(0)))) 103 | if y.is_cuda: 104 | index = index.cuda() 105 | out = out[index, y] 106 | 107 | return out 108 | 109 | 110 | class ResnetBlock(nn.Module): 111 | def __init__(self, fin, fout, fhidden=None, is_bias=True): 112 | super().__init__() 113 | # Attributes 114 | self.is_bias = is_bias 115 | self.learned_shortcut = (fin != fout) 116 | self.fin = fin 117 | self.fout = fout 118 | if fhidden is None: 119 | self.fhidden = min(fin, fout) 120 | else: 121 | self.fhidden = fhidden 122 | 123 | # Submodules 124 | self.conv_0 = nn.Conv2d(self.fin, self.fhidden, 3, stride=1, padding=1) 125 | self.conv_1 = nn.Conv2d(self.fhidden, self.fout, 3, stride=1, padding=1, bias=is_bias) 126 | if self.learned_shortcut: 127 | self.conv_s = nn.Conv2d(self.fin, self.fout, 1, stride=1, padding=0, bias=False) 128 | 129 | def forward(self, x): 130 | x_s = self._shortcut(x) 131 | dx = self.conv_0(actvn(x)) 132 | dx = self.conv_1(actvn(dx)) 133 | out = x_s + 0.1*dx 134 | 135 | return out 136 | 137 | def _shortcut(self, x): 138 | if self.learned_shortcut: 139 | x_s = self.conv_s(x) 140 | else: 141 | x_s = x 142 | return x_s 143 | 144 | 145 | def actvn(x): 146 | out = F.leaky_relu(x, 2e-1) 147 | return out 148 | -------------------------------------------------------------------------------- /submodules/GAN_stability/gan_training/models/resnet2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from torch.autograd import Variable 5 | import torch.utils.data 6 | import torch.utils.data.distributed 7 | 8 | 9 | class Generator(nn.Module): 10 | def __init__(self, z_dim, nlabels, size, embed_size=256, nfilter=64, **kwargs): 11 | super().__init__() 12 | s0 = self.s0 = size // 32 13 | nf = self.nf = nfilter 14 | self.z_dim = z_dim 15 | 16 | # Submodules 17 | self.embedding = nn.Embedding(nlabels, embed_size) 18 | self.fc = nn.Linear(z_dim + embed_size, 16*nf*s0*s0) 19 | 20 | self.resnet_0_0 = ResnetBlock(16*nf, 16*nf) 21 | self.resnet_0_1 = ResnetBlock(16*nf, 16*nf) 22 | 23 | self.resnet_1_0 = ResnetBlock(16*nf, 16*nf) 24 | self.resnet_1_1 = ResnetBlock(16*nf, 16*nf) 25 | 26 | self.resnet_2_0 = ResnetBlock(16*nf, 8*nf) 27 | self.resnet_2_1 = ResnetBlock(8*nf, 8*nf) 28 | 29 | self.resnet_3_0 = ResnetBlock(8*nf, 4*nf) 30 | self.resnet_3_1 = ResnetBlock(4*nf, 4*nf) 31 | 32 | self.resnet_4_0 = ResnetBlock(4*nf, 2*nf) 33 | self.resnet_4_1 = ResnetBlock(2*nf, 2*nf) 34 | 35 | self.resnet_5_0 = ResnetBlock(2*nf, 1*nf) 36 | self.resnet_5_1 = ResnetBlock(1*nf, 1*nf) 37 | 38 | self.conv_img = nn.Conv2d(nf, 3, 3, padding=1) 39 | 40 | def forward(self, z, y): 41 | assert(z.size(0) == y.size(0)) 42 | batch_size = z.size(0) 43 | 44 | if y.dtype is torch.int64: 45 | yembed = self.embedding(y) 46 | else: 47 | yembed = y 48 | 49 | yembed = yembed / torch.norm(yembed, p=2, dim=1, keepdim=True) 50 | 51 | yz = torch.cat([z, yembed], dim=1) 52 | out = self.fc(yz) 53 | out = out.view(batch_size, 16*self.nf, self.s0, self.s0) 54 | 55 | out = self.resnet_0_0(out) 56 | out = self.resnet_0_1(out) 57 | 58 | out = F.interpolate(out, scale_factor=2) 59 | out = self.resnet_1_0(out) 60 | out = self.resnet_1_1(out) 61 | 62 | out = F.interpolate(out, scale_factor=2) 63 | out = self.resnet_2_0(out) 64 | out = self.resnet_2_1(out) 65 | 66 | out = F.interpolate(out, scale_factor=2) 67 | out = self.resnet_3_0(out) 68 | out = self.resnet_3_1(out) 69 | 70 | out = F.interpolate(out, scale_factor=2) 71 | out = self.resnet_4_0(out) 72 | out = self.resnet_4_1(out) 73 | 74 | out = F.interpolate(out, scale_factor=2) 75 | out = self.resnet_5_0(out) 76 | out = self.resnet_5_1(out) 77 | 78 | out = self.conv_img(actvn(out)) 79 | out = torch.tanh(out) 80 | 81 | return out 82 | 83 | 84 | class Discriminator(nn.Module): 85 | def __init__(self, z_dim, nlabels, size, embed_size=256, nfilter=64, **kwargs): 86 | super().__init__() 87 | self.embed_size = embed_size 88 | s0 = self.s0 = size // 32 89 | nf = self.nf = nfilter 90 | ny = nlabels 91 | 92 | # Submodules 93 | self.conv_img = nn.Conv2d(3, 1*nf, 3, padding=1) 94 | 95 | self.resnet_0_0 = ResnetBlock(1*nf, 1*nf) 96 | self.resnet_0_1 = ResnetBlock(1*nf, 2*nf) 97 | 98 | self.resnet_1_0 = ResnetBlock(2*nf, 2*nf) 99 | self.resnet_1_1 = ResnetBlock(2*nf, 4*nf) 100 | 101 | self.resnet_2_0 = ResnetBlock(4*nf, 4*nf) 102 | self.resnet_2_1 = ResnetBlock(4*nf, 8*nf) 103 | 104 | self.resnet_3_0 = ResnetBlock(8*nf, 8*nf) 105 | self.resnet_3_1 = ResnetBlock(8*nf, 16*nf) 106 | 107 | self.resnet_4_0 = ResnetBlock(16*nf, 16*nf) 108 | self.resnet_4_1 = ResnetBlock(16*nf, 16*nf) 109 | 110 | self.resnet_5_0 = ResnetBlock(16*nf, 16*nf) 111 | self.resnet_5_1 = ResnetBlock(16*nf, 16*nf) 112 | 113 | self.fc = nn.Linear(16*nf*s0*s0, nlabels) 114 | 115 | 116 | def forward(self, x, y): 117 | assert(x.size(0) == y.size(0)) 118 | batch_size = x.size(0) 119 | 120 | out = self.conv_img(x) 121 | 122 | out = self.resnet_0_0(out) 123 | out = self.resnet_0_1(out) 124 | 125 | out = F.avg_pool2d(out, 3, stride=2, padding=1) 126 | out = self.resnet_1_0(out) 127 | out = self.resnet_1_1(out) 128 | 129 | out = F.avg_pool2d(out, 3, stride=2, padding=1) 130 | out = self.resnet_2_0(out) 131 | out = self.resnet_2_1(out) 132 | 133 | out = F.avg_pool2d(out, 3, stride=2, padding=1) 134 | out = self.resnet_3_0(out) 135 | out = self.resnet_3_1(out) 136 | 137 | out = F.avg_pool2d(out, 3, stride=2, padding=1) 138 | out = self.resnet_4_0(out) 139 | out = self.resnet_4_1(out) 140 | 141 | out = F.avg_pool2d(out, 3, stride=2, padding=1) 142 | out = self.resnet_5_0(out) 143 | out = self.resnet_5_1(out) 144 | 145 | out = out.view(batch_size, 16*self.nf*self.s0*self.s0) 146 | out = self.fc(actvn(out)) 147 | 148 | index = Variable(torch.LongTensor(range(out.size(0)))) 149 | if y.is_cuda: 150 | index = index.cuda() 151 | out = out[index, y] 152 | 153 | return out 154 | 155 | 156 | class ResnetBlock(nn.Module): 157 | def __init__(self, fin, fout, fhidden=None, is_bias=True): 158 | super().__init__() 159 | # Attributes 160 | self.is_bias = is_bias 161 | self.learned_shortcut = (fin != fout) 162 | self.fin = fin 163 | self.fout = fout 164 | if fhidden is None: 165 | self.fhidden = min(fin, fout) 166 | else: 167 | self.fhidden = fhidden 168 | 169 | # Submodules 170 | self.conv_0 = nn.Conv2d(self.fin, self.fhidden, 3, stride=1, padding=1) 171 | self.conv_1 = nn.Conv2d(self.fhidden, self.fout, 3, stride=1, padding=1, bias=is_bias) 172 | if self.learned_shortcut: 173 | self.conv_s = nn.Conv2d(self.fin, self.fout, 1, stride=1, padding=0, bias=False) 174 | 175 | 176 | def forward(self, x): 177 | x_s = self._shortcut(x) 178 | dx = self.conv_0(actvn(x)) 179 | dx = self.conv_1(actvn(dx)) 180 | out = x_s + 0.1*dx 181 | 182 | return out 183 | 184 | def _shortcut(self, x): 185 | if self.learned_shortcut: 186 | x_s = self.conv_s(x) 187 | else: 188 | x_s = x 189 | return x_s 190 | 191 | 192 | def actvn(x): 193 | out = F.leaky_relu(x, 2e-1) 194 | return out 195 | -------------------------------------------------------------------------------- /submodules/GAN_stability/gan_training/models/resnet3.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from torch.autograd import Variable 5 | import torch.utils.data 6 | import torch.utils.data.distributed 7 | 8 | 9 | class Generator(nn.Module): 10 | def __init__(self, z_dim, nlabels, size, embed_size=256, nfilter=64, **kwargs): 11 | super().__init__() 12 | s0 = self.s0 = size // 64 13 | nf = self.nf = nfilter 14 | self.z_dim = z_dim 15 | 16 | # Submodules 17 | self.embedding = nn.Embedding(nlabels, embed_size) 18 | self.fc = nn.Linear(z_dim + embed_size, 32*nf*s0*s0) 19 | 20 | self.resnet_0_0 = ResnetBlock(32*nf, 16*nf) 21 | self.resnet_1_0 = ResnetBlock(16*nf, 16*nf) 22 | self.resnet_2_0 = ResnetBlock(16*nf, 8*nf) 23 | self.resnet_3_0 = ResnetBlock(8*nf, 4*nf) 24 | self.resnet_4_0 = ResnetBlock(4*nf, 2*nf) 25 | self.resnet_5_0 = ResnetBlock(2*nf, 1*nf) 26 | self.conv_img = nn.Conv2d(nf, 3, 7, padding=3) 27 | 28 | def forward(self, z, y): 29 | assert(z.size(0) == y.size(0)) 30 | batch_size = z.size(0) 31 | 32 | yembed = self.embedding(y) 33 | yz = torch.cat([z, yembed], dim=1) 34 | out = self.fc(yz) 35 | out = out.view(batch_size, 32*self.nf, self.s0, self.s0) 36 | 37 | out = self.resnet_0_0(out) 38 | 39 | out = F.interpolate(out, scale_factor=2) 40 | out = self.resnet_1_0(out) 41 | 42 | out = F.interpolate(out, scale_factor=2) 43 | out = self.resnet_2_0(out) 44 | 45 | out = F.interpolate(out, scale_factor=2) 46 | out = self.resnet_3_0(out) 47 | 48 | out = F.interpolate(out, scale_factor=2) 49 | out = self.resnet_4_0(out) 50 | 51 | out = F.interpolate(out, scale_factor=2) 52 | out = self.resnet_5_0(out) 53 | 54 | out = F.interpolate(out, scale_factor=2) 55 | 56 | out = self.conv_img(actvn(out)) 57 | out = torch.tanh(out) 58 | 59 | return out 60 | 61 | 62 | class Discriminator(nn.Module): 63 | def __init__(self, z_dim, nlabels, size, embed_size=256, nfilter=64, **kwargs): 64 | super().__init__() 65 | self.embed_size = embed_size 66 | s0 = self.s0 = size // 64 67 | nf = self.nf = nfilter 68 | 69 | # Submodules 70 | self.conv_img = nn.Conv2d(3, 1*nf, 7, padding=3) 71 | 72 | self.resnet_0_0 = ResnetBlock(1*nf, 2*nf) 73 | self.resnet_1_0 = ResnetBlock(2*nf, 4*nf) 74 | self.resnet_2_0 = ResnetBlock(4*nf, 8*nf) 75 | self.resnet_3_0 = ResnetBlock(8*nf, 16*nf) 76 | self.resnet_4_0 = ResnetBlock(16*nf, 16*nf) 77 | self.resnet_5_0 = ResnetBlock(16*nf, 32*nf) 78 | 79 | self.fc = nn.Linear(32*nf*s0*s0, nlabels) 80 | 81 | def forward(self, x, y): 82 | assert(x.size(0) == y.size(0)) 83 | batch_size = x.size(0) 84 | 85 | out = self.conv_img(x) 86 | 87 | out = F.avg_pool2d(out, 3, stride=2, padding=1) 88 | out = self.resnet_0_0(out) 89 | 90 | out = F.avg_pool2d(out, 3, stride=2, padding=1) 91 | out = self.resnet_1_0(out) 92 | 93 | out = F.avg_pool2d(out, 3, stride=2, padding=1) 94 | out = self.resnet_2_0(out) 95 | 96 | out = F.avg_pool2d(out, 3, stride=2, padding=1) 97 | out = self.resnet_3_0(out) 98 | 99 | out = F.avg_pool2d(out, 3, stride=2, padding=1) 100 | out = self.resnet_4_0(out) 101 | 102 | out = F.avg_pool2d(out, 3, stride=2, padding=1) 103 | out = self.resnet_5_0(out) 104 | 105 | out = out.view(batch_size, 32*self.nf*self.s0*self.s0) 106 | out = self.fc(actvn(out)) 107 | 108 | index = Variable(torch.LongTensor(range(out.size(0)))) 109 | if y.is_cuda: 110 | index = index.cuda() 111 | out = out[index, y] 112 | 113 | return out 114 | 115 | 116 | class ResnetBlock(nn.Module): 117 | def __init__(self, fin, fout, fhidden=None, is_bias=True): 118 | super().__init__() 119 | # Attributes 120 | self.is_bias = is_bias 121 | self.learned_shortcut = (fin != fout) 122 | self.fin = fin 123 | self.fout = fout 124 | if fhidden is None: 125 | self.fhidden = min(fin, fout) 126 | else: 127 | self.fhidden = fhidden 128 | 129 | # Submodules 130 | self.conv_0 = nn.Conv2d(self.fin, self.fhidden, 3, stride=1, padding=1) 131 | self.conv_1 = nn.Conv2d(self.fhidden, self.fout, 3, stride=1, padding=1, bias=is_bias) 132 | if self.learned_shortcut: 133 | self.conv_s = nn.Conv2d(self.fin, self.fout, 1, stride=1, padding=0, bias=False) 134 | 135 | def forward(self, x): 136 | x_s = self._shortcut(x) 137 | dx = self.conv_0(actvn(x)) 138 | dx = self.conv_1(actvn(dx)) 139 | out = x_s + 0.1*dx 140 | 141 | return out 142 | 143 | def _shortcut(self, x): 144 | if self.learned_shortcut: 145 | x_s = self.conv_s(x) 146 | else: 147 | x_s = x 148 | return x_s 149 | 150 | 151 | def actvn(x): 152 | out = F.leaky_relu(x, 2e-1) 153 | return out 154 | -------------------------------------------------------------------------------- /submodules/GAN_stability/gan_training/models/resnet4.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from torch.autograd import Variable 5 | import torch.utils.data 6 | import torch.utils.data.distributed 7 | 8 | 9 | class Generator(nn.Module): 10 | def __init__(self, z_dim, nlabels, size, embed_size=256, nfilter=64, **kwargs): 11 | super().__init__() 12 | s0 = self.s0 = size // 64 13 | nf = self.nf = nfilter 14 | self.z_dim = z_dim 15 | 16 | # Submodules 17 | self.embedding = nn.Embedding(nlabels, embed_size) 18 | self.fc = nn.Linear(z_dim + embed_size, 16*nf*s0*s0) 19 | 20 | self.resnet_0_0 = ResnetBlock(16*nf, 16*nf) 21 | self.resnet_1_0 = ResnetBlock(16*nf, 16*nf) 22 | self.resnet_2_0 = ResnetBlock(16*nf, 8*nf) 23 | self.resnet_3_0 = ResnetBlock(8*nf, 4*nf) 24 | self.resnet_4_0 = ResnetBlock(4*nf, 2*nf) 25 | self.resnet_5_0 = ResnetBlock(2*nf, 1*nf) 26 | self.resnet_6_0 = ResnetBlock(1*nf, 1*nf) 27 | self.conv_img = nn.Conv2d(nf, 3, 7, padding=3) 28 | 29 | 30 | def forward(self, z, y): 31 | assert(z.size(0) == y.size(0)) 32 | batch_size = z.size(0) 33 | 34 | yembed = self.embedding(y) 35 | yz = torch.cat([z, yembed], dim=1) 36 | out = self.fc(yz) 37 | out = out.view(batch_size, 16*self.nf, self.s0, self.s0) 38 | 39 | out = self.resnet_0_0(out) 40 | 41 | out = F.interpolate(out, scale_factor=2) 42 | out = self.resnet_1_0(out) 43 | 44 | out = F.interpolate(out, scale_factor=2) 45 | out = self.resnet_2_0(out) 46 | 47 | out = F.interpolate(out, scale_factor=2) 48 | out = self.resnet_3_0(out) 49 | 50 | out = F.interpolate(out, scale_factor=2) 51 | out = self.resnet_4_0(out) 52 | 53 | out = F.interpolate(out, scale_factor=2) 54 | out = self.resnet_5_0(out) 55 | 56 | out = F.interpolate(out, scale_factor=2) 57 | out = self.resnet_6_0(out) 58 | out = self.conv_img(actvn(out)) 59 | out = torch.tanh(out) 60 | 61 | return out 62 | 63 | 64 | class Discriminator(nn.Module): 65 | def __init__(self, z_dim, nlabels, size, embed_size=256, nfilter=64, **kwargs): 66 | super().__init__() 67 | self.embed_size = embed_size 68 | s0 = self.s0 = size // 64 69 | nf = self.nf = nfilter 70 | 71 | # Submodules 72 | self.conv_img = nn.Conv2d(3, 1*nf, 7, padding=3) 73 | 74 | self.resnet_0_0 = ResnetBlock(1*nf, 1*nf) 75 | self.resnet_1_0 = ResnetBlock(1*nf, 2*nf) 76 | self.resnet_2_0 = ResnetBlock(2*nf, 4*nf) 77 | self.resnet_3_0 = ResnetBlock(4*nf, 8*nf) 78 | self.resnet_4_0 = ResnetBlock(8*nf, 16*nf) 79 | self.resnet_5_0 = ResnetBlock(16*nf, 16*nf) 80 | self.resnet_6_0 = ResnetBlock(16*nf, 16*nf) 81 | 82 | self.fc = nn.Linear(16*nf*s0*s0, nlabels) 83 | 84 | def forward(self, x, y): 85 | assert(x.size(0) == y.size(0)) 86 | batch_size = x.size(0) 87 | 88 | out = self.conv_img(x) 89 | out = self.resnet_0_0(out) 90 | 91 | out = F.avg_pool2d(out, 3, stride=2, padding=1) 92 | out = self.resnet_1_0(out) 93 | 94 | out = F.avg_pool2d(out, 3, stride=2, padding=1) 95 | out = self.resnet_2_0(out) 96 | 97 | out = F.avg_pool2d(out, 3, stride=2, padding=1) 98 | out = self.resnet_3_0(out) 99 | 100 | out = F.avg_pool2d(out, 3, stride=2, padding=1) 101 | out = self.resnet_4_0(out) 102 | 103 | out = F.avg_pool2d(out, 3, stride=2, padding=1) 104 | out = self.resnet_5_0(out) 105 | 106 | out = F.avg_pool2d(out, 3, stride=2, padding=1) 107 | out = self.resnet_6_0(out) 108 | 109 | out = out.view(batch_size, 16*self.nf*self.s0*self.s0) 110 | out = self.fc(actvn(out)) 111 | 112 | index = Variable(torch.LongTensor(range(out.size(0)))) 113 | if y.is_cuda: 114 | index = index.cuda() 115 | out = out[index, y] 116 | 117 | return out 118 | 119 | 120 | class ResnetBlock(nn.Module): 121 | def __init__(self, fin, fout, fhidden=None, is_bias=True): 122 | super().__init__() 123 | # Attributes 124 | self.is_bias = is_bias 125 | self.learned_shortcut = (fin != fout) 126 | self.fin = fin 127 | self.fout = fout 128 | if fhidden is None: 129 | self.fhidden = min(fin, fout) 130 | else: 131 | self.fhidden = fhidden 132 | 133 | # Submodules 134 | self.conv_0 = nn.Conv2d(self.fin, self.fhidden, 3, stride=1, padding=1) 135 | self.conv_1 = nn.Conv2d(self.fhidden, self.fout, 3, stride=1, padding=1, bias=is_bias) 136 | if self.learned_shortcut: 137 | self.conv_s = nn.Conv2d(self.fin, self.fout, 1, stride=1, padding=0, bias=False) 138 | 139 | def forward(self, x): 140 | x_s = self._shortcut(x) 141 | dx = self.conv_0(actvn(x)) 142 | dx = self.conv_1(actvn(dx)) 143 | out = x_s + 0.1*dx 144 | 145 | return out 146 | 147 | def _shortcut(self, x): 148 | if self.learned_shortcut: 149 | x_s = self.conv_s(x) 150 | else: 151 | x_s = x 152 | return x_s 153 | 154 | 155 | def actvn(x): 156 | out = F.leaky_relu(x, 2e-1) 157 | return out 158 | -------------------------------------------------------------------------------- /submodules/GAN_stability/gan_training/ops.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | from torch.nn import Parameter 4 | 5 | 6 | class SpectralNorm(nn.Module): 7 | def __init__(self, module, name='weight', power_iterations=1): 8 | super(SpectralNorm, self).__init__() 9 | self.module = module 10 | self.name = name 11 | self.power_iterations = power_iterations 12 | if not self._made_params(): 13 | self._make_params() 14 | 15 | def _update_u_v(self): 16 | u = getattr(self.module, self.name + "_u") 17 | v = getattr(self.module, self.name + "_v") 18 | w = getattr(self.module, self.name + "_bar") 19 | 20 | height = w.data.shape[0] 21 | for _ in range(self.power_iterations): 22 | v.data = l2normalize( 23 | torch.mv(torch.t(w.view(height, -1).data), u.data)) 24 | u.data = l2normalize( 25 | torch.mv(w.view(height, -1).data, v.data)) 26 | 27 | # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data)) 28 | sigma = u.dot(w.view(height, -1).mv(v)) 29 | setattr(self.module, self.name, w / sigma.expand_as(w)) 30 | 31 | def _made_params(self): 32 | made_params = ( 33 | hasattr(self.module, self.name + "_u") 34 | and hasattr(self.module, self.name + "_v") 35 | and hasattr(self.module, self.name + "_bar") 36 | ) 37 | return made_params 38 | 39 | def _make_params(self): 40 | w = getattr(self.module, self.name) 41 | 42 | height = w.data.shape[0] 43 | width = w.view(height, -1).data.shape[1] 44 | 45 | u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False) 46 | v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False) 47 | u.data = l2normalize(u.data) 48 | v.data = l2normalize(v.data) 49 | w_bar = Parameter(w.data) 50 | 51 | del self.module._parameters[self.name] 52 | 53 | self.module.register_parameter(self.name + "_u", u) 54 | self.module.register_parameter(self.name + "_v", v) 55 | self.module.register_parameter(self.name + "_bar", w_bar) 56 | 57 | def forward(self, *args): 58 | self._update_u_v() 59 | return self.module.forward(*args) 60 | 61 | 62 | def l2normalize(v, eps=1e-12): 63 | return v / (v.norm() + eps) 64 | 65 | 66 | class CBatchNorm(nn.Module): 67 | def __init__(self, nfilter, nlabels): 68 | super().__init__() 69 | # Attributes 70 | self.nlabels = nlabels 71 | self.nfilter = nfilter 72 | # Submodules 73 | self.alpha_embedding = nn.Embedding(nlabels, nfilter) 74 | self.beta_embedding = nn.Embedding(nlabels, nfilter) 75 | self.bn = nn.BatchNorm2d(nfilter, affine=False) 76 | # Initialize 77 | nn.init.constant_(self.alpha_embedding.weight, 1.) 78 | nn.init.constant_(self.beta_embedding.weight, 0.) 79 | 80 | def forward(self, x, y): 81 | dim = len(x.size()) 82 | batch_size = x.size(0) 83 | assert(dim >= 2) 84 | assert(x.size(1) == self.nfilter) 85 | 86 | s = [batch_size, self.nfilter] + [1] * (dim - 2) 87 | alpha = self.alpha_embedding(y) 88 | alpha = alpha.view(s) 89 | beta = self.beta_embedding(y) 90 | beta = beta.view(s) 91 | 92 | out = self.bn(x) 93 | out = alpha * out + beta 94 | 95 | return out 96 | 97 | 98 | class CInstanceNorm(nn.Module): 99 | def __init__(self, nfilter, nlabels): 100 | super().__init__() 101 | # Attributes 102 | self.nlabels = nlabels 103 | self.nfilter = nfilter 104 | # Submodules 105 | self.alpha_embedding = nn.Embedding(nlabels, nfilter) 106 | self.beta_embedding = nn.Embedding(nlabels, nfilter) 107 | self.bn = nn.InstanceNorm2d(nfilter, affine=False) 108 | # Initialize 109 | nn.init.uniform(self.alpha_embedding.weight, -1., 1.) 110 | nn.init.constant_(self.beta_embedding.weight, 0.) 111 | 112 | def forward(self, x, y): 113 | dim = len(x.size()) 114 | batch_size = x.size(0) 115 | assert(dim >= 2) 116 | assert(x.size(1) == self.nfilter) 117 | 118 | s = [batch_size, self.nfilter] + [1] * (dim - 2) 119 | alpha = self.alpha_embedding(y) 120 | alpha = alpha.view(s) 121 | beta = self.beta_embedding(y) 122 | beta = beta.view(s) 123 | 124 | out = self.bn(x) 125 | out = alpha * out + beta 126 | 127 | return out 128 | -------------------------------------------------------------------------------- /submodules/GAN_stability/gan_training/train.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | import torch 3 | from torch.nn import functional as F 4 | import torch.utils.data 5 | import torch.utils.data.distributed 6 | from torch import autograd 7 | 8 | 9 | class Trainer(object): 10 | def __init__(self, generator, discriminator, g_optimizer, d_optimizer, 11 | gan_type, reg_type, reg_param): 12 | self.generator = generator 13 | self.discriminator = discriminator 14 | self.g_optimizer = g_optimizer 15 | self.d_optimizer = d_optimizer 16 | 17 | self.gan_type = gan_type 18 | self.reg_type = reg_type 19 | self.reg_param = reg_param 20 | 21 | def generator_trainstep(self, y, z): 22 | assert(y.size(0) == z.size(0)) 23 | toggle_grad(self.generator, True) 24 | toggle_grad(self.discriminator, False) 25 | self.generator.train() 26 | self.discriminator.train() 27 | self.g_optimizer.zero_grad() 28 | 29 | x_fake = self.generator(z, y) 30 | d_fake = self.discriminator(x_fake, y) 31 | gloss = self.compute_loss(d_fake, 1) 32 | gloss.backward() 33 | 34 | self.g_optimizer.step() 35 | 36 | return gloss.item() 37 | 38 | def discriminator_trainstep(self, x_real, y, z): 39 | toggle_grad(self.generator, False) 40 | toggle_grad(self.discriminator, True) 41 | self.generator.train() 42 | self.discriminator.train() 43 | self.d_optimizer.zero_grad() 44 | 45 | # On real data 46 | x_real.requires_grad_() 47 | 48 | d_real = self.discriminator(x_real, y) 49 | dloss_real = self.compute_loss(d_real, 1) 50 | 51 | if self.reg_type == 'real' or self.reg_type == 'real_fake': 52 | dloss_real.backward(retain_graph=True) 53 | reg = self.reg_param * compute_grad2(d_real, x_real).mean() 54 | reg.backward() 55 | else: 56 | dloss_real.backward() 57 | 58 | # On fake data 59 | with torch.no_grad(): 60 | x_fake = self.generator(z, y) 61 | 62 | x_fake.requires_grad_() 63 | d_fake = self.discriminator(x_fake, y) 64 | dloss_fake = self.compute_loss(d_fake, 0) 65 | 66 | if self.reg_type == 'fake' or self.reg_type == 'real_fake': 67 | dloss_fake.backward(retain_graph=True) 68 | reg = self.reg_param * compute_grad2(d_fake, x_fake).mean() 69 | reg.backward() 70 | else: 71 | dloss_fake.backward() 72 | 73 | if self.reg_type == 'wgangp': 74 | reg = self.reg_param * self.wgan_gp_reg(x_real, x_fake, y) 75 | reg.backward() 76 | elif self.reg_type == 'wgangp0': 77 | reg = self.reg_param * self.wgan_gp_reg(x_real, x_fake, y, center=0.) 78 | reg.backward() 79 | 80 | self.d_optimizer.step() 81 | 82 | toggle_grad(self.discriminator, False) 83 | 84 | # Output 85 | dloss = (dloss_real + dloss_fake) 86 | 87 | if self.reg_type == 'none': 88 | reg = torch.tensor(0.) 89 | 90 | return dloss.item(), reg.item() 91 | 92 | def compute_loss(self, d_outs, target): 93 | 94 | d_outs = [d_outs] if not isinstance(d_outs, list) else d_outs 95 | loss = 0 96 | 97 | for d_out in d_outs: 98 | 99 | targets = d_out.new_full(size=d_out.size(), fill_value=target) 100 | 101 | if self.gan_type == 'standard': 102 | loss += F.binary_cross_entropy_with_logits(d_out, targets) 103 | elif self.gan_type == 'wgan': 104 | loss += (2*target - 1) * d_out.mean() 105 | else: 106 | raise NotImplementedError 107 | 108 | return loss / len(d_outs) 109 | 110 | def wgan_gp_reg(self, x_real, x_fake, y, center=1.): 111 | batch_size = y.size(0) 112 | eps = torch.rand(batch_size, device=y.device).view(batch_size, 1, 1, 1) 113 | x_interp = (1 - eps) * x_real + eps * x_fake 114 | x_interp = x_interp.detach() 115 | x_interp.requires_grad_() 116 | d_out = self.discriminator(x_interp, y) 117 | 118 | reg = (compute_grad2(d_out, x_interp).sqrt() - center).pow(2).mean() 119 | 120 | return reg 121 | 122 | 123 | # Utility functions 124 | def toggle_grad(model, requires_grad): 125 | for p in model.parameters(): 126 | p.requires_grad_(requires_grad) 127 | 128 | 129 | def compute_grad2(d_outs, x_in): 130 | d_outs = [d_outs] if not isinstance(d_outs, list) else d_outs 131 | reg = 0 132 | for d_out in d_outs: 133 | batch_size = x_in.size(0) 134 | grad_dout = autograd.grad( 135 | outputs=d_out.sum(), inputs=x_in, 136 | create_graph=True, retain_graph=True, only_inputs=True 137 | )[0] 138 | grad_dout2 = grad_dout.pow(2) 139 | assert(grad_dout2.size() == x_in.size()) 140 | reg += grad_dout2.view(batch_size, -1).sum(1) 141 | return reg / len(d_outs) 142 | 143 | 144 | def update_average(model_tgt, model_src, beta): 145 | toggle_grad(model_src, False) 146 | toggle_grad(model_tgt, False) 147 | 148 | param_dict_src = dict(model_src.named_parameters()) 149 | 150 | for p_name, p_tgt in model_tgt.named_parameters(): 151 | p_src = param_dict_src[p_name] 152 | assert(p_src is not p_tgt) 153 | p_tgt.copy_(beta*p_tgt + (1. - beta)*p_src) 154 | -------------------------------------------------------------------------------- /submodules/GAN_stability/gan_training/utils.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.utils.data 4 | import torch.utils.data.distributed 5 | import torchvision 6 | 7 | 8 | def save_images(imgs, outfile, nrow=8): 9 | imgs = imgs / 2 + 0.5 # unnormalize 10 | torchvision.utils.save_image(imgs, outfile, nrow=nrow) 11 | 12 | 13 | def get_nsamples(data_loader, N): 14 | x = [] 15 | y = [] 16 | n = 0 17 | while n < N: 18 | x_next, y_next = next(iter(data_loader)) 19 | x.append(x_next) 20 | y.append(y_next) 21 | n += x_next.size(0) 22 | x = torch.cat(x, dim=0)[:N] 23 | y = torch.cat(y, dim=0)[:N] 24 | return x, y 25 | 26 | 27 | def update_average(model_tgt, model_src, beta): 28 | param_dict_src = dict(model_src.named_parameters()) 29 | 30 | for p_name, p_tgt in model_tgt.named_parameters(): 31 | p_src = param_dict_src[p_name] 32 | assert(p_src is not p_tgt) 33 | p_tgt.copy_(beta*p_tgt + (1. - beta)*p_src) 34 | -------------------------------------------------------------------------------- /submodules/GAN_stability/interpolate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from os import path 4 | import copy 5 | import numpy as np 6 | import torch 7 | from torch import nn 8 | from gan_training import utils 9 | from gan_training.checkpoints import CheckpointIO 10 | from gan_training.distributions import get_ydist, get_zdist, interpolate_sphere 11 | from gan_training.config import ( 12 | load_config, build_models 13 | ) 14 | 15 | # Arguments 16 | parser = argparse.ArgumentParser( 17 | description='Create interpolations for a trained GAN.' 18 | ) 19 | parser.add_argument('config', type=str, help='Path to config file.') 20 | parser.add_argument('--no-cuda', action='store_true', help='Do not use cuda.') 21 | 22 | args = parser.parse_args() 23 | 24 | config = load_config(args.config, 'configs/default.yaml') 25 | is_cuda = (torch.cuda.is_available() and not args.no_cuda) 26 | 27 | # Shorthands 28 | nlabels = config['data']['nlabels'] 29 | out_dir = config['training']['out_dir'] 30 | batch_size = config['test']['batch_size'] 31 | sample_size = config['test']['sample_size'] 32 | sample_nrow = config['test']['sample_nrow'] 33 | checkpoint_dir = path.join(out_dir, 'chkpts') 34 | interp_dir = path.join(out_dir, 'test', 'interp') 35 | 36 | # Creat missing directories 37 | if not path.exists(interp_dir): 38 | os.makedirs(interp_dir) 39 | 40 | # Logger 41 | checkpoint_io = CheckpointIO( 42 | checkpoint_dir=checkpoint_dir 43 | ) 44 | 45 | # Get model file 46 | model_file = config['test']['model_file'] 47 | 48 | # Models 49 | device = torch.device("cuda:0" if is_cuda else "cpu") 50 | 51 | generator, discriminator = build_models(config) 52 | print(generator) 53 | print(discriminator) 54 | 55 | # Put models on gpu if needed 56 | generator = generator.to(device) 57 | discriminator = discriminator.to(device) 58 | 59 | # Use multiple GPUs if possible 60 | generator = nn.DataParallel(generator) 61 | discriminator = nn.DataParallel(discriminator) 62 | 63 | # Register modules to checkpoint 64 | checkpoint_io.register_modules( 65 | generator=generator, 66 | discriminator=discriminator, 67 | ) 68 | 69 | # Test generator 70 | if config['test']['use_model_average']: 71 | generator_test = copy.deepcopy(generator) 72 | checkpoint_io.register_modules(generator_test=generator_test) 73 | else: 74 | generator_test = generator 75 | 76 | # Distributions 77 | ydist = get_ydist(nlabels, device=device) 78 | zdist = get_zdist(config['z_dist']['type'], config['z_dist']['dim'], 79 | device=device) 80 | 81 | 82 | # Load checkpoint if existant 83 | load_dict = checkpoint_io.load(model_file) 84 | it = load_dict.get('it', -1) 85 | epoch_idx = load_dict.get('epoch_idx', -1) 86 | 87 | # Interpolations 88 | print('Creating interplations...') 89 | nsteps = config['interpolations']['nzs'] 90 | nsubsteps = config['interpolations']['nsubsteps'] 91 | 92 | y = ydist.sample((sample_size,)) 93 | zs = [zdist.sample((sample_size,)) for i in range(nsteps)] 94 | ts = np.linspace(0, 1, nsubsteps) 95 | 96 | it = 0 97 | for z1, z2 in zip(zs, zs[1:] + [zs[0]]): 98 | for t in ts: 99 | z = interpolate_sphere(z1, z2, float(t)) 100 | with torch.no_grad(): 101 | x = generator_test(z, y) 102 | utils.save_images(x, path.join(interp_dir, '%04d.png' % it), 103 | nrow=sample_nrow) 104 | it += 1 105 | print('%d/%d done!' % (it, nsteps * nsubsteps)) 106 | -------------------------------------------------------------------------------- /submodules/GAN_stability/interpolate_class.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from os import path 4 | import copy 5 | import numpy as np 6 | import torch 7 | from torch import nn 8 | from gan_training import utils 9 | from gan_training.checkpoints import CheckpointIO 10 | from gan_training.distributions import get_ydist, get_zdist, interpolate_sphere 11 | from gan_training.config import ( 12 | load_config, build_models 13 | ) 14 | 15 | # Arguments 16 | parser = argparse.ArgumentParser( 17 | description='Create interpolations for a trained GAN.' 18 | ) 19 | parser.add_argument('config', type=str, help='Path to config file.') 20 | parser.add_argument('--no-cuda', action='store_true', help='Do not use cuda.') 21 | 22 | args = parser.parse_args() 23 | 24 | config = load_config(args.config, 'configs/default.yaml') 25 | is_cuda = (torch.cuda.is_available() and not args.no_cuda) 26 | 27 | # Shorthands 28 | nlabels = config['data']['nlabels'] 29 | out_dir = config['training']['out_dir'] 30 | batch_size = config['test']['batch_size'] 31 | sample_size = config['test']['sample_size'] 32 | sample_nrow = config['test']['sample_nrow'] 33 | checkpoint_dir = path.join(out_dir, 'chkpts') 34 | interp_dir = path.join(out_dir, 'test', 'interp_class') 35 | 36 | # Creat missing directories 37 | if not path.exists(interp_dir): 38 | os.makedirs(interp_dir) 39 | 40 | # Logger 41 | checkpoint_io = CheckpointIO( 42 | checkpoint_dir=checkpoint_dir 43 | ) 44 | 45 | # Get model file 46 | model_file = config['test']['model_file'] 47 | 48 | # Models 49 | device = torch.device("cuda:0" if is_cuda else "cpu") 50 | 51 | generator, discriminator = build_models(config) 52 | print(generator) 53 | print(discriminator) 54 | 55 | # Put models on gpu if needed 56 | generator = generator.to(device) 57 | discriminator = discriminator.to(device) 58 | 59 | # Use multiple GPUs if possible 60 | generator = nn.DataParallel(generator) 61 | discriminator = nn.DataParallel(discriminator) 62 | 63 | # Register modules to checkpoint 64 | checkpoint_io.register_modules( 65 | generator=generator, 66 | discriminator=discriminator, 67 | ) 68 | 69 | # Test generator 70 | if config['test']['use_model_average']: 71 | generator_test = copy.deepcopy(generator) 72 | checkpoint_io.register_modules(generator_test=generator_test) 73 | else: 74 | generator_test = generator 75 | 76 | # Distributions 77 | ydist = get_ydist(nlabels, device=device) 78 | zdist = get_zdist(config['z_dist']['type'], config['z_dist']['dim'], 79 | device=device) 80 | 81 | 82 | # Load checkpoint if existant 83 | load_dict = checkpoint_io.load(model_file) 84 | it = load_dict.get('it', -1) 85 | epoch_idx = load_dict.get('epoch_idx', -1) 86 | 87 | # Interpolations 88 | print('Creating interplations...') 89 | nsubsteps = config['interpolations']['nsubsteps'] 90 | ys = config['interpolations']['ys'] 91 | 92 | nsteps = len(ys) 93 | z = zdist.sample((sample_size,)) 94 | ts = np.linspace(0, 1, nsubsteps) 95 | 96 | it = 0 97 | for y1, y2 in zip(ys, ys[1:] + [ys[0]]): 98 | for t in ts: 99 | y1_pt = torch.full((sample_size,), y1, dtype=torch.int64, device=device) 100 | y2_pt = torch.full((sample_size,), y2, dtype=torch.int64, device=device) 101 | y1_embed = generator_test.module.embedding(y1_pt) 102 | y2_embed = generator_test.module.embedding(y2_pt) 103 | t = float(t) 104 | y_embed = (1 - t) * y1_embed + t * y2_embed 105 | with torch.no_grad(): 106 | x = generator_test(z, y_embed) 107 | utils.save_images(x, path.join(interp_dir, '%04d.png' % it), 108 | nrow=sample_nrow) 109 | it += 1 110 | print('%d/%d done!' % (it, nsteps * nsubsteps)) 111 | -------------------------------------------------------------------------------- /submodules/GAN_stability/notebooks/create_video.sh: -------------------------------------------------------------------------------- 1 | EXEC=ffmpeg 2 | declare -a FOLDERS=( 3 | "simgd" 4 | "altgd1" 5 | "altgd5" 6 | ) 7 | declare -a SUBFOLDERS=( 8 | "gan" 9 | "gan_consensus" 10 | "gan_gradpen" 11 | "gan_gradpen_critical" 12 | "gan_instnoise" 13 | "nsgan" 14 | "nsgan_gradpen" 15 | "wgan" 16 | "wgan_gp" 17 | ) 18 | OPTIONS="-y" 19 | 20 | cd ./out 21 | for FOLDER in ${FOLDERS[@]}; do 22 | for SUBFOLDER in ${SUBFOLDERS[@]}; do 23 | INPUT="$FOLDER/animations/$SUBFOLDER/%06d.png" 24 | OUTPUT="$FOLDER/animations/$SUBFOLDER.mp4" 25 | $EXEC -framerate 30 -i $INPUT $OPTIONS $OUTPUT 26 | echo $FOLDER 27 | echo $SUBFOLDER 28 | done 29 | 30 | done 31 | -------------------------------------------------------------------------------- /submodules/GAN_stability/notebooks/diracgan/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/graf/c50d342fb567aec335b92e3f867c54b4dc4e1d09/submodules/GAN_stability/notebooks/diracgan/__init__.py -------------------------------------------------------------------------------- /submodules/GAN_stability/notebooks/diracgan/gans.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from diracgan.util import sigmoid, clip 3 | 4 | 5 | class VectorField(object): 6 | def __call__(self, theta, psi): 7 | theta_isfloat = isinstance(theta, float) 8 | psi_isfloat = isinstance(psi, float) 9 | if theta_isfloat: 10 | theta = np.array([theta]) 11 | if psi_isfloat: 12 | psi = np.array([psi]) 13 | 14 | v1, v2 = self._get_vector(theta, psi) 15 | 16 | if theta_isfloat: 17 | v1 = v1[0] 18 | if psi_isfloat: 19 | v2 = v2[0] 20 | 21 | return v1, v2 22 | 23 | def postprocess(self, theta, psi): 24 | theta_isfloat = isinstance(theta, float) 25 | psi_isfloat = isinstance(psi, float) 26 | if theta_isfloat: 27 | theta = np.array([theta]) 28 | if psi_isfloat: 29 | psi = np.array([psi]) 30 | theta, psi = self._postprocess(theta, psi) 31 | if theta_isfloat: 32 | theta = theta[0] 33 | if psi_isfloat: 34 | psi = psi[0] 35 | 36 | return theta, psi 37 | 38 | def step_sizes(self, h): 39 | return h, h 40 | 41 | def _get_vector(self, theta, psi): 42 | raise NotImplemented 43 | 44 | def _postprocess(self, theta, psi): 45 | return theta, psi 46 | 47 | 48 | # GANs 49 | def fp(x): 50 | return sigmoid(-x) 51 | 52 | 53 | def fp2(x): 54 | return -sigmoid(-x) * sigmoid(x) 55 | 56 | 57 | class GAN(VectorField): 58 | def _get_vector(self, theta, psi): 59 | v1 = -psi * fp(psi*theta) 60 | v2 = theta * fp(psi*theta) 61 | return v1, v2 62 | 63 | 64 | class NSGAN(VectorField): 65 | def _get_vector(self, theta, psi): 66 | v1 = -psi * fp(-psi*theta) 67 | v2 = theta * fp(psi*theta) 68 | return v1, v2 69 | 70 | 71 | class WGAN(VectorField): 72 | def __init__(self, clip=0.3): 73 | super().__init__() 74 | self.clip = clip 75 | 76 | def _get_vector(self, theta, psi): 77 | v1 = -psi 78 | v2 = theta 79 | 80 | return v1, v2 81 | 82 | def _postprocess(self, theta, psi): 83 | psi = clip(psi, self.clip) 84 | return theta, psi 85 | 86 | 87 | class WGAN_GP(VectorField): 88 | def __init__(self, reg=1., target=0.3): 89 | super().__init__() 90 | self.reg = reg 91 | self.target = target 92 | 93 | def _get_vector(self, theta, psi): 94 | v1 = -psi 95 | v2 = theta - self.reg * (np.abs(psi) - self.target) * np.sign(psi) 96 | return v1, v2 97 | 98 | 99 | class GAN_InstNoise(VectorField): 100 | def __init__(self, std=1): 101 | self.std = std 102 | 103 | def _get_vector(self, theta, psi): 104 | theta_eps = ( 105 | theta + self.std*np.random.randn(*([1000] + list(theta.shape))) 106 | ) 107 | x_eps = ( 108 | self.std * np.random.randn(*([1000] + list(theta.shape))) 109 | ) 110 | v1 = -psi * fp(psi*theta_eps) 111 | v2 = theta_eps * fp(psi*theta_eps) - x_eps * fp(-x_eps * psi) 112 | v1 = v1.mean(axis=0) 113 | v2 = v2.mean(axis=0) 114 | return v1, v2 115 | 116 | 117 | class GAN_GradPenalty(VectorField): 118 | def __init__(self, reg=0.3): 119 | self.reg = reg 120 | 121 | def _get_vector(self, theta, psi): 122 | v1 = -psi * fp(psi*theta) 123 | v2 = +theta * fp(psi*theta) - self.reg * psi 124 | return v1, v2 125 | 126 | 127 | class NSGAN_GradPenalty(VectorField): 128 | def __init__(self, reg=0.3): 129 | self.reg = reg 130 | 131 | def _get_vector(self, theta, psi): 132 | v1 = -psi * fp(-psi*theta) 133 | v2 = theta * fp(psi*theta) - self.reg * psi 134 | return v1, v2 135 | 136 | 137 | class GAN_Consensus(VectorField): 138 | def __init__(self, reg=0.3): 139 | self.reg = reg 140 | 141 | def _get_vector(self, theta, psi): 142 | v1 = -psi * fp(psi*theta) 143 | v2 = +theta * fp(psi*theta) 144 | 145 | # L 0.5*(psi**2 + theta**2)*f(psi*theta)**2 146 | v1reg = ( 147 | theta * fp(psi*theta)**2 148 | + 0.5*psi * (psi**2 + theta**2) * fp(psi*theta)*fp2(psi*theta) 149 | ) 150 | v2reg = ( 151 | psi * fp(psi*theta)**2 152 | + 0.5*theta * (psi**2 + theta**2) * fp(psi*theta)*fp2(psi*theta) 153 | ) 154 | v1 -= self.reg * v1reg 155 | v2 -= self.reg * v2reg 156 | 157 | return v1, v2 158 | 159 | -------------------------------------------------------------------------------- /submodules/GAN_stability/notebooks/diracgan/plotting.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from matplotlib import pyplot as plt 3 | import matplotlib.patches as patches 4 | import os 5 | from diracgan.gans import WGAN 6 | from diracgan.subplots import vector_field_plot 7 | from tqdm import tqdm 8 | 9 | 10 | def plot_vector(vecfn, theta, psi, outfile, trajectory=None, marker='b^'): 11 | fig, ax = plt.subplots(1, 1) 12 | theta, psi = np.meshgrid(theta, psi) 13 | v1, v2 = vecfn(theta, psi) 14 | if isinstance(vecfn, WGAN): 15 | clip_y = vecfn.clip 16 | else: 17 | clip_y = None 18 | vector_field_plot(theta, psi, v1, v2, trajectory, clip_y=clip_y, marker=marker) 19 | plt.savefig(outfile, bbox_inches='tight') 20 | plt.show() 21 | 22 | 23 | def simulate_trajectories(vecfn, theta, psi, trajectory, outfolder, maxframes=300): 24 | if not os.path.exists(outfolder): 25 | os.makedirs(outfolder) 26 | theta, psi = np.meshgrid(theta, psi) 27 | 28 | N = min(len(trajectory[0]), maxframes) 29 | 30 | v1, v2 = vecfn(theta, psi) 31 | if isinstance(vecfn, WGAN): 32 | clip_y = vecfn.clip 33 | else: 34 | clip_y = None 35 | 36 | for i in tqdm(range(1, N)): 37 | fig, (ax1, ax2) = plt.subplots(1, 2, 38 | subplot_kw=dict(adjustable='box', aspect=0.7)) 39 | 40 | plt.sca(ax1) 41 | trajectory_i = [trajectory[0][:i], trajectory[1][:i]] 42 | vector_field_plot(theta, psi, v1, v2, trajectory_i, clip_y=clip_y, marker='b-') 43 | plt.plot(trajectory_i[0][-1], trajectory_i[1][-1], 'bo') 44 | 45 | plt.sca(ax2) 46 | ax2.set_axisbelow(True) 47 | plt.grid() 48 | 49 | x = np.linspace(np.min(theta), np.max(theta)) 50 | y = x*trajectory[1][i] 51 | plt.plot(x, y, 'C1') 52 | 53 | ax2.add_patch(patches.Rectangle( 54 | (-0.05, 0), .1, 2.5, facecolor='C2' 55 | )) 56 | 57 | ax2.add_patch(patches.Rectangle( 58 | (trajectory[0][i]-0.05, 0), .1, 2.5, facecolor='C0' 59 | )) 60 | 61 | plt.xlim(np.min(theta), np.max(theta)) 62 | plt.ylim(-1, 3.) 63 | plt.xlabel(r'$\theta$') 64 | plt.xticks(np.linspace(np.min(theta), np.max(theta), 5)) 65 | ax2.set_yticklabels([]) 66 | 67 | plt.savefig(os.path.join(outfolder, '%06d.png' % i), dpi=200, bbox_inches='tight') 68 | plt.close() 69 | -------------------------------------------------------------------------------- /submodules/GAN_stability/notebooks/diracgan/simulate.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | # Simulate 4 | def trajectory_simgd(vec_fn, theta0, psi0, 5 | nsteps=50, hs_g=0.1, hs_d=0.1): 6 | theta, psi = vec_fn.postprocess(theta0, psi0) 7 | thetas, psis = [theta], [psi] 8 | 9 | if isinstance(hs_g, float): 10 | hs_g = [hs_g] * nsteps 11 | if isinstance(hs_d, float): 12 | hs_d = [hs_d] * nsteps 13 | assert(len(hs_g) == nsteps) 14 | assert(len(hs_d) == nsteps) 15 | 16 | for h_g, h_d in zip(hs_g, hs_d): 17 | v1, v2 = vec_fn(theta, psi) 18 | theta += h_g * v1 19 | psi += h_d * v2 20 | theta, psi = vec_fn.postprocess(theta, psi) 21 | thetas.append(theta) 22 | psis.append(psi) 23 | 24 | return thetas, psis 25 | 26 | 27 | def trajectory_altgd(vec_fn, theta0, psi0, 28 | nsteps=50, hs_g=0.1, hs_d=0.1, gsteps=1, dsteps=1): 29 | theta, psi = vec_fn.postprocess(theta0, psi0) 30 | thetas, psis = [theta], [psi] 31 | 32 | if isinstance(hs_g, float): 33 | hs_g = [hs_g] * nsteps 34 | if isinstance(hs_d, float): 35 | hs_d = [hs_d] * nsteps 36 | assert(len(hs_g) == nsteps) 37 | assert(len(hs_d) == nsteps) 38 | 39 | for h_g, h_d in zip(hs_g, hs_d): 40 | for it in range(gsteps): 41 | v1, v2 = vec_fn(theta, psi) 42 | theta += h_g * v1 43 | theta, psi = vec_fn.postprocess(theta, psi) 44 | 45 | for it in range(dsteps): 46 | v1, v2 = vec_fn(theta, psi) 47 | psi += h_d * v2 48 | theta, psi = vec_fn.postprocess(theta, psi) 49 | thetas.append(theta) 50 | psis.append(psi) 51 | 52 | return thetas, psis 53 | -------------------------------------------------------------------------------- /submodules/GAN_stability/notebooks/diracgan/subplots.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from matplotlib import pyplot as plt 3 | 4 | 5 | def arrow_plot(x, y, color='C1'): 6 | plt.quiver(x[:-1], y[:-1], x[1:]-x[:-1], y[1:]-y[:-1], 7 | color=color, scale_units='xy', angles='xy', scale=1) 8 | 9 | 10 | def vector_field_plot(theta, psi, v1, v2, trajectory=None, clip_y=None, marker='b^'): 11 | plt.quiver(theta, psi, v1, v2) 12 | if clip_y is not None: 13 | plt.axhspan(np.min(psi), -clip_y, facecolor='0.2', alpha=0.5) 14 | plt.plot([np.min(theta), np.max(theta)], [-clip_y, -clip_y], 'k-') 15 | plt.axhspan(clip_y, np.max(psi), facecolor='0.2', alpha=0.5) 16 | plt.plot([np.min(theta), np.max(theta)], [clip_y, clip_y], 'k-') 17 | 18 | if trajectory is not None: 19 | psis, thetas = trajectory 20 | plt.plot(psis, thetas, marker, markerfacecolor='None') 21 | plt.plot(psis[0], thetas[0], 'ro') 22 | 23 | plt.xlim(np.min(theta), np.max(theta)) 24 | plt.ylim(np.min(psi), np.max(psi)) 25 | plt.xlabel(r'$\theta$') 26 | plt.ylabel(r'$\psi$') 27 | plt.xticks(np.linspace(np.min(theta), np.max(theta), 5)) 28 | plt.yticks(np.linspace(np.min(psi), np.max(psi), 5)) 29 | -------------------------------------------------------------------------------- /submodules/GAN_stability/notebooks/diracgan/util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def sigmoid(x): 4 | m = np.minimum(0, x) 5 | return np.exp(m)/(np.exp(m) + np.exp(-x + m)) 6 | 7 | 8 | def clip(x, clipval=0.3): 9 | x = np.clip(x, -clipval, clipval) 10 | return x 11 | -------------------------------------------------------------------------------- /submodules/GAN_stability/results/celebA-HQ.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/graf/c50d342fb567aec335b92e3f867c54b4dc4e1d09/submodules/GAN_stability/results/celebA-HQ.jpg -------------------------------------------------------------------------------- /submodules/GAN_stability/results/imagenet_00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/graf/c50d342fb567aec335b92e3f867c54b4dc4e1d09/submodules/GAN_stability/results/imagenet_00.jpg -------------------------------------------------------------------------------- /submodules/GAN_stability/results/imagenet_01.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/graf/c50d342fb567aec335b92e3f867c54b4dc4e1d09/submodules/GAN_stability/results/imagenet_01.jpg -------------------------------------------------------------------------------- /submodules/GAN_stability/results/imagenet_02.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/graf/c50d342fb567aec335b92e3f867c54b4dc4e1d09/submodules/GAN_stability/results/imagenet_02.jpg -------------------------------------------------------------------------------- /submodules/GAN_stability/results/imagenet_03.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/graf/c50d342fb567aec335b92e3f867c54b4dc4e1d09/submodules/GAN_stability/results/imagenet_03.jpg -------------------------------------------------------------------------------- /submodules/GAN_stability/results/imagenet_04.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/graf/c50d342fb567aec335b92e3f867c54b4dc4e1d09/submodules/GAN_stability/results/imagenet_04.jpg -------------------------------------------------------------------------------- /submodules/GAN_stability/test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from os import path 4 | import copy 5 | from tqdm import tqdm 6 | import torch 7 | from torch import nn 8 | from gan_training import utils 9 | from gan_training.checkpoints import CheckpointIO 10 | from gan_training.distributions import get_ydist, get_zdist 11 | from gan_training.eval import Evaluator 12 | from gan_training.config import ( 13 | load_config, build_models 14 | ) 15 | 16 | # Arguments 17 | parser = argparse.ArgumentParser( 18 | description='Test a trained GAN and create visualizations.' 19 | ) 20 | parser.add_argument('config', type=str, help='Path to config file.') 21 | parser.add_argument('--no-cuda', action='store_true', help='Do not use cuda.') 22 | 23 | args = parser.parse_args() 24 | 25 | config = load_config(args.config, 'configs/default.yaml') 26 | is_cuda = (torch.cuda.is_available() and not args.no_cuda) 27 | 28 | # Shorthands 29 | nlabels = config['data']['nlabels'] 30 | out_dir = config['training']['out_dir'] 31 | batch_size = config['test']['batch_size'] 32 | sample_size = config['test']['sample_size'] 33 | sample_nrow = config['test']['sample_nrow'] 34 | checkpoint_dir = path.join(out_dir, 'chkpts') 35 | img_dir = path.join(out_dir, 'test', 'img') 36 | img_all_dir = path.join(out_dir, 'test', 'img_all') 37 | 38 | # Creat missing directories 39 | if not path.exists(img_dir): 40 | os.makedirs(img_dir) 41 | if not path.exists(img_all_dir): 42 | os.makedirs(img_all_dir) 43 | 44 | # Logger 45 | checkpoint_io = CheckpointIO( 46 | checkpoint_dir=checkpoint_dir 47 | ) 48 | 49 | # Get model file 50 | model_file = config['test']['model_file'] 51 | 52 | # Models 53 | device = torch.device("cuda:0" if is_cuda else "cpu") 54 | 55 | generator, discriminator = build_models(config) 56 | print(generator) 57 | print(discriminator) 58 | 59 | # Put models on gpu if needed 60 | generator = generator.to(device) 61 | discriminator = discriminator.to(device) 62 | 63 | # Use multiple GPUs if possible 64 | generator = nn.DataParallel(generator) 65 | discriminator = nn.DataParallel(discriminator) 66 | 67 | # Register modules to checkpoint 68 | checkpoint_io.register_modules( 69 | generator=generator, 70 | discriminator=discriminator, 71 | ) 72 | 73 | # Test generator 74 | if config['test']['use_model_average']: 75 | generator_test = copy.deepcopy(generator) 76 | checkpoint_io.register_modules(generator_test=generator_test) 77 | else: 78 | generator_test = generator 79 | 80 | # Distributions 81 | ydist = get_ydist(nlabels, device=device) 82 | zdist = get_zdist(config['z_dist']['type'], config['z_dist']['dim'], 83 | device=device) 84 | 85 | # Evaluator 86 | evaluator = Evaluator(generator_test, zdist, ydist, 87 | batch_size=batch_size, device=device) 88 | 89 | # Load checkpoint if existant 90 | load_dict = checkpoint_io.load(model_file) 91 | it = load_dict.get('it', -1) 92 | epoch_idx = load_dict.get('epoch_idx', -1) 93 | 94 | # Inception score 95 | if config['test']['compute_inception']: 96 | print('Computing inception score...') 97 | inception_mean, inception_std = evaluator.compute_inception_score() 98 | print('Inception score: %.4f +- %.4f' % (inception_mean, inception_std)) 99 | 100 | # Samples 101 | print('Creating samples...') 102 | ztest = zdist.sample((sample_size,)) 103 | x = evaluator.create_samples(ztest) 104 | utils.save_images(x, path.join(img_all_dir, '%08d.png' % it), 105 | nrow=sample_nrow) 106 | if config['test']['conditional_samples']: 107 | for y_inst in tqdm(range(nlabels)): 108 | x = evaluator.create_samples(ztest, y_inst) 109 | utils.save_images(x, path.join(img_dir, '%04d.png' % y_inst), 110 | nrow=sample_nrow) 111 | -------------------------------------------------------------------------------- /submodules/GAN_stability/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from os import path 4 | import time 5 | import copy 6 | import torch 7 | from torch import nn 8 | from gan_training import utils 9 | from gan_training.train import Trainer, update_average 10 | from gan_training.logger import Logger 11 | from gan_training.checkpoints import CheckpointIO 12 | from gan_training.inputs import get_dataset 13 | from gan_training.distributions import get_ydist, get_zdist 14 | from gan_training.eval import Evaluator 15 | from gan_training.config import ( 16 | load_config, build_models, build_optimizers, build_lr_scheduler, 17 | ) 18 | 19 | # Arguments 20 | parser = argparse.ArgumentParser( 21 | description='Train a GAN with different regularization strategies.' 22 | ) 23 | parser.add_argument('config', type=str, help='Path to config file.') 24 | parser.add_argument('--no-cuda', action='store_true', help='Do not use cuda.') 25 | 26 | args = parser.parse_args() 27 | 28 | config = load_config(args.config, 'configs/default.yaml') 29 | is_cuda = (torch.cuda.is_available() and not args.no_cuda) 30 | 31 | # Short hands 32 | batch_size = config['training']['batch_size'] 33 | d_steps = config['training']['d_steps'] 34 | restart_every = config['training']['restart_every'] 35 | inception_every = config['training']['inception_every'] 36 | save_every = config['training']['save_every'] 37 | backup_every = config['training']['backup_every'] 38 | sample_nlabels = config['training']['sample_nlabels'] 39 | 40 | out_dir = config['training']['out_dir'] 41 | checkpoint_dir = path.join(out_dir, 'chkpts') 42 | 43 | # Create missing directories 44 | if not path.exists(out_dir): 45 | os.makedirs(out_dir) 46 | if not path.exists(checkpoint_dir): 47 | os.makedirs(checkpoint_dir) 48 | 49 | # Logger 50 | checkpoint_io = CheckpointIO( 51 | checkpoint_dir=checkpoint_dir 52 | ) 53 | 54 | device = torch.device("cuda:0" if is_cuda else "cpu") 55 | 56 | 57 | # Dataset 58 | train_dataset, nlabels = get_dataset( 59 | name=config['data']['type'], 60 | data_dir=config['data']['train_dir'], 61 | size=config['data']['img_size'], 62 | lsun_categories=config['data']['lsun_categories_train'] 63 | ) 64 | train_loader = torch.utils.data.DataLoader( 65 | train_dataset, 66 | batch_size=batch_size, 67 | num_workers=config['training']['nworkers'], 68 | shuffle=True, pin_memory=True, sampler=None, drop_last=True 69 | ) 70 | 71 | # Number of labels 72 | nlabels = min(nlabels, config['data']['nlabels']) 73 | sample_nlabels = min(nlabels, sample_nlabels) 74 | 75 | # Create models 76 | generator, discriminator = build_models(config) 77 | print(generator) 78 | print(discriminator) 79 | 80 | # Put models on gpu if needed 81 | generator = generator.to(device) 82 | discriminator = discriminator.to(device) 83 | 84 | g_optimizer, d_optimizer = build_optimizers( 85 | generator, discriminator, config 86 | ) 87 | 88 | # Use multiple GPUs if possible 89 | generator = nn.DataParallel(generator) 90 | discriminator = nn.DataParallel(discriminator) 91 | 92 | # Register modules to checkpoint 93 | checkpoint_io.register_modules( 94 | generator=generator, 95 | discriminator=discriminator, 96 | g_optimizer=g_optimizer, 97 | d_optimizer=d_optimizer, 98 | ) 99 | 100 | # Get model file 101 | model_file = config['training']['model_file'] 102 | 103 | # Logger 104 | logger = Logger( 105 | log_dir=path.join(out_dir, 'logs'), 106 | img_dir=path.join(out_dir, 'imgs'), 107 | monitoring=config['training']['monitoring'], 108 | monitoring_dir=path.join(out_dir, 'monitoring') 109 | ) 110 | 111 | # Distributions 112 | ydist = get_ydist(nlabels, device=device) 113 | zdist = get_zdist(config['z_dist']['type'], config['z_dist']['dim'], 114 | device=device) 115 | 116 | # Save for tests 117 | ntest = batch_size 118 | x_real, ytest = utils.get_nsamples(train_loader, ntest) 119 | ytest.clamp_(None, nlabels-1) 120 | ztest = zdist.sample((ntest,)) 121 | utils.save_images(x_real, path.join(out_dir, 'real.png')) 122 | 123 | # Test generator 124 | if config['training']['take_model_average']: 125 | generator_test = copy.deepcopy(generator) 126 | checkpoint_io.register_modules(generator_test=generator_test) 127 | else: 128 | generator_test = generator 129 | 130 | # Evaluator 131 | evaluator = Evaluator(generator_test, zdist, ydist, 132 | batch_size=batch_size, device=device) 133 | 134 | # Train 135 | tstart = t0 = time.time() 136 | 137 | # Load checkpoint if it exists 138 | try: 139 | load_dict = checkpoint_io.load(model_file) 140 | except FileNotFoundError: 141 | it = epoch_idx = -1 142 | else: 143 | it = load_dict.get('it', -1) 144 | epoch_idx = load_dict.get('epoch_idx', -1) 145 | logger.load_stats('stats.p') 146 | 147 | # Reinitialize model average if needed 148 | if (config['training']['take_model_average'] 149 | and config['training']['model_average_reinit']): 150 | update_average(generator_test, generator, 0.) 151 | 152 | # Learning rate anneling 153 | g_scheduler = build_lr_scheduler(g_optimizer, config, last_epoch=it) 154 | d_scheduler = build_lr_scheduler(d_optimizer, config, last_epoch=it) 155 | 156 | # Trainer 157 | trainer = Trainer( 158 | generator, discriminator, g_optimizer, d_optimizer, 159 | gan_type=config['training']['gan_type'], 160 | reg_type=config['training']['reg_type'], 161 | reg_param=config['training']['reg_param'] 162 | ) 163 | 164 | # Training loop 165 | print('Start training...') 166 | while True: 167 | epoch_idx += 1 168 | print('Start epoch %d...' % epoch_idx) 169 | 170 | for x_real, y in train_loader: 171 | it += 1 172 | g_scheduler.step() 173 | d_scheduler.step() 174 | 175 | d_lr = d_optimizer.param_groups[0]['lr'] 176 | g_lr = g_optimizer.param_groups[0]['lr'] 177 | logger.add('learning_rates', 'discriminator', d_lr, it=it) 178 | logger.add('learning_rates', 'generator', g_lr, it=it) 179 | 180 | x_real, y = x_real.to(device), y.to(device) 181 | y.clamp_(None, nlabels-1) 182 | 183 | # Discriminator updates 184 | z = zdist.sample((batch_size,)) 185 | dloss, reg = trainer.discriminator_trainstep(x_real, y, z) 186 | logger.add('losses', 'discriminator', dloss, it=it) 187 | logger.add('losses', 'regularizer', reg, it=it) 188 | 189 | # Generators updates 190 | if ((it + 1) % d_steps) == 0: 191 | z = zdist.sample((batch_size,)) 192 | gloss = trainer.generator_trainstep(y, z) 193 | logger.add('losses', 'generator', gloss, it=it) 194 | 195 | if config['training']['take_model_average']: 196 | update_average(generator_test, generator, 197 | beta=config['training']['model_average_beta']) 198 | 199 | # Print stats 200 | g_loss_last = logger.get_last('losses', 'generator') 201 | d_loss_last = logger.get_last('losses', 'discriminator') 202 | d_reg_last = logger.get_last('losses', 'regularizer') 203 | print('[epoch %0d, it %4d] g_loss = %.4f, d_loss = %.4f, reg=%.4f' 204 | % (epoch_idx, it, g_loss_last, d_loss_last, d_reg_last)) 205 | 206 | # (i) Sample if necessary 207 | if (it % config['training']['sample_every']) == 0: 208 | print('Creating samples...') 209 | x = evaluator.create_samples(ztest, ytest) 210 | logger.add_imgs(x, 'all', it) 211 | for y_inst in range(sample_nlabels): 212 | x = evaluator.create_samples(ztest, y_inst) 213 | logger.add_imgs(x, '%04d' % y_inst, it) 214 | 215 | # (ii) Compute inception if necessary 216 | if inception_every > 0 and ((it + 1) % inception_every) == 0: 217 | inception_mean, inception_std = evaluator.compute_inception_score() 218 | logger.add('inception_score', 'mean', inception_mean, it=it) 219 | logger.add('inception_score', 'stddev', inception_std, it=it) 220 | 221 | # (iii) Backup if necessary 222 | if ((it + 1) % backup_every) == 0: 223 | print('Saving backup...') 224 | checkpoint_io.save('model_%08d.pt' % it, it=it) 225 | logger.save_stats('stats_%08d.p' % it) 226 | 227 | # (iv) Save checkpoint if necessary 228 | if time.time() - t0 > save_every: 229 | print('Saving checkpoint...') 230 | checkpoint_io.save(model_file, it=it) 231 | logger.save_stats('stats.p') 232 | t0 = time.time() 233 | 234 | if (restart_every > 0 and t0 - tstart > restart_every): 235 | exit(3) 236 | -------------------------------------------------------------------------------- /submodules/nerf_pytorch/.gitignore: -------------------------------------------------------------------------------- 1 | **/.ipynb_checkpoints 2 | **/__pycache__ 3 | *.png 4 | *.mp4 5 | *.npy 6 | *.npz 7 | *.dae 8 | data/* 9 | logs/* -------------------------------------------------------------------------------- /submodules/nerf_pytorch/.gitmodules: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/graf/c50d342fb567aec335b92e3f867c54b4dc4e1d09/submodules/nerf_pytorch/.gitmodules -------------------------------------------------------------------------------- /submodules/nerf_pytorch/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 bmild 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /submodules/nerf_pytorch/README.md: -------------------------------------------------------------------------------- 1 | # NeRF-pytorch 2 | 3 | 4 | [NeRF](http://www.matthewtancik.com/nerf) (Neural Radiance Fields) is a method that achieves state-of-the-art results for synthesizing novel views of complex scenes. Here are some videos generated by this repository (pre-trained models are provided below): 5 | 6 | ![](https://user-images.githubusercontent.com/7057863/78472232-cf374a00-7769-11ea-8871-0bc710951839.gif) 7 | ![](https://user-images.githubusercontent.com/7057863/78472235-d1010d80-7769-11ea-9be9-51365180e063.gif) 8 | 9 | This project is a faithful PyTorch implementation of [NeRF](http://www.matthewtancik.com/nerf) that **reproduces** the results while running **1.3 times faster**. The code is based on authors' Tensorflow implementation [here](https://github.com/bmild/nerf), and has been tested to match it numerically. 10 | 11 | ## Installation 12 | 13 | ``` 14 | git clone https://github.com/yenchenlin/nerf-pytorch.git 15 | cd nerf-pytorch 16 | pip install -r requirements.txt 17 | cd torchsearchsorted 18 | pip install . 19 | cd ../ 20 | ``` 21 | 22 |
23 | Dependencies (click to expand) 24 | 25 | ## Dependencies 26 | - PyTorch 1.4 27 | - matplotlib 28 | - numpy 29 | - imageio 30 | - imageio-ffmpeg 31 | - configargparse 32 | 33 | The LLFF data loader requires ImageMagick. 34 | 35 | You will also need the [LLFF code](http://github.com/fyusion/llff) (and COLMAP) set up to compute poses if you want to run on your own real data. 36 | 37 |
38 | 39 | ## How To Run? 40 | 41 | ### Quick Start 42 | 43 | Download data for two example datasets: `lego` and `fern` 44 | ``` 45 | bash download_example_data.sh 46 | ``` 47 | 48 | To train a low-res `lego` NeRF: 49 | ``` 50 | python run_nerf.py --config configs/config_lego.txt 51 | ``` 52 | After training for 100k iterations (~4 hours on a single 2080 Ti), you can find the following video at `logs/lego_test/lego_test_spiral_100000_rgb.mp4`. 53 | 54 | ![](https://user-images.githubusercontent.com/7057863/78473103-9353b300-7770-11ea-98ed-6ba2d877b62c.gif) 55 | 56 | --- 57 | 58 | To train a low-res `fern` NeRF: 59 | ``` 60 | python run_nerf.py --config configs/config_fern.txt 61 | ``` 62 | After training for 200k iterations (~8 hours on a single 2080 Ti), you can find the following video at `logs/fern_test/fern_test_spiral_200000_rgb.mp4` and `logs/fern_test/fern_test_spiral_200000_disp.mp4` 63 | 64 | ![](https://user-images.githubusercontent.com/7057863/78473081-58ea1600-7770-11ea-92ce-2bbf6a3f9add.gif) 65 | 66 | --- 67 | 68 | ### More Datasets 69 | To play with other scenes presented in the paper, download the data [here](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1). Place the downloaded dataset according to the following directory structure: 70 | ``` 71 | ├── configs 72 | │   ├── ... 73 | │   74 | ├── data 75 | │   ├── nerf_llff_data 76 | │   │   └── fern 77 | │   │  └── flower # downloaded llff dataset 78 | │   │  └── horns # downloaded llff dataset 79 | | | └── ... 80 | | ├── nerf_synthetic 81 | | | └── lego 82 | | | └── ship # downloaded synthetic dataset 83 | | | └── ... 84 | ``` 85 | 86 | --- 87 | 88 | To train NeRF on different datasets: 89 | 90 | ``` 91 | python run_nerf.py --config configs/config_{DATASET}.txt 92 | ``` 93 | 94 | replace `{DATASET}` with `trex` | `horns` | `flower` | `fortress` | `lego` | etc. 95 | 96 | --- 97 | 98 | To test NeRF trained on different datasets: 99 | 100 | ``` 101 | python run_nerf.py --config configs/config_{DATASET}.txt --render_only 102 | ``` 103 | 104 | replace `{DATASET}` with `trex` | `horns` | `flower` | `fortress` | `lego` | etc. 105 | 106 | 107 | ### Pre-trained Models 108 | 109 | You can download the pre-trained models [here](https://drive.google.com/drive/folders/1jIr8dkvefrQmv737fFm2isiT6tqpbTbv?usp=sharing). Place the downloaded directory in `./logs` in order to test it later. See the following directory structure for an example: 110 | 111 | ``` 112 | ├── logs 113 | │   ├── fern_test 114 | │   ├── flower_test # downloaded logs 115 | │ ├── trex_test # downloaded logs 116 | ``` 117 | 118 | ### Reproducibility 119 | 120 | Tests that ensure the results of all functions and training loop match the official implentation are contained in a different branch `reproduce`. One can check it out and run the tests: 121 | ``` 122 | git checkout reproduce 123 | py.test 124 | ``` 125 | 126 | ## Method 127 | 128 | [NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis](http://tancik.com/nerf) 129 | [Ben Mildenhall](https://people.eecs.berkeley.edu/~bmild/)\*1, 130 | [Pratul P. Srinivasan](https://people.eecs.berkeley.edu/~pratul/)\*1, 131 | [Matthew Tancik](http://tancik.com/)\*1, 132 | [Jonathan T. Barron](http://jonbarron.info/)2, 133 | [Ravi Ramamoorthi](http://cseweb.ucsd.edu/~ravir/)3, 134 | [Ren Ng](https://www2.eecs.berkeley.edu/Faculty/Homepages/yirenng.html)1
135 | 1UC Berkeley, 2Google Research, 3UC San Diego 136 | \*denotes equal contribution 137 | 138 | 139 | 140 | > A neural radiance field is a simple fully connected network (weights are ~5MB) trained to reproduce input views of a single scene using a rendering loss. The network directly maps from spatial location and viewing direction (5D input) to color and opacity (4D output), acting as the "volume" so we can use volume rendering to differentiably render new views 141 | 142 | 143 | ## Citation 144 | Kudos to the authors for their amazing results: 145 | ``` 146 | @misc{mildenhall2020nerf, 147 | title={NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis}, 148 | author={Ben Mildenhall and Pratul P. Srinivasan and Matthew Tancik and Jonathan T. Barron and Ravi Ramamoorthi and Ren Ng}, 149 | year={2020}, 150 | eprint={2003.08934}, 151 | archivePrefix={arXiv}, 152 | primaryClass={cs.CV} 153 | } 154 | ``` 155 | 156 | However, if you find this implementation or pre-trained models helpful, please consider to cite: 157 | ``` 158 | @misc{lin2020nerfpytorch, 159 | title={NeRF-pytorch}, 160 | author={Yen-Chen, Lin}, 161 | howpublished={\url{https://github.com/yenchenlin/nerf-pytorch/}}, 162 | year={2020} 163 | } 164 | ``` 165 | -------------------------------------------------------------------------------- /submodules/nerf_pytorch/configs/config_fern.txt: -------------------------------------------------------------------------------- 1 | expname = fern_test 2 | basedir = ./logs 3 | datadir = ./data/nerf_llff_data/fern 4 | dataset_type = llff 5 | 6 | factor = 8 7 | llffhold = 8 8 | 9 | N_rand = 1024 10 | N_samples = 64 11 | N_importance = 64 12 | 13 | use_viewdirs = True 14 | raw_noise_std = 1e0 15 | 16 | -------------------------------------------------------------------------------- /submodules/nerf_pytorch/configs/config_flower.txt: -------------------------------------------------------------------------------- 1 | expname = flower_test 2 | basedir = ./logs 3 | datadir = ./data/nerf_llff_data/flower 4 | dataset_type = llff 5 | 6 | factor = 8 7 | llffhold = 8 8 | 9 | N_rand = 1024 10 | N_samples = 64 11 | N_importance = 64 12 | 13 | use_viewdirs = True 14 | raw_noise_std = 1e0 15 | 16 | -------------------------------------------------------------------------------- /submodules/nerf_pytorch/configs/config_fortress.txt: -------------------------------------------------------------------------------- 1 | expname = fortress_test 2 | basedir = ./logs 3 | datadir = ./data/nerf_llff_data/fortress 4 | dataset_type = llff 5 | 6 | factor = 8 7 | llffhold = 8 8 | 9 | N_rand = 1024 10 | N_samples = 64 11 | N_importance = 64 12 | 13 | use_viewdirs = True 14 | raw_noise_std = 1e0 15 | 16 | -------------------------------------------------------------------------------- /submodules/nerf_pytorch/configs/config_horns.txt: -------------------------------------------------------------------------------- 1 | expname = horns_test 2 | basedir = ./logs 3 | datadir = ./data/nerf_llff_data/horns 4 | dataset_type = llff 5 | 6 | factor = 8 7 | llffhold = 8 8 | 9 | N_rand = 1024 10 | N_samples = 64 11 | N_importance = 64 12 | 13 | use_viewdirs = True 14 | raw_noise_std = 1e0 15 | 16 | -------------------------------------------------------------------------------- /submodules/nerf_pytorch/configs/config_lego.txt: -------------------------------------------------------------------------------- 1 | expname = lego_test 2 | basedir = ./logs 3 | datadir = ./data/nerf_synthetic/lego 4 | dataset_type = blender 5 | 6 | half_res = True 7 | 8 | N_samples = 64 9 | N_importance = 64 10 | 11 | use_viewdirs = True 12 | 13 | white_bkgd = True 14 | 15 | N_rand = 1024 -------------------------------------------------------------------------------- /submodules/nerf_pytorch/configs/config_trex.txt: -------------------------------------------------------------------------------- 1 | expname = trex_test 2 | basedir = ./logs 3 | datadir = ./data/nerf_llff_data/trex 4 | dataset_type = llff 5 | 6 | factor = 8 7 | llffhold = 8 8 | 9 | N_rand = 1024 10 | N_samples = 64 11 | N_importance = 64 12 | 13 | use_viewdirs = True 14 | raw_noise_std = 1e0 15 | 16 | -------------------------------------------------------------------------------- /submodules/nerf_pytorch/download_example_data.sh: -------------------------------------------------------------------------------- 1 | wget https://people.eecs.berkeley.edu/~bmild/nerf/tiny_nerf_data.npz 2 | mkdir -p data 3 | cd data 4 | wget https://people.eecs.berkeley.edu/~bmild/nerf/nerf_example_data.zip 5 | unzip nerf_example_data.zip 6 | cd .. 7 | -------------------------------------------------------------------------------- /submodules/nerf_pytorch/imgs/pipeline.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonomousvision/graf/c50d342fb567aec335b92e3f867c54b4dc4e1d09/submodules/nerf_pytorch/imgs/pipeline.jpg -------------------------------------------------------------------------------- /submodules/nerf_pytorch/load_blender.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import imageio 5 | import json 6 | import torch.nn.functional as F 7 | import cv2 8 | 9 | 10 | trans_t = lambda t : torch.Tensor([ 11 | [1,0,0,0], 12 | [0,1,0,0], 13 | [0,0,1,t], 14 | [0,0,0,1]]).float() 15 | 16 | rot_phi = lambda phi : torch.Tensor([ 17 | [1,0,0,0], 18 | [0,np.cos(phi),-np.sin(phi),0], 19 | [0,np.sin(phi), np.cos(phi),0], 20 | [0,0,0,1]]).float() 21 | 22 | rot_theta = lambda th : torch.Tensor([ 23 | [np.cos(th),0,-np.sin(th),0], 24 | [0,1,0,0], 25 | [np.sin(th),0, np.cos(th),0], 26 | [0,0,0,1]]).float() 27 | 28 | 29 | def pose_spherical(theta, phi, radius): 30 | c2w = trans_t(radius) 31 | c2w = rot_phi(phi/180.*np.pi) @ c2w 32 | c2w = rot_theta(theta/180.*np.pi) @ c2w 33 | c2w = torch.Tensor(np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]])) @ c2w 34 | return c2w 35 | 36 | 37 | def load_blender_data(basedir, half_res=False, testskip=1): 38 | splits = ['train', 'val', 'test'] 39 | metas = {} 40 | for s in splits: 41 | with open(os.path.join(basedir, 'transforms_{}.json'.format(s)), 'r') as fp: 42 | metas[s] = json.load(fp) 43 | 44 | all_imgs = [] 45 | all_poses = [] 46 | counts = [0] 47 | for s in splits: 48 | meta = metas[s] 49 | imgs = [] 50 | poses = [] 51 | if s=='train' or testskip==0: 52 | skip = 1 53 | else: 54 | skip = testskip 55 | 56 | for frame in meta['frames'][::skip]: 57 | fname = os.path.join(basedir, frame['file_path'] + '.png') 58 | imgs.append(imageio.imread(fname)) 59 | poses.append(np.array(frame['transform_matrix'])) 60 | imgs = (np.array(imgs) / 255.).astype(np.float32) # keep all 4 channels (RGBA) 61 | poses = np.array(poses).astype(np.float32) 62 | counts.append(counts[-1] + imgs.shape[0]) 63 | all_imgs.append(imgs) 64 | all_poses.append(poses) 65 | 66 | i_split = [np.arange(counts[i], counts[i+1]) for i in range(3)] 67 | 68 | imgs = np.concatenate(all_imgs, 0) 69 | poses = np.concatenate(all_poses, 0) 70 | 71 | H, W = imgs[0].shape[:2] 72 | camera_angle_x = float(meta['camera_angle_x']) 73 | focal = .5 * W / np.tan(.5 * camera_angle_x) 74 | 75 | render_poses = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180,180,40+1)[:-1]], 0) 76 | 77 | if half_res: 78 | H = H//2 79 | W = W//2 80 | focal = focal/2. 81 | 82 | imgs_half_res = np.zeros((imgs.shape[0], H, W, 4)) 83 | for i, img in enumerate(imgs): 84 | imgs_half_res[i] = cv2.resize(img, (H, W), interpolation=cv2.INTER_AREA) 85 | imgs = imgs_half_res 86 | # imgs = tf.image.resize_area(imgs, [400, 400]).numpy() 87 | 88 | 89 | return imgs, poses, render_poses, [H, W, focal], i_split 90 | 91 | 92 | -------------------------------------------------------------------------------- /submodules/nerf_pytorch/load_deepvoxels.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import imageio 4 | 5 | 6 | def load_dv_data(scene='cube', basedir='/data/deepvoxels', testskip=8): 7 | 8 | 9 | def parse_intrinsics(filepath, trgt_sidelength, invert_y=False): 10 | # Get camera intrinsics 11 | with open(filepath, 'r') as file: 12 | f, cx, cy = list(map(float, file.readline().split()))[:3] 13 | grid_barycenter = np.array(list(map(float, file.readline().split()))) 14 | near_plane = float(file.readline()) 15 | scale = float(file.readline()) 16 | height, width = map(float, file.readline().split()) 17 | 18 | try: 19 | world2cam_poses = int(file.readline()) 20 | except ValueError: 21 | world2cam_poses = None 22 | 23 | if world2cam_poses is None: 24 | world2cam_poses = False 25 | 26 | world2cam_poses = bool(world2cam_poses) 27 | 28 | print(cx,cy,f,height,width) 29 | 30 | cx = cx / width * trgt_sidelength 31 | cy = cy / height * trgt_sidelength 32 | f = trgt_sidelength / height * f 33 | 34 | fx = f 35 | if invert_y: 36 | fy = -f 37 | else: 38 | fy = f 39 | 40 | # Build the intrinsic matrices 41 | full_intrinsic = np.array([[fx, 0., cx, 0.], 42 | [0., fy, cy, 0], 43 | [0., 0, 1, 0], 44 | [0, 0, 0, 1]]) 45 | 46 | return full_intrinsic, grid_barycenter, scale, near_plane, world2cam_poses 47 | 48 | 49 | def load_pose(filename): 50 | assert os.path.isfile(filename) 51 | nums = open(filename).read().split() 52 | return np.array([float(x) for x in nums]).reshape([4,4]).astype(np.float32) 53 | 54 | 55 | H = 512 56 | W = 512 57 | deepvoxels_base = '{}/train/{}/'.format(basedir, scene) 58 | 59 | full_intrinsic, grid_barycenter, scale, near_plane, world2cam_poses = parse_intrinsics(os.path.join(deepvoxels_base, 'intrinsics.txt'), H) 60 | print(full_intrinsic, grid_barycenter, scale, near_plane, world2cam_poses) 61 | focal = full_intrinsic[0,0] 62 | print(H, W, focal) 63 | 64 | 65 | def dir2poses(posedir): 66 | poses = np.stack([load_pose(os.path.join(posedir, f)) for f in sorted(os.listdir(posedir)) if f.endswith('txt')], 0) 67 | transf = np.array([ 68 | [1,0,0,0], 69 | [0,-1,0,0], 70 | [0,0,-1,0], 71 | [0,0,0,1.], 72 | ]) 73 | poses = poses @ transf 74 | poses = poses[:,:3,:4].astype(np.float32) 75 | return poses 76 | 77 | posedir = os.path.join(deepvoxels_base, 'pose') 78 | poses = dir2poses(posedir) 79 | testposes = dir2poses('{}/test/{}/pose'.format(basedir, scene)) 80 | testposes = testposes[::testskip] 81 | valposes = dir2poses('{}/validation/{}/pose'.format(basedir, scene)) 82 | valposes = valposes[::testskip] 83 | 84 | imgfiles = [f for f in sorted(os.listdir(os.path.join(deepvoxels_base, 'rgb'))) if f.endswith('png')] 85 | imgs = np.stack([imageio.imread(os.path.join(deepvoxels_base, 'rgb', f))/255. for f in imgfiles], 0).astype(np.float32) 86 | 87 | 88 | testimgd = '{}/test/{}/rgb'.format(basedir, scene) 89 | imgfiles = [f for f in sorted(os.listdir(testimgd)) if f.endswith('png')] 90 | testimgs = np.stack([imageio.imread(os.path.join(testimgd, f))/255. for f in imgfiles[::testskip]], 0).astype(np.float32) 91 | 92 | valimgd = '{}/validation/{}/rgb'.format(basedir, scene) 93 | imgfiles = [f for f in sorted(os.listdir(valimgd)) if f.endswith('png')] 94 | valimgs = np.stack([imageio.imread(os.path.join(valimgd, f))/255. for f in imgfiles[::testskip]], 0).astype(np.float32) 95 | 96 | all_imgs = [imgs, valimgs, testimgs] 97 | counts = [0] + [x.shape[0] for x in all_imgs] 98 | counts = np.cumsum(counts) 99 | i_split = [np.arange(counts[i], counts[i+1]) for i in range(3)] 100 | 101 | imgs = np.concatenate(all_imgs, 0) 102 | poses = np.concatenate([poses, valposes, testposes], 0) 103 | 104 | render_poses = testposes 105 | 106 | print(poses.shape, imgs.shape) 107 | 108 | return imgs, poses, render_poses, [H,W,focal], i_split 109 | 110 | 111 | -------------------------------------------------------------------------------- /submodules/nerf_pytorch/requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.4.0 2 | torchvision>=0.2.1 3 | imageio 4 | imageio-ffmpeg 5 | matplotlib 6 | configargparse 7 | tensorboard==1.14.0 8 | tqdm 9 | opencv-python 10 | -------------------------------------------------------------------------------- /submodules/nerf_pytorch/torchsearchsorted/.gitignore: -------------------------------------------------------------------------------- 1 | # Prerequisites 2 | *.d 3 | 4 | # Object files 5 | *.o 6 | *.ko 7 | *.obj 8 | *.elf 9 | 10 | # Linker output 11 | *.ilk 12 | *.map 13 | *.exp 14 | 15 | # Precompiled Headers 16 | *.gch 17 | *.pch 18 | 19 | # Libraries 20 | *.lib 21 | *.a 22 | *.la 23 | *.lo 24 | 25 | # Shared objects (inc. Windows DLLs) 26 | *.dll 27 | *.so 28 | *.so.* 29 | *.dylib 30 | 31 | # Executables 32 | *.exe 33 | *.out 34 | *.app 35 | *.i*86 36 | *.x86_64 37 | *.hex 38 | 39 | # Debug files 40 | *.dSYM/ 41 | *.su 42 | *.idb 43 | *.pdb 44 | 45 | # Kernel Module Compile Results 46 | *.mod* 47 | *.cmd 48 | .tmp_versions/ 49 | modules.order 50 | Module.symvers 51 | Mkfile.old 52 | dkms.conf 53 | 54 | 55 | # Byte-compiled / optimized / DLL files 56 | __pycache__/ 57 | *.py[cod] 58 | *$py.class 59 | 60 | # C extensions 61 | *.so 62 | 63 | # Distribution / packaging 64 | .Python 65 | build/ 66 | develop-eggs/ 67 | dist/ 68 | downloads/ 69 | eggs/ 70 | .eggs/ 71 | lib/ 72 | lib64/ 73 | parts/ 74 | sdist/ 75 | var/ 76 | wheels/ 77 | *.egg-info/ 78 | .installed.cfg 79 | *.egg 80 | MANIFEST 81 | 82 | # PyInstaller 83 | # Usually these files are written by a python script from a template 84 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 85 | *.manifest 86 | *.spec 87 | 88 | # Installer logs 89 | pip-log.txt 90 | pip-delete-this-directory.txt 91 | 92 | # Unit test / coverage reports 93 | htmlcov/ 94 | .tox/ 95 | .coverage 96 | .coverage.* 97 | .cache 98 | nosetests.xml 99 | coverage.xml 100 | *.cover 101 | .hypothesis/ 102 | .pytest_cache/ 103 | 104 | # Translations 105 | *.mo 106 | *.pot 107 | 108 | # Django stuff: 109 | *.log 110 | local_settings.py 111 | db.sqlite3 112 | 113 | # Flask stuff: 114 | instance/ 115 | .webassets-cache 116 | 117 | # Scrapy stuff: 118 | .scrapy 119 | 120 | # Sphinx documentation 121 | docs/_build/ 122 | 123 | # PyBuilder 124 | target/ 125 | 126 | # Jupyter Notebook 127 | .ipynb_checkpoints 128 | 129 | # pyenv 130 | .python-version 131 | 132 | # celery beat schedule file 133 | celerybeat-schedule 134 | 135 | # SageMath parsed files 136 | *.sage.py 137 | 138 | # Environments 139 | .env 140 | .venv 141 | env/ 142 | venv/ 143 | ENV/ 144 | env.bak/ 145 | venv.bak/ 146 | 147 | # Spyder project settings 148 | .spyderproject 149 | .spyproject 150 | 151 | # Rope project settings 152 | .ropeproject 153 | 154 | # mkdocs documentation 155 | /site 156 | 157 | # mypy 158 | .mypy_cache/ 159 | -------------------------------------------------------------------------------- /submodules/nerf_pytorch/torchsearchsorted/LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2019, Inria (Antoine Liutkus) 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /submodules/nerf_pytorch/torchsearchsorted/README.md: -------------------------------------------------------------------------------- 1 | # Pytorch Custom CUDA kernel for searchsorted 2 | 3 | This repository is an implementation of the searchsorted function to work for pytorch CUDA Tensors. Initially derived from the great [C extension tutorial](https://github.com/chrischoy/pytorch-custom-cuda-tutorial), but totally changed since then because building C extensions is not available anymore on pytorch 1.0. 4 | 5 | 6 | > Warnings: 7 | > * only works with pytorch > v1.3 and CUDA >= v10.1 8 | > * **NOTE** When using `searchsorted()` for practical applications, tensors need to be contiguous in memory. This can be easily achieved by calling `tensor.contiguous()` on the input tensors. Failing to do so _will_ lead to inconsistent results across applications. 9 | 10 | ## Description 11 | 12 | Implements a function `searchsorted(a, v, out, side)` that works just like the [numpy version](https://docs.scipy.org/doc/numpy/reference/generated/numpy.searchsorted.html#numpy.searchsorted) except that `a` and `v` are matrices. 13 | * `a` is of shape either `(1, ncols_a)` or `(nrows, ncols_a)`, and is contiguous in memory (do `a.contiguous()` to ensure this). 14 | * `v` is of shape either `(1, ncols_v)` or `(nrows, ncols_v)`, and is contiguous in memory (do `v.contiguous()` to ensure this). 15 | * `out` is either `None` or of shape `(nrows, ncols_v)`. If provided and of the right shape, the result is put there. This is to avoid costly memory allocations if the user already did it. If provided, `out` should be contiguous in memory too (do `out.contiguous()` to ensure this). 16 | * `side` is either "left" or "right". See the [numpy doc](https://docs.scipy.org/doc/numpy/reference/generated/numpy.searchsorted.html#numpy.searchsorted). Please not that the current implementation *does not correctly handle this parameter*. Help welcome to improve the speed of [this PR](https://github.com/aliutkus/torchsearchsorted/pull/7) 17 | 18 | the output is of size as `(nrows, ncols_v)`. If all input tensors are on GPU, a cuda version will be called. Otherwise, it will be on CPU. 19 | 20 | 21 | **Disclaimers** 22 | 23 | * This function has not been heavily tested. Use at your own risks 24 | * When `a` is not sorted, the results vary from numpy's version. But I decided not to care about this because the function should not be called in this case. 25 | * In some cases, the results vary from numpy's version. However, as far as I could see, this only happens when values are equal, which means we actually don't care about the order in which this value is added. I decided not to care about this also. 26 | * vectors have to be contiguous for torchsearchsorted to give consistant results. use `.contiguous()` on all tensor arguments before calling 27 | 28 | 29 | ## Installation 30 | 31 | Just `pip install .`, in the root folder of this repo. This will compile 32 | and install the torchsearchsorted module. 33 | 34 | be careful that sometimes, `nvcc` needs versions of `gcc` and `g++` that are older than those found by default on the system. If so, just create symbolic links to the right versions in your cuda/bin folder (where `nvcc` is) 35 | 36 | For instance, on my machine, I had `gcc` and `g++` v9 installed, but `nvcc` required v8. 37 | So I had to do: 38 | 39 | > sudo apt-get install g++-8 gcc-8 40 | > sudo ln -s /usr/bin/gcc-8 /usr/local/cuda-10.1/bin/gcc 41 | > sudo ln -s /usr/bin/g++-8 /usr/local/cuda-10.1/bin/g++ 42 | 43 | be careful that you need pytorch to be installed on your system. The code was tested on pytorch v1.3 44 | 45 | ## Usage 46 | 47 | Just import the torchsearchsorted package after installation. I typically do: 48 | 49 | ``` 50 | from torchsearchsorted import searchsorted 51 | ``` 52 | 53 | 54 | ## Testing 55 | 56 | Under the `examples` subfolder, you may: 57 | 58 | 1. try `python test.py` with `torch` available. 59 | 60 | ``` 61 | Looking for 50000x1000 values in 50000x300 entries 62 | NUMPY: searchsorted in 4851.592ms 63 | CPU: searchsorted in 4805.432ms 64 | difference between CPU and NUMPY: 0.000 65 | GPU: searchsorted in 1.055ms 66 | difference between GPU and NUMPY: 0.000 67 | 68 | Looking for 50000x1000 values in 50000x300 entries 69 | NUMPY: searchsorted in 4333.964ms 70 | CPU: searchsorted in 4753.958ms 71 | difference between CPU and NUMPY: 0.000 72 | GPU: searchsorted in 0.391ms 73 | difference between GPU and NUMPY: 0.000 74 | ``` 75 | The first run comprises the time of allocation, while the second one does not. 76 | 77 | 2. You may also use the nice `benchmark.py` code written by [@baldassarreFe](https://github.com/baldassarreFe), that tests `searchsorted` on many runs: 78 | 79 | ``` 80 | Benchmark searchsorted: 81 | - a [5000 x 300] 82 | - v [5000 x 100] 83 | - reporting fastest time of 20 runs 84 | - each run executes searchsorted 100 times 85 | 86 | Numpy: 4.6302046799100935 87 | CPU: 5.041533078998327 88 | CUDA: 0.0007955809123814106 89 | ``` 90 | -------------------------------------------------------------------------------- /submodules/nerf_pytorch/torchsearchsorted/examples/benchmark.py: -------------------------------------------------------------------------------- 1 | import timeit 2 | 3 | import torch 4 | import numpy as np 5 | from torchsearchsorted import searchsorted, numpy_searchsorted 6 | 7 | B = 5_000 8 | A = 300 9 | V = 100 10 | 11 | repeats = 20 12 | number = 100 13 | 14 | print( 15 | f'Benchmark searchsorted:', 16 | f'- a [{B} x {A}]', 17 | f'- v [{B} x {V}]', 18 | f'- reporting fastest time of {repeats} runs', 19 | f'- each run executes searchsorted {number} times', 20 | sep='\n', 21 | end='\n\n' 22 | ) 23 | 24 | 25 | def get_arrays(): 26 | a = np.sort(np.random.randn(B, A), axis=1) 27 | v = np.random.randn(B, V) 28 | out = np.empty_like(v, dtype=np.long) 29 | return a, v, out 30 | 31 | 32 | def get_tensors(device): 33 | a = torch.sort(torch.randn(B, A, device=device), dim=1)[0] 34 | v = torch.randn(B, V, device=device) 35 | out = torch.empty(B, V, device=device, dtype=torch.long) 36 | if torch.cuda.is_available(): 37 | torch.cuda.synchronize() 38 | return a, v, out 39 | 40 | def searchsorted_synchronized(a,v,out=None,side='left'): 41 | out = searchsorted(a,v,out,side) 42 | torch.cuda.synchronize() 43 | return out 44 | 45 | numpy = timeit.repeat( 46 | stmt="numpy_searchsorted(a, v, side='left')", 47 | setup="a, v, out = get_arrays()", 48 | globals=globals(), 49 | repeat=repeats, 50 | number=number 51 | ) 52 | print('Numpy: ', min(numpy), sep='\t') 53 | 54 | cpu = timeit.repeat( 55 | stmt="searchsorted(a, v, out, side='left')", 56 | setup="a, v, out = get_tensors(device='cpu')", 57 | globals=globals(), 58 | repeat=repeats, 59 | number=number 60 | ) 61 | print('CPU: ', min(cpu), sep='\t') 62 | 63 | if torch.cuda.is_available(): 64 | gpu = timeit.repeat( 65 | stmt="searchsorted_synchronized(a, v, out, side='left')", 66 | setup="a, v, out = get_tensors(device='cuda')", 67 | globals=globals(), 68 | repeat=repeats, 69 | number=number 70 | ) 71 | print('CUDA: ', min(gpu), sep='\t') 72 | -------------------------------------------------------------------------------- /submodules/nerf_pytorch/torchsearchsorted/examples/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchsearchsorted import searchsorted, numpy_searchsorted 3 | import time 4 | 5 | if __name__ == '__main__': 6 | # defining the number of tests 7 | ntests = 2 8 | 9 | # defining the problem dimensions 10 | nrows_a = 50000 11 | nrows_v = 50000 12 | nsorted_values = 300 13 | nvalues = 1000 14 | 15 | # defines the variables. The first run will comprise allocation, the 16 | # further ones will not 17 | test_GPU = None 18 | test_CPU = None 19 | 20 | for ntest in range(ntests): 21 | print("\nLooking for %dx%d values in %dx%d entries" % (nrows_v, nvalues, 22 | nrows_a, 23 | nsorted_values)) 24 | 25 | side = 'right' 26 | # generate a matrix with sorted rows 27 | a = torch.randn(nrows_a, nsorted_values, device='cpu') 28 | a = torch.sort(a, dim=1)[0] 29 | # generate a matrix of values to searchsort 30 | v = torch.randn(nrows_v, nvalues, device='cpu') 31 | 32 | # a = torch.tensor([[0., 1.]]) 33 | # v = torch.tensor([[1.]]) 34 | 35 | t0 = time.time() 36 | test_NP = torch.tensor(numpy_searchsorted(a, v, side)) 37 | print('NUMPY: searchsorted in %0.3fms' % (1000*(time.time()-t0))) 38 | t0 = time.time() 39 | test_CPU = searchsorted(a, v, test_CPU, side) 40 | print('CPU: searchsorted in %0.3fms' % (1000*(time.time()-t0))) 41 | # compute the difference between both 42 | error_CPU = torch.norm(test_NP.double() 43 | - test_CPU.double()).numpy() 44 | if error_CPU: 45 | import ipdb; ipdb.set_trace() 46 | print(' difference between CPU and NUMPY: %0.3f' % error_CPU) 47 | 48 | if not torch.cuda.is_available(): 49 | print('CUDA is not available on this machine, cannot go further.') 50 | continue 51 | else: 52 | # now do the CPU 53 | a = a.to('cuda') 54 | v = v.to('cuda') 55 | torch.cuda.synchronize() 56 | # launch searchsorted on those 57 | t0 = time.time() 58 | test_GPU = searchsorted(a, v, test_GPU, side) 59 | torch.cuda.synchronize() 60 | print('GPU: searchsorted in %0.3fms' % (1000*(time.time()-t0))) 61 | 62 | # compute the difference between both 63 | error_CUDA = torch.norm(test_NP.to('cuda').double() 64 | - test_GPU.double()).cpu().numpy() 65 | 66 | print(' difference between GPU and NUMPY: %0.3f' % error_CUDA) 67 | -------------------------------------------------------------------------------- /submodules/nerf_pytorch/torchsearchsorted/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | from torch.utils.cpp_extension import BuildExtension, CUDA_HOME 3 | from torch.utils.cpp_extension import CppExtension, CUDAExtension 4 | 5 | # In any case, include the CPU version 6 | modules = [ 7 | CppExtension('torchsearchsorted.cpu', 8 | ['src/cpu/searchsorted_cpu_wrapper.cpp']), 9 | ] 10 | 11 | # If nvcc is available, add the CUDA extension 12 | if CUDA_HOME: 13 | modules.append( 14 | CUDAExtension('torchsearchsorted.cuda', 15 | ['src/cuda/searchsorted_cuda_wrapper.cpp', 16 | 'src/cuda/searchsorted_cuda_kernel.cu']) 17 | ) 18 | 19 | tests_require = [ 20 | 'pytest', 21 | ] 22 | 23 | # Now proceed to setup 24 | setup( 25 | name='torchsearchsorted', 26 | version='1.1', 27 | description='A searchsorted implementation for pytorch', 28 | keywords='searchsorted', 29 | author='Antoine Liutkus', 30 | author_email='antoine.liutkus@inria.fr', 31 | packages=find_packages(where='src'), 32 | package_dir={"": "src"}, 33 | ext_modules=modules, 34 | tests_require=tests_require, 35 | extras_require={ 36 | 'test': tests_require, 37 | }, 38 | cmdclass={ 39 | 'build_ext': BuildExtension 40 | } 41 | ) 42 | -------------------------------------------------------------------------------- /submodules/nerf_pytorch/torchsearchsorted/src/cpu/searchsorted_cpu_wrapper.cpp: -------------------------------------------------------------------------------- 1 | #include "searchsorted_cpu_wrapper.h" 2 | #include 3 | 4 | template 5 | int eval(scalar_t val, scalar_t *a, int64_t row, int64_t col, int64_t ncol, bool side_left) 6 | { 7 | /* Evaluates whether a[row,col] < val <= a[row, col+1]*/ 8 | 9 | if (col == ncol - 1) 10 | { 11 | // special case: we are on the right border 12 | if (a[row * ncol + col] <= val){ 13 | return 1;} 14 | else { 15 | return -1;} 16 | } 17 | bool is_lower; 18 | bool is_next_higher; 19 | 20 | if (side_left) { 21 | // a[row, col] < v <= a[row, col+1] 22 | is_lower = (a[row * ncol + col] < val); 23 | is_next_higher = (a[row*ncol + col + 1] >= val); 24 | } else { 25 | // a[row, col] <= v < a[row, col+1] 26 | is_lower = (a[row * ncol + col] <= val); 27 | is_next_higher = (a[row * ncol + col + 1] > val); 28 | } 29 | if (is_lower && is_next_higher) { 30 | // we found the right spot 31 | return 0; 32 | } else if (is_lower) { 33 | // answer is on the right side 34 | return 1; 35 | } else { 36 | // answer is on the left side 37 | return -1; 38 | } 39 | } 40 | 41 | template 42 | int64_t binary_search(scalar_t*a, int64_t row, scalar_t val, int64_t ncol, bool side_left) 43 | { 44 | /* Look for the value `val` within row `row` of matrix `a`, which 45 | has `ncol` columns. 46 | 47 | the `a` matrix is assumed sorted in increasing order, row-wise 48 | 49 | returns: 50 | * -1 if `val` is smaller than the smallest value found within that row of `a` 51 | * `ncol` - 1 if `val` is larger than the largest element of that row of `a` 52 | * Otherwise, return the column index `res` such that: 53 | - a[row, col] < val <= a[row, col+1]. (if side_left), or 54 | - a[row, col] < val <= a[row, col+1] (if not side_left). 55 | */ 56 | 57 | //start with left at 0 and right at number of columns of a 58 | int64_t right = ncol; 59 | int64_t left = 0; 60 | 61 | while (right >= left) { 62 | // take the midpoint of current left and right cursors 63 | int64_t mid = left + (right-left)/2; 64 | 65 | // check the relative position of val: are we good here ? 66 | int rel_pos = eval(val, a, row, mid, ncol, side_left); 67 | // we found the point 68 | if(rel_pos == 0) { 69 | return mid; 70 | } else if (rel_pos > 0) { 71 | if (mid==ncol-1){return ncol-1;} 72 | // the answer is on the right side 73 | left = mid; 74 | } else { 75 | if (mid==0){return -1;} 76 | right = mid; 77 | } 78 | } 79 | return -1; 80 | } 81 | 82 | void searchsorted_cpu_wrapper( 83 | at::Tensor a, 84 | at::Tensor v, 85 | at::Tensor res, 86 | bool side_left) 87 | { 88 | 89 | // Get the dimensions 90 | auto nrow_a = a.size(/*dim=*/0); 91 | auto ncol_a = a.size(/*dim=*/1); 92 | auto nrow_v = v.size(/*dim=*/0); 93 | auto ncol_v = v.size(/*dim=*/1); 94 | 95 | auto nrow_res = fmax(nrow_a, nrow_v); 96 | 97 | //auto acc_v = v.accessor(); 98 | //auto acc_res = res.accessor(); 99 | 100 | AT_DISPATCH_ALL_TYPES(a.type(), "searchsorted cpu", [&] { 101 | 102 | scalar_t* a_data = a.data_ptr(); 103 | scalar_t* v_data = v.data_ptr(); 104 | int64_t* res_data = res.data(); 105 | 106 | for (int64_t row = 0; row < nrow_res; row++) 107 | { 108 | for (int64_t col = 0; col < ncol_v; col++) 109 | { 110 | // get the value to look for 111 | int64_t row_in_v = (nrow_v == 1) ? 0 : row; 112 | int64_t row_in_a = (nrow_a == 1) ? 0 : row; 113 | 114 | int64_t idx_in_v = row_in_v * ncol_v + col; 115 | int64_t idx_in_res = row * ncol_v + col; 116 | 117 | // apply binary search 118 | res_data[idx_in_res] = (binary_search(a_data, row_in_a, v_data[idx_in_v], ncol_a, side_left) + 1); 119 | } 120 | } 121 | }); 122 | } 123 | 124 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 125 | m.def("searchsorted_cpu_wrapper", &searchsorted_cpu_wrapper, "searchsorted (CPU)"); 126 | } 127 | -------------------------------------------------------------------------------- /submodules/nerf_pytorch/torchsearchsorted/src/cpu/searchsorted_cpu_wrapper.h: -------------------------------------------------------------------------------- 1 | #ifndef _SEARCHSORTED_CPU 2 | #define _SEARCHSORTED_CPU 3 | 4 | #include 5 | 6 | void searchsorted_cpu_wrapper( 7 | at::Tensor a, 8 | at::Tensor v, 9 | at::Tensor res, 10 | bool side_left); 11 | 12 | #endif -------------------------------------------------------------------------------- /submodules/nerf_pytorch/torchsearchsorted/src/cuda/searchsorted_cuda_kernel.cu: -------------------------------------------------------------------------------- 1 | #include "searchsorted_cuda_kernel.h" 2 | 3 | template 4 | __device__ 5 | int eval(scalar_t val, scalar_t *a, int64_t row, int64_t col, int64_t ncol, bool side_left) 6 | { 7 | /* Evaluates whether a[row,col] < val <= a[row, col+1]*/ 8 | 9 | if (col == ncol - 1) 10 | { 11 | // special case: we are on the right border 12 | if (a[row * ncol + col] <= val){ 13 | return 1;} 14 | else { 15 | return -1;} 16 | } 17 | bool is_lower; 18 | bool is_next_higher; 19 | 20 | if (side_left) { 21 | // a[row, col] < v <= a[row, col+1] 22 | is_lower = (a[row * ncol + col] < val); 23 | is_next_higher = (a[row*ncol + col + 1] >= val); 24 | } else { 25 | // a[row, col] <= v < a[row, col+1] 26 | is_lower = (a[row * ncol + col] <= val); 27 | is_next_higher = (a[row * ncol + col + 1] > val); 28 | } 29 | if (is_lower && is_next_higher) { 30 | // we found the right spot 31 | return 0; 32 | } else if (is_lower) { 33 | // answer is on the right side 34 | return 1; 35 | } else { 36 | // answer is on the left side 37 | return -1; 38 | } 39 | } 40 | 41 | template 42 | __device__ 43 | int binary_search(scalar_t *a, int64_t row, scalar_t val, int64_t ncol, bool side_left) 44 | { 45 | /* Look for the value `val` within row `row` of matrix `a`, which 46 | has `ncol` columns. 47 | 48 | the `a` matrix is assumed sorted in increasing order, row-wise 49 | 50 | Returns 51 | * -1 if `val` is smaller than the smallest value found within that row of `a` 52 | * `ncol` - 1 if `val` is larger than the largest element of that row of `a` 53 | * Otherwise, return the column index `res` such that: 54 | - a[row, col] < val <= a[row, col+1]. (if side_left), or 55 | - a[row, col] < val <= a[row, col+1] (if not side_left). 56 | */ 57 | 58 | //start with left at 0 and right at number of columns of a 59 | int64_t right = ncol; 60 | int64_t left = 0; 61 | 62 | while (right >= left) { 63 | // take the midpoint of current left and right cursors 64 | int64_t mid = left + (right-left)/2; 65 | 66 | // check the relative position of val: are we good here ? 67 | int rel_pos = eval(val, a, row, mid, ncol, side_left); 68 | // we found the point 69 | if(rel_pos == 0) { 70 | return mid; 71 | } else if (rel_pos > 0) { 72 | if (mid==ncol-1){return ncol-1;} 73 | // the answer is on the right side 74 | left = mid; 75 | } else { 76 | if (mid==0){return -1;} 77 | right = mid; 78 | } 79 | } 80 | return -1; 81 | } 82 | 83 | template 84 | __global__ 85 | void searchsorted_kernel( 86 | int64_t *res, 87 | scalar_t *a, 88 | scalar_t *v, 89 | int64_t nrow_res, int64_t nrow_a, int64_t nrow_v, int64_t ncol_a, int64_t ncol_v, bool side_left) 90 | { 91 | // get current row and column 92 | int64_t row = blockIdx.y*blockDim.y+threadIdx.y; 93 | int64_t col = blockIdx.x*blockDim.x+threadIdx.x; 94 | 95 | // check whether we are outside the bounds of what needs be computed. 96 | if ((row >= nrow_res) || (col >= ncol_v)) { 97 | return;} 98 | 99 | // get the value to look for 100 | int64_t row_in_v = (nrow_v==1) ? 0: row; 101 | int64_t row_in_a = (nrow_a==1) ? 0: row; 102 | int64_t idx_in_v = row_in_v*ncol_v+col; 103 | int64_t idx_in_res = row*ncol_v+col; 104 | 105 | // apply binary search 106 | res[idx_in_res] = binary_search(a, row_in_a, v[idx_in_v], ncol_a, side_left)+1; 107 | } 108 | 109 | 110 | void searchsorted_cuda( 111 | at::Tensor a, 112 | at::Tensor v, 113 | at::Tensor res, 114 | bool side_left){ 115 | 116 | // Get the dimensions 117 | auto nrow_a = a.size(/*dim=*/0); 118 | auto nrow_v = v.size(/*dim=*/0); 119 | auto ncol_a = a.size(/*dim=*/1); 120 | auto ncol_v = v.size(/*dim=*/1); 121 | 122 | auto nrow_res = fmax(double(nrow_a), double(nrow_v)); 123 | 124 | // prepare the kernel configuration 125 | dim3 threads(ncol_v, nrow_res); 126 | dim3 blocks(1, 1); 127 | if (nrow_res*ncol_v > 1024){ 128 | threads.x = int(fmin(double(1024), double(ncol_v))); 129 | threads.y = floor(1024/threads.x); 130 | blocks.x = ceil(double(ncol_v)/double(threads.x)); 131 | blocks.y = ceil(double(nrow_res)/double(threads.y)); 132 | } 133 | 134 | AT_DISPATCH_ALL_TYPES(a.type(), "searchsorted cuda", ([&] { 135 | searchsorted_kernel<<>>( 136 | res.data(), 137 | a.data(), 138 | v.data(), 139 | nrow_res, nrow_a, nrow_v, ncol_a, ncol_v, side_left); 140 | })); 141 | 142 | } 143 | -------------------------------------------------------------------------------- /submodules/nerf_pytorch/torchsearchsorted/src/cuda/searchsorted_cuda_kernel.h: -------------------------------------------------------------------------------- 1 | #ifndef _SEARCHSORTED_CUDA_KERNEL 2 | #define _SEARCHSORTED_CUDA_KERNEL 3 | 4 | #include 5 | 6 | void searchsorted_cuda( 7 | at::Tensor a, 8 | at::Tensor v, 9 | at::Tensor res, 10 | bool side_left); 11 | 12 | #endif 13 | -------------------------------------------------------------------------------- /submodules/nerf_pytorch/torchsearchsorted/src/cuda/searchsorted_cuda_wrapper.cpp: -------------------------------------------------------------------------------- 1 | #include "searchsorted_cuda_wrapper.h" 2 | 3 | // C++ interface 4 | 5 | #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") 6 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") 7 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 8 | 9 | void searchsorted_cuda_wrapper(at::Tensor a, at::Tensor v, at::Tensor res, bool side_left) 10 | { 11 | CHECK_INPUT(a); 12 | CHECK_INPUT(v); 13 | CHECK_INPUT(res); 14 | 15 | searchsorted_cuda(a, v, res, side_left); 16 | } 17 | 18 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 19 | m.def("searchsorted_cuda_wrapper", &searchsorted_cuda_wrapper, "searchsorted (CUDA)"); 20 | } 21 | -------------------------------------------------------------------------------- /submodules/nerf_pytorch/torchsearchsorted/src/cuda/searchsorted_cuda_wrapper.h: -------------------------------------------------------------------------------- 1 | #ifndef _SEARCHSORTED_CUDA_WRAPPER 2 | #define _SEARCHSORTED_CUDA_WRAPPER 3 | 4 | #include 5 | #include "searchsorted_cuda_kernel.h" 6 | 7 | void searchsorted_cuda_wrapper( 8 | at::Tensor a, 9 | at::Tensor v, 10 | at::Tensor res, 11 | bool side_left); 12 | 13 | #endif 14 | -------------------------------------------------------------------------------- /submodules/nerf_pytorch/torchsearchsorted/src/torchsearchsorted/__init__.py: -------------------------------------------------------------------------------- 1 | from .searchsorted import searchsorted 2 | from .utils import numpy_searchsorted 3 | -------------------------------------------------------------------------------- /submodules/nerf_pytorch/torchsearchsorted/src/torchsearchsorted/searchsorted.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | 5 | # trying to import the CPU searchsorted 6 | SEARCHSORTED_CPU_AVAILABLE = True 7 | try: 8 | from torchsearchsorted.cpu import searchsorted_cpu_wrapper 9 | except ImportError: 10 | SEARCHSORTED_CPU_AVAILABLE = False 11 | 12 | # trying to import the CUDA searchsorted 13 | SEARCHSORTED_GPU_AVAILABLE = True 14 | try: 15 | from torchsearchsorted.cuda import searchsorted_cuda_wrapper 16 | except ImportError: 17 | SEARCHSORTED_GPU_AVAILABLE = False 18 | 19 | 20 | def searchsorted(a: torch.Tensor, v: torch.Tensor, 21 | out: Optional[torch.LongTensor] = None, 22 | side='left') -> torch.LongTensor: 23 | assert len(a.shape) == 2, "input `a` must be 2-D." 24 | assert len(v.shape) == 2, "input `v` mus(t be 2-D." 25 | assert (a.shape[0] == v.shape[0] 26 | or a.shape[0] == 1 27 | or v.shape[0] == 1), ("`a` and `v` must have the same number of " 28 | "rows or one of them must have only one ") 29 | assert a.device == v.device, '`a` and `v` must be on the same device' 30 | 31 | result_shape = (max(a.shape[0], v.shape[0]), v.shape[1]) 32 | if out is not None: 33 | assert out.device == a.device, "`out` must be on the same device as `a`" 34 | assert out.dtype == torch.long, "out.dtype must be torch.long" 35 | assert out.shape == result_shape, ("If the output tensor is provided, " 36 | "its shape must be correct.") 37 | else: 38 | out = torch.empty(result_shape, device=v.device, dtype=torch.long) 39 | 40 | if a.is_cuda and not SEARCHSORTED_GPU_AVAILABLE: 41 | raise Exception('torchsearchsorted on CUDA device is asked, but it seems ' 42 | 'that it is not available. Please install it') 43 | if not a.is_cuda and not SEARCHSORTED_CPU_AVAILABLE: 44 | raise Exception('torchsearchsorted on CPU is not available. ' 45 | 'Please install it.') 46 | 47 | left_side = 1 if side=='left' else 0 48 | if a.is_cuda: 49 | searchsorted_cuda_wrapper(a, v, out, left_side) 50 | else: 51 | searchsorted_cpu_wrapper(a, v, out, left_side) 52 | 53 | return out 54 | -------------------------------------------------------------------------------- /submodules/nerf_pytorch/torchsearchsorted/src/torchsearchsorted/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def numpy_searchsorted(a: np.ndarray, v: np.ndarray, side='left'): 5 | """Numpy version of searchsorted that works batch-wise on pytorch tensors 6 | """ 7 | nrows_a = a.shape[0] 8 | (nrows_v, ncols_v) = v.shape 9 | nrows_out = max(nrows_a, nrows_v) 10 | out = np.empty((nrows_out, ncols_v), dtype=np.long) 11 | def sel(data, row): 12 | return data[0] if data.shape[0] == 1 else data[row] 13 | for row in range(nrows_out): 14 | out[row] = np.searchsorted(sel(a, row), sel(v, row), side=side) 15 | return out 16 | -------------------------------------------------------------------------------- /submodules/nerf_pytorch/torchsearchsorted/test/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | devices = {'cpu': torch.device('cpu')} 5 | if torch.cuda.is_available(): 6 | devices['cuda'] = torch.device('cuda:0') 7 | 8 | 9 | @pytest.fixture(params=devices.values(), ids=devices.keys()) 10 | def device(request): 11 | return request.param 12 | -------------------------------------------------------------------------------- /submodules/nerf_pytorch/torchsearchsorted/test/test_searchsorted.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import torch 4 | import numpy as np 5 | from torchsearchsorted import searchsorted, numpy_searchsorted 6 | from itertools import product, repeat 7 | 8 | 9 | def test_searchsorted_output_dtype(device): 10 | B = 100 11 | A = 50 12 | V = 12 13 | 14 | a = torch.sort(torch.rand(B, V, device=device), dim=1)[0] 15 | v = torch.rand(B, A, device=device) 16 | 17 | out = searchsorted(a, v) 18 | out_np = numpy_searchsorted(a.cpu().numpy(), v.cpu().numpy()) 19 | assert out.dtype == torch.long 20 | np.testing.assert_array_equal(out.cpu().numpy(), out_np) 21 | 22 | out = torch.empty(v.shape, dtype=torch.long, device=device) 23 | searchsorted(a, v, out) 24 | assert out.dtype == torch.long 25 | np.testing.assert_array_equal(out.cpu().numpy(), out_np) 26 | 27 | Ba_val = [1, 100, 200] 28 | Bv_val = [1, 100, 200] 29 | A_val = [1, 50, 500] 30 | V_val = [1, 12, 120] 31 | side_val = ['left', 'right'] 32 | nrepeat = 100 33 | 34 | @pytest.mark.parametrize('Ba,Bv,A,V,side', product(Ba_val, Bv_val, A_val, V_val, side_val)) 35 | def test_searchsorted_correct(Ba, Bv, A, V, side, device): 36 | if Ba > 1 and Bv > 1 and Ba != Bv: 37 | return 38 | for test in range(nrepeat): 39 | a = torch.sort(torch.rand(Ba, A, device=device), dim=1)[0] 40 | v = torch.rand(Bv, V, device=device) 41 | out_np = numpy_searchsorted(a.cpu().numpy(), v.cpu().numpy(), 42 | side=side) 43 | out = searchsorted(a, v, side=side).cpu().numpy() 44 | np.testing.assert_array_equal(out, out_np) 45 | --------------------------------------------------------------------------------