├── LICENSE
├── README.md
├── animations
└── carla_256.gif
├── configs
├── carla.yaml
├── cats.yaml
├── celebA.yaml
├── celebAHQ.yaml
├── cub.yaml
├── debug.yaml
├── default.yaml
└── pretrained_models.yaml
├── data
├── cub
│ └── filtered_files.txt
├── download_carla.sh
├── download_carla_poses.sh
├── preprocess_cats.py
└── preprocess_cub.py
├── environment.yml
├── eval.py
├── external
└── colmap
│ ├── __init__.py
│ ├── filter_points.py
│ └── run_colmap_automatic.sh
├── graf
├── config.py
├── datasets.py
├── gan_training.py
├── models
│ ├── discriminator.py
│ └── generator.py
├── transforms.py
└── utils.py
├── submodules
├── GAN_stability
│ ├── .gitignore
│ ├── LICENSE
│ ├── README.md
│ ├── configs
│ │ ├── celebAHQ.yaml
│ │ ├── default.yaml
│ │ ├── imagenet.yaml
│ │ ├── lsun_bedroom.yaml
│ │ ├── lsun_bridge.yaml
│ │ ├── lsun_church.yaml
│ │ ├── lsun_tower.yaml
│ │ └── pretrained
│ │ │ ├── celebAHQ_pretrained.yaml
│ │ │ ├── celebA_pretrained.yaml
│ │ │ ├── imagenet_pretrained.yaml
│ │ │ ├── lsun_bedroom_pretrained.yaml
│ │ │ ├── lsun_bridge_pretrained.yaml
│ │ │ ├── lsun_church_pretrained.yaml
│ │ │ └── lsun_tower_pretrained.yaml
│ ├── gan_training
│ │ ├── __init__.py
│ │ ├── checkpoints.py
│ │ ├── config.py
│ │ ├── distributions.py
│ │ ├── eval.py
│ │ ├── inputs.py
│ │ ├── logger.py
│ │ ├── metrics
│ │ │ ├── __init__.py
│ │ │ ├── fid_score.py
│ │ │ ├── inception.py
│ │ │ ├── inception_score.py
│ │ │ └── kid_score.py
│ │ ├── models
│ │ │ ├── __init__.py
│ │ │ ├── resnet.py
│ │ │ ├── resnet2.py
│ │ │ ├── resnet3.py
│ │ │ └── resnet4.py
│ │ ├── ops.py
│ │ ├── train.py
│ │ └── utils.py
│ ├── interpolate.py
│ ├── interpolate_class.py
│ ├── notebooks
│ │ ├── DiracGAN.ipynb
│ │ ├── create_video.sh
│ │ └── diracgan
│ │ │ ├── __init__.py
│ │ │ ├── gans.py
│ │ │ ├── plotting.py
│ │ │ ├── simulate.py
│ │ │ ├── subplots.py
│ │ │ └── util.py
│ ├── results
│ │ ├── celebA-HQ.jpg
│ │ ├── imagenet_00.jpg
│ │ ├── imagenet_01.jpg
│ │ ├── imagenet_02.jpg
│ │ ├── imagenet_03.jpg
│ │ └── imagenet_04.jpg
│ ├── test.py
│ └── train.py
└── nerf_pytorch
│ ├── .gitignore
│ ├── .gitmodules
│ ├── LICENSE
│ ├── README.md
│ ├── configs
│ ├── config_fern.txt
│ ├── config_flower.txt
│ ├── config_fortress.txt
│ ├── config_horns.txt
│ ├── config_lego.txt
│ └── config_trex.txt
│ ├── download_example_data.sh
│ ├── imgs
│ └── pipeline.jpg
│ ├── load_blender.py
│ ├── load_deepvoxels.py
│ ├── load_llff.py
│ ├── requirements.txt
│ ├── run_nerf.py
│ ├── run_nerf_helpers.py
│ ├── run_nerf_helpers_mod.py
│ ├── run_nerf_mod.py
│ └── torchsearchsorted
│ ├── .gitignore
│ ├── LICENSE
│ ├── README.md
│ ├── examples
│ ├── benchmark.py
│ └── test.py
│ ├── setup.py
│ ├── src
│ ├── cpu
│ │ ├── searchsorted_cpu_wrapper.cpp
│ │ └── searchsorted_cpu_wrapper.h
│ ├── cuda
│ │ ├── searchsorted_cuda_kernel.cu
│ │ ├── searchsorted_cuda_kernel.h
│ │ ├── searchsorted_cuda_wrapper.cpp
│ │ └── searchsorted_cuda_wrapper.h
│ └── torchsearchsorted
│ │ ├── __init__.py
│ │ ├── searchsorted.py
│ │ └── utils.py
│ └── test
│ ├── conftest.py
│ └── test_searchsorted.py
└── train.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 autonomousvision
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # GRAF
2 |
3 |
4 |

