├── CODE_OF_CONDUCT.md
├── CONTRIBUTING-RESEARCH.md
├── CONTRIBUTING.md
├── LICENSE-APPLE-SAMPLE-CODE
├── README-DATA.md
├── README.md
├── config.json
├── evaluate_psnr.py
├── experiments.py
├── exploration.ipynb
├── imgs
├── all_datasets.gif
├── coordinate-system.png
└── example-data
│ ├── chair1.png
│ ├── chair2.png
│ ├── chair3.png
│ └── chair4.png
├── misc
├── __init__.py
├── dataloaders.py
├── quantitative_evaluation.py
├── utils.py
└── viz.py
├── models
├── __init__.py
├── layers.py
├── neural_renderer.py
├── rotation_layers.py
└── submodels.py
├── requirements.txt
├── trained-models
└── chairs.pt
├── training
├── __init__.py
└── training.py
└── transforms3d
├── __init__.py
├── conversions.py
└── rotations.py
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Code of Conduct
2 |
3 | ## Our Pledge
4 |
5 | In the interest of fostering an open and welcoming environment, we as
6 | contributors and maintainers pledge to making participation in our project and
7 | our community a harassment-free experience for everyone, regardless of age, body
8 | size, disability, ethnicity, sex characteristics, gender identity and expression,
9 | level of experience, education, socio-economic status, nationality, personal
10 | appearance, race, religion, or sexual identity and orientation.
11 |
12 | ## Our Standards
13 |
14 | Examples of behavior that contributes to creating a positive environment
15 | include:
16 |
17 | * Using welcoming and inclusive language
18 | * Being respectful of differing viewpoints and experiences
19 | * Gracefully accepting constructive criticism
20 | * Focusing on what is best for the community
21 | * Showing empathy towards other community members
22 |
23 | Examples of unacceptable behavior by participants include:
24 |
25 | * The use of sexualized language or imagery and unwelcome sexual attention or
26 | advances
27 | * Trolling, insulting/derogatory comments, and personal or political attacks
28 | * Public or private harassment
29 | * Publishing others' private information, such as a physical or electronic
30 | address, without explicit permission
31 | * Other conduct which could reasonably be considered inappropriate in a
32 | professional setting
33 |
34 | ## Our Responsibilities
35 |
36 | Project maintainers are responsible for clarifying the standards of acceptable
37 | behavior and are expected to take appropriate and fair corrective action in
38 | response to any instances of unacceptable behavior.
39 |
40 | Project maintainers have the right and responsibility to remove, edit, or
41 | reject comments, commits, code, wiki edits, issues, and other contributions
42 | that are not aligned to this Code of Conduct, or to ban temporarily or
43 | permanently any contributor for other behaviors that they deem inappropriate,
44 | threatening, offensive, or harmful.
45 |
46 | ## Scope
47 |
48 | This Code of Conduct applies within all project spaces, and it also applies when
49 | an individual is representing the project or its community in public spaces.
50 | Examples of representing a project or community include using an official
51 | project e-mail address, posting via an official social media account, or acting
52 | as an appointed representative at an online or offline event. Representation of
53 | a project may be further defined and clarified by project maintainers.
54 |
55 | ## Enforcement
56 |
57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be
58 | reported by contacting the open source team at [opensource-conduct@group.apple.com](mailto:opensource-conduct@group.apple.com). All
59 | complaints will be reviewed and investigated and will result in a response that
60 | is deemed necessary and appropriate to the circumstances. The project team is
61 | obligated to maintain confidentiality with regard to the reporter of an incident.
62 | Further details of specific enforcement policies may be posted separately.
63 |
64 | Project maintainers who do not follow or enforce the Code of Conduct in good
65 | faith may face temporary or permanent repercussions as determined by other
66 | members of the project's leadership.
67 |
68 | ## Attribution
69 |
70 | This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org), version 1.4,
71 | available at [https://www.contributor-covenant.org/version/1/4/code-of-conduct.html](https://www.contributor-covenant.org/version/1/4/code-of-conduct.html)
--------------------------------------------------------------------------------
/CONTRIBUTING-RESEARCH.md:
--------------------------------------------------------------------------------
1 | # Contribution Guide
2 |
3 | Thanks for your interest in contributing. This project was released to accompany a research paper for purposes of reproducability, and beyond its publication there are limited plans for future development of the repository.
4 |
5 | ## Before you get started
6 |
7 | We ask that all community members read and observe our [Code of Conduct](CODE_OF_CONDUCT.md).
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contribution Guide
2 |
3 | Thanks for your interest in contributing. This project was released to accompany a research paper for purposes of reproducability, and beyond its publication there are limited plans for future development of the repository.
4 |
5 | While we welcome new pull requests and issues please note that our response may be limited. Forks and out-of-tree improvements are strongly encouraged.
6 |
7 | ## Before you get started
8 |
9 | By submitting a pull request, you represent that you have the right to license your contribution to Apple and the community, and agree by submitting the patch that your contributions are licensed under the [LICENSE](LICENSE).
10 |
11 | We ask that all community members read and observe our [Code of Conduct](CODE_OF_CONDUCT.md).
--------------------------------------------------------------------------------
/LICENSE-APPLE-SAMPLE-CODE:
--------------------------------------------------------------------------------
1 | Copyright (C) 2020 Apple Inc. All Rights Reserved.
2 |
3 | IMPORTANT: This Apple software is supplied to you by Apple
4 | Inc. ("Apple") in consideration of your agreement to the following
5 | terms, and your use, installation, modification or redistribution of
6 | this Apple software constitutes acceptance of these terms. If you do
7 | not agree with these terms, please do not use, install, modify or
8 | redistribute this Apple software.
9 |
10 | In consideration of your agreement to abide by the following terms, and
11 | subject to these terms, Apple grants you a personal, non-exclusive
12 | license, under Apple's copyrights in this original Apple software (the
13 | "Apple Software"), to use, reproduce, modify and redistribute the Apple
14 | Software, with or without modifications, in source and/or binary forms;
15 | provided that if you redistribute the Apple Software in its entirety and
16 | without modifications, you must retain this notice and the following
17 | text and disclaimers in all such redistributions of the Apple Software.
18 | Neither the name, trademarks, service marks or logos of Apple Inc. may
19 | be used to endorse or promote products derived from the Apple Software
20 | without specific prior written permission from Apple. Except as
21 | expressly stated in this notice, no other rights or licenses, express or
22 | implied, are granted by Apple herein, including but not limited to any
23 | patent rights that may be infringed by your derivative works or by other
24 | works in which the Apple Software may be incorporated.
25 |
26 | The Apple Software is provided by Apple on an "AS IS" basis. APPLE
27 | MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION
28 | THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS
29 | FOR A PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND
30 | OPERATION ALONE OR IN COMBINATION WITH YOUR PRODUCTS.
31 |
32 | IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL
33 | OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
34 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
35 | INTERRUPTION) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION,
36 | MODIFICATION AND/OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED
37 | AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE),
38 | STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE
39 | POSSIBILITY OF SUCH DAMAGE.
40 |
--------------------------------------------------------------------------------
/README-DATA.md:
--------------------------------------------------------------------------------
1 |
2 | README - ENR Data
3 |
4 | The datasets accompanying the ICML 2020 paper Equivariant neural rendering are licensed as follows:
5 |
6 |
7 | 1) The cars, chairs, and mugs-hq images are rendered from ShapeNet models (http://www.shapenet.org/), they are licensed under the shapenet license - https://www.shapenet.org/terms
8 |
9 | 2) The mountains images are licensed under under the CC-BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/legalcode).
10 | Satellite imagery © 2020 Maxar Technologies
11 |
12 |
13 | Additionally we thank Bernhard Vogl (salzamt@dativ.at) for the environmental map used to render the mugs-hq images. This environmental map and others can be found at http://dativ.at/lightprobes.
14 |
15 | If you find this code useful in your research, consider citing with:
16 |
17 |
18 | @article{dupont2020equivariant,
19 | title={Equivariant Neural Rendering},
20 | author={Dupont, Emilien and Miguel Angel, Bautista and Colburn, Alex and Sankar, Aditya and Guestrin, Carlos and Susskind, Josh and Shan, Qi},
21 | journal={arXiv preprint arXiv:2006.07630},
22 | year={2020}
23 | }
24 |
25 |
26 |
27 | ShapeNet Terms of Use
28 |
29 | After registering for a ShapeNet account, you will be considered for account approval and privilege elevation by an administrator. After approval, you (the "Researcher") receive permission to use the *ShapeNet database* (the "Database") *at Princeton University and Stanford University*. In exchange for being able to join the ShapeNet community and receive such permission, Researcher hereby agrees to the following terms and conditions:
30 |
31 | 1. Researcher shall use the Database only for non-commercial research and educational purposes.
32 | 2. Princeton University and Stanford University make no representations or warranties regarding the Database, including but not limited to warranties of non-infringement or fitness for a particular purpose.
33 | 3. Researcher accepts full responsibility for his or her use of the Database and shall defend and indemnify Princeton University and Stanford University, including their employees, Trustees, officers and agents, against any and all claims arising from Researcher's use of the Database, including but not limited to Researcher's use of any copies of copyrighted 3D models that he or she may create from the Database.
34 | 4. Researcher may provide research associates and colleagues with access to the Database provided that they first agree to be bound by these terms and conditions.
35 | 5. Princeton University and Stanford University reserve the right to terminate Researcher's access to the Database at any time.
36 | 6. If Researcher is employed by a for-profit, commercial entity, Researcher's employer shall also be bound by these terms and conditions, and Researcher hereby represents that he or she is fully authorized to enter into this agreement on behalf of such employer.
37 | 7. The law of the State of New Jersey shall apply to all disputes under this agreement.
38 |
39 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Equivariant Neural Rendering
2 |
3 | This repo contains code to reproduce all experiments in [Equivariant Neural Rendering](https://arxiv.org/abs/2006.07630) by [E. Dupont](https://emiliendupont.github.io), [M. A. Bautista](https://scholar.google.com/citations?user=ZrRs-qoAAAAJ&hl=en), [A. Colburn](https://www.colburn.org), [A. Sankar](https://scholar.google.com/citations?user=6ZDIdEAAAAAJ&hl=en), [C. Guestrin](https://homes.cs.washington.edu/~guestrin/), [J. Susskind](https://scholar.google.com/citations?user=Sv2TGqsAAAAJ&hl=en), [Q. Shan](http://shanqi.github.io), ICML 2020.
4 |
5 |
6 |
7 |
8 |
9 | ### Pre-trained models
10 |
11 | The weights for the trained chairs model are provided in `trained-models/chairs.pt`.
12 |
13 | The other pre-trained models are located https://icml20-prod.cdn-apple.com/eqn-data/models/pre-trained_models.zip. They should be downloaded and placed into the trained-models directory. A small model chairs.pt is included in the git repo.
14 |
15 | ## Examples
16 |
17 |
18 |
19 | ## Requirements
20 |
21 | The requirements can be directly installed from PyPi with `pip install -r requirements.txt`. Running the code requires `python3.6` or higher.
22 |
23 | ## Datasets
24 |
25 | - ShapeNet chairs: https://icml20-prod.cdn-apple.com/eqn-data/data/chairs.zip
26 | - ShapeNet cars: https://icml20-prod.cdn-apple.com/eqn-data/data/cars.zip
27 | - MugsHQ: https://icml20-prod.cdn-apple.com/eqn-data/data/mugs.zip
28 | - 3D mountains: https://icml20-prod.cdn-apple.com/eqn-data/data/mountains.zip
29 |
30 |
31 | each zip file will expand into 3 separate components and a readme e.g:
32 | - `cars-train.zip`
33 | - `cars-val.zip`
34 | - `cars-test.zip`
35 | - `readme.txt` containing the license terms.
36 |
37 | A few example images are provided in `imgs/example-data/`.
38 |
39 | The chairs and car datasets were created with the help of [Vincent Sitzmann](https://vsitzmann.github.io).
40 |
41 | Satellite imagery © 2020 Maxar Technologies.
42 |
43 | We thank Bernhard Vogl (salzamt@dativ.at) for the lightmaps. The MugsHQ were rendered utilizing an environmental map located at http://dativ.at/lightprobes.
44 |
45 | ## Usage
46 |
47 | ### Training a model
48 |
49 | To train a model, run the following:
50 |
51 | ```
52 | python experiments.py config.json
53 | ```
54 |
55 | This supports both single and multi-GPU training (see `config.json` for detailed training options). Note that you need to download the datasets before running this command.
56 |
57 | ### Quantitative evaluation
58 |
59 | To evaluate a model, run the following:
60 |
61 | ```
62 | python evaluate_psnr.py
63 | ```
64 |
65 | This will measure the performance (in PSNR) of a trained model on a test dataset.
66 |
67 | ### Model exploration and visualization
68 |
69 | The jupyter notebook `exploration.ipynb` shows how to use a trained model to infer a scene representation from a single image and how to use this representation to render novel views.
70 |
71 |
72 | ## Coordinate system
73 |
74 | The diagram below details the coordinate system we use for the voxel grid. Due to the manner in which images are stored in arrays and the way PyTorch's `affine_grid` and `grid_sample` functions work, this is a slightly unusual coordinate system. Note that `theta` and `phi` correspond to elevation and azimuth rotations of the **camera** around the scene representation. Note also that these are left handed rotations. Full details of the voxel rotation function can be found in `transforms3d/rotations.py`.
75 |
76 |
77 |
78 | ## Citing
79 |
80 | If you find this code useful in your research, consider citing with
81 |
82 | ```
83 | @article{dupont2020equivariant,
84 | title={Equivariant Neural Rendering},
85 | author={Dupont, Emilien and Miguel Angel, Bautista and Colburn, Alex and Sankar, Aditya and Guestrin, Carlos and Susskind, Josh and Shan, Qi},
86 | journal={arXiv preprint arXiv:2006.07630},
87 | year={2020}
88 | }
89 | ```
90 |
91 | ## License
92 |
93 | This project is licensed under the Apple Sample Code License
94 |
95 |
--------------------------------------------------------------------------------
/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "id": "chairs-experiment",
3 | "path_to_data": "chairs-train",
4 | "path_to_test_data": "chairs-val",
5 | "multi_gpu": false,
6 | "batch_size": 16,
7 | "lr": 2e-4,
8 | "epochs": 100,
9 | "loss_type": "l1",
10 | "ssim_loss_weight": 0.05,
11 | "save_freq": 1,
12 | "img_shape": [3, 128, 128],
13 | "channels_2d": [64, 64, 128, 128, 128, 128, 256, 256, 128, 128, 128],
14 | "strides_2d": [1, 1, 2, 1, 2, 1, 2, 1, -2, 1, 1],
15 | "channels_3d": [32, 32, 128, 128, 128, 64, 64, 64],
16 | "strides_3d": [1, 1, 2, 1, 1, -2, 1, 1],
17 | "num_channels_projection": [512, 256, 256],
18 | "num_channels_inv_projection": [256, 512, 1024],
19 | "mode": "bilinear"
20 | }
21 |
--------------------------------------------------------------------------------
/evaluate_psnr.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import torch
3 | from misc.dataloaders import scene_render_dataset
4 | from misc.quantitative_evaluation import get_dataset_psnr
5 | from models.neural_renderer import load_model
6 |
7 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
8 |
9 | # Get path to experiment folder from command line arguments
10 | if len(sys.argv) != 3:
11 | raise(RuntimeError("Wrong arguments, use python experiments_psnr.py "))
12 | model_path = sys.argv[1]
13 | data_dir = sys.argv[2] # This is usually one of "chairs-test" and "cars-test"
14 |
15 | # Load model
16 | model = load_model(model_path)
17 | model = model.to(device)
18 |
19 | # Initialize dataset
20 | dataset = scene_render_dataset(path_to_data=data_dir, img_size=(3, 128, 128),
21 | crop_size=128, allow_odd_num_imgs=True)
22 |
23 | # Calculate PSNR
24 | with torch.no_grad():
25 | psnrs = get_dataset_psnr(device, model, dataset, source_img_idx_shift=64,
26 | batch_size=125, max_num_scenes=None)
27 |
--------------------------------------------------------------------------------
/experiments.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import sys
4 | import time
5 | import torch
6 | from misc.dataloaders import scene_render_dataloader
7 | from models.neural_renderer import NeuralRenderer
8 | from training.training import Trainer
9 |
10 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11 |
12 | # Get path to data from command line arguments
13 | if len(sys.argv) != 2:
14 | raise(RuntimeError("Wrong arguments, use python experiments.py "))
15 | path_to_config = sys.argv[1]
16 |
17 | # Open config file
18 | with open(path_to_config) as file:
19 | config = json.load(file)
20 |
21 | # Set up directory to store experiments
22 | timestamp = time.strftime("%Y-%m-%d_%H-%M")
23 | directory = "{}_{}".format(timestamp, config["id"])
24 | if not os.path.exists(directory):
25 | os.makedirs(directory)
26 |
27 | # Save config file in directory
28 | with open(directory + '/config.json', 'w') as file:
29 | json.dump(config, file)
30 |
31 | # Set up renderer
32 | model = NeuralRenderer(
33 | img_shape=config["img_shape"],
34 | channels_2d=config["channels_2d"],
35 | strides_2d=config["strides_2d"],
36 | channels_3d=config["channels_3d"],
37 | strides_3d=config["strides_3d"],
38 | num_channels_inv_projection=config["num_channels_inv_projection"],
39 | num_channels_projection=config["num_channels_projection"],
40 | mode=config["mode"]
41 | )
42 |
43 | model.print_model_info()
44 |
45 | model = model.to(device)
46 |
47 | if config["multi_gpu"]:
48 | model = torch.nn.DataParallel(model)
49 |
50 | # Set up trainer for renderer
51 | trainer = Trainer(device, model, lr=config["lr"],
52 | rendering_loss_type=config["loss_type"],
53 | ssim_loss_weight=config["ssim_loss_weight"])
54 |
55 | dataloader = scene_render_dataloader(path_to_data=config["path_to_data"],
56 | batch_size=config["batch_size"],
57 | img_size=config["img_shape"],
58 | crop_size=128)
59 |
60 | # Optionally set up test_dataloader
61 | if config["path_to_test_data"]:
62 | test_dataloader = scene_render_dataloader(path_to_data=config["path_to_test_data"],
63 | batch_size=config["batch_size"],
64 | img_size=config["img_shape"],
65 | crop_size=128)
66 | else:
67 | test_dataloader = None
68 |
69 | print("PID: {}".format(os.getpid()))
70 |
71 | # Train renderer, save generated images, losses and model
72 | trainer.train(dataloader, config["epochs"], save_dir=directory,
73 | save_freq=config["save_freq"], test_dataloader=test_dataloader)
74 |
75 | # Print best losses
76 | print("Model id: {}".format(config["id"]))
77 | print("Best train loss: {:.4f}".format(min(trainer.epoch_loss_history["total"])))
78 | print("Best validation loss: {:.4f}".format(min(trainer.val_loss_history["total"])))
79 |
--------------------------------------------------------------------------------
/imgs/all_datasets.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-equivariant-neural-rendering/e7ecfb1b4f93fc89cd4912af64e8223c6523efa8/imgs/all_datasets.gif
--------------------------------------------------------------------------------
/imgs/coordinate-system.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-equivariant-neural-rendering/e7ecfb1b4f93fc89cd4912af64e8223c6523efa8/imgs/coordinate-system.png
--------------------------------------------------------------------------------
/imgs/example-data/chair1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-equivariant-neural-rendering/e7ecfb1b4f93fc89cd4912af64e8223c6523efa8/imgs/example-data/chair1.png
--------------------------------------------------------------------------------
/imgs/example-data/chair2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-equivariant-neural-rendering/e7ecfb1b4f93fc89cd4912af64e8223c6523efa8/imgs/example-data/chair2.png
--------------------------------------------------------------------------------
/imgs/example-data/chair3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-equivariant-neural-rendering/e7ecfb1b4f93fc89cd4912af64e8223c6523efa8/imgs/example-data/chair3.png
--------------------------------------------------------------------------------
/imgs/example-data/chair4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-equivariant-neural-rendering/e7ecfb1b4f93fc89cd4912af64e8223c6523efa8/imgs/example-data/chair4.png
--------------------------------------------------------------------------------
/misc/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-equivariant-neural-rendering/e7ecfb1b4f93fc89cd4912af64e8223c6523efa8/misc/__init__.py
--------------------------------------------------------------------------------
/misc/dataloaders.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import json
3 | import torch
4 | from numpy import float32 as np_float32
5 | from PIL import Image
6 | from torch.utils.data import Dataset, DataLoader, Sampler
7 | from torchvision import transforms
8 |
9 |
10 | def scene_render_dataloader(path_to_data='chairs-train', batch_size=16,
11 | img_size=(3, 128, 128), crop_size=128):
12 | """Dataloader for scene render datasets. Returns scene renders in pairs,
13 | i.e. 1st and 2nd images are of some scene, 3rd and 4th are of some different
14 | scene and so on.
15 |
16 | Args:
17 | path_to_data (string): Path to folder containing dataset.
18 | batch_size (int): Batch size for data.
19 | img_size (tuple of ints): Size of output images.
20 | crop_size (int): Size at which to center crop rendered images.
21 |
22 | Notes:
23 | Batch size must be even.
24 | """
25 | assert batch_size % 2 == 0, "Batch size is {} but must be even".format(batch_size)
26 |
27 | dataset = scene_render_dataset(path_to_data, img_size, crop_size)
28 |
29 | sampler = RandomPairSampler(dataset)
30 |
31 | return DataLoader(dataset, batch_size=batch_size, sampler=sampler,
32 | drop_last=True)
33 |
34 |
35 | def scene_render_dataset(path_to_data='chairs-train', img_size=(3, 128, 128),
36 | crop_size=128, allow_odd_num_imgs=False):
37 | """Helper function for creating a scene render dataset.
38 |
39 | Args:
40 | path_to_data (string): Path to folder containing dataset.
41 | img_size (tuple of ints): Size of output images.
42 | crop_size (int): Size at which to center crop rendered images.
43 | allow_odd_num_imgs (int): If True, allows datasets with an odd number
44 | of views. Such a dataset cannot be used for training, since each
45 | training iteration requires a *pair* of images. Datasets with an odd
46 | number of images are used for PSNR calculations.
47 | """
48 | img_transform = transforms.Compose([
49 | transforms.CenterCrop(crop_size),
50 | transforms.Resize(img_size[1:]),
51 | transforms.ToTensor()
52 | ])
53 |
54 | dataset = SceneRenderDataset(path_to_data=path_to_data,
55 | img_transform=img_transform,
56 | allow_odd_num_imgs=allow_odd_num_imgs)
57 |
58 | return dataset
59 |
60 |
61 | class SceneRenderDataset(Dataset):
62 | """Dataset of rendered scenes and their corresponding camera angles.
63 |
64 | Args:
65 | path_to_data (string): Path to folder containing dataset.
66 | img_transform (torchvision.transform): Transforms to be applied to
67 | images.
68 | allow_odd_num_imgs (bool): If True, allows datasets with an odd number
69 | of views. Such a dataset cannot be used for training, since each
70 | training iteration requires a *pair* of images.
71 |
72 | Notes:
73 | - Image paths must be of the form "XXXXX.png" where XXXXX are *five*
74 | integers indexing the image.
75 | - We assume there are the same number of rendered images for each scene
76 | and that this number is even.
77 | - We assume angles are given in degrees.
78 | """
79 | def __init__(self, path_to_data='chairs-train', img_transform=None,
80 | allow_odd_num_imgs=False):
81 | self.path_to_data = path_to_data
82 | self.img_transform = img_transform
83 | self.allow_odd_num_imgs = allow_odd_num_imgs
84 | self.data = []
85 | # Each folder contains a single scene with different rendering
86 | # parameters and views
87 | self.scene_paths = glob.glob(path_to_data + '/*')
88 | self.scene_paths.sort() # Ensure consistent ordering of scenes
89 | self.num_scenes = len(self.scene_paths)
90 | # Extract number of rendered images per object (which we assume is constant)
91 | self.num_imgs_per_scene = len(glob.glob(self.scene_paths[0] + '/*.png'))
92 | # If number of images per scene is not even, drop last image
93 | if self.num_imgs_per_scene % 2 != 0:
94 | if not self.allow_odd_num_imgs:
95 | self.num_imgs_per_scene -= 1
96 | # For each scene, extract its rendered views and render parameters
97 | for scene_path in self.scene_paths:
98 | # Name of folder defines scene name
99 | scene_name = scene_path.split('/')[-1]
100 |
101 | # Load render parameters
102 | with open(scene_path + '/render_params.json') as f:
103 | render_params = json.load(f)
104 |
105 | # Extract path to rendered images of scene
106 | img_paths = glob.glob(scene_path + '/*.png')
107 | img_paths.sort() # Ensure consistent ordering of images
108 | # Ensure number of image paths is even
109 | img_paths = img_paths[:self.num_imgs_per_scene]
110 |
111 | for img_path in img_paths:
112 | # Extract image filename
113 | img_file = img_path.split('/')[-1]
114 | # Filenames are of the type ".png", so extract this
115 | # index to match with render parameters.
116 | img_idx = img_file.split('.')[0][-5:] # This should be a string
117 | # Convert render parameters to float32
118 | img_params = {key: np_float32(value)
119 | for key, value in render_params[img_idx].items()}
120 | self.data.append({
121 | "scene_name": scene_name,
122 | "img_path": img_path,
123 | "render_params": img_params
124 | })
125 |
126 | def __len__(self):
127 | return len(self.data)
128 |
129 | def __getitem__(self, idx):
130 | img_path = self.data[idx]["img_path"]
131 | render_params = self.data[idx]["render_params"]
132 |
133 | img = Image.open(img_path)
134 |
135 | # Transform images
136 | if self.img_transform:
137 | img = self.img_transform(img)
138 |
139 | # Note some images may contain 4 channels (i.e. RGB + alpha), we only
140 | # keep RGB channels
141 | data_item = {
142 | "img": img[:3],
143 | "scene_name": self.data[idx]["scene_name"],
144 | "render_params": self.data[idx]["render_params"]
145 | }
146 |
147 | return data_item
148 |
149 |
150 | class RandomPairSampler(Sampler):
151 | """Samples random elements in pairs. Dataset is assumed to be composed of a
152 | number of scenes, each rendered in a number of views. This sampler returns
153 | rendered image in pairs. I.e. for a batch of size 6, it would return e.g.:
154 |
155 | [object 4 - img 5,
156 | object 4 - img 12,
157 | object 6 - img 3,
158 | object 6 - img 19,
159 | object 52 - img 10,
160 | object 52 - img 3]
161 |
162 |
163 | Arguments:
164 | dataset (Dataset): Dataset to sample from. This will typically be an
165 | instance of SceneRenderDataset.
166 | """
167 |
168 | def __init__(self, dataset):
169 | self.dataset = dataset
170 |
171 | def __iter__(self):
172 | num_scenes = self.dataset.num_scenes
173 | num_imgs_per_scene = self.dataset.num_imgs_per_scene
174 |
175 | # Sample num_imgs_per_scene / 2 permutations of the objects
176 | scene_permutations = [torch.randperm(num_scenes) for _ in range(num_imgs_per_scene // 2)]
177 | # For each scene, sample a permutation of its images
178 | img_permutations = [torch.randperm(num_imgs_per_scene) for _ in range(num_scenes)]
179 |
180 | data_permutation = []
181 |
182 | for i, scene_permutation in enumerate(scene_permutations):
183 | for scene_idx in scene_permutation:
184 | # Extract image permutation for this object
185 | img_permutation = img_permutations[scene_idx]
186 | # Add 2 images of this object to data_permutation
187 | data_permutation.append(scene_idx.item() * num_imgs_per_scene + img_permutation[2*i].item())
188 | data_permutation.append(scene_idx.item() * num_imgs_per_scene + img_permutation[2*i + 1].item())
189 |
190 | return iter(data_permutation)
191 |
192 | def __len__(self):
193 | return len(self.dataset)
194 |
195 |
196 | def create_batch_from_data_list(data_list):
197 | """Given a list of datapoints, create a batch.
198 |
199 | Args:
200 | data_list (list): List of items returned by SceneRenderDataset.
201 | """
202 | imgs = []
203 | azimuths = []
204 | elevations = []
205 | for data_item in data_list:
206 | img, render_params = data_item["img"], data_item["render_params"]
207 | azimuth, elevation = render_params["azimuth"], render_params["elevation"]
208 | imgs.append(img.unsqueeze(0))
209 | azimuths.append(torch.Tensor([azimuth]))
210 | elevations.append(torch.Tensor([elevation]))
211 | imgs = torch.cat(imgs, dim=0)
212 | azimuths = torch.cat(azimuths)
213 | elevations = torch.cat(elevations)
214 | return imgs, azimuths, elevations
215 |
--------------------------------------------------------------------------------
/misc/quantitative_evaluation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from misc.dataloaders import create_batch_from_data_list
4 |
5 |
6 | def get_dataset_psnr(device, model, dataset, source_img_idx_shift=64,
7 | batch_size=10, max_num_scenes=None):
8 | """Returns PSNR for each scene in a dataset by comparing the view predicted
9 | by a model and the ground truth view.
10 |
11 | Args:
12 | device (torch.device): Device to perform PSNR calculation on.
13 | model (models.neural_renderer.NeuralRenderer): Model to evaluate.
14 | dataset (misc.dataloaders.SceneRenderDataset): Dataset to evaluate model
15 | performance on. Should be one of "chairs-test" or "cars-test".
16 | source_img_idx_shift (int): Index of source image for each scene. For
17 | example if 00064.png is the source view, then
18 | source_img_idx_shift = 64.
19 | batch_size (int): Batch size to use when generating predictions. This
20 | should be a divisor of the number of images per scene.
21 | max_num_scenes (None or int): Optionally limit the maximum number of
22 | scenes to calculate PSNR for.
23 |
24 | Notes:
25 | This function should be used with the ShapeNet chairs and cars *test*
26 | sets.
27 | """
28 | num_imgs_per_scene = dataset.num_imgs_per_scene
29 | # Set number of scenes to calculate
30 | num_scenes = dataset.num_scenes
31 | if max_num_scenes is not None:
32 | num_scenes = min(max_num_scenes, num_scenes)
33 | # Calculate number of batches per scene
34 | assert (num_imgs_per_scene - 1) % batch_size == 0, "Batch size {} must divide number of images per scene {}."
35 | # Comparison are made against all images except the source image (and
36 | # therefore subtract 1 from total number of images)
37 | batches_per_scene = (num_imgs_per_scene - 1) // batch_size
38 | # Initialize psnr values
39 | psnrs = []
40 | for i in range(num_scenes):
41 | # Extract source view
42 | source_img_idx = i * num_imgs_per_scene + source_img_idx_shift
43 | img_source = dataset[source_img_idx]["img"].unsqueeze(0).repeat(batch_size, 1, 1, 1).to(device)
44 | render_params = dataset[source_img_idx]["render_params"]
45 | azimuth_source = torch.Tensor([render_params["azimuth"]]).repeat(batch_size).to(device)
46 | elevation_source = torch.Tensor([render_params["elevation"]]).repeat(batch_size).to(device)
47 | # Infer source scene
48 | scenes = model.inverse_render(img_source)
49 |
50 | # Iterate over all other views of scene
51 | num_points_in_batch = 0
52 | data_list = []
53 | scene_psnr = 0.
54 | for j in range(num_imgs_per_scene):
55 | if j == source_img_idx_shift:
56 | continue # Do not compare against same image
57 | # Add new image to list of images we want to compare to
58 | data_list.append(dataset[i * num_imgs_per_scene + j])
59 | num_points_in_batch += 1
60 | # If we have filled up a batch, make psnr calculation
61 | if num_points_in_batch == batch_size:
62 | # Create batch for target data
63 | img_target, azimuth_target, elevation_target = create_batch_from_data_list(data_list)
64 | img_target = img_target.to(device)
65 | azimuth_target = azimuth_target.to(device)
66 | elevation_target = elevation_target.to(device)
67 | # Rotate scene and render image
68 | rotated = model.rotate_source_to_target(scenes, azimuth_source,
69 | elevation_source, azimuth_target,
70 | elevation_target)
71 | img_predicted = model.render(rotated).detach()
72 | scene_psnr += get_psnr(img_predicted, img_target)
73 | data_list = []
74 | num_points_in_batch = 0
75 |
76 | psnrs.append(scene_psnr / batches_per_scene)
77 |
78 | print("{}/{}: Current - {:.3f}, Mean - {:.4f}".format(i + 1,
79 | num_scenes,
80 | psnrs[-1],
81 | torch.mean(torch.Tensor(psnrs))))
82 |
83 | return psnrs
84 |
85 |
86 | def get_psnr(prediction, target):
87 | """Returns PSNR between a batch of predictions and a batch of targets.
88 |
89 | Args:
90 | prediction (torch.Tensor): Shape (batch_size, channels, height, width).
91 | target (torch.Tensor): Shape (batch_size, channels, height, width).
92 | """
93 | batch_size = prediction.shape[0]
94 | mse_per_pixel = F.mse_loss(prediction, target, reduction='none')
95 | mse_per_img = mse_per_pixel.view(batch_size, -1).mean(dim=1)
96 | psnr = 10 * torch.log10(1 / mse_per_img)
97 | return torch.mean(psnr).item()
98 |
--------------------------------------------------------------------------------
/misc/utils.py:
--------------------------------------------------------------------------------
1 | import models.layers
2 | import models.submodels
3 | import torch
4 | from math import pi
5 |
6 |
7 | def full_rotation_angle_sequence(num_steps):
8 | """Returns a sequence of angles corresponding to a full 360 degree rotation.
9 | Useful for generating gifs.
10 |
11 | Args:
12 | num_steps (int): Number of steps in sequence.
13 | """
14 | return torch.linspace(0., 360. - 360. / num_steps, num_steps)
15 |
16 |
17 | def constant_angle_sequence(num_steps, value=0.):
18 | """Returns a sequence of constant angles. Useful for generating gifs.
19 |
20 | Args:
21 | num_steps (int): Number of steps in sequence.
22 | value (float): Constant angle value.
23 | """
24 | return value * torch.ones(num_steps)
25 |
26 |
27 | def back_and_forth_angle_sequence(num_steps, start, end):
28 | """Returns a sequence of angles linearly increasing from start to end and
29 | back.
30 |
31 | Args:
32 | num_steps (int): Number of steps in sequence.
33 | start (float): Angle at which to start (in degrees).
34 | end (float): Angle at which to end (in degrees).
35 | """
36 | half_num_steps = int(num_steps / 2)
37 | # Increase angle from start to end
38 | first = torch.linspace(start, end - end / half_num_steps, half_num_steps)
39 | # Decrease angle from end to start
40 | second = torch.linspace(end, start - start / half_num_steps, half_num_steps)
41 | # Return combined sequence of increasing and decreasing angles
42 | return torch.cat([first, second], dim=0)
43 |
44 |
45 | def sinusoidal_angle_sequence(num_steps, minimum, maximum):
46 | """Returns a sequence of angles sinusoidally varying between minimum and
47 | maximum.
48 |
49 | Args:
50 | num_steps (int): Number of steps in sequence.
51 | start (float): Angle at which to start (in degrees).
52 | end (float): Angle at which to end (in degrees).
53 | """
54 | period = 2 * pi * torch.linspace(0., 1. - 1. / num_steps, num_steps)
55 | return .5 * (minimum + maximum + (maximum - minimum) * torch.sin(period))
56 |
57 |
58 | def sine_squared_angle_sequence(num_steps, start, end):
59 | """Returns a sequence of angles increasing from start to end and back as the
60 | sine squared function.
61 |
62 | Args:
63 | num_steps (int): Number of steps in sequence.
64 | start (float): Angle at which to start (in degrees).
65 | end (float): Angle at which to end (in degrees).
66 | """
67 | half_period = pi * torch.linspace(0., 1. - 1. / num_steps, num_steps)
68 | return start + (end - start) * torch.sin(half_period) ** 2
69 |
70 |
71 | def count_parameters(model):
72 | """Returns number of trainable parameters in a model."""
73 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
74 |
75 |
76 | def get_layers_info(model):
77 | """Returns information about input shapes, output shapes and number of
78 | parameters in every block of model.
79 |
80 | Args:
81 | model (torch.nn.Module): Model to analyse. This will typically be a
82 | submodel of models.neural_renderer.NeuralRenderer.
83 | """
84 | in_shape = model.input_shape
85 | layers_info = []
86 |
87 | if isinstance(model, models.submodels.Projection):
88 | out_shape = (in_shape[0] * in_shape[1], *in_shape[2:])
89 | layer_info = {"name": "Reshape", "in_shape": in_shape,
90 | "out_shape": out_shape, "num_params": 0}
91 | layers_info.append(layer_info)
92 | in_shape = out_shape
93 |
94 | for layer in model.forward_layers:
95 | if isinstance(layer, torch.nn.Conv2d):
96 | if layer.stride[0] == 1:
97 | out_shape = (layer.out_channels, *in_shape[1:])
98 | elif layer.stride[0] == 2:
99 | out_shape = (layer.out_channels, in_shape[1] // 2, in_shape[2] // 2)
100 | name = "Conv2D"
101 | elif isinstance(layer, torch.nn.ConvTranspose2d):
102 | if layer.stride[0] == 1:
103 | out_shape = (layer.out_channels, *in_shape[1:])
104 | elif layer.stride[0] == 2:
105 | out_shape = (layer.out_channels, in_shape[1] * 2, in_shape[2] * 2)
106 | name = "ConvTr2D"
107 | elif isinstance(layer, models.layers.ResBlock2d):
108 | out_shape = in_shape
109 | name = "ResBlock2D"
110 | elif isinstance(layer, torch.nn.Conv3d):
111 | if layer.stride[0] == 1:
112 | out_shape = (layer.out_channels, *in_shape[1:])
113 | elif layer.stride[0] == 2:
114 | out_shape = (layer.out_channels, in_shape[1] // 2, in_shape[2] // 2, in_shape[3] // 2)
115 | name = "Conv3D"
116 | elif isinstance(layer, torch.nn.ConvTranspose3d):
117 | if layer.stride[0] == 1:
118 | out_shape = (layer.out_channels, *in_shape[1:])
119 | elif layer.stride[0] == 2:
120 | out_shape = (layer.out_channels, in_shape[1] * 2, in_shape[2] * 2, in_shape[3] * 2)
121 | name = "ConvTr3D"
122 | elif isinstance(layer, models.layers.ResBlock3d):
123 | out_shape = in_shape
124 | name = "ResBlock3D"
125 | else:
126 | # If layer is just an activation layer, skip
127 | continue
128 |
129 | num_params = count_parameters(layer)
130 | layer_info = {"name": name, "in_shape": in_shape,
131 | "out_shape": out_shape, "num_params": num_params}
132 | layers_info.append(layer_info)
133 |
134 | in_shape = out_shape
135 |
136 | if isinstance(model, models.submodels.InverseProjection):
137 | layer_info = {"name": "Reshape", "in_shape": in_shape,
138 | "out_shape": model.output_shape, "num_params": 0}
139 | layers_info.append(layer_info)
140 |
141 | return layers_info
142 |
143 |
144 | def pretty_print_layers_info(model, title):
145 | """Prints information about a model.
146 |
147 | Args:
148 | model (see get_layers_info)
149 | title (string): Title of model.
150 | """
151 | # Extract layers info for model
152 | layers_info = get_layers_info(model)
153 | # Print information in a nice format
154 | print(title)
155 | print("-" * len(title))
156 | print("{: <12} \t {: <14} \t {: <14} \t {: <10} \t {: <10}".format("name", "in_shape", "out_shape", "num_params", "feat_size"))
157 | print("---------------------------------------------------------------------------------------------")
158 |
159 | min_feat_size = 2 ** 20 # Some huge number
160 | for info in layers_info:
161 | feat_size = tuple_product(info["out_shape"])
162 | print("{: <12} \t {: <14} \t {: <14} \t {: <10} \t {: <10}".format(info["name"],
163 | str(info["in_shape"]),
164 | str(info["out_shape"]),
165 | info["num_params"],
166 | feat_size))
167 | if feat_size < min_feat_size:
168 | min_feat_size = feat_size
169 | print("---------------------------------------------------------------------------------------------")
170 | # Only print model info if model is not empty
171 | if len(layers_info):
172 | print("{: <12} \t {: <14} \t {: <14} \t {: <10} \t {: <10}".format("Total",
173 | str(layers_info[0]["in_shape"]),
174 | str(layers_info[-1]["out_shape"]),
175 | count_parameters(model),
176 | min_feat_size))
177 |
178 |
179 | def tuple_product(input_tuple):
180 | """Returns product of elements in a tuple."""
181 | product = 1
182 | for elem in input_tuple:
183 | product *= elem
184 | return product
185 |
--------------------------------------------------------------------------------
/misc/viz.py:
--------------------------------------------------------------------------------
1 | import imageio
2 | import torch
3 | import torchvision
4 | from misc.dataloaders import create_batch_from_data_list
5 |
6 |
7 | def generate_novel_views(model, img_source, azimuth_source, elevation_source,
8 | azimuth_shifts, elevation_shifts):
9 | """Generates novel views of an image by inferring its scene representation,
10 | rotating it and rendering novel views. Returns a batch of images
11 | corresponding to the novel views.
12 |
13 | Args:
14 | model (models.neural_renderer.NeuralRenderer): Neural rendering model.
15 | img_source (torch.Tensor): Single image. Shape (channels, height, width).
16 | azimuth_source (torch.Tensor): Azimuth of source image. Shape (1,).
17 | elevation_source (torch.Tensor): Elevation of source image. Shape (1,).
18 | azimuth_shifts (torch.Tensor): Batch of angle shifts at which to
19 | generate novel views. Shape (num_views,).
20 | elevation_shifts (torch.Tensor): Batch of angle shifts at which to
21 | generate novel views. Shape (num_views,).
22 | """
23 | # No need to calculate gradients
24 | with torch.no_grad():
25 | num_views = len(azimuth_shifts)
26 | # Batchify image
27 | img_batch = img_source.unsqueeze(0)
28 | # Infer scene
29 | scenes = model.inverse_render(img_batch)
30 | # Copy scene for each target view
31 | scenes_batch = scenes.repeat(num_views, 1, 1, 1, 1)
32 | # Batchify azimuth and elevation source
33 | azimuth_source_batch = azimuth_source.repeat(num_views)
34 | elevation_source_batch = elevation_source.repeat(num_views)
35 | # Calculate azimuth and elevation targets
36 | azimuth_target = azimuth_source_batch + azimuth_shifts
37 | elevation_target = elevation_source_batch + elevation_shifts
38 | # Rotate scenes
39 | rotated = model.rotate_source_to_target(scenes_batch, azimuth_source_batch,
40 | elevation_source_batch,
41 | azimuth_target, elevation_target)
42 | # Render images
43 | return model.render(rotated).detach()
44 |
45 |
46 | def batch_generate_novel_views(model, imgs_source, azimuth_source,
47 | elevation_source, azimuth_shifts,
48 | elevation_shifts):
49 | """Generates novel views for a batch of images. Returns a list of batches of
50 | images, where each item in the list corresponds to a novel view for all
51 | images.
52 |
53 | Args:
54 | model (models.neural_renderer.NeuralRenderer): Neural rendering model.
55 | imgs_source (torch.Tensor): Source images. Shape (batch_size, channels, height, width).
56 | azimuth_source (torch.Tensor): Azimuth of source. Shape (batch_size,).
57 | elevation_source (torch.Tensor): Elevation of source. Shape (batch_size,).
58 | azimuth_shifts (torch.Tensor): Batch of angle shifts at which to generate
59 | novel views. Shape (num_views,).
60 | elevation_shifts (torch.Tensor): Batch of angle shifts at which to
61 | generate novel views. Shape (num_views,).
62 | """
63 | num_imgs = imgs_source.shape[0]
64 | num_views = azimuth_shifts.shape[0]
65 |
66 | # Initialize novel views, i.e. a list of length num_views with each item
67 | # containing num_imgs images
68 | all_novel_views = [torch.zeros_like(imgs_source) for _ in range(num_views)]
69 |
70 | for i in range(num_imgs):
71 | # Generate novel views for single image
72 | novel_views = generate_novel_views(model, imgs_source[i],
73 | azimuth_source[i:i+1],
74 | elevation_source[i:i+1],
75 | azimuth_shifts, elevation_shifts).cpu()
76 | # Add to list of all novel_views
77 | for j in range(num_views):
78 | all_novel_views[j][i] = novel_views[j]
79 |
80 | return all_novel_views
81 |
82 |
83 | def dataset_novel_views(device, model, dataset, img_indices, azimuth_shifts,
84 | elevation_shifts):
85 | """Helper function for generating novel views from specific images in a
86 | dataset.
87 |
88 | Args:
89 | device (torch.device):
90 | model (models.neural_renderer.NeuralRenderer):
91 | dataset (misc.dataloaders.SceneRenderDataset):
92 | img_indices (tuple of ints): Indices of images in dataset to use as
93 | source views for novel view synthesis.
94 | azimuth_shifts (torch.Tensor): Batch of angle shifts at which to generate
95 | novel views. Shape (num_views,).
96 | elevation_shifts (torch.Tensor): Batch of angle shifts at which to
97 | generate novel views. Shape (num_views,).
98 | """
99 | # Extract image and pose information for all views
100 | data_list = []
101 | for img_idx in img_indices:
102 | data_list.append(dataset[img_idx])
103 | imgs_source, azimuth_source, elevation_source = create_batch_from_data_list(data_list)
104 | imgs_source = imgs_source.to(device)
105 | azimuth_source = azimuth_source.to(device)
106 | elevation_source = elevation_source.to(device)
107 | # Generate novel views
108 | return batch_generate_novel_views(model, imgs_source, azimuth_source,
109 | elevation_source, azimuth_shifts,
110 | elevation_shifts)
111 |
112 |
113 | def shapenet_test_novel_views(device, model, dataset, source_scenes_idx=(0, 1, 2, 3),
114 | source_img_idx_shift=64, subsample_target=5):
115 | """Helper function for generating novel views on an archimedean spiral for
116 | the test images for ShapeNet chairs and cars.
117 |
118 | Args:
119 | device (torch.device):
120 | model (models.neural_renderer.NeuralRenderer):
121 | dataset (misc.dataloaders.SceneRenderDataset): Test dataloader for a
122 | ShapeNet dataset.
123 | source_scenes_idx (tuple of ints): Indices of source scenes to use for
124 | generating novel views.
125 | source_img_idx_shift (int): Index of source image for each scene. For
126 | example if 00064.png is the source view, then
127 | source_img_idx_shift = 64.
128 | subsample_target (int): Amount by which to subsample target views. If
129 | set to 1, uses all 250 target views.
130 | """
131 | num_imgs = len(source_scenes_idx)
132 | # Extract source azimuths and elevations
133 | # Note we can extract this from the first scene since for the shapenet test
134 | # sets, the test poses are the same for all scenes
135 | render_params = dataset[source_img_idx_shift]["render_params"]
136 | azimuth_source = torch.Tensor([render_params["azimuth"]]).to(device)
137 | elevation_source = torch.Tensor([render_params["elevation"]]).to(device)
138 |
139 | # Extract target azimuths and elevations (do not use final view as it is
140 | # slightly off in dataset)
141 | azimuth_target = torch.zeros(dataset.num_imgs_per_scene - 1)
142 | elevation_target = torch.zeros(dataset.num_imgs_per_scene - 1)
143 | for i in range(dataset.num_imgs_per_scene - 1):
144 | render_params = dataset[i]["render_params"]
145 | azimuth_target[i] = torch.Tensor([render_params["azimuth"]])
146 | elevation_target[i] = torch.Tensor([render_params["elevation"]])
147 | # Move to GPU
148 | azimuth_target = azimuth_target.to(device)
149 | elevation_target = elevation_target.to(device)
150 | # Subsample
151 | azimuth_target = azimuth_target[::subsample_target]
152 | elevation_target = elevation_target[::subsample_target]
153 |
154 | # Calculate azimuth and elevation shifts
155 | azimuth_shifts = azimuth_target - azimuth_source
156 | elevation_shifts = elevation_target - elevation_source
157 |
158 | # Ensure source angles have same batch_size as imgs_source
159 | azimuth_source = azimuth_source.repeat(num_imgs)
160 | elevation_source = elevation_source.repeat(num_imgs)
161 |
162 | # Create source image batch
163 | imgs_source = torch.zeros(num_imgs, 3, 128, 128).to(device)
164 | for i in range(num_imgs):
165 | scene_idx = source_scenes_idx[i]
166 | source_img_idx = scene_idx * dataset.num_imgs_per_scene + source_img_idx_shift
167 | imgs_source[i] = dataset[source_img_idx]["img"].to(device)
168 |
169 | return batch_generate_novel_views(model, imgs_source, azimuth_source,
170 | elevation_source, azimuth_shifts,
171 | elevation_shifts)
172 |
173 |
174 | def save_generate_novel_views(filename, model, img_source, azimuth_source,
175 | elevation_source, azimuth_shifts,
176 | elevation_shifts):
177 | """Generates novel views of an image by inferring its scene representation,
178 | rotating it and rendering novel views. Saves the source image and novel
179 | views as png files.
180 |
181 | Args:
182 | filename (string): Filename root for saving images.
183 | model (models.neural_renderer.NeuralRenderer): Neural rendering model.
184 | img_source (torch.Tensor): Single image. Shape (channels, height, width).
185 | azimuth_source (torch.Tensor): Azimuth of source image. Shape (1,).
186 | elevation_source (torch.Tensor): Elevation of source image. Shape (1,).
187 | azimuth_shifts (torch.Tensor): Batch of angle shifts at which to
188 | generate novel views. Shape (num_views,).
189 | elevation_shifts (torch.Tensor): Batch of angle shifts at which to
190 | generate novel views. Shape (num_views,).
191 | """
192 | # Generate novel views
193 | novel_views = generate_novel_views(model, img_source, azimuth_source,
194 | elevation_source, azimuth_shifts,
195 | elevation_shifts)
196 | # Save original image
197 | torchvision.utils.save_image(img_source, filename + '.png', padding=4,
198 | pad_value=1.)
199 | # Save novel views (with white padding)
200 | torchvision.utils.save_image(novel_views, filename + '_novel.png',
201 | padding=4, pad_value=1.)
202 |
203 |
204 | def save_img_sequence_as_gif(img_sequence, filename, nrow=4):
205 | """Given a sequence of images as tensors, saves a gif of the images.
206 | If images are in batches, they are converted to a grid before being
207 | saved.
208 |
209 | Args:
210 | img_sequence (list of torch.Tensor): List of images. Tensors should
211 | have shape either (batch_size, channels, height, width) or shape
212 | (channels, height, width). If there is a batch dimension, images
213 | will be converted to a grid.
214 | filename (string): Path where gif will be saved. Should end in '.gif'.
215 | nrow (int): Number of rows in image grid, if image has batch dimension.
216 | """
217 | img_grid_sequence = []
218 | for img in img_sequence:
219 | if len(img.shape) == 4:
220 | img_grid = torchvision.utils.make_grid(img, nrow=nrow)
221 | else:
222 | img_grid = img
223 | # Convert to numpy array and from float in [0, 1] to int in [0, 255]
224 | # which is what imageio expects
225 | img_grid = (img_grid * 255.).byte().cpu().numpy().transpose(1, 2, 0)
226 | img_grid_sequence.append(img_grid)
227 | # Save gif
228 | imageio.mimwrite(filename, img_grid_sequence)
229 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-equivariant-neural-rendering/e7ecfb1b4f93fc89cd4912af64e8223c6523efa8/models/__init__.py
--------------------------------------------------------------------------------
/models/layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class ConvBlock2d(nn.Module):
6 | """Block of 1x1, 3x3, 1x1 convolutions with non linearities. Shape of input
7 | and output is the same.
8 |
9 | Args:
10 | in_channels (int): Number of channels in input.
11 | num_filters (list of ints): List of two ints with the number of filters
12 | for the first and second conv layers. Third conv layer must have the
13 | same number of input filters as there are channels.
14 | add_groupnorm (bool): If True adds GroupNorm.
15 | """
16 | def __init__(self, in_channels, num_filters, add_groupnorm=True):
17 | super(ConvBlock2d, self).__init__()
18 | if add_groupnorm:
19 | self.forward_layers = nn.Sequential(
20 | nn.GroupNorm(num_channels_to_num_groups(in_channels), in_channels),
21 | nn.LeakyReLU(0.2, True),
22 | nn.Conv2d(in_channels, num_filters[0], kernel_size=1, stride=1,
23 | bias=False),
24 | nn.GroupNorm(num_channels_to_num_groups(num_filters[0]), num_filters[0]),
25 | nn.LeakyReLU(0.2, True),
26 | nn.Conv2d(num_filters[0], num_filters[1], kernel_size=3,
27 | stride=1, padding=1, bias=False),
28 | nn.GroupNorm(num_channels_to_num_groups(num_filters[1]), num_filters[1]),
29 | nn.LeakyReLU(0.2, True),
30 | nn.Conv2d(num_filters[1], in_channels, kernel_size=1, stride=1,
31 | bias=False)
32 | )
33 | else:
34 | self.forward_layers = nn.Sequential(
35 | nn.LeakyReLU(0.2, True),
36 | nn.Conv2d(in_channels, num_filters[0], kernel_size=1, stride=1,
37 | bias=True),
38 | nn.LeakyReLU(0.2, True),
39 | nn.Conv2d(num_filters[0], num_filters[1], kernel_size=3,
40 | stride=1, padding=1, bias=True),
41 | nn.LeakyReLU(0.2, True),
42 | nn.Conv2d(num_filters[1], in_channels, kernel_size=1, stride=1,
43 | bias=True)
44 | )
45 |
46 | def forward(self, inputs):
47 | return self.forward_layers(inputs)
48 |
49 |
50 | class ConvBlock3d(nn.Module):
51 | """Block of 1x1, 3x3, 1x1 convolutions with non linearities. Shape of input
52 | and output is the same.
53 |
54 | Args:
55 | in_channels (int): Number of channels in input.
56 | num_filters (list of ints): List of two ints with the number of filters
57 | for the first and second conv layers. Third conv layer must have the
58 | same number of input filters as there are channels.
59 | add_groupnorm (bool): If True adds BatchNorm.
60 | """
61 | def __init__(self, in_channels, num_filters, add_groupnorm=True):
62 | super(ConvBlock3d, self).__init__()
63 | if add_groupnorm:
64 | self.forward_layers = nn.Sequential(
65 | nn.GroupNorm(num_channels_to_num_groups(in_channels), in_channels),
66 | nn.LeakyReLU(0.2, True),
67 | nn.Conv3d(in_channels, num_filters[0], kernel_size=1, stride=1,
68 | bias=False),
69 | nn.GroupNorm(num_channels_to_num_groups(num_filters[0]), num_filters[0]),
70 | nn.LeakyReLU(0.2, True),
71 | nn.Conv3d(num_filters[0], num_filters[1], kernel_size=3,
72 | stride=1, padding=1, bias=False),
73 | nn.GroupNorm(num_channels_to_num_groups(num_filters[1]), num_filters[1]),
74 | nn.LeakyReLU(0.2, True),
75 | nn.Conv3d(num_filters[1], in_channels, kernel_size=1, stride=1,
76 | bias=False)
77 | )
78 | else:
79 | self.forward_layers = nn.Sequential(
80 | nn.LeakyReLU(0.2, True),
81 | nn.Conv3d(in_channels, num_filters[0], kernel_size=1, stride=1,
82 | bias=True),
83 | nn.LeakyReLU(0.2, True),
84 | nn.Conv3d(num_filters[0], num_filters[1], kernel_size=3,
85 | stride=1, padding=1, bias=True),
86 | nn.LeakyReLU(0.2, True),
87 | nn.Conv3d(num_filters[1], in_channels, kernel_size=1, stride=1,
88 | bias=True)
89 | )
90 |
91 | def forward(self, inputs):
92 | return self.forward_layers(inputs)
93 |
94 |
95 | class ResBlock2d(nn.Module):
96 | """Residual block of 1x1, 3x3, 1x1 convolutions with non linearities. Shape
97 | of input and output is the same.
98 |
99 | Args:
100 | in_channels (int): Number of channels in input.
101 | num_filters (list of ints): List of two ints with the number of filters
102 | for the first and second conv layers. Third conv layer must have the
103 | same number of input filters as there are channels.
104 | add_groupnorm (bool): If True adds GroupNorm.
105 | """
106 | def __init__(self, in_channels, num_filters, add_groupnorm=True):
107 | super(ResBlock2d, self).__init__()
108 | self.residual_layers = ConvBlock2d(in_channels, num_filters,
109 | add_groupnorm)
110 |
111 | def forward(self, inputs):
112 | return inputs + self.residual_layers(inputs)
113 |
114 |
115 | class ResBlock3d(nn.Module):
116 | """Residual block of 1x1, 3x3, 1x1 convolutions with non linearities. Shape
117 | of input and output is the same.
118 |
119 | Args:
120 | in_channels (int): Number of channels in input.
121 | num_filters (list of ints): List of two ints with the number of filters
122 | for the first and second conv layers. Third conv layer must have the
123 | same number of input filters as there are channels.
124 | add_groupnorm (bool): If True adds GroupNorm.
125 | """
126 | def __init__(self, in_channels, num_filters, add_groupnorm=True):
127 | super(ResBlock3d, self).__init__()
128 | self.residual_layers = ConvBlock3d(in_channels, num_filters,
129 | add_groupnorm)
130 |
131 | def forward(self, inputs):
132 | return inputs + self.residual_layers(inputs)
133 |
134 |
135 | def num_channels_to_num_groups(num_channels):
136 | """Returns number of groups to use in a GroupNorm layer with a given number
137 | of channels. Note that these choices are hyperparameters.
138 |
139 | Args:
140 | num_channels (int): Number of channels.
141 | """
142 | if num_channels < 8:
143 | return 1
144 | if num_channels < 32:
145 | return 2
146 | if num_channels < 64:
147 | return 4
148 | if num_channels < 128:
149 | return 8
150 | if num_channels < 256:
151 | return 16
152 | else:
153 | return 32
154 |
--------------------------------------------------------------------------------
/models/neural_renderer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from misc.utils import pretty_print_layers_info, count_parameters
4 | from models.submodels import ResNet2d, ResNet3d, Projection, InverseProjection
5 | from models.rotation_layers import SphericalMask, Rotate3d
6 |
7 |
8 | class NeuralRenderer(nn.Module):
9 | """Implements a Neural Renderer with an implicit scene representation that
10 | allows both forward and inverse rendering.
11 |
12 | The forward pass from 3d scene to 2d image is (rendering):
13 | Scene representation (input) -> ResNet3d -> Projection -> ResNet2d ->
14 | Rendered image (output)
15 |
16 | The inverse pass from 2d image to 3d scene is (inverse rendering):
17 | Image (input) -> ResNet2d -> Inverse Projection -> ResNet3d -> Scene
18 | representation (output)
19 |
20 | Args:
21 | img_shape (tuple of ints): Shape of the image input to the model. Should
22 | be of the form (channels, height, width).
23 | channels_2d (tuple of ints): List of channels for 2D layers in inverse
24 | rendering model (image -> scene).
25 | strides_2d (tuple of ints): List of strides for 2D layers in inverse
26 | rendering model (image -> scene).
27 | channels_3d (tuple of ints): List of channels for 3D layers in inverse
28 | rendering model (image -> scene).
29 | strides_3d (tuple of ints): List of channels for 3D layers in inverse
30 | rendering model (image -> scene).
31 | num_channels_inv_projection (tuple of ints): Number of channels in each
32 | layer of inverse projection unit from 2D to 3D.
33 | num_channels_projection (tuple of ints): Number of channels in each
34 | layer of projection unit from 2D to 3D.
35 | mode (string): One of 'bilinear' and 'nearest' for interpolation mode
36 | used when rotating voxel grid.
37 |
38 | Notes:
39 | Given the inverse rendering channels and strides, the model will
40 | automatically build a forward renderer as the transpose of the inverse
41 | renderer.
42 | """
43 | def __init__(self, img_shape, channels_2d, strides_2d, channels_3d,
44 | strides_3d, num_channels_inv_projection, num_channels_projection,
45 | mode='bilinear'):
46 | super(NeuralRenderer, self).__init__()
47 | self.img_shape = img_shape
48 | self.channels_2d = channels_2d
49 | self.strides_2d = strides_2d
50 | self.channels_3d = channels_3d
51 | self.strides_3d = strides_3d
52 | self.num_channels_projection = num_channels_projection
53 | self.num_channels_inv_projection = num_channels_inv_projection
54 | self.mode = mode
55 |
56 | # Initialize layers
57 |
58 | # Inverse pass (image -> scene)
59 | # First transform image into a 2D representation
60 | self.inv_transform_2d = ResNet2d(self.img_shape, channels_2d,
61 | strides_2d)
62 |
63 | # Perform inverse projection from 2D to 3D
64 | input_shape = self.inv_transform_2d.output_shape
65 | self.inv_projection = InverseProjection(input_shape, num_channels_inv_projection)
66 |
67 | # Transform 3D inverse projection into a scene representation
68 | self.inv_transform_3d = ResNet3d(self.inv_projection.output_shape,
69 | channels_3d, strides_3d)
70 | # Add rotation layer
71 | self.rotation_layer = Rotate3d(self.mode)
72 |
73 | # Forward pass (scene -> image)
74 | # Forward renderer is just transpose of inverse renderer, so flip order
75 | # of channels and strides
76 | # Transform scene representation to 3D features
77 | forward_channels_3d = list(reversed(channels_3d))[1:] + [channels_3d[0]]
78 | forward_strides_3d = [-stride if abs(stride) == 2 else 1 for stride in list(reversed(strides_3d[1:]))] + [strides_3d[0]]
79 | self.transform_3d = ResNet3d(self.inv_transform_3d.output_shape,
80 | forward_channels_3d, forward_strides_3d)
81 |
82 | # Layer for projection of 3D representation to 2D representation
83 | self.projection = Projection(self.transform_3d.output_shape,
84 | num_channels_projection)
85 |
86 | # Transform 2D features to rendered image
87 | forward_channels_2d = list(reversed(channels_2d))[1:] + [channels_2d[0]]
88 | forward_strides_2d = [-stride if abs(stride) == 2 else 1 for stride in list(reversed(strides_2d[1:]))] + [strides_2d[0]]
89 | final_conv_channels_2d = img_shape[0]
90 | self.transform_2d = ResNet2d(self.projection.output_shape,
91 | forward_channels_2d, forward_strides_2d,
92 | final_conv_channels_2d)
93 |
94 | # Scene representation shape is output of inverse 3D transformation
95 | self.scene_shape = self.inv_transform_3d.output_shape
96 | # Add spherical mask before scene rotation
97 | self.spherical_mask = SphericalMask(self.scene_shape)
98 |
99 | def render(self, scene):
100 | """Renders a scene to an image.
101 |
102 | Args:
103 | scene (torch.Tensor): Shape (batch_size, channels, depth, height, width).
104 | """
105 | features_3d = self.transform_3d(scene)
106 | features_2d = self.projection(features_3d)
107 | return torch.sigmoid(self.transform_2d(features_2d))
108 |
109 | def inverse_render(self, img):
110 | """Maps an image to a (spherical) scene representation.
111 |
112 | Args:
113 | img (torch.Tensor): Shape (batch_size, channels, height, width).
114 | """
115 | # Transform image to 2D features
116 | features_2d = self.inv_transform_2d(img)
117 | # Perform inverse projection
118 | features_3d = self.inv_projection(features_2d)
119 | # Map 3D features to scene representation
120 | scene = self.inv_transform_3d(features_3d)
121 | # Ensure scene is spherical
122 | return self.spherical_mask(scene)
123 |
124 | def rotate(self, scene, rotation_matrix):
125 | """Rotates scene by rotation matrix.
126 |
127 | Args:
128 | scene (torch.Tensor): Shape (batch_size, channels, depth, height, width).
129 | rotation_matrix (torch.Tensor): Batch of rotation matrices of shape
130 | (batch_size, 3, 3).
131 | """
132 | return self.rotation_layer(scene, rotation_matrix)
133 |
134 | def rotate_source_to_target(self, scene, azimuth_source, elevation_source,
135 | azimuth_target, elevation_target):
136 | """Assuming the scene is being observed by a camera at
137 | (azimuth_source, elevation_source), rotates scene so camera is observing
138 | it at (azimuth_target, elevation_target).
139 |
140 | Args:
141 | scene (torch.Tensor): Shape (batch_size, channels, depth, height, width).
142 | azimuth_source (torch.Tensor): Shape (batch_size,). Azimuth of source.
143 | elevation_source (torch.Tensor): Shape (batch_size,). Elevation of source.
144 | azimuth_target (torch.Tensor): Shape (batch_size,). Azimuth of target.
145 | elevation_target (torch.Tensor): Shape (batch_size,). Elevation of target.
146 | """
147 | return self.rotation_layer.rotate_source_to_target(scene,
148 | azimuth_source,
149 | elevation_source,
150 | azimuth_target,
151 | elevation_target)
152 |
153 | def forward(self, batch):
154 | """Given a batch of images and poses, infers scene representations,
155 | rotates them into target poses and renders them into images.
156 |
157 | Args:
158 | batch (dict): A batch of images and poses as returned by
159 | misc.dataloaders.scene_render_dataloader.
160 |
161 | Notes:
162 | This *must* be a batch as returned by the scene render dataloader,
163 | i.e. the batch must be composed of pairs of images of the same
164 | scene. Specifically, the first time in the batch should be an image
165 | of scene A and the second item in the batch should be an image of
166 | scene A observed from a different pose. The third item should be an
167 | image of scene B and the fourth item should be an image scene B
168 | observed from a different pose (and so on).
169 | """
170 | # Slightly hacky way of extracting model device. Device on which
171 | # spherical is stored is the one where model is too
172 | device = self.spherical_mask.mask.device
173 | imgs = batch["img"].to(device)
174 | params = batch["render_params"]
175 | azimuth = params["azimuth"].to(device)
176 | elevation = params["elevation"].to(device)
177 |
178 | # Infer scenes from images
179 | scenes = self.inverse_render(imgs)
180 |
181 | # Rotate scenes so that for every pair of rendered images, the 1st
182 | # one will be reconstructed as the 2nd and then 2nd will be
183 | # reconstructed as the 1st
184 | swapped_idx = get_swapped_indices(azimuth.shape[0])
185 |
186 | # Each pair of indices in the azimuth vector corresponds to the same
187 | # scene at two different angles. Therefore performing a pairwise swap,
188 | # the first index will correspond to the second index in the original
189 | # vector. Since we want to rotate camera angle 1 to camera angle 2 and
190 | # vice versa, we can use these swapped angles to define a target
191 | # position for the camera
192 | azimuth_swapped = azimuth[swapped_idx]
193 | elevation_swapped = elevation[swapped_idx]
194 | scenes_swapped = \
195 | self.rotate_source_to_target(scenes, azimuth, elevation,
196 | azimuth_swapped, elevation_swapped)
197 |
198 | # Swap scenes, so rotated scenes match with original inferred scene.
199 | # Specifically, we have images x1, x2 from which we inferred the scenes
200 | # z1, z2. We then rotated these scenes into z1' and z2'. Now z1' should
201 | # be almost equal to z2 and z2' should be almost equal to z1, so we swap
202 | # the order of z1', z2' to z2', z1' so we can easily render them to
203 | # x1 and x2.
204 | scenes_rotated = scenes_swapped[swapped_idx]
205 |
206 | # Render scene using model
207 | rendered = self.render(scenes_rotated)
208 |
209 | return imgs, rendered, scenes, scenes_rotated
210 |
211 | def print_model_info(self):
212 | """Prints detailed information about model, such as how input shape is
213 | transformed to output shape and how many parameters are trained in each
214 | block.
215 | """
216 | print("Forward renderer")
217 | print("----------------\n")
218 | pretty_print_layers_info(self.transform_3d, "3D Layers")
219 | print("\n")
220 | pretty_print_layers_info(self.projection, "Projection")
221 | print("\n")
222 | pretty_print_layers_info(self.transform_2d, "2D Layers")
223 | print("\n")
224 |
225 | print("Inverse renderer")
226 | print("----------------\n")
227 | pretty_print_layers_info(self.inv_transform_2d, "Inverse 2D Layers")
228 | print("\n")
229 | pretty_print_layers_info(self.inv_projection, "Inverse Projection")
230 | print("\n")
231 | pretty_print_layers_info(self.inv_transform_3d, "Inverse 3D Layers")
232 | print("\n")
233 |
234 | print("Scene Representation:")
235 | print("\tShape: {}".format(self.scene_shape))
236 | # Size of scene representation corresponds to non zero entries of
237 | # spherical mask
238 | print("\tSize: {}\n".format(int(self.spherical_mask.mask.sum().item())))
239 |
240 | print("Number of parameters: {}\n".format(count_parameters(self)))
241 |
242 | def get_model_config(self):
243 | """Returns the complete model configuration as a dict."""
244 | return {
245 | "img_shape": self.img_shape,
246 | "channels_2d": self.channels_2d,
247 | "strides_2d": self.strides_2d,
248 | "channels_3d": self.channels_3d,
249 | "strides_3d": self.strides_3d,
250 | "num_channels_inv_projection": self.num_channels_inv_projection,
251 | "num_channels_projection": self.num_channels_projection,
252 | "mode": self.mode
253 | }
254 |
255 | def save(self, filename):
256 | """Saves model and its config.
257 |
258 | Args:
259 | filename (string): Path where model will be saved. Should end with
260 | '.pt' or '.pth'.
261 | """
262 | torch.save({
263 | "config": self.get_model_config(),
264 | "state_dict": self.state_dict()
265 | }, filename)
266 |
267 |
268 | def load_model(filename):
269 | """Loads a NeuralRenderer model from saved model config and weights.
270 |
271 | Args:
272 | filename (string): Path where model was saved.
273 | """
274 | model_dict = torch.load(filename, map_location="cpu")
275 | config = model_dict["config"]
276 | # Initialize a model based on config
277 | model = NeuralRenderer(
278 | img_shape=config["img_shape"],
279 | channels_2d=config["channels_2d"],
280 | strides_2d=config["strides_2d"],
281 | channels_3d=config["channels_3d"],
282 | strides_3d=config["strides_3d"],
283 | num_channels_inv_projection=config["num_channels_inv_projection"],
284 | num_channels_projection=config["num_channels_projection"],
285 | mode=config["mode"]
286 | )
287 | # Load weights into model
288 | model.load_state_dict(model_dict["state_dict"])
289 | return model
290 |
291 |
292 | def get_swapped_indices(length):
293 | """Returns a list of swapped index pairs. For example, if length = 6, then
294 | function returns [1, 0, 3, 2, 5, 4], i.e. every index pair is swapped.
295 |
296 | Args:
297 | length (int): Length of swapped indices.
298 | """
299 | return [i + 1 if i % 2 == 0 else i - 1 for i in range(length)]
300 |
--------------------------------------------------------------------------------
/models/rotation_layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from transforms3d.rotations import rotate, rotate_source_to_target
4 |
5 |
6 | class Rotate3d(nn.Module):
7 | """Layer used to rotate 3D feature maps.
8 |
9 | Args:
10 | mode (string): One of 'bilinear' and 'nearest' for interpolation mode
11 | used when resampling rotated values on the grid.
12 | """
13 | def __init__(self, mode='bilinear'):
14 | super(Rotate3d, self).__init__()
15 | self.mode = mode
16 |
17 | def forward(self, volume, rotation_matrix):
18 | """Rotates the volume by the rotation matrix.
19 |
20 | Args:
21 | volume (torch.Tensor): Shape (batch_size, channels, depth, height, width).
22 | rotation_matrix (torch.Tensor): Batch of rotation matrices of shape
23 | (batch_size, 3, 3).
24 | """
25 | return rotate(volume, rotation_matrix, mode=self.mode)
26 |
27 | def rotate_source_to_target(self, volume, azimuth_source, elevation_source,
28 | azimuth_target, elevation_target):
29 | """Rotates volume from source coordinate frame to target coordinate
30 | frame.
31 |
32 | Args:
33 | volume (torch.Tensor): Shape (batch_size, channels, depth, height, width).
34 | azimuth_source (torch.Tensor): Shape (batch_size,). Azimuth of
35 | source view in degrees.
36 | elevation_source (torch.Tensor): Shape (batch_size,). Elevation of
37 | source view in degrees.
38 | azimuth_target (torch.Tensor): Shape (batch_size,). Azimuth of
39 | target view in degrees.
40 | elevation_target (torch.Tensor): Shape (batch_size,). Elevation of
41 | target view in degrees.
42 | """
43 | return rotate_source_to_target(volume, azimuth_source, elevation_source,
44 | azimuth_target, elevation_target,
45 | mode=self.mode)
46 |
47 |
48 | class SphericalMask(nn.Module):
49 | """Sets all features outside the largest sphere embedded in a cubic tensor
50 | to zero.
51 |
52 | Args:
53 | input_shape (tuple of ints): Shape of 3D feature map. Should have the
54 | form (channels, depth, height, width).
55 | radius_fraction (float): Fraction of radius to keep as non zero. E.g.
56 | if radius_fraction=0.9, only elements within the sphere of radius
57 | 0.9 of half the cube length will not be zeroed. Must be in [0., 1.].
58 | """
59 | def __init__(self, input_shape, radius_fraction=1.):
60 | super(SphericalMask, self).__init__()
61 | # Check input
62 | _, depth, height, width = input_shape
63 | assert depth == height, "Depth, height, width are {}, {}, {} but must be equal.".format(depth, height, width)
64 | assert height == width, "Depth, height, width are {}, {}, {} but must be equal.".format(depth, height, width)
65 |
66 | self.input_shape = input_shape
67 |
68 | # Build spherical mask
69 | mask = torch.ones(input_shape)
70 | mask_center = (depth - 1) / 2 # Center of cube (in terms of index)
71 | radius = (depth - 1) / 2 # Distance from center to edge of cube is radius of sphere
72 | for i in range(depth):
73 | for j in range(height):
74 | for k in range(width):
75 | squared_distance = (mask_center - i) ** 2 + (mask_center - j) ** 2 + (mask_center - k) ** 2
76 | if squared_distance > (radius_fraction * radius) ** 2:
77 | mask[:, i, j, k] = 0.
78 |
79 | # Register buffer adds a key to the state dict of the model. This will
80 | # track the attribute without registering it as a learnable parameter.
81 | # This also means mask will be moved to device when calling
82 | # model.to(device)
83 | self.register_buffer('mask', mask)
84 |
85 | def forward(self, volume):
86 | """Applies a spherical mask to input.
87 |
88 | Args:
89 | volume (torch.Tensor): Shape (batch_size, channels, depth, height, width).
90 | """
91 | return volume * self.mask
92 |
--------------------------------------------------------------------------------
/models/submodels.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from models.layers import ResBlock2d, ResBlock3d, num_channels_to_num_groups
4 |
5 |
6 | class ResNet2d(nn.Module):
7 | """ResNets for 2d inputs.
8 |
9 | Args:
10 | input_shape (tuple of ints): Shape of the input to the model. Should be
11 | of the form (channels, height, width).
12 | channels (tuple of ints): List of number of channels for each layer.
13 | Length of this tuple corresponds to number of layers in network.
14 | strides (tuple of ints): List of strides for each layer. Length of this
15 | tuple corresponds to number of layers in network. If stride is 1, a
16 | residual layer is applied. If stride is 2 a convolution with stride
17 | 2 is applied. If stride is -2 a transpose convolution with stride 2
18 | is applied.
19 | final_conv_channels (int): If not 0, a convolution is added as the final
20 | layer, with the number of output channels specified by this int.
21 | filter_multipliers (tuple of ints): Multipliers for filters in residual
22 | layers.
23 | add_groupnorm (bool): If True, adds GroupNorm layers.
24 |
25 |
26 | Notes:
27 | The first layer of this model is a standard convolution to increase the
28 | number of filters. A convolution can optionally be added at the final
29 | layer.
30 | """
31 | def __init__(self, input_shape, channels, strides, final_conv_channels=0,
32 | filter_multipliers=(1, 1), add_groupnorm=True):
33 | super(ResNet2d, self).__init__()
34 | assert len(channels) == len(strides), "Length of channels tuple is {} and length of strides tuple is {} but " \
35 | "they should be equal".format(len(channels), len(strides))
36 | self.input_shape = input_shape
37 | self.channels = channels
38 | self.strides = strides
39 | self.filter_multipliers = filter_multipliers
40 | self.add_groupnorm = add_groupnorm
41 |
42 | # Calculate output_shape:
43 | # Every layer with stride 2 divides the height and width by 2.
44 | # Similarly, every layer with stride -2 multiplies the height and width
45 | # by 2
46 | output_channels, output_height, output_width = input_shape
47 |
48 | for stride in strides:
49 | if stride == 1:
50 | pass
51 | elif stride == 2:
52 | output_height //= 2
53 | output_width //= 2
54 | elif stride == -2:
55 | output_height *= 2
56 | output_width *= 2
57 |
58 | self.output_shape = (channels[-1], output_height, output_width)
59 |
60 | # Build layers
61 | # First layer to increase number of channels before applying residual
62 | # layers
63 | forward_layers = [
64 | nn.Conv2d(self.input_shape[0], channels[0], kernel_size=1,
65 | stride=1, padding=0)
66 | ]
67 | in_channels = channels[0]
68 | multiplier1x1, multiplier3x3 = filter_multipliers
69 | for out_channels, stride in zip(channels, strides):
70 | if stride == 1:
71 | forward_layers.append(
72 | ResBlock2d(in_channels,
73 | [out_channels * multiplier1x1, out_channels * multiplier3x3],
74 | add_groupnorm=add_groupnorm)
75 | )
76 | if stride == 2:
77 | forward_layers.append(
78 | nn.Conv2d(in_channels, out_channels, kernel_size=4,
79 | stride=2, padding=1)
80 | )
81 | if stride == -2:
82 | forward_layers.append(
83 | nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4,
84 | stride=2, padding=1)
85 | )
86 |
87 | # Add non-linearity
88 | if stride == 2 or stride == -2:
89 | forward_layers.append(nn.GroupNorm(num_channels_to_num_groups(out_channels), out_channels))
90 | forward_layers.append(nn.LeakyReLU(0.2, True))
91 |
92 | in_channels = out_channels
93 |
94 | if final_conv_channels:
95 | forward_layers.append(
96 | nn.Conv2d(in_channels, final_conv_channels, kernel_size=1,
97 | stride=1, padding=0)
98 | )
99 |
100 | self.forward_layers = nn.Sequential(*forward_layers)
101 |
102 | def forward(self, inputs):
103 | """Applies ResNet to image-like features.
104 |
105 | Args:
106 | inputs (torch.Tensor): Image-like tensor, with shape (batch_size,
107 | channels, height, width).
108 | """
109 | return self.forward_layers(inputs)
110 |
111 |
112 | class ResNet3d(nn.Module):
113 | """ResNets for 3d inputs.
114 |
115 | Args:
116 | input_shape (tuple of ints): Shape of the input to the model. Should be
117 | of the form (channels, depth, height, width).
118 | channels (tuple of ints): List of number of channels for each layer.
119 | Length of this tuple corresponds to number of layers in network.
120 | Note that this corresponds to number of *output* channels for each
121 | convolutional layer.
122 | strides (tuple of ints): List of strides for each layer. Length of this
123 | tuple corresponds to number of layers in network. If stride is 1, a
124 | residual layer is applied. If stride is 2 a convolution with stride
125 | 2 is applied. If stride is -2 a transpose convolution with stride 2
126 | is applied.
127 | final_conv_channels (int): If not 0, a convolution is added as the final
128 | layer, with the number of output channels specified by this int.
129 | filter_multipliers (tuple of ints): Multipliers for filters in residual
130 | layers.
131 | add_groupnorm (bool): If True, adds GroupNorm layers.
132 |
133 | Notes:
134 | The first layer of this model is a standard convolution to increase the
135 | number of filters. A convolution can optionally be added at the final
136 | layer.
137 | """
138 | def __init__(self, input_shape, channels, strides, final_conv_channels=0,
139 | filter_multipliers=(1, 1), add_groupnorm=True):
140 | super(ResNet3d, self).__init__()
141 | assert len(channels) == len(strides), "Length of channels tuple is {} and length of strides tuple is {} but they should be equal".format(len(channels), len(strides))
142 | self.input_shape = input_shape
143 | self.channels = channels
144 | self.strides = strides
145 | self.filter_multipliers = filter_multipliers
146 | self.add_groupnorm = add_groupnorm
147 |
148 | # Calculate output_shape
149 | output_channels, output_depth, output_height, output_width = input_shape
150 |
151 | for stride in strides:
152 | if stride == 1:
153 | pass
154 | elif stride == 2:
155 | output_depth //= 2
156 | output_height //= 2
157 | output_width //= 2
158 | elif stride == -2:
159 | output_depth *= 2
160 | output_height *= 2
161 | output_width *= 2
162 |
163 | self.output_shape = (channels[-1], output_depth, output_height, output_width)
164 |
165 | # Build layers
166 | # First layer to increase number of channels before applying residual
167 | # layers
168 | forward_layers = [
169 | nn.Conv3d(self.input_shape[0], channels[0], kernel_size=1,
170 | stride=1, padding=0)
171 | ]
172 | in_channels = channels[0]
173 | multiplier1x1, multiplier3x3 = filter_multipliers
174 | for out_channels, stride in zip(channels, strides):
175 | if stride == 1:
176 | forward_layers.append(
177 | ResBlock3d(in_channels,
178 | [out_channels * multiplier1x1, out_channels * multiplier3x3],
179 | add_groupnorm=add_groupnorm)
180 | )
181 | if stride == 2:
182 | forward_layers.append(
183 | nn.Conv3d(in_channels, out_channels, kernel_size=4,
184 | stride=2, padding=1)
185 | )
186 | if stride == -2:
187 | forward_layers.append(
188 | nn.ConvTranspose3d(in_channels, out_channels, kernel_size=4,
189 | stride=2, padding=1)
190 | )
191 |
192 | # Add non-linearity
193 | if stride == 2 or stride == -2:
194 | forward_layers.append(nn.GroupNorm(num_channels_to_num_groups(out_channels), out_channels))
195 | forward_layers.append(nn.LeakyReLU(0.2, True))
196 |
197 | in_channels = out_channels
198 |
199 | if final_conv_channels:
200 | forward_layers.append(
201 | nn.Conv3d(in_channels, final_conv_channels, kernel_size=1,
202 | stride=1, padding=0)
203 | )
204 |
205 | self.forward_layers = nn.Sequential(*forward_layers)
206 |
207 | def forward(self, inputs):
208 | """Applies ResNet to 3D features.
209 |
210 | Args:
211 | inputs (torch.Tensor): Tensor, with shape (batch_size, channels,
212 | depth, height, width).
213 | """
214 | return self.forward_layers(inputs)
215 |
216 |
217 | class Projection(nn.Module):
218 | """Performs a projection from a 3D voxel-like feature map to a 2D image-like
219 | feature map.
220 |
221 | Args:
222 | input_shape (tuple of ints): Shape of 3D input, (channels, depth,
223 | height, width).
224 | num_channels (tuple of ints): Number of channels in each layer of the
225 | projection unit.
226 |
227 | Notes:
228 | This layer is inspired by the Projection Unit from
229 | https://arxiv.org/abs/1806.06575.
230 | """
231 | def __init__(self, input_shape, num_channels):
232 | super(Projection, self).__init__()
233 | self.input_shape = input_shape
234 | self.num_channels = num_channels
235 | self.output_shape = (num_channels[-1],) + input_shape[2:]
236 | # Number of input channels for first 2D convolution is
237 | # channels * depth since we flatten the 3D input
238 | in_channels = self.input_shape[0] * self.input_shape[1]
239 | # Initialize forward pass layers
240 | forward_layers = []
241 | num_layers = len(num_channels)
242 | for i in range(num_layers):
243 | out_channels = num_channels[i]
244 | forward_layers.append(
245 | nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1)
246 | )
247 | # Add non linearites, except for last layer
248 | if i != num_layers - 1:
249 | forward_layers.append(nn.GroupNorm(num_channels_to_num_groups(out_channels), out_channels))
250 | forward_layers.append(nn.LeakyReLU(0.2, True))
251 | in_channels = out_channels
252 | # Set up forward layers as model
253 | self.forward_layers = nn.Sequential(*forward_layers)
254 |
255 | def forward(self, inputs):
256 | """Reshapes inputs from 3D -> 2D and applies 1x1 convolutions.
257 |
258 | Args:
259 | inputs (torch.Tensor): Voxel like tensor, with shape (batch_size,
260 | channels, depth, height, width).
261 | """
262 | batch_size, channels, depth, height, width = inputs.shape
263 | # Reshape 3D -> 2D
264 | reshaped = inputs.view(batch_size, channels * depth, height, width)
265 | # 1x1 conv layers
266 | return self.forward_layers(reshaped)
267 |
268 |
269 | class InverseProjection(nn.Module):
270 | """Performs an inverse projection from a 2D feature map to a 3D feature map.
271 |
272 | Args:
273 | input_shape (tuple of ints): Shape of 2D input, (channels, height, width).
274 | num_channels (tuple of ints): Number of channels in each layer of the
275 | projection unit.
276 |
277 | Note:
278 | The depth will be equal to the height and width of the input map.
279 | Therefore, the final number of channels must be divisible by the height
280 | and width of the input.
281 | """
282 | def __init__(self, input_shape, num_channels):
283 | super(InverseProjection, self).__init__()
284 | self.input_shape = input_shape
285 | self.num_channels = num_channels
286 | assert num_channels[-1] % input_shape[-1] == 0, "Number of output channels is {} which is not divisible by " \
287 | "width {} of image".format(num_channels[-1], input_shape[-1])
288 | self.output_shape = (num_channels[-1] // input_shape[-1], input_shape[-1]) + input_shape[1:]
289 |
290 | # Initialize forward pass layers
291 | in_channels = self.input_shape[0]
292 | forward_layers = []
293 | num_layers = len(num_channels)
294 | for i in range(num_layers):
295 | out_channels = num_channels[i]
296 | forward_layers.append(
297 | nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1,
298 | padding=0)
299 | )
300 | # Add non linearites, except for last layer
301 | if i != num_layers - 1:
302 | forward_layers.append(nn.GroupNorm(num_channels_to_num_groups(out_channels), out_channels))
303 | forward_layers.append(nn.LeakyReLU(0.2, True))
304 | in_channels = out_channels
305 | # Set up forward layers as model
306 | self.forward_layers = nn.Sequential(*forward_layers)
307 |
308 | def forward(self, inputs):
309 | """Applies convolutions and reshapes outputs from 2D -> 3D.
310 |
311 | Args:
312 | inputs (torch.Tensor): Image like tensor, with shape (batch_size,
313 | channels, height, width).
314 | """
315 | # 1x1 conv layers
316 | features = self.forward_layers(inputs)
317 | # Reshape 3D -> 2D
318 | batch_size = inputs.shape[0]
319 | return features.view(batch_size, *self.output_shape)
320 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | imageio
3 | Pillow
4 | torch==1.4.0
5 | torchvision
6 | pytorch-msssim
7 |
--------------------------------------------------------------------------------
/trained-models/chairs.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-equivariant-neural-rendering/e7ecfb1b4f93fc89cd4912af64e8223c6523efa8/trained-models/chairs.pt
--------------------------------------------------------------------------------
/training/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-equivariant-neural-rendering/e7ecfb1b4f93fc89cd4912af64e8223c6523efa8/training/__init__.py
--------------------------------------------------------------------------------
/training/training.py:
--------------------------------------------------------------------------------
1 | import json
2 | import torch
3 | import torch.nn as nn
4 | from models.neural_renderer import get_swapped_indices
5 | from pytorch_msssim import SSIM
6 | from torchvision.utils import save_image
7 |
8 |
9 | class Trainer():
10 | """Class used to train neural renderers.
11 |
12 | Args:
13 | device (torch.device): Device to train model on.
14 | model (models.neural_renderer.NeuralRenderer): Model to train.
15 | lr (float): Learning rate.
16 | rendering_loss_type (string): One of 'l1', 'l2'.
17 | ssim_loss_weight (float): Weight assigned to SSIM loss.
18 | """
19 | def __init__(self, device, model, lr=2e-4, rendering_loss_type='l1',
20 | ssim_loss_weight=0.05):
21 | self.device = device
22 | self.model = model
23 | self.lr = lr
24 | self.rendering_loss_type = rendering_loss_type
25 | self.ssim_loss_weight = ssim_loss_weight
26 | self.use_ssim = self.ssim_loss_weight != 0
27 | # If False doesn't save losses in loss history
28 | self.register_losses = True
29 | # Check if model is multi-gpu
30 | self.multi_gpu = isinstance(self.model, nn.DataParallel)
31 |
32 | # Initialize optimizer
33 | self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
34 |
35 | # Initialize loss functions
36 | # For rendered images
37 | if self.rendering_loss_type == 'l1':
38 | self.loss_func = nn.L1Loss()
39 | elif self.rendering_loss_type == 'l2':
40 | self.loss_func = nn.MSELoss()
41 |
42 | # For SSIM
43 | if self.use_ssim:
44 | self.ssim_loss_func = SSIM(data_range=1.0, size_average=True,
45 | channel=3, nonnegative_ssim=False)
46 |
47 | # Loss histories
48 | self.recorded_losses = ["total", "regression", "ssim"]
49 | self.loss_history = {loss_type: [] for loss_type in self.recorded_losses}
50 | self.epoch_loss_history = {loss_type: [] for loss_type in self.recorded_losses}
51 | self.val_loss_history = {loss_type: [] for loss_type in self.recorded_losses}
52 |
53 | def train(self, dataloader, epochs, save_dir=None, save_freq=1,
54 | test_dataloader=None):
55 | """Trains a neural renderer model on the given dataloader.
56 |
57 | Args:
58 | dataloader (torch.utils.DataLoader): Dataloader for a
59 | misc.dataloaders.SceneRenderDataset instance.
60 | epochs (int): Number of epochs to train for.
61 | save_dir (string or None): If not None, saves model and generated
62 | images to directory described by save_dir. Note that this
63 | directory should already exist.
64 | save_freq (int): Frequency with which to save model.
65 | test_dataloader (torch.utils.DataLoader or None): If not None, will
66 | test model on this dataset after every epoch.
67 | """
68 | if save_dir is not None:
69 | # Extract one batch of data
70 | for batch in dataloader:
71 | break
72 | # Save original images
73 | save_image(batch["img"], save_dir + "/imgs_ground_truth.png", nrow=4)
74 | # Store batch to check how rendered images improve during training
75 | self.fixed_batch = batch
76 | # Render images before any training
77 | rendered = self._render_fixed_img()
78 | save_image(rendered.cpu(),
79 | save_dir + "/imgs_gen_{}.png".format(str(0).zfill(3)), nrow=4)
80 |
81 | for epoch in range(epochs):
82 | print("\nEpoch {}".format(epoch + 1))
83 | self._train_epoch(dataloader)
84 | # Update epoch loss history with mean loss over epoch
85 | for loss_type in self.recorded_losses:
86 | self.epoch_loss_history[loss_type].append(
87 | sum(self.loss_history[loss_type][-len(dataloader):]) / len(dataloader)
88 | )
89 | # Print epoch losses
90 | print("Mean epoch loss:")
91 | self._print_losses(epoch_loss=True)
92 |
93 | # Optionally save generated images, losses and model
94 | if save_dir is not None:
95 | # Save generated images
96 | rendered = self._render_fixed_img()
97 | save_image(rendered.cpu(),
98 | save_dir + "/imgs_gen_{}.png".format(str(epoch + 1).zfill(3)), nrow=4)
99 | # Save losses
100 | with open(save_dir + '/loss_history.json', 'w') as loss_file:
101 | json.dump(self.loss_history, loss_file)
102 | # Save epoch losses
103 | with open(save_dir + '/epoch_loss_history.json', 'w') as loss_file:
104 | json.dump(self.epoch_loss_history, loss_file)
105 | # Save model
106 | if (epoch + 1) % save_freq == 0:
107 | if self.multi_gpu:
108 | self.model.module.save(save_dir + "/model.pt")
109 | else:
110 | self.model.save(save_dir + "/model.pt")
111 |
112 | if test_dataloader is not None:
113 | regression_loss, ssim_loss, total_loss = mean_dataset_loss(self, test_dataloader)
114 | print("Validation:\nRegression: {:.4f}, SSIM: {:.4f}, Total: {:.4f}".format(regression_loss, ssim_loss, total_loss))
115 | self.val_loss_history["regression"].append(regression_loss)
116 | self.val_loss_history["ssim"].append(ssim_loss)
117 | self.val_loss_history["total"].append(total_loss)
118 | if save_dir is not None:
119 | # Save validation losses
120 | with open(save_dir + '/val_loss_history.json', 'w') as loss_file:
121 | json.dump(self.val_loss_history, loss_file)
122 | # If current validation loss is the lowest, save model as best
123 | # model
124 | if min(self.val_loss_history["total"]) == total_loss:
125 | print("New best model!")
126 | if self.multi_gpu:
127 | self.model.module.save(save_dir + "/best_model.pt")
128 | else:
129 | self.model.save(save_dir + "/best_model.pt")
130 |
131 | # Save model after training
132 | if save_dir is not None:
133 | if self.multi_gpu:
134 | self.model.module.save(save_dir + "/model.pt")
135 | else:
136 | self.model.save(save_dir + "/model.pt")
137 |
138 | def _train_epoch(self, dataloader):
139 | """Trains model for a single epoch.
140 |
141 | Args:
142 | dataloader (torch.utils.DataLoader): Dataloader for a
143 | misc.dataloaders.SceneRenderDataset instance.
144 | """
145 | num_iterations = len(dataloader)
146 | for i, batch in enumerate(dataloader):
147 | # Train inverse and forward renderer on batch
148 | self._train_iteration(batch)
149 |
150 | # Print iteration losses
151 | print("{}/{}".format(i + 1, num_iterations))
152 | self._print_losses()
153 |
154 | def _train_iteration(self, batch):
155 | """Trains model for a single iteration.
156 |
157 | Args:
158 | batch (dict): Batch of data as returned by a Dataloader for a
159 | misc.dataloaders.SceneRenderDataset instance.
160 | """
161 | imgs, rendered, scenes, scenes_rotated = self.model(batch)
162 | self._optimizer_step(imgs, rendered)
163 |
164 | def _optimizer_step(self, imgs, rendered):
165 | """Updates weights of neural renderer.
166 |
167 | Args:
168 | imgs (torch.Tensor): Ground truth images. Shape
169 | (batch_size, channels, height, width).
170 | rendered (torch.Tensor): Rendered images. Shape
171 | (batch_size, channels, height, width).
172 | """
173 | self.optimizer.zero_grad()
174 |
175 | loss_regression = self.loss_func(rendered, imgs)
176 | if self.use_ssim:
177 | # We want to maximize SSIM, i.e. minimize -SSIM
178 | loss_ssim = 1. - self.ssim_loss_func(rendered, imgs)
179 | loss_total = loss_regression + self.ssim_loss_weight * loss_ssim
180 | else:
181 | loss_total = loss_regression
182 |
183 | loss_total.backward()
184 | self.optimizer.step()
185 |
186 | # Record total loss
187 | if self.register_losses:
188 | self.loss_history["total"].append(loss_total.item())
189 | self.loss_history["regression"].append(loss_regression.item())
190 | # If SSIM is not used, register 0 in logs
191 | if not self.use_ssim:
192 | self.loss_history["ssim"].append(0.)
193 | else:
194 | self.loss_history["ssim"].append(loss_ssim.item())
195 |
196 | def _render_fixed_img(self):
197 | """Reconstructs fixed batch through neural renderer (by inferring
198 | scenes, rotating them and rerendering).
199 | """
200 | _, rendered, _, _ = self.model(self.fixed_batch)
201 | return rendered
202 |
203 | def _print_losses(self, epoch_loss=False):
204 | """Prints most recent losses."""
205 | loss_info = []
206 | for loss_type in self.recorded_losses:
207 | if epoch_loss:
208 | loss = self.epoch_loss_history[loss_type][-1]
209 | else:
210 | loss = self.loss_history[loss_type][-1]
211 | loss_info += [loss_type, loss]
212 | print("{}: {:.3f}, {}: {:.3f}, {}: {:.3f}".format(*loss_info))
213 |
214 |
215 | def mean_dataset_loss(trainer, dataloader):
216 | """Returns the mean loss of a model across a dataloader.
217 |
218 | Args:
219 | trainer (training.Trainer): Trainer instance containing model to
220 | evaluate.
221 | dataloader (torch.utils.DataLoader): Dataloader for a
222 | misc.dataloaders.SceneRenderDataset instance.
223 | """
224 | # No need to calculate gradients during evaluation, so disable gradients to
225 | # increase performance and reduce memory footprint
226 | with torch.no_grad():
227 | # Ensure calculated losses aren't registered as training losses
228 | trainer.register_losses = False
229 |
230 | regression_loss = 0.
231 | ssim_loss = 0.
232 | total_loss = 0.
233 | for i, batch in enumerate(dataloader):
234 | imgs, rendered, scenes, scenes_rotated = trainer.model(batch)
235 |
236 | # Update losses
237 | # Use _loss_func here and not _loss_renderer since we only want regression term
238 | current_regression_loss = trainer.loss_func(rendered, imgs).item()
239 | if trainer.use_ssim:
240 | current_ssim_loss = 1. - trainer.ssim_loss_func(rendered, imgs).item()
241 | else:
242 | current_ssim_loss = 0.
243 | regression_loss += current_regression_loss
244 | ssim_loss += current_ssim_loss
245 | total_loss += current_regression_loss + trainer.ssim_loss_weight * current_ssim_loss
246 |
247 | # Average losses over dataset
248 | regression_loss /= len(dataloader)
249 | ssim_loss /= len(dataloader)
250 | total_loss /= len(dataloader)
251 |
252 | # Reset boolean so we register losses if we continue training
253 | trainer.register_losses = True
254 |
255 | return regression_loss, ssim_loss, total_loss
256 |
--------------------------------------------------------------------------------
/transforms3d/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-equivariant-neural-rendering/e7ecfb1b4f93fc89cd4912af64e8223c6523efa8/transforms3d/__init__.py
--------------------------------------------------------------------------------
/transforms3d/conversions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from math import pi
3 |
4 |
5 | def deg2rad(angles):
6 | return angles * pi / 180.
7 |
8 |
9 | def rad2deg(angles):
10 | return angles * 180. / pi
11 |
12 |
13 | def rotation_matrix_y(angle):
14 | """Returns rotation matrix about y-axis.
15 |
16 | Args:
17 | angle (torch.Tensor): Rotation angle in degrees. Shape (batch_size,).
18 | """
19 | # Initialize rotation matrix
20 | rotation_matrix = torch.zeros(angle.shape[0], 3, 3, device=angle.device)
21 | # Fill out matrix entries
22 | angle_rad = deg2rad(angle)
23 | cos_angle = torch.cos(angle_rad)
24 | sin_angle = torch.sin(angle_rad)
25 | rotation_matrix[:, 0, 0] = cos_angle
26 | rotation_matrix[:, 0, 2] = sin_angle
27 | rotation_matrix[:, 1, 1] = 1.
28 | rotation_matrix[:, 2, 0] = -sin_angle
29 | rotation_matrix[:, 2, 2] = cos_angle
30 | return rotation_matrix
31 |
32 |
33 | def rotation_matrix_z(angle):
34 | """Returns rotation matrix about z-axis.
35 |
36 | Args:
37 | angle (torch.Tensor): Rotation angle in degrees. Shape (batch_size,).
38 | """
39 | # Initialize rotation matrix
40 | rotation_matrix = torch.zeros(angle.shape[0], 3, 3, device=angle.device)
41 | # Fill out matrix entries
42 | angle_rad = deg2rad(angle)
43 | cos_angle = torch.cos(angle_rad)
44 | sin_angle = torch.sin(angle_rad)
45 | rotation_matrix[:, 0, 0] = cos_angle
46 | rotation_matrix[:, 0, 1] = -sin_angle
47 | rotation_matrix[:, 1, 0] = sin_angle
48 | rotation_matrix[:, 1, 1] = cos_angle
49 | rotation_matrix[:, 2, 2] = 1.
50 | return rotation_matrix
51 |
52 |
53 | def azimuth_elevation_to_rotation_matrix(azimuth, elevation):
54 | """Returns rotation matrix matching the default view (i.e. both azimuth and
55 | elevation are zero) to the view defined by the azimuth, elevation pair.
56 |
57 |
58 | Args:
59 | azimuth (torch.Tensor): Shape (batch_size,). Azimuth of camera in
60 | degrees.
61 | elevation (torch.Tensor): Shape (batch_size,). Elevation of camera in
62 | degrees.
63 |
64 | Notes:
65 | The azimuth and elevation refer to the position of the camera. This
66 | function returns the rotation of the *scene representation*, i.e. the
67 | inverse of the camera transformation.
68 | """
69 | # In the coordinate system we define (see README), azimuth rotation
70 | # corresponds to negative rotation about y axis and elevation rotation to a
71 | # negative rotation about z axis
72 | azimuth_matrix = rotation_matrix_y(-azimuth)
73 | elevation_matrix = rotation_matrix_z(-elevation)
74 | # We first perform elevation rotation followed by azimuth when rotating camera
75 | camera_matrix = azimuth_matrix @ elevation_matrix
76 | # Object rotation matrix is inverse (i.e. transpose) of camera rotation matrix
77 | return transpose_matrix(camera_matrix)
78 |
79 |
80 | def rotation_matrix_source_to_target(azimuth_source, elevation_source,
81 | azimuth_target, elevation_target):
82 | """Returns rotation matrix matching two views defined by azimuth, elevation
83 | pairs.
84 |
85 | Args:
86 | azimuth_source (torch.Tensor): Shape (batch_size,). Azimuth of source
87 | view in degrees.
88 | elevation_source (torch.Tensor): Shape (batch_size,). Elevation of
89 | source view in degrees.
90 | azimuth_target (torch.Tensor): Shape (batch_size,). Azimuth of target
91 | view in degrees.
92 | elevation_target (torch.Tensor): Shape (batch_size,). Elevation of
93 | target view in degrees.
94 | """
95 | # Calculate rotation matrix for each view
96 | rotation_source = azimuth_elevation_to_rotation_matrix(azimuth_source, elevation_source)
97 | rotation_target = azimuth_elevation_to_rotation_matrix(azimuth_target, elevation_target)
98 | # Calculate rotation matrix bringing source view to target view (note that
99 | # for rotation matrix, inverse is transpose)
100 | return rotation_target @ transpose_matrix(rotation_source)
101 |
102 |
103 | def transpose_matrix(matrix):
104 | """Transposes a batch of matrices.
105 |
106 | Args:
107 | matrix (torch.Tensor): Batch of matrices of shape (batch_size, n, m).
108 | """
109 | return matrix.transpose(1, 2)
110 |
--------------------------------------------------------------------------------
/transforms3d/rotations.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from transforms3d.conversions import (rotation_matrix_source_to_target,
3 | transpose_matrix)
4 |
5 |
6 | def rotate(volume, rotation_matrix, mode='bilinear'):
7 | """Performs 3D rotation of tensor volume by rotation matrix.
8 |
9 | Args:
10 | volume (torch.Tensor): Shape (batch_size, channels, depth, height, width).
11 | rotation_matrix (torch.Tensor): Batch of rotation matrices of shape
12 | (batch_size, 3, 3).
13 | mode (string): One of 'bilinear' and 'nearest' for interpolation mode
14 | used in grid_sample. Note that the 'bilinear' option actually
15 | performs trilinear interpolation.
16 |
17 | Notes:
18 | We use align_corners=False in grid_sample. See
19 | https://discuss.pytorch.org/t/what-we-should-use-align-corners-false/22663/9
20 | for a nice illustration of why this is.
21 | """
22 | # The grid_sample function performs the inverse transformation of the input
23 | # coordinates, so invert matrix to get forward transformation
24 | inverse_rotation_matrix = transpose_matrix(rotation_matrix)
25 | # The grid_sample function swaps x and z (i.e. it assumes the tensor
26 | # dimensions are ordered as z, y, x), therefore we need to flip the rows and
27 | # columns of the matrix (which we can verify is equivalent to multiplying by
28 | # the appropriate permutation matrices)
29 | inverse_rotation_matrix_swap_xz = torch.flip(inverse_rotation_matrix,
30 | dims=(1, 2))
31 | # Apply transformation to grid
32 | affine_grid = get_affine_grid(inverse_rotation_matrix_swap_xz, volume.shape)
33 | # Regrid volume according to transformation grid
34 | return torch.nn.functional.grid_sample(volume, affine_grid, mode=mode,
35 | align_corners=False)
36 |
37 |
38 | def get_affine_grid(matrix, grid_shape):
39 | """Given a matrix and a grid shape, returns the grid transformed by the
40 | matrix (typically a rotation matrix).
41 |
42 | Args:
43 | matrix (torch.Tensor): Batch of matrices of size (batch_size, 3, 3).
44 | grid_shape (torch.size): Shape of returned affine grid. Should be of the
45 | form (batch_size, channels, depth, height, width).
46 |
47 | Notes:
48 | We use align_corners=False in affine_grid. See
49 | https://discuss.pytorch.org/t/what-we-should-use-align-corners-false/22663/9
50 | for a nice illustration of why this is.
51 | """
52 | batch_size = matrix.shape[0]
53 | # Last column of affine matrix corresponds to translation which is 0 in our
54 | # case. Therefore pad original matrix with zeros, so shape changes from
55 | # (batch_size, 3, 3) to (batch_size, 3, 4)
56 | translations = torch.zeros(batch_size, 3, 1, device=matrix.device)
57 | affine_matrix = torch.cat([matrix, translations], dim=2)
58 | return torch.nn.functional.affine_grid(affine_matrix, grid_shape,
59 | align_corners=False)
60 |
61 |
62 | def rotate_source_to_target(volume, azimuth_source, elevation_source,
63 | azimuth_target, elevation_target, mode='bilinear'):
64 | """Performs 3D rotation matching two coordinate frames defined by a source
65 | view and a target view.
66 |
67 | Args:
68 | volume (torch.Tensor): Shape (batch_size, channels, depth, height, width).
69 | azimuth_source (torch.Tensor): Shape (batch_size,). Azimuth of source
70 | view in degrees.
71 | elevation_source (torch.Tensor): Shape (batch_size,). Elevation of
72 | source view in degrees.
73 | azimuth_target (torch.Tensor): Shape (batch_size,). Azimuth of target
74 | view in degrees.
75 | elevation_target (torch.Tensor): Shape (batch_size,). Elevation of
76 | target view in degrees.
77 | """
78 | rotation_matrix = rotation_matrix_source_to_target(azimuth_source,
79 | elevation_source,
80 | azimuth_target,
81 | elevation_target)
82 | return rotate(volume, rotation_matrix, mode=mode)
83 |
--------------------------------------------------------------------------------