├── CODE_OF_CONDUCT.md ├── CONTRIBUTING-RESEARCH.md ├── CONTRIBUTING.md ├── LICENSE-APPLE-SAMPLE-CODE ├── README-DATA.md ├── README.md ├── config.json ├── evaluate_psnr.py ├── experiments.py ├── exploration.ipynb ├── imgs ├── all_datasets.gif ├── coordinate-system.png └── example-data │ ├── chair1.png │ ├── chair2.png │ ├── chair3.png │ └── chair4.png ├── misc ├── __init__.py ├── dataloaders.py ├── quantitative_evaluation.py ├── utils.py └── viz.py ├── models ├── __init__.py ├── layers.py ├── neural_renderer.py ├── rotation_layers.py └── submodels.py ├── requirements.txt ├── trained-models └── chairs.pt ├── training ├── __init__.py └── training.py └── transforms3d ├── __init__.py ├── conversions.py └── rotations.py /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to making participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the open source team at [opensource-conduct@group.apple.com](mailto:opensource-conduct@group.apple.com). All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org), version 1.4, 71 | available at [https://www.contributor-covenant.org/version/1/4/code-of-conduct.html](https://www.contributor-covenant.org/version/1/4/code-of-conduct.html) -------------------------------------------------------------------------------- /CONTRIBUTING-RESEARCH.md: -------------------------------------------------------------------------------- 1 | # Contribution Guide 2 | 3 | Thanks for your interest in contributing. This project was released to accompany a research paper for purposes of reproducability, and beyond its publication there are limited plans for future development of the repository. 4 | 5 | ## Before you get started 6 | 7 | We ask that all community members read and observe our [Code of Conduct](CODE_OF_CONDUCT.md). -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contribution Guide 2 | 3 | Thanks for your interest in contributing. This project was released to accompany a research paper for purposes of reproducability, and beyond its publication there are limited plans for future development of the repository. 4 | 5 | While we welcome new pull requests and issues please note that our response may be limited. Forks and out-of-tree improvements are strongly encouraged. 6 | 7 | ## Before you get started 8 | 9 | By submitting a pull request, you represent that you have the right to license your contribution to Apple and the community, and agree by submitting the patch that your contributions are licensed under the [LICENSE](LICENSE). 10 | 11 | We ask that all community members read and observe our [Code of Conduct](CODE_OF_CONDUCT.md). -------------------------------------------------------------------------------- /LICENSE-APPLE-SAMPLE-CODE: -------------------------------------------------------------------------------- 1 | Copyright (C) 2020 Apple Inc. All Rights Reserved. 2 | 3 | IMPORTANT: This Apple software is supplied to you by Apple 4 | Inc. ("Apple") in consideration of your agreement to the following 5 | terms, and your use, installation, modification or redistribution of 6 | this Apple software constitutes acceptance of these terms. If you do 7 | not agree with these terms, please do not use, install, modify or 8 | redistribute this Apple software. 9 | 10 | In consideration of your agreement to abide by the following terms, and 11 | subject to these terms, Apple grants you a personal, non-exclusive 12 | license, under Apple's copyrights in this original Apple software (the 13 | "Apple Software"), to use, reproduce, modify and redistribute the Apple 14 | Software, with or without modifications, in source and/or binary forms; 15 | provided that if you redistribute the Apple Software in its entirety and 16 | without modifications, you must retain this notice and the following 17 | text and disclaimers in all such redistributions of the Apple Software. 18 | Neither the name, trademarks, service marks or logos of Apple Inc. may 19 | be used to endorse or promote products derived from the Apple Software 20 | without specific prior written permission from Apple. Except as 21 | expressly stated in this notice, no other rights or licenses, express or 22 | implied, are granted by Apple herein, including but not limited to any 23 | patent rights that may be infringed by your derivative works or by other 24 | works in which the Apple Software may be incorporated. 25 | 26 | The Apple Software is provided by Apple on an "AS IS" basis. APPLE 27 | MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION 28 | THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS 29 | FOR A PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND 30 | OPERATION ALONE OR IN COMBINATION WITH YOUR PRODUCTS. 31 | 32 | IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL 33 | OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 34 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 35 | INTERRUPTION) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION, 36 | MODIFICATION AND/OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED 37 | AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE), 38 | STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE 39 | POSSIBILITY OF SUCH DAMAGE. 40 | -------------------------------------------------------------------------------- /README-DATA.md: -------------------------------------------------------------------------------- 1 | 2 | README - ENR Data 3 | 4 | The datasets accompanying the ICML 2020 paper Equivariant neural rendering are licensed as follows: 5 | 6 | 7 | 1) The cars, chairs, and mugs-hq images are rendered from ShapeNet models (http://www.shapenet.org/), they are licensed under the shapenet license - https://www.shapenet.org/terms 8 | 9 | 2) The mountains images are licensed under under the CC-BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/legalcode). 10 | Satellite imagery © 2020 Maxar Technologies 11 | 12 | 13 | Additionally we thank Bernhard Vogl (salzamt@dativ.at) for the environmental map used to render the mugs-hq images. This environmental map and others can be found at http://dativ.at/lightprobes. 14 | 15 | If you find this code useful in your research, consider citing with: 16 | 17 | 18 | @article{dupont2020equivariant, 19 | title={Equivariant Neural Rendering}, 20 | author={Dupont, Emilien and Miguel Angel, Bautista and Colburn, Alex and Sankar, Aditya and Guestrin, Carlos and Susskind, Josh and Shan, Qi}, 21 | journal={arXiv preprint arXiv:2006.07630}, 22 | year={2020} 23 | } 24 | 25 | 26 | 27 | ShapeNet Terms of Use 28 | 29 | After registering for a ShapeNet account, you will be considered for account approval and privilege elevation by an administrator. After approval, you (the "Researcher") receive permission to use the *ShapeNet database* (the "Database") *at Princeton University and Stanford University*. In exchange for being able to join the ShapeNet community and receive such permission, Researcher hereby agrees to the following terms and conditions: 30 | 31 | 1. Researcher shall use the Database only for non-commercial research and educational purposes. 32 | 2. Princeton University and Stanford University make no representations or warranties regarding the Database, including but not limited to warranties of non-infringement or fitness for a particular purpose. 33 | 3. Researcher accepts full responsibility for his or her use of the Database and shall defend and indemnify Princeton University and Stanford University, including their employees, Trustees, officers and agents, against any and all claims arising from Researcher's use of the Database, including but not limited to Researcher's use of any copies of copyrighted 3D models that he or she may create from the Database. 34 | 4. Researcher may provide research associates and colleagues with access to the Database provided that they first agree to be bound by these terms and conditions. 35 | 5. Princeton University and Stanford University reserve the right to terminate Researcher's access to the Database at any time. 36 | 6. If Researcher is employed by a for-profit, commercial entity, Researcher's employer shall also be bound by these terms and conditions, and Researcher hereby represents that he or she is fully authorized to enter into this agreement on behalf of such employer. 37 | 7. The law of the State of New Jersey shall apply to all disputes under this agreement. 38 | 39 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Equivariant Neural Rendering 2 | 3 | This repo contains code to reproduce all experiments in [Equivariant Neural Rendering](https://arxiv.org/abs/2006.07630) by [E. Dupont](https://emiliendupont.github.io), [M. A. Bautista](https://scholar.google.com/citations?user=ZrRs-qoAAAAJ&hl=en), [A. Colburn](https://www.colburn.org), [A. Sankar](https://scholar.google.com/citations?user=6ZDIdEAAAAAJ&hl=en), [C. Guestrin](https://homes.cs.washington.edu/~guestrin/), [J. Susskind](https://scholar.google.com/citations?user=Sv2TGqsAAAAJ&hl=en), [Q. Shan](http://shanqi.github.io), ICML 2020. 4 | 5 | 6 | 7 | 8 | 9 | ### Pre-trained models 10 | 11 | The weights for the trained chairs model are provided in `trained-models/chairs.pt`. 12 | 13 | The other pre-trained models are located https://icml20-prod.cdn-apple.com/eqn-data/models/pre-trained_models.zip. They should be downloaded and placed into the trained-models directory. A small model chairs.pt is included in the git repo. 14 | 15 | ## Examples 16 | 17 | 18 | 19 | ## Requirements 20 | 21 | The requirements can be directly installed from PyPi with `pip install -r requirements.txt`. Running the code requires `python3.6` or higher. 22 | 23 | ## Datasets 24 | 25 | - ShapeNet chairs: https://icml20-prod.cdn-apple.com/eqn-data/data/chairs.zip 26 | - ShapeNet cars: https://icml20-prod.cdn-apple.com/eqn-data/data/cars.zip 27 | - MugsHQ: https://icml20-prod.cdn-apple.com/eqn-data/data/mugs.zip 28 | - 3D mountains: https://icml20-prod.cdn-apple.com/eqn-data/data/mountains.zip 29 | 30 | 31 | each zip file will expand into 3 separate components and a readme e.g: 32 | - `cars-train.zip` 33 | - `cars-val.zip` 34 | - `cars-test.zip` 35 | - `readme.txt` containing the license terms. 36 | 37 | A few example images are provided in `imgs/example-data/`. 38 | 39 | The chairs and car datasets were created with the help of [Vincent Sitzmann](https://vsitzmann.github.io). 40 | 41 | Satellite imagery © 2020 Maxar Technologies. 42 | 43 | We thank Bernhard Vogl (salzamt@dativ.at) for the lightmaps. The MugsHQ were rendered utilizing an environmental map located at http://dativ.at/lightprobes. 44 | 45 | ## Usage 46 | 47 | ### Training a model 48 | 49 | To train a model, run the following: 50 | 51 | ``` 52 | python experiments.py config.json 53 | ``` 54 | 55 | This supports both single and multi-GPU training (see `config.json` for detailed training options). Note that you need to download the datasets before running this command. 56 | 57 | ### Quantitative evaluation 58 | 59 | To evaluate a model, run the following: 60 | 61 | ``` 62 | python evaluate_psnr.py 63 | ``` 64 | 65 | This will measure the performance (in PSNR) of a trained model on a test dataset. 66 | 67 | ### Model exploration and visualization 68 | 69 | The jupyter notebook `exploration.ipynb` shows how to use a trained model to infer a scene representation from a single image and how to use this representation to render novel views. 70 | 71 | 72 | ## Coordinate system 73 | 74 | The diagram below details the coordinate system we use for the voxel grid. Due to the manner in which images are stored in arrays and the way PyTorch's `affine_grid` and `grid_sample` functions work, this is a slightly unusual coordinate system. Note that `theta` and `phi` correspond to elevation and azimuth rotations of the **camera** around the scene representation. Note also that these are left handed rotations. Full details of the voxel rotation function can be found in `transforms3d/rotations.py`. 75 | 76 | 77 | 78 | ## Citing 79 | 80 | If you find this code useful in your research, consider citing with 81 | 82 | ``` 83 | @article{dupont2020equivariant, 84 | title={Equivariant Neural Rendering}, 85 | author={Dupont, Emilien and Miguel Angel, Bautista and Colburn, Alex and Sankar, Aditya and Guestrin, Carlos and Susskind, Josh and Shan, Qi}, 86 | journal={arXiv preprint arXiv:2006.07630}, 87 | year={2020} 88 | } 89 | ``` 90 | 91 | ## License 92 | 93 | This project is licensed under the Apple Sample Code License 94 | 95 | -------------------------------------------------------------------------------- /config.json: -------------------------------------------------------------------------------- 1 | { 2 | "id": "chairs-experiment", 3 | "path_to_data": "chairs-train", 4 | "path_to_test_data": "chairs-val", 5 | "multi_gpu": false, 6 | "batch_size": 16, 7 | "lr": 2e-4, 8 | "epochs": 100, 9 | "loss_type": "l1", 10 | "ssim_loss_weight": 0.05, 11 | "save_freq": 1, 12 | "img_shape": [3, 128, 128], 13 | "channels_2d": [64, 64, 128, 128, 128, 128, 256, 256, 128, 128, 128], 14 | "strides_2d": [1, 1, 2, 1, 2, 1, 2, 1, -2, 1, 1], 15 | "channels_3d": [32, 32, 128, 128, 128, 64, 64, 64], 16 | "strides_3d": [1, 1, 2, 1, 1, -2, 1, 1], 17 | "num_channels_projection": [512, 256, 256], 18 | "num_channels_inv_projection": [256, 512, 1024], 19 | "mode": "bilinear" 20 | } 21 | -------------------------------------------------------------------------------- /evaluate_psnr.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | from misc.dataloaders import scene_render_dataset 4 | from misc.quantitative_evaluation import get_dataset_psnr 5 | from models.neural_renderer import load_model 6 | 7 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 8 | 9 | # Get path to experiment folder from command line arguments 10 | if len(sys.argv) != 3: 11 | raise(RuntimeError("Wrong arguments, use python experiments_psnr.py ")) 12 | model_path = sys.argv[1] 13 | data_dir = sys.argv[2] # This is usually one of "chairs-test" and "cars-test" 14 | 15 | # Load model 16 | model = load_model(model_path) 17 | model = model.to(device) 18 | 19 | # Initialize dataset 20 | dataset = scene_render_dataset(path_to_data=data_dir, img_size=(3, 128, 128), 21 | crop_size=128, allow_odd_num_imgs=True) 22 | 23 | # Calculate PSNR 24 | with torch.no_grad(): 25 | psnrs = get_dataset_psnr(device, model, dataset, source_img_idx_shift=64, 26 | batch_size=125, max_num_scenes=None) 27 | -------------------------------------------------------------------------------- /experiments.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import sys 4 | import time 5 | import torch 6 | from misc.dataloaders import scene_render_dataloader 7 | from models.neural_renderer import NeuralRenderer 8 | from training.training import Trainer 9 | 10 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 11 | 12 | # Get path to data from command line arguments 13 | if len(sys.argv) != 2: 14 | raise(RuntimeError("Wrong arguments, use python experiments.py ")) 15 | path_to_config = sys.argv[1] 16 | 17 | # Open config file 18 | with open(path_to_config) as file: 19 | config = json.load(file) 20 | 21 | # Set up directory to store experiments 22 | timestamp = time.strftime("%Y-%m-%d_%H-%M") 23 | directory = "{}_{}".format(timestamp, config["id"]) 24 | if not os.path.exists(directory): 25 | os.makedirs(directory) 26 | 27 | # Save config file in directory 28 | with open(directory + '/config.json', 'w') as file: 29 | json.dump(config, file) 30 | 31 | # Set up renderer 32 | model = NeuralRenderer( 33 | img_shape=config["img_shape"], 34 | channels_2d=config["channels_2d"], 35 | strides_2d=config["strides_2d"], 36 | channels_3d=config["channels_3d"], 37 | strides_3d=config["strides_3d"], 38 | num_channels_inv_projection=config["num_channels_inv_projection"], 39 | num_channels_projection=config["num_channels_projection"], 40 | mode=config["mode"] 41 | ) 42 | 43 | model.print_model_info() 44 | 45 | model = model.to(device) 46 | 47 | if config["multi_gpu"]: 48 | model = torch.nn.DataParallel(model) 49 | 50 | # Set up trainer for renderer 51 | trainer = Trainer(device, model, lr=config["lr"], 52 | rendering_loss_type=config["loss_type"], 53 | ssim_loss_weight=config["ssim_loss_weight"]) 54 | 55 | dataloader = scene_render_dataloader(path_to_data=config["path_to_data"], 56 | batch_size=config["batch_size"], 57 | img_size=config["img_shape"], 58 | crop_size=128) 59 | 60 | # Optionally set up test_dataloader 61 | if config["path_to_test_data"]: 62 | test_dataloader = scene_render_dataloader(path_to_data=config["path_to_test_data"], 63 | batch_size=config["batch_size"], 64 | img_size=config["img_shape"], 65 | crop_size=128) 66 | else: 67 | test_dataloader = None 68 | 69 | print("PID: {}".format(os.getpid())) 70 | 71 | # Train renderer, save generated images, losses and model 72 | trainer.train(dataloader, config["epochs"], save_dir=directory, 73 | save_freq=config["save_freq"], test_dataloader=test_dataloader) 74 | 75 | # Print best losses 76 | print("Model id: {}".format(config["id"])) 77 | print("Best train loss: {:.4f}".format(min(trainer.epoch_loss_history["total"]))) 78 | print("Best validation loss: {:.4f}".format(min(trainer.val_loss_history["total"]))) 79 | -------------------------------------------------------------------------------- /imgs/all_datasets.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-equivariant-neural-rendering/e7ecfb1b4f93fc89cd4912af64e8223c6523efa8/imgs/all_datasets.gif -------------------------------------------------------------------------------- /imgs/coordinate-system.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-equivariant-neural-rendering/e7ecfb1b4f93fc89cd4912af64e8223c6523efa8/imgs/coordinate-system.png -------------------------------------------------------------------------------- /imgs/example-data/chair1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-equivariant-neural-rendering/e7ecfb1b4f93fc89cd4912af64e8223c6523efa8/imgs/example-data/chair1.png -------------------------------------------------------------------------------- /imgs/example-data/chair2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-equivariant-neural-rendering/e7ecfb1b4f93fc89cd4912af64e8223c6523efa8/imgs/example-data/chair2.png -------------------------------------------------------------------------------- /imgs/example-data/chair3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-equivariant-neural-rendering/e7ecfb1b4f93fc89cd4912af64e8223c6523efa8/imgs/example-data/chair3.png -------------------------------------------------------------------------------- /imgs/example-data/chair4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-equivariant-neural-rendering/e7ecfb1b4f93fc89cd4912af64e8223c6523efa8/imgs/example-data/chair4.png -------------------------------------------------------------------------------- /misc/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-equivariant-neural-rendering/e7ecfb1b4f93fc89cd4912af64e8223c6523efa8/misc/__init__.py -------------------------------------------------------------------------------- /misc/dataloaders.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import json 3 | import torch 4 | from numpy import float32 as np_float32 5 | from PIL import Image 6 | from torch.utils.data import Dataset, DataLoader, Sampler 7 | from torchvision import transforms 8 | 9 | 10 | def scene_render_dataloader(path_to_data='chairs-train', batch_size=16, 11 | img_size=(3, 128, 128), crop_size=128): 12 | """Dataloader for scene render datasets. Returns scene renders in pairs, 13 | i.e. 1st and 2nd images are of some scene, 3rd and 4th are of some different 14 | scene and so on. 15 | 16 | Args: 17 | path_to_data (string): Path to folder containing dataset. 18 | batch_size (int): Batch size for data. 19 | img_size (tuple of ints): Size of output images. 20 | crop_size (int): Size at which to center crop rendered images. 21 | 22 | Notes: 23 | Batch size must be even. 24 | """ 25 | assert batch_size % 2 == 0, "Batch size is {} but must be even".format(batch_size) 26 | 27 | dataset = scene_render_dataset(path_to_data, img_size, crop_size) 28 | 29 | sampler = RandomPairSampler(dataset) 30 | 31 | return DataLoader(dataset, batch_size=batch_size, sampler=sampler, 32 | drop_last=True) 33 | 34 | 35 | def scene_render_dataset(path_to_data='chairs-train', img_size=(3, 128, 128), 36 | crop_size=128, allow_odd_num_imgs=False): 37 | """Helper function for creating a scene render dataset. 38 | 39 | Args: 40 | path_to_data (string): Path to folder containing dataset. 41 | img_size (tuple of ints): Size of output images. 42 | crop_size (int): Size at which to center crop rendered images. 43 | allow_odd_num_imgs (int): If True, allows datasets with an odd number 44 | of views. Such a dataset cannot be used for training, since each 45 | training iteration requires a *pair* of images. Datasets with an odd 46 | number of images are used for PSNR calculations. 47 | """ 48 | img_transform = transforms.Compose([ 49 | transforms.CenterCrop(crop_size), 50 | transforms.Resize(img_size[1:]), 51 | transforms.ToTensor() 52 | ]) 53 | 54 | dataset = SceneRenderDataset(path_to_data=path_to_data, 55 | img_transform=img_transform, 56 | allow_odd_num_imgs=allow_odd_num_imgs) 57 | 58 | return dataset 59 | 60 | 61 | class SceneRenderDataset(Dataset): 62 | """Dataset of rendered scenes and their corresponding camera angles. 63 | 64 | Args: 65 | path_to_data (string): Path to folder containing dataset. 66 | img_transform (torchvision.transform): Transforms to be applied to 67 | images. 68 | allow_odd_num_imgs (bool): If True, allows datasets with an odd number 69 | of views. Such a dataset cannot be used for training, since each 70 | training iteration requires a *pair* of images. 71 | 72 | Notes: 73 | - Image paths must be of the form "XXXXX.png" where XXXXX are *five* 74 | integers indexing the image. 75 | - We assume there are the same number of rendered images for each scene 76 | and that this number is even. 77 | - We assume angles are given in degrees. 78 | """ 79 | def __init__(self, path_to_data='chairs-train', img_transform=None, 80 | allow_odd_num_imgs=False): 81 | self.path_to_data = path_to_data 82 | self.img_transform = img_transform 83 | self.allow_odd_num_imgs = allow_odd_num_imgs 84 | self.data = [] 85 | # Each folder contains a single scene with different rendering 86 | # parameters and views 87 | self.scene_paths = glob.glob(path_to_data + '/*') 88 | self.scene_paths.sort() # Ensure consistent ordering of scenes 89 | self.num_scenes = len(self.scene_paths) 90 | # Extract number of rendered images per object (which we assume is constant) 91 | self.num_imgs_per_scene = len(glob.glob(self.scene_paths[0] + '/*.png')) 92 | # If number of images per scene is not even, drop last image 93 | if self.num_imgs_per_scene % 2 != 0: 94 | if not self.allow_odd_num_imgs: 95 | self.num_imgs_per_scene -= 1 96 | # For each scene, extract its rendered views and render parameters 97 | for scene_path in self.scene_paths: 98 | # Name of folder defines scene name 99 | scene_name = scene_path.split('/')[-1] 100 | 101 | # Load render parameters 102 | with open(scene_path + '/render_params.json') as f: 103 | render_params = json.load(f) 104 | 105 | # Extract path to rendered images of scene 106 | img_paths = glob.glob(scene_path + '/*.png') 107 | img_paths.sort() # Ensure consistent ordering of images 108 | # Ensure number of image paths is even 109 | img_paths = img_paths[:self.num_imgs_per_scene] 110 | 111 | for img_path in img_paths: 112 | # Extract image filename 113 | img_file = img_path.split('/')[-1] 114 | # Filenames are of the type ".png", so extract this 115 | # index to match with render parameters. 116 | img_idx = img_file.split('.')[0][-5:] # This should be a string 117 | # Convert render parameters to float32 118 | img_params = {key: np_float32(value) 119 | for key, value in render_params[img_idx].items()} 120 | self.data.append({ 121 | "scene_name": scene_name, 122 | "img_path": img_path, 123 | "render_params": img_params 124 | }) 125 | 126 | def __len__(self): 127 | return len(self.data) 128 | 129 | def __getitem__(self, idx): 130 | img_path = self.data[idx]["img_path"] 131 | render_params = self.data[idx]["render_params"] 132 | 133 | img = Image.open(img_path) 134 | 135 | # Transform images 136 | if self.img_transform: 137 | img = self.img_transform(img) 138 | 139 | # Note some images may contain 4 channels (i.e. RGB + alpha), we only 140 | # keep RGB channels 141 | data_item = { 142 | "img": img[:3], 143 | "scene_name": self.data[idx]["scene_name"], 144 | "render_params": self.data[idx]["render_params"] 145 | } 146 | 147 | return data_item 148 | 149 | 150 | class RandomPairSampler(Sampler): 151 | """Samples random elements in pairs. Dataset is assumed to be composed of a 152 | number of scenes, each rendered in a number of views. This sampler returns 153 | rendered image in pairs. I.e. for a batch of size 6, it would return e.g.: 154 | 155 | [object 4 - img 5, 156 | object 4 - img 12, 157 | object 6 - img 3, 158 | object 6 - img 19, 159 | object 52 - img 10, 160 | object 52 - img 3] 161 | 162 | 163 | Arguments: 164 | dataset (Dataset): Dataset to sample from. This will typically be an 165 | instance of SceneRenderDataset. 166 | """ 167 | 168 | def __init__(self, dataset): 169 | self.dataset = dataset 170 | 171 | def __iter__(self): 172 | num_scenes = self.dataset.num_scenes 173 | num_imgs_per_scene = self.dataset.num_imgs_per_scene 174 | 175 | # Sample num_imgs_per_scene / 2 permutations of the objects 176 | scene_permutations = [torch.randperm(num_scenes) for _ in range(num_imgs_per_scene // 2)] 177 | # For each scene, sample a permutation of its images 178 | img_permutations = [torch.randperm(num_imgs_per_scene) for _ in range(num_scenes)] 179 | 180 | data_permutation = [] 181 | 182 | for i, scene_permutation in enumerate(scene_permutations): 183 | for scene_idx in scene_permutation: 184 | # Extract image permutation for this object 185 | img_permutation = img_permutations[scene_idx] 186 | # Add 2 images of this object to data_permutation 187 | data_permutation.append(scene_idx.item() * num_imgs_per_scene + img_permutation[2*i].item()) 188 | data_permutation.append(scene_idx.item() * num_imgs_per_scene + img_permutation[2*i + 1].item()) 189 | 190 | return iter(data_permutation) 191 | 192 | def __len__(self): 193 | return len(self.dataset) 194 | 195 | 196 | def create_batch_from_data_list(data_list): 197 | """Given a list of datapoints, create a batch. 198 | 199 | Args: 200 | data_list (list): List of items returned by SceneRenderDataset. 201 | """ 202 | imgs = [] 203 | azimuths = [] 204 | elevations = [] 205 | for data_item in data_list: 206 | img, render_params = data_item["img"], data_item["render_params"] 207 | azimuth, elevation = render_params["azimuth"], render_params["elevation"] 208 | imgs.append(img.unsqueeze(0)) 209 | azimuths.append(torch.Tensor([azimuth])) 210 | elevations.append(torch.Tensor([elevation])) 211 | imgs = torch.cat(imgs, dim=0) 212 | azimuths = torch.cat(azimuths) 213 | elevations = torch.cat(elevations) 214 | return imgs, azimuths, elevations 215 | -------------------------------------------------------------------------------- /misc/quantitative_evaluation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from misc.dataloaders import create_batch_from_data_list 4 | 5 | 6 | def get_dataset_psnr(device, model, dataset, source_img_idx_shift=64, 7 | batch_size=10, max_num_scenes=None): 8 | """Returns PSNR for each scene in a dataset by comparing the view predicted 9 | by a model and the ground truth view. 10 | 11 | Args: 12 | device (torch.device): Device to perform PSNR calculation on. 13 | model (models.neural_renderer.NeuralRenderer): Model to evaluate. 14 | dataset (misc.dataloaders.SceneRenderDataset): Dataset to evaluate model 15 | performance on. Should be one of "chairs-test" or "cars-test". 16 | source_img_idx_shift (int): Index of source image for each scene. For 17 | example if 00064.png is the source view, then 18 | source_img_idx_shift = 64. 19 | batch_size (int): Batch size to use when generating predictions. This 20 | should be a divisor of the number of images per scene. 21 | max_num_scenes (None or int): Optionally limit the maximum number of 22 | scenes to calculate PSNR for. 23 | 24 | Notes: 25 | This function should be used with the ShapeNet chairs and cars *test* 26 | sets. 27 | """ 28 | num_imgs_per_scene = dataset.num_imgs_per_scene 29 | # Set number of scenes to calculate 30 | num_scenes = dataset.num_scenes 31 | if max_num_scenes is not None: 32 | num_scenes = min(max_num_scenes, num_scenes) 33 | # Calculate number of batches per scene 34 | assert (num_imgs_per_scene - 1) % batch_size == 0, "Batch size {} must divide number of images per scene {}." 35 | # Comparison are made against all images except the source image (and 36 | # therefore subtract 1 from total number of images) 37 | batches_per_scene = (num_imgs_per_scene - 1) // batch_size 38 | # Initialize psnr values 39 | psnrs = [] 40 | for i in range(num_scenes): 41 | # Extract source view 42 | source_img_idx = i * num_imgs_per_scene + source_img_idx_shift 43 | img_source = dataset[source_img_idx]["img"].unsqueeze(0).repeat(batch_size, 1, 1, 1).to(device) 44 | render_params = dataset[source_img_idx]["render_params"] 45 | azimuth_source = torch.Tensor([render_params["azimuth"]]).repeat(batch_size).to(device) 46 | elevation_source = torch.Tensor([render_params["elevation"]]).repeat(batch_size).to(device) 47 | # Infer source scene 48 | scenes = model.inverse_render(img_source) 49 | 50 | # Iterate over all other views of scene 51 | num_points_in_batch = 0 52 | data_list = [] 53 | scene_psnr = 0. 54 | for j in range(num_imgs_per_scene): 55 | if j == source_img_idx_shift: 56 | continue # Do not compare against same image 57 | # Add new image to list of images we want to compare to 58 | data_list.append(dataset[i * num_imgs_per_scene + j]) 59 | num_points_in_batch += 1 60 | # If we have filled up a batch, make psnr calculation 61 | if num_points_in_batch == batch_size: 62 | # Create batch for target data 63 | img_target, azimuth_target, elevation_target = create_batch_from_data_list(data_list) 64 | img_target = img_target.to(device) 65 | azimuth_target = azimuth_target.to(device) 66 | elevation_target = elevation_target.to(device) 67 | # Rotate scene and render image 68 | rotated = model.rotate_source_to_target(scenes, azimuth_source, 69 | elevation_source, azimuth_target, 70 | elevation_target) 71 | img_predicted = model.render(rotated).detach() 72 | scene_psnr += get_psnr(img_predicted, img_target) 73 | data_list = [] 74 | num_points_in_batch = 0 75 | 76 | psnrs.append(scene_psnr / batches_per_scene) 77 | 78 | print("{}/{}: Current - {:.3f}, Mean - {:.4f}".format(i + 1, 79 | num_scenes, 80 | psnrs[-1], 81 | torch.mean(torch.Tensor(psnrs)))) 82 | 83 | return psnrs 84 | 85 | 86 | def get_psnr(prediction, target): 87 | """Returns PSNR between a batch of predictions and a batch of targets. 88 | 89 | Args: 90 | prediction (torch.Tensor): Shape (batch_size, channels, height, width). 91 | target (torch.Tensor): Shape (batch_size, channels, height, width). 92 | """ 93 | batch_size = prediction.shape[0] 94 | mse_per_pixel = F.mse_loss(prediction, target, reduction='none') 95 | mse_per_img = mse_per_pixel.view(batch_size, -1).mean(dim=1) 96 | psnr = 10 * torch.log10(1 / mse_per_img) 97 | return torch.mean(psnr).item() 98 | -------------------------------------------------------------------------------- /misc/utils.py: -------------------------------------------------------------------------------- 1 | import models.layers 2 | import models.submodels 3 | import torch 4 | from math import pi 5 | 6 | 7 | def full_rotation_angle_sequence(num_steps): 8 | """Returns a sequence of angles corresponding to a full 360 degree rotation. 9 | Useful for generating gifs. 10 | 11 | Args: 12 | num_steps (int): Number of steps in sequence. 13 | """ 14 | return torch.linspace(0., 360. - 360. / num_steps, num_steps) 15 | 16 | 17 | def constant_angle_sequence(num_steps, value=0.): 18 | """Returns a sequence of constant angles. Useful for generating gifs. 19 | 20 | Args: 21 | num_steps (int): Number of steps in sequence. 22 | value (float): Constant angle value. 23 | """ 24 | return value * torch.ones(num_steps) 25 | 26 | 27 | def back_and_forth_angle_sequence(num_steps, start, end): 28 | """Returns a sequence of angles linearly increasing from start to end and 29 | back. 30 | 31 | Args: 32 | num_steps (int): Number of steps in sequence. 33 | start (float): Angle at which to start (in degrees). 34 | end (float): Angle at which to end (in degrees). 35 | """ 36 | half_num_steps = int(num_steps / 2) 37 | # Increase angle from start to end 38 | first = torch.linspace(start, end - end / half_num_steps, half_num_steps) 39 | # Decrease angle from end to start 40 | second = torch.linspace(end, start - start / half_num_steps, half_num_steps) 41 | # Return combined sequence of increasing and decreasing angles 42 | return torch.cat([first, second], dim=0) 43 | 44 | 45 | def sinusoidal_angle_sequence(num_steps, minimum, maximum): 46 | """Returns a sequence of angles sinusoidally varying between minimum and 47 | maximum. 48 | 49 | Args: 50 | num_steps (int): Number of steps in sequence. 51 | start (float): Angle at which to start (in degrees). 52 | end (float): Angle at which to end (in degrees). 53 | """ 54 | period = 2 * pi * torch.linspace(0., 1. - 1. / num_steps, num_steps) 55 | return .5 * (minimum + maximum + (maximum - minimum) * torch.sin(period)) 56 | 57 | 58 | def sine_squared_angle_sequence(num_steps, start, end): 59 | """Returns a sequence of angles increasing from start to end and back as the 60 | sine squared function. 61 | 62 | Args: 63 | num_steps (int): Number of steps in sequence. 64 | start (float): Angle at which to start (in degrees). 65 | end (float): Angle at which to end (in degrees). 66 | """ 67 | half_period = pi * torch.linspace(0., 1. - 1. / num_steps, num_steps) 68 | return start + (end - start) * torch.sin(half_period) ** 2 69 | 70 | 71 | def count_parameters(model): 72 | """Returns number of trainable parameters in a model.""" 73 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 74 | 75 | 76 | def get_layers_info(model): 77 | """Returns information about input shapes, output shapes and number of 78 | parameters in every block of model. 79 | 80 | Args: 81 | model (torch.nn.Module): Model to analyse. This will typically be a 82 | submodel of models.neural_renderer.NeuralRenderer. 83 | """ 84 | in_shape = model.input_shape 85 | layers_info = [] 86 | 87 | if isinstance(model, models.submodels.Projection): 88 | out_shape = (in_shape[0] * in_shape[1], *in_shape[2:]) 89 | layer_info = {"name": "Reshape", "in_shape": in_shape, 90 | "out_shape": out_shape, "num_params": 0} 91 | layers_info.append(layer_info) 92 | in_shape = out_shape 93 | 94 | for layer in model.forward_layers: 95 | if isinstance(layer, torch.nn.Conv2d): 96 | if layer.stride[0] == 1: 97 | out_shape = (layer.out_channels, *in_shape[1:]) 98 | elif layer.stride[0] == 2: 99 | out_shape = (layer.out_channels, in_shape[1] // 2, in_shape[2] // 2) 100 | name = "Conv2D" 101 | elif isinstance(layer, torch.nn.ConvTranspose2d): 102 | if layer.stride[0] == 1: 103 | out_shape = (layer.out_channels, *in_shape[1:]) 104 | elif layer.stride[0] == 2: 105 | out_shape = (layer.out_channels, in_shape[1] * 2, in_shape[2] * 2) 106 | name = "ConvTr2D" 107 | elif isinstance(layer, models.layers.ResBlock2d): 108 | out_shape = in_shape 109 | name = "ResBlock2D" 110 | elif isinstance(layer, torch.nn.Conv3d): 111 | if layer.stride[0] == 1: 112 | out_shape = (layer.out_channels, *in_shape[1:]) 113 | elif layer.stride[0] == 2: 114 | out_shape = (layer.out_channels, in_shape[1] // 2, in_shape[2] // 2, in_shape[3] // 2) 115 | name = "Conv3D" 116 | elif isinstance(layer, torch.nn.ConvTranspose3d): 117 | if layer.stride[0] == 1: 118 | out_shape = (layer.out_channels, *in_shape[1:]) 119 | elif layer.stride[0] == 2: 120 | out_shape = (layer.out_channels, in_shape[1] * 2, in_shape[2] * 2, in_shape[3] * 2) 121 | name = "ConvTr3D" 122 | elif isinstance(layer, models.layers.ResBlock3d): 123 | out_shape = in_shape 124 | name = "ResBlock3D" 125 | else: 126 | # If layer is just an activation layer, skip 127 | continue 128 | 129 | num_params = count_parameters(layer) 130 | layer_info = {"name": name, "in_shape": in_shape, 131 | "out_shape": out_shape, "num_params": num_params} 132 | layers_info.append(layer_info) 133 | 134 | in_shape = out_shape 135 | 136 | if isinstance(model, models.submodels.InverseProjection): 137 | layer_info = {"name": "Reshape", "in_shape": in_shape, 138 | "out_shape": model.output_shape, "num_params": 0} 139 | layers_info.append(layer_info) 140 | 141 | return layers_info 142 | 143 | 144 | def pretty_print_layers_info(model, title): 145 | """Prints information about a model. 146 | 147 | Args: 148 | model (see get_layers_info) 149 | title (string): Title of model. 150 | """ 151 | # Extract layers info for model 152 | layers_info = get_layers_info(model) 153 | # Print information in a nice format 154 | print(title) 155 | print("-" * len(title)) 156 | print("{: <12} \t {: <14} \t {: <14} \t {: <10} \t {: <10}".format("name", "in_shape", "out_shape", "num_params", "feat_size")) 157 | print("---------------------------------------------------------------------------------------------") 158 | 159 | min_feat_size = 2 ** 20 # Some huge number 160 | for info in layers_info: 161 | feat_size = tuple_product(info["out_shape"]) 162 | print("{: <12} \t {: <14} \t {: <14} \t {: <10} \t {: <10}".format(info["name"], 163 | str(info["in_shape"]), 164 | str(info["out_shape"]), 165 | info["num_params"], 166 | feat_size)) 167 | if feat_size < min_feat_size: 168 | min_feat_size = feat_size 169 | print("---------------------------------------------------------------------------------------------") 170 | # Only print model info if model is not empty 171 | if len(layers_info): 172 | print("{: <12} \t {: <14} \t {: <14} \t {: <10} \t {: <10}".format("Total", 173 | str(layers_info[0]["in_shape"]), 174 | str(layers_info[-1]["out_shape"]), 175 | count_parameters(model), 176 | min_feat_size)) 177 | 178 | 179 | def tuple_product(input_tuple): 180 | """Returns product of elements in a tuple.""" 181 | product = 1 182 | for elem in input_tuple: 183 | product *= elem 184 | return product 185 | -------------------------------------------------------------------------------- /misc/viz.py: -------------------------------------------------------------------------------- 1 | import imageio 2 | import torch 3 | import torchvision 4 | from misc.dataloaders import create_batch_from_data_list 5 | 6 | 7 | def generate_novel_views(model, img_source, azimuth_source, elevation_source, 8 | azimuth_shifts, elevation_shifts): 9 | """Generates novel views of an image by inferring its scene representation, 10 | rotating it and rendering novel views. Returns a batch of images 11 | corresponding to the novel views. 12 | 13 | Args: 14 | model (models.neural_renderer.NeuralRenderer): Neural rendering model. 15 | img_source (torch.Tensor): Single image. Shape (channels, height, width). 16 | azimuth_source (torch.Tensor): Azimuth of source image. Shape (1,). 17 | elevation_source (torch.Tensor): Elevation of source image. Shape (1,). 18 | azimuth_shifts (torch.Tensor): Batch of angle shifts at which to 19 | generate novel views. Shape (num_views,). 20 | elevation_shifts (torch.Tensor): Batch of angle shifts at which to 21 | generate novel views. Shape (num_views,). 22 | """ 23 | # No need to calculate gradients 24 | with torch.no_grad(): 25 | num_views = len(azimuth_shifts) 26 | # Batchify image 27 | img_batch = img_source.unsqueeze(0) 28 | # Infer scene 29 | scenes = model.inverse_render(img_batch) 30 | # Copy scene for each target view 31 | scenes_batch = scenes.repeat(num_views, 1, 1, 1, 1) 32 | # Batchify azimuth and elevation source 33 | azimuth_source_batch = azimuth_source.repeat(num_views) 34 | elevation_source_batch = elevation_source.repeat(num_views) 35 | # Calculate azimuth and elevation targets 36 | azimuth_target = azimuth_source_batch + azimuth_shifts 37 | elevation_target = elevation_source_batch + elevation_shifts 38 | # Rotate scenes 39 | rotated = model.rotate_source_to_target(scenes_batch, azimuth_source_batch, 40 | elevation_source_batch, 41 | azimuth_target, elevation_target) 42 | # Render images 43 | return model.render(rotated).detach() 44 | 45 | 46 | def batch_generate_novel_views(model, imgs_source, azimuth_source, 47 | elevation_source, azimuth_shifts, 48 | elevation_shifts): 49 | """Generates novel views for a batch of images. Returns a list of batches of 50 | images, where each item in the list corresponds to a novel view for all 51 | images. 52 | 53 | Args: 54 | model (models.neural_renderer.NeuralRenderer): Neural rendering model. 55 | imgs_source (torch.Tensor): Source images. Shape (batch_size, channels, height, width). 56 | azimuth_source (torch.Tensor): Azimuth of source. Shape (batch_size,). 57 | elevation_source (torch.Tensor): Elevation of source. Shape (batch_size,). 58 | azimuth_shifts (torch.Tensor): Batch of angle shifts at which to generate 59 | novel views. Shape (num_views,). 60 | elevation_shifts (torch.Tensor): Batch of angle shifts at which to 61 | generate novel views. Shape (num_views,). 62 | """ 63 | num_imgs = imgs_source.shape[0] 64 | num_views = azimuth_shifts.shape[0] 65 | 66 | # Initialize novel views, i.e. a list of length num_views with each item 67 | # containing num_imgs images 68 | all_novel_views = [torch.zeros_like(imgs_source) for _ in range(num_views)] 69 | 70 | for i in range(num_imgs): 71 | # Generate novel views for single image 72 | novel_views = generate_novel_views(model, imgs_source[i], 73 | azimuth_source[i:i+1], 74 | elevation_source[i:i+1], 75 | azimuth_shifts, elevation_shifts).cpu() 76 | # Add to list of all novel_views 77 | for j in range(num_views): 78 | all_novel_views[j][i] = novel_views[j] 79 | 80 | return all_novel_views 81 | 82 | 83 | def dataset_novel_views(device, model, dataset, img_indices, azimuth_shifts, 84 | elevation_shifts): 85 | """Helper function for generating novel views from specific images in a 86 | dataset. 87 | 88 | Args: 89 | device (torch.device): 90 | model (models.neural_renderer.NeuralRenderer): 91 | dataset (misc.dataloaders.SceneRenderDataset): 92 | img_indices (tuple of ints): Indices of images in dataset to use as 93 | source views for novel view synthesis. 94 | azimuth_shifts (torch.Tensor): Batch of angle shifts at which to generate 95 | novel views. Shape (num_views,). 96 | elevation_shifts (torch.Tensor): Batch of angle shifts at which to 97 | generate novel views. Shape (num_views,). 98 | """ 99 | # Extract image and pose information for all views 100 | data_list = [] 101 | for img_idx in img_indices: 102 | data_list.append(dataset[img_idx]) 103 | imgs_source, azimuth_source, elevation_source = create_batch_from_data_list(data_list) 104 | imgs_source = imgs_source.to(device) 105 | azimuth_source = azimuth_source.to(device) 106 | elevation_source = elevation_source.to(device) 107 | # Generate novel views 108 | return batch_generate_novel_views(model, imgs_source, azimuth_source, 109 | elevation_source, azimuth_shifts, 110 | elevation_shifts) 111 | 112 | 113 | def shapenet_test_novel_views(device, model, dataset, source_scenes_idx=(0, 1, 2, 3), 114 | source_img_idx_shift=64, subsample_target=5): 115 | """Helper function for generating novel views on an archimedean spiral for 116 | the test images for ShapeNet chairs and cars. 117 | 118 | Args: 119 | device (torch.device): 120 | model (models.neural_renderer.NeuralRenderer): 121 | dataset (misc.dataloaders.SceneRenderDataset): Test dataloader for a 122 | ShapeNet dataset. 123 | source_scenes_idx (tuple of ints): Indices of source scenes to use for 124 | generating novel views. 125 | source_img_idx_shift (int): Index of source image for each scene. For 126 | example if 00064.png is the source view, then 127 | source_img_idx_shift = 64. 128 | subsample_target (int): Amount by which to subsample target views. If 129 | set to 1, uses all 250 target views. 130 | """ 131 | num_imgs = len(source_scenes_idx) 132 | # Extract source azimuths and elevations 133 | # Note we can extract this from the first scene since for the shapenet test 134 | # sets, the test poses are the same for all scenes 135 | render_params = dataset[source_img_idx_shift]["render_params"] 136 | azimuth_source = torch.Tensor([render_params["azimuth"]]).to(device) 137 | elevation_source = torch.Tensor([render_params["elevation"]]).to(device) 138 | 139 | # Extract target azimuths and elevations (do not use final view as it is 140 | # slightly off in dataset) 141 | azimuth_target = torch.zeros(dataset.num_imgs_per_scene - 1) 142 | elevation_target = torch.zeros(dataset.num_imgs_per_scene - 1) 143 | for i in range(dataset.num_imgs_per_scene - 1): 144 | render_params = dataset[i]["render_params"] 145 | azimuth_target[i] = torch.Tensor([render_params["azimuth"]]) 146 | elevation_target[i] = torch.Tensor([render_params["elevation"]]) 147 | # Move to GPU 148 | azimuth_target = azimuth_target.to(device) 149 | elevation_target = elevation_target.to(device) 150 | # Subsample 151 | azimuth_target = azimuth_target[::subsample_target] 152 | elevation_target = elevation_target[::subsample_target] 153 | 154 | # Calculate azimuth and elevation shifts 155 | azimuth_shifts = azimuth_target - azimuth_source 156 | elevation_shifts = elevation_target - elevation_source 157 | 158 | # Ensure source angles have same batch_size as imgs_source 159 | azimuth_source = azimuth_source.repeat(num_imgs) 160 | elevation_source = elevation_source.repeat(num_imgs) 161 | 162 | # Create source image batch 163 | imgs_source = torch.zeros(num_imgs, 3, 128, 128).to(device) 164 | for i in range(num_imgs): 165 | scene_idx = source_scenes_idx[i] 166 | source_img_idx = scene_idx * dataset.num_imgs_per_scene + source_img_idx_shift 167 | imgs_source[i] = dataset[source_img_idx]["img"].to(device) 168 | 169 | return batch_generate_novel_views(model, imgs_source, azimuth_source, 170 | elevation_source, azimuth_shifts, 171 | elevation_shifts) 172 | 173 | 174 | def save_generate_novel_views(filename, model, img_source, azimuth_source, 175 | elevation_source, azimuth_shifts, 176 | elevation_shifts): 177 | """Generates novel views of an image by inferring its scene representation, 178 | rotating it and rendering novel views. Saves the source image and novel 179 | views as png files. 180 | 181 | Args: 182 | filename (string): Filename root for saving images. 183 | model (models.neural_renderer.NeuralRenderer): Neural rendering model. 184 | img_source (torch.Tensor): Single image. Shape (channels, height, width). 185 | azimuth_source (torch.Tensor): Azimuth of source image. Shape (1,). 186 | elevation_source (torch.Tensor): Elevation of source image. Shape (1,). 187 | azimuth_shifts (torch.Tensor): Batch of angle shifts at which to 188 | generate novel views. Shape (num_views,). 189 | elevation_shifts (torch.Tensor): Batch of angle shifts at which to 190 | generate novel views. Shape (num_views,). 191 | """ 192 | # Generate novel views 193 | novel_views = generate_novel_views(model, img_source, azimuth_source, 194 | elevation_source, azimuth_shifts, 195 | elevation_shifts) 196 | # Save original image 197 | torchvision.utils.save_image(img_source, filename + '.png', padding=4, 198 | pad_value=1.) 199 | # Save novel views (with white padding) 200 | torchvision.utils.save_image(novel_views, filename + '_novel.png', 201 | padding=4, pad_value=1.) 202 | 203 | 204 | def save_img_sequence_as_gif(img_sequence, filename, nrow=4): 205 | """Given a sequence of images as tensors, saves a gif of the images. 206 | If images are in batches, they are converted to a grid before being 207 | saved. 208 | 209 | Args: 210 | img_sequence (list of torch.Tensor): List of images. Tensors should 211 | have shape either (batch_size, channels, height, width) or shape 212 | (channels, height, width). If there is a batch dimension, images 213 | will be converted to a grid. 214 | filename (string): Path where gif will be saved. Should end in '.gif'. 215 | nrow (int): Number of rows in image grid, if image has batch dimension. 216 | """ 217 | img_grid_sequence = [] 218 | for img in img_sequence: 219 | if len(img.shape) == 4: 220 | img_grid = torchvision.utils.make_grid(img, nrow=nrow) 221 | else: 222 | img_grid = img 223 | # Convert to numpy array and from float in [0, 1] to int in [0, 255] 224 | # which is what imageio expects 225 | img_grid = (img_grid * 255.).byte().cpu().numpy().transpose(1, 2, 0) 226 | img_grid_sequence.append(img_grid) 227 | # Save gif 228 | imageio.mimwrite(filename, img_grid_sequence) 229 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-equivariant-neural-rendering/e7ecfb1b4f93fc89cd4912af64e8223c6523efa8/models/__init__.py -------------------------------------------------------------------------------- /models/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class ConvBlock2d(nn.Module): 6 | """Block of 1x1, 3x3, 1x1 convolutions with non linearities. Shape of input 7 | and output is the same. 8 | 9 | Args: 10 | in_channels (int): Number of channels in input. 11 | num_filters (list of ints): List of two ints with the number of filters 12 | for the first and second conv layers. Third conv layer must have the 13 | same number of input filters as there are channels. 14 | add_groupnorm (bool): If True adds GroupNorm. 15 | """ 16 | def __init__(self, in_channels, num_filters, add_groupnorm=True): 17 | super(ConvBlock2d, self).__init__() 18 | if add_groupnorm: 19 | self.forward_layers = nn.Sequential( 20 | nn.GroupNorm(num_channels_to_num_groups(in_channels), in_channels), 21 | nn.LeakyReLU(0.2, True), 22 | nn.Conv2d(in_channels, num_filters[0], kernel_size=1, stride=1, 23 | bias=False), 24 | nn.GroupNorm(num_channels_to_num_groups(num_filters[0]), num_filters[0]), 25 | nn.LeakyReLU(0.2, True), 26 | nn.Conv2d(num_filters[0], num_filters[1], kernel_size=3, 27 | stride=1, padding=1, bias=False), 28 | nn.GroupNorm(num_channels_to_num_groups(num_filters[1]), num_filters[1]), 29 | nn.LeakyReLU(0.2, True), 30 | nn.Conv2d(num_filters[1], in_channels, kernel_size=1, stride=1, 31 | bias=False) 32 | ) 33 | else: 34 | self.forward_layers = nn.Sequential( 35 | nn.LeakyReLU(0.2, True), 36 | nn.Conv2d(in_channels, num_filters[0], kernel_size=1, stride=1, 37 | bias=True), 38 | nn.LeakyReLU(0.2, True), 39 | nn.Conv2d(num_filters[0], num_filters[1], kernel_size=3, 40 | stride=1, padding=1, bias=True), 41 | nn.LeakyReLU(0.2, True), 42 | nn.Conv2d(num_filters[1], in_channels, kernel_size=1, stride=1, 43 | bias=True) 44 | ) 45 | 46 | def forward(self, inputs): 47 | return self.forward_layers(inputs) 48 | 49 | 50 | class ConvBlock3d(nn.Module): 51 | """Block of 1x1, 3x3, 1x1 convolutions with non linearities. Shape of input 52 | and output is the same. 53 | 54 | Args: 55 | in_channels (int): Number of channels in input. 56 | num_filters (list of ints): List of two ints with the number of filters 57 | for the first and second conv layers. Third conv layer must have the 58 | same number of input filters as there are channels. 59 | add_groupnorm (bool): If True adds BatchNorm. 60 | """ 61 | def __init__(self, in_channels, num_filters, add_groupnorm=True): 62 | super(ConvBlock3d, self).__init__() 63 | if add_groupnorm: 64 | self.forward_layers = nn.Sequential( 65 | nn.GroupNorm(num_channels_to_num_groups(in_channels), in_channels), 66 | nn.LeakyReLU(0.2, True), 67 | nn.Conv3d(in_channels, num_filters[0], kernel_size=1, stride=1, 68 | bias=False), 69 | nn.GroupNorm(num_channels_to_num_groups(num_filters[0]), num_filters[0]), 70 | nn.LeakyReLU(0.2, True), 71 | nn.Conv3d(num_filters[0], num_filters[1], kernel_size=3, 72 | stride=1, padding=1, bias=False), 73 | nn.GroupNorm(num_channels_to_num_groups(num_filters[1]), num_filters[1]), 74 | nn.LeakyReLU(0.2, True), 75 | nn.Conv3d(num_filters[1], in_channels, kernel_size=1, stride=1, 76 | bias=False) 77 | ) 78 | else: 79 | self.forward_layers = nn.Sequential( 80 | nn.LeakyReLU(0.2, True), 81 | nn.Conv3d(in_channels, num_filters[0], kernel_size=1, stride=1, 82 | bias=True), 83 | nn.LeakyReLU(0.2, True), 84 | nn.Conv3d(num_filters[0], num_filters[1], kernel_size=3, 85 | stride=1, padding=1, bias=True), 86 | nn.LeakyReLU(0.2, True), 87 | nn.Conv3d(num_filters[1], in_channels, kernel_size=1, stride=1, 88 | bias=True) 89 | ) 90 | 91 | def forward(self, inputs): 92 | return self.forward_layers(inputs) 93 | 94 | 95 | class ResBlock2d(nn.Module): 96 | """Residual block of 1x1, 3x3, 1x1 convolutions with non linearities. Shape 97 | of input and output is the same. 98 | 99 | Args: 100 | in_channels (int): Number of channels in input. 101 | num_filters (list of ints): List of two ints with the number of filters 102 | for the first and second conv layers. Third conv layer must have the 103 | same number of input filters as there are channels. 104 | add_groupnorm (bool): If True adds GroupNorm. 105 | """ 106 | def __init__(self, in_channels, num_filters, add_groupnorm=True): 107 | super(ResBlock2d, self).__init__() 108 | self.residual_layers = ConvBlock2d(in_channels, num_filters, 109 | add_groupnorm) 110 | 111 | def forward(self, inputs): 112 | return inputs + self.residual_layers(inputs) 113 | 114 | 115 | class ResBlock3d(nn.Module): 116 | """Residual block of 1x1, 3x3, 1x1 convolutions with non linearities. Shape 117 | of input and output is the same. 118 | 119 | Args: 120 | in_channels (int): Number of channels in input. 121 | num_filters (list of ints): List of two ints with the number of filters 122 | for the first and second conv layers. Third conv layer must have the 123 | same number of input filters as there are channels. 124 | add_groupnorm (bool): If True adds GroupNorm. 125 | """ 126 | def __init__(self, in_channels, num_filters, add_groupnorm=True): 127 | super(ResBlock3d, self).__init__() 128 | self.residual_layers = ConvBlock3d(in_channels, num_filters, 129 | add_groupnorm) 130 | 131 | def forward(self, inputs): 132 | return inputs + self.residual_layers(inputs) 133 | 134 | 135 | def num_channels_to_num_groups(num_channels): 136 | """Returns number of groups to use in a GroupNorm layer with a given number 137 | of channels. Note that these choices are hyperparameters. 138 | 139 | Args: 140 | num_channels (int): Number of channels. 141 | """ 142 | if num_channels < 8: 143 | return 1 144 | if num_channels < 32: 145 | return 2 146 | if num_channels < 64: 147 | return 4 148 | if num_channels < 128: 149 | return 8 150 | if num_channels < 256: 151 | return 16 152 | else: 153 | return 32 154 | -------------------------------------------------------------------------------- /models/neural_renderer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from misc.utils import pretty_print_layers_info, count_parameters 4 | from models.submodels import ResNet2d, ResNet3d, Projection, InverseProjection 5 | from models.rotation_layers import SphericalMask, Rotate3d 6 | 7 | 8 | class NeuralRenderer(nn.Module): 9 | """Implements a Neural Renderer with an implicit scene representation that 10 | allows both forward and inverse rendering. 11 | 12 | The forward pass from 3d scene to 2d image is (rendering): 13 | Scene representation (input) -> ResNet3d -> Projection -> ResNet2d -> 14 | Rendered image (output) 15 | 16 | The inverse pass from 2d image to 3d scene is (inverse rendering): 17 | Image (input) -> ResNet2d -> Inverse Projection -> ResNet3d -> Scene 18 | representation (output) 19 | 20 | Args: 21 | img_shape (tuple of ints): Shape of the image input to the model. Should 22 | be of the form (channels, height, width). 23 | channels_2d (tuple of ints): List of channels for 2D layers in inverse 24 | rendering model (image -> scene). 25 | strides_2d (tuple of ints): List of strides for 2D layers in inverse 26 | rendering model (image -> scene). 27 | channels_3d (tuple of ints): List of channels for 3D layers in inverse 28 | rendering model (image -> scene). 29 | strides_3d (tuple of ints): List of channels for 3D layers in inverse 30 | rendering model (image -> scene). 31 | num_channels_inv_projection (tuple of ints): Number of channels in each 32 | layer of inverse projection unit from 2D to 3D. 33 | num_channels_projection (tuple of ints): Number of channels in each 34 | layer of projection unit from 2D to 3D. 35 | mode (string): One of 'bilinear' and 'nearest' for interpolation mode 36 | used when rotating voxel grid. 37 | 38 | Notes: 39 | Given the inverse rendering channels and strides, the model will 40 | automatically build a forward renderer as the transpose of the inverse 41 | renderer. 42 | """ 43 | def __init__(self, img_shape, channels_2d, strides_2d, channels_3d, 44 | strides_3d, num_channels_inv_projection, num_channels_projection, 45 | mode='bilinear'): 46 | super(NeuralRenderer, self).__init__() 47 | self.img_shape = img_shape 48 | self.channels_2d = channels_2d 49 | self.strides_2d = strides_2d 50 | self.channels_3d = channels_3d 51 | self.strides_3d = strides_3d 52 | self.num_channels_projection = num_channels_projection 53 | self.num_channels_inv_projection = num_channels_inv_projection 54 | self.mode = mode 55 | 56 | # Initialize layers 57 | 58 | # Inverse pass (image -> scene) 59 | # First transform image into a 2D representation 60 | self.inv_transform_2d = ResNet2d(self.img_shape, channels_2d, 61 | strides_2d) 62 | 63 | # Perform inverse projection from 2D to 3D 64 | input_shape = self.inv_transform_2d.output_shape 65 | self.inv_projection = InverseProjection(input_shape, num_channels_inv_projection) 66 | 67 | # Transform 3D inverse projection into a scene representation 68 | self.inv_transform_3d = ResNet3d(self.inv_projection.output_shape, 69 | channels_3d, strides_3d) 70 | # Add rotation layer 71 | self.rotation_layer = Rotate3d(self.mode) 72 | 73 | # Forward pass (scene -> image) 74 | # Forward renderer is just transpose of inverse renderer, so flip order 75 | # of channels and strides 76 | # Transform scene representation to 3D features 77 | forward_channels_3d = list(reversed(channels_3d))[1:] + [channels_3d[0]] 78 | forward_strides_3d = [-stride if abs(stride) == 2 else 1 for stride in list(reversed(strides_3d[1:]))] + [strides_3d[0]] 79 | self.transform_3d = ResNet3d(self.inv_transform_3d.output_shape, 80 | forward_channels_3d, forward_strides_3d) 81 | 82 | # Layer for projection of 3D representation to 2D representation 83 | self.projection = Projection(self.transform_3d.output_shape, 84 | num_channels_projection) 85 | 86 | # Transform 2D features to rendered image 87 | forward_channels_2d = list(reversed(channels_2d))[1:] + [channels_2d[0]] 88 | forward_strides_2d = [-stride if abs(stride) == 2 else 1 for stride in list(reversed(strides_2d[1:]))] + [strides_2d[0]] 89 | final_conv_channels_2d = img_shape[0] 90 | self.transform_2d = ResNet2d(self.projection.output_shape, 91 | forward_channels_2d, forward_strides_2d, 92 | final_conv_channels_2d) 93 | 94 | # Scene representation shape is output of inverse 3D transformation 95 | self.scene_shape = self.inv_transform_3d.output_shape 96 | # Add spherical mask before scene rotation 97 | self.spherical_mask = SphericalMask(self.scene_shape) 98 | 99 | def render(self, scene): 100 | """Renders a scene to an image. 101 | 102 | Args: 103 | scene (torch.Tensor): Shape (batch_size, channels, depth, height, width). 104 | """ 105 | features_3d = self.transform_3d(scene) 106 | features_2d = self.projection(features_3d) 107 | return torch.sigmoid(self.transform_2d(features_2d)) 108 | 109 | def inverse_render(self, img): 110 | """Maps an image to a (spherical) scene representation. 111 | 112 | Args: 113 | img (torch.Tensor): Shape (batch_size, channels, height, width). 114 | """ 115 | # Transform image to 2D features 116 | features_2d = self.inv_transform_2d(img) 117 | # Perform inverse projection 118 | features_3d = self.inv_projection(features_2d) 119 | # Map 3D features to scene representation 120 | scene = self.inv_transform_3d(features_3d) 121 | # Ensure scene is spherical 122 | return self.spherical_mask(scene) 123 | 124 | def rotate(self, scene, rotation_matrix): 125 | """Rotates scene by rotation matrix. 126 | 127 | Args: 128 | scene (torch.Tensor): Shape (batch_size, channels, depth, height, width). 129 | rotation_matrix (torch.Tensor): Batch of rotation matrices of shape 130 | (batch_size, 3, 3). 131 | """ 132 | return self.rotation_layer(scene, rotation_matrix) 133 | 134 | def rotate_source_to_target(self, scene, azimuth_source, elevation_source, 135 | azimuth_target, elevation_target): 136 | """Assuming the scene is being observed by a camera at 137 | (azimuth_source, elevation_source), rotates scene so camera is observing 138 | it at (azimuth_target, elevation_target). 139 | 140 | Args: 141 | scene (torch.Tensor): Shape (batch_size, channels, depth, height, width). 142 | azimuth_source (torch.Tensor): Shape (batch_size,). Azimuth of source. 143 | elevation_source (torch.Tensor): Shape (batch_size,). Elevation of source. 144 | azimuth_target (torch.Tensor): Shape (batch_size,). Azimuth of target. 145 | elevation_target (torch.Tensor): Shape (batch_size,). Elevation of target. 146 | """ 147 | return self.rotation_layer.rotate_source_to_target(scene, 148 | azimuth_source, 149 | elevation_source, 150 | azimuth_target, 151 | elevation_target) 152 | 153 | def forward(self, batch): 154 | """Given a batch of images and poses, infers scene representations, 155 | rotates them into target poses and renders them into images. 156 | 157 | Args: 158 | batch (dict): A batch of images and poses as returned by 159 | misc.dataloaders.scene_render_dataloader. 160 | 161 | Notes: 162 | This *must* be a batch as returned by the scene render dataloader, 163 | i.e. the batch must be composed of pairs of images of the same 164 | scene. Specifically, the first time in the batch should be an image 165 | of scene A and the second item in the batch should be an image of 166 | scene A observed from a different pose. The third item should be an 167 | image of scene B and the fourth item should be an image scene B 168 | observed from a different pose (and so on). 169 | """ 170 | # Slightly hacky way of extracting model device. Device on which 171 | # spherical is stored is the one where model is too 172 | device = self.spherical_mask.mask.device 173 | imgs = batch["img"].to(device) 174 | params = batch["render_params"] 175 | azimuth = params["azimuth"].to(device) 176 | elevation = params["elevation"].to(device) 177 | 178 | # Infer scenes from images 179 | scenes = self.inverse_render(imgs) 180 | 181 | # Rotate scenes so that for every pair of rendered images, the 1st 182 | # one will be reconstructed as the 2nd and then 2nd will be 183 | # reconstructed as the 1st 184 | swapped_idx = get_swapped_indices(azimuth.shape[0]) 185 | 186 | # Each pair of indices in the azimuth vector corresponds to the same 187 | # scene at two different angles. Therefore performing a pairwise swap, 188 | # the first index will correspond to the second index in the original 189 | # vector. Since we want to rotate camera angle 1 to camera angle 2 and 190 | # vice versa, we can use these swapped angles to define a target 191 | # position for the camera 192 | azimuth_swapped = azimuth[swapped_idx] 193 | elevation_swapped = elevation[swapped_idx] 194 | scenes_swapped = \ 195 | self.rotate_source_to_target(scenes, azimuth, elevation, 196 | azimuth_swapped, elevation_swapped) 197 | 198 | # Swap scenes, so rotated scenes match with original inferred scene. 199 | # Specifically, we have images x1, x2 from which we inferred the scenes 200 | # z1, z2. We then rotated these scenes into z1' and z2'. Now z1' should 201 | # be almost equal to z2 and z2' should be almost equal to z1, so we swap 202 | # the order of z1', z2' to z2', z1' so we can easily render them to 203 | # x1 and x2. 204 | scenes_rotated = scenes_swapped[swapped_idx] 205 | 206 | # Render scene using model 207 | rendered = self.render(scenes_rotated) 208 | 209 | return imgs, rendered, scenes, scenes_rotated 210 | 211 | def print_model_info(self): 212 | """Prints detailed information about model, such as how input shape is 213 | transformed to output shape and how many parameters are trained in each 214 | block. 215 | """ 216 | print("Forward renderer") 217 | print("----------------\n") 218 | pretty_print_layers_info(self.transform_3d, "3D Layers") 219 | print("\n") 220 | pretty_print_layers_info(self.projection, "Projection") 221 | print("\n") 222 | pretty_print_layers_info(self.transform_2d, "2D Layers") 223 | print("\n") 224 | 225 | print("Inverse renderer") 226 | print("----------------\n") 227 | pretty_print_layers_info(self.inv_transform_2d, "Inverse 2D Layers") 228 | print("\n") 229 | pretty_print_layers_info(self.inv_projection, "Inverse Projection") 230 | print("\n") 231 | pretty_print_layers_info(self.inv_transform_3d, "Inverse 3D Layers") 232 | print("\n") 233 | 234 | print("Scene Representation:") 235 | print("\tShape: {}".format(self.scene_shape)) 236 | # Size of scene representation corresponds to non zero entries of 237 | # spherical mask 238 | print("\tSize: {}\n".format(int(self.spherical_mask.mask.sum().item()))) 239 | 240 | print("Number of parameters: {}\n".format(count_parameters(self))) 241 | 242 | def get_model_config(self): 243 | """Returns the complete model configuration as a dict.""" 244 | return { 245 | "img_shape": self.img_shape, 246 | "channels_2d": self.channels_2d, 247 | "strides_2d": self.strides_2d, 248 | "channels_3d": self.channels_3d, 249 | "strides_3d": self.strides_3d, 250 | "num_channels_inv_projection": self.num_channels_inv_projection, 251 | "num_channels_projection": self.num_channels_projection, 252 | "mode": self.mode 253 | } 254 | 255 | def save(self, filename): 256 | """Saves model and its config. 257 | 258 | Args: 259 | filename (string): Path where model will be saved. Should end with 260 | '.pt' or '.pth'. 261 | """ 262 | torch.save({ 263 | "config": self.get_model_config(), 264 | "state_dict": self.state_dict() 265 | }, filename) 266 | 267 | 268 | def load_model(filename): 269 | """Loads a NeuralRenderer model from saved model config and weights. 270 | 271 | Args: 272 | filename (string): Path where model was saved. 273 | """ 274 | model_dict = torch.load(filename, map_location="cpu") 275 | config = model_dict["config"] 276 | # Initialize a model based on config 277 | model = NeuralRenderer( 278 | img_shape=config["img_shape"], 279 | channels_2d=config["channels_2d"], 280 | strides_2d=config["strides_2d"], 281 | channels_3d=config["channels_3d"], 282 | strides_3d=config["strides_3d"], 283 | num_channels_inv_projection=config["num_channels_inv_projection"], 284 | num_channels_projection=config["num_channels_projection"], 285 | mode=config["mode"] 286 | ) 287 | # Load weights into model 288 | model.load_state_dict(model_dict["state_dict"]) 289 | return model 290 | 291 | 292 | def get_swapped_indices(length): 293 | """Returns a list of swapped index pairs. For example, if length = 6, then 294 | function returns [1, 0, 3, 2, 5, 4], i.e. every index pair is swapped. 295 | 296 | Args: 297 | length (int): Length of swapped indices. 298 | """ 299 | return [i + 1 if i % 2 == 0 else i - 1 for i in range(length)] 300 | -------------------------------------------------------------------------------- /models/rotation_layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from transforms3d.rotations import rotate, rotate_source_to_target 4 | 5 | 6 | class Rotate3d(nn.Module): 7 | """Layer used to rotate 3D feature maps. 8 | 9 | Args: 10 | mode (string): One of 'bilinear' and 'nearest' for interpolation mode 11 | used when resampling rotated values on the grid. 12 | """ 13 | def __init__(self, mode='bilinear'): 14 | super(Rotate3d, self).__init__() 15 | self.mode = mode 16 | 17 | def forward(self, volume, rotation_matrix): 18 | """Rotates the volume by the rotation matrix. 19 | 20 | Args: 21 | volume (torch.Tensor): Shape (batch_size, channels, depth, height, width). 22 | rotation_matrix (torch.Tensor): Batch of rotation matrices of shape 23 | (batch_size, 3, 3). 24 | """ 25 | return rotate(volume, rotation_matrix, mode=self.mode) 26 | 27 | def rotate_source_to_target(self, volume, azimuth_source, elevation_source, 28 | azimuth_target, elevation_target): 29 | """Rotates volume from source coordinate frame to target coordinate 30 | frame. 31 | 32 | Args: 33 | volume (torch.Tensor): Shape (batch_size, channels, depth, height, width). 34 | azimuth_source (torch.Tensor): Shape (batch_size,). Azimuth of 35 | source view in degrees. 36 | elevation_source (torch.Tensor): Shape (batch_size,). Elevation of 37 | source view in degrees. 38 | azimuth_target (torch.Tensor): Shape (batch_size,). Azimuth of 39 | target view in degrees. 40 | elevation_target (torch.Tensor): Shape (batch_size,). Elevation of 41 | target view in degrees. 42 | """ 43 | return rotate_source_to_target(volume, azimuth_source, elevation_source, 44 | azimuth_target, elevation_target, 45 | mode=self.mode) 46 | 47 | 48 | class SphericalMask(nn.Module): 49 | """Sets all features outside the largest sphere embedded in a cubic tensor 50 | to zero. 51 | 52 | Args: 53 | input_shape (tuple of ints): Shape of 3D feature map. Should have the 54 | form (channels, depth, height, width). 55 | radius_fraction (float): Fraction of radius to keep as non zero. E.g. 56 | if radius_fraction=0.9, only elements within the sphere of radius 57 | 0.9 of half the cube length will not be zeroed. Must be in [0., 1.]. 58 | """ 59 | def __init__(self, input_shape, radius_fraction=1.): 60 | super(SphericalMask, self).__init__() 61 | # Check input 62 | _, depth, height, width = input_shape 63 | assert depth == height, "Depth, height, width are {}, {}, {} but must be equal.".format(depth, height, width) 64 | assert height == width, "Depth, height, width are {}, {}, {} but must be equal.".format(depth, height, width) 65 | 66 | self.input_shape = input_shape 67 | 68 | # Build spherical mask 69 | mask = torch.ones(input_shape) 70 | mask_center = (depth - 1) / 2 # Center of cube (in terms of index) 71 | radius = (depth - 1) / 2 # Distance from center to edge of cube is radius of sphere 72 | for i in range(depth): 73 | for j in range(height): 74 | for k in range(width): 75 | squared_distance = (mask_center - i) ** 2 + (mask_center - j) ** 2 + (mask_center - k) ** 2 76 | if squared_distance > (radius_fraction * radius) ** 2: 77 | mask[:, i, j, k] = 0. 78 | 79 | # Register buffer adds a key to the state dict of the model. This will 80 | # track the attribute without registering it as a learnable parameter. 81 | # This also means mask will be moved to device when calling 82 | # model.to(device) 83 | self.register_buffer('mask', mask) 84 | 85 | def forward(self, volume): 86 | """Applies a spherical mask to input. 87 | 88 | Args: 89 | volume (torch.Tensor): Shape (batch_size, channels, depth, height, width). 90 | """ 91 | return volume * self.mask 92 | -------------------------------------------------------------------------------- /models/submodels.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from models.layers import ResBlock2d, ResBlock3d, num_channels_to_num_groups 4 | 5 | 6 | class ResNet2d(nn.Module): 7 | """ResNets for 2d inputs. 8 | 9 | Args: 10 | input_shape (tuple of ints): Shape of the input to the model. Should be 11 | of the form (channels, height, width). 12 | channels (tuple of ints): List of number of channels for each layer. 13 | Length of this tuple corresponds to number of layers in network. 14 | strides (tuple of ints): List of strides for each layer. Length of this 15 | tuple corresponds to number of layers in network. If stride is 1, a 16 | residual layer is applied. If stride is 2 a convolution with stride 17 | 2 is applied. If stride is -2 a transpose convolution with stride 2 18 | is applied. 19 | final_conv_channels (int): If not 0, a convolution is added as the final 20 | layer, with the number of output channels specified by this int. 21 | filter_multipliers (tuple of ints): Multipliers for filters in residual 22 | layers. 23 | add_groupnorm (bool): If True, adds GroupNorm layers. 24 | 25 | 26 | Notes: 27 | The first layer of this model is a standard convolution to increase the 28 | number of filters. A convolution can optionally be added at the final 29 | layer. 30 | """ 31 | def __init__(self, input_shape, channels, strides, final_conv_channels=0, 32 | filter_multipliers=(1, 1), add_groupnorm=True): 33 | super(ResNet2d, self).__init__() 34 | assert len(channels) == len(strides), "Length of channels tuple is {} and length of strides tuple is {} but " \ 35 | "they should be equal".format(len(channels), len(strides)) 36 | self.input_shape = input_shape 37 | self.channels = channels 38 | self.strides = strides 39 | self.filter_multipliers = filter_multipliers 40 | self.add_groupnorm = add_groupnorm 41 | 42 | # Calculate output_shape: 43 | # Every layer with stride 2 divides the height and width by 2. 44 | # Similarly, every layer with stride -2 multiplies the height and width 45 | # by 2 46 | output_channels, output_height, output_width = input_shape 47 | 48 | for stride in strides: 49 | if stride == 1: 50 | pass 51 | elif stride == 2: 52 | output_height //= 2 53 | output_width //= 2 54 | elif stride == -2: 55 | output_height *= 2 56 | output_width *= 2 57 | 58 | self.output_shape = (channels[-1], output_height, output_width) 59 | 60 | # Build layers 61 | # First layer to increase number of channels before applying residual 62 | # layers 63 | forward_layers = [ 64 | nn.Conv2d(self.input_shape[0], channels[0], kernel_size=1, 65 | stride=1, padding=0) 66 | ] 67 | in_channels = channels[0] 68 | multiplier1x1, multiplier3x3 = filter_multipliers 69 | for out_channels, stride in zip(channels, strides): 70 | if stride == 1: 71 | forward_layers.append( 72 | ResBlock2d(in_channels, 73 | [out_channels * multiplier1x1, out_channels * multiplier3x3], 74 | add_groupnorm=add_groupnorm) 75 | ) 76 | if stride == 2: 77 | forward_layers.append( 78 | nn.Conv2d(in_channels, out_channels, kernel_size=4, 79 | stride=2, padding=1) 80 | ) 81 | if stride == -2: 82 | forward_layers.append( 83 | nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, 84 | stride=2, padding=1) 85 | ) 86 | 87 | # Add non-linearity 88 | if stride == 2 or stride == -2: 89 | forward_layers.append(nn.GroupNorm(num_channels_to_num_groups(out_channels), out_channels)) 90 | forward_layers.append(nn.LeakyReLU(0.2, True)) 91 | 92 | in_channels = out_channels 93 | 94 | if final_conv_channels: 95 | forward_layers.append( 96 | nn.Conv2d(in_channels, final_conv_channels, kernel_size=1, 97 | stride=1, padding=0) 98 | ) 99 | 100 | self.forward_layers = nn.Sequential(*forward_layers) 101 | 102 | def forward(self, inputs): 103 | """Applies ResNet to image-like features. 104 | 105 | Args: 106 | inputs (torch.Tensor): Image-like tensor, with shape (batch_size, 107 | channels, height, width). 108 | """ 109 | return self.forward_layers(inputs) 110 | 111 | 112 | class ResNet3d(nn.Module): 113 | """ResNets for 3d inputs. 114 | 115 | Args: 116 | input_shape (tuple of ints): Shape of the input to the model. Should be 117 | of the form (channels, depth, height, width). 118 | channels (tuple of ints): List of number of channels for each layer. 119 | Length of this tuple corresponds to number of layers in network. 120 | Note that this corresponds to number of *output* channels for each 121 | convolutional layer. 122 | strides (tuple of ints): List of strides for each layer. Length of this 123 | tuple corresponds to number of layers in network. If stride is 1, a 124 | residual layer is applied. If stride is 2 a convolution with stride 125 | 2 is applied. If stride is -2 a transpose convolution with stride 2 126 | is applied. 127 | final_conv_channels (int): If not 0, a convolution is added as the final 128 | layer, with the number of output channels specified by this int. 129 | filter_multipliers (tuple of ints): Multipliers for filters in residual 130 | layers. 131 | add_groupnorm (bool): If True, adds GroupNorm layers. 132 | 133 | Notes: 134 | The first layer of this model is a standard convolution to increase the 135 | number of filters. A convolution can optionally be added at the final 136 | layer. 137 | """ 138 | def __init__(self, input_shape, channels, strides, final_conv_channels=0, 139 | filter_multipliers=(1, 1), add_groupnorm=True): 140 | super(ResNet3d, self).__init__() 141 | assert len(channels) == len(strides), "Length of channels tuple is {} and length of strides tuple is {} but they should be equal".format(len(channels), len(strides)) 142 | self.input_shape = input_shape 143 | self.channels = channels 144 | self.strides = strides 145 | self.filter_multipliers = filter_multipliers 146 | self.add_groupnorm = add_groupnorm 147 | 148 | # Calculate output_shape 149 | output_channels, output_depth, output_height, output_width = input_shape 150 | 151 | for stride in strides: 152 | if stride == 1: 153 | pass 154 | elif stride == 2: 155 | output_depth //= 2 156 | output_height //= 2 157 | output_width //= 2 158 | elif stride == -2: 159 | output_depth *= 2 160 | output_height *= 2 161 | output_width *= 2 162 | 163 | self.output_shape = (channels[-1], output_depth, output_height, output_width) 164 | 165 | # Build layers 166 | # First layer to increase number of channels before applying residual 167 | # layers 168 | forward_layers = [ 169 | nn.Conv3d(self.input_shape[0], channels[0], kernel_size=1, 170 | stride=1, padding=0) 171 | ] 172 | in_channels = channels[0] 173 | multiplier1x1, multiplier3x3 = filter_multipliers 174 | for out_channels, stride in zip(channels, strides): 175 | if stride == 1: 176 | forward_layers.append( 177 | ResBlock3d(in_channels, 178 | [out_channels * multiplier1x1, out_channels * multiplier3x3], 179 | add_groupnorm=add_groupnorm) 180 | ) 181 | if stride == 2: 182 | forward_layers.append( 183 | nn.Conv3d(in_channels, out_channels, kernel_size=4, 184 | stride=2, padding=1) 185 | ) 186 | if stride == -2: 187 | forward_layers.append( 188 | nn.ConvTranspose3d(in_channels, out_channels, kernel_size=4, 189 | stride=2, padding=1) 190 | ) 191 | 192 | # Add non-linearity 193 | if stride == 2 or stride == -2: 194 | forward_layers.append(nn.GroupNorm(num_channels_to_num_groups(out_channels), out_channels)) 195 | forward_layers.append(nn.LeakyReLU(0.2, True)) 196 | 197 | in_channels = out_channels 198 | 199 | if final_conv_channels: 200 | forward_layers.append( 201 | nn.Conv3d(in_channels, final_conv_channels, kernel_size=1, 202 | stride=1, padding=0) 203 | ) 204 | 205 | self.forward_layers = nn.Sequential(*forward_layers) 206 | 207 | def forward(self, inputs): 208 | """Applies ResNet to 3D features. 209 | 210 | Args: 211 | inputs (torch.Tensor): Tensor, with shape (batch_size, channels, 212 | depth, height, width). 213 | """ 214 | return self.forward_layers(inputs) 215 | 216 | 217 | class Projection(nn.Module): 218 | """Performs a projection from a 3D voxel-like feature map to a 2D image-like 219 | feature map. 220 | 221 | Args: 222 | input_shape (tuple of ints): Shape of 3D input, (channels, depth, 223 | height, width). 224 | num_channels (tuple of ints): Number of channels in each layer of the 225 | projection unit. 226 | 227 | Notes: 228 | This layer is inspired by the Projection Unit from 229 | https://arxiv.org/abs/1806.06575. 230 | """ 231 | def __init__(self, input_shape, num_channels): 232 | super(Projection, self).__init__() 233 | self.input_shape = input_shape 234 | self.num_channels = num_channels 235 | self.output_shape = (num_channels[-1],) + input_shape[2:] 236 | # Number of input channels for first 2D convolution is 237 | # channels * depth since we flatten the 3D input 238 | in_channels = self.input_shape[0] * self.input_shape[1] 239 | # Initialize forward pass layers 240 | forward_layers = [] 241 | num_layers = len(num_channels) 242 | for i in range(num_layers): 243 | out_channels = num_channels[i] 244 | forward_layers.append( 245 | nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1) 246 | ) 247 | # Add non linearites, except for last layer 248 | if i != num_layers - 1: 249 | forward_layers.append(nn.GroupNorm(num_channels_to_num_groups(out_channels), out_channels)) 250 | forward_layers.append(nn.LeakyReLU(0.2, True)) 251 | in_channels = out_channels 252 | # Set up forward layers as model 253 | self.forward_layers = nn.Sequential(*forward_layers) 254 | 255 | def forward(self, inputs): 256 | """Reshapes inputs from 3D -> 2D and applies 1x1 convolutions. 257 | 258 | Args: 259 | inputs (torch.Tensor): Voxel like tensor, with shape (batch_size, 260 | channels, depth, height, width). 261 | """ 262 | batch_size, channels, depth, height, width = inputs.shape 263 | # Reshape 3D -> 2D 264 | reshaped = inputs.view(batch_size, channels * depth, height, width) 265 | # 1x1 conv layers 266 | return self.forward_layers(reshaped) 267 | 268 | 269 | class InverseProjection(nn.Module): 270 | """Performs an inverse projection from a 2D feature map to a 3D feature map. 271 | 272 | Args: 273 | input_shape (tuple of ints): Shape of 2D input, (channels, height, width). 274 | num_channels (tuple of ints): Number of channels in each layer of the 275 | projection unit. 276 | 277 | Note: 278 | The depth will be equal to the height and width of the input map. 279 | Therefore, the final number of channels must be divisible by the height 280 | and width of the input. 281 | """ 282 | def __init__(self, input_shape, num_channels): 283 | super(InverseProjection, self).__init__() 284 | self.input_shape = input_shape 285 | self.num_channels = num_channels 286 | assert num_channels[-1] % input_shape[-1] == 0, "Number of output channels is {} which is not divisible by " \ 287 | "width {} of image".format(num_channels[-1], input_shape[-1]) 288 | self.output_shape = (num_channels[-1] // input_shape[-1], input_shape[-1]) + input_shape[1:] 289 | 290 | # Initialize forward pass layers 291 | in_channels = self.input_shape[0] 292 | forward_layers = [] 293 | num_layers = len(num_channels) 294 | for i in range(num_layers): 295 | out_channels = num_channels[i] 296 | forward_layers.append( 297 | nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, 298 | padding=0) 299 | ) 300 | # Add non linearites, except for last layer 301 | if i != num_layers - 1: 302 | forward_layers.append(nn.GroupNorm(num_channels_to_num_groups(out_channels), out_channels)) 303 | forward_layers.append(nn.LeakyReLU(0.2, True)) 304 | in_channels = out_channels 305 | # Set up forward layers as model 306 | self.forward_layers = nn.Sequential(*forward_layers) 307 | 308 | def forward(self, inputs): 309 | """Applies convolutions and reshapes outputs from 2D -> 3D. 310 | 311 | Args: 312 | inputs (torch.Tensor): Image like tensor, with shape (batch_size, 313 | channels, height, width). 314 | """ 315 | # 1x1 conv layers 316 | features = self.forward_layers(inputs) 317 | # Reshape 3D -> 2D 318 | batch_size = inputs.shape[0] 319 | return features.view(batch_size, *self.output_shape) 320 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | imageio 3 | Pillow 4 | torch==1.4.0 5 | torchvision 6 | pytorch-msssim 7 | -------------------------------------------------------------------------------- /trained-models/chairs.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-equivariant-neural-rendering/e7ecfb1b4f93fc89cd4912af64e8223c6523efa8/trained-models/chairs.pt -------------------------------------------------------------------------------- /training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-equivariant-neural-rendering/e7ecfb1b4f93fc89cd4912af64e8223c6523efa8/training/__init__.py -------------------------------------------------------------------------------- /training/training.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import torch.nn as nn 4 | from models.neural_renderer import get_swapped_indices 5 | from pytorch_msssim import SSIM 6 | from torchvision.utils import save_image 7 | 8 | 9 | class Trainer(): 10 | """Class used to train neural renderers. 11 | 12 | Args: 13 | device (torch.device): Device to train model on. 14 | model (models.neural_renderer.NeuralRenderer): Model to train. 15 | lr (float): Learning rate. 16 | rendering_loss_type (string): One of 'l1', 'l2'. 17 | ssim_loss_weight (float): Weight assigned to SSIM loss. 18 | """ 19 | def __init__(self, device, model, lr=2e-4, rendering_loss_type='l1', 20 | ssim_loss_weight=0.05): 21 | self.device = device 22 | self.model = model 23 | self.lr = lr 24 | self.rendering_loss_type = rendering_loss_type 25 | self.ssim_loss_weight = ssim_loss_weight 26 | self.use_ssim = self.ssim_loss_weight != 0 27 | # If False doesn't save losses in loss history 28 | self.register_losses = True 29 | # Check if model is multi-gpu 30 | self.multi_gpu = isinstance(self.model, nn.DataParallel) 31 | 32 | # Initialize optimizer 33 | self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr) 34 | 35 | # Initialize loss functions 36 | # For rendered images 37 | if self.rendering_loss_type == 'l1': 38 | self.loss_func = nn.L1Loss() 39 | elif self.rendering_loss_type == 'l2': 40 | self.loss_func = nn.MSELoss() 41 | 42 | # For SSIM 43 | if self.use_ssim: 44 | self.ssim_loss_func = SSIM(data_range=1.0, size_average=True, 45 | channel=3, nonnegative_ssim=False) 46 | 47 | # Loss histories 48 | self.recorded_losses = ["total", "regression", "ssim"] 49 | self.loss_history = {loss_type: [] for loss_type in self.recorded_losses} 50 | self.epoch_loss_history = {loss_type: [] for loss_type in self.recorded_losses} 51 | self.val_loss_history = {loss_type: [] for loss_type in self.recorded_losses} 52 | 53 | def train(self, dataloader, epochs, save_dir=None, save_freq=1, 54 | test_dataloader=None): 55 | """Trains a neural renderer model on the given dataloader. 56 | 57 | Args: 58 | dataloader (torch.utils.DataLoader): Dataloader for a 59 | misc.dataloaders.SceneRenderDataset instance. 60 | epochs (int): Number of epochs to train for. 61 | save_dir (string or None): If not None, saves model and generated 62 | images to directory described by save_dir. Note that this 63 | directory should already exist. 64 | save_freq (int): Frequency with which to save model. 65 | test_dataloader (torch.utils.DataLoader or None): If not None, will 66 | test model on this dataset after every epoch. 67 | """ 68 | if save_dir is not None: 69 | # Extract one batch of data 70 | for batch in dataloader: 71 | break 72 | # Save original images 73 | save_image(batch["img"], save_dir + "/imgs_ground_truth.png", nrow=4) 74 | # Store batch to check how rendered images improve during training 75 | self.fixed_batch = batch 76 | # Render images before any training 77 | rendered = self._render_fixed_img() 78 | save_image(rendered.cpu(), 79 | save_dir + "/imgs_gen_{}.png".format(str(0).zfill(3)), nrow=4) 80 | 81 | for epoch in range(epochs): 82 | print("\nEpoch {}".format(epoch + 1)) 83 | self._train_epoch(dataloader) 84 | # Update epoch loss history with mean loss over epoch 85 | for loss_type in self.recorded_losses: 86 | self.epoch_loss_history[loss_type].append( 87 | sum(self.loss_history[loss_type][-len(dataloader):]) / len(dataloader) 88 | ) 89 | # Print epoch losses 90 | print("Mean epoch loss:") 91 | self._print_losses(epoch_loss=True) 92 | 93 | # Optionally save generated images, losses and model 94 | if save_dir is not None: 95 | # Save generated images 96 | rendered = self._render_fixed_img() 97 | save_image(rendered.cpu(), 98 | save_dir + "/imgs_gen_{}.png".format(str(epoch + 1).zfill(3)), nrow=4) 99 | # Save losses 100 | with open(save_dir + '/loss_history.json', 'w') as loss_file: 101 | json.dump(self.loss_history, loss_file) 102 | # Save epoch losses 103 | with open(save_dir + '/epoch_loss_history.json', 'w') as loss_file: 104 | json.dump(self.epoch_loss_history, loss_file) 105 | # Save model 106 | if (epoch + 1) % save_freq == 0: 107 | if self.multi_gpu: 108 | self.model.module.save(save_dir + "/model.pt") 109 | else: 110 | self.model.save(save_dir + "/model.pt") 111 | 112 | if test_dataloader is not None: 113 | regression_loss, ssim_loss, total_loss = mean_dataset_loss(self, test_dataloader) 114 | print("Validation:\nRegression: {:.4f}, SSIM: {:.4f}, Total: {:.4f}".format(regression_loss, ssim_loss, total_loss)) 115 | self.val_loss_history["regression"].append(regression_loss) 116 | self.val_loss_history["ssim"].append(ssim_loss) 117 | self.val_loss_history["total"].append(total_loss) 118 | if save_dir is not None: 119 | # Save validation losses 120 | with open(save_dir + '/val_loss_history.json', 'w') as loss_file: 121 | json.dump(self.val_loss_history, loss_file) 122 | # If current validation loss is the lowest, save model as best 123 | # model 124 | if min(self.val_loss_history["total"]) == total_loss: 125 | print("New best model!") 126 | if self.multi_gpu: 127 | self.model.module.save(save_dir + "/best_model.pt") 128 | else: 129 | self.model.save(save_dir + "/best_model.pt") 130 | 131 | # Save model after training 132 | if save_dir is not None: 133 | if self.multi_gpu: 134 | self.model.module.save(save_dir + "/model.pt") 135 | else: 136 | self.model.save(save_dir + "/model.pt") 137 | 138 | def _train_epoch(self, dataloader): 139 | """Trains model for a single epoch. 140 | 141 | Args: 142 | dataloader (torch.utils.DataLoader): Dataloader for a 143 | misc.dataloaders.SceneRenderDataset instance. 144 | """ 145 | num_iterations = len(dataloader) 146 | for i, batch in enumerate(dataloader): 147 | # Train inverse and forward renderer on batch 148 | self._train_iteration(batch) 149 | 150 | # Print iteration losses 151 | print("{}/{}".format(i + 1, num_iterations)) 152 | self._print_losses() 153 | 154 | def _train_iteration(self, batch): 155 | """Trains model for a single iteration. 156 | 157 | Args: 158 | batch (dict): Batch of data as returned by a Dataloader for a 159 | misc.dataloaders.SceneRenderDataset instance. 160 | """ 161 | imgs, rendered, scenes, scenes_rotated = self.model(batch) 162 | self._optimizer_step(imgs, rendered) 163 | 164 | def _optimizer_step(self, imgs, rendered): 165 | """Updates weights of neural renderer. 166 | 167 | Args: 168 | imgs (torch.Tensor): Ground truth images. Shape 169 | (batch_size, channels, height, width). 170 | rendered (torch.Tensor): Rendered images. Shape 171 | (batch_size, channels, height, width). 172 | """ 173 | self.optimizer.zero_grad() 174 | 175 | loss_regression = self.loss_func(rendered, imgs) 176 | if self.use_ssim: 177 | # We want to maximize SSIM, i.e. minimize -SSIM 178 | loss_ssim = 1. - self.ssim_loss_func(rendered, imgs) 179 | loss_total = loss_regression + self.ssim_loss_weight * loss_ssim 180 | else: 181 | loss_total = loss_regression 182 | 183 | loss_total.backward() 184 | self.optimizer.step() 185 | 186 | # Record total loss 187 | if self.register_losses: 188 | self.loss_history["total"].append(loss_total.item()) 189 | self.loss_history["regression"].append(loss_regression.item()) 190 | # If SSIM is not used, register 0 in logs 191 | if not self.use_ssim: 192 | self.loss_history["ssim"].append(0.) 193 | else: 194 | self.loss_history["ssim"].append(loss_ssim.item()) 195 | 196 | def _render_fixed_img(self): 197 | """Reconstructs fixed batch through neural renderer (by inferring 198 | scenes, rotating them and rerendering). 199 | """ 200 | _, rendered, _, _ = self.model(self.fixed_batch) 201 | return rendered 202 | 203 | def _print_losses(self, epoch_loss=False): 204 | """Prints most recent losses.""" 205 | loss_info = [] 206 | for loss_type in self.recorded_losses: 207 | if epoch_loss: 208 | loss = self.epoch_loss_history[loss_type][-1] 209 | else: 210 | loss = self.loss_history[loss_type][-1] 211 | loss_info += [loss_type, loss] 212 | print("{}: {:.3f}, {}: {:.3f}, {}: {:.3f}".format(*loss_info)) 213 | 214 | 215 | def mean_dataset_loss(trainer, dataloader): 216 | """Returns the mean loss of a model across a dataloader. 217 | 218 | Args: 219 | trainer (training.Trainer): Trainer instance containing model to 220 | evaluate. 221 | dataloader (torch.utils.DataLoader): Dataloader for a 222 | misc.dataloaders.SceneRenderDataset instance. 223 | """ 224 | # No need to calculate gradients during evaluation, so disable gradients to 225 | # increase performance and reduce memory footprint 226 | with torch.no_grad(): 227 | # Ensure calculated losses aren't registered as training losses 228 | trainer.register_losses = False 229 | 230 | regression_loss = 0. 231 | ssim_loss = 0. 232 | total_loss = 0. 233 | for i, batch in enumerate(dataloader): 234 | imgs, rendered, scenes, scenes_rotated = trainer.model(batch) 235 | 236 | # Update losses 237 | # Use _loss_func here and not _loss_renderer since we only want regression term 238 | current_regression_loss = trainer.loss_func(rendered, imgs).item() 239 | if trainer.use_ssim: 240 | current_ssim_loss = 1. - trainer.ssim_loss_func(rendered, imgs).item() 241 | else: 242 | current_ssim_loss = 0. 243 | regression_loss += current_regression_loss 244 | ssim_loss += current_ssim_loss 245 | total_loss += current_regression_loss + trainer.ssim_loss_weight * current_ssim_loss 246 | 247 | # Average losses over dataset 248 | regression_loss /= len(dataloader) 249 | ssim_loss /= len(dataloader) 250 | total_loss /= len(dataloader) 251 | 252 | # Reset boolean so we register losses if we continue training 253 | trainer.register_losses = True 254 | 255 | return regression_loss, ssim_loss, total_loss 256 | -------------------------------------------------------------------------------- /transforms3d/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-equivariant-neural-rendering/e7ecfb1b4f93fc89cd4912af64e8223c6523efa8/transforms3d/__init__.py -------------------------------------------------------------------------------- /transforms3d/conversions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from math import pi 3 | 4 | 5 | def deg2rad(angles): 6 | return angles * pi / 180. 7 | 8 | 9 | def rad2deg(angles): 10 | return angles * 180. / pi 11 | 12 | 13 | def rotation_matrix_y(angle): 14 | """Returns rotation matrix about y-axis. 15 | 16 | Args: 17 | angle (torch.Tensor): Rotation angle in degrees. Shape (batch_size,). 18 | """ 19 | # Initialize rotation matrix 20 | rotation_matrix = torch.zeros(angle.shape[0], 3, 3, device=angle.device) 21 | # Fill out matrix entries 22 | angle_rad = deg2rad(angle) 23 | cos_angle = torch.cos(angle_rad) 24 | sin_angle = torch.sin(angle_rad) 25 | rotation_matrix[:, 0, 0] = cos_angle 26 | rotation_matrix[:, 0, 2] = sin_angle 27 | rotation_matrix[:, 1, 1] = 1. 28 | rotation_matrix[:, 2, 0] = -sin_angle 29 | rotation_matrix[:, 2, 2] = cos_angle 30 | return rotation_matrix 31 | 32 | 33 | def rotation_matrix_z(angle): 34 | """Returns rotation matrix about z-axis. 35 | 36 | Args: 37 | angle (torch.Tensor): Rotation angle in degrees. Shape (batch_size,). 38 | """ 39 | # Initialize rotation matrix 40 | rotation_matrix = torch.zeros(angle.shape[0], 3, 3, device=angle.device) 41 | # Fill out matrix entries 42 | angle_rad = deg2rad(angle) 43 | cos_angle = torch.cos(angle_rad) 44 | sin_angle = torch.sin(angle_rad) 45 | rotation_matrix[:, 0, 0] = cos_angle 46 | rotation_matrix[:, 0, 1] = -sin_angle 47 | rotation_matrix[:, 1, 0] = sin_angle 48 | rotation_matrix[:, 1, 1] = cos_angle 49 | rotation_matrix[:, 2, 2] = 1. 50 | return rotation_matrix 51 | 52 | 53 | def azimuth_elevation_to_rotation_matrix(azimuth, elevation): 54 | """Returns rotation matrix matching the default view (i.e. both azimuth and 55 | elevation are zero) to the view defined by the azimuth, elevation pair. 56 | 57 | 58 | Args: 59 | azimuth (torch.Tensor): Shape (batch_size,). Azimuth of camera in 60 | degrees. 61 | elevation (torch.Tensor): Shape (batch_size,). Elevation of camera in 62 | degrees. 63 | 64 | Notes: 65 | The azimuth and elevation refer to the position of the camera. This 66 | function returns the rotation of the *scene representation*, i.e. the 67 | inverse of the camera transformation. 68 | """ 69 | # In the coordinate system we define (see README), azimuth rotation 70 | # corresponds to negative rotation about y axis and elevation rotation to a 71 | # negative rotation about z axis 72 | azimuth_matrix = rotation_matrix_y(-azimuth) 73 | elevation_matrix = rotation_matrix_z(-elevation) 74 | # We first perform elevation rotation followed by azimuth when rotating camera 75 | camera_matrix = azimuth_matrix @ elevation_matrix 76 | # Object rotation matrix is inverse (i.e. transpose) of camera rotation matrix 77 | return transpose_matrix(camera_matrix) 78 | 79 | 80 | def rotation_matrix_source_to_target(azimuth_source, elevation_source, 81 | azimuth_target, elevation_target): 82 | """Returns rotation matrix matching two views defined by azimuth, elevation 83 | pairs. 84 | 85 | Args: 86 | azimuth_source (torch.Tensor): Shape (batch_size,). Azimuth of source 87 | view in degrees. 88 | elevation_source (torch.Tensor): Shape (batch_size,). Elevation of 89 | source view in degrees. 90 | azimuth_target (torch.Tensor): Shape (batch_size,). Azimuth of target 91 | view in degrees. 92 | elevation_target (torch.Tensor): Shape (batch_size,). Elevation of 93 | target view in degrees. 94 | """ 95 | # Calculate rotation matrix for each view 96 | rotation_source = azimuth_elevation_to_rotation_matrix(azimuth_source, elevation_source) 97 | rotation_target = azimuth_elevation_to_rotation_matrix(azimuth_target, elevation_target) 98 | # Calculate rotation matrix bringing source view to target view (note that 99 | # for rotation matrix, inverse is transpose) 100 | return rotation_target @ transpose_matrix(rotation_source) 101 | 102 | 103 | def transpose_matrix(matrix): 104 | """Transposes a batch of matrices. 105 | 106 | Args: 107 | matrix (torch.Tensor): Batch of matrices of shape (batch_size, n, m). 108 | """ 109 | return matrix.transpose(1, 2) 110 | -------------------------------------------------------------------------------- /transforms3d/rotations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transforms3d.conversions import (rotation_matrix_source_to_target, 3 | transpose_matrix) 4 | 5 | 6 | def rotate(volume, rotation_matrix, mode='bilinear'): 7 | """Performs 3D rotation of tensor volume by rotation matrix. 8 | 9 | Args: 10 | volume (torch.Tensor): Shape (batch_size, channels, depth, height, width). 11 | rotation_matrix (torch.Tensor): Batch of rotation matrices of shape 12 | (batch_size, 3, 3). 13 | mode (string): One of 'bilinear' and 'nearest' for interpolation mode 14 | used in grid_sample. Note that the 'bilinear' option actually 15 | performs trilinear interpolation. 16 | 17 | Notes: 18 | We use align_corners=False in grid_sample. See 19 | https://discuss.pytorch.org/t/what-we-should-use-align-corners-false/22663/9 20 | for a nice illustration of why this is. 21 | """ 22 | # The grid_sample function performs the inverse transformation of the input 23 | # coordinates, so invert matrix to get forward transformation 24 | inverse_rotation_matrix = transpose_matrix(rotation_matrix) 25 | # The grid_sample function swaps x and z (i.e. it assumes the tensor 26 | # dimensions are ordered as z, y, x), therefore we need to flip the rows and 27 | # columns of the matrix (which we can verify is equivalent to multiplying by 28 | # the appropriate permutation matrices) 29 | inverse_rotation_matrix_swap_xz = torch.flip(inverse_rotation_matrix, 30 | dims=(1, 2)) 31 | # Apply transformation to grid 32 | affine_grid = get_affine_grid(inverse_rotation_matrix_swap_xz, volume.shape) 33 | # Regrid volume according to transformation grid 34 | return torch.nn.functional.grid_sample(volume, affine_grid, mode=mode, 35 | align_corners=False) 36 | 37 | 38 | def get_affine_grid(matrix, grid_shape): 39 | """Given a matrix and a grid shape, returns the grid transformed by the 40 | matrix (typically a rotation matrix). 41 | 42 | Args: 43 | matrix (torch.Tensor): Batch of matrices of size (batch_size, 3, 3). 44 | grid_shape (torch.size): Shape of returned affine grid. Should be of the 45 | form (batch_size, channels, depth, height, width). 46 | 47 | Notes: 48 | We use align_corners=False in affine_grid. See 49 | https://discuss.pytorch.org/t/what-we-should-use-align-corners-false/22663/9 50 | for a nice illustration of why this is. 51 | """ 52 | batch_size = matrix.shape[0] 53 | # Last column of affine matrix corresponds to translation which is 0 in our 54 | # case. Therefore pad original matrix with zeros, so shape changes from 55 | # (batch_size, 3, 3) to (batch_size, 3, 4) 56 | translations = torch.zeros(batch_size, 3, 1, device=matrix.device) 57 | affine_matrix = torch.cat([matrix, translations], dim=2) 58 | return torch.nn.functional.affine_grid(affine_matrix, grid_shape, 59 | align_corners=False) 60 | 61 | 62 | def rotate_source_to_target(volume, azimuth_source, elevation_source, 63 | azimuth_target, elevation_target, mode='bilinear'): 64 | """Performs 3D rotation matching two coordinate frames defined by a source 65 | view and a target view. 66 | 67 | Args: 68 | volume (torch.Tensor): Shape (batch_size, channels, depth, height, width). 69 | azimuth_source (torch.Tensor): Shape (batch_size,). Azimuth of source 70 | view in degrees. 71 | elevation_source (torch.Tensor): Shape (batch_size,). Elevation of 72 | source view in degrees. 73 | azimuth_target (torch.Tensor): Shape (batch_size,). Azimuth of target 74 | view in degrees. 75 | elevation_target (torch.Tensor): Shape (batch_size,). Elevation of 76 | target view in degrees. 77 | """ 78 | rotation_matrix = rotation_matrix_source_to_target(azimuth_source, 79 | elevation_source, 80 | azimuth_target, 81 | elevation_target) 82 | return rotate(volume, rotation_matrix, mode=mode) 83 | --------------------------------------------------------------------------------