5 |
6 |
7 | This repository contains official code for the paper
8 | [GRAF: Generative Radiance Fields for 3D-Aware Image Synthesis](https://avg.is.tuebingen.mpg.de/publications/schwarz2020neurips).
9 |
10 | You can find detailed usage instructions for training your own models and using pre-trained models below.
11 |
12 |
13 | If you find our code or paper useful, please consider citing
14 |
15 | @inproceedings{Schwarz2020NEURIPS,
16 | title = {GRAF: Generative Radiance Fields for 3D-Aware Image Synthesis},
17 | author = {Schwarz, Katja and Liao, Yiyi and Niemeyer, Michael and Geiger, Andreas},
18 | booktitle = {Advances in Neural Information Processing Systems (NeurIPS)},
19 | year = {2020}
20 | }
21 |
22 | ## Installation
23 | First you have to make sure that you have all dependencies in place.
24 | The simplest way to do so, is to use [anaconda](https://www.anaconda.com/).
25 |
26 | You can create an anaconda environment called `graf` using
27 | ```
28 | conda env create -f environment.yml
29 | conda activate graf
30 | ```
31 |
32 | Next, for nerf-pytorch install torchsearchsorted. Note that this requires `torch>=1.4.0` and `CUDA >= v10.1`.
33 | You can install torchsearchsorted via
34 | ```
35 | cd submodules/nerf_pytorch
36 | pip install -r requirements.txt
37 | cd torchsearchsorted
38 | pip install .
39 | cd ../../../
40 | ```
41 |
42 | ## Demo
43 |
44 | You can now test our code via:
45 | ```
46 | python eval.py configs/carla.yaml --pretrained --rotation_elevation
47 | ```
48 | This script should create a folder `results/carla_128_from_pretrained/eval/` where you can find generated videos varying camera pose for the Cars dataset.
49 |
50 | ## Datasets
51 |
52 | If you only want to generate images using our pretrained models you do not need to download the datasets.
53 | The datasets are only needed if you want to train a model from scratch.
54 |
55 | ### Cars
56 |
57 | To download the Cars dataset from the paper simply run
58 | ```
59 | cd data
60 | ./download_carla.sh
61 | cd ..
62 | ```
63 | This creates a folder `data/carla/` downloads the images as a zip file and extracts them to `data/carla/`.
64 | While we do not use camera poses in this project we provide them for completeness. Your can download them by running
65 | ```
66 | cd data
67 | ./download_carla_poses.sh
68 | cd ..
69 | ```
70 | This downloads the camera intrinsics (single file, equal for all images) and extrinsics corresponding to each image.
71 |
72 | ### Faces
73 |
74 | Download [celebA](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html).
75 | Then replace `data/celebA` in `configs/celebA.yaml` with `*PATH/TO/CELEBA*/Img/img_align_celebA`.
76 |
77 | Download [celebA_hq](https://github.com/tkarras/progressive_growing_of_gans).
78 | Then replace `data/celebA_hq` in `configs/celebAHQ.yaml` with `*PATH/TO/CELEBA_HQ*`.
79 |
80 | ### Cats
81 | Download the [CatDataset](https://www.kaggle.com/crawford/cat-dataset).
82 | Run
83 | ```
84 | cd data
85 | python preprocess_cats.py PATH/TO/CATS/DATASET
86 | cd ..
87 | ```
88 | to preprocess the data and save it to `data/cats`.
89 | If successful this script should print: `Preprocessed 9407 images.`
90 |
91 | ### Birds
92 | Download [CUB-200-2011](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html) and the corresponding [Segmentation Masks](https://drive.google.com/file/d/1EamOKGLoTuZdtcVYbHMWNpkn3iAVj8TP/view).
93 | Run
94 | ```
95 | cd data
96 | python preprocess_cub.py PATH/TO/CUB-200-2011 PATH/TO/SEGMENTATION/MASKS
97 | cd ..
98 | ```
99 | to preprocess the data and save it to `data/cub`.
100 | If successful this script should print: `Preprocessed 8444 images.`
101 |
102 | ## Usage
103 |
104 | When you have installed all dependencies, you are ready to run our pre-trained models for 3D-aware image synthesis.
105 |
106 | ### Generate images using a pretrained model
107 |
108 | To evaluate a pretrained model, run
109 | ```
110 | python eval.py CONFIG.yaml --pretrained --fid_kid --rotation_elevation --shape_appearance
111 | ```
112 | where you replace CONFIG.yaml with one of the config files in `./configs`.
113 |
114 | This script should create a folder `results/EXPNAME/eval` with FID and KID scores in `fid_kid.csv`, videos for rotation and elevation in the respective folders and an interpolation for shape and appearance, `shape_appearance.png`.
115 |
116 | Note that some pretrained models are available for different image sizes which you can choose by setting `data:imsize` in the config file to one of the following values:
117 | ```
118 | configs/carla.yaml:
119 | data:imsize 64 or 128 or 256 or 512
120 | configs/celebA.yaml:
121 | data:imsize 64 or 128
122 | configs/celebAHQ.yaml:
123 | data:imsize 256 or 512
124 | ```
125 |
126 | ### Train a model from scratch
127 |
128 | To train a 3D-aware generative model from scratch run
129 | ```
130 | python train.py CONFIG.yaml
131 | ```
132 | where you replace `CONFIG.yaml` with your config file.
133 | The easiest way is to use one of the existing config files in the `./configs` directory
134 | which correspond to the experiments presented in the paper.
135 | Note that this will train the model from scratch and will not resume training for a pretrained model.
136 |
137 | You can monitor on the training process using [tensorboard](https://www.tensorflow.org/guide/summaries_and_tensorboard):
138 | ```
139 | cd OUTPUT_DIR
140 | tensorboard --logdir ./monitoring --port 6006
141 | ```
142 | where you replace `OUTPUT_DIR` with the respective output directory.
143 |
144 | For available training options, please take a look at `configs/default.yaml`.
145 |
146 | ### Evaluation of a new model
147 |
148 | For evaluation of the models run
149 | ```
150 | python eval.py CONFIG.yaml --fid_kid --rotation_elevation --shape_appearance
151 | ```
152 | where you replace `CONFIG.yaml` with your config file.
153 |
154 | ## Multi-View Consistency Check
155 |
156 | You can evaluate the multi-view consistency of the generated images by running a Multi-View-Stereo (MVS) algorithm on the generated images. This evaluation uses [COLMAP](https://colmap.github.io/) and make sure that you have COLMAP installed to run
157 | ```
158 | python eval.py CONFIG.yaml --reconstruction
159 | ```
160 | where you replace `CONFIG.yaml` with your config file. You can also evaluate our pretrained models via:
161 | ```
162 | python eval.py configs/carla.yaml --pretrained --reconstruction
163 | ```
164 | This script should create a folder `results/EXPNAME/eval/reconstruction/` where you can find generated multi-view images in `images/` and the corresponding 3D reconstructions in `models/`.
165 |
166 | ## Further Information
167 |
168 | ### GAN training
169 |
170 | This repository uses Lars Mescheder's awesome framework for [GAN training](https://github.com/LMescheder/GAN_stability).
171 |
172 | ### NeRF
173 |
174 | We base our code for the Generator on this great [Pytorch reimplementation](https://github.com/yenchenlin/nerf-pytorch) of Neural Radiance Fields.
175 |
--------------------------------------------------------------------------------
/animations/carla_256.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/autonomousvision/graf/c50d342fb567aec335b92e3f867c54b4dc4e1d09/animations/carla_256.gif
--------------------------------------------------------------------------------
/configs/carla.yaml:
--------------------------------------------------------------------------------
1 | expname: carla_128
2 | data:
3 | imsize: 128
4 | datadir: data/carla
5 | type: carla
6 | radius: 10.
7 | near: 7.5
8 | far: 12.5
9 | fov: 30.0
10 |
--------------------------------------------------------------------------------
/configs/cats.yaml:
--------------------------------------------------------------------------------
1 | expname: cats_64
2 | data:
3 | datadir: data/cats
4 | type: cats
5 | imsize: 64
6 | white_bkgd: False
7 | radius: 10
8 | near: 7.5
9 | far: 12.5
10 | fov: 10
11 | umin: 0
12 | umax: 0.19444444444444445 #70 deg
13 | vmin: 0.32898992833716556 # 70 deg
14 | vmax: 0.45642212862617093 # 85 deg
15 |
16 |
--------------------------------------------------------------------------------
/configs/celebA.yaml:
--------------------------------------------------------------------------------
1 | expname: celebA_64
2 | data:
3 | datadir: /PATH/TO/CELEBA/Img/img_align_celebA
4 | type: celebA
5 | imsize: 64
6 | white_bkgd: False
7 | radius: 9.5,10.5
8 | near: 7.5
9 | far: 12.5
10 | fov: 10.
11 | umin: 0
12 | umax: 0.25
13 | vmin: 0.32898992833716556 # 70 deg
14 | vmax: 0.45642212862617093 # 85 deg
15 |
16 |
--------------------------------------------------------------------------------
/configs/celebAHQ.yaml:
--------------------------------------------------------------------------------
1 | expname: celebAHQ_256
2 | data:
3 | datadir: /PATH/TO/CELEBA_HQ
4 | type: celebA_hq
5 | imsize: 256
6 | white_bkgd: False
7 | radius: 9.5,10.5
8 | near: 7.5
9 | far: 12.5
10 | fov: 10
11 | umin: 0
12 | umax: 0.25
13 | vmin: 0.32898992833716556 # 70 deg
14 | vmax: 0.45642212862617093 # 85 deg
15 | ray_sampler:
16 | min_scale: 0.125
17 | scale_anneal: 0.0019
18 | training:
19 | fid_every: 10000
20 |
--------------------------------------------------------------------------------
/configs/cub.yaml:
--------------------------------------------------------------------------------
1 | expname: cub_64
2 | data:
3 | imsize: 64
4 | datadir: data/cub
5 | type: cub
6 | radius: 9,11
7 | near: 7.5
8 | far: 12.5
9 | fov: 30
10 | vmin: 0.24999999999999994 # 60 deg
11 | vmax: 0.5435778713738291 # 95 deg
12 | discriminator:
13 | hflip: True
14 | nerf:
15 | use_viewdirs: False
--------------------------------------------------------------------------------
/configs/debug.yaml:
--------------------------------------------------------------------------------
1 | expname: debug
2 | data:
3 | imsize: 64
4 | datadir: data/cub
5 | type: cub
6 | radius: 10.
7 | near: 7.5
8 | far: 12.5
9 | fov: 30.0
10 | training:
11 | batch_size: 2
12 | nworkers: 0
13 | fid_every: -1
14 |
--------------------------------------------------------------------------------
/configs/default.yaml:
--------------------------------------------------------------------------------
1 | expname: default
2 | data:
3 | datadir: data/carla
4 | type: carla
5 | imsize: 64
6 | white_bkgd: True
7 | near: 1.
8 | far: 6.
9 | radius: 3.4 # set according to near and far plane
10 | fov: 90.
11 | orthographic: False
12 | umin: 0. # 0 deg, convert to degree via 360. * u
13 | umax: 1. # 360 deg, convert to degree via 360. * u
14 | vmin: 0. # 0 deg, convert to degrees via arccos(1 - 2 * v) * 180. / pi
15 | vmax: 0.45642212862617093 # 85 deg, convert to degrees via arccos(1 - 2 * v) * 180. / pi
16 | nerf:
17 | i_embed: 0
18 | use_viewdirs: True
19 | multires: 10
20 | multires_views: 4
21 | N_samples: 64
22 | N_importance: 0
23 | netdepth: 8
24 | netwidth: 256
25 | netdepth_fine: 8
26 | netwidth_fine: 256
27 | perturb: 1.
28 | raw_noise_std: 1.
29 | decrease_noise: True
30 | z_dist:
31 | type: gauss
32 | dim: 256
33 | dim_appearance: 128 # This dimension is subtracted from "dim"
34 | ray_sampler:
35 | min_scale: 0.25
36 | max_scale: 1.
37 | scale_anneal: 0.0025 # no effect if scale_anneal<0, else the minimum scale decreases exponentially until converge to min_scale
38 | N_samples: 1024 # 32*32, patchsize
39 | discriminator:
40 | ndf: 64
41 | hflip: False # Randomly flip discriminator input horizontally
42 | training:
43 | outdir: ./results
44 | model_file: model.pt
45 | monitoring: tensorboard
46 | use_amp: False # Use automated mixed precision
47 | nworkers: 6
48 | batch_size: 8
49 | chunk: 32768 # 1024*32
50 | netchunk: 65536 # 1024*64
51 | lr_g: 0.0005
52 | lr_d: 0.0001
53 | lr_anneal: 0.5
54 | lr_anneal_every: 50000,100000,200000
55 | equalize_lr: False
56 | gan_type: standard
57 | reg_type: real
58 | reg_param: 10.
59 | optimizer: rmsprop
60 | n_test_samples_with_same_shape_code: 4
61 | take_model_average: true
62 | model_average_beta: 0.999
63 | model_average_reinit: false
64 | restart_every: -1
65 | save_best: fid
66 | fid_every: 5000 # Valid for FID and KID
67 | print_every: 10
68 | sample_every: 500
69 | save_every: 900
70 | backup_every: 50000
71 | video_every: 10000
72 |
--------------------------------------------------------------------------------
/configs/pretrained_models.yaml:
--------------------------------------------------------------------------------
1 | carla:
2 | 64: https://s3.eu-central-1.amazonaws.com/avg-projects/graf/models/carla/carla_64.pt
3 | 128: https://s3.eu-central-1.amazonaws.com/avg-projects/graf/models/carla/carla_128.pt
4 | 256: https://s3.eu-central-1.amazonaws.com/avg-projects/graf/models/carla/carla_256.pt
5 | 512: https://s3.eu-central-1.amazonaws.com/avg-projects/graf/models/carla/carla_512.pt
6 | celebA:
7 | 64: https://s3.eu-central-1.amazonaws.com/avg-projects/graf/models/faces/celebA_64.pt
8 | 128: https://s3.eu-central-1.amazonaws.com/avg-projects/graf/models/faces/celebA_128.pt
9 | celebA_hq:
10 | 256: https://s3.eu-central-1.amazonaws.com/avg-projects/graf/models/faces/celebA_hq_256.pt
11 | 512: https://s3.eu-central-1.amazonaws.com/avg-projects/graf/models/faces/celebA_hq_512.pt
12 | cats:
13 | 64: https://s3.eu-central-1.amazonaws.com/avg-projects/graf/models/cats/cats_64.pt
14 | cub:
15 | 64: https://s3.eu-central-1.amazonaws.com/avg-projects/graf/models/birds/cub_64.pt
--------------------------------------------------------------------------------
/data/download_carla.sh:
--------------------------------------------------------------------------------
1 | wget https://s3.eu-central-1.amazonaws.com/avg-projects/graf/data/carla.zip
2 | unzip carla.zip
3 | cd ..
4 |
--------------------------------------------------------------------------------
/data/download_carla_poses.sh:
--------------------------------------------------------------------------------
1 | mkdir -p ./carla
2 | cd ./carla
3 | wget https://s3.eu-central-1.amazonaws.com/avg-projects/graf/data/carla_poses.zip
4 | unzip carla_poses.zip
5 | cd ..
--------------------------------------------------------------------------------
/data/preprocess_cats.py:
--------------------------------------------------------------------------------
1 | ### adapted from https://github.com/AlexiaJM/RelativisticGAN/blob/master/code/preprocess_cat_dataset.py
2 | ### original code from https://github.com/microe/angora-blue/blob/master/cascade_training/describe.py by Erik Hovland
3 | import argparse
4 | import cv2
5 | import glob
6 | import math
7 | import os
8 | from tqdm import tqdm
9 |
10 |
11 | def rotateCoords(coords, center, angleRadians):
12 | # Positive y is down so reverse the angle, too.
13 | angleRadians = -angleRadians
14 | xs, ys = coords[::2], coords[1::2]
15 | newCoords = []
16 | n = min(len(xs), len(ys))
17 | i = 0
18 | centerX = center[0]
19 | centerY = center[1]
20 | cosAngle = math.cos(angleRadians)
21 | sinAngle = math.sin(angleRadians)
22 | while i < n:
23 | xOffset = xs[i] - centerX
24 | yOffset = ys[i] - centerY
25 | newX = xOffset * cosAngle - yOffset * sinAngle + centerX
26 | newY = xOffset * sinAngle + yOffset * cosAngle + centerY
27 | newCoords += [newX, newY]
28 | i += 1
29 | return newCoords
30 |
31 |
32 | def preprocessCatFace(coords, image):
33 | leftEyeX, leftEyeY = coords[0], coords[1]
34 | rightEyeX, rightEyeY = coords[2], coords[3]
35 | mouthX = coords[4]
36 | if leftEyeX > rightEyeX and leftEyeY < rightEyeY and \
37 | mouthX > rightEyeX:
38 | # The "right eye" is in the second quadrant of the face,
39 | # while the "left eye" is in the fourth quadrant (from the
40 | # viewer's perspective.) Swap the eyes' labels in order to
41 | # simplify the rotation logic.
42 | leftEyeX, rightEyeX = rightEyeX, leftEyeX
43 | leftEyeY, rightEyeY = rightEyeY, leftEyeY
44 |
45 | eyesCenter = (0.5 * (leftEyeX + rightEyeX),
46 | 0.5 * (leftEyeY + rightEyeY))
47 |
48 | eyesDeltaX = rightEyeX - leftEyeX
49 | eyesDeltaY = rightEyeY - leftEyeY
50 | eyesAngleRadians = math.atan2(eyesDeltaY, eyesDeltaX)
51 | eyesAngleDegrees = eyesAngleRadians * 180.0 / math.pi
52 |
53 | # Straighten the image and fill in gray for blank borders.
54 | rotation = cv2.getRotationMatrix2D(
55 | eyesCenter, eyesAngleDegrees, 1.0)
56 | imageSize = image.shape[1::-1]
57 | straight = cv2.warpAffine(image, rotation, imageSize,
58 | borderValue=(128, 128, 128))
59 |
60 | # Straighten the coordinates of the features.
61 | newCoords = rotateCoords(
62 | coords, eyesCenter, eyesAngleRadians)
63 |
64 | # Make the face as wide as the space between the ear bases.
65 | w = abs(newCoords[16] - newCoords[6])
66 | # Make the face square.
67 | h = w
68 | # Put the center point between the eyes at (0.5, 0.4) in
69 | # proportion to the entire face.
70 | minX = eyesCenter[0] - w / 2
71 | if minX < 0:
72 | w += minX
73 | minX = 0
74 | minY = eyesCenter[1] - h * 2 / 5
75 | if minY < 0:
76 | h += minY
77 | minY = 0
78 |
79 | # Crop the face.
80 | crop = straight[int(minY):int(minY + h), int(minX):int(minX + w)]
81 | # Return the crop.
82 | return crop
83 |
84 |
85 | def describePositive(root, outdir):
86 | filenames = glob.glob('%s/CAT_*/*.jpg' % root)
87 |
88 | for imagePath in tqdm(filenames, total=len(filenames), desc='Process images...'):
89 | # Open the '.cat' annotation file associated with this
90 | # image.
91 | if not os.path.isfile('%s.cat' % imagePath):
92 | print('.cat file missing at %s' % imagePath)
93 | continue
94 | input = open('%s.cat' % imagePath, 'r')
95 | # Read the coordinates of the cat features from the
96 | # file. Discard the first number, which is the number
97 | # of features.
98 | coords = [int(i) for i in input.readline().split()[1:]]
99 | # Read the image.
100 | image = cv2.imread(imagePath)
101 | # Straighten and crop the cat face.
102 | crop = preprocessCatFace(coords, image)
103 | if crop is None:
104 | print('Failed to preprocess image at %s' % imagePath)
105 | continue
106 | # Save the crop to folders based on size
107 | h, w, colors = crop.shape
108 | if min(h, w) >= 64:
109 | Path1 = imagePath.replace(root, outdir)
110 | os.makedirs(os.path.dirname(Path1), exist_ok=True)
111 | resized_crop = cv2.resize(crop, (64, 64))
112 | cv2.imwrite(Path1, resized_crop)
113 |
114 |
115 | if __name__ == '__main__':
116 | # Arguments
117 | parser = argparse.ArgumentParser(
118 | description='Crop cats from the CatDataset.'
119 | )
120 | parser.add_argument('root', type=str, help='Path to data directory containing "CAT_00" - "CAT_06" folders.')
121 | args = parser.parse_args()
122 |
123 | outdir = './cats'
124 | os.makedirs(outdir, exist_ok=True)
125 |
126 | describePositive(args.root, outdir)
127 | print('Preprocessed {} images.'.format(len(glob.glob(os.path.join(outdir, '*/*.jpg')))))
--------------------------------------------------------------------------------
/data/preprocess_cub.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from tqdm import tqdm
4 | from PIL import Image
5 | import numpy as np
6 | import glob
7 | from torchvision.transforms import CenterCrop
8 |
9 |
10 | if __name__ == '__main__':
11 | # Arguments
12 | parser = argparse.ArgumentParser(
13 | description='Select split from CUB200-2011 and crop birds from images.'
14 | )
15 | parser.add_argument('root', type=str,
16 | help='Path to data directory containing bounding box file and "images" folder.')
17 | parser.add_argument('maskdir', type=str, help='Path to data directory containing the segmentation masks.')
18 | args = parser.parse_args()
19 |
20 | imdir = os.path.join(args.root, 'images')
21 | bboxfile = os.path.join(args.root, 'bounding_boxes.txt')
22 | maskdir = args.maskdir
23 | namefile = './cub/filtered_files.txt'
24 | outdir = './cub'
25 | os.makedirs(outdir, exist_ok=True)
26 |
27 | # load files
28 | with open(namefile, 'r') as f:
29 | id_filename = [line.split(' ') for line in f.read().splitlines()]
30 |
31 | # load bounding boxes
32 | boxes = {}
33 | with open(bboxfile, 'r') as f:
34 | for line in f.read().splitlines():
35 | k, x, y, w, h = line.split(' ')
36 | box = float(x), float(y), float(x) + float(w), float(y) + float(h) # (left, up, right, down)
37 | boxes[k] = box
38 |
39 | for i, (id, filename) in tqdm(enumerate(id_filename), total=len(id_filename)):
40 | path = os.path.join(imdir, filename)
41 | img = Image.open(path).convert('RGBA')
42 |
43 | # load alpha
44 | path = os.path.join(maskdir, filename.replace('.jpg', '.png'))
45 | alpha = Image.open(path)
46 | if alpha.mode == 'RGBA':
47 | alpha = alpha.split()[-1]
48 | alpha = alpha.convert('L')
49 | img.putalpha(alpha)
50 |
51 | # crop square images using bbox
52 | img = img.crop(boxes[id])
53 | s = max(img.size)
54 | img = CenterCrop(s)(img) # CenterCrop pads image to square using zeros (also for alpha)
55 |
56 | # composite
57 | img = np.array(img)
58 | alpha = (img[..., 3:4]) > 127 # convert to binary mask
59 | bg = np.array(255 * (1. - alpha), np.uint8)
60 | img = img[..., :3] * alpha + bg
61 | img = Image.fromarray(img)
62 |
63 | img.save(os.path.join(outdir, '%06d.png' % i))
64 |
65 | print('Preprocessed {} images.'.format(len(glob.glob(os.path.join(outdir, '*.png')))))
66 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: graf
2 | channels:
3 | - conda-forge
4 | - anaconda
5 | - defaults
6 | dependencies:
7 | - blas=1.0=mkl
8 | - ca-certificates=2020.11.8=ha878542_0
9 | - certifi=2020.11.8=py38h578d9bd_0
10 | - intel-openmp=2020.2=254
11 | - joblib=0.17.0=py_0
12 | - ld_impl_linux-64=2.33.1=h53a641e_7
13 | - libedit=3.1.20191231=h14c3975_1
14 | - libffi=3.3=he6710b0_2
15 | - libgcc-ng=9.1.0=hdf63c60_0
16 | - libgfortran-ng=7.3.0=hdf63c60_0
17 | - libprotobuf=3.13.0.1=h8b12597_0
18 | - libstdcxx-ng=9.1.0=hdf63c60_0
19 | - mkl=2019.4=243
20 | - mkl-service=2.3.0=py38he904b0f_0
21 | - mkl_fft=1.2.0=py38h23d657b_0
22 | - mkl_random=1.1.0=py38h962f231_0
23 | - ncurses=6.2=he6710b0_1
24 | - numpy-base=1.19.1=py38hfa32c7d_0
25 | - openssl=1.1.1h=h516909a_0
26 | - pip=20.2.4=py38_0
27 | - python=3.8.5=h7579374_1
28 | - python_abi=3.8=1_cp38
29 | - pyyaml=5.3.1=py38h7b6447c_1
30 | - readline=8.0=h7b6447c_0
31 | - scikit-learn=0.23.2=py38h0573a6f_0
32 | - scipy=1.5.2=py38h0b6359f_0
33 | - setuptools=50.3.0=py38hb0f4dca_1
34 | - six=1.15.0=py_0
35 | - sqlite=3.33.0=h62c20be_0
36 | - tensorboardx=2.1=py_0
37 | - threadpoolctl=2.1.0=pyh5ca1d4c_0
38 | - tk=8.6.10=hbc83047_0
39 | - wheel=0.35.1=py_0
40 | - xz=5.2.5=h7b6447c_0
41 | - yaml=0.2.5=h7b6447c_0
42 | - zlib=1.2.11=h7b6447c_3
43 | - pip:
44 | - absl-py==0.11.0
45 | - configargparse==1.2.3
46 | - cycler==0.10.0
47 | - dataclasses==0.6
48 | - future==0.18.2
49 | - grpcio==1.33.2
50 | - imageio==2.9.0
51 | - imageio-ffmpeg==0.4.2
52 | - kiwisolver==1.3.1
53 | - markdown==3.3.3
54 | - matplotlib==3.3.3
55 | - numpy==1.19.4
56 | - opencv-python==4.4.0.46
57 | - pillow==8.0.1
58 | - protobuf==3.14.0
59 | - pyparsing==2.4.7
60 | - python-dateutil==2.8.1
61 | - tensorboard==1.14.0
62 | - torch==1.7.0
63 | - torchvision==0.8.1
64 | - tqdm==4.53.0
65 | - typing-extensions==3.7.4.3
66 | - werkzeug==1.0.1
67 |
--------------------------------------------------------------------------------
/external/colmap/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/autonomousvision/graf/c50d342fb567aec335b92e3f867c54b4dc4e1d09/external/colmap/__init__.py
--------------------------------------------------------------------------------
/external/colmap/filter_points.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import glob
4 | import struct
5 | import numpy as np
6 |
7 | def readBinaryPly(pcdFile, fmt='ffffffBBB', fmt_len=27):
8 |
9 | with open(pcdFile, 'rb') as f:
10 | plyData = f.readlines()
11 |
12 | headLine = plyData.index(b'end_header\n')+1
13 | plyData = plyData[headLine:]
14 | plyData = b"".join(plyData)
15 |
16 | n_pts_loaded = int(len(plyData)/fmt_len)
17 |
18 | data = []
19 | for i in range(n_pts_loaded):
20 | pts=struct.unpack(fmt, plyData[i*fmt_len:(i+1)*fmt_len])
21 | data.append(pts)
22 | data=np.asarray(data)
23 |
24 | return data
25 |
26 | def writeBinaryPly(pcdFile, data):
27 | fmt = '=ffffffBBB'
28 | fmt_len = 27
29 | n_pts = data.shape[0]
30 |
31 | with open(pcdFile, 'wb') as f:
32 | f.write(b'ply\n')
33 | f.write(b'format binary_little_endian 1.0\n')
34 | f.write(b'comment\n')
35 | f.write(b'element vertex %d\n' % n_pts)
36 | f.write(b'property float x\n')
37 | f.write(b'property float y\n')
38 | f.write(b'property float z\n')
39 | f.write(b'property float nx\n')
40 | f.write(b'property float ny\n')
41 | f.write(b'property float nz\n')
42 | f.write(b'property uchar red\n')
43 | f.write(b'property uchar green\n')
44 | f.write(b'property uchar blue\n')
45 | f.write(b'end_header\n')
46 |
47 | for i in range(n_pts):
48 | f.write(struct.pack(fmt, *data[i,0:6], *data[i,6:9].astype(np.uint8)))
49 |
50 |
51 | def filter_ply(object_dir):
52 |
53 | ply_files = sorted(glob.glob(os.path.join(object_dir, 'dense', '*', 'fused.ply')))
54 |
55 | for ply_file in ply_files:
56 | ply_filter_file = ply_file.replace('.ply', '_filtered.ply')
57 | plydata = readBinaryPly(ply_file)
58 | vertex = plydata[:,0:3]
59 | normal = plydata[:,3:6]
60 | color = plydata[:,6:9]
61 |
62 | mask = np.mean(color,1)<(0.85 * 255.)
63 | color = color[mask, :]
64 | normal = normal[mask, :]
65 | vertex = vertex[mask, :]
66 | plydata = np.hstack((vertex, normal, color))
67 | writeBinaryPly(ply_filter_file, plydata)
68 | print('Processed file {}'.format(ply_filter_file))
69 |
70 | if __name__=='__main__':
71 |
72 | object_dir=sys.argv[1]
73 | filter_ply(object_dir)
74 |
--------------------------------------------------------------------------------
/external/colmap/run_colmap_automatic.sh:
--------------------------------------------------------------------------------
1 | #set -e
2 |
3 | input_dir=$1
4 | output_dir=$2
5 |
6 | mkdir -p ${output_dir}
7 | echo Processing ${input_dir} ...
8 | colmap automatic_reconstructor \
9 | --workspace_path ${output_dir} \
10 | --image_path ${input_dir}/ \
11 | --single_camera=1 \
12 | --dense=1 \
13 | --gpu_index=0
14 |
--------------------------------------------------------------------------------
/graf/config.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torchvision.transforms import *
4 |
5 | from .datasets import *
6 | from .transforms import FlexGridRaySampler
7 | from .utils import polar_to_cartesian, look_at, to_phi, to_theta
8 |
9 |
10 | def save_config(outpath, config):
11 | from yaml import safe_dump
12 | with open(outpath, 'w') as f:
13 | safe_dump(config, f)
14 |
15 |
16 | def update_config(config, unknown):
17 | # update config given args
18 | for idx,arg in enumerate(unknown):
19 | if arg.startswith("--"):
20 | if (':') in arg:
21 | k1,k2 = arg.replace("--","").split(':')
22 | argtype = type(config[k1][k2])
23 | if argtype == bool:
24 | v = unknown[idx+1].lower() == 'true'
25 | else:
26 | if config[k1][k2] is not None:
27 | v = type(config[k1][k2])(unknown[idx+1])
28 | else:
29 | v = unknown[idx+1]
30 | print(f'Changing {k1}:{k2} ---- {config[k1][k2]} to {v}')
31 | config[k1][k2] = v
32 | else:
33 | k = arg.replace('--','')
34 | v = unknown[idx+1]
35 | argtype = type(config[k])
36 | print(f'Changing {k} ---- {config[k]} to {v}')
37 | config[k] = v
38 |
39 | return config
40 |
41 |
42 | def get_data(config):
43 | H = W = imsize = config['data']['imsize']
44 | dset_type = config['data']['type']
45 | fov = config['data']['fov']
46 |
47 | transforms = Compose([
48 | Resize(imsize),
49 | ToTensor(),
50 | Lambda(lambda x: x * 2 - 1),
51 | ])
52 |
53 | kwargs = {
54 | 'data_dirs': config['data']['datadir'],
55 | 'transforms': transforms
56 | }
57 |
58 | if dset_type == 'carla':
59 | dset = Carla(**kwargs)
60 |
61 | elif dset_type == 'celebA':
62 | assert imsize <= 128, 'cropped GT data has lower resolution than imsize, consider using celebA_hq instead'
63 | transforms.transforms.insert(0, RandomHorizontalFlip())
64 | transforms.transforms.insert(0, CenterCrop(108))
65 |
66 | dset = CelebA(**kwargs)
67 |
68 | elif dset_type == 'celebA_hq':
69 | transforms.transforms.insert(0, RandomHorizontalFlip())
70 | transforms.transforms.insert(0, CenterCrop(650))
71 |
72 | dset = CelebAHQ(**kwargs)
73 |
74 | elif dset_type == 'cats':
75 | transforms.transforms.insert(0, RandomHorizontalFlip())
76 | dset = Cats(**kwargs)
77 |
78 | elif dset_type == 'cub':
79 | dset = CUB(**kwargs)
80 |
81 | dset.H = dset.W = imsize
82 | dset.focal = W/2 * 1 / np.tan((.5 * fov * np.pi/180.))
83 | radius = config['data']['radius']
84 | render_radius = radius
85 | if isinstance(radius, str):
86 | radius = tuple(float(r) for r in radius.split(','))
87 | render_radius = max(radius)
88 | dset.radius = radius
89 |
90 | # compute render poses
91 | N = 40
92 | theta = 0.5 * (to_theta(config['data']['vmin']) + to_theta(config['data']['vmax']))
93 | angle_range = (to_phi(config['data']['umin']), to_phi(config['data']['umax']))
94 | render_poses = get_render_poses(render_radius, angle_range=angle_range, theta=theta, N=N)
95 |
96 | print('Loaded {}'.format(dset_type), imsize, len(dset), render_poses.shape, [H,W,dset.focal,dset.radius], config['data']['datadir'])
97 | return dset, [H,W,dset.focal,dset.radius], render_poses
98 |
99 |
100 | def get_render_poses(radius, angle_range=(0, 360), theta=0, N=40, swap_angles=False):
101 | poses = []
102 | theta = max(0.1, theta)
103 | for angle in np.linspace(angle_range[0],angle_range[1],N+1)[:-1]:
104 | angle = max(0.1, angle)
105 | if swap_angles:
106 | loc = polar_to_cartesian(radius, theta, angle, deg=True)
107 | else:
108 | loc = polar_to_cartesian(radius, angle, theta, deg=True)
109 | R = look_at(loc)[0]
110 | RT = np.concatenate([R, loc.reshape(3, 1)], axis=1)
111 | poses.append(RT)
112 | return torch.from_numpy(np.stack(poses))
113 |
114 |
115 | def build_models(config, disc=True):
116 | from argparse import Namespace
117 | from submodules.nerf_pytorch.run_nerf_mod import create_nerf
118 | from .models.generator import Generator
119 | from .models.discriminator import Discriminator
120 |
121 | config_nerf = Namespace(**config['nerf'])
122 | # Update config for NERF
123 | config_nerf.chunk = min(config['training']['chunk'], 1024*config['training']['batch_size']) # let batch size for training with patches limit the maximal memory
124 | config_nerf.netchunk = config['training']['netchunk']
125 | config_nerf.white_bkgd = config['data']['white_bkgd']
126 | config_nerf.feat_dim = config['z_dist']['dim']
127 | config_nerf.feat_dim_appearance = config['z_dist']['dim_appearance']
128 |
129 | render_kwargs_train, render_kwargs_test, params, named_parameters = create_nerf(config_nerf)
130 |
131 | bds_dict = {'near': config['data']['near'], 'far': config['data']['far']}
132 | render_kwargs_train.update(bds_dict)
133 | render_kwargs_test.update(bds_dict)
134 |
135 | ray_sampler = FlexGridRaySampler(N_samples=config['ray_sampler']['N_samples'],
136 | min_scale=config['ray_sampler']['min_scale'],
137 | max_scale=config['ray_sampler']['max_scale'],
138 | scale_anneal=config['ray_sampler']['scale_anneal'],
139 | orthographic=config['data']['orthographic'])
140 |
141 | H, W, f, r = config['data']['hwfr']
142 | generator = Generator(H, W, f, r,
143 | ray_sampler=ray_sampler,
144 | render_kwargs_train=render_kwargs_train, render_kwargs_test=render_kwargs_test,
145 | parameters=params, named_parameters=named_parameters,
146 | chunk=config_nerf.chunk,
147 | range_u=(float(config['data']['umin']), float(config['data']['umax'])),
148 | range_v=(float(config['data']['vmin']), float(config['data']['vmax'])),
149 | orthographic=config['data']['orthographic'],
150 | )
151 |
152 | discriminator = None
153 | if disc:
154 | disc_kwargs = {'nc': 3, # channels for patch discriminator
155 | 'ndf': config['discriminator']['ndf'],
156 | 'imsize': int(np.sqrt(config['ray_sampler']['N_samples'])),
157 | 'hflip': config['discriminator']['hflip']}
158 |
159 | discriminator = Discriminator(**disc_kwargs)
160 |
161 | return generator, discriminator
162 |
163 |
164 | def build_lr_scheduler(optimizer, config, last_epoch=-1):
165 | import torch.optim as optim
166 | step_size = config['training']['lr_anneal_every']
167 | if isinstance(step_size, str):
168 | milestones = [int(m) for m in step_size.split(',')]
169 | lr_scheduler = optim.lr_scheduler.MultiStepLR(
170 | optimizer,
171 | milestones=milestones,
172 | gamma=config['training']['lr_anneal'],
173 | last_epoch=last_epoch)
174 | else:
175 | lr_scheduler = optim.lr_scheduler.StepLR(
176 | optimizer,
177 | step_size=step_size,
178 | gamma=config['training']['lr_anneal'],
179 | last_epoch=last_epoch
180 | )
181 | return lr_scheduler
182 |
--------------------------------------------------------------------------------
/graf/datasets.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import numpy as np
3 | from PIL import Image
4 |
5 | from torchvision.datasets.vision import VisionDataset
6 |
7 |
8 | class ImageDataset(VisionDataset):
9 | """
10 | Load images from multiple data directories.
11 | Folder structure: data_dir/filename.png
12 | """
13 |
14 | def __init__(self, data_dirs, transforms=None):
15 | # Use multiple root folders
16 | if not isinstance(data_dirs, list):
17 | data_dirs = [data_dirs]
18 |
19 | # initialize base class
20 | VisionDataset.__init__(self, root=data_dirs, transform=transforms)
21 |
22 | self.filenames = []
23 | root = []
24 |
25 | for ddir in self.root:
26 | filenames = self._get_files(ddir)
27 | self.filenames.extend(filenames)
28 | root.append(ddir)
29 |
30 | def __len__(self):
31 | return len(self.filenames)
32 |
33 | @staticmethod
34 | def _get_files(root_dir):
35 | return glob.glob(f'{root_dir}/*.png') + glob.glob(f'{root_dir}/*.jpg')
36 |
37 | def __getitem__(self, idx):
38 | filename = self.filenames[idx]
39 | img = Image.open(filename).convert('RGB')
40 | if self.transform is not None:
41 | img = self.transform(img)
42 | return img
43 |
44 |
45 | class Carla(ImageDataset):
46 | def __init__(self, *args, **kwargs):
47 | super(Carla, self).__init__(*args, **kwargs)
48 |
49 |
50 | class CelebA(ImageDataset):
51 | def __init__(self, *args, **kwargs):
52 | super(CelebA, self).__init__(*args, **kwargs)
53 |
54 |
55 | class CUB(ImageDataset):
56 | def __init__(self, *args, **kwargs):
57 | super(CUB, self).__init__(*args, **kwargs)
58 |
59 |
60 | class Cats(ImageDataset):
61 | def __init__(self, *args, **kwargs):
62 | super(Cats, self).__init__(*args, **kwargs)
63 |
64 | @staticmethod
65 | def _get_files(root_dir):
66 | return glob.glob(f'{root_dir}/CAT_*/*.jpg')
67 |
68 |
69 | class CelebAHQ(ImageDataset):
70 | def __init__(self, *args, **kwargs):
71 | super(CelebAHQ, self).__init__(*args, **kwargs)
72 |
73 | def _get_files(self, root):
74 | return glob.glob(f'{root}/*.npy')
75 |
76 | def __getitem__(self, idx):
77 | img = np.load(self.filenames[idx]).squeeze(0).transpose(1,2,0)
78 | if img.dtype == np.uint8:
79 | pass
80 | elif img.dtype == np.float32:
81 | img = (img * 255).astype(np.uint8)
82 | else:
83 | raise NotImplementedError
84 | img = Image.fromarray(img).convert('RGB')
85 | if self.transform is not None:
86 | img = self.transform(img)
87 |
88 | return img
89 |
--------------------------------------------------------------------------------
/graf/gan_training.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import os
4 | from tqdm import tqdm
5 |
6 | from submodules.GAN_stability.gan_training.train import toggle_grad, Trainer as TrainerBase
7 | from submodules.GAN_stability.gan_training.eval import Evaluator as EvaluatorBase
8 | from submodules.GAN_stability.gan_training.metrics import FIDEvaluator, KIDEvaluator
9 |
10 | from .utils import save_video, color_depth_map
11 |
12 |
13 | class Trainer(TrainerBase):
14 | def __init__(self, *args, use_amp=False, **kwargs):
15 | super(Trainer, self).__init__(*args, **kwargs)
16 | self.use_amp = use_amp
17 | if self.use_amp:
18 | self.scaler = torch.cuda.amp.GradScaler()
19 |
20 | def generator_trainstep(self, y, z):
21 | if not self.use_amp:
22 | return super(Trainer, self).generator_trainstep(y, z)
23 | assert (y.size(0) == z.size(0))
24 | toggle_grad(self.generator, True)
25 | toggle_grad(self.discriminator, False)
26 | self.generator.train()
27 | self.discriminator.train()
28 | self.g_optimizer.zero_grad()
29 |
30 | with torch.cuda.amp.autocast():
31 | x_fake = self.generator(z, y)
32 | d_fake = self.discriminator(x_fake, y)
33 | gloss = self.compute_loss(d_fake, 1)
34 | self.scaler.scale(gloss).backward()
35 |
36 | self.scaler.step(self.g_optimizer)
37 | self.scaler.update()
38 |
39 | return gloss.item()
40 |
41 | def discriminator_trainstep(self, x_real, y, z):
42 | return super(Trainer, self).discriminator_trainstep(x_real, y, z) # spectral norm raises error for when using amp
43 |
44 |
45 | class Evaluator(EvaluatorBase):
46 | def __init__(self, eval_fid_kid, *args, **kwargs):
47 | super(Evaluator, self).__init__(*args, **kwargs)
48 | if eval_fid_kid:
49 | self.inception_eval = FIDEvaluator(
50 | device=self.device,
51 | batch_size=self.batch_size,
52 | resize=True,
53 | n_samples=20000,
54 | n_samples_fake=1000,
55 | )
56 |
57 | def get_rays(self, pose):
58 | return self.generator.val_ray_sampler(self.generator.H, self.generator.W,
59 | self.generator.focal, pose)[0]
60 |
61 | def create_samples(self, z, poses=None):
62 | self.generator.eval()
63 |
64 | N_samples = len(z)
65 | device = self.generator.device
66 | z = z.to(device).split(self.batch_size)
67 | if poses is None:
68 | rays = [None] * len(z)
69 | else:
70 | rays = torch.stack([self.get_rays(poses[i].to(device)) for i in range(N_samples)])
71 | rays = rays.split(self.batch_size)
72 |
73 | rgb, disp, acc = [], [], []
74 | with torch.no_grad():
75 | for z_i, rays_i in tqdm(zip(z, rays), total=len(z), desc='Create samples...'):
76 | bs = len(z_i)
77 | if rays_i is not None:
78 | rays_i = rays_i.permute(1, 0, 2, 3).flatten(1, 2) # Bx2x(HxW)xC -> 2x(BxHxW)x3
79 | rgb_i, disp_i, acc_i, _ = self.generator(z_i, rays=rays_i)
80 |
81 | reshape = lambda x: x.view(bs, self.generator.H, self.generator.W, x.shape[1]).permute(0, 3, 1, 2) # (NxHxW)xC -> NxCxHxW
82 | rgb.append(reshape(rgb_i).cpu())
83 | disp.append(reshape(disp_i).cpu())
84 | acc.append(reshape(acc_i).cpu())
85 |
86 | rgb = torch.cat(rgb)
87 | disp = torch.cat(disp)
88 | acc = torch.cat(acc)
89 |
90 | depth = self.disp_to_cdepth(disp)
91 |
92 | return rgb, depth, acc
93 |
94 | def make_video(self, basename, z, poses, as_gif=True):
95 | """ Generate images and save them as video.
96 | z (N_samples, zdim): latent codes
97 | poses (N_frames, 3 x 4): camera poses for all frames of video
98 | """
99 | N_samples, N_frames = len(z), len(poses)
100 |
101 | # reshape inputs
102 | z = z.unsqueeze(1).expand(-1, N_frames, -1).flatten(0, 1) # (N_samples x N_frames) x z_dim
103 | poses = poses.unsqueeze(0) \
104 | .expand(N_samples, -1, -1, -1).flatten(0, 1) # (N_samples x N_frames) x 3 x 4
105 |
106 | rgbs, depths, accs = self.create_samples(z, poses=poses)
107 |
108 | reshape = lambda x: x.view(N_samples, N_frames, *x.shape[1:])
109 | rgbs = reshape(rgbs)
110 | depths = reshape(depths)
111 | print('Done, saving', rgbs.shape)
112 |
113 | fps = min(int(N_frames / 2.), 25) # aim for at least 2 second video
114 | for i in range(N_samples):
115 | save_video(rgbs[i], basename + '{:04d}_rgb.mp4'.format(i), as_gif=as_gif, fps=fps)
116 | save_video(depths[i], basename + '{:04d}_depth.mp4'.format(i), as_gif=as_gif, fps=fps)
117 |
118 | def disp_to_cdepth(self, disps):
119 | """Convert depth to color values"""
120 | if (disps == 2e10).all(): # no values predicted
121 | return torch.ones_like(disps)
122 |
123 | near, far = self.generator.render_kwargs_test['near'], self.generator.render_kwargs_test['far']
124 |
125 | disps = disps / 2 + 0.5 # [-1, 1] -> [0, 1]
126 |
127 | depth = 1. / torch.max(1e-10 * torch.ones_like(disps), disps) # disparity -> depth
128 | depth[disps == 1e10] = far # set undefined values to far plane
129 |
130 | # scale between near, far plane for better visualization
131 | depth = (depth - near) / (far - near)
132 |
133 | depth = np.stack([color_depth_map(d) for d in depth[:, 0].detach().cpu().numpy()]) # convert to color
134 | depth = (torch.from_numpy(depth).permute(0, 3, 1, 2) / 255.) * 2 - 1 # [0, 255] -> [-1, 1]
135 |
136 | return depth
137 |
138 | def compute_fid_kid(self, sample_generator=None):
139 | if sample_generator is None:
140 | def sample():
141 | while True:
142 | z = self.zdist.sample((self.batch_size,))
143 | rgb, _, _ = self.create_samples(z)
144 | # convert to uint8 and back to get correct binning
145 | rgb = (rgb / 2 + 0.5).mul_(255).clamp_(0, 255).to(torch.uint8).to(torch.float) / 255. * 2 - 1
146 | yield rgb.cpu()
147 |
148 | sample_generator = sample()
149 |
150 | fid, (kids, vars) = self.inception_eval.get_fid_kid(sample_generator)
151 | kid = np.mean(kids)
152 | return fid, kid
153 |
--------------------------------------------------------------------------------
/graf/models/discriminator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class Discriminator(nn.Module):
6 | def __init__(self, nc=3, ndf=64, imsize=64, hflip=False):
7 | super(Discriminator, self).__init__()
8 | self.nc = nc
9 | assert(imsize==32 or imsize==64 or imsize==128)
10 | self.imsize = imsize
11 | self.hflip = hflip
12 |
13 | SN = torch.nn.utils.spectral_norm
14 | IN = lambda x : nn.InstanceNorm2d(x)
15 |
16 | blocks = []
17 | if self.imsize==128:
18 | blocks += [
19 | # input is (nc) x 128 x 128
20 | SN(nn.Conv2d(nc, ndf//2, 4, 2, 1, bias=False)),
21 | nn.LeakyReLU(0.2, inplace=True),
22 | # input is (ndf//2) x 64 x 64
23 | SN(nn.Conv2d(ndf//2, ndf, 4, 2, 1, bias=False)),
24 | IN(ndf),
25 | nn.LeakyReLU(0.2, inplace=True),
26 | # state size. (ndf) x 32 x 32
27 | SN(nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False)),
28 | #nn.BatchNorm2d(ndf * 2),
29 | IN(ndf * 2),
30 | nn.LeakyReLU(0.2, inplace=True),
31 | ]
32 | elif self.imsize==64:
33 | blocks += [
34 | # input is (nc) x 64 x 64
35 | SN(nn.Conv2d(nc, ndf, 4, 2, 1, bias=False)),
36 | nn.LeakyReLU(0.2, inplace=True),
37 | # state size. (ndf) x 32 x 32
38 | SN(nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False)),
39 | #nn.BatchNorm2d(ndf * 2),
40 | IN(ndf * 2),
41 | nn.LeakyReLU(0.2, inplace=True),
42 | ]
43 | else:
44 | blocks += [
45 | # input is (nc) x 32 x 32
46 | SN(nn.Conv2d(nc, ndf * 2, 4, 2, 1, bias=False)),
47 | #nn.BatchNorm2d(ndf * 2),
48 | IN(ndf * 2),
49 | nn.LeakyReLU(0.2, inplace=True),
50 | ]
51 |
52 | blocks += [
53 | # state size. (ndf*2) x 16 x 16
54 | SN(nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False)),
55 | #nn.BatchNorm2d(ndf * 4),
56 | IN(ndf * 4),
57 | nn.LeakyReLU(0.2, inplace=True),
58 | # state size. (ndf*4) x 8 x 8
59 | SN(nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False)),
60 | #nn.BatchNorm2d(ndf * 8),
61 | IN(ndf * 8),
62 | nn.LeakyReLU(0.2, inplace=True),
63 | # state size. (ndf*8) x 4 x 4
64 | SN(nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False)),
65 | # nn.Sigmoid()
66 | ]
67 | blocks = [x for x in blocks if x]
68 | self.main = nn.Sequential(*blocks)
69 |
70 | def forward(self, input, y=None):
71 | input = input[:, :self.nc]
72 | input = input.view(-1, self.imsize, self.imsize, self.nc).permute(0, 3, 1, 2) # (BxN_samples)xC -> BxCxHxW
73 |
74 | if self.hflip: # Randomly flip input horizontally
75 | input_flipped = input.flip(3)
76 | mask = torch.randint(0, 2, (len(input),1, 1, 1)).bool().expand(-1, *input.shape[1:])
77 | input = torch.where(mask, input, input_flipped)
78 |
79 | return self.main(input)
80 |
81 |
82 |
--------------------------------------------------------------------------------
/graf/models/generator.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from ..utils import sample_on_sphere, look_at, to_sphere
4 | from ..transforms import FullRaySampler
5 | from submodules.nerf_pytorch.run_nerf_mod import render, run_network # import conditional render
6 | from functools import partial
7 |
8 |
9 | class Generator(object):
10 | def __init__(self, H, W, focal, radius, ray_sampler, render_kwargs_train, render_kwargs_test, parameters, named_parameters,
11 | range_u=(0,1), range_v=(0.01,0.49), chunk=None, device='cuda', orthographic=False):
12 | self.device = device
13 | self.H = int(H)
14 | self.W = int(W)
15 | self.focal = focal
16 | self.radius = radius
17 | self.range_u = range_u
18 | self.range_v = range_v
19 | self.chunk = chunk
20 | coords = torch.from_numpy(np.stack(np.meshgrid(np.arange(H), np.arange(W), indexing='ij'), -1))
21 | self.coords = coords.view(-1, 2)
22 |
23 | self.ray_sampler = ray_sampler
24 | self.val_ray_sampler = FullRaySampler(orthographic=orthographic)
25 | self.render_kwargs_train = render_kwargs_train
26 | self.render_kwargs_test = render_kwargs_test
27 | self.initial_raw_noise_std = self.render_kwargs_train['raw_noise_std']
28 | self._parameters = parameters
29 | self._named_parameters = named_parameters
30 | self.module_dict = {'generator': self.render_kwargs_train['network_fn']}
31 | for name, module in [('generator_fine', self.render_kwargs_train['network_fine'])]:
32 | if module is not None:
33 | self.module_dict[name] = module
34 |
35 | for k, v in self.module_dict.items():
36 | if k in ['generator', 'generator_fine']:
37 | continue # parameters already included
38 | self._parameters += list(v.parameters())
39 | self._named_parameters += list(v.named_parameters())
40 |
41 | self.parameters = lambda: self._parameters # save as function to enable calling model.parameters()
42 | self.named_parameters = lambda: self._named_parameters # save as function to enable calling model.named_parameters()
43 | self.use_test_kwargs = False
44 |
45 | self.render = partial(render, H=self.H, W=self.W, focal=self.focal, chunk=self.chunk)
46 |
47 | def __call__(self, z, y=None, rays=None):
48 | bs = z.shape[0]
49 | if rays is None:
50 | rays = torch.cat([self.sample_rays() for _ in range(bs)], dim=1)
51 |
52 | render_kwargs = self.render_kwargs_test if self.use_test_kwargs else self.render_kwargs_train
53 | render_kwargs = dict(render_kwargs) # copy
54 |
55 | # in the case of a variable radius
56 | # we need to adjust near and far plane for the rays
57 | # so they stay within the bounds defined wrt. maximal radius
58 | # otherwise each camera samples within its own near/far plane (relative to this camera's radius)
59 | # instead of the absolute value (relative to maximum camera radius)
60 | if isinstance(self.radius, tuple):
61 | assert self.radius[1] - self.radius[0] <= render_kwargs['near'], 'Your smallest radius lies behind your near plane!'
62 |
63 | rays_radius = rays[0].norm(dim=-1)
64 | shift = (self.radius[1] - rays_radius).view(-1, 1).float() # reshape s.t. shape matches required shape in run_nerf
65 | render_kwargs['near'] = render_kwargs['near'] - shift
66 | render_kwargs['far'] = render_kwargs['far'] - shift
67 | assert (render_kwargs['near'] >= 0).all() and (render_kwargs['far'] >= 0).all(), \
68 | (rays_radius.min(), rays_radius.max(), shift.min(), shift.max())
69 |
70 |
71 | render_kwargs['features'] = z
72 | rgb, disp, acc, extras = render(self.H, self.W, self.focal, chunk=self.chunk, rays=rays,
73 | **render_kwargs)
74 |
75 | rays_to_output = lambda x: x.view(len(x), -1) * 2 - 1 # (BxN_samples)xC
76 |
77 | if self.use_test_kwargs: # return all outputs
78 | return rays_to_output(rgb), \
79 | rays_to_output(disp), \
80 | rays_to_output(acc), extras
81 |
82 | rgb = rays_to_output(rgb)
83 | return rgb
84 |
85 | def decrease_nerf_noise(self, it):
86 | end_it = 5000
87 | if it < end_it:
88 | noise_std = self.initial_raw_noise_std - self.initial_raw_noise_std/end_it * it
89 | self.render_kwargs_train['raw_noise_std'] = noise_std
90 |
91 | def sample_pose(self):
92 | # sample location on unit sphere
93 | loc = sample_on_sphere(self.range_u, self.range_v)
94 |
95 | # sample radius if necessary
96 | radius = self.radius
97 | if isinstance(radius, tuple):
98 | radius = np.random.uniform(*radius)
99 |
100 | loc = loc * radius
101 | R = look_at(loc)[0]
102 |
103 | RT = np.concatenate([R, loc.reshape(3, 1)], axis=1)
104 | RT = torch.Tensor(RT.astype(np.float32))
105 | return RT
106 |
107 | def sample_rays(self):
108 | pose = self.sample_pose()
109 | sampler = self.val_ray_sampler if self.use_test_kwargs else self.ray_sampler
110 | batch_rays, _, _ = sampler(self.H, self.W, self.focal, pose)
111 | return batch_rays
112 |
113 | def to(self, device):
114 | self.render_kwargs_train['network_fn'].to(device)
115 | if self.render_kwargs_train['network_fine'] is not None:
116 | self.render_kwargs_train['network_fine'].to(device)
117 | self.device = device
118 | return self
119 |
120 | def train(self):
121 | self.use_test_kwargs = False
122 | self.render_kwargs_train['network_fn'].train()
123 | if self.render_kwargs_train['network_fine'] is not None:
124 | self.render_kwargs_train['network_fine'].train()
125 |
126 | def eval(self):
127 | self.use_test_kwargs = True
128 | self.render_kwargs_train['network_fn'].eval()
129 | if self.render_kwargs_train['network_fine'] is not None:
130 | self.render_kwargs_train['network_fine'].eval()
131 |
--------------------------------------------------------------------------------
/graf/transforms.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from math import sqrt, exp
3 |
4 | from submodules.nerf_pytorch.run_nerf_helpers_mod import get_rays, get_rays_ortho
5 |
6 |
7 | class ImgToPatch(object):
8 | def __init__(self, ray_sampler, hwf):
9 | self.ray_sampler = ray_sampler
10 | self.hwf = hwf # camera intrinsics
11 |
12 | def __call__(self, img):
13 | rgbs = []
14 | for img_i in img:
15 | pose = torch.eye(4) # use dummy pose to infer pixel values
16 | _, selected_idcs, pixels_i = self.ray_sampler(H=self.hwf[0], W=self.hwf[1], focal=self.hwf[2], pose=pose)
17 | if selected_idcs is not None:
18 | rgbs_i = img_i.flatten(1, 2).t()[selected_idcs]
19 | else:
20 | rgbs_i = torch.nn.functional.grid_sample(img_i.unsqueeze(0),
21 | pixels_i.unsqueeze(0), mode='bilinear', align_corners=True)[0]
22 | rgbs_i = rgbs_i.flatten(1, 2).t()
23 | rgbs.append(rgbs_i)
24 |
25 | rgbs = torch.cat(rgbs, dim=0) # (B*N)x3
26 |
27 | return rgbs
28 |
29 |
30 | class RaySampler(object):
31 | def __init__(self, N_samples, orthographic=False):
32 | super(RaySampler, self).__init__()
33 | self.N_samples = N_samples
34 | self.scale = torch.ones(1,).float()
35 | self.return_indices = True
36 | self.orthographic = orthographic
37 |
38 | def __call__(self, H, W, focal, pose):
39 | if self.orthographic:
40 | size_h, size_w = focal # Hacky
41 | rays_o, rays_d = get_rays_ortho(H, W, pose, size_h, size_w)
42 | else:
43 | rays_o, rays_d = get_rays(H, W, focal, pose)
44 |
45 | select_inds = self.sample_rays(H, W)
46 |
47 | if self.return_indices:
48 | rays_o = rays_o.view(-1, 3)[select_inds]
49 | rays_d = rays_d.view(-1, 3)[select_inds]
50 |
51 | h = (select_inds // W) / float(H) - 0.5
52 | w = (select_inds % W) / float(W) - 0.5
53 |
54 | hw = torch.stack([h,w]).t()
55 |
56 | else:
57 | rays_o = torch.nn.functional.grid_sample(rays_o.permute(2,0,1).unsqueeze(0),
58 | select_inds.unsqueeze(0), mode='bilinear', align_corners=True)[0]
59 | rays_d = torch.nn.functional.grid_sample(rays_d.permute(2,0,1).unsqueeze(0),
60 | select_inds.unsqueeze(0), mode='bilinear', align_corners=True)[0]
61 | rays_o = rays_o.permute(1,2,0).view(-1, 3)
62 | rays_d = rays_d.permute(1,2,0).view(-1, 3)
63 |
64 | hw = select_inds
65 | select_inds = None
66 |
67 | return torch.stack([rays_o, rays_d]), select_inds, hw
68 |
69 | def sample_rays(self, H, W):
70 | raise NotImplementedError
71 |
72 |
73 | class FullRaySampler(RaySampler):
74 | def __init__(self, **kwargs):
75 | super(FullRaySampler, self).__init__(N_samples=None, **kwargs)
76 |
77 | def sample_rays(self, H, W):
78 | return torch.arange(0, H*W)
79 |
80 |
81 | class FlexGridRaySampler(RaySampler):
82 | def __init__(self, N_samples, random_shift=True, random_scale=True, min_scale=0.25, max_scale=1., scale_anneal=-1,
83 | **kwargs):
84 | self.N_samples_sqrt = int(sqrt(N_samples))
85 | super(FlexGridRaySampler, self).__init__(self.N_samples_sqrt**2, **kwargs)
86 |
87 | self.random_shift = random_shift
88 | self.random_scale = random_scale
89 |
90 | self.min_scale = min_scale
91 | self.max_scale = max_scale
92 |
93 | # nn.functional.grid_sample grid value range in [-1,1]
94 | self.w, self.h = torch.meshgrid([torch.linspace(-1,1,self.N_samples_sqrt),
95 | torch.linspace(-1,1,self.N_samples_sqrt)])
96 | self.h = self.h.unsqueeze(2)
97 | self.w = self.w.unsqueeze(2)
98 |
99 | # directly return grid for grid_sample
100 | self.return_indices = False
101 |
102 | self.iterations = 0
103 | self.scale_anneal = scale_anneal
104 |
105 | def sample_rays(self, H, W):
106 |
107 | if self.scale_anneal>0:
108 | k_iter = self.iterations // 1000 * 3
109 | min_scale = max(self.min_scale, self.max_scale * exp(-k_iter*self.scale_anneal))
110 | min_scale = min(0.9, min_scale)
111 | else:
112 | min_scale = self.min_scale
113 |
114 | scale = 1
115 | if self.random_scale:
116 | scale = torch.Tensor(1).uniform_(min_scale, self.max_scale)
117 | h = self.h * scale
118 | w = self.w * scale
119 |
120 | if self.random_shift:
121 | max_offset = 1-scale.item()
122 | h_offset = torch.Tensor(1).uniform_(0, max_offset) * (torch.randint(2,(1,)).float()-0.5)*2
123 | w_offset = torch.Tensor(1).uniform_(0, max_offset) * (torch.randint(2,(1,)).float()-0.5)*2
124 |
125 | h += h_offset
126 | w += w_offset
127 |
128 | self.scale = scale
129 |
130 | return torch.cat([h, w], dim=2)
131 |
--------------------------------------------------------------------------------
/graf/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import imageio
4 | import os
5 |
6 |
7 | def get_nsamples(data_loader, N):
8 | x = []
9 | n = 0
10 | while n < N:
11 | x_next = next(iter(data_loader))
12 | x.append(x_next)
13 | n += x_next.size(0)
14 | x = torch.cat(x, dim=0)[:N]
15 | return x
16 |
17 |
18 | def count_trainable_parameters(model):
19 | model_parameters = filter(lambda p: p.requires_grad, model.parameters())
20 | return sum([np.prod(p.size()) for p in model_parameters])
21 |
22 |
23 | def save_video(imgs, fname, as_gif=False, fps=24, quality=8):
24 | # convert to np.uint8
25 | imgs = (255 * np.clip(imgs.permute(0, 2, 3, 1).detach().cpu().numpy() / 2 + 0.5, 0, 1)).astype(np.uint8)
26 | imageio.mimwrite(fname, imgs, fps=fps, quality=quality)
27 |
28 | if as_gif: # save as gif, too
29 | os.system(f'ffmpeg -i {fname} -r 15 '
30 | f'-vf "scale=512:-1,split[s0][s1];[s0]palettegen[p];[s1][p]paletteuse" {os.path.splitext(fname)[0] + ".gif"}')
31 |
32 |
33 | def color_depth_map(depths, scale=None):
34 | """
35 | Color an input depth map.
36 |
37 | Arguments:
38 | depths -- HxW numpy array of depths
39 | [scale=None] -- scaling the values (defaults to the maximum depth)
40 |
41 | Returns:
42 | colored_depths -- HxWx3 numpy array visualizing the depths
43 | """
44 |
45 | _color_map_depths = np.array([
46 | [0, 0, 0], # 0.000
47 | [0, 0, 255], # 0.114
48 | [255, 0, 0], # 0.299
49 | [255, 0, 255], # 0.413
50 | [0, 255, 0], # 0.587
51 | [0, 255, 255], # 0.701
52 | [255, 255, 0], # 0.886
53 | [255, 255, 255], # 1.000
54 | [255, 255, 255], # 1.000
55 | ]).astype(float)
56 | _color_map_bincenters = np.array([
57 | 0.0,
58 | 0.114,
59 | 0.299,
60 | 0.413,
61 | 0.587,
62 | 0.701,
63 | 0.886,
64 | 1.000,
65 | 2.000, # doesn't make a difference, just strictly higher than 1
66 | ])
67 |
68 | if scale is None:
69 | scale = depths.max()
70 |
71 | values = np.clip(depths.flatten() / scale, 0, 1)
72 | # for each value, figure out where they fit in in the bincenters: what is the last bincenter smaller than this value?
73 | lower_bin = ((values.reshape(-1, 1) >= _color_map_bincenters.reshape(1, -1)) * np.arange(0, 9)).max(axis=1)
74 | lower_bin_value = _color_map_bincenters[lower_bin]
75 | higher_bin_value = _color_map_bincenters[lower_bin + 1]
76 | alphas = (values - lower_bin_value) / (higher_bin_value - lower_bin_value)
77 | colors = _color_map_depths[lower_bin] * (1 - alphas).reshape(-1, 1) + _color_map_depths[
78 | lower_bin + 1] * alphas.reshape(-1, 1)
79 | return colors.reshape(depths.shape[0], depths.shape[1], 3).astype(np.uint8)
80 |
81 |
82 | # Virtual camera utils
83 |
84 |
85 | def to_sphere(u, v):
86 | theta = 2 * np.pi * u
87 | phi = np.arccos(1 - 2 * v)
88 | cx = np.sin(phi) * np.cos(theta)
89 | cy = np.sin(phi) * np.sin(theta)
90 | cz = np.cos(phi)
91 | s = np.stack([cx, cy, cz])
92 | return s
93 |
94 |
95 | def polar_to_cartesian(r, theta, phi, deg=True):
96 | if deg:
97 | phi = phi * np.pi / 180
98 | theta = theta * np.pi / 180
99 | cx = np.sin(phi) * np.cos(theta)
100 | cy = np.sin(phi) * np.sin(theta)
101 | cz = np.cos(phi)
102 | return r * np.stack([cx, cy, cz])
103 |
104 |
105 | def to_uv(loc):
106 | # normalize to unit sphere
107 | loc = loc / loc.norm(dim=1, keepdim=True)
108 |
109 | cx, cy, cz = loc.t()
110 | v = (1 - cz) / 2
111 |
112 | phi = torch.acos(cz)
113 | sin_phi = torch.sin(phi)
114 |
115 | # ensure we do not divide by zero
116 | eps = 1e-8
117 | sin_phi[sin_phi.abs() < eps] = eps
118 |
119 | theta = torch.acos(cx / sin_phi)
120 |
121 | # check for sign of phi
122 | cx_rec = sin_phi * torch.cos(theta)
123 | if not np.isclose(cx.numpy(), cx_rec.numpy(), atol=1e-5).all():
124 | sin_phi = -sin_phi
125 |
126 | # check for sign of theta
127 | cy_rec = sin_phi * torch.sin(theta)
128 | if not np.isclose(cy.numpy(), cy_rec.numpy(), atol=1e-5).all():
129 | theta = -theta
130 |
131 | u = theta / (2 * np.pi)
132 | assert np.isclose(to_sphere(u, v).detach().cpu().numpy(), loc.t().detach().cpu().numpy(), atol=1e-5).all()
133 |
134 | return u, v
135 |
136 |
137 | def to_phi(u):
138 | return 360 * u # 2*pi*u*180/pi
139 |
140 |
141 | def to_theta(v):
142 | return np.arccos(1 - 2 * v) * 180. / np.pi
143 |
144 |
145 | def sample_on_sphere(range_u=(0, 1), range_v=(0, 1)):
146 | u = np.random.uniform(*range_u)
147 | v = np.random.uniform(*range_v)
148 | return to_sphere(u, v)
149 |
150 |
151 | def look_at(eye, at=np.array([0, 0, 0]), up=np.array([0, 0, 1]), eps=1e-5):
152 | at = at.astype(float).reshape(1, 3)
153 | up = up.astype(float).reshape(1, 3)
154 |
155 | eye = eye.reshape(-1, 3)
156 | up = up.repeat(eye.shape[0] // up.shape[0], axis=0)
157 | eps = np.array([eps]).reshape(1, 1).repeat(up.shape[0], axis=0)
158 |
159 | z_axis = eye - at
160 | z_axis /= np.max(np.stack([np.linalg.norm(z_axis, axis=1, keepdims=True), eps]))
161 |
162 | x_axis = np.cross(up, z_axis)
163 | x_axis /= np.max(np.stack([np.linalg.norm(x_axis, axis=1, keepdims=True), eps]))
164 |
165 | y_axis = np.cross(z_axis, x_axis)
166 | y_axis /= np.max(np.stack([np.linalg.norm(y_axis, axis=1, keepdims=True), eps]))
167 |
168 | r_mat = np.concatenate((x_axis.reshape(-1, 3, 1), y_axis.reshape(-1, 3, 1), z_axis.reshape(-1, 3, 1)), axis=2)
169 |
170 | return r_mat
171 |
--------------------------------------------------------------------------------
/submodules/GAN_stability/.gitignore:
--------------------------------------------------------------------------------
1 | output
2 | data
3 | *_lmdb
4 | __pycache__
5 | *.pyc
6 |
--------------------------------------------------------------------------------
/submodules/GAN_stability/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Lars Mescheder
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/submodules/GAN_stability/README.md:
--------------------------------------------------------------------------------
1 | # GAN stability
2 | This repository contains the experiments in the supplementary material for the paper [Which Training Methods for GANs do actually Converge?](https://avg.is.tuebingen.mpg.de/publications/meschedericml2018).
3 |
4 | To cite this work, please use
5 | ```
6 | @INPROCEEDINGS{Mescheder2018ICML,
7 | author = {Lars Mescheder and Sebastian Nowozin and Andreas Geiger},
8 | title = {Which Training Methods for GANs do actually Converge?},
9 | booktitle = {International Conference on Machine Learning (ICML)},
10 | year = {2018}
11 | }
12 | ```
13 | You can find further details on [our project page](https://avg.is.tuebingen.mpg.de/research_projects/convergence-and-stability-of-gan-training).
14 |
15 | # Usage
16 | First download your data and put it into the `./data` folder.
17 |
18 | To train a new model, first create a config script similar to the ones provided in the `./configs` folder. You can then train you model using
19 | ```
20 | python train.py PATH_TO_CONFIG
21 | ```
22 |
23 | To compute the inception score for your model and generate samples, use
24 | ```
25 | python test.py PATH_TO_CONFIG
26 | ```
27 |
28 | Finally, you can create nice latent space interpolations using
29 | ```
30 | python interpolate.py PATH_TO_CONFIG
31 | ```
32 | or
33 | ```
34 | python interpolate_class.py PATH_TO_CONFIG
35 | ```
36 |
37 | # Pretrained models
38 | We also provide several pretrained models.
39 |
40 | You can use the models for sampling by entering
41 | ```
42 | python test.py PATH_TO_CONFIG
43 | ```
44 | where `PATH_TO_CONFIG` is one of the config files
45 | ```
46 | configs/pretrained/celebA_pretrained.yaml
47 | configs/pretrained/celebAHQ_pretrained.yaml
48 | configs/pretrained/imagenet_pretrained.yaml
49 | configs/pretrained/lsun_bedroom_pretrained.yaml
50 | configs/pretrained/lsun_bridge_pretrained.yaml
51 | configs/pretrained/lsun_church_pretrained.yaml
52 | configs/pretrained/lsun_tower_pretrained.yaml
53 | ```
54 | Our script will automatically download the model checkpoints and run the generation.
55 | You can find the outputs in the `output/pretrained` folders.
56 | Similarly, you can use the scripts `interpolate.py` and `interpolate_class.py` for generating interpolations for the pretrained models.
57 |
58 | Please note that the config files `*_pretrained.yaml` are only for generation, not for training new models: when these configs are used for training, the model will be trained from scratch, but during inference our code will still use the pretrained model.
59 |
60 | # Notes
61 | * Batch normalization is currently *not* supported when using an exponential running average, as the running average is only computed over the parameters of the models and not the other buffers of the model.
62 |
63 | # Results
64 | ## celebA-HQ
65 | 
66 |
67 | ## Imagenet
68 | 
69 | 
70 | 
71 | 
72 | 
73 |
--------------------------------------------------------------------------------
/submodules/GAN_stability/configs/celebAHQ.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | type: npy
3 | train_dir: data/celebA-HQ
4 | test_dir: data/celebA-HQ
5 | img_size: 1024
6 | generator:
7 | name: resnet
8 | kwargs:
9 | nfilter: 16
10 | nfilter_max: 512
11 | embed_size: 1
12 | discriminator:
13 | name: resnet
14 | kwargs:
15 | nfilter: 16
16 | nfilter_max: 512
17 | embed_size: 1
18 | z_dist:
19 | type: gauss
20 | dim: 256
21 | training:
22 | out_dir: output/celebAHQ
23 | batch_size: 24
24 | test:
25 | batch_size: 4
26 | sample_size: 6
27 | sample_nrow: 3
28 | interpolations:
29 | nzs: 10
30 | nsubsteps: 75
31 |
--------------------------------------------------------------------------------
/submodules/GAN_stability/configs/default.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | type: lsun
3 | train_dir: data/LSUN
4 | test_dir: data/LSUN
5 | lsun_categories_train: [bedroom_train]
6 | lsun_categories_test: [bedroom_test]
7 | img_size: 256
8 | nlabels: 1
9 | generator:
10 | name: resnet
11 | kwargs:
12 | discriminator:
13 | name: resnet
14 | kwargs:
15 | z_dist:
16 | type: gauss
17 | dim: 256
18 | training:
19 | out_dir: output/default
20 | gan_type: standard
21 | reg_type: real
22 | reg_param: 10.
23 | batch_size: 64
24 | nworkers: 16
25 | take_model_average: true
26 | model_average_beta: 0.999
27 | model_average_reinit: false
28 | monitoring: tensorboard
29 | sample_every: 1000
30 | sample_nlabels: 20
31 | inception_every: -1
32 | save_every: 900
33 | backup_every: 100000
34 | restart_every: -1
35 | optimizer: rmsprop
36 | lr_g: 0.0001
37 | lr_d: 0.0001
38 | lr_anneal: 1.
39 | lr_anneal_every: 150000
40 | d_steps: 1
41 | equalize_lr: false
42 | model_file: model.pt
43 | test:
44 | batch_size: 32
45 | sample_size: 64
46 | sample_nrow: 8
47 | use_model_average: true
48 | compute_inception: false
49 | conditional_samples: false
50 | model_file: model.pt
51 | interpolations:
52 | nzs: 10
53 | nsubsteps: 75
54 |
--------------------------------------------------------------------------------
/submodules/GAN_stability/configs/imagenet.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | type: image
3 | train_dir: data/Imagenet
4 | test_dir: data/Imagenet
5 | img_size: 128
6 | nlabels: 1000
7 | generator:
8 | name: resnet2
9 | kwargs:
10 | nfilter: 64
11 | nfilter_max: 1024
12 | embed_size: 256
13 | discriminator:
14 | name: resnet2
15 | kwargs:
16 | nfilter: 64
17 | nfilter_max: 1024
18 | embed_size: 256
19 | z_dist:
20 | type: gauss
21 | dim: 256
22 | training:
23 | out_dir: output/imagenet
24 | gan_type: standard
25 | sample_nlabels: 20
26 | inception_every: 10000
27 | batch_size: 128
28 | test:
29 | batch_size: 32
30 | sample_size: 64
31 | sample_nrow: 8
32 | compute_inception: true
33 | conditional_samples: true
34 | interpolations:
35 | ys: [15, 157, 307, 321, 442, 483, 484, 525,
36 | 536, 598, 607, 734, 768, 795, 927, 977,
37 | 963, 946, 979]
38 | nzs: 10
39 | nsubsteps: 75
40 |
--------------------------------------------------------------------------------
/submodules/GAN_stability/configs/lsun_bedroom.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | type: lsun
3 | train_dir: data/LSUN
4 | test_dir: data/LSUN
5 | lsun_categories_train: [bedroom_train]
6 | lsun_categories_test: [bedroom_test]
7 | img_size: 256
8 | generator:
9 | name: resnet
10 | kwargs:
11 | nfilter: 64
12 | nfilter_max: 1024
13 | embed_size: 1
14 | discriminator:
15 | name: resnet
16 | kwargs:
17 | nfilter: 64
18 | nfilter_max: 1024
19 | embed_size: 1
20 | z_dist:
21 | type: gauss
22 | dim: 256
23 | training:
24 | out_dir: output/lsun_bedroom
25 | test:
26 | batch_size: 32
27 | sample_size: 64
28 | sample_nrow: 8
29 | interpolations:
30 | nzs: 10
31 | nsubsteps: 75
32 |
--------------------------------------------------------------------------------
/submodules/GAN_stability/configs/lsun_bridge.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | type: lsun
3 | train_dir: data/LSUN
4 | test_dir: data/LSUN
5 | lsun_categories_train: [bridge_train]
6 | lsun_categories_test: [bridge_train]
7 | img_size: 256
8 | generator:
9 | name: resnet
10 | kwargs:
11 | nfilter: 64
12 | nfilter_max: 1024
13 | embed_size: 1
14 | discriminator:
15 | name: resnet
16 | kwargs:
17 | nfilter: 64
18 | nfilter_max: 1024
19 | embed_size: 1
20 | z_dist:
21 | type: gauss
22 | dim: 256
23 | training:
24 | out_dir: output/lsun_bridge
25 | test:
26 | batch_size: 32
27 | sample_size: 64
28 | sample_nrow: 8
29 | interpolations:
30 | nzs: 10
31 | nsubsteps: 75
32 |
--------------------------------------------------------------------------------
/submodules/GAN_stability/configs/lsun_church.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | type: lsun
3 | train_dir: data/LSUN
4 | test_dir: data/LSUN
5 | lsun_categories_train: [church_outdoor_train]
6 | lsun_categories_test: [church_outdoor_test]
7 | img_size: 256
8 | generator:
9 | name: resnet
10 | kwargs:
11 | nfilter: 64
12 | nfilter_max: 1024
13 | embed_size: 1
14 | discriminator:
15 | name: resnet
16 | kwargs:
17 | nfilter: 64
18 | nfilter_max: 1024
19 | embed_size: 1
20 | z_dist:
21 | type: gauss
22 | dim: 256
23 | training:
24 | out_dir: output/lsun_church
25 | test:
26 | batch_size: 32
27 | sample_size: 64
28 | sample_nrow: 8
29 | interpolations:
30 | nzs: 10
31 | nsubsteps: 75
32 |
--------------------------------------------------------------------------------
/submodules/GAN_stability/configs/lsun_tower.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | type: lsun
3 | train_dir: data/LSUN
4 | test_dir: data/LSUN
5 | lsun_categories_train: [tower_train]
6 | lsun_categories_test: [tower_test]
7 | img_size: 256
8 | generator:
9 | name: resnet
10 | kwargs:
11 | nfilter: 64
12 | nfilter_max: 1024
13 | embed_size: 1
14 | discriminator:
15 | name: resnet
16 | kwargs:
17 | nfilter: 64
18 | nfilter_max: 1024
19 | embed_size: 1
20 | z_dist:
21 | type: gauss
22 | dim: 256
23 | training:
24 | out_dir: output/lsun_tower
25 | test:
26 | batch_size: 32
27 | sample_size: 64
28 | sample_nrow: 8
29 | interpolations:
30 | nzs: 10
31 | nsubsteps: 75
32 |
--------------------------------------------------------------------------------
/submodules/GAN_stability/configs/pretrained/celebAHQ_pretrained.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | type: npy
3 | train_dir: data/celebA-HQ
4 | test_dir: data/celebA-HQ
5 | img_size: 1024
6 | generator:
7 | name: resnet
8 | kwargs:
9 | nfilter: 16
10 | nfilter_max: 512
11 | embed_size: 1
12 | discriminator:
13 | name: resnet
14 | kwargs:
15 | nfilter: 16
16 | nfilter_max: 512
17 | embed_size: 1
18 | z_dist:
19 | type: gauss
20 | dim: 256
21 | training:
22 | out_dir: output/pretrained/celebAHQ
23 | batch_size: 24
24 | test:
25 | model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/gan_stability/models/celebahq-baab46b2.pt
26 | batch_size: 4
27 | sample_size: 6
28 | sample_nrow: 3
29 | interpolations:
30 | nzs: 10
31 | nsubsteps: 75
32 |
--------------------------------------------------------------------------------
/submodules/GAN_stability/configs/pretrained/celebA_pretrained.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | type: image
3 | train_dir: data/celebA
4 | test_dir: data/celebA
5 | img_size: 256
6 | generator:
7 | name: resnet4
8 | kwargs:
9 | nfilter: 64
10 | embed_size: 1
11 | discriminator:
12 | name: resnet4
13 | kwargs:
14 | nfilter: 64
15 | embed_size: 1
16 | z_dist:
17 | type: gauss
18 | dim: 256
19 | training:
20 | out_dir: output/pretrained/celebA
21 | test:
22 | model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/gan_stability/models/celeba-ab478c9d.pt
23 | batch_size: 32
24 | sample_size: 64
25 | sample_nrow: 8
26 | interpolations:
27 | nzs: 10
28 | nsubsteps: 75
29 |
--------------------------------------------------------------------------------
/submodules/GAN_stability/configs/pretrained/imagenet_pretrained.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | type: image
3 | train_dir: data/Imagenet
4 | test_dir: data/Imagenet
5 | img_size: 128
6 | nlabels: 1000
7 | generator:
8 | name: resnet2
9 | kwargs:
10 | nfilter: 64
11 | nfilter_max: 1024
12 | embed_size: 256
13 | discriminator:
14 | name: resnet2
15 | kwargs:
16 | nfilter: 64
17 | nfilter_max: 1024
18 | embed_size: 256
19 | z_dist:
20 | type: gauss
21 | dim: 256
22 | training:
23 | out_dir: output/pretrained/imagenet
24 | sample_nlabels: 20
25 | inception_every: 10000
26 | batch_size: 128
27 | test:
28 | model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/gan_stability/models/imagenet-8c505f47.pt
29 | batch_size: 32
30 | sample_size: 64
31 | sample_nrow: 8
32 | compute_inception: false
33 | conditional_samples: true
34 | interpolations:
35 | ys: [15, 157, 307, 321, 442, 483, 484, 525,
36 | 536, 598, 607, 734, 768, 795, 927, 977,
37 | 963, 946, 979]
38 | nzs: 10
39 | nsubsteps: 75
40 |
--------------------------------------------------------------------------------
/submodules/GAN_stability/configs/pretrained/lsun_bedroom_pretrained.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | type: lsun
3 | train_dir: data/LSUN
4 | test_dir: data/LSUN
5 | lsun_categories_train: [bedroom_train]
6 | lsun_categories_test: [bedroom_test]
7 | img_size: 256
8 | generator:
9 | name: resnet3
10 | kwargs:
11 | nfilter: 64
12 | embed_size: 1
13 | discriminator:
14 | name: resnet3
15 | kwargs:
16 | nfilter: 64
17 | embed_size: 1
18 | z_dist:
19 | type: gauss
20 | dim: 256
21 | training:
22 | out_dir: output/pretrained/lsun_bedroom
23 | test:
24 | model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/gan_stability/models/lsun_bedroom-df4e7dd2.pt
25 | batch_size: 32
26 | sample_size: 64
27 | sample_nrow: 8
28 | interpolations:
29 | nzs: 10
30 | nsubsteps: 75
31 |
--------------------------------------------------------------------------------
/submodules/GAN_stability/configs/pretrained/lsun_bridge_pretrained.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | type: lsun
3 | train_dir: data/LSUN
4 | test_dir: data/LSUN
5 | lsun_categories_train: [bridge_train]
6 | lsun_categories_test: [bridge_test]
7 | img_size: 256
8 | generator:
9 | name: resnet3
10 | kwargs:
11 | nfilter: 64
12 | embed_size: 1
13 | discriminator:
14 | name: resnet3
15 | kwargs:
16 | nfilter: 64
17 | embed_size: 1
18 | z_dist:
19 | type: gauss
20 | dim: 256
21 | training:
22 | out_dir: output/pretrained/lsun_bridge
23 | test:
24 | model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/gan_stability/models/lsun_bridge-82887d22.pt
25 | batch_size: 32
26 | sample_size: 64
27 | sample_nrow: 8
28 | interpolations:
29 | nzs: 10
30 | nsubsteps: 75
31 |
--------------------------------------------------------------------------------
/submodules/GAN_stability/configs/pretrained/lsun_church_pretrained.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | type: lsun
3 | train_dir: data/LSUN
4 | test_dir: data/LSUN
5 | lsun_categories_train: [church_outdoor_train]
6 | lsun_categories_test: [church_outdoor_test]
7 | img_size: 256
8 | generator:
9 | name: resnet3
10 | kwargs:
11 | nfilter: 64
12 | embed_size: 1
13 | discriminator:
14 | name: resnet3
15 | kwargs:
16 | nfilter: 64
17 | embed_size: 1
18 | z_dist:
19 | type: gauss
20 | dim: 256
21 | training:
22 | out_dir: output/pretrained/lsun_church
23 | test:
24 | model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/gan_stability/models/lsun_church-b6f0191b.pt
25 | batch_size: 32
26 | sample_size: 64
27 | sample_nrow: 8
28 | interpolations:
29 | nzs: 10
30 | nsubsteps: 75
31 |
--------------------------------------------------------------------------------
/submodules/GAN_stability/configs/pretrained/lsun_tower_pretrained.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | type: lsun
3 | train_dir: data/LSUN
4 | test_dir: data/LSUN
5 | lsun_categories_train: [tower_train]
6 | lsun_categories_test: [tower_test]
7 | img_size: 256
8 | generator:
9 | name: resnet3
10 | kwargs:
11 | nfilter: 64
12 | embed_size: 1
13 | discriminator:
14 | name: resnet3
15 | kwargs:
16 | nfilter: 64
17 | embed_size: 1
18 | z_dist:
19 | type: gauss
20 | dim: 256
21 | training:
22 | out_dir: output/pretrained/lsun_tower
23 | test:
24 | model_file: https://s3.eu-central-1.amazonaws.com/avg-projects/gan_stability/models/lsun_tower-1af5e570.pt
25 | batch_size: 32
26 | sample_size: 64
27 | sample_nrow: 8
28 | interpolations:
29 | nzs: 10
30 | nsubsteps: 75
31 |
--------------------------------------------------------------------------------
/submodules/GAN_stability/gan_training/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/autonomousvision/graf/c50d342fb567aec335b92e3f867c54b4dc4e1d09/submodules/GAN_stability/gan_training/__init__.py
--------------------------------------------------------------------------------
/submodules/GAN_stability/gan_training/checkpoints.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import urllib
4 | import torch
5 | from torch.utils import model_zoo
6 |
7 |
8 | class CheckpointIO(object):
9 | ''' CheckpointIO class.
10 |
11 | It handles saving and loading checkpoints.
12 |
13 | Args:
14 | checkpoint_dir (str): path where checkpoints are saved
15 | '''
16 | def __init__(self, checkpoint_dir='./chkpts', **kwargs):
17 | self.module_dict = kwargs
18 | self.checkpoint_dir = checkpoint_dir
19 | if not os.path.exists(checkpoint_dir):
20 | os.makedirs(checkpoint_dir)
21 |
22 | def register_modules(self, **kwargs):
23 | ''' Registers modules in current module dictionary.
24 | '''
25 | self.module_dict.update(kwargs)
26 |
27 | def save(self, filename, **kwargs):
28 | ''' Saves the current module dictionary.
29 |
30 | Args:
31 | filename (str): name of output file
32 | '''
33 | if not os.path.isabs(filename):
34 | filename = os.path.join(self.checkpoint_dir, filename)
35 |
36 | outdict = kwargs
37 | for k, v in self.module_dict.items():
38 | outdict[k] = v.state_dict()
39 | torch.save(outdict, filename)
40 |
41 | def load(self, filename):
42 | '''Loads a module dictionary from local file or url.
43 |
44 | Args:
45 | filename (str): name of saved module dictionary
46 | '''
47 | if is_url(filename):
48 | return self.load_url(filename)
49 | else:
50 | return self.load_file(filename)
51 |
52 | def load_file(self, filename):
53 | '''Loads a module dictionary from file.
54 |
55 | Args:
56 | filename (str): name of saved module dictionary
57 | '''
58 |
59 | if not os.path.isabs(filename):
60 | filename = os.path.join(self.checkpoint_dir, filename)
61 |
62 | if os.path.exists(filename):
63 | print(filename)
64 | print('=> Loading checkpoint from local file...')
65 | state_dict = torch.load(filename)
66 | scalars = self.parse_state_dict(state_dict)
67 | return scalars
68 | else:
69 | raise FileNotFoundError
70 |
71 | def load_url(self, url):
72 | '''Load a module dictionary from url.
73 |
74 | Args:
75 | url (str): url to saved model
76 | '''
77 | print(url)
78 | print('=> Loading checkpoint from url...')
79 | state_dict = model_zoo.load_url(url, progress=True)
80 | scalars = self.parse_state_dict(state_dict)
81 | return scalars
82 |
83 | def parse_state_dict(self, state_dict):
84 | '''Parse state_dict of model and return scalars.
85 |
86 | Args:
87 | state_dict (dict): State dict of model
88 | '''
89 |
90 | for k, v in self.module_dict.items():
91 | if k in state_dict:
92 | v.load_state_dict(state_dict[k])
93 | else:
94 | print('Warning: Could not find %s in checkpoint!' % k)
95 | scalars = {k: v for k, v in state_dict.items()
96 | if k not in self.module_dict}
97 | return scalars
98 |
99 | def is_url(url):
100 | scheme = urllib.parse.urlparse(url).scheme
101 | return scheme in ('http', 'https')
--------------------------------------------------------------------------------
/submodules/GAN_stability/gan_training/config.py:
--------------------------------------------------------------------------------
1 | import yaml
2 | from torch import optim
3 | from os import path
4 | from GAN_stability.gan_training.models import generator_dict, discriminator_dict
5 | from GAN_stability.gan_training.train import toggle_grad
6 |
7 |
8 | # General config
9 | def load_config(path, default_path):
10 | ''' Loads config file.
11 |
12 | Args:
13 | path (str): path to config file
14 | default_path (bool): whether to use default path
15 | '''
16 | # Load configuration from file itself
17 | with open(path, 'r') as f:
18 | cfg_special = yaml.load(f)
19 |
20 | # Check if we should inherit from a config
21 | inherit_from = cfg_special.get('inherit_from')
22 |
23 | # If yes, load this config first as default
24 | # If no, use the default_path
25 | if inherit_from is not None:
26 | cfg = load_config(inherit_from, default_path)
27 | elif default_path is not None:
28 | with open(default_path, 'r') as f:
29 | cfg = yaml.load(f)
30 | else:
31 | cfg = dict()
32 |
33 | # Include main configuration
34 | update_recursive(cfg, cfg_special)
35 |
36 | return cfg
37 |
38 |
39 | def update_recursive(dict1, dict2):
40 | ''' Update two config dictionaries recursively.
41 |
42 | Args:
43 | dict1 (dict): first dictionary to be updated
44 | dict2 (dict): second dictionary which entries should be used
45 |
46 | '''
47 | for k, v in dict2.items():
48 | # Add item if not yet in dict1
49 | if k not in dict1:
50 | dict1[k] = None
51 | # Update
52 | if isinstance(dict1[k], dict):
53 | update_recursive(dict1[k], v)
54 | else:
55 | dict1[k] = v
56 |
57 |
58 | def build_models(config):
59 | # Get classes
60 | Generator = generator_dict[config['generator']['name']]
61 | Discriminator = discriminator_dict[config['discriminator']['name']]
62 |
63 | # Build models
64 | generator = Generator(
65 | z_dim=config['z_dist']['dim'],
66 | nlabels=config['data']['nlabels'],
67 | size=config['data']['img_size'],
68 | **config['generator']['kwargs']
69 | )
70 | discriminator = Discriminator(
71 | config['discriminator']['name'],
72 | nlabels=config['data']['nlabels'],
73 | size=config['data']['img_size'],
74 | **config['discriminator']['kwargs']
75 | )
76 |
77 | return generator, discriminator
78 |
79 |
80 | def build_optimizers(generator, discriminator, config):
81 | optimizer = config['training']['optimizer']
82 | lr_g = config['training']['lr_g']
83 | lr_d = config['training']['lr_d']
84 | equalize_lr = config['training']['equalize_lr']
85 |
86 | toggle_grad(generator, True)
87 | toggle_grad(discriminator, True)
88 |
89 | if equalize_lr:
90 | g_gradient_scales = getattr(generator, 'gradient_scales', dict())
91 | d_gradient_scales = getattr(discriminator, 'gradient_scales', dict())
92 |
93 | g_params = get_parameter_groups(generator.parameters(),
94 | g_gradient_scales,
95 | base_lr=lr_g)
96 | d_params = get_parameter_groups(discriminator.parameters(),
97 | d_gradient_scales,
98 | base_lr=lr_d)
99 | else:
100 | g_params = generator.parameters()
101 | d_params = discriminator.parameters()
102 |
103 | # Optimizers
104 | if optimizer == 'rmsprop':
105 | g_optimizer = optim.RMSprop(g_params, lr=lr_g, alpha=0.99, eps=1e-8)
106 | d_optimizer = optim.RMSprop(d_params, lr=lr_d, alpha=0.99, eps=1e-8)
107 | elif optimizer == 'adam':
108 | g_optimizer = optim.Adam(g_params, lr=lr_g, betas=(0., 0.99), eps=1e-8)
109 | d_optimizer = optim.Adam(d_params, lr=lr_d, betas=(0., 0.99), eps=1e-8)
110 | elif optimizer == 'sgd':
111 | g_optimizer = optim.SGD(g_params, lr=lr_g, momentum=0.)
112 | d_optimizer = optim.SGD(d_params, lr=lr_d, momentum=0.)
113 |
114 | return g_optimizer, d_optimizer
115 |
116 |
117 | def build_lr_scheduler(optimizer, config, last_epoch=-1):
118 | lr_scheduler = optim.lr_scheduler.StepLR(
119 | optimizer,
120 | step_size=config['training']['lr_anneal_every'],
121 | gamma=config['training']['lr_anneal'],
122 | last_epoch=last_epoch
123 | )
124 | return lr_scheduler
125 |
126 |
127 | # Some utility functions
128 | def get_parameter_groups(parameters, gradient_scales, base_lr):
129 | param_groups = []
130 | for p in parameters:
131 | c = gradient_scales.get(p, 1.)
132 | param_groups.append({
133 | 'params': [p],
134 | 'lr': c * base_lr
135 | })
136 | return param_groups
137 |
--------------------------------------------------------------------------------
/submodules/GAN_stability/gan_training/distributions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import distributions
3 |
4 |
5 | def get_zdist(dist_name, dim, device=None):
6 | # Get distribution
7 | if dist_name == 'uniform':
8 | low = -torch.ones(dim, device=device)
9 | high = torch.ones(dim, device=device)
10 | zdist = distributions.Uniform(low, high)
11 | elif dist_name == 'gauss':
12 | mu = torch.zeros(dim, device=device)
13 | scale = torch.ones(dim, device=device)
14 | zdist = distributions.Normal(mu, scale)
15 | else:
16 | raise NotImplementedError
17 |
18 | # Add dim attribute
19 | zdist.dim = dim
20 |
21 | return zdist
22 |
23 |
24 | def get_ydist(nlabels, device=None):
25 | logits = torch.zeros(nlabels, device=device)
26 | ydist = distributions.categorical.Categorical(logits=logits)
27 |
28 | # Add nlabels attribute
29 | ydist.nlabels = nlabels
30 |
31 | return ydist
32 |
33 |
34 | def interpolate_sphere(z1, z2, t):
35 | p = (z1 * z2).sum(dim=-1, keepdim=True)
36 | p = p / z1.pow(2).sum(dim=-1, keepdim=True).sqrt()
37 | p = p / z2.pow(2).sum(dim=-1, keepdim=True).sqrt()
38 | omega = torch.acos(p)
39 | s1 = torch.sin((1-t)*omega)/torch.sin(omega)
40 | s2 = torch.sin(t*omega)/torch.sin(omega)
41 | z = s1 * z1 + s2 * z2
42 |
43 | return z
44 |
--------------------------------------------------------------------------------
/submodules/GAN_stability/gan_training/eval.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from GAN_stability.gan_training.metrics import inception_score
3 |
4 |
5 | class Evaluator(object):
6 | def __init__(self, generator, zdist, ydist, batch_size=64,
7 | inception_nsamples=60000, device=None):
8 | self.generator = generator
9 | self.zdist = zdist
10 | self.ydist = ydist
11 | self.inception_nsamples = inception_nsamples
12 | self.batch_size = batch_size
13 | self.device = device
14 |
15 | def compute_inception_score(self):
16 | self.generator.eval()
17 | imgs = []
18 | while(len(imgs) < self.inception_nsamples):
19 | ztest = self.zdist.sample((self.batch_size,))
20 | ytest = self.ydist.sample((self.batch_size,))
21 |
22 | samples = self.generator(ztest, ytest)
23 | samples = [s.data.cpu().numpy() for s in samples]
24 | imgs.extend(samples)
25 |
26 | imgs = imgs[:self.inception_nsamples]
27 | score, score_std = inception_score(
28 | imgs, device=self.device, resize=True, splits=10
29 | )
30 |
31 | return score, score_std
32 |
33 | def create_samples(self, z, y=None):
34 | self.generator.eval()
35 | batch_size = z.size(0)
36 | # Parse y
37 | if y is None:
38 | y = self.ydist.sample((batch_size,))
39 | elif isinstance(y, int):
40 | y = torch.full((batch_size,), y,
41 | device=self.device, dtype=torch.int64)
42 | # Sample x
43 | with torch.no_grad():
44 | x = self.generator(z, y)
45 | return x
46 |
--------------------------------------------------------------------------------
/submodules/GAN_stability/gan_training/inputs.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision.transforms as transforms
3 | import torchvision.datasets as datasets
4 | import numpy as np
5 |
6 |
7 | def get_dataset(name, data_dir, size=64, lsun_categories=None):
8 | transform = transforms.Compose([
9 | transforms.Resize(size),
10 | transforms.CenterCrop(size),
11 | transforms.RandomHorizontalFlip(),
12 | transforms.ToTensor(),
13 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
14 | transforms.Lambda(lambda x: x + 1./128 * torch.rand(x.size())),
15 | ])
16 |
17 | if name == 'image':
18 | dataset = datasets.ImageFolder(data_dir, transform)
19 | nlabels = len(dataset.classes)
20 | elif name == 'npy':
21 | # Only support normalization for now
22 | dataset = datasets.DatasetFolder(data_dir, npy_loader, ['npy'])
23 | nlabels = len(dataset.classes)
24 | elif name == 'cifar10':
25 | dataset = datasets.CIFAR10(root=data_dir, train=True, download=True,
26 | transform=transform)
27 | nlabels = 10
28 | elif name == 'lsun':
29 | if lsun_categories is None:
30 | lsun_categories = 'train'
31 | dataset = datasets.LSUN(data_dir, lsun_categories, transform)
32 | nlabels = len(dataset.classes)
33 | elif name == 'lsun_class':
34 | dataset = datasets.LSUNClass(data_dir, transform,
35 | target_transform=(lambda t: 0))
36 | nlabels = 1
37 | else:
38 | raise NotImplemented
39 |
40 | return dataset, nlabels
41 |
42 |
43 | def npy_loader(path):
44 | img = np.load(path)
45 |
46 | if img.dtype == np.uint8:
47 | img = img.astype(np.float32)
48 | img = img/127.5 - 1.
49 | elif img.dtype == np.float32:
50 | img = img * 2 - 1.
51 | else:
52 | raise NotImplementedError
53 |
54 | img = torch.Tensor(img)
55 | if len(img.size()) == 4:
56 | img.squeeze_(0)
57 |
58 | return img
59 |
--------------------------------------------------------------------------------
/submodules/GAN_stability/gan_training/logger.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import os
3 | import torchvision
4 |
5 |
6 | class Logger(object):
7 | def __init__(self, log_dir='./logs', img_dir='./imgs',
8 | monitoring=None, monitoring_dir=None):
9 | self.stats = dict()
10 | self.log_dir = log_dir
11 | self.img_dir = img_dir
12 |
13 | if not os.path.exists(log_dir):
14 | os.makedirs(log_dir)
15 |
16 | if not os.path.exists(img_dir):
17 | os.makedirs(img_dir)
18 |
19 | if not (monitoring is None or monitoring == 'none'):
20 | self.setup_monitoring(monitoring, monitoring_dir)
21 | else:
22 | self.monitoring = None
23 | self.monitoring_dir = None
24 |
25 | def setup_monitoring(self, monitoring, monitoring_dir=None):
26 | self.monitoring = monitoring
27 | self.monitoring_dir = monitoring_dir
28 |
29 | if monitoring == 'telemetry':
30 | import telemetry
31 | self.tm = telemetry.ApplicationTelemetry()
32 | if self.tm.get_status() == 0:
33 | print('Telemetry successfully connected.')
34 | elif monitoring == 'tensorboard':
35 | import tensorboardX
36 | self.tb = tensorboardX.SummaryWriter(monitoring_dir)
37 | else:
38 | raise NotImplementedError('Monitoring tool "%s" not supported!'
39 | % monitoring)
40 |
41 | def add(self, category, k, v, it):
42 | if category not in self.stats:
43 | self.stats[category] = {}
44 |
45 | if k not in self.stats[category]:
46 | self.stats[category][k] = []
47 |
48 | self.stats[category][k].append((it, v))
49 |
50 | k_name = '%s/%s' % (category, k)
51 | if self.monitoring == 'telemetry':
52 | self.tm.metric_push_async({
53 | 'metric': k_name, 'value': v, 'it': it
54 | })
55 | elif self.monitoring == 'tensorboard':
56 | self.tb.add_scalar(k_name, v, it)
57 |
58 | def add_imgs(self, imgs, class_name, it):
59 | outdir = os.path.join(self.img_dir, class_name)
60 | if not os.path.exists(outdir):
61 | os.makedirs(outdir)
62 | outfile = os.path.join(outdir, '%08d.png' % it)
63 |
64 | imgs = imgs / 2 + 0.5
65 | imgs = torchvision.utils.make_grid(imgs)
66 | torchvision.utils.save_image(imgs.clone(), outfile, nrow=8)
67 |
68 | if self.monitoring == 'tensorboard':
69 | self.tb.add_image(class_name, imgs, it)
70 |
71 | def get_last(self, category, k, default=0.):
72 | if category not in self.stats:
73 | return default
74 | elif k not in self.stats[category]:
75 | return default
76 | else:
77 | return self.stats[category][k][-1][1]
78 |
79 | def save_stats(self, filename):
80 | filename = os.path.join(self.log_dir, filename)
81 | with open(filename, 'wb') as f:
82 | pickle.dump(self.stats, f)
83 |
84 | def load_stats(self, filename):
85 | filename = os.path.join(self.log_dir, filename)
86 | if not os.path.exists(filename):
87 | print('Warning: file "%s" does not exist!' % filename)
88 | return
89 |
90 | try:
91 | with open(filename, 'rb') as f:
92 | self.stats = pickle.load(f)
93 | except EOFError:
94 | print('Warning: log file corrupted!')
95 |
--------------------------------------------------------------------------------
/submodules/GAN_stability/gan_training/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | from GAN_stability.gan_training.metrics.inception_score import inception_score
2 | from GAN_stability.gan_training.metrics.fid_score import FIDEvaluator
3 | from GAN_stability.gan_training.metrics.kid_score import KIDEvaluator
4 |
5 | __all__ = [
6 | inception_score,
7 | FIDEvaluator,
8 | KIDEvaluator
9 | ]
10 |
--------------------------------------------------------------------------------
/submodules/GAN_stability/gan_training/metrics/inception_score.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 | import torch.utils.data
5 |
6 | from torchvision.models.inception import inception_v3
7 |
8 | import numpy as np
9 | from scipy.stats import entropy
10 |
11 |
12 | def inception_score(imgs, device=None, batch_size=32, resize=False, splits=1):
13 | """Computes the inception score of the generated images imgs
14 |
15 | Args:
16 | imgs: Torch dataset of (3xHxW) numpy images normalized in the
17 | range [-1, 1]
18 | cuda: whether or not to run on GPU
19 | batch_size: batch size for feeding into Inception v3
20 | splits: number of splits
21 | """
22 | N = len(imgs)
23 |
24 | assert batch_size > 0
25 | assert N > batch_size
26 |
27 | # Set up dataloader
28 | dataloader = torch.utils.data.DataLoader(imgs, batch_size=batch_size)
29 |
30 | # Load inception model
31 | inception_model = inception_v3(pretrained=True, transform_input=False)
32 | inception_model = inception_model.to(device)
33 | inception_model.eval()
34 | up = nn.Upsample(size=(299, 299), mode='bilinear').to(device)
35 |
36 | def get_pred(x):
37 | with torch.no_grad():
38 | if resize:
39 | x = up(x)
40 | x = inception_model(x)
41 | out = F.softmax(x, dim=-1)
42 | out = out.cpu().numpy()
43 | return out
44 |
45 | # Get predictions
46 | preds = np.zeros((N, 1000))
47 |
48 | for i, batch in enumerate(dataloader, 0):
49 | batchv = batch.to(device)
50 | batch_size_i = batch.size()[0]
51 |
52 | preds[i*batch_size:i*batch_size + batch_size_i] = get_pred(batchv)
53 |
54 | # Now compute the mean kl-div
55 | split_scores = []
56 |
57 | for k in range(splits):
58 | part = preds[k * (N // splits): (k+1) * (N // splits), :]
59 | py = np.mean(part, axis=0)
60 | scores = []
61 | for i in range(part.shape[0]):
62 | pyx = part[i, :]
63 | scores.append(entropy(pyx, py))
64 | split_scores.append(np.exp(np.mean(scores)))
65 |
66 | return np.mean(split_scores), np.std(split_scores)
67 |
--------------------------------------------------------------------------------
/submodules/GAN_stability/gan_training/models/__init__.py:
--------------------------------------------------------------------------------
1 | from GAN_stability.gan_training.models import (
2 | resnet, resnet2, resnet3, resnet4,
3 | )
4 |
5 | generator_dict = {
6 | 'resnet': resnet.Generator,
7 | 'resnet2': resnet2.Generator,
8 | 'resnet3': resnet3.Generator,
9 | 'resnet4': resnet4.Generator,
10 | }
11 |
12 | discriminator_dict = {
13 | 'resnet': resnet.Discriminator,
14 | 'resnet2': resnet2.Discriminator,
15 | 'resnet3': resnet3.Discriminator,
16 | 'resnet4': resnet4.Discriminator,
17 | }
18 |
--------------------------------------------------------------------------------
/submodules/GAN_stability/gan_training/models/resnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 | from torch.autograd import Variable
5 | import torch.utils.data
6 | import torch.utils.data.distributed
7 | import numpy as np
8 |
9 |
10 | class Generator(nn.Module):
11 | def __init__(self, z_dim, nlabels, size, embed_size=256, nfilter=64, nfilter_max=512, **kwargs):
12 | super().__init__()
13 | s0 = self.s0 = 4
14 | nf = self.nf = nfilter
15 | nf_max = self.nf_max = nfilter_max
16 |
17 | self.z_dim = z_dim
18 |
19 | # Submodules
20 | nlayers = int(np.log2(size / s0))
21 | self.nf0 = min(nf_max, nf * 2**nlayers)
22 |
23 | self.embedding = nn.Embedding(nlabels, embed_size)
24 | self.fc = nn.Linear(z_dim + embed_size, self.nf0*s0*s0)
25 |
26 | blocks = []
27 | for i in range(nlayers):
28 | nf0 = min(nf * 2**(nlayers-i), nf_max)
29 | nf1 = min(nf * 2**(nlayers-i-1), nf_max)
30 | blocks += [
31 | ResnetBlock(nf0, nf1),
32 | nn.Upsample(scale_factor=2)
33 | ]
34 |
35 | blocks += [
36 | ResnetBlock(nf, nf),
37 | ]
38 |
39 | self.resnet = nn.Sequential(*blocks)
40 | self.conv_img = nn.Conv2d(nf, 3, 3, padding=1)
41 |
42 | def forward(self, z, y):
43 | assert(z.size(0) == y.size(0))
44 | batch_size = z.size(0)
45 |
46 | if y.dtype is torch.int64:
47 | yembed = self.embedding(y)
48 | else:
49 | yembed = y
50 |
51 | yembed = yembed / torch.norm(yembed, p=2, dim=1, keepdim=True)
52 |
53 | yz = torch.cat([z, yembed], dim=1)
54 | out = self.fc(yz)
55 | out = out.view(batch_size, self.nf0, self.s0, self.s0)
56 |
57 | out = self.resnet(out)
58 |
59 | out = self.conv_img(actvn(out))
60 | out = torch.tanh(out)
61 |
62 | return out
63 |
64 |
65 | class Discriminator(nn.Module):
66 | def __init__(self, z_dim, nlabels, size, embed_size=256, nfilter=64, nfilter_max=1024):
67 | super().__init__()
68 | self.embed_size = embed_size
69 | s0 = self.s0 = 4
70 | nf = self.nf = nfilter
71 | nf_max = self.nf_max = nfilter_max
72 |
73 | # Submodules
74 | nlayers = int(np.log2(size / s0))
75 | self.nf0 = min(nf_max, nf * 2**nlayers)
76 |
77 | blocks = [
78 | ResnetBlock(nf, nf)
79 | ]
80 |
81 | for i in range(nlayers):
82 | nf0 = min(nf * 2**i, nf_max)
83 | nf1 = min(nf * 2**(i+1), nf_max)
84 | blocks += [
85 | nn.AvgPool2d(3, stride=2, padding=1),
86 | ResnetBlock(nf0, nf1),
87 | ]
88 |
89 | self.conv_img = nn.Conv2d(3, 1*nf, 3, padding=1)
90 | self.resnet = nn.Sequential(*blocks)
91 | self.fc = nn.Linear(self.nf0*s0*s0, nlabels)
92 |
93 | def forward(self, x, y):
94 | assert(x.size(0) == y.size(0))
95 | batch_size = x.size(0)
96 |
97 | out = self.conv_img(x)
98 | out = self.resnet(out)
99 | out = out.view(batch_size, self.nf0*self.s0*self.s0)
100 | out = self.fc(actvn(out))
101 |
102 | index = Variable(torch.LongTensor(range(out.size(0))))
103 | if y.is_cuda:
104 | index = index.cuda()
105 | out = out[index, y]
106 |
107 | return out
108 |
109 |
110 | class ResnetBlock(nn.Module):
111 | def __init__(self, fin, fout, fhidden=None, is_bias=True):
112 | super().__init__()
113 | # Attributes
114 | self.is_bias = is_bias
115 | self.learned_shortcut = (fin != fout)
116 | self.fin = fin
117 | self.fout = fout
118 | if fhidden is None:
119 | self.fhidden = min(fin, fout)
120 | else:
121 | self.fhidden = fhidden
122 |
123 | # Submodules
124 | self.conv_0 = nn.Conv2d(self.fin, self.fhidden, 3, stride=1, padding=1)
125 | self.conv_1 = nn.Conv2d(self.fhidden, self.fout, 3, stride=1, padding=1, bias=is_bias)
126 | if self.learned_shortcut:
127 | self.conv_s = nn.Conv2d(self.fin, self.fout, 1, stride=1, padding=0, bias=False)
128 |
129 | def forward(self, x):
130 | x_s = self._shortcut(x)
131 | dx = self.conv_0(actvn(x))
132 | dx = self.conv_1(actvn(dx))
133 | out = x_s + 0.1*dx
134 |
135 | return out
136 |
137 | def _shortcut(self, x):
138 | if self.learned_shortcut:
139 | x_s = self.conv_s(x)
140 | else:
141 | x_s = x
142 | return x_s
143 |
144 |
145 | def actvn(x):
146 | out = F.leaky_relu(x, 2e-1)
147 | return out
148 |
--------------------------------------------------------------------------------
/submodules/GAN_stability/gan_training/models/resnet2.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 | from torch.autograd import Variable
5 | import torch.utils.data
6 | import torch.utils.data.distributed
7 |
8 |
9 | class Generator(nn.Module):
10 | def __init__(self, z_dim, nlabels, size, embed_size=256, nfilter=64, **kwargs):
11 | super().__init__()
12 | s0 = self.s0 = size // 32
13 | nf = self.nf = nfilter
14 | self.z_dim = z_dim
15 |
16 | # Submodules
17 | self.embedding = nn.Embedding(nlabels, embed_size)
18 | self.fc = nn.Linear(z_dim + embed_size, 16*nf*s0*s0)
19 |
20 | self.resnet_0_0 = ResnetBlock(16*nf, 16*nf)
21 | self.resnet_0_1 = ResnetBlock(16*nf, 16*nf)
22 |
23 | self.resnet_1_0 = ResnetBlock(16*nf, 16*nf)
24 | self.resnet_1_1 = ResnetBlock(16*nf, 16*nf)
25 |
26 | self.resnet_2_0 = ResnetBlock(16*nf, 8*nf)
27 | self.resnet_2_1 = ResnetBlock(8*nf, 8*nf)
28 |
29 | self.resnet_3_0 = ResnetBlock(8*nf, 4*nf)
30 | self.resnet_3_1 = ResnetBlock(4*nf, 4*nf)
31 |
32 | self.resnet_4_0 = ResnetBlock(4*nf, 2*nf)
33 | self.resnet_4_1 = ResnetBlock(2*nf, 2*nf)
34 |
35 | self.resnet_5_0 = ResnetBlock(2*nf, 1*nf)
36 | self.resnet_5_1 = ResnetBlock(1*nf, 1*nf)
37 |
38 | self.conv_img = nn.Conv2d(nf, 3, 3, padding=1)
39 |
40 | def forward(self, z, y):
41 | assert(z.size(0) == y.size(0))
42 | batch_size = z.size(0)
43 |
44 | if y.dtype is torch.int64:
45 | yembed = self.embedding(y)
46 | else:
47 | yembed = y
48 |
49 | yembed = yembed / torch.norm(yembed, p=2, dim=1, keepdim=True)
50 |
51 | yz = torch.cat([z, yembed], dim=1)
52 | out = self.fc(yz)
53 | out = out.view(batch_size, 16*self.nf, self.s0, self.s0)
54 |
55 | out = self.resnet_0_0(out)
56 | out = self.resnet_0_1(out)
57 |
58 | out = F.interpolate(out, scale_factor=2)
59 | out = self.resnet_1_0(out)
60 | out = self.resnet_1_1(out)
61 |
62 | out = F.interpolate(out, scale_factor=2)
63 | out = self.resnet_2_0(out)
64 | out = self.resnet_2_1(out)
65 |
66 | out = F.interpolate(out, scale_factor=2)
67 | out = self.resnet_3_0(out)
68 | out = self.resnet_3_1(out)
69 |
70 | out = F.interpolate(out, scale_factor=2)
71 | out = self.resnet_4_0(out)
72 | out = self.resnet_4_1(out)
73 |
74 | out = F.interpolate(out, scale_factor=2)
75 | out = self.resnet_5_0(out)
76 | out = self.resnet_5_1(out)
77 |
78 | out = self.conv_img(actvn(out))
79 | out = torch.tanh(out)
80 |
81 | return out
82 |
83 |
84 | class Discriminator(nn.Module):
85 | def __init__(self, z_dim, nlabels, size, embed_size=256, nfilter=64, **kwargs):
86 | super().__init__()
87 | self.embed_size = embed_size
88 | s0 = self.s0 = size // 32
89 | nf = self.nf = nfilter
90 | ny = nlabels
91 |
92 | # Submodules
93 | self.conv_img = nn.Conv2d(3, 1*nf, 3, padding=1)
94 |
95 | self.resnet_0_0 = ResnetBlock(1*nf, 1*nf)
96 | self.resnet_0_1 = ResnetBlock(1*nf, 2*nf)
97 |
98 | self.resnet_1_0 = ResnetBlock(2*nf, 2*nf)
99 | self.resnet_1_1 = ResnetBlock(2*nf, 4*nf)
100 |
101 | self.resnet_2_0 = ResnetBlock(4*nf, 4*nf)
102 | self.resnet_2_1 = ResnetBlock(4*nf, 8*nf)
103 |
104 | self.resnet_3_0 = ResnetBlock(8*nf, 8*nf)
105 | self.resnet_3_1 = ResnetBlock(8*nf, 16*nf)
106 |
107 | self.resnet_4_0 = ResnetBlock(16*nf, 16*nf)
108 | self.resnet_4_1 = ResnetBlock(16*nf, 16*nf)
109 |
110 | self.resnet_5_0 = ResnetBlock(16*nf, 16*nf)
111 | self.resnet_5_1 = ResnetBlock(16*nf, 16*nf)
112 |
113 | self.fc = nn.Linear(16*nf*s0*s0, nlabels)
114 |
115 |
116 | def forward(self, x, y):
117 | assert(x.size(0) == y.size(0))
118 | batch_size = x.size(0)
119 |
120 | out = self.conv_img(x)
121 |
122 | out = self.resnet_0_0(out)
123 | out = self.resnet_0_1(out)
124 |
125 | out = F.avg_pool2d(out, 3, stride=2, padding=1)
126 | out = self.resnet_1_0(out)
127 | out = self.resnet_1_1(out)
128 |
129 | out = F.avg_pool2d(out, 3, stride=2, padding=1)
130 | out = self.resnet_2_0(out)
131 | out = self.resnet_2_1(out)
132 |
133 | out = F.avg_pool2d(out, 3, stride=2, padding=1)
134 | out = self.resnet_3_0(out)
135 | out = self.resnet_3_1(out)
136 |
137 | out = F.avg_pool2d(out, 3, stride=2, padding=1)
138 | out = self.resnet_4_0(out)
139 | out = self.resnet_4_1(out)
140 |
141 | out = F.avg_pool2d(out, 3, stride=2, padding=1)
142 | out = self.resnet_5_0(out)
143 | out = self.resnet_5_1(out)
144 |
145 | out = out.view(batch_size, 16*self.nf*self.s0*self.s0)
146 | out = self.fc(actvn(out))
147 |
148 | index = Variable(torch.LongTensor(range(out.size(0))))
149 | if y.is_cuda:
150 | index = index.cuda()
151 | out = out[index, y]
152 |
153 | return out
154 |
155 |
156 | class ResnetBlock(nn.Module):
157 | def __init__(self, fin, fout, fhidden=None, is_bias=True):
158 | super().__init__()
159 | # Attributes
160 | self.is_bias = is_bias
161 | self.learned_shortcut = (fin != fout)
162 | self.fin = fin
163 | self.fout = fout
164 | if fhidden is None:
165 | self.fhidden = min(fin, fout)
166 | else:
167 | self.fhidden = fhidden
168 |
169 | # Submodules
170 | self.conv_0 = nn.Conv2d(self.fin, self.fhidden, 3, stride=1, padding=1)
171 | self.conv_1 = nn.Conv2d(self.fhidden, self.fout, 3, stride=1, padding=1, bias=is_bias)
172 | if self.learned_shortcut:
173 | self.conv_s = nn.Conv2d(self.fin, self.fout, 1, stride=1, padding=0, bias=False)
174 |
175 |
176 | def forward(self, x):
177 | x_s = self._shortcut(x)
178 | dx = self.conv_0(actvn(x))
179 | dx = self.conv_1(actvn(dx))
180 | out = x_s + 0.1*dx
181 |
182 | return out
183 |
184 | def _shortcut(self, x):
185 | if self.learned_shortcut:
186 | x_s = self.conv_s(x)
187 | else:
188 | x_s = x
189 | return x_s
190 |
191 |
192 | def actvn(x):
193 | out = F.leaky_relu(x, 2e-1)
194 | return out
195 |
--------------------------------------------------------------------------------
/submodules/GAN_stability/gan_training/models/resnet3.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 | from torch.autograd import Variable
5 | import torch.utils.data
6 | import torch.utils.data.distributed
7 |
8 |
9 | class Generator(nn.Module):
10 | def __init__(self, z_dim, nlabels, size, embed_size=256, nfilter=64, **kwargs):
11 | super().__init__()
12 | s0 = self.s0 = size // 64
13 | nf = self.nf = nfilter
14 | self.z_dim = z_dim
15 |
16 | # Submodules
17 | self.embedding = nn.Embedding(nlabels, embed_size)
18 | self.fc = nn.Linear(z_dim + embed_size, 32*nf*s0*s0)
19 |
20 | self.resnet_0_0 = ResnetBlock(32*nf, 16*nf)
21 | self.resnet_1_0 = ResnetBlock(16*nf, 16*nf)
22 | self.resnet_2_0 = ResnetBlock(16*nf, 8*nf)
23 | self.resnet_3_0 = ResnetBlock(8*nf, 4*nf)
24 | self.resnet_4_0 = ResnetBlock(4*nf, 2*nf)
25 | self.resnet_5_0 = ResnetBlock(2*nf, 1*nf)
26 | self.conv_img = nn.Conv2d(nf, 3, 7, padding=3)
27 |
28 | def forward(self, z, y):
29 | assert(z.size(0) == y.size(0))
30 | batch_size = z.size(0)
31 |
32 | yembed = self.embedding(y)
33 | yz = torch.cat([z, yembed], dim=1)
34 | out = self.fc(yz)
35 | out = out.view(batch_size, 32*self.nf, self.s0, self.s0)
36 |
37 | out = self.resnet_0_0(out)
38 |
39 | out = F.interpolate(out, scale_factor=2)
40 | out = self.resnet_1_0(out)
41 |
42 | out = F.interpolate(out, scale_factor=2)
43 | out = self.resnet_2_0(out)
44 |
45 | out = F.interpolate(out, scale_factor=2)
46 | out = self.resnet_3_0(out)
47 |
48 | out = F.interpolate(out, scale_factor=2)
49 | out = self.resnet_4_0(out)
50 |
51 | out = F.interpolate(out, scale_factor=2)
52 | out = self.resnet_5_0(out)
53 |
54 | out = F.interpolate(out, scale_factor=2)
55 |
56 | out = self.conv_img(actvn(out))
57 | out = torch.tanh(out)
58 |
59 | return out
60 |
61 |
62 | class Discriminator(nn.Module):
63 | def __init__(self, z_dim, nlabels, size, embed_size=256, nfilter=64, **kwargs):
64 | super().__init__()
65 | self.embed_size = embed_size
66 | s0 = self.s0 = size // 64
67 | nf = self.nf = nfilter
68 |
69 | # Submodules
70 | self.conv_img = nn.Conv2d(3, 1*nf, 7, padding=3)
71 |
72 | self.resnet_0_0 = ResnetBlock(1*nf, 2*nf)
73 | self.resnet_1_0 = ResnetBlock(2*nf, 4*nf)
74 | self.resnet_2_0 = ResnetBlock(4*nf, 8*nf)
75 | self.resnet_3_0 = ResnetBlock(8*nf, 16*nf)
76 | self.resnet_4_0 = ResnetBlock(16*nf, 16*nf)
77 | self.resnet_5_0 = ResnetBlock(16*nf, 32*nf)
78 |
79 | self.fc = nn.Linear(32*nf*s0*s0, nlabels)
80 |
81 | def forward(self, x, y):
82 | assert(x.size(0) == y.size(0))
83 | batch_size = x.size(0)
84 |
85 | out = self.conv_img(x)
86 |
87 | out = F.avg_pool2d(out, 3, stride=2, padding=1)
88 | out = self.resnet_0_0(out)
89 |
90 | out = F.avg_pool2d(out, 3, stride=2, padding=1)
91 | out = self.resnet_1_0(out)
92 |
93 | out = F.avg_pool2d(out, 3, stride=2, padding=1)
94 | out = self.resnet_2_0(out)
95 |
96 | out = F.avg_pool2d(out, 3, stride=2, padding=1)
97 | out = self.resnet_3_0(out)
98 |
99 | out = F.avg_pool2d(out, 3, stride=2, padding=1)
100 | out = self.resnet_4_0(out)
101 |
102 | out = F.avg_pool2d(out, 3, stride=2, padding=1)
103 | out = self.resnet_5_0(out)
104 |
105 | out = out.view(batch_size, 32*self.nf*self.s0*self.s0)
106 | out = self.fc(actvn(out))
107 |
108 | index = Variable(torch.LongTensor(range(out.size(0))))
109 | if y.is_cuda:
110 | index = index.cuda()
111 | out = out[index, y]
112 |
113 | return out
114 |
115 |
116 | class ResnetBlock(nn.Module):
117 | def __init__(self, fin, fout, fhidden=None, is_bias=True):
118 | super().__init__()
119 | # Attributes
120 | self.is_bias = is_bias
121 | self.learned_shortcut = (fin != fout)
122 | self.fin = fin
123 | self.fout = fout
124 | if fhidden is None:
125 | self.fhidden = min(fin, fout)
126 | else:
127 | self.fhidden = fhidden
128 |
129 | # Submodules
130 | self.conv_0 = nn.Conv2d(self.fin, self.fhidden, 3, stride=1, padding=1)
131 | self.conv_1 = nn.Conv2d(self.fhidden, self.fout, 3, stride=1, padding=1, bias=is_bias)
132 | if self.learned_shortcut:
133 | self.conv_s = nn.Conv2d(self.fin, self.fout, 1, stride=1, padding=0, bias=False)
134 |
135 | def forward(self, x):
136 | x_s = self._shortcut(x)
137 | dx = self.conv_0(actvn(x))
138 | dx = self.conv_1(actvn(dx))
139 | out = x_s + 0.1*dx
140 |
141 | return out
142 |
143 | def _shortcut(self, x):
144 | if self.learned_shortcut:
145 | x_s = self.conv_s(x)
146 | else:
147 | x_s = x
148 | return x_s
149 |
150 |
151 | def actvn(x):
152 | out = F.leaky_relu(x, 2e-1)
153 | return out
154 |
--------------------------------------------------------------------------------
/submodules/GAN_stability/gan_training/models/resnet4.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 | from torch.autograd import Variable
5 | import torch.utils.data
6 | import torch.utils.data.distributed
7 |
8 |
9 | class Generator(nn.Module):
10 | def __init__(self, z_dim, nlabels, size, embed_size=256, nfilter=64, **kwargs):
11 | super().__init__()
12 | s0 = self.s0 = size // 64
13 | nf = self.nf = nfilter
14 | self.z_dim = z_dim
15 |
16 | # Submodules
17 | self.embedding = nn.Embedding(nlabels, embed_size)
18 | self.fc = nn.Linear(z_dim + embed_size, 16*nf*s0*s0)
19 |
20 | self.resnet_0_0 = ResnetBlock(16*nf, 16*nf)
21 | self.resnet_1_0 = ResnetBlock(16*nf, 16*nf)
22 | self.resnet_2_0 = ResnetBlock(16*nf, 8*nf)
23 | self.resnet_3_0 = ResnetBlock(8*nf, 4*nf)
24 | self.resnet_4_0 = ResnetBlock(4*nf, 2*nf)
25 | self.resnet_5_0 = ResnetBlock(2*nf, 1*nf)
26 | self.resnet_6_0 = ResnetBlock(1*nf, 1*nf)
27 | self.conv_img = nn.Conv2d(nf, 3, 7, padding=3)
28 |
29 |
30 | def forward(self, z, y):
31 | assert(z.size(0) == y.size(0))
32 | batch_size = z.size(0)
33 |
34 | yembed = self.embedding(y)
35 | yz = torch.cat([z, yembed], dim=1)
36 | out = self.fc(yz)
37 | out = out.view(batch_size, 16*self.nf, self.s0, self.s0)
38 |
39 | out = self.resnet_0_0(out)
40 |
41 | out = F.interpolate(out, scale_factor=2)
42 | out = self.resnet_1_0(out)
43 |
44 | out = F.interpolate(out, scale_factor=2)
45 | out = self.resnet_2_0(out)
46 |
47 | out = F.interpolate(out, scale_factor=2)
48 | out = self.resnet_3_0(out)
49 |
50 | out = F.interpolate(out, scale_factor=2)
51 | out = self.resnet_4_0(out)
52 |
53 | out = F.interpolate(out, scale_factor=2)
54 | out = self.resnet_5_0(out)
55 |
56 | out = F.interpolate(out, scale_factor=2)
57 | out = self.resnet_6_0(out)
58 | out = self.conv_img(actvn(out))
59 | out = torch.tanh(out)
60 |
61 | return out
62 |
63 |
64 | class Discriminator(nn.Module):
65 | def __init__(self, z_dim, nlabels, size, embed_size=256, nfilter=64, **kwargs):
66 | super().__init__()
67 | self.embed_size = embed_size
68 | s0 = self.s0 = size // 64
69 | nf = self.nf = nfilter
70 |
71 | # Submodules
72 | self.conv_img = nn.Conv2d(3, 1*nf, 7, padding=3)
73 |
74 | self.resnet_0_0 = ResnetBlock(1*nf, 1*nf)
75 | self.resnet_1_0 = ResnetBlock(1*nf, 2*nf)
76 | self.resnet_2_0 = ResnetBlock(2*nf, 4*nf)
77 | self.resnet_3_0 = ResnetBlock(4*nf, 8*nf)
78 | self.resnet_4_0 = ResnetBlock(8*nf, 16*nf)
79 | self.resnet_5_0 = ResnetBlock(16*nf, 16*nf)
80 | self.resnet_6_0 = ResnetBlock(16*nf, 16*nf)
81 |
82 | self.fc = nn.Linear(16*nf*s0*s0, nlabels)
83 |
84 | def forward(self, x, y):
85 | assert(x.size(0) == y.size(0))
86 | batch_size = x.size(0)
87 |
88 | out = self.conv_img(x)
89 | out = self.resnet_0_0(out)
90 |
91 | out = F.avg_pool2d(out, 3, stride=2, padding=1)
92 | out = self.resnet_1_0(out)
93 |
94 | out = F.avg_pool2d(out, 3, stride=2, padding=1)
95 | out = self.resnet_2_0(out)
96 |
97 | out = F.avg_pool2d(out, 3, stride=2, padding=1)
98 | out = self.resnet_3_0(out)
99 |
100 | out = F.avg_pool2d(out, 3, stride=2, padding=1)
101 | out = self.resnet_4_0(out)
102 |
103 | out = F.avg_pool2d(out, 3, stride=2, padding=1)
104 | out = self.resnet_5_0(out)
105 |
106 | out = F.avg_pool2d(out, 3, stride=2, padding=1)
107 | out = self.resnet_6_0(out)
108 |
109 | out = out.view(batch_size, 16*self.nf*self.s0*self.s0)
110 | out = self.fc(actvn(out))
111 |
112 | index = Variable(torch.LongTensor(range(out.size(0))))
113 | if y.is_cuda:
114 | index = index.cuda()
115 | out = out[index, y]
116 |
117 | return out
118 |
119 |
120 | class ResnetBlock(nn.Module):
121 | def __init__(self, fin, fout, fhidden=None, is_bias=True):
122 | super().__init__()
123 | # Attributes
124 | self.is_bias = is_bias
125 | self.learned_shortcut = (fin != fout)
126 | self.fin = fin
127 | self.fout = fout
128 | if fhidden is None:
129 | self.fhidden = min(fin, fout)
130 | else:
131 | self.fhidden = fhidden
132 |
133 | # Submodules
134 | self.conv_0 = nn.Conv2d(self.fin, self.fhidden, 3, stride=1, padding=1)
135 | self.conv_1 = nn.Conv2d(self.fhidden, self.fout, 3, stride=1, padding=1, bias=is_bias)
136 | if self.learned_shortcut:
137 | self.conv_s = nn.Conv2d(self.fin, self.fout, 1, stride=1, padding=0, bias=False)
138 |
139 | def forward(self, x):
140 | x_s = self._shortcut(x)
141 | dx = self.conv_0(actvn(x))
142 | dx = self.conv_1(actvn(dx))
143 | out = x_s + 0.1*dx
144 |
145 | return out
146 |
147 | def _shortcut(self, x):
148 | if self.learned_shortcut:
149 | x_s = self.conv_s(x)
150 | else:
151 | x_s = x
152 | return x_s
153 |
154 |
155 | def actvn(x):
156 | out = F.leaky_relu(x, 2e-1)
157 | return out
158 |
--------------------------------------------------------------------------------
/submodules/GAN_stability/gan_training/ops.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | import torch
3 | from torch.nn import Parameter
4 |
5 |
6 | class SpectralNorm(nn.Module):
7 | def __init__(self, module, name='weight', power_iterations=1):
8 | super(SpectralNorm, self).__init__()
9 | self.module = module
10 | self.name = name
11 | self.power_iterations = power_iterations
12 | if not self._made_params():
13 | self._make_params()
14 |
15 | def _update_u_v(self):
16 | u = getattr(self.module, self.name + "_u")
17 | v = getattr(self.module, self.name + "_v")
18 | w = getattr(self.module, self.name + "_bar")
19 |
20 | height = w.data.shape[0]
21 | for _ in range(self.power_iterations):
22 | v.data = l2normalize(
23 | torch.mv(torch.t(w.view(height, -1).data), u.data))
24 | u.data = l2normalize(
25 | torch.mv(w.view(height, -1).data, v.data))
26 |
27 | # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
28 | sigma = u.dot(w.view(height, -1).mv(v))
29 | setattr(self.module, self.name, w / sigma.expand_as(w))
30 |
31 | def _made_params(self):
32 | made_params = (
33 | hasattr(self.module, self.name + "_u")
34 | and hasattr(self.module, self.name + "_v")
35 | and hasattr(self.module, self.name + "_bar")
36 | )
37 | return made_params
38 |
39 | def _make_params(self):
40 | w = getattr(self.module, self.name)
41 |
42 | height = w.data.shape[0]
43 | width = w.view(height, -1).data.shape[1]
44 |
45 | u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
46 | v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
47 | u.data = l2normalize(u.data)
48 | v.data = l2normalize(v.data)
49 | w_bar = Parameter(w.data)
50 |
51 | del self.module._parameters[self.name]
52 |
53 | self.module.register_parameter(self.name + "_u", u)
54 | self.module.register_parameter(self.name + "_v", v)
55 | self.module.register_parameter(self.name + "_bar", w_bar)
56 |
57 | def forward(self, *args):
58 | self._update_u_v()
59 | return self.module.forward(*args)
60 |
61 |
62 | def l2normalize(v, eps=1e-12):
63 | return v / (v.norm() + eps)
64 |
65 |
66 | class CBatchNorm(nn.Module):
67 | def __init__(self, nfilter, nlabels):
68 | super().__init__()
69 | # Attributes
70 | self.nlabels = nlabels
71 | self.nfilter = nfilter
72 | # Submodules
73 | self.alpha_embedding = nn.Embedding(nlabels, nfilter)
74 | self.beta_embedding = nn.Embedding(nlabels, nfilter)
75 | self.bn = nn.BatchNorm2d(nfilter, affine=False)
76 | # Initialize
77 | nn.init.constant_(self.alpha_embedding.weight, 1.)
78 | nn.init.constant_(self.beta_embedding.weight, 0.)
79 |
80 | def forward(self, x, y):
81 | dim = len(x.size())
82 | batch_size = x.size(0)
83 | assert(dim >= 2)
84 | assert(x.size(1) == self.nfilter)
85 |
86 | s = [batch_size, self.nfilter] + [1] * (dim - 2)
87 | alpha = self.alpha_embedding(y)
88 | alpha = alpha.view(s)
89 | beta = self.beta_embedding(y)
90 | beta = beta.view(s)
91 |
92 | out = self.bn(x)
93 | out = alpha * out + beta
94 |
95 | return out
96 |
97 |
98 | class CInstanceNorm(nn.Module):
99 | def __init__(self, nfilter, nlabels):
100 | super().__init__()
101 | # Attributes
102 | self.nlabels = nlabels
103 | self.nfilter = nfilter
104 | # Submodules
105 | self.alpha_embedding = nn.Embedding(nlabels, nfilter)
106 | self.beta_embedding = nn.Embedding(nlabels, nfilter)
107 | self.bn = nn.InstanceNorm2d(nfilter, affine=False)
108 | # Initialize
109 | nn.init.uniform(self.alpha_embedding.weight, -1., 1.)
110 | nn.init.constant_(self.beta_embedding.weight, 0.)
111 |
112 | def forward(self, x, y):
113 | dim = len(x.size())
114 | batch_size = x.size(0)
115 | assert(dim >= 2)
116 | assert(x.size(1) == self.nfilter)
117 |
118 | s = [batch_size, self.nfilter] + [1] * (dim - 2)
119 | alpha = self.alpha_embedding(y)
120 | alpha = alpha.view(s)
121 | beta = self.beta_embedding(y)
122 | beta = beta.view(s)
123 |
124 | out = self.bn(x)
125 | out = alpha * out + beta
126 |
127 | return out
128 |
--------------------------------------------------------------------------------
/submodules/GAN_stability/gan_training/train.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 | import torch
3 | from torch.nn import functional as F
4 | import torch.utils.data
5 | import torch.utils.data.distributed
6 | from torch import autograd
7 |
8 |
9 | class Trainer(object):
10 | def __init__(self, generator, discriminator, g_optimizer, d_optimizer,
11 | gan_type, reg_type, reg_param):
12 | self.generator = generator
13 | self.discriminator = discriminator
14 | self.g_optimizer = g_optimizer
15 | self.d_optimizer = d_optimizer
16 |
17 | self.gan_type = gan_type
18 | self.reg_type = reg_type
19 | self.reg_param = reg_param
20 |
21 | def generator_trainstep(self, y, z):
22 | assert(y.size(0) == z.size(0))
23 | toggle_grad(self.generator, True)
24 | toggle_grad(self.discriminator, False)
25 | self.generator.train()
26 | self.discriminator.train()
27 | self.g_optimizer.zero_grad()
28 |
29 | x_fake = self.generator(z, y)
30 | d_fake = self.discriminator(x_fake, y)
31 | gloss = self.compute_loss(d_fake, 1)
32 | gloss.backward()
33 |
34 | self.g_optimizer.step()
35 |
36 | return gloss.item()
37 |
38 | def discriminator_trainstep(self, x_real, y, z):
39 | toggle_grad(self.generator, False)
40 | toggle_grad(self.discriminator, True)
41 | self.generator.train()
42 | self.discriminator.train()
43 | self.d_optimizer.zero_grad()
44 |
45 | # On real data
46 | x_real.requires_grad_()
47 |
48 | d_real = self.discriminator(x_real, y)
49 | dloss_real = self.compute_loss(d_real, 1)
50 |
51 | if self.reg_type == 'real' or self.reg_type == 'real_fake':
52 | dloss_real.backward(retain_graph=True)
53 | reg = self.reg_param * compute_grad2(d_real, x_real).mean()
54 | reg.backward()
55 | else:
56 | dloss_real.backward()
57 |
58 | # On fake data
59 | with torch.no_grad():
60 | x_fake = self.generator(z, y)
61 |
62 | x_fake.requires_grad_()
63 | d_fake = self.discriminator(x_fake, y)
64 | dloss_fake = self.compute_loss(d_fake, 0)
65 |
66 | if self.reg_type == 'fake' or self.reg_type == 'real_fake':
67 | dloss_fake.backward(retain_graph=True)
68 | reg = self.reg_param * compute_grad2(d_fake, x_fake).mean()
69 | reg.backward()
70 | else:
71 | dloss_fake.backward()
72 |
73 | if self.reg_type == 'wgangp':
74 | reg = self.reg_param * self.wgan_gp_reg(x_real, x_fake, y)
75 | reg.backward()
76 | elif self.reg_type == 'wgangp0':
77 | reg = self.reg_param * self.wgan_gp_reg(x_real, x_fake, y, center=0.)
78 | reg.backward()
79 |
80 | self.d_optimizer.step()
81 |
82 | toggle_grad(self.discriminator, False)
83 |
84 | # Output
85 | dloss = (dloss_real + dloss_fake)
86 |
87 | if self.reg_type == 'none':
88 | reg = torch.tensor(0.)
89 |
90 | return dloss.item(), reg.item()
91 |
92 | def compute_loss(self, d_outs, target):
93 |
94 | d_outs = [d_outs] if not isinstance(d_outs, list) else d_outs
95 | loss = 0
96 |
97 | for d_out in d_outs:
98 |
99 | targets = d_out.new_full(size=d_out.size(), fill_value=target)
100 |
101 | if self.gan_type == 'standard':
102 | loss += F.binary_cross_entropy_with_logits(d_out, targets)
103 | elif self.gan_type == 'wgan':
104 | loss += (2*target - 1) * d_out.mean()
105 | else:
106 | raise NotImplementedError
107 |
108 | return loss / len(d_outs)
109 |
110 | def wgan_gp_reg(self, x_real, x_fake, y, center=1.):
111 | batch_size = y.size(0)
112 | eps = torch.rand(batch_size, device=y.device).view(batch_size, 1, 1, 1)
113 | x_interp = (1 - eps) * x_real + eps * x_fake
114 | x_interp = x_interp.detach()
115 | x_interp.requires_grad_()
116 | d_out = self.discriminator(x_interp, y)
117 |
118 | reg = (compute_grad2(d_out, x_interp).sqrt() - center).pow(2).mean()
119 |
120 | return reg
121 |
122 |
123 | # Utility functions
124 | def toggle_grad(model, requires_grad):
125 | for p in model.parameters():
126 | p.requires_grad_(requires_grad)
127 |
128 |
129 | def compute_grad2(d_outs, x_in):
130 | d_outs = [d_outs] if not isinstance(d_outs, list) else d_outs
131 | reg = 0
132 | for d_out in d_outs:
133 | batch_size = x_in.size(0)
134 | grad_dout = autograd.grad(
135 | outputs=d_out.sum(), inputs=x_in,
136 | create_graph=True, retain_graph=True, only_inputs=True
137 | )[0]
138 | grad_dout2 = grad_dout.pow(2)
139 | assert(grad_dout2.size() == x_in.size())
140 | reg += grad_dout2.view(batch_size, -1).sum(1)
141 | return reg / len(d_outs)
142 |
143 |
144 | def update_average(model_tgt, model_src, beta):
145 | toggle_grad(model_src, False)
146 | toggle_grad(model_tgt, False)
147 |
148 | param_dict_src = dict(model_src.named_parameters())
149 |
150 | for p_name, p_tgt in model_tgt.named_parameters():
151 | p_src = param_dict_src[p_name]
152 | assert(p_src is not p_tgt)
153 | p_tgt.copy_(beta*p_tgt + (1. - beta)*p_src)
154 |
--------------------------------------------------------------------------------
/submodules/GAN_stability/gan_training/utils.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.utils.data
4 | import torch.utils.data.distributed
5 | import torchvision
6 |
7 |
8 | def save_images(imgs, outfile, nrow=8):
9 | imgs = imgs / 2 + 0.5 # unnormalize
10 | torchvision.utils.save_image(imgs, outfile, nrow=nrow)
11 |
12 |
13 | def get_nsamples(data_loader, N):
14 | x = []
15 | y = []
16 | n = 0
17 | while n < N:
18 | x_next, y_next = next(iter(data_loader))
19 | x.append(x_next)
20 | y.append(y_next)
21 | n += x_next.size(0)
22 | x = torch.cat(x, dim=0)[:N]
23 | y = torch.cat(y, dim=0)[:N]
24 | return x, y
25 |
26 |
27 | def update_average(model_tgt, model_src, beta):
28 | param_dict_src = dict(model_src.named_parameters())
29 |
30 | for p_name, p_tgt in model_tgt.named_parameters():
31 | p_src = param_dict_src[p_name]
32 | assert(p_src is not p_tgt)
33 | p_tgt.copy_(beta*p_tgt + (1. - beta)*p_src)
34 |
--------------------------------------------------------------------------------
/submodules/GAN_stability/interpolate.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from os import path
4 | import copy
5 | import numpy as np
6 | import torch
7 | from torch import nn
8 | from gan_training import utils
9 | from gan_training.checkpoints import CheckpointIO
10 | from gan_training.distributions import get_ydist, get_zdist, interpolate_sphere
11 | from gan_training.config import (
12 | load_config, build_models
13 | )
14 |
15 | # Arguments
16 | parser = argparse.ArgumentParser(
17 | description='Create interpolations for a trained GAN.'
18 | )
19 | parser.add_argument('config', type=str, help='Path to config file.')
20 | parser.add_argument('--no-cuda', action='store_true', help='Do not use cuda.')
21 |
22 | args = parser.parse_args()
23 |
24 | config = load_config(args.config, 'configs/default.yaml')
25 | is_cuda = (torch.cuda.is_available() and not args.no_cuda)
26 |
27 | # Shorthands
28 | nlabels = config['data']['nlabels']
29 | out_dir = config['training']['out_dir']
30 | batch_size = config['test']['batch_size']
31 | sample_size = config['test']['sample_size']
32 | sample_nrow = config['test']['sample_nrow']
33 | checkpoint_dir = path.join(out_dir, 'chkpts')
34 | interp_dir = path.join(out_dir, 'test', 'interp')
35 |
36 | # Creat missing directories
37 | if not path.exists(interp_dir):
38 | os.makedirs(interp_dir)
39 |
40 | # Logger
41 | checkpoint_io = CheckpointIO(
42 | checkpoint_dir=checkpoint_dir
43 | )
44 |
45 | # Get model file
46 | model_file = config['test']['model_file']
47 |
48 | # Models
49 | device = torch.device("cuda:0" if is_cuda else "cpu")
50 |
51 | generator, discriminator = build_models(config)
52 | print(generator)
53 | print(discriminator)
54 |
55 | # Put models on gpu if needed
56 | generator = generator.to(device)
57 | discriminator = discriminator.to(device)
58 |
59 | # Use multiple GPUs if possible
60 | generator = nn.DataParallel(generator)
61 | discriminator = nn.DataParallel(discriminator)
62 |
63 | # Register modules to checkpoint
64 | checkpoint_io.register_modules(
65 | generator=generator,
66 | discriminator=discriminator,
67 | )
68 |
69 | # Test generator
70 | if config['test']['use_model_average']:
71 | generator_test = copy.deepcopy(generator)
72 | checkpoint_io.register_modules(generator_test=generator_test)
73 | else:
74 | generator_test = generator
75 |
76 | # Distributions
77 | ydist = get_ydist(nlabels, device=device)
78 | zdist = get_zdist(config['z_dist']['type'], config['z_dist']['dim'],
79 | device=device)
80 |
81 |
82 | # Load checkpoint if existant
83 | load_dict = checkpoint_io.load(model_file)
84 | it = load_dict.get('it', -1)
85 | epoch_idx = load_dict.get('epoch_idx', -1)
86 |
87 | # Interpolations
88 | print('Creating interplations...')
89 | nsteps = config['interpolations']['nzs']
90 | nsubsteps = config['interpolations']['nsubsteps']
91 |
92 | y = ydist.sample((sample_size,))
93 | zs = [zdist.sample((sample_size,)) for i in range(nsteps)]
94 | ts = np.linspace(0, 1, nsubsteps)
95 |
96 | it = 0
97 | for z1, z2 in zip(zs, zs[1:] + [zs[0]]):
98 | for t in ts:
99 | z = interpolate_sphere(z1, z2, float(t))
100 | with torch.no_grad():
101 | x = generator_test(z, y)
102 | utils.save_images(x, path.join(interp_dir, '%04d.png' % it),
103 | nrow=sample_nrow)
104 | it += 1
105 | print('%d/%d done!' % (it, nsteps * nsubsteps))
106 |
--------------------------------------------------------------------------------
/submodules/GAN_stability/interpolate_class.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from os import path
4 | import copy
5 | import numpy as np
6 | import torch
7 | from torch import nn
8 | from gan_training import utils
9 | from gan_training.checkpoints import CheckpointIO
10 | from gan_training.distributions import get_ydist, get_zdist, interpolate_sphere
11 | from gan_training.config import (
12 | load_config, build_models
13 | )
14 |
15 | # Arguments
16 | parser = argparse.ArgumentParser(
17 | description='Create interpolations for a trained GAN.'
18 | )
19 | parser.add_argument('config', type=str, help='Path to config file.')
20 | parser.add_argument('--no-cuda', action='store_true', help='Do not use cuda.')
21 |
22 | args = parser.parse_args()
23 |
24 | config = load_config(args.config, 'configs/default.yaml')
25 | is_cuda = (torch.cuda.is_available() and not args.no_cuda)
26 |
27 | # Shorthands
28 | nlabels = config['data']['nlabels']
29 | out_dir = config['training']['out_dir']
30 | batch_size = config['test']['batch_size']
31 | sample_size = config['test']['sample_size']
32 | sample_nrow = config['test']['sample_nrow']
33 | checkpoint_dir = path.join(out_dir, 'chkpts')
34 | interp_dir = path.join(out_dir, 'test', 'interp_class')
35 |
36 | # Creat missing directories
37 | if not path.exists(interp_dir):
38 | os.makedirs(interp_dir)
39 |
40 | # Logger
41 | checkpoint_io = CheckpointIO(
42 | checkpoint_dir=checkpoint_dir
43 | )
44 |
45 | # Get model file
46 | model_file = config['test']['model_file']
47 |
48 | # Models
49 | device = torch.device("cuda:0" if is_cuda else "cpu")
50 |
51 | generator, discriminator = build_models(config)
52 | print(generator)
53 | print(discriminator)
54 |
55 | # Put models on gpu if needed
56 | generator = generator.to(device)
57 | discriminator = discriminator.to(device)
58 |
59 | # Use multiple GPUs if possible
60 | generator = nn.DataParallel(generator)
61 | discriminator = nn.DataParallel(discriminator)
62 |
63 | # Register modules to checkpoint
64 | checkpoint_io.register_modules(
65 | generator=generator,
66 | discriminator=discriminator,
67 | )
68 |
69 | # Test generator
70 | if config['test']['use_model_average']:
71 | generator_test = copy.deepcopy(generator)
72 | checkpoint_io.register_modules(generator_test=generator_test)
73 | else:
74 | generator_test = generator
75 |
76 | # Distributions
77 | ydist = get_ydist(nlabels, device=device)
78 | zdist = get_zdist(config['z_dist']['type'], config['z_dist']['dim'],
79 | device=device)
80 |
81 |
82 | # Load checkpoint if existant
83 | load_dict = checkpoint_io.load(model_file)
84 | it = load_dict.get('it', -1)
85 | epoch_idx = load_dict.get('epoch_idx', -1)
86 |
87 | # Interpolations
88 | print('Creating interplations...')
89 | nsubsteps = config['interpolations']['nsubsteps']
90 | ys = config['interpolations']['ys']
91 |
92 | nsteps = len(ys)
93 | z = zdist.sample((sample_size,))
94 | ts = np.linspace(0, 1, nsubsteps)
95 |
96 | it = 0
97 | for y1, y2 in zip(ys, ys[1:] + [ys[0]]):
98 | for t in ts:
99 | y1_pt = torch.full((sample_size,), y1, dtype=torch.int64, device=device)
100 | y2_pt = torch.full((sample_size,), y2, dtype=torch.int64, device=device)
101 | y1_embed = generator_test.module.embedding(y1_pt)
102 | y2_embed = generator_test.module.embedding(y2_pt)
103 | t = float(t)
104 | y_embed = (1 - t) * y1_embed + t * y2_embed
105 | with torch.no_grad():
106 | x = generator_test(z, y_embed)
107 | utils.save_images(x, path.join(interp_dir, '%04d.png' % it),
108 | nrow=sample_nrow)
109 | it += 1
110 | print('%d/%d done!' % (it, nsteps * nsubsteps))
111 |
--------------------------------------------------------------------------------
/submodules/GAN_stability/notebooks/create_video.sh:
--------------------------------------------------------------------------------
1 | EXEC=ffmpeg
2 | declare -a FOLDERS=(
3 | "simgd"
4 | "altgd1"
5 | "altgd5"
6 | )
7 | declare -a SUBFOLDERS=(
8 | "gan"
9 | "gan_consensus"
10 | "gan_gradpen"
11 | "gan_gradpen_critical"
12 | "gan_instnoise"
13 | "nsgan"
14 | "nsgan_gradpen"
15 | "wgan"
16 | "wgan_gp"
17 | )
18 | OPTIONS="-y"
19 |
20 | cd ./out
21 | for FOLDER in ${FOLDERS[@]}; do
22 | for SUBFOLDER in ${SUBFOLDERS[@]}; do
23 | INPUT="$FOLDER/animations/$SUBFOLDER/%06d.png"
24 | OUTPUT="$FOLDER/animations/$SUBFOLDER.mp4"
25 | $EXEC -framerate 30 -i $INPUT $OPTIONS $OUTPUT
26 | echo $FOLDER
27 | echo $SUBFOLDER
28 | done
29 |
30 | done
31 |
--------------------------------------------------------------------------------
/submodules/GAN_stability/notebooks/diracgan/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/autonomousvision/graf/c50d342fb567aec335b92e3f867c54b4dc4e1d09/submodules/GAN_stability/notebooks/diracgan/__init__.py
--------------------------------------------------------------------------------
/submodules/GAN_stability/notebooks/diracgan/gans.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from diracgan.util import sigmoid, clip
3 |
4 |
5 | class VectorField(object):
6 | def __call__(self, theta, psi):
7 | theta_isfloat = isinstance(theta, float)
8 | psi_isfloat = isinstance(psi, float)
9 | if theta_isfloat:
10 | theta = np.array([theta])
11 | if psi_isfloat:
12 | psi = np.array([psi])
13 |
14 | v1, v2 = self._get_vector(theta, psi)
15 |
16 | if theta_isfloat:
17 | v1 = v1[0]
18 | if psi_isfloat:
19 | v2 = v2[0]
20 |
21 | return v1, v2
22 |
23 | def postprocess(self, theta, psi):
24 | theta_isfloat = isinstance(theta, float)
25 | psi_isfloat = isinstance(psi, float)
26 | if theta_isfloat:
27 | theta = np.array([theta])
28 | if psi_isfloat:
29 | psi = np.array([psi])
30 | theta, psi = self._postprocess(theta, psi)
31 | if theta_isfloat:
32 | theta = theta[0]
33 | if psi_isfloat:
34 | psi = psi[0]
35 |
36 | return theta, psi
37 |
38 | def step_sizes(self, h):
39 | return h, h
40 |
41 | def _get_vector(self, theta, psi):
42 | raise NotImplemented
43 |
44 | def _postprocess(self, theta, psi):
45 | return theta, psi
46 |
47 |
48 | # GANs
49 | def fp(x):
50 | return sigmoid(-x)
51 |
52 |
53 | def fp2(x):
54 | return -sigmoid(-x) * sigmoid(x)
55 |
56 |
57 | class GAN(VectorField):
58 | def _get_vector(self, theta, psi):
59 | v1 = -psi * fp(psi*theta)
60 | v2 = theta * fp(psi*theta)
61 | return v1, v2
62 |
63 |
64 | class NSGAN(VectorField):
65 | def _get_vector(self, theta, psi):
66 | v1 = -psi * fp(-psi*theta)
67 | v2 = theta * fp(psi*theta)
68 | return v1, v2
69 |
70 |
71 | class WGAN(VectorField):
72 | def __init__(self, clip=0.3):
73 | super().__init__()
74 | self.clip = clip
75 |
76 | def _get_vector(self, theta, psi):
77 | v1 = -psi
78 | v2 = theta
79 |
80 | return v1, v2
81 |
82 | def _postprocess(self, theta, psi):
83 | psi = clip(psi, self.clip)
84 | return theta, psi
85 |
86 |
87 | class WGAN_GP(VectorField):
88 | def __init__(self, reg=1., target=0.3):
89 | super().__init__()
90 | self.reg = reg
91 | self.target = target
92 |
93 | def _get_vector(self, theta, psi):
94 | v1 = -psi
95 | v2 = theta - self.reg * (np.abs(psi) - self.target) * np.sign(psi)
96 | return v1, v2
97 |
98 |
99 | class GAN_InstNoise(VectorField):
100 | def __init__(self, std=1):
101 | self.std = std
102 |
103 | def _get_vector(self, theta, psi):
104 | theta_eps = (
105 | theta + self.std*np.random.randn(*([1000] + list(theta.shape)))
106 | )
107 | x_eps = (
108 | self.std * np.random.randn(*([1000] + list(theta.shape)))
109 | )
110 | v1 = -psi * fp(psi*theta_eps)
111 | v2 = theta_eps * fp(psi*theta_eps) - x_eps * fp(-x_eps * psi)
112 | v1 = v1.mean(axis=0)
113 | v2 = v2.mean(axis=0)
114 | return v1, v2
115 |
116 |
117 | class GAN_GradPenalty(VectorField):
118 | def __init__(self, reg=0.3):
119 | self.reg = reg
120 |
121 | def _get_vector(self, theta, psi):
122 | v1 = -psi * fp(psi*theta)
123 | v2 = +theta * fp(psi*theta) - self.reg * psi
124 | return v1, v2
125 |
126 |
127 | class NSGAN_GradPenalty(VectorField):
128 | def __init__(self, reg=0.3):
129 | self.reg = reg
130 |
131 | def _get_vector(self, theta, psi):
132 | v1 = -psi * fp(-psi*theta)
133 | v2 = theta * fp(psi*theta) - self.reg * psi
134 | return v1, v2
135 |
136 |
137 | class GAN_Consensus(VectorField):
138 | def __init__(self, reg=0.3):
139 | self.reg = reg
140 |
141 | def _get_vector(self, theta, psi):
142 | v1 = -psi * fp(psi*theta)
143 | v2 = +theta * fp(psi*theta)
144 |
145 | # L 0.5*(psi**2 + theta**2)*f(psi*theta)**2
146 | v1reg = (
147 | theta * fp(psi*theta)**2
148 | + 0.5*psi * (psi**2 + theta**2) * fp(psi*theta)*fp2(psi*theta)
149 | )
150 | v2reg = (
151 | psi * fp(psi*theta)**2
152 | + 0.5*theta * (psi**2 + theta**2) * fp(psi*theta)*fp2(psi*theta)
153 | )
154 | v1 -= self.reg * v1reg
155 | v2 -= self.reg * v2reg
156 |
157 | return v1, v2
158 |
159 |
--------------------------------------------------------------------------------
/submodules/GAN_stability/notebooks/diracgan/plotting.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from matplotlib import pyplot as plt
3 | import matplotlib.patches as patches
4 | import os
5 | from diracgan.gans import WGAN
6 | from diracgan.subplots import vector_field_plot
7 | from tqdm import tqdm
8 |
9 |
10 | def plot_vector(vecfn, theta, psi, outfile, trajectory=None, marker='b^'):
11 | fig, ax = plt.subplots(1, 1)
12 | theta, psi = np.meshgrid(theta, psi)
13 | v1, v2 = vecfn(theta, psi)
14 | if isinstance(vecfn, WGAN):
15 | clip_y = vecfn.clip
16 | else:
17 | clip_y = None
18 | vector_field_plot(theta, psi, v1, v2, trajectory, clip_y=clip_y, marker=marker)
19 | plt.savefig(outfile, bbox_inches='tight')
20 | plt.show()
21 |
22 |
23 | def simulate_trajectories(vecfn, theta, psi, trajectory, outfolder, maxframes=300):
24 | if not os.path.exists(outfolder):
25 | os.makedirs(outfolder)
26 | theta, psi = np.meshgrid(theta, psi)
27 |
28 | N = min(len(trajectory[0]), maxframes)
29 |
30 | v1, v2 = vecfn(theta, psi)
31 | if isinstance(vecfn, WGAN):
32 | clip_y = vecfn.clip
33 | else:
34 | clip_y = None
35 |
36 | for i in tqdm(range(1, N)):
37 | fig, (ax1, ax2) = plt.subplots(1, 2,
38 | subplot_kw=dict(adjustable='box', aspect=0.7))
39 |
40 | plt.sca(ax1)
41 | trajectory_i = [trajectory[0][:i], trajectory[1][:i]]
42 | vector_field_plot(theta, psi, v1, v2, trajectory_i, clip_y=clip_y, marker='b-')
43 | plt.plot(trajectory_i[0][-1], trajectory_i[1][-1], 'bo')
44 |
45 | plt.sca(ax2)
46 | ax2.set_axisbelow(True)
47 | plt.grid()
48 |
49 | x = np.linspace(np.min(theta), np.max(theta))
50 | y = x*trajectory[1][i]
51 | plt.plot(x, y, 'C1')
52 |
53 | ax2.add_patch(patches.Rectangle(
54 | (-0.05, 0), .1, 2.5, facecolor='C2'
55 | ))
56 |
57 | ax2.add_patch(patches.Rectangle(
58 | (trajectory[0][i]-0.05, 0), .1, 2.5, facecolor='C0'
59 | ))
60 |
61 | plt.xlim(np.min(theta), np.max(theta))
62 | plt.ylim(-1, 3.)
63 | plt.xlabel(r'$\theta$')
64 | plt.xticks(np.linspace(np.min(theta), np.max(theta), 5))
65 | ax2.set_yticklabels([])
66 |
67 | plt.savefig(os.path.join(outfolder, '%06d.png' % i), dpi=200, bbox_inches='tight')
68 | plt.close()
69 |
--------------------------------------------------------------------------------
/submodules/GAN_stability/notebooks/diracgan/simulate.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | # Simulate
4 | def trajectory_simgd(vec_fn, theta0, psi0,
5 | nsteps=50, hs_g=0.1, hs_d=0.1):
6 | theta, psi = vec_fn.postprocess(theta0, psi0)
7 | thetas, psis = [theta], [psi]
8 |
9 | if isinstance(hs_g, float):
10 | hs_g = [hs_g] * nsteps
11 | if isinstance(hs_d, float):
12 | hs_d = [hs_d] * nsteps
13 | assert(len(hs_g) == nsteps)
14 | assert(len(hs_d) == nsteps)
15 |
16 | for h_g, h_d in zip(hs_g, hs_d):
17 | v1, v2 = vec_fn(theta, psi)
18 | theta += h_g * v1
19 | psi += h_d * v2
20 | theta, psi = vec_fn.postprocess(theta, psi)
21 | thetas.append(theta)
22 | psis.append(psi)
23 |
24 | return thetas, psis
25 |
26 |
27 | def trajectory_altgd(vec_fn, theta0, psi0,
28 | nsteps=50, hs_g=0.1, hs_d=0.1, gsteps=1, dsteps=1):
29 | theta, psi = vec_fn.postprocess(theta0, psi0)
30 | thetas, psis = [theta], [psi]
31 |
32 | if isinstance(hs_g, float):
33 | hs_g = [hs_g] * nsteps
34 | if isinstance(hs_d, float):
35 | hs_d = [hs_d] * nsteps
36 | assert(len(hs_g) == nsteps)
37 | assert(len(hs_d) == nsteps)
38 |
39 | for h_g, h_d in zip(hs_g, hs_d):
40 | for it in range(gsteps):
41 | v1, v2 = vec_fn(theta, psi)
42 | theta += h_g * v1
43 | theta, psi = vec_fn.postprocess(theta, psi)
44 |
45 | for it in range(dsteps):
46 | v1, v2 = vec_fn(theta, psi)
47 | psi += h_d * v2
48 | theta, psi = vec_fn.postprocess(theta, psi)
49 | thetas.append(theta)
50 | psis.append(psi)
51 |
52 | return thetas, psis
53 |
--------------------------------------------------------------------------------
/submodules/GAN_stability/notebooks/diracgan/subplots.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from matplotlib import pyplot as plt
3 |
4 |
5 | def arrow_plot(x, y, color='C1'):
6 | plt.quiver(x[:-1], y[:-1], x[1:]-x[:-1], y[1:]-y[:-1],
7 | color=color, scale_units='xy', angles='xy', scale=1)
8 |
9 |
10 | def vector_field_plot(theta, psi, v1, v2, trajectory=None, clip_y=None, marker='b^'):
11 | plt.quiver(theta, psi, v1, v2)
12 | if clip_y is not None:
13 | plt.axhspan(np.min(psi), -clip_y, facecolor='0.2', alpha=0.5)
14 | plt.plot([np.min(theta), np.max(theta)], [-clip_y, -clip_y], 'k-')
15 | plt.axhspan(clip_y, np.max(psi), facecolor='0.2', alpha=0.5)
16 | plt.plot([np.min(theta), np.max(theta)], [clip_y, clip_y], 'k-')
17 |
18 | if trajectory is not None:
19 | psis, thetas = trajectory
20 | plt.plot(psis, thetas, marker, markerfacecolor='None')
21 | plt.plot(psis[0], thetas[0], 'ro')
22 |
23 | plt.xlim(np.min(theta), np.max(theta))
24 | plt.ylim(np.min(psi), np.max(psi))
25 | plt.xlabel(r'$\theta$')
26 | plt.ylabel(r'$\psi$')
27 | plt.xticks(np.linspace(np.min(theta), np.max(theta), 5))
28 | plt.yticks(np.linspace(np.min(psi), np.max(psi), 5))
29 |
--------------------------------------------------------------------------------
/submodules/GAN_stability/notebooks/diracgan/util.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | def sigmoid(x):
4 | m = np.minimum(0, x)
5 | return np.exp(m)/(np.exp(m) + np.exp(-x + m))
6 |
7 |
8 | def clip(x, clipval=0.3):
9 | x = np.clip(x, -clipval, clipval)
10 | return x
11 |
--------------------------------------------------------------------------------
/submodules/GAN_stability/results/celebA-HQ.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/autonomousvision/graf/c50d342fb567aec335b92e3f867c54b4dc4e1d09/submodules/GAN_stability/results/celebA-HQ.jpg
--------------------------------------------------------------------------------
/submodules/GAN_stability/results/imagenet_00.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/autonomousvision/graf/c50d342fb567aec335b92e3f867c54b4dc4e1d09/submodules/GAN_stability/results/imagenet_00.jpg
--------------------------------------------------------------------------------
/submodules/GAN_stability/results/imagenet_01.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/autonomousvision/graf/c50d342fb567aec335b92e3f867c54b4dc4e1d09/submodules/GAN_stability/results/imagenet_01.jpg
--------------------------------------------------------------------------------
/submodules/GAN_stability/results/imagenet_02.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/autonomousvision/graf/c50d342fb567aec335b92e3f867c54b4dc4e1d09/submodules/GAN_stability/results/imagenet_02.jpg
--------------------------------------------------------------------------------
/submodules/GAN_stability/results/imagenet_03.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/autonomousvision/graf/c50d342fb567aec335b92e3f867c54b4dc4e1d09/submodules/GAN_stability/results/imagenet_03.jpg
--------------------------------------------------------------------------------
/submodules/GAN_stability/results/imagenet_04.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/autonomousvision/graf/c50d342fb567aec335b92e3f867c54b4dc4e1d09/submodules/GAN_stability/results/imagenet_04.jpg
--------------------------------------------------------------------------------
/submodules/GAN_stability/test.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from os import path
4 | import copy
5 | from tqdm import tqdm
6 | import torch
7 | from torch import nn
8 | from gan_training import utils
9 | from gan_training.checkpoints import CheckpointIO
10 | from gan_training.distributions import get_ydist, get_zdist
11 | from gan_training.eval import Evaluator
12 | from gan_training.config import (
13 | load_config, build_models
14 | )
15 |
16 | # Arguments
17 | parser = argparse.ArgumentParser(
18 | description='Test a trained GAN and create visualizations.'
19 | )
20 | parser.add_argument('config', type=str, help='Path to config file.')
21 | parser.add_argument('--no-cuda', action='store_true', help='Do not use cuda.')
22 |
23 | args = parser.parse_args()
24 |
25 | config = load_config(args.config, 'configs/default.yaml')
26 | is_cuda = (torch.cuda.is_available() and not args.no_cuda)
27 |
28 | # Shorthands
29 | nlabels = config['data']['nlabels']
30 | out_dir = config['training']['out_dir']
31 | batch_size = config['test']['batch_size']
32 | sample_size = config['test']['sample_size']
33 | sample_nrow = config['test']['sample_nrow']
34 | checkpoint_dir = path.join(out_dir, 'chkpts')
35 | img_dir = path.join(out_dir, 'test', 'img')
36 | img_all_dir = path.join(out_dir, 'test', 'img_all')
37 |
38 | # Creat missing directories
39 | if not path.exists(img_dir):
40 | os.makedirs(img_dir)
41 | if not path.exists(img_all_dir):
42 | os.makedirs(img_all_dir)
43 |
44 | # Logger
45 | checkpoint_io = CheckpointIO(
46 | checkpoint_dir=checkpoint_dir
47 | )
48 |
49 | # Get model file
50 | model_file = config['test']['model_file']
51 |
52 | # Models
53 | device = torch.device("cuda:0" if is_cuda else "cpu")
54 |
55 | generator, discriminator = build_models(config)
56 | print(generator)
57 | print(discriminator)
58 |
59 | # Put models on gpu if needed
60 | generator = generator.to(device)
61 | discriminator = discriminator.to(device)
62 |
63 | # Use multiple GPUs if possible
64 | generator = nn.DataParallel(generator)
65 | discriminator = nn.DataParallel(discriminator)
66 |
67 | # Register modules to checkpoint
68 | checkpoint_io.register_modules(
69 | generator=generator,
70 | discriminator=discriminator,
71 | )
72 |
73 | # Test generator
74 | if config['test']['use_model_average']:
75 | generator_test = copy.deepcopy(generator)
76 | checkpoint_io.register_modules(generator_test=generator_test)
77 | else:
78 | generator_test = generator
79 |
80 | # Distributions
81 | ydist = get_ydist(nlabels, device=device)
82 | zdist = get_zdist(config['z_dist']['type'], config['z_dist']['dim'],
83 | device=device)
84 |
85 | # Evaluator
86 | evaluator = Evaluator(generator_test, zdist, ydist,
87 | batch_size=batch_size, device=device)
88 |
89 | # Load checkpoint if existant
90 | load_dict = checkpoint_io.load(model_file)
91 | it = load_dict.get('it', -1)
92 | epoch_idx = load_dict.get('epoch_idx', -1)
93 |
94 | # Inception score
95 | if config['test']['compute_inception']:
96 | print('Computing inception score...')
97 | inception_mean, inception_std = evaluator.compute_inception_score()
98 | print('Inception score: %.4f +- %.4f' % (inception_mean, inception_std))
99 |
100 | # Samples
101 | print('Creating samples...')
102 | ztest = zdist.sample((sample_size,))
103 | x = evaluator.create_samples(ztest)
104 | utils.save_images(x, path.join(img_all_dir, '%08d.png' % it),
105 | nrow=sample_nrow)
106 | if config['test']['conditional_samples']:
107 | for y_inst in tqdm(range(nlabels)):
108 | x = evaluator.create_samples(ztest, y_inst)
109 | utils.save_images(x, path.join(img_dir, '%04d.png' % y_inst),
110 | nrow=sample_nrow)
111 |
--------------------------------------------------------------------------------
/submodules/GAN_stability/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from os import path
4 | import time
5 | import copy
6 | import torch
7 | from torch import nn
8 | from gan_training import utils
9 | from gan_training.train import Trainer, update_average
10 | from gan_training.logger import Logger
11 | from gan_training.checkpoints import CheckpointIO
12 | from gan_training.inputs import get_dataset
13 | from gan_training.distributions import get_ydist, get_zdist
14 | from gan_training.eval import Evaluator
15 | from gan_training.config import (
16 | load_config, build_models, build_optimizers, build_lr_scheduler,
17 | )
18 |
19 | # Arguments
20 | parser = argparse.ArgumentParser(
21 | description='Train a GAN with different regularization strategies.'
22 | )
23 | parser.add_argument('config', type=str, help='Path to config file.')
24 | parser.add_argument('--no-cuda', action='store_true', help='Do not use cuda.')
25 |
26 | args = parser.parse_args()
27 |
28 | config = load_config(args.config, 'configs/default.yaml')
29 | is_cuda = (torch.cuda.is_available() and not args.no_cuda)
30 |
31 | # Short hands
32 | batch_size = config['training']['batch_size']
33 | d_steps = config['training']['d_steps']
34 | restart_every = config['training']['restart_every']
35 | inception_every = config['training']['inception_every']
36 | save_every = config['training']['save_every']
37 | backup_every = config['training']['backup_every']
38 | sample_nlabels = config['training']['sample_nlabels']
39 |
40 | out_dir = config['training']['out_dir']
41 | checkpoint_dir = path.join(out_dir, 'chkpts')
42 |
43 | # Create missing directories
44 | if not path.exists(out_dir):
45 | os.makedirs(out_dir)
46 | if not path.exists(checkpoint_dir):
47 | os.makedirs(checkpoint_dir)
48 |
49 | # Logger
50 | checkpoint_io = CheckpointIO(
51 | checkpoint_dir=checkpoint_dir
52 | )
53 |
54 | device = torch.device("cuda:0" if is_cuda else "cpu")
55 |
56 |
57 | # Dataset
58 | train_dataset, nlabels = get_dataset(
59 | name=config['data']['type'],
60 | data_dir=config['data']['train_dir'],
61 | size=config['data']['img_size'],
62 | lsun_categories=config['data']['lsun_categories_train']
63 | )
64 | train_loader = torch.utils.data.DataLoader(
65 | train_dataset,
66 | batch_size=batch_size,
67 | num_workers=config['training']['nworkers'],
68 | shuffle=True, pin_memory=True, sampler=None, drop_last=True
69 | )
70 |
71 | # Number of labels
72 | nlabels = min(nlabels, config['data']['nlabels'])
73 | sample_nlabels = min(nlabels, sample_nlabels)
74 |
75 | # Create models
76 | generator, discriminator = build_models(config)
77 | print(generator)
78 | print(discriminator)
79 |
80 | # Put models on gpu if needed
81 | generator = generator.to(device)
82 | discriminator = discriminator.to(device)
83 |
84 | g_optimizer, d_optimizer = build_optimizers(
85 | generator, discriminator, config
86 | )
87 |
88 | # Use multiple GPUs if possible
89 | generator = nn.DataParallel(generator)
90 | discriminator = nn.DataParallel(discriminator)
91 |
92 | # Register modules to checkpoint
93 | checkpoint_io.register_modules(
94 | generator=generator,
95 | discriminator=discriminator,
96 | g_optimizer=g_optimizer,
97 | d_optimizer=d_optimizer,
98 | )
99 |
100 | # Get model file
101 | model_file = config['training']['model_file']
102 |
103 | # Logger
104 | logger = Logger(
105 | log_dir=path.join(out_dir, 'logs'),
106 | img_dir=path.join(out_dir, 'imgs'),
107 | monitoring=config['training']['monitoring'],
108 | monitoring_dir=path.join(out_dir, 'monitoring')
109 | )
110 |
111 | # Distributions
112 | ydist = get_ydist(nlabels, device=device)
113 | zdist = get_zdist(config['z_dist']['type'], config['z_dist']['dim'],
114 | device=device)
115 |
116 | # Save for tests
117 | ntest = batch_size
118 | x_real, ytest = utils.get_nsamples(train_loader, ntest)
119 | ytest.clamp_(None, nlabels-1)
120 | ztest = zdist.sample((ntest,))
121 | utils.save_images(x_real, path.join(out_dir, 'real.png'))
122 |
123 | # Test generator
124 | if config['training']['take_model_average']:
125 | generator_test = copy.deepcopy(generator)
126 | checkpoint_io.register_modules(generator_test=generator_test)
127 | else:
128 | generator_test = generator
129 |
130 | # Evaluator
131 | evaluator = Evaluator(generator_test, zdist, ydist,
132 | batch_size=batch_size, device=device)
133 |
134 | # Train
135 | tstart = t0 = time.time()
136 |
137 | # Load checkpoint if it exists
138 | try:
139 | load_dict = checkpoint_io.load(model_file)
140 | except FileNotFoundError:
141 | it = epoch_idx = -1
142 | else:
143 | it = load_dict.get('it', -1)
144 | epoch_idx = load_dict.get('epoch_idx', -1)
145 | logger.load_stats('stats.p')
146 |
147 | # Reinitialize model average if needed
148 | if (config['training']['take_model_average']
149 | and config['training']['model_average_reinit']):
150 | update_average(generator_test, generator, 0.)
151 |
152 | # Learning rate anneling
153 | g_scheduler = build_lr_scheduler(g_optimizer, config, last_epoch=it)
154 | d_scheduler = build_lr_scheduler(d_optimizer, config, last_epoch=it)
155 |
156 | # Trainer
157 | trainer = Trainer(
158 | generator, discriminator, g_optimizer, d_optimizer,
159 | gan_type=config['training']['gan_type'],
160 | reg_type=config['training']['reg_type'],
161 | reg_param=config['training']['reg_param']
162 | )
163 |
164 | # Training loop
165 | print('Start training...')
166 | while True:
167 | epoch_idx += 1
168 | print('Start epoch %d...' % epoch_idx)
169 |
170 | for x_real, y in train_loader:
171 | it += 1
172 | g_scheduler.step()
173 | d_scheduler.step()
174 |
175 | d_lr = d_optimizer.param_groups[0]['lr']
176 | g_lr = g_optimizer.param_groups[0]['lr']
177 | logger.add('learning_rates', 'discriminator', d_lr, it=it)
178 | logger.add('learning_rates', 'generator', g_lr, it=it)
179 |
180 | x_real, y = x_real.to(device), y.to(device)
181 | y.clamp_(None, nlabels-1)
182 |
183 | # Discriminator updates
184 | z = zdist.sample((batch_size,))
185 | dloss, reg = trainer.discriminator_trainstep(x_real, y, z)
186 | logger.add('losses', 'discriminator', dloss, it=it)
187 | logger.add('losses', 'regularizer', reg, it=it)
188 |
189 | # Generators updates
190 | if ((it + 1) % d_steps) == 0:
191 | z = zdist.sample((batch_size,))
192 | gloss = trainer.generator_trainstep(y, z)
193 | logger.add('losses', 'generator', gloss, it=it)
194 |
195 | if config['training']['take_model_average']:
196 | update_average(generator_test, generator,
197 | beta=config['training']['model_average_beta'])
198 |
199 | # Print stats
200 | g_loss_last = logger.get_last('losses', 'generator')
201 | d_loss_last = logger.get_last('losses', 'discriminator')
202 | d_reg_last = logger.get_last('losses', 'regularizer')
203 | print('[epoch %0d, it %4d] g_loss = %.4f, d_loss = %.4f, reg=%.4f'
204 | % (epoch_idx, it, g_loss_last, d_loss_last, d_reg_last))
205 |
206 | # (i) Sample if necessary
207 | if (it % config['training']['sample_every']) == 0:
208 | print('Creating samples...')
209 | x = evaluator.create_samples(ztest, ytest)
210 | logger.add_imgs(x, 'all', it)
211 | for y_inst in range(sample_nlabels):
212 | x = evaluator.create_samples(ztest, y_inst)
213 | logger.add_imgs(x, '%04d' % y_inst, it)
214 |
215 | # (ii) Compute inception if necessary
216 | if inception_every > 0 and ((it + 1) % inception_every) == 0:
217 | inception_mean, inception_std = evaluator.compute_inception_score()
218 | logger.add('inception_score', 'mean', inception_mean, it=it)
219 | logger.add('inception_score', 'stddev', inception_std, it=it)
220 |
221 | # (iii) Backup if necessary
222 | if ((it + 1) % backup_every) == 0:
223 | print('Saving backup...')
224 | checkpoint_io.save('model_%08d.pt' % it, it=it)
225 | logger.save_stats('stats_%08d.p' % it)
226 |
227 | # (iv) Save checkpoint if necessary
228 | if time.time() - t0 > save_every:
229 | print('Saving checkpoint...')
230 | checkpoint_io.save(model_file, it=it)
231 | logger.save_stats('stats.p')
232 | t0 = time.time()
233 |
234 | if (restart_every > 0 and t0 - tstart > restart_every):
235 | exit(3)
236 |
--------------------------------------------------------------------------------
/submodules/nerf_pytorch/.gitignore:
--------------------------------------------------------------------------------
1 | **/.ipynb_checkpoints
2 | **/__pycache__
3 | *.png
4 | *.mp4
5 | *.npy
6 | *.npz
7 | *.dae
8 | data/*
9 | logs/*
--------------------------------------------------------------------------------
/submodules/nerf_pytorch/.gitmodules:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/autonomousvision/graf/c50d342fb567aec335b92e3f867c54b4dc4e1d09/submodules/nerf_pytorch/.gitmodules
--------------------------------------------------------------------------------
/submodules/nerf_pytorch/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 bmild
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/submodules/nerf_pytorch/README.md:
--------------------------------------------------------------------------------
1 | # NeRF-pytorch
2 |
3 |
4 | [NeRF](http://www.matthewtancik.com/nerf) (Neural Radiance Fields) is a method that achieves state-of-the-art results for synthesizing novel views of complex scenes. Here are some videos generated by this repository (pre-trained models are provided below):
5 |
6 | 
7 | 
8 |
9 | This project is a faithful PyTorch implementation of [NeRF](http://www.matthewtancik.com/nerf) that **reproduces** the results while running **1.3 times faster**. The code is based on authors' Tensorflow implementation [here](https://github.com/bmild/nerf), and has been tested to match it numerically.
10 |
11 | ## Installation
12 |
13 | ```
14 | git clone https://github.com/yenchenlin/nerf-pytorch.git
15 | cd nerf-pytorch
16 | pip install -r requirements.txt
17 | cd torchsearchsorted
18 | pip install .
19 | cd ../
20 | ```
21 |
22 |
23 | Dependencies (click to expand)
24 |
25 | ## Dependencies
26 | - PyTorch 1.4
27 | - matplotlib
28 | - numpy
29 | - imageio
30 | - imageio-ffmpeg
31 | - configargparse
32 |
33 | The LLFF data loader requires ImageMagick.
34 |
35 | You will also need the [LLFF code](http://github.com/fyusion/llff) (and COLMAP) set up to compute poses if you want to run on your own real data.
36 |
37 |
38 |
39 | ## How To Run?
40 |
41 | ### Quick Start
42 |
43 | Download data for two example datasets: `lego` and `fern`
44 | ```
45 | bash download_example_data.sh
46 | ```
47 |
48 | To train a low-res `lego` NeRF:
49 | ```
50 | python run_nerf.py --config configs/config_lego.txt
51 | ```
52 | After training for 100k iterations (~4 hours on a single 2080 Ti), you can find the following video at `logs/lego_test/lego_test_spiral_100000_rgb.mp4`.
53 |
54 | 
55 |
56 | ---
57 |
58 | To train a low-res `fern` NeRF:
59 | ```
60 | python run_nerf.py --config configs/config_fern.txt
61 | ```
62 | After training for 200k iterations (~8 hours on a single 2080 Ti), you can find the following video at `logs/fern_test/fern_test_spiral_200000_rgb.mp4` and `logs/fern_test/fern_test_spiral_200000_disp.mp4`
63 |
64 | 
65 |
66 | ---
67 |
68 | ### More Datasets
69 | To play with other scenes presented in the paper, download the data [here](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1). Place the downloaded dataset according to the following directory structure:
70 | ```
71 | ├── configs
72 | │ ├── ...
73 | │
74 | ├── data
75 | │ ├── nerf_llff_data
76 | │ │ └── fern
77 | │ │ └── flower # downloaded llff dataset
78 | │ │ └── horns # downloaded llff dataset
79 | | | └── ...
80 | | ├── nerf_synthetic
81 | | | └── lego
82 | | | └── ship # downloaded synthetic dataset
83 | | | └── ...
84 | ```
85 |
86 | ---
87 |
88 | To train NeRF on different datasets:
89 |
90 | ```
91 | python run_nerf.py --config configs/config_{DATASET}.txt
92 | ```
93 |
94 | replace `{DATASET}` with `trex` | `horns` | `flower` | `fortress` | `lego` | etc.
95 |
96 | ---
97 |
98 | To test NeRF trained on different datasets:
99 |
100 | ```
101 | python run_nerf.py --config configs/config_{DATASET}.txt --render_only
102 | ```
103 |
104 | replace `{DATASET}` with `trex` | `horns` | `flower` | `fortress` | `lego` | etc.
105 |
106 |
107 | ### Pre-trained Models
108 |
109 | You can download the pre-trained models [here](https://drive.google.com/drive/folders/1jIr8dkvefrQmv737fFm2isiT6tqpbTbv?usp=sharing). Place the downloaded directory in `./logs` in order to test it later. See the following directory structure for an example:
110 |
111 | ```
112 | ├── logs
113 | │ ├── fern_test
114 | │ ├── flower_test # downloaded logs
115 | │ ├── trex_test # downloaded logs
116 | ```
117 |
118 | ### Reproducibility
119 |
120 | Tests that ensure the results of all functions and training loop match the official implentation are contained in a different branch `reproduce`. One can check it out and run the tests:
121 | ```
122 | git checkout reproduce
123 | py.test
124 | ```
125 |
126 | ## Method
127 |
128 | [NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis](http://tancik.com/nerf)
129 | [Ben Mildenhall](https://people.eecs.berkeley.edu/~bmild/)\*1,
130 | [Pratul P. Srinivasan](https://people.eecs.berkeley.edu/~pratul/)\*1,
131 | [Matthew Tancik](http://tancik.com/)\*1,
132 | [Jonathan T. Barron](http://jonbarron.info/)2,
133 | [Ravi Ramamoorthi](http://cseweb.ucsd.edu/~ravir/)3,
134 | [Ren Ng](https://www2.eecs.berkeley.edu/Faculty/Homepages/yirenng.html)1
135 | 1UC Berkeley, 2Google Research, 3UC San Diego
136 | \*denotes equal contribution
137 |
138 |
139 |
140 | > A neural radiance field is a simple fully connected network (weights are ~5MB) trained to reproduce input views of a single scene using a rendering loss. The network directly maps from spatial location and viewing direction (5D input) to color and opacity (4D output), acting as the "volume" so we can use volume rendering to differentiably render new views
141 |
142 |
143 | ## Citation
144 | Kudos to the authors for their amazing results:
145 | ```
146 | @misc{mildenhall2020nerf,
147 | title={NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis},
148 | author={Ben Mildenhall and Pratul P. Srinivasan and Matthew Tancik and Jonathan T. Barron and Ravi Ramamoorthi and Ren Ng},
149 | year={2020},
150 | eprint={2003.08934},
151 | archivePrefix={arXiv},
152 | primaryClass={cs.CV}
153 | }
154 | ```
155 |
156 | However, if you find this implementation or pre-trained models helpful, please consider to cite:
157 | ```
158 | @misc{lin2020nerfpytorch,
159 | title={NeRF-pytorch},
160 | author={Yen-Chen, Lin},
161 | howpublished={\url{https://github.com/yenchenlin/nerf-pytorch/}},
162 | year={2020}
163 | }
164 | ```
165 |
--------------------------------------------------------------------------------
/submodules/nerf_pytorch/configs/config_fern.txt:
--------------------------------------------------------------------------------
1 | expname = fern_test
2 | basedir = ./logs
3 | datadir = ./data/nerf_llff_data/fern
4 | dataset_type = llff
5 |
6 | factor = 8
7 | llffhold = 8
8 |
9 | N_rand = 1024
10 | N_samples = 64
11 | N_importance = 64
12 |
13 | use_viewdirs = True
14 | raw_noise_std = 1e0
15 |
16 |
--------------------------------------------------------------------------------
/submodules/nerf_pytorch/configs/config_flower.txt:
--------------------------------------------------------------------------------
1 | expname = flower_test
2 | basedir = ./logs
3 | datadir = ./data/nerf_llff_data/flower
4 | dataset_type = llff
5 |
6 | factor = 8
7 | llffhold = 8
8 |
9 | N_rand = 1024
10 | N_samples = 64
11 | N_importance = 64
12 |
13 | use_viewdirs = True
14 | raw_noise_std = 1e0
15 |
16 |
--------------------------------------------------------------------------------
/submodules/nerf_pytorch/configs/config_fortress.txt:
--------------------------------------------------------------------------------
1 | expname = fortress_test
2 | basedir = ./logs
3 | datadir = ./data/nerf_llff_data/fortress
4 | dataset_type = llff
5 |
6 | factor = 8
7 | llffhold = 8
8 |
9 | N_rand = 1024
10 | N_samples = 64
11 | N_importance = 64
12 |
13 | use_viewdirs = True
14 | raw_noise_std = 1e0
15 |
16 |
--------------------------------------------------------------------------------
/submodules/nerf_pytorch/configs/config_horns.txt:
--------------------------------------------------------------------------------
1 | expname = horns_test
2 | basedir = ./logs
3 | datadir = ./data/nerf_llff_data/horns
4 | dataset_type = llff
5 |
6 | factor = 8
7 | llffhold = 8
8 |
9 | N_rand = 1024
10 | N_samples = 64
11 | N_importance = 64
12 |
13 | use_viewdirs = True
14 | raw_noise_std = 1e0
15 |
16 |
--------------------------------------------------------------------------------
/submodules/nerf_pytorch/configs/config_lego.txt:
--------------------------------------------------------------------------------
1 | expname = lego_test
2 | basedir = ./logs
3 | datadir = ./data/nerf_synthetic/lego
4 | dataset_type = blender
5 |
6 | half_res = True
7 |
8 | N_samples = 64
9 | N_importance = 64
10 |
11 | use_viewdirs = True
12 |
13 | white_bkgd = True
14 |
15 | N_rand = 1024
--------------------------------------------------------------------------------
/submodules/nerf_pytorch/configs/config_trex.txt:
--------------------------------------------------------------------------------
1 | expname = trex_test
2 | basedir = ./logs
3 | datadir = ./data/nerf_llff_data/trex
4 | dataset_type = llff
5 |
6 | factor = 8
7 | llffhold = 8
8 |
9 | N_rand = 1024
10 | N_samples = 64
11 | N_importance = 64
12 |
13 | use_viewdirs = True
14 | raw_noise_std = 1e0
15 |
16 |
--------------------------------------------------------------------------------
/submodules/nerf_pytorch/download_example_data.sh:
--------------------------------------------------------------------------------
1 | wget https://people.eecs.berkeley.edu/~bmild/nerf/tiny_nerf_data.npz
2 | mkdir -p data
3 | cd data
4 | wget https://people.eecs.berkeley.edu/~bmild/nerf/nerf_example_data.zip
5 | unzip nerf_example_data.zip
6 | cd ..
7 |
--------------------------------------------------------------------------------
/submodules/nerf_pytorch/imgs/pipeline.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/autonomousvision/graf/c50d342fb567aec335b92e3f867c54b4dc4e1d09/submodules/nerf_pytorch/imgs/pipeline.jpg
--------------------------------------------------------------------------------
/submodules/nerf_pytorch/load_blender.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | import imageio
5 | import json
6 | import torch.nn.functional as F
7 | import cv2
8 |
9 |
10 | trans_t = lambda t : torch.Tensor([
11 | [1,0,0,0],
12 | [0,1,0,0],
13 | [0,0,1,t],
14 | [0,0,0,1]]).float()
15 |
16 | rot_phi = lambda phi : torch.Tensor([
17 | [1,0,0,0],
18 | [0,np.cos(phi),-np.sin(phi),0],
19 | [0,np.sin(phi), np.cos(phi),0],
20 | [0,0,0,1]]).float()
21 |
22 | rot_theta = lambda th : torch.Tensor([
23 | [np.cos(th),0,-np.sin(th),0],
24 | [0,1,0,0],
25 | [np.sin(th),0, np.cos(th),0],
26 | [0,0,0,1]]).float()
27 |
28 |
29 | def pose_spherical(theta, phi, radius):
30 | c2w = trans_t(radius)
31 | c2w = rot_phi(phi/180.*np.pi) @ c2w
32 | c2w = rot_theta(theta/180.*np.pi) @ c2w
33 | c2w = torch.Tensor(np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]])) @ c2w
34 | return c2w
35 |
36 |
37 | def load_blender_data(basedir, half_res=False, testskip=1):
38 | splits = ['train', 'val', 'test']
39 | metas = {}
40 | for s in splits:
41 | with open(os.path.join(basedir, 'transforms_{}.json'.format(s)), 'r') as fp:
42 | metas[s] = json.load(fp)
43 |
44 | all_imgs = []
45 | all_poses = []
46 | counts = [0]
47 | for s in splits:
48 | meta = metas[s]
49 | imgs = []
50 | poses = []
51 | if s=='train' or testskip==0:
52 | skip = 1
53 | else:
54 | skip = testskip
55 |
56 | for frame in meta['frames'][::skip]:
57 | fname = os.path.join(basedir, frame['file_path'] + '.png')
58 | imgs.append(imageio.imread(fname))
59 | poses.append(np.array(frame['transform_matrix']))
60 | imgs = (np.array(imgs) / 255.).astype(np.float32) # keep all 4 channels (RGBA)
61 | poses = np.array(poses).astype(np.float32)
62 | counts.append(counts[-1] + imgs.shape[0])
63 | all_imgs.append(imgs)
64 | all_poses.append(poses)
65 |
66 | i_split = [np.arange(counts[i], counts[i+1]) for i in range(3)]
67 |
68 | imgs = np.concatenate(all_imgs, 0)
69 | poses = np.concatenate(all_poses, 0)
70 |
71 | H, W = imgs[0].shape[:2]
72 | camera_angle_x = float(meta['camera_angle_x'])
73 | focal = .5 * W / np.tan(.5 * camera_angle_x)
74 |
75 | render_poses = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180,180,40+1)[:-1]], 0)
76 |
77 | if half_res:
78 | H = H//2
79 | W = W//2
80 | focal = focal/2.
81 |
82 | imgs_half_res = np.zeros((imgs.shape[0], H, W, 4))
83 | for i, img in enumerate(imgs):
84 | imgs_half_res[i] = cv2.resize(img, (H, W), interpolation=cv2.INTER_AREA)
85 | imgs = imgs_half_res
86 | # imgs = tf.image.resize_area(imgs, [400, 400]).numpy()
87 |
88 |
89 | return imgs, poses, render_poses, [H, W, focal], i_split
90 |
91 |
92 |
--------------------------------------------------------------------------------
/submodules/nerf_pytorch/load_deepvoxels.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import imageio
4 |
5 |
6 | def load_dv_data(scene='cube', basedir='/data/deepvoxels', testskip=8):
7 |
8 |
9 | def parse_intrinsics(filepath, trgt_sidelength, invert_y=False):
10 | # Get camera intrinsics
11 | with open(filepath, 'r') as file:
12 | f, cx, cy = list(map(float, file.readline().split()))[:3]
13 | grid_barycenter = np.array(list(map(float, file.readline().split())))
14 | near_plane = float(file.readline())
15 | scale = float(file.readline())
16 | height, width = map(float, file.readline().split())
17 |
18 | try:
19 | world2cam_poses = int(file.readline())
20 | except ValueError:
21 | world2cam_poses = None
22 |
23 | if world2cam_poses is None:
24 | world2cam_poses = False
25 |
26 | world2cam_poses = bool(world2cam_poses)
27 |
28 | print(cx,cy,f,height,width)
29 |
30 | cx = cx / width * trgt_sidelength
31 | cy = cy / height * trgt_sidelength
32 | f = trgt_sidelength / height * f
33 |
34 | fx = f
35 | if invert_y:
36 | fy = -f
37 | else:
38 | fy = f
39 |
40 | # Build the intrinsic matrices
41 | full_intrinsic = np.array([[fx, 0., cx, 0.],
42 | [0., fy, cy, 0],
43 | [0., 0, 1, 0],
44 | [0, 0, 0, 1]])
45 |
46 | return full_intrinsic, grid_barycenter, scale, near_plane, world2cam_poses
47 |
48 |
49 | def load_pose(filename):
50 | assert os.path.isfile(filename)
51 | nums = open(filename).read().split()
52 | return np.array([float(x) for x in nums]).reshape([4,4]).astype(np.float32)
53 |
54 |
55 | H = 512
56 | W = 512
57 | deepvoxels_base = '{}/train/{}/'.format(basedir, scene)
58 |
59 | full_intrinsic, grid_barycenter, scale, near_plane, world2cam_poses = parse_intrinsics(os.path.join(deepvoxels_base, 'intrinsics.txt'), H)
60 | print(full_intrinsic, grid_barycenter, scale, near_plane, world2cam_poses)
61 | focal = full_intrinsic[0,0]
62 | print(H, W, focal)
63 |
64 |
65 | def dir2poses(posedir):
66 | poses = np.stack([load_pose(os.path.join(posedir, f)) for f in sorted(os.listdir(posedir)) if f.endswith('txt')], 0)
67 | transf = np.array([
68 | [1,0,0,0],
69 | [0,-1,0,0],
70 | [0,0,-1,0],
71 | [0,0,0,1.],
72 | ])
73 | poses = poses @ transf
74 | poses = poses[:,:3,:4].astype(np.float32)
75 | return poses
76 |
77 | posedir = os.path.join(deepvoxels_base, 'pose')
78 | poses = dir2poses(posedir)
79 | testposes = dir2poses('{}/test/{}/pose'.format(basedir, scene))
80 | testposes = testposes[::testskip]
81 | valposes = dir2poses('{}/validation/{}/pose'.format(basedir, scene))
82 | valposes = valposes[::testskip]
83 |
84 | imgfiles = [f for f in sorted(os.listdir(os.path.join(deepvoxels_base, 'rgb'))) if f.endswith('png')]
85 | imgs = np.stack([imageio.imread(os.path.join(deepvoxels_base, 'rgb', f))/255. for f in imgfiles], 0).astype(np.float32)
86 |
87 |
88 | testimgd = '{}/test/{}/rgb'.format(basedir, scene)
89 | imgfiles = [f for f in sorted(os.listdir(testimgd)) if f.endswith('png')]
90 | testimgs = np.stack([imageio.imread(os.path.join(testimgd, f))/255. for f in imgfiles[::testskip]], 0).astype(np.float32)
91 |
92 | valimgd = '{}/validation/{}/rgb'.format(basedir, scene)
93 | imgfiles = [f for f in sorted(os.listdir(valimgd)) if f.endswith('png')]
94 | valimgs = np.stack([imageio.imread(os.path.join(valimgd, f))/255. for f in imgfiles[::testskip]], 0).astype(np.float32)
95 |
96 | all_imgs = [imgs, valimgs, testimgs]
97 | counts = [0] + [x.shape[0] for x in all_imgs]
98 | counts = np.cumsum(counts)
99 | i_split = [np.arange(counts[i], counts[i+1]) for i in range(3)]
100 |
101 | imgs = np.concatenate(all_imgs, 0)
102 | poses = np.concatenate([poses, valposes, testposes], 0)
103 |
104 | render_poses = testposes
105 |
106 | print(poses.shape, imgs.shape)
107 |
108 | return imgs, poses, render_poses, [H,W,focal], i_split
109 |
110 |
111 |
--------------------------------------------------------------------------------
/submodules/nerf_pytorch/requirements.txt:
--------------------------------------------------------------------------------
1 | torch>=1.4.0
2 | torchvision>=0.2.1
3 | imageio
4 | imageio-ffmpeg
5 | matplotlib
6 | configargparse
7 | tensorboard==1.14.0
8 | tqdm
9 | opencv-python
10 |
--------------------------------------------------------------------------------
/submodules/nerf_pytorch/torchsearchsorted/.gitignore:
--------------------------------------------------------------------------------
1 | # Prerequisites
2 | *.d
3 |
4 | # Object files
5 | *.o
6 | *.ko
7 | *.obj
8 | *.elf
9 |
10 | # Linker output
11 | *.ilk
12 | *.map
13 | *.exp
14 |
15 | # Precompiled Headers
16 | *.gch
17 | *.pch
18 |
19 | # Libraries
20 | *.lib
21 | *.a
22 | *.la
23 | *.lo
24 |
25 | # Shared objects (inc. Windows DLLs)
26 | *.dll
27 | *.so
28 | *.so.*
29 | *.dylib
30 |
31 | # Executables
32 | *.exe
33 | *.out
34 | *.app
35 | *.i*86
36 | *.x86_64
37 | *.hex
38 |
39 | # Debug files
40 | *.dSYM/
41 | *.su
42 | *.idb
43 | *.pdb
44 |
45 | # Kernel Module Compile Results
46 | *.mod*
47 | *.cmd
48 | .tmp_versions/
49 | modules.order
50 | Module.symvers
51 | Mkfile.old
52 | dkms.conf
53 |
54 |
55 | # Byte-compiled / optimized / DLL files
56 | __pycache__/
57 | *.py[cod]
58 | *$py.class
59 |
60 | # C extensions
61 | *.so
62 |
63 | # Distribution / packaging
64 | .Python
65 | build/
66 | develop-eggs/
67 | dist/
68 | downloads/
69 | eggs/
70 | .eggs/
71 | lib/
72 | lib64/
73 | parts/
74 | sdist/
75 | var/
76 | wheels/
77 | *.egg-info/
78 | .installed.cfg
79 | *.egg
80 | MANIFEST
81 |
82 | # PyInstaller
83 | # Usually these files are written by a python script from a template
84 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
85 | *.manifest
86 | *.spec
87 |
88 | # Installer logs
89 | pip-log.txt
90 | pip-delete-this-directory.txt
91 |
92 | # Unit test / coverage reports
93 | htmlcov/
94 | .tox/
95 | .coverage
96 | .coverage.*
97 | .cache
98 | nosetests.xml
99 | coverage.xml
100 | *.cover
101 | .hypothesis/
102 | .pytest_cache/
103 |
104 | # Translations
105 | *.mo
106 | *.pot
107 |
108 | # Django stuff:
109 | *.log
110 | local_settings.py
111 | db.sqlite3
112 |
113 | # Flask stuff:
114 | instance/
115 | .webassets-cache
116 |
117 | # Scrapy stuff:
118 | .scrapy
119 |
120 | # Sphinx documentation
121 | docs/_build/
122 |
123 | # PyBuilder
124 | target/
125 |
126 | # Jupyter Notebook
127 | .ipynb_checkpoints
128 |
129 | # pyenv
130 | .python-version
131 |
132 | # celery beat schedule file
133 | celerybeat-schedule
134 |
135 | # SageMath parsed files
136 | *.sage.py
137 |
138 | # Environments
139 | .env
140 | .venv
141 | env/
142 | venv/
143 | ENV/
144 | env.bak/
145 | venv.bak/
146 |
147 | # Spyder project settings
148 | .spyderproject
149 | .spyproject
150 |
151 | # Rope project settings
152 | .ropeproject
153 |
154 | # mkdocs documentation
155 | /site
156 |
157 | # mypy
158 | .mypy_cache/
159 |
--------------------------------------------------------------------------------
/submodules/nerf_pytorch/torchsearchsorted/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2019, Inria (Antoine Liutkus)
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | 1. Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | 2. Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | 3. Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 |
--------------------------------------------------------------------------------
/submodules/nerf_pytorch/torchsearchsorted/README.md:
--------------------------------------------------------------------------------
1 | # Pytorch Custom CUDA kernel for searchsorted
2 |
3 | This repository is an implementation of the searchsorted function to work for pytorch CUDA Tensors. Initially derived from the great [C extension tutorial](https://github.com/chrischoy/pytorch-custom-cuda-tutorial), but totally changed since then because building C extensions is not available anymore on pytorch 1.0.
4 |
5 |
6 | > Warnings:
7 | > * only works with pytorch > v1.3 and CUDA >= v10.1
8 | > * **NOTE** When using `searchsorted()` for practical applications, tensors need to be contiguous in memory. This can be easily achieved by calling `tensor.contiguous()` on the input tensors. Failing to do so _will_ lead to inconsistent results across applications.
9 |
10 | ## Description
11 |
12 | Implements a function `searchsorted(a, v, out, side)` that works just like the [numpy version](https://docs.scipy.org/doc/numpy/reference/generated/numpy.searchsorted.html#numpy.searchsorted) except that `a` and `v` are matrices.
13 | * `a` is of shape either `(1, ncols_a)` or `(nrows, ncols_a)`, and is contiguous in memory (do `a.contiguous()` to ensure this).
14 | * `v` is of shape either `(1, ncols_v)` or `(nrows, ncols_v)`, and is contiguous in memory (do `v.contiguous()` to ensure this).
15 | * `out` is either `None` or of shape `(nrows, ncols_v)`. If provided and of the right shape, the result is put there. This is to avoid costly memory allocations if the user already did it. If provided, `out` should be contiguous in memory too (do `out.contiguous()` to ensure this).
16 | * `side` is either "left" or "right". See the [numpy doc](https://docs.scipy.org/doc/numpy/reference/generated/numpy.searchsorted.html#numpy.searchsorted). Please not that the current implementation *does not correctly handle this parameter*. Help welcome to improve the speed of [this PR](https://github.com/aliutkus/torchsearchsorted/pull/7)
17 |
18 | the output is of size as `(nrows, ncols_v)`. If all input tensors are on GPU, a cuda version will be called. Otherwise, it will be on CPU.
19 |
20 |
21 | **Disclaimers**
22 |
23 | * This function has not been heavily tested. Use at your own risks
24 | * When `a` is not sorted, the results vary from numpy's version. But I decided not to care about this because the function should not be called in this case.
25 | * In some cases, the results vary from numpy's version. However, as far as I could see, this only happens when values are equal, which means we actually don't care about the order in which this value is added. I decided not to care about this also.
26 | * vectors have to be contiguous for torchsearchsorted to give consistant results. use `.contiguous()` on all tensor arguments before calling
27 |
28 |
29 | ## Installation
30 |
31 | Just `pip install .`, in the root folder of this repo. This will compile
32 | and install the torchsearchsorted module.
33 |
34 | be careful that sometimes, `nvcc` needs versions of `gcc` and `g++` that are older than those found by default on the system. If so, just create symbolic links to the right versions in your cuda/bin folder (where `nvcc` is)
35 |
36 | For instance, on my machine, I had `gcc` and `g++` v9 installed, but `nvcc` required v8.
37 | So I had to do:
38 |
39 | > sudo apt-get install g++-8 gcc-8
40 | > sudo ln -s /usr/bin/gcc-8 /usr/local/cuda-10.1/bin/gcc
41 | > sudo ln -s /usr/bin/g++-8 /usr/local/cuda-10.1/bin/g++
42 |
43 | be careful that you need pytorch to be installed on your system. The code was tested on pytorch v1.3
44 |
45 | ## Usage
46 |
47 | Just import the torchsearchsorted package after installation. I typically do:
48 |
49 | ```
50 | from torchsearchsorted import searchsorted
51 | ```
52 |
53 |
54 | ## Testing
55 |
56 | Under the `examples` subfolder, you may:
57 |
58 | 1. try `python test.py` with `torch` available.
59 |
60 | ```
61 | Looking for 50000x1000 values in 50000x300 entries
62 | NUMPY: searchsorted in 4851.592ms
63 | CPU: searchsorted in 4805.432ms
64 | difference between CPU and NUMPY: 0.000
65 | GPU: searchsorted in 1.055ms
66 | difference between GPU and NUMPY: 0.000
67 |
68 | Looking for 50000x1000 values in 50000x300 entries
69 | NUMPY: searchsorted in 4333.964ms
70 | CPU: searchsorted in 4753.958ms
71 | difference between CPU and NUMPY: 0.000
72 | GPU: searchsorted in 0.391ms
73 | difference between GPU and NUMPY: 0.000
74 | ```
75 | The first run comprises the time of allocation, while the second one does not.
76 |
77 | 2. You may also use the nice `benchmark.py` code written by [@baldassarreFe](https://github.com/baldassarreFe), that tests `searchsorted` on many runs:
78 |
79 | ```
80 | Benchmark searchsorted:
81 | - a [5000 x 300]
82 | - v [5000 x 100]
83 | - reporting fastest time of 20 runs
84 | - each run executes searchsorted 100 times
85 |
86 | Numpy: 4.6302046799100935
87 | CPU: 5.041533078998327
88 | CUDA: 0.0007955809123814106
89 | ```
90 |
--------------------------------------------------------------------------------
/submodules/nerf_pytorch/torchsearchsorted/examples/benchmark.py:
--------------------------------------------------------------------------------
1 | import timeit
2 |
3 | import torch
4 | import numpy as np
5 | from torchsearchsorted import searchsorted, numpy_searchsorted
6 |
7 | B = 5_000
8 | A = 300
9 | V = 100
10 |
11 | repeats = 20
12 | number = 100
13 |
14 | print(
15 | f'Benchmark searchsorted:',
16 | f'- a [{B} x {A}]',
17 | f'- v [{B} x {V}]',
18 | f'- reporting fastest time of {repeats} runs',
19 | f'- each run executes searchsorted {number} times',
20 | sep='\n',
21 | end='\n\n'
22 | )
23 |
24 |
25 | def get_arrays():
26 | a = np.sort(np.random.randn(B, A), axis=1)
27 | v = np.random.randn(B, V)
28 | out = np.empty_like(v, dtype=np.long)
29 | return a, v, out
30 |
31 |
32 | def get_tensors(device):
33 | a = torch.sort(torch.randn(B, A, device=device), dim=1)[0]
34 | v = torch.randn(B, V, device=device)
35 | out = torch.empty(B, V, device=device, dtype=torch.long)
36 | if torch.cuda.is_available():
37 | torch.cuda.synchronize()
38 | return a, v, out
39 |
40 | def searchsorted_synchronized(a,v,out=None,side='left'):
41 | out = searchsorted(a,v,out,side)
42 | torch.cuda.synchronize()
43 | return out
44 |
45 | numpy = timeit.repeat(
46 | stmt="numpy_searchsorted(a, v, side='left')",
47 | setup="a, v, out = get_arrays()",
48 | globals=globals(),
49 | repeat=repeats,
50 | number=number
51 | )
52 | print('Numpy: ', min(numpy), sep='\t')
53 |
54 | cpu = timeit.repeat(
55 | stmt="searchsorted(a, v, out, side='left')",
56 | setup="a, v, out = get_tensors(device='cpu')",
57 | globals=globals(),
58 | repeat=repeats,
59 | number=number
60 | )
61 | print('CPU: ', min(cpu), sep='\t')
62 |
63 | if torch.cuda.is_available():
64 | gpu = timeit.repeat(
65 | stmt="searchsorted_synchronized(a, v, out, side='left')",
66 | setup="a, v, out = get_tensors(device='cuda')",
67 | globals=globals(),
68 | repeat=repeats,
69 | number=number
70 | )
71 | print('CUDA: ', min(gpu), sep='\t')
72 |
--------------------------------------------------------------------------------
/submodules/nerf_pytorch/torchsearchsorted/examples/test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchsearchsorted import searchsorted, numpy_searchsorted
3 | import time
4 |
5 | if __name__ == '__main__':
6 | # defining the number of tests
7 | ntests = 2
8 |
9 | # defining the problem dimensions
10 | nrows_a = 50000
11 | nrows_v = 50000
12 | nsorted_values = 300
13 | nvalues = 1000
14 |
15 | # defines the variables. The first run will comprise allocation, the
16 | # further ones will not
17 | test_GPU = None
18 | test_CPU = None
19 |
20 | for ntest in range(ntests):
21 | print("\nLooking for %dx%d values in %dx%d entries" % (nrows_v, nvalues,
22 | nrows_a,
23 | nsorted_values))
24 |
25 | side = 'right'
26 | # generate a matrix with sorted rows
27 | a = torch.randn(nrows_a, nsorted_values, device='cpu')
28 | a = torch.sort(a, dim=1)[0]
29 | # generate a matrix of values to searchsort
30 | v = torch.randn(nrows_v, nvalues, device='cpu')
31 |
32 | # a = torch.tensor([[0., 1.]])
33 | # v = torch.tensor([[1.]])
34 |
35 | t0 = time.time()
36 | test_NP = torch.tensor(numpy_searchsorted(a, v, side))
37 | print('NUMPY: searchsorted in %0.3fms' % (1000*(time.time()-t0)))
38 | t0 = time.time()
39 | test_CPU = searchsorted(a, v, test_CPU, side)
40 | print('CPU: searchsorted in %0.3fms' % (1000*(time.time()-t0)))
41 | # compute the difference between both
42 | error_CPU = torch.norm(test_NP.double()
43 | - test_CPU.double()).numpy()
44 | if error_CPU:
45 | import ipdb; ipdb.set_trace()
46 | print(' difference between CPU and NUMPY: %0.3f' % error_CPU)
47 |
48 | if not torch.cuda.is_available():
49 | print('CUDA is not available on this machine, cannot go further.')
50 | continue
51 | else:
52 | # now do the CPU
53 | a = a.to('cuda')
54 | v = v.to('cuda')
55 | torch.cuda.synchronize()
56 | # launch searchsorted on those
57 | t0 = time.time()
58 | test_GPU = searchsorted(a, v, test_GPU, side)
59 | torch.cuda.synchronize()
60 | print('GPU: searchsorted in %0.3fms' % (1000*(time.time()-t0)))
61 |
62 | # compute the difference between both
63 | error_CUDA = torch.norm(test_NP.to('cuda').double()
64 | - test_GPU.double()).cpu().numpy()
65 |
66 | print(' difference between GPU and NUMPY: %0.3f' % error_CUDA)
67 |
--------------------------------------------------------------------------------
/submodules/nerf_pytorch/torchsearchsorted/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 | from torch.utils.cpp_extension import BuildExtension, CUDA_HOME
3 | from torch.utils.cpp_extension import CppExtension, CUDAExtension
4 |
5 | # In any case, include the CPU version
6 | modules = [
7 | CppExtension('torchsearchsorted.cpu',
8 | ['src/cpu/searchsorted_cpu_wrapper.cpp']),
9 | ]
10 |
11 | # If nvcc is available, add the CUDA extension
12 | if CUDA_HOME:
13 | modules.append(
14 | CUDAExtension('torchsearchsorted.cuda',
15 | ['src/cuda/searchsorted_cuda_wrapper.cpp',
16 | 'src/cuda/searchsorted_cuda_kernel.cu'])
17 | )
18 |
19 | tests_require = [
20 | 'pytest',
21 | ]
22 |
23 | # Now proceed to setup
24 | setup(
25 | name='torchsearchsorted',
26 | version='1.1',
27 | description='A searchsorted implementation for pytorch',
28 | keywords='searchsorted',
29 | author='Antoine Liutkus',
30 | author_email='antoine.liutkus@inria.fr',
31 | packages=find_packages(where='src'),
32 | package_dir={"": "src"},
33 | ext_modules=modules,
34 | tests_require=tests_require,
35 | extras_require={
36 | 'test': tests_require,
37 | },
38 | cmdclass={
39 | 'build_ext': BuildExtension
40 | }
41 | )
42 |
--------------------------------------------------------------------------------
/submodules/nerf_pytorch/torchsearchsorted/src/cpu/searchsorted_cpu_wrapper.cpp:
--------------------------------------------------------------------------------
1 | #include "searchsorted_cpu_wrapper.h"
2 | #include
3 |
4 | template
5 | int eval(scalar_t val, scalar_t *a, int64_t row, int64_t col, int64_t ncol, bool side_left)
6 | {
7 | /* Evaluates whether a[row,col] < val <= a[row, col+1]*/
8 |
9 | if (col == ncol - 1)
10 | {
11 | // special case: we are on the right border
12 | if (a[row * ncol + col] <= val){
13 | return 1;}
14 | else {
15 | return -1;}
16 | }
17 | bool is_lower;
18 | bool is_next_higher;
19 |
20 | if (side_left) {
21 | // a[row, col] < v <= a[row, col+1]
22 | is_lower = (a[row * ncol + col] < val);
23 | is_next_higher = (a[row*ncol + col + 1] >= val);
24 | } else {
25 | // a[row, col] <= v < a[row, col+1]
26 | is_lower = (a[row * ncol + col] <= val);
27 | is_next_higher = (a[row * ncol + col + 1] > val);
28 | }
29 | if (is_lower && is_next_higher) {
30 | // we found the right spot
31 | return 0;
32 | } else if (is_lower) {
33 | // answer is on the right side
34 | return 1;
35 | } else {
36 | // answer is on the left side
37 | return -1;
38 | }
39 | }
40 |
41 | template
42 | int64_t binary_search(scalar_t*a, int64_t row, scalar_t val, int64_t ncol, bool side_left)
43 | {
44 | /* Look for the value `val` within row `row` of matrix `a`, which
45 | has `ncol` columns.
46 |
47 | the `a` matrix is assumed sorted in increasing order, row-wise
48 |
49 | returns:
50 | * -1 if `val` is smaller than the smallest value found within that row of `a`
51 | * `ncol` - 1 if `val` is larger than the largest element of that row of `a`
52 | * Otherwise, return the column index `res` such that:
53 | - a[row, col] < val <= a[row, col+1]. (if side_left), or
54 | - a[row, col] < val <= a[row, col+1] (if not side_left).
55 | */
56 |
57 | //start with left at 0 and right at number of columns of a
58 | int64_t right = ncol;
59 | int64_t left = 0;
60 |
61 | while (right >= left) {
62 | // take the midpoint of current left and right cursors
63 | int64_t mid = left + (right-left)/2;
64 |
65 | // check the relative position of val: are we good here ?
66 | int rel_pos = eval(val, a, row, mid, ncol, side_left);
67 | // we found the point
68 | if(rel_pos == 0) {
69 | return mid;
70 | } else if (rel_pos > 0) {
71 | if (mid==ncol-1){return ncol-1;}
72 | // the answer is on the right side
73 | left = mid;
74 | } else {
75 | if (mid==0){return -1;}
76 | right = mid;
77 | }
78 | }
79 | return -1;
80 | }
81 |
82 | void searchsorted_cpu_wrapper(
83 | at::Tensor a,
84 | at::Tensor v,
85 | at::Tensor res,
86 | bool side_left)
87 | {
88 |
89 | // Get the dimensions
90 | auto nrow_a = a.size(/*dim=*/0);
91 | auto ncol_a = a.size(/*dim=*/1);
92 | auto nrow_v = v.size(/*dim=*/0);
93 | auto ncol_v = v.size(/*dim=*/1);
94 |
95 | auto nrow_res = fmax(nrow_a, nrow_v);
96 |
97 | //auto acc_v = v.accessor();
98 | //auto acc_res = res.accessor();
99 |
100 | AT_DISPATCH_ALL_TYPES(a.type(), "searchsorted cpu", [&] {
101 |
102 | scalar_t* a_data = a.data_ptr();
103 | scalar_t* v_data = v.data_ptr();
104 | int64_t* res_data = res.data();
105 |
106 | for (int64_t row = 0; row < nrow_res; row++)
107 | {
108 | for (int64_t col = 0; col < ncol_v; col++)
109 | {
110 | // get the value to look for
111 | int64_t row_in_v = (nrow_v == 1) ? 0 : row;
112 | int64_t row_in_a = (nrow_a == 1) ? 0 : row;
113 |
114 | int64_t idx_in_v = row_in_v * ncol_v + col;
115 | int64_t idx_in_res = row * ncol_v + col;
116 |
117 | // apply binary search
118 | res_data[idx_in_res] = (binary_search(a_data, row_in_a, v_data[idx_in_v], ncol_a, side_left) + 1);
119 | }
120 | }
121 | });
122 | }
123 |
124 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
125 | m.def("searchsorted_cpu_wrapper", &searchsorted_cpu_wrapper, "searchsorted (CPU)");
126 | }
127 |
--------------------------------------------------------------------------------
/submodules/nerf_pytorch/torchsearchsorted/src/cpu/searchsorted_cpu_wrapper.h:
--------------------------------------------------------------------------------
1 | #ifndef _SEARCHSORTED_CPU
2 | #define _SEARCHSORTED_CPU
3 |
4 | #include
5 |
6 | void searchsorted_cpu_wrapper(
7 | at::Tensor a,
8 | at::Tensor v,
9 | at::Tensor res,
10 | bool side_left);
11 |
12 | #endif
--------------------------------------------------------------------------------
/submodules/nerf_pytorch/torchsearchsorted/src/cuda/searchsorted_cuda_kernel.cu:
--------------------------------------------------------------------------------
1 | #include "searchsorted_cuda_kernel.h"
2 |
3 | template
4 | __device__
5 | int eval(scalar_t val, scalar_t *a, int64_t row, int64_t col, int64_t ncol, bool side_left)
6 | {
7 | /* Evaluates whether a[row,col] < val <= a[row, col+1]*/
8 |
9 | if (col == ncol - 1)
10 | {
11 | // special case: we are on the right border
12 | if (a[row * ncol + col] <= val){
13 | return 1;}
14 | else {
15 | return -1;}
16 | }
17 | bool is_lower;
18 | bool is_next_higher;
19 |
20 | if (side_left) {
21 | // a[row, col] < v <= a[row, col+1]
22 | is_lower = (a[row * ncol + col] < val);
23 | is_next_higher = (a[row*ncol + col + 1] >= val);
24 | } else {
25 | // a[row, col] <= v < a[row, col+1]
26 | is_lower = (a[row * ncol + col] <= val);
27 | is_next_higher = (a[row * ncol + col + 1] > val);
28 | }
29 | if (is_lower && is_next_higher) {
30 | // we found the right spot
31 | return 0;
32 | } else if (is_lower) {
33 | // answer is on the right side
34 | return 1;
35 | } else {
36 | // answer is on the left side
37 | return -1;
38 | }
39 | }
40 |
41 | template
42 | __device__
43 | int binary_search(scalar_t *a, int64_t row, scalar_t val, int64_t ncol, bool side_left)
44 | {
45 | /* Look for the value `val` within row `row` of matrix `a`, which
46 | has `ncol` columns.
47 |
48 | the `a` matrix is assumed sorted in increasing order, row-wise
49 |
50 | Returns
51 | * -1 if `val` is smaller than the smallest value found within that row of `a`
52 | * `ncol` - 1 if `val` is larger than the largest element of that row of `a`
53 | * Otherwise, return the column index `res` such that:
54 | - a[row, col] < val <= a[row, col+1]. (if side_left), or
55 | - a[row, col] < val <= a[row, col+1] (if not side_left).
56 | */
57 |
58 | //start with left at 0 and right at number of columns of a
59 | int64_t right = ncol;
60 | int64_t left = 0;
61 |
62 | while (right >= left) {
63 | // take the midpoint of current left and right cursors
64 | int64_t mid = left + (right-left)/2;
65 |
66 | // check the relative position of val: are we good here ?
67 | int rel_pos = eval(val, a, row, mid, ncol, side_left);
68 | // we found the point
69 | if(rel_pos == 0) {
70 | return mid;
71 | } else if (rel_pos > 0) {
72 | if (mid==ncol-1){return ncol-1;}
73 | // the answer is on the right side
74 | left = mid;
75 | } else {
76 | if (mid==0){return -1;}
77 | right = mid;
78 | }
79 | }
80 | return -1;
81 | }
82 |
83 | template
84 | __global__
85 | void searchsorted_kernel(
86 | int64_t *res,
87 | scalar_t *a,
88 | scalar_t *v,
89 | int64_t nrow_res, int64_t nrow_a, int64_t nrow_v, int64_t ncol_a, int64_t ncol_v, bool side_left)
90 | {
91 | // get current row and column
92 | int64_t row = blockIdx.y*blockDim.y+threadIdx.y;
93 | int64_t col = blockIdx.x*blockDim.x+threadIdx.x;
94 |
95 | // check whether we are outside the bounds of what needs be computed.
96 | if ((row >= nrow_res) || (col >= ncol_v)) {
97 | return;}
98 |
99 | // get the value to look for
100 | int64_t row_in_v = (nrow_v==1) ? 0: row;
101 | int64_t row_in_a = (nrow_a==1) ? 0: row;
102 | int64_t idx_in_v = row_in_v*ncol_v+col;
103 | int64_t idx_in_res = row*ncol_v+col;
104 |
105 | // apply binary search
106 | res[idx_in_res] = binary_search(a, row_in_a, v[idx_in_v], ncol_a, side_left)+1;
107 | }
108 |
109 |
110 | void searchsorted_cuda(
111 | at::Tensor a,
112 | at::Tensor v,
113 | at::Tensor res,
114 | bool side_left){
115 |
116 | // Get the dimensions
117 | auto nrow_a = a.size(/*dim=*/0);
118 | auto nrow_v = v.size(/*dim=*/0);
119 | auto ncol_a = a.size(/*dim=*/1);
120 | auto ncol_v = v.size(/*dim=*/1);
121 |
122 | auto nrow_res = fmax(double(nrow_a), double(nrow_v));
123 |
124 | // prepare the kernel configuration
125 | dim3 threads(ncol_v, nrow_res);
126 | dim3 blocks(1, 1);
127 | if (nrow_res*ncol_v > 1024){
128 | threads.x = int(fmin(double(1024), double(ncol_v)));
129 | threads.y = floor(1024/threads.x);
130 | blocks.x = ceil(double(ncol_v)/double(threads.x));
131 | blocks.y = ceil(double(nrow_res)/double(threads.y));
132 | }
133 |
134 | AT_DISPATCH_ALL_TYPES(a.type(), "searchsorted cuda", ([&] {
135 | searchsorted_kernel<<>>(
136 | res.data(),
137 | a.data(),
138 | v.data(),
139 | nrow_res, nrow_a, nrow_v, ncol_a, ncol_v, side_left);
140 | }));
141 |
142 | }
143 |
--------------------------------------------------------------------------------
/submodules/nerf_pytorch/torchsearchsorted/src/cuda/searchsorted_cuda_kernel.h:
--------------------------------------------------------------------------------
1 | #ifndef _SEARCHSORTED_CUDA_KERNEL
2 | #define _SEARCHSORTED_CUDA_KERNEL
3 |
4 | #include
5 |
6 | void searchsorted_cuda(
7 | at::Tensor a,
8 | at::Tensor v,
9 | at::Tensor res,
10 | bool side_left);
11 |
12 | #endif
13 |
--------------------------------------------------------------------------------
/submodules/nerf_pytorch/torchsearchsorted/src/cuda/searchsorted_cuda_wrapper.cpp:
--------------------------------------------------------------------------------
1 | #include "searchsorted_cuda_wrapper.h"
2 |
3 | // C++ interface
4 |
5 | #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
6 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
7 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
8 |
9 | void searchsorted_cuda_wrapper(at::Tensor a, at::Tensor v, at::Tensor res, bool side_left)
10 | {
11 | CHECK_INPUT(a);
12 | CHECK_INPUT(v);
13 | CHECK_INPUT(res);
14 |
15 | searchsorted_cuda(a, v, res, side_left);
16 | }
17 |
18 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
19 | m.def("searchsorted_cuda_wrapper", &searchsorted_cuda_wrapper, "searchsorted (CUDA)");
20 | }
21 |
--------------------------------------------------------------------------------
/submodules/nerf_pytorch/torchsearchsorted/src/cuda/searchsorted_cuda_wrapper.h:
--------------------------------------------------------------------------------
1 | #ifndef _SEARCHSORTED_CUDA_WRAPPER
2 | #define _SEARCHSORTED_CUDA_WRAPPER
3 |
4 | #include
5 | #include "searchsorted_cuda_kernel.h"
6 |
7 | void searchsorted_cuda_wrapper(
8 | at::Tensor a,
9 | at::Tensor v,
10 | at::Tensor res,
11 | bool side_left);
12 |
13 | #endif
14 |
--------------------------------------------------------------------------------
/submodules/nerf_pytorch/torchsearchsorted/src/torchsearchsorted/__init__.py:
--------------------------------------------------------------------------------
1 | from .searchsorted import searchsorted
2 | from .utils import numpy_searchsorted
3 |
--------------------------------------------------------------------------------
/submodules/nerf_pytorch/torchsearchsorted/src/torchsearchsorted/searchsorted.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import torch
4 |
5 | # trying to import the CPU searchsorted
6 | SEARCHSORTED_CPU_AVAILABLE = True
7 | try:
8 | from torchsearchsorted.cpu import searchsorted_cpu_wrapper
9 | except ImportError:
10 | SEARCHSORTED_CPU_AVAILABLE = False
11 |
12 | # trying to import the CUDA searchsorted
13 | SEARCHSORTED_GPU_AVAILABLE = True
14 | try:
15 | from torchsearchsorted.cuda import searchsorted_cuda_wrapper
16 | except ImportError:
17 | SEARCHSORTED_GPU_AVAILABLE = False
18 |
19 |
20 | def searchsorted(a: torch.Tensor, v: torch.Tensor,
21 | out: Optional[torch.LongTensor] = None,
22 | side='left') -> torch.LongTensor:
23 | assert len(a.shape) == 2, "input `a` must be 2-D."
24 | assert len(v.shape) == 2, "input `v` mus(t be 2-D."
25 | assert (a.shape[0] == v.shape[0]
26 | or a.shape[0] == 1
27 | or v.shape[0] == 1), ("`a` and `v` must have the same number of "
28 | "rows or one of them must have only one ")
29 | assert a.device == v.device, '`a` and `v` must be on the same device'
30 |
31 | result_shape = (max(a.shape[0], v.shape[0]), v.shape[1])
32 | if out is not None:
33 | assert out.device == a.device, "`out` must be on the same device as `a`"
34 | assert out.dtype == torch.long, "out.dtype must be torch.long"
35 | assert out.shape == result_shape, ("If the output tensor is provided, "
36 | "its shape must be correct.")
37 | else:
38 | out = torch.empty(result_shape, device=v.device, dtype=torch.long)
39 |
40 | if a.is_cuda and not SEARCHSORTED_GPU_AVAILABLE:
41 | raise Exception('torchsearchsorted on CUDA device is asked, but it seems '
42 | 'that it is not available. Please install it')
43 | if not a.is_cuda and not SEARCHSORTED_CPU_AVAILABLE:
44 | raise Exception('torchsearchsorted on CPU is not available. '
45 | 'Please install it.')
46 |
47 | left_side = 1 if side=='left' else 0
48 | if a.is_cuda:
49 | searchsorted_cuda_wrapper(a, v, out, left_side)
50 | else:
51 | searchsorted_cpu_wrapper(a, v, out, left_side)
52 |
53 | return out
54 |
--------------------------------------------------------------------------------
/submodules/nerf_pytorch/torchsearchsorted/src/torchsearchsorted/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def numpy_searchsorted(a: np.ndarray, v: np.ndarray, side='left'):
5 | """Numpy version of searchsorted that works batch-wise on pytorch tensors
6 | """
7 | nrows_a = a.shape[0]
8 | (nrows_v, ncols_v) = v.shape
9 | nrows_out = max(nrows_a, nrows_v)
10 | out = np.empty((nrows_out, ncols_v), dtype=np.long)
11 | def sel(data, row):
12 | return data[0] if data.shape[0] == 1 else data[row]
13 | for row in range(nrows_out):
14 | out[row] = np.searchsorted(sel(a, row), sel(v, row), side=side)
15 | return out
16 |
--------------------------------------------------------------------------------
/submodules/nerf_pytorch/torchsearchsorted/test/conftest.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 |
4 | devices = {'cpu': torch.device('cpu')}
5 | if torch.cuda.is_available():
6 | devices['cuda'] = torch.device('cuda:0')
7 |
8 |
9 | @pytest.fixture(params=devices.values(), ids=devices.keys())
10 | def device(request):
11 | return request.param
12 |
--------------------------------------------------------------------------------
/submodules/nerf_pytorch/torchsearchsorted/test/test_searchsorted.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | import torch
4 | import numpy as np
5 | from torchsearchsorted import searchsorted, numpy_searchsorted
6 | from itertools import product, repeat
7 |
8 |
9 | def test_searchsorted_output_dtype(device):
10 | B = 100
11 | A = 50
12 | V = 12
13 |
14 | a = torch.sort(torch.rand(B, V, device=device), dim=1)[0]
15 | v = torch.rand(B, A, device=device)
16 |
17 | out = searchsorted(a, v)
18 | out_np = numpy_searchsorted(a.cpu().numpy(), v.cpu().numpy())
19 | assert out.dtype == torch.long
20 | np.testing.assert_array_equal(out.cpu().numpy(), out_np)
21 |
22 | out = torch.empty(v.shape, dtype=torch.long, device=device)
23 | searchsorted(a, v, out)
24 | assert out.dtype == torch.long
25 | np.testing.assert_array_equal(out.cpu().numpy(), out_np)
26 |
27 | Ba_val = [1, 100, 200]
28 | Bv_val = [1, 100, 200]
29 | A_val = [1, 50, 500]
30 | V_val = [1, 12, 120]
31 | side_val = ['left', 'right']
32 | nrepeat = 100
33 |
34 | @pytest.mark.parametrize('Ba,Bv,A,V,side', product(Ba_val, Bv_val, A_val, V_val, side_val))
35 | def test_searchsorted_correct(Ba, Bv, A, V, side, device):
36 | if Ba > 1 and Bv > 1 and Ba != Bv:
37 | return
38 | for test in range(nrepeat):
39 | a = torch.sort(torch.rand(Ba, A, device=device), dim=1)[0]
40 | v = torch.rand(Bv, V, device=device)
41 | out_np = numpy_searchsorted(a.cpu().numpy(), v.cpu().numpy(),
42 | side=side)
43 | out = searchsorted(a, v, side=side).cpu().numpy()
44 | np.testing.assert_array_equal(out, out_np)
45 |
--------------------------------------------------------------------------------