├── .github ├── CODE_OF_CONDUCT.md └── CONTRIBUTING.md ├── INSTALL.md ├── KITTI.md ├── LICENSE.md ├── MP3D.md ├── QUICKSTART.md ├── README.md ├── REALESTATE.md ├── data ├── create_rgb_dataset.py ├── files │ ├── eval_cached_cameras_mp3d.txt │ ├── eval_cached_cameras_replica.txt │ ├── kitti.txt │ └── realestate.txt ├── habitat_data.py ├── kitti.py ├── realestate10k.py └── scene_episodes │ ├── mp3d_test │ └── dataset_one_ep_per_scene.json.gz │ ├── mp3d_train │ └── dataset_one_ep_per_scene.json.gz │ ├── mp3d_val │ └── dataset_one_ep_per_scene.json.gz │ ├── replica_test │ └── dataset_one_ep_per_scene.json.gz │ └── replica_train │ └── dataset_one_ep_per_scene.json.gz ├── demos ├── Interactive Demo.ipynb ├── Simple Demo.ipynb └── im.jpg ├── download_models.sh ├── evaluation ├── RESULTS.md ├── eval.py ├── eval_kitti.py ├── eval_realestate.py ├── evaluate_perceptualsim.py └── metrics.py ├── geometry └── camera_transformations.py ├── models ├── base_model.py ├── depth_model.py ├── encoderdecoder.py ├── layers │ ├── blocks.py │ ├── normalization.py │ └── z_buffer_layers.py ├── losses │ ├── gan_loss.py │ ├── ssim.py │ └── synthesis.py ├── networks │ ├── architectures.py │ ├── configs.py │ ├── discriminators.py │ ├── pretrained_networks.py │ ├── sync_batchnorm │ │ ├── __init__.py │ │ ├── batchnorm.py │ │ ├── batchnorm_reimpl.py │ │ ├── comm.py │ │ ├── replicate.py │ │ └── unittest.py │ └── utilities.py ├── projection │ ├── depth_manipulator.py │ └── z_buffer_manipulator.py └── z_buffermodel.py ├── options ├── options.py ├── test_options.py └── train_options.py ├── requirements.txt ├── submit.sh ├── submit_slurm_synsin.sh ├── train.py ├── train.sh └── utils ├── geometry.py └── jitter.py /.github/CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | Facebook has adopted a Code of Conduct that we expect project participants to adhere to. 4 | Please read the [full text](https://code.fb.com/codeofconduct/) 5 | so that you can understand what actions will and will not be tolerated. 6 | -------------------------------------------------------------------------------- /.github/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to detectron2 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `master`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to synsin, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. 32 | -------------------------------------------------------------------------------- /INSTALL.md: -------------------------------------------------------------------------------- 1 | 2 |

Setup and Installation

3 | 4 | Install the requirements. 5 | - These are located in `./requirements.txt`. 6 | - They can be installed with `conda create –-name synsin_env –-file requirements.txt`. (Note that this file is missing the tensorboardX dependency which needs to be installed separately if you wish to train a model yourself. It is not necessary for running the demos or evaluation.) 7 | 8 | Or install the requirements yourself. Requirements: 9 | - pytorch=1.4 10 | - torchvision=0.5 11 | - opencv-python (pip) 12 | - [pytorch3d](https://github.com/facebookresearch/pytorch3d). Make sure it's the most recent version, 13 | as the older version does not have the required libraries. 14 | - tensorboardX (pip) 15 | - jupyter; matplotlib 16 | - [habitat-api](https://github.com/facebookresearch/habitat-api) (if you want to use Matterport or Replica datasets) 17 | - [habitat-sim](https://github.com/facebookresearch/habitat-sim) (if you want to use Matterport or Replica datasets) 18 | -------------------------------------------------------------------------------- /KITTI.md: -------------------------------------------------------------------------------- 1 |

KITTI

2 | 3 |

Download

4 | 5 | Download from [continuous_view_synthesis](https://github.com/xuchen-ethz/continuous_view_synthesis). 6 | Store the files in `${KITTI_HOME}/dataset_kitti`. 7 | 8 | ## Train 9 | 10 | ### Update options 11 | 12 | Update the paths in `./options/options.py` for the dataset being used. 13 | 14 | ### Training scripts 15 | Use the `./train.sh` to train one of the models on a single GPU node. 16 | 17 | You can also look at `./submit_slurm_synsin.sh` to see how to modify parameters in the renderer 18 | and run on a slurm cluster. 19 | 20 | ## Evaluate 21 | 22 | To evaluate, we run the following script. This gives us a bunch of generated vs ground truth images. 23 | 24 | ```bash 25 | export KITTI=${KITTI_HOME}/dataset_kitti/images 26 | python evaluation/eval_kitti.py --old_model ${OLD_MODEL} --result_folder ${TEST_FOLDER} 27 | ``` 28 | 29 | We then compare the generated to ground truth images using the evaluation script. 30 | 31 | ```bash 32 | python evaluation/evaluate_perceptualsim.py \ 33 | --folder ${TEST_FOLDER} \ 34 | --pred_image im_B.png \ 35 | --target_image im_res.png \ 36 | --output_file kitti_results 37 | ``` 38 | 39 | The results we get for each model is given in [RESULTS.md](https://github.com/fairinternal/synsin_public/tree/master/evaluation/RESULTS.md). 40 | 41 | If you do not get approximately the same results (some models use noise as input, so there is some randomness), then there is probably an error in your setup: 42 | - Check the libraries. 43 | - Check the data setup is indeed correct. 44 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | 2 | Copyright (c) 2020, Facebook 3 | All rights reserved. 4 | 5 | --------------------------- LICENSE FOR SPADE -------------------------------- 6 | Copyright (C) 2019 NVIDIA Corporation. 7 | 8 | All rights reserved. 9 | Licensed under the [CC BY-NC-SA 4.0](https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode) (**Attribution-NonCommercial-ShareAlike 4.0 International**) 10 | 11 | The code is released for academic research use only. For commercial use, please contact [researchinquiries@nvidia.com](researchinquiries@nvidia.com). 12 | 13 | --------------------------- LICENSE FOR pereptual similarity -------------------------------- 14 | 15 | Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang 16 | All rights reserved. 17 | 18 | Redistribution and use in source and binary forms, with or without 19 | modification, are permitted provided that the following conditions are met: 20 | 21 | * Redistributions of source code must retain the above copyright notice, this 22 | list of conditions and the following disclaimer. 23 | 24 | * Redistributions in binary form must reproduce the above copyright notice, 25 | this list of conditions and the following disclaimer in the documentation 26 | and/or other materials provided with the distribution. 27 | 28 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 29 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 30 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 31 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 32 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 33 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 34 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 35 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 36 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 37 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 38 | 39 | --------------------------- LICENSE FOR ssim-pytorch -------------------------------- 40 | 41 | MIT License 42 | 43 | Copyright (c) 2017 Po-Hsun-Su 44 | 45 | Permission is hereby granted, free of charge, to any person obtaining a copy 46 | of this software and associated documentation files (the "Software"), to deal 47 | in the Software without restriction, including without limitation the rights 48 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 49 | copies of the Software, and to permit persons to whom the Software is 50 | furnished to do so, subject to the following conditions: 51 | 52 | The above copyright notice and this permission notice shall be included in all 53 | copies or substantial portions of the Software. 54 | 55 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 56 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 57 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 58 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 59 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 60 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 61 | SOFTWARE. 62 | --------------------------- LICENSE FOR BigGAN-pytorch -------------------------------- 63 | 64 | MIT License 65 | 66 | Copyright (c) 2019 Andy Brock 67 | 68 | Permission is hereby granted, free of charge, to any person obtaining a copy 69 | of this software and associated documentation files (the "Software"), to deal 70 | in the Software without restriction, including without limitation the rights 71 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 72 | copies of the Software, and to permit persons to whom the Software is 73 | furnished to do so, subject to the following conditions: 74 | 75 | The above copyright notice and this permission notice shall be included in all 76 | copies or substantial portions of the Software. 77 | 78 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 79 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 80 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 81 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 82 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 83 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 84 | SOFTWARE. 85 | 86 | --------------------------- LICENSE FOR continuous-view-synthesis -------------------------------- 87 | 88 | Copyright (c) 2019, Xu Chen 89 | All rights reserved. 90 | 91 | Redistribution and use in source and binary forms, with or without 92 | modification, are permitted provided that the following conditions are met: 93 | 94 | * Redistributions of source code must retain the above copyright notice, this 95 | list of conditions and the following disclaimer. 96 | 97 | * Redistributions in binary form must reproduce the above copyright notice, 98 | this list of conditions and the following disclaimer in the documentation 99 | and/or other materials provided with the distribution. 100 | 101 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 102 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 103 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 104 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 105 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 106 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 107 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 108 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 109 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 110 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 111 | 112 | --------------------------- LICENSE FOR pytorch-CycleGAN-and-pix2pix -------------------------------- 113 | BSD License 114 | 115 | For pytorch-CycleGAN-and-pix2pix software 116 | Copyright (c) 2017, Jun-Yan Zhu and Taesung Park 117 | All rights reserved. 118 | 119 | Redistribution and use in source and binary forms, with or without 120 | modification, are permitted provided that the following conditions are met: 121 | 122 | * Redistributions of source code must retain the above copyright notice, this 123 | list of conditions and the following disclaimer. 124 | 125 | * Redistributions in binary form must reproduce the above copyright notice, 126 | this list of conditions and the following disclaimer in the documentation 127 | and/or other materials provided with the distribution. 128 | 129 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 130 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 131 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 132 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 133 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 134 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 135 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 136 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 137 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 138 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 139 | 140 | --------------------------- LICENSE FOR pix2pix -------------------------------- 141 | BSD License 142 | 143 | For pix2pix software 144 | Copyright (c) 2016, Phillip Isola and Jun-Yan Zhu 145 | All rights reserved. 146 | 147 | Redistribution and use in source and binary forms, with or without 148 | modification, are permitted provided that the following conditions are met: 149 | 150 | * Redistributions of source code must retain the above copyright notice, this 151 | list of conditions and the following disclaimer. 152 | 153 | * Redistributions in binary form must reproduce the above copyright notice, 154 | this list of conditions and the following disclaimer in the documentation 155 | and/or other materials provided with the distribution. 156 | 157 | ----------------------------- LICENSE FOR DCGAN -------------------------------- 158 | BSD License 159 | 160 | For dcgan.torch software 161 | 162 | Copyright (c) 2015, Facebook, Inc. All rights reserved. 163 | 164 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 165 | 166 | Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 167 | 168 | Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 169 | 170 | Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 171 | 172 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 173 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 174 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 175 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 176 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 177 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 178 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 179 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 180 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 181 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /MP3D.md: -------------------------------------------------------------------------------- 1 |

MP3D and Replica

2 | 3 |

Download

4 | 5 |

Matterport3D

6 | 7 | 1. Download habitat: 8 | - [habitat-api](https://github.com/facebookresearch/habitat-api) (if you want to use Matterport or Replica datasets) 9 | - [habitat-sim](https://github.com/facebookresearch/habitat-sim) (if you want to use Matterport or Replica datasets) 10 | 11 | 2. Download the [point nav datasets](https://github.com/facebookresearch/habitat-api#task-datasets). 12 | 13 | 3. Download [MP3D](https://niessner.github.io/Matterport/). 14 | 15 |

Replica

16 | 17 | Do the steps for Matterport. 18 | 19 | Download [Replica](https://github.com/facebookresearch/Replica-Dataset). 20 | 21 | 22 | ## Train 23 | 24 | ### Update options 25 | 26 | Update the paths in `./options/options.py` for the dataset being used. 27 | 28 | ### Training scripts 29 | Use the `./train.sh` to train one of the models on a single GPU node. 30 | 31 | You can also look at `./submit_slurm_synsin.sh` to see how to modify parameters in the renderer 32 | and run on a slurm cluster. 33 | 34 | ## Evaluate 35 | 36 | Run the evaluation to obtain both visibility and invisibility scores. 37 | 38 | Run the following bash command. It will output some sample images, and save the results to a txt file. Make sure to set the options correctly, else this will throw an error, as the results won't be compatible with our given results. 39 | 40 | ```bash 41 | python evaluation/eval.py \ 42 | --result_folder ${TEST_FOLDER} \ 43 | --old_model ${OLD_MODEL} \ 44 | --batch_size 8 --num_workers 10 --images_before_reset 200 \ # It is IMPORTANT to set these correctly 45 | --dataset replica # ONLY if you want to evaluate on replica 46 | ``` 47 | -------------------------------------------------------------------------------- /QUICKSTART.md: -------------------------------------------------------------------------------- 1 |

Getting start with SynSin

2 | 3 | To quickly start using the code, first download the pretrained models. Then use the demo code in order to visualise the models in action. 4 | 5 | 6 |

Download pretrained models

7 | 8 | * Download the models using `bash ./download_models.sh`. This should create a `./modelcheckpoints` directory with a readme.txt explaining the pretrained models. 9 | 10 |

Running demo code

11 | 12 | To just run a model on a given image and set the view transform, we have included a simple demo notebook `./demos/Simple Demo.ipynb`. 13 | 14 | We have also included an interactive demo notebook in `./demos/Interactive Demo.ipynb`. This shows how to load a model, and update camera/view parameters interactively. 15 | 16 | To use our models on unseen images, we recommend looking at these notebooks. 17 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

SynSin: End-to-end View Synthesis from a Single Image (CVPR 2020)

2 | 3 | This is the code for the [CVPR 2020 paper](https://arxiv.org/abs/1912.08804). 4 | This code allows for synthesising of new views of a scene given a single image of an unseen scene at test time. 5 | It is trained with pairs of views in a self-supervised fashion. 6 | It is trained end to end, using GAN techniques and a new differentiable point cloud renderer. 7 | At test time, a single image of an unseen scene is input to the model from which new views are generated. 8 | 9 |

10 | 11 | 12 | 13 | 14 | 15 | 16 | **Fig 1: Generated images at new viewpoints using SynSin.** Given the first image in the video, the model generates all subsequent images along the trajectory. The same model is used for all reconstructions. The scenes were not seen at train time. 17 |

18 | 19 | # Usage 20 | 21 | Note that this repository is a large refactoring of the original code to allow for public release 22 | and to integrate with [pytorch3d](https://github.com/facebookresearch/pytorch3d). 23 | Hence the models/datasets are not necessarily the same as that in the paper, as we cannot release 24 | the saved test images we used. 25 | To compare results, we recommend comparing against the numbers and models in this [repo](./evaluation/RESULTS.md) for fair comparison 26 | and reproducibility. 27 | 28 | ## Setup and Installation 29 | See [INSTALL](./INSTALL.md). 30 | 31 | ## Quickstart 32 | To quickly start using a pretrained model, see [Quickstart](./QUICKSTART.md). 33 | 34 |

Training and evaluating your own model

35 | 36 | To download, train, or evaluate a model on a given dataset, please read the appropriate file. 37 | (Note that we cannot distribute the raw pixels, so we have explained how we downloaded and organised the datasets in the appropriate file.) 38 | 39 | - [RealEstate10K](./REALESTATE.md) 40 | - [MP3D and Replica](./MP3D.md) 41 | - [KITTI](./KITTI.md) 42 | 43 |

Citation

44 | If this work is helpful in your research. Please cite: 45 | 46 | ``` 47 | @inproceedings{wiles2020synsin, 48 | author = {Olivia Wiles and Georgia Gkioxari and Richard Szeliski and 49 | Justin Johnson}, 50 | title = {{SynSin}: {E}nd-to-end View Synthesis from a Single Image}, 51 | booktitle = {CVPR}, 52 | year = {2020} 53 | } 54 | ``` 55 | -------------------------------------------------------------------------------- /REALESTATE.md: -------------------------------------------------------------------------------- 1 |

RealEstate

2 | 3 |

Download

4 | 5 | Download from [RealEstate10K](https://google.github.io/realestate10k/). 6 | Store the files in the following structure. The `${REAL_ESTATE_10K}/test/` and `${REAL_ESTATE_10K}/train` folders store the original text files. 7 | 8 | The frames need to be extracted based on the text files; we extract them to: `${REAL_ESTATE_10K}/frames`. There may be some missing videos, so we use some additional files as described below. 9 | 10 | We use a file `${REAL_ESTATE_10K}/frames/train/video_loc.txt` and `${REAL_ESTATE_10K}/frames/test/video_loc.txt` to store the location of the extracted videos. Finally, for each extracted video located at `${REAL_ESTATE_10K}/frames/train/${path_totrain_vid1}/*.png`, we create a new text file `${REAL_ESTATE_10K}/frames/train/${path_totrain_vid1}.txt` which stores the metadata for each frame (this is necessary as there may be some errors in the extraction process). The `${REAL_ESTATE_10K}/frames/train/${path_totrain_vid1}.txt` file is in the same structure as the original text file, except all rows containing images that were not extracted, have been removed. 11 | 12 | After following the above, you should have the following structure: 13 | 14 | ```bash 15 | - ${REAL_ESTATE_10K}/test/*.txt 16 | 17 | - ${REAL_ESTATE_10K}/train/*.txt 18 | 19 | - ${REAL_ESTATE_10K}/frames/train/ 20 | - ${REAL_ESTATE_10K}/frames/train/video_loc.txt 21 | - ${REAL_ESTATE_10K}/frames/train/${path_totrain_vid1}/*.png 22 | - ${REAL_ESTATE_10K}/frames/train/${path_totrain_vid1}.txt 23 | ... 24 | - ${REAL_ESTATE_10K}/frames/train/${path_totrain_vidN}/*.png 25 | - ${REAL_ESTATE_10K}/frames/train/${path_totrain_vidN}.txt 26 | 27 | - ${REAL_ESTATE_10K}/frames/test/ 28 | - ${REAL_ESTATE_10K}/frames/test/video_loc.txt 29 | - ${REAL_ESTATE_10K}/frames/test/${path_totest_vid1}/*.png 30 | - ${REAL_ESTATE_10K}/frames/test/${path_totest_vid1}.txt 31 | ... 32 | - ${REAL_ESTATE_10K}/frames/test/${path_totest_vidN}/*.png 33 | - ${REAL_ESTATE_10K}/frames/test/${path_totest_vidN}.txt 34 | ``` 35 | 36 | where `${REAL_ESTATE_10K}/frames/train/video_loc.txt` contains: 37 | 38 | ```bash 39 | ${path_totrain_vid1} 40 | ... 41 | ${path_totrain_vidN} 42 | ``` 43 | 44 | ## Train 45 | 46 | ### Update options 47 | 48 | Update the paths in `./options/options.py` for the dataset being used. 49 | 50 | ### Training scripts 51 | Use the `./train.sh` to train one of the models on a single GPU node. 52 | 53 | You can also look at `./submit_slurm_synsin.sh` to see how to modify parameters in the renderer 54 | and run on a slurm cluster. 55 | 56 | ## Evaluate 57 | 58 | To evaluate, we run the following script. This gives us a bunch of generated vs ground truth images. 59 | 60 | ```bash 61 | export REALESTATE=${REAL_ESTATE_10K}/frames/test/ 62 | python evaluation/eval_realestate.py --old_model ${OLD_MODEL} --result_folder ${TEST_FOLDER} 63 | ``` 64 | 65 | We then compare the generated to ground truth images using the evaluation script. 66 | 67 | ```bash 68 | python evaluation/evaluate_perceptualsim.py \ 69 | --folder ${TEST_FOLDER} \ 70 | --pred_image output_image_.png \ 71 | --target_image tgt_image_.png \ 72 | --output_file ${TEST_FOLDER}/realestate_results \ 73 | --take_every_other # Used for RealEstate10K when comparing to methods that uses 2 images per output 74 | ``` 75 | 76 | The results we get for each model is given in [RESULTS.md](https://github.com/fairinternal/synsin_public/tree/master/evaluation/RESULTS.md). 77 | 78 | If you do not get approximately the same results (some models use noise as input, so there is some randomness), then there is probably an error in your setup: 79 | - Check the libraries. 80 | - Check the data setup is indeed correct. 81 | 82 | -------------------------------------------------------------------------------- /data/files/eval_cached_cameras_mp3d.txt: -------------------------------------------------------------------------------- 1 | 6.228155493736267090e-01 2 | 1.615185663104057312e-02 3 | 7.822020053863525391e-01 4 | 1.141124343872070312e+01 5 | 6.692571192979812622e-02 6 | 9.950222969055175781e-01 7 | -7.383493334054946899e-02 8 | -2.435735702514648438e+00 9 | -7.795009613037109375e-01 10 | 9.833496809005737305e-02 11 | 6.186343431472778320e-01 12 | 8.109766006469726562e+00 13 | 0.000000000000000000e+00 14 | 0.000000000000000000e+00 15 | 0.000000000000000000e+00 16 | 1.000000000000000000e+00 17 | -2.003425508737564087e-01 18 | 3.667427459731698036e-03 19 | -9.797190427780151367e-01 20 | -1.511021614074707031e+01 21 | 6.621651351451873779e-02 22 | 9.977570772171020508e-01 23 | -9.805651381611824036e-03 24 | -5.176193714141845703e-01 25 | 9.774856567382812500e-01 26 | -6.683807075023651123e-02 27 | -2.001360505819320679e-01 28 | 1.007463550567626953e+01 29 | 0.000000000000000000e+00 30 | 0.000000000000000000e+00 31 | 0.000000000000000000e+00 32 | 1.000000000000000000e+00 33 | 9.151580333709716797e-01 34 | 7.105614524334669113e-03 35 | 4.030326306819915771e-01 36 | 1.145830249786376953e+01 37 | 1.451579481363296509e-02 38 | 9.986152052879333496e-01 39 | -5.056668072938919067e-02 40 | -1.288629293441772461e+00 41 | -4.028338193893432617e-01 42 | 5.212683975696563721e-02 43 | 9.137875437736511230e-01 44 | -7.844190597534179688e-01 45 | 0.000000000000000000e+00 46 | 0.000000000000000000e+00 47 | 0.000000000000000000e+00 48 | 1.000000000000000000e+00 49 | -9.488242268562316895e-01 50 | -1.046648994088172913e-02 51 | 3.156312406063079834e-01 52 | -1.485589122772216797e+01 53 | -2.146627940237522125e-02 54 | 9.992765784263610840e-01 55 | -3.139362484216690063e-02 56 | -2.338595151901245117e+00 57 | -3.150743246078491211e-01 58 | -3.656245768070220947e-02 59 | -9.483624696731567383e-01 60 | -2.555275726318359375e+01 61 | 0.000000000000000000e+00 62 | 0.000000000000000000e+00 63 | 0.000000000000000000e+00 64 | 1.000000000000000000e+00 65 | 2.377092465758323669e-02 66 | -5.511723831295967102e-02 67 | 9.981968998908996582e-01 68 | -2.652645707130432129e-01 69 | 6.205340474843978882e-02 70 | 9.966350793838500977e-01 71 | 5.355326831340789795e-02 72 | -6.461702287197113037e-02 73 | -9.977897405624389648e-01 74 | 6.066850572824478149e-02 75 | 2.711115032434463501e-02 76 | -2.202825355529785156e+01 77 | 0.000000000000000000e+00 78 | 0.000000000000000000e+00 79 | 0.000000000000000000e+00 80 | 1.000000000000000000e+00 81 | 8.768993429839611053e-03 82 | 7.599002867937088013e-02 83 | -9.970700144767761230e-01 84 | -9.225944519042968750e+00 85 | -3.119941242039203644e-02 86 | 9.966437220573425293e-01 87 | 7.568314671516418457e-02 88 | -1.081742167472839355e+00 89 | 9.994747042655944824e-01 90 | 3.044433332979679108e-02 91 | 1.111040636897087097e-02 92 | 1.758421897888183594e+01 93 | 0.000000000000000000e+00 94 | 0.000000000000000000e+00 95 | 0.000000000000000000e+00 96 | 1.000000000000000000e+00 97 | -1.274008750915527344e-01 98 | -6.224572286009788513e-02 99 | 9.898961782455444336e-01 100 | -4.355641841888427734e+00 101 | -5.402780696749687195e-02 102 | 9.969825744628906250e-01 103 | 5.573787540197372437e-02 104 | -2.676608324050903320e+00 105 | -9.903787374496459961e-01 106 | -4.638086631894111633e-02 107 | -1.303794533014297485e-01 108 | -2.055020904541015625e+01 109 | 0.000000000000000000e+00 110 | 0.000000000000000000e+00 111 | 0.000000000000000000e+00 112 | 1.000000000000000000e+00 113 | 4.623147845268249512e-01 114 | -3.606788069009780884e-02 115 | -8.859820365905761719e-01 116 | -6.229297637939453125e+00 117 | -1.621732115745544434e-02 118 | 9.986613392829895020e-01 119 | -4.911736771464347839e-02 120 | -2.162063121795654297e+00 121 | 8.865675330162048340e-01 122 | 3.707594051957130432e-02 123 | 4.611109793186187744e-01 124 | 1.855968666076660156e+01 125 | 0.000000000000000000e+00 126 | 0.000000000000000000e+00 127 | 0.000000000000000000e+00 128 | 1.000000000000000000e+00 129 | -------------------------------------------------------------------------------- /data/files/eval_cached_cameras_replica.txt: -------------------------------------------------------------------------------- 1 | -8.751150965690612793e-01 2 | -1.854917965829372406e-02 3 | -4.835592210292816162e-01 4 | -5.352131128311157227e-01 5 | -3.498959913849830627e-02 6 | 9.990749955177307129e-01 7 | 2.499777451157569885e-02 8 | 2.149755507707595825e-01 9 | 4.826482534408569336e-01 10 | 3.879547119140625000e-02 11 | -8.749546408653259277e-01 12 | -5.159330844879150391e+00 13 | 0.000000000000000000e+00 14 | 0.000000000000000000e+00 15 | 0.000000000000000000e+00 16 | 1.000000000000000000e+00 17 | 8.664437532424926758e-01 18 | -2.258590981364250183e-02 19 | 4.987635612487792969e-01 20 | 3.569913208484649658e-01 21 | 2.733567170798778534e-02 22 | 9.996238350868225098e-01 23 | -2.220305381342768669e-03 24 | -2.914254367351531982e-02 25 | -4.985257983207702637e-01 26 | 1.555780600756406784e-02 27 | 8.667352199554443359e-01 28 | 3.924350976943969727e+00 29 | 0.000000000000000000e+00 30 | 0.000000000000000000e+00 31 | 0.000000000000000000e+00 32 | 1.000000000000000000e+00 33 | -9.118588566780090332e-01 34 | -1.744104549288749695e-02 35 | 4.101331830024719238e-01 36 | 1.337187170982360840e+00 37 | -2.261161431670188904e-02 38 | 9.997141957283020020e-01 39 | -7.759770844131708145e-03 40 | -7.843595147132873535e-01 41 | -4.098806381225585938e-01 42 | -1.634958945214748383e-02 43 | -9.119926095008850098e-01 44 | -4.242384910583496094e+00 45 | 0.000000000000000000e+00 46 | 0.000000000000000000e+00 47 | 0.000000000000000000e+00 48 | 1.000000000000000000e+00 49 | 4.445236325263977051e-01 50 | 4.226445779204368591e-02 51 | -8.947694897651672363e-01 52 | -5.978088378906250000e+00 53 | 1.511921267956495285e-02 54 | 9.983900189399719238e-01 55 | 5.467023700475692749e-02 56 | 1.173593029379844666e-01 57 | 8.956395387649536133e-01 58 | -3.783041983842849731e-02 59 | 4.431689381599426270e-01 60 | -4.252989768981933594e+00 61 | 0.000000000000000000e+00 62 | 0.000000000000000000e+00 63 | 0.000000000000000000e+00 64 | 1.000000000000000000e+00 65 | -7.398086190223693848e-01 66 | 4.271236434578895569e-02 67 | -6.714602708816528320e-01 68 | 2.654052257537841797e+00 69 | -8.073388598859310150e-03 70 | 9.973475337028503418e-01 71 | 7.233761996030807495e-02 72 | 2.502764761447906494e-01 73 | 6.727689504623413086e-01 74 | 5.893695354461669922e-02 75 | -7.375014424324035645e-01 76 | -5.705560684204101562e+00 77 | 0.000000000000000000e+00 78 | 0.000000000000000000e+00 79 | 0.000000000000000000e+00 80 | 1.000000000000000000e+00 81 | 9.885296821594238281e-01 82 | -1.169542502611875534e-02 83 | 1.505732983350753784e-01 84 | 1.230404257774353027e+00 85 | 4.616037011146545410e-03 86 | 9.988710284233093262e-01 87 | 4.728017002344131470e-02 88 | -4.308381378650665283e-01 89 | -1.509562581777572632e-01 90 | -4.604279994964599609e-02 91 | 9.874675869941711426e-01 92 | 4.562815189361572266e+00 93 | 0.000000000000000000e+00 94 | 0.000000000000000000e+00 95 | 0.000000000000000000e+00 96 | 1.000000000000000000e+00 97 | 9.749212265014648438e-01 98 | -8.578262291848659515e-03 99 | -2.223847359418869019e-01 100 | -7.106215476989746094e+00 101 | -1.746909925714135170e-03 102 | 9.989311099052429199e-01 103 | -4.619108512997627258e-02 104 | -1.134104877710342407e-01 105 | 2.225432693958282471e-01 106 | 4.542115703225135803e-02 107 | 9.738641381263732910e-01 108 | 1.696477293968200684e+00 109 | 0.000000000000000000e+00 110 | 0.000000000000000000e+00 111 | 0.000000000000000000e+00 112 | 1.000000000000000000e+00 113 | -1.347370594739913940e-01 114 | -5.171212926506996155e-02 115 | 9.895310997962951660e-01 116 | 4.524656772613525391e+00 117 | 6.843224912881851196e-02 118 | 9.957672953605651855e-01 119 | 6.135593354701995850e-02 120 | 1.193765401840209961e-01 121 | -9.885155558586120605e-01 122 | 7.598275691270828247e-02 123 | -1.306279748678207397e-01 124 | 1.892387628555297852e+00 125 | 0.000000000000000000e+00 126 | 0.000000000000000000e+00 127 | 0.000000000000000000e+00 128 | 1.000000000000000000e+00 129 | -------------------------------------------------------------------------------- /data/habitat_data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from data.create_rgb_dataset import RandomImageGenerator 7 | 8 | 9 | class HabitatImageGenerator(torch.utils.data.Dataset): 10 | def __init__(self, split, opts, vectorize=True, seed=0): 11 | self.worker_id = 0 12 | self.split = split 13 | self.opts = opts 14 | 15 | self.num_views = opts.num_views 16 | self.vectorize = vectorize 17 | 18 | self.image_generator = None 19 | 20 | # Part of hacky code to have train/val 21 | self.episodes = None 22 | self.restarted = True 23 | self.train = True 24 | 25 | self.rng = np.random.RandomState(seed) 26 | self.seed = opts.seed 27 | 28 | self.fixed_val_images = [None] * 32 # Keep 32 examples 29 | 30 | def __len__(self): 31 | return 2 ** 31 32 | 33 | def __restart__(self): 34 | if self.vectorize: 35 | self.image_generator = RandomImageGenerator( 36 | self.split, 37 | self.opts.render_ids[ 38 | self.worker_id % len(self.opts.render_ids) 39 | ], 40 | self.opts, 41 | vectorize=self.vectorize, 42 | seed=self.worker_id + self.seed, 43 | ) 44 | self.image_generator.env.reset() 45 | else: 46 | self.image_generator = RandomImageGenerator( 47 | self.split, 48 | self.opts.render_ids[ 49 | self.worker_id % len(self.opts.render_ids) 50 | ], 51 | self.opts, 52 | vectorize=self.vectorize, 53 | seed=torch.randint(100, size=(1,)).item(), 54 | ) 55 | self.image_generator.env.reset() 56 | self.rng = np.random.RandomState( 57 | torch.randint(100, size=(1,)).item() 58 | ) 59 | 60 | if not (self.vectorize): 61 | if self.episodes is None: 62 | self.rng.shuffle(self.image_generator.env.episodes) 63 | self.episodes = self.image_generator.env.episodes 64 | self.image_generator.env.reset() 65 | self.num_samples = 0 66 | 67 | def restart(self, train): 68 | 69 | if not (self.vectorize): 70 | if train: 71 | self.image_generator.env.episodes = self.episodes[ 72 | 0 : int(0.8 * len(self.episodes)) 73 | ] 74 | else: 75 | self.image_generator.env.episodes = self.episodes[ 76 | int(0.8 * len(self.episodes)) : 77 | ] 78 | 79 | # randomly choose an environment to start at (as opposed to always 0) 80 | self.image_generator.env._current_episode_index = self.rng.randint( 81 | len(self.episodes) 82 | ) 83 | print( 84 | "EPISODES A ", 85 | self.image_generator.env._current_episode_index, 86 | flush=True, 87 | ) 88 | self.image_generator.env.reset() 89 | print( 90 | "EPISODES B ", 91 | self.image_generator.env._current_episode_index, 92 | flush=True, 93 | ) 94 | 95 | def totrain(self, epoch=0): 96 | self.restarted = True 97 | self.train = True 98 | self.seed = epoch 99 | 100 | def toval(self, epoch=0): 101 | self.restarted = True 102 | self.train = False 103 | self.val_index = 0 104 | self.seed = epoch 105 | 106 | def __getitem__(self, item): 107 | if not (self.train) and (self.val_index < len(self.fixed_val_images)): 108 | if self.fixed_val_images[self.val_index]: 109 | data = self.fixed_val_images[self.val_index] 110 | self.val_index += 1 111 | return data 112 | 113 | if self.image_generator is None: 114 | print( 115 | "Restarting image_generator.... with seed %d in train mode? %s" 116 | % (self.seed, self.train), 117 | flush=True, 118 | ) 119 | self.__restart__() 120 | 121 | if self.restarted: 122 | self.restart(self.train) 123 | self.restarted = False 124 | 125 | # Ignore the item and just generate an image 126 | data = self.image_generator.get_sample(item, self.num_views, self.train) 127 | 128 | if not (self.train) and (self.val_index < len(self.fixed_val_images)): 129 | self.fixed_val_images[self.val_index] = data 130 | 131 | self.val_index += 1 132 | 133 | return data 134 | -------------------------------------------------------------------------------- /data/kitti.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # Based on https://github.com/xuchen-ethz/continuous_view_synthesis/blob/master/data/kitti_data_loader.py 3 | 4 | import torch 5 | import numpy as np 6 | from scipy.spatial.transform import Rotation as ROT 7 | import torch.utils.data as data 8 | import os 9 | import csv 10 | import random 11 | from PIL import Image 12 | 13 | class KITTIDataLoader(data.Dataset): 14 | """ Dataset for loading the RealEstate10K. In this case, images are 15 | chosen within a video. 16 | """ 17 | 18 | def __init__(self, dataset, opts=None, num_views=2, seed=0, vectorize=False): 19 | super(KITTIDataLoader, self).__init__() 20 | 21 | self.initialize(opts) 22 | 23 | def initialize(self, opt): 24 | self.opt = opt 25 | self.dataroot = opt.train_data_path 26 | 27 | self.opt.bound = 5 28 | with open(os.path.join(self.dataroot, 'id_train.txt'), 'r') as fp: 29 | self.ids_train = [s.strip() for s in fp.readlines() if s] 30 | 31 | self.ids = self.ids_train 32 | self.dataset_size = int(len(self.ids))// (opt.bound*2) 33 | 34 | self.pose_dict = {} 35 | pose_path = os.path.join(self.dataroot, 'poses.txt') 36 | with open(pose_path) as csv_file: 37 | csv_reader = csv.reader(csv_file, delimiter=' ') 38 | for row in csv_reader: 39 | id = row[0] 40 | self.pose_dict[id] = [] 41 | for col in row[1:-1]: 42 | self.pose_dict[id].append(float(col)) 43 | self.pose_dict[id] = np.array(self.pose_dict[id]) 44 | 45 | def __getitem__(self, index): 46 | id = self.ids[index] 47 | id_num = int(id.split('_')[-1]) 48 | while True: 49 | delta = random.choice([x for x in range(-self.opt.bound, self.opt.bound+1) if x != 0] ) 50 | id_target = id.split('_')[0] +'_' + str(id_num + delta).zfill(len(id.split('_')[-1])) 51 | if id_target in self.pose_dict.keys(): break 52 | 53 | B = self.load_image(id) / 255. * 2 - 1 54 | B = torch.from_numpy(B.astype(np.float32)).permute((2,0,1)) 55 | A = self.load_image(id_target) / 255. * 2 - 1 56 | A = torch.from_numpy(A.astype(np.float32)).permute((2,0,1)) 57 | 58 | poseB = self.pose_dict[id] 59 | poseA = self.pose_dict[id_target] 60 | TB = poseB[3:].reshape(3, 1) 61 | RB = ROT.from_euler('xyz',poseB[0:3]).as_dcm() 62 | TA = poseA[3:].reshape(3, 1) 63 | RA = ROT.from_euler('xyz',poseA[0:3]).as_dcm() 64 | T = RA.T.dot(TB-TA)/50. 65 | 66 | mat = np.block( 67 | [ [RA.T@RB, T], 68 | [np.zeros((1,3)), 1] ] ) 69 | 70 | 71 | RT = mat.astype(np.float32) 72 | RTinv = np.linalg.inv(mat).astype(np.float32) 73 | identity = torch.eye(4) 74 | 75 | K = np.array( 76 | [718.9 / 256., 0., 128 / 256., 0, \ 77 | 0., 718.9 / 256., 128 / 256., 0, \ 78 | 0., 0., 1., 0., \ 79 | 0., 0., 0., 1.]).reshape((4, 4)).astype(np.float32) 80 | 81 | offset = np.array( 82 | [[2, 0, -1, 0], [0, -2, 1, 0], [0, 0, 1, 0], [0, 0, 0, 1]], # Flip ys to match habitat 83 | dtype=np.float32, 84 | ) # Make z negative to match habitat (which assumes a negative z) 85 | 86 | K = np.matmul(offset, K) 87 | 88 | Kinv = np.linalg.inv(K).astype(np.float32) 89 | 90 | return {'images' : [A, B], 'cameras' : [{'Pinv' : identity, 'P' : identity, 'K' : K, 'Kinv' : Kinv}, 91 | {'Pinv' : RTinv, 'P' : RT, 'K' : K, 'Kinv' : Kinv}] 92 | } 93 | 94 | def load_image(self, id): 95 | image_path = os.path.join(self.dataroot, 'images', id + '.png') 96 | image = np.asarray(Image.open(image_path).convert('RGB')) 97 | return image 98 | 99 | def __len__(self): 100 | return self.dataset_size * 20 101 | 102 | def toval(self, epoch): 103 | pass 104 | 105 | def totrain(self, epoch): 106 | pass -------------------------------------------------------------------------------- /data/realestate10k.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | import numpy as np 4 | import torch.utils.data as data 5 | from PIL import Image 6 | from torchvision.transforms import Compose, Normalize, Resize, ToTensor 7 | 8 | from utils.geometry import get_deltas 9 | 10 | 11 | class RealEstate10K(data.Dataset): 12 | """ Dataset for loading the RealEstate10K. In this case, images are randomly 13 | chosen within a video subject to certain constraints: e.g. they should 14 | be within a number of frames but the angle and translation should 15 | vary as much as possible. 16 | """ 17 | 18 | def __init__( 19 | self, dataset, opts=None, num_views=2, seed=0, vectorize=False 20 | ): 21 | # Now go through the dataset 22 | 23 | self.imageset = np.loadtxt( 24 | opts.train_data_path + "/frames/%s/video_loc.txt" % "train", 25 | dtype=np.str, 26 | ) 27 | 28 | if dataset == "train": 29 | self.imageset = self.imageset[0 : int(0.8 * self.imageset.shape[0])] 30 | else: 31 | self.imageset = self.imageset[int(0.8 * self.imageset.shape[0]) :] 32 | 33 | self.rng = np.random.RandomState(seed) 34 | self.base_file = opts.train_data_path 35 | 36 | self.num_views = num_views 37 | 38 | self.input_transform = Compose( 39 | [ 40 | Resize((opts.W, opts.W)), 41 | ToTensor(), 42 | Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 43 | ] 44 | ) 45 | 46 | self.offset = np.array( 47 | [[2, 0, -1], [0, -2, 1], [0, 0, -1]], # Flip ys to match habitat 48 | dtype=np.float32, 49 | ) # Make z negative to match habitat (which assumes a negative z) 50 | 51 | self.dataset = "train" 52 | 53 | self.K = np.array( 54 | [ 55 | [1.0, 0.0, 0.0, 0.0], 56 | [0, 1.0, 0.0, 0.0], 57 | [0.0, 0.0, 1.0, 0.0], 58 | [0.0, 0.0, 0.0, 1.0], 59 | ], 60 | dtype=np.float32, 61 | ) 62 | 63 | self.invK = np.linalg.inv(self.K) 64 | 65 | self.ANGLE_THRESH = 5 66 | self.TRANS_THRESH = 0.15 67 | 68 | def __len__(self): 69 | return 2 ** 31 70 | 71 | def __getitem_simple__(self, index): 72 | index = self.rng.randint(self.imageset.shape[0]) 73 | # index = index % self.imageset.shape[0] 74 | # Load text file containing frame information 75 | frames = np.loadtxt( 76 | self.base_file 77 | + "/frames/%s/%s.txt" % (self.dataset, self.imageset[index]) 78 | ) 79 | 80 | image_index = self.rng.choice(frames.shape[0], size=(1,))[0] 81 | 82 | rgbs = [] 83 | cameras = [] 84 | for i in range(0, self.num_views): 85 | t_index = max( 86 | min( 87 | image_index + self.rng.randint(16) - 8, frames.shape[0] - 1 88 | ), 89 | 0, 90 | ) 91 | 92 | image = Image.open( 93 | self.base_file 94 | + "/frames/%s/%s/" % (self.dataset, self.imageset[index]) 95 | + str(int(frames[t_index, 0])) 96 | + ".png" 97 | ) 98 | rgbs += [self.input_transform(image)] 99 | 100 | intrinsics = frames[t_index, 1:7] 101 | extrinsics = frames[t_index, 7:] 102 | 103 | origK = np.array( 104 | [ 105 | [intrinsics[0], 0, intrinsics[2]], 106 | [0, intrinsics[1], intrinsics[3]], 107 | [0, 0, 1], 108 | ], 109 | dtype=np.float32, 110 | ) 111 | K = np.matmul(self.offset, origK) 112 | 113 | origP = extrinsics.reshape(3, 4) 114 | P = np.matmul(K, origP) # Merge these together to match habitat 115 | P = np.vstack((P, np.zeros((1, 4), dtype=np.float32))).astype( 116 | np.float32 117 | ) 118 | P[3, 3] = 1 119 | 120 | Pinv = np.linalg.inv(P) 121 | 122 | cameras += [ 123 | { 124 | "P": P, 125 | "OrigP": origP, 126 | "Pinv": Pinv, 127 | "K": self.K, 128 | "Kinv": self.invK, 129 | } 130 | ] 131 | 132 | return {"images": rgbs, "cameras": cameras} 133 | 134 | def __getitem__(self, index): 135 | index = self.rng.randint(self.imageset.shape[0]) 136 | # index = index % self.imageset.shape[0] 137 | # Load text file containing frame information 138 | frames = np.loadtxt( 139 | self.base_file 140 | + "/frames/%s/%s.txt" % (self.dataset, self.imageset[index]) 141 | ) 142 | 143 | image_index = self.rng.choice(frames.shape[0], size=(1,))[0] 144 | 145 | # Chose 15 images within 30 frames of the iniital one 146 | image_indices = self.rng.randint(80, size=(30,)) - 40 + image_index 147 | image_indices = np.minimum( 148 | np.maximum(image_indices, 0), frames.shape[0] - 1 149 | ) 150 | 151 | # Look at the change in angle and choose a hard one 152 | angles = [] 153 | translations = [] 154 | for viewpoint in range(0, image_indices.shape[0]): 155 | orig_viewpoint = frames[image_index, 7:].reshape(3, 4) 156 | new_viewpoint = frames[image_indices[viewpoint], 7:].reshape(3, 4) 157 | dang, dtrans = get_deltas(orig_viewpoint, new_viewpoint) 158 | 159 | angles += [dang] 160 | translations += [dtrans] 161 | 162 | angles = np.array(angles) 163 | translations = np.array(translations) 164 | 165 | mask = image_indices[ 166 | (angles > self.ANGLE_THRESH) | (translations > self.TRANS_THRESH) 167 | ] 168 | 169 | rgbs = [] 170 | cameras = [] 171 | for i in range(0, self.num_views): 172 | if i == 0: 173 | t_index = image_index 174 | elif mask.shape[0] > 5: 175 | # Choose a harder angle change 176 | t_index = mask[self.rng.randint(mask.shape[0])] 177 | else: 178 | t_index = image_indices[ 179 | self.rng.randint(image_indices.shape[0]) 180 | ] 181 | 182 | image = Image.open( 183 | self.base_file 184 | + "/frames/%s/%s/" % (self.dataset, self.imageset[index]) 185 | + str(int(frames[t_index, 0])) 186 | + ".png" 187 | ) 188 | rgbs += [self.input_transform(image)] 189 | 190 | intrinsics = frames[t_index, 1:7] 191 | extrinsics = frames[t_index, 7:] 192 | 193 | origK = np.array( 194 | [ 195 | [intrinsics[0], 0, intrinsics[2]], 196 | [0, intrinsics[1], intrinsics[3]], 197 | [0, 0, 1], 198 | ], 199 | dtype=np.float32, 200 | ) 201 | K = np.matmul(self.offset, origK) 202 | 203 | origP = extrinsics.reshape(3, 4) 204 | P = np.matmul(K, origP) # Merge these together to match habitat 205 | P = np.vstack((P, np.zeros((1, 4), dtype=np.float32))).astype( 206 | np.float32 207 | ) 208 | P[3, 3] = 1 209 | 210 | Pinv = np.linalg.inv(P) 211 | 212 | cameras += [ 213 | { 214 | "P": P, 215 | "Pinv": Pinv, 216 | "OrigP": origP, 217 | "K": self.K, 218 | "Kinv": self.invK, 219 | } 220 | ] 221 | 222 | return {"images": rgbs, "cameras": cameras} 223 | 224 | def totrain(self, epoch): 225 | self.imageset = np.loadtxt( 226 | self.base_file + "/frames/%s/video_loc.txt" % "train", dtype=np.str 227 | ) 228 | self.imageset = self.imageset[0 : int(0.8 * self.imageset.shape[0])] 229 | self.rng = np.random.RandomState(epoch) 230 | 231 | def toval(self, epoch): 232 | self.imageset = np.loadtxt( 233 | self.base_file + "/frames/%s/video_loc.txt" % "train", dtype=np.str 234 | ) 235 | self.imageset = self.imageset[int(0.8 * self.imageset.shape[0]) :] 236 | self.rng = np.random.RandomState(epoch) 237 | 238 | 239 | class RealEstate10KConsecutive(data.Dataset): 240 | """ Dataset for loading the RealEstate10K. In this case, images are 241 | consecutive within a video, as opposed to randomly chosen. 242 | """ 243 | 244 | def __init__( 245 | self, dataset, opts=None, num_views=2, seed=0, vectorize=False 246 | ): 247 | # Now go through the dataset 248 | 249 | self.imageset = np.loadtxt( 250 | opts.train_data_path + "/frames/%s/video_loc.txt" % "train", 251 | dtype=np.str, 252 | ) 253 | 254 | if dataset == "train": 255 | self.imageset = self.imageset[0 : int(0.8 * self.imageset.shape[0])] 256 | else: 257 | self.imageset = self.imageset[int(0.8 * self.imageset.shape[0]) :] 258 | 259 | self.rng = np.random.RandomState(seed) 260 | self.base_file = opts.train_data_path 261 | 262 | self.num_views = num_views 263 | 264 | self.input_transform = Compose( 265 | [ 266 | Resize((opts.W, opts.W)), 267 | ToTensor(), 268 | Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 269 | ] 270 | ) 271 | 272 | self.offset = np.array( 273 | [[2, 0, -1], [0, -2, 1], [0, 0, -1]], # Flip ys to match habitat 274 | dtype=np.float32, 275 | ) # Make z negative to match habitat (which assumes a negative z) 276 | 277 | self.dataset = "train" 278 | 279 | self.K = np.array( 280 | [ 281 | [1.0, 0.0, 0.0, 0.0], 282 | [0, 1.0, 0.0, 0.0], 283 | [0.0, 0.0, 1.0, 0.0], 284 | [0.0, 0.0, 0.0, 1.0], 285 | ], 286 | dtype=np.float32, 287 | ) 288 | 289 | self.invK = np.linalg.inv(self.K) 290 | 291 | self.ANGLE_THRESH = 5 292 | self.TRANS_THRESH = 0.15 293 | 294 | def __len__(self): 295 | return 2 ** 31 296 | 297 | def __getitem__(self, index): 298 | index = self.rng.randint(self.imageset.shape[0]) 299 | # Load text file containing frame information 300 | frames = np.loadtxt( 301 | self.base_file 302 | + "/frames/%s/%s.txt" % (self.dataset, self.imageset[index]) 303 | ) 304 | 305 | image_index = self.rng.choice( 306 | max(1, frames.shape[0] - self.num_views), size=(1,) 307 | )[0] 308 | 309 | image_indices = np.linspace( 310 | image_index, image_index + self.num_views - 1, self.num_views 311 | ).astype(np.int32) 312 | image_indices = np.minimum( 313 | np.maximum(image_indices, 0), frames.shape[0] - 1 314 | ) 315 | 316 | rgbs = [] 317 | cameras = [] 318 | for i in range(0, self.num_views): 319 | t_index = image_indices[i] 320 | image = Image.open( 321 | self.base_file 322 | + "/frames/%s/%s/" % (self.dataset, self.imageset[index]) 323 | + str(int(frames[t_index, 0])) 324 | + ".png" 325 | ) 326 | rgbs += [self.input_transform(image)] 327 | 328 | intrinsics = frames[t_index, 1:7] 329 | extrinsics = frames[t_index, 7:] 330 | 331 | origK = np.array( 332 | [ 333 | [intrinsics[0], 0, intrinsics[2]], 334 | [0, intrinsics[1], intrinsics[3]], 335 | [0, 0, 1], 336 | ], 337 | dtype=np.float32, 338 | ) 339 | K = np.matmul(self.offset, origK) 340 | 341 | origP = extrinsics.reshape(3, 4) 342 | np.set_printoptions(precision=3, suppress=True) 343 | P = np.matmul(K, origP) # Merge these together to match habitat 344 | P = np.vstack((P, np.zeros((1, 4), dtype=np.float32))).astype( 345 | np.float32 346 | ) 347 | P[3, 3] = 1 348 | 349 | Pinv = np.linalg.inv(P) 350 | 351 | cameras += [ 352 | { 353 | "P": P, 354 | "Pinv": Pinv, 355 | "OrigP": origP, 356 | "K": self.K, 357 | "Kinv": self.invK, 358 | } 359 | ] 360 | 361 | return {"images": rgbs, "cameras": cameras} 362 | -------------------------------------------------------------------------------- /data/scene_episodes/mp3d_test/dataset_one_ep_per_scene.json.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/synsin/501ec49b11030a41207e7b923b949fab8fd6e1b5/data/scene_episodes/mp3d_test/dataset_one_ep_per_scene.json.gz -------------------------------------------------------------------------------- /data/scene_episodes/mp3d_train/dataset_one_ep_per_scene.json.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/synsin/501ec49b11030a41207e7b923b949fab8fd6e1b5/data/scene_episodes/mp3d_train/dataset_one_ep_per_scene.json.gz -------------------------------------------------------------------------------- /data/scene_episodes/mp3d_val/dataset_one_ep_per_scene.json.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/synsin/501ec49b11030a41207e7b923b949fab8fd6e1b5/data/scene_episodes/mp3d_val/dataset_one_ep_per_scene.json.gz -------------------------------------------------------------------------------- /data/scene_episodes/replica_test/dataset_one_ep_per_scene.json.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/synsin/501ec49b11030a41207e7b923b949fab8fd6e1b5/data/scene_episodes/replica_test/dataset_one_ep_per_scene.json.gz -------------------------------------------------------------------------------- /data/scene_episodes/replica_train/dataset_one_ep_per_scene.json.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/synsin/501ec49b11030a41207e7b923b949fab8fd6e1b5/data/scene_episodes/replica_train/dataset_one_ep_per_scene.json.gz -------------------------------------------------------------------------------- /demos/im.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/synsin/501ec49b11030a41207e7b923b949fab8fd6e1b5/demos/im.jpg -------------------------------------------------------------------------------- /download_models.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | # BASH file for downloading pretrained models. 4 | 5 | # Download readme 6 | mkdir ./modelcheckpoints/ 7 | cd modelcheckpoints/ 8 | wget https://dl.fbaipublicfiles.com/synsin/checkpoints/readme.txt 9 | 10 | # Make dataset directories 11 | mkdir kitti 12 | mkdir realestate 13 | mkdir mp3d 14 | 15 | # Download files 16 | cd ./kitti/ 17 | wget https://dl.fbaipublicfiles.com/synsin/checkpoints/kitti/zbufferpts_invdepth.pth 18 | wget https://dl.fbaipublicfiles.com/synsin/checkpoints/kitti/synsin_invdepth.pth 19 | 20 | wget https://dl.fbaipublicfiles.com/synsin/checkpoints/kitti/viewappearance.pth 21 | wget https://dl.fbaipublicfiles.com/synsin/checkpoints/kitti/tatarchenko.pth 22 | 23 | cd ../realestate/ 24 | wget https://dl.fbaipublicfiles.com/synsin/checkpoints/realestate/synsin.pth 25 | wget https://dl.fbaipublicfiles.com/synsin/checkpoints/realestate/zbufferpts.pth 26 | 27 | wget https://dl.fbaipublicfiles.com/synsin/checkpoints/realestate/viewappearance.pth 28 | wget https://dl.fbaipublicfiles.com/synsin/checkpoints/realestate/tatarchenko.pth 29 | 30 | cd ../mp3d/ 31 | wget https://dl.fbaipublicfiles.com/synsin/checkpoints/mp3d/synsin.pth 32 | wget https://dl.fbaipublicfiles.com/synsin/checkpoints/mp3d/viewappearance.pth 33 | wget https://dl.fbaipublicfiles.com/synsin/checkpoints/mp3d/tatarchenko.pth 34 | 35 | cd ../ 36 | -------------------------------------------------------------------------------- /evaluation/RESULTS.md: -------------------------------------------------------------------------------- 1 |

Results for released models

2 | 3 | Note that this repository is a large refactoring of the original code to allow for public release 4 | and to integrate with [pytorch3d](https://github.com/facebookresearch/pytorch3d). 5 | Hence the models/datasets are not necessarily the same as that in the paper, as we cannot release 6 | the saved test images we used. 7 | To compare results, we recommend comparing against the numbers and models in this repo for fair comparison 8 | and reproducibility. 9 | 10 | These models have been trained with the same learning rate (0.0001) and number of epochs. 11 | If you want to train the baselines differently (for example in the paper we found that the voxel based methods were highly sensitive to learning rates, so used a model trained with a lower learning rate), look at the options in `./../submit.sh`. 12 | 13 | You can use these numbers to: 14 | 1. Compare against your models 15 | 2. Verify your setup is indeed correct 16 | 17 |

Results on RealEstate

18 | 19 | 20 | | | PSNR | SSIM | Perc SIM | 21 | | --------------------------|:-------------:| -----:|:-------------:| 22 | | SynSin | 22.31 | 0.74 | 1.18 | 23 | | SynSin+ | 22.83 | 0.75 | 1.13 | 24 | | ViewAppearance [1] | 17.05 | 0.56 | 2.19 | 25 | | Tatarchenko [2] | 11.35 | 0.33 | 3.95 | 26 | | StereoMag [4] | 25.34 | 0.82 | 1.19 | 27 | | 3DPaper [5] | 21.88 | 0.66 | 1.52 | 28 | 29 |

Results on Matterport

30 | 31 | | | PSNR | SSIM | Perc SIM | 32 | | -------------------------|:-------------:| -----:|:-------------:| 33 | | SynSin | 20.91 | 0.72 | 1.68 | 34 | | ViewAppearance [1] | 15.87 | 0.53 | 2.99 | 35 | | Tatarchenko [2] | 14.79 | 0.57 | 3.73 | 36 | 37 |

Results on Replica

38 | 39 | | | PSNR | SSIM | Perc SIM | 40 | | -------------------------|:-------------:| -----:|:-------------:| 41 | | SynSin | 21.94 | 0.81 | 1.55 | 42 | | ViewAppearance [1] | 17.42 | 0.66 | 2.29 | 43 | | Tatarchenko [2] | 14.36 | 0.68 | 3.36 | 44 | 45 |

Results on KITTI

46 | 47 | | | PSNR | SSIM | Perc SIM | 48 | | --------------------------|:-------------:| -----:|:-------------:| 49 | | SynSin* | 16.70 | 0.52 | 2.07 | 50 | | SynSin+* | 16.73 | 0.52 | 2.05 | 51 | | ViewAppearance [1] | 14.21 | 0.43 | 2.51 | 52 | | Tatarchenko [2] | 10.31 | 0.30 | 3.48 | 53 | | ContView [3] | 16.90 | 0.54 | 2.21 | 54 | 55 | 56 |

References

57 | The implemented models are based on: 58 | 59 | *: Using inverse depth as opposed to a uniform sampling. This is better if there is a long tail distribution of the true depths (as in the KITTI case). 60 | 61 | +: Leaving the model to run for longer than the paper for a small boost in results. 62 | 63 | [1] Zhou, Tinghui, et al. "View synthesis by appearance flow." ECCV, 2016. 64 | 65 | [2] Dosovitskiy, Alexey, et al. "Learning to generate chairs with convolutional neural networks." CVPR, 2015. 66 | 67 | [3] Chen, Xu, et al. "Monocular Neural Image Based Rendering with Continuous View Control." ICCV, 2019. 68 | 69 | [4] Zhou, Tinghui, et al. "Stereo Magnification: Learning View Synthesis using Multiplane Images." SIGGRAPH, 2018. 70 | 71 | [5] Code based on work by folks at Facebook. The code used is an early version of the [3D Photos work](https://ai.facebook.com/blog/-powered-by-ai-turning-any-2d-photo-into-3d-using-convolutional-neural-nets/). 72 | -------------------------------------------------------------------------------- /evaluation/eval.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | import json 4 | import os 5 | import time 6 | 7 | import cv2 8 | import numpy as np 9 | 10 | from tqdm import tqdm 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torchvision 15 | from torch.multiprocessing import set_start_method 16 | from torch.utils.data import DataLoader 17 | 18 | from evaluation.metrics import perceptual_sim, psnr, ssim_metric 19 | 20 | from models.base_model import BaseModel 21 | from models.depth_model import Model 22 | from models.networks.pretrained_networks import PNet 23 | from models.networks.sync_batchnorm import convert_model 24 | 25 | from options.options import get_dataset, get_model 26 | from options.test_options import ArgumentParser 27 | from utils.geometry import get_deltas 28 | 29 | torch.manual_seed(0) 30 | torch.backends.cudnn.benchmark = True 31 | 32 | 33 | def worker_init_fn(worker_id): 34 | torch.manual_seed(worker_id) 35 | 36 | 37 | def psnr_mask(pred_imgs, key, invis=False): 38 | mask = pred_imgs["OutputImg"] == pred_imgs["SampledImg"] 39 | mask = mask.float().min(dim=1, keepdim=True)[0] 40 | 41 | if invis: 42 | mask = 1 - mask 43 | 44 | return psnr(pred_imgs["OutputImg"], pred_imgs[key], mask) 45 | 46 | 47 | def ssim_mask(pred_imgs, key, invis=False): 48 | mask = pred_imgs["OutputImg"] == pred_imgs["SampledImg"] 49 | mask = mask.float().min(dim=1, keepdim=True)[0] 50 | 51 | if invis: 52 | mask = 1 - mask 53 | 54 | return ssim_metric(pred_imgs["OutputImg"], pred_imgs[key], mask) 55 | 56 | 57 | def perceptual_sim_mask(pred_imgs, key, vgg16, invis=False): 58 | mask = pred_imgs["OutputImg"] == pred_imgs["SampledImg"] 59 | mask = mask.float().min(dim=1, keepdim=True)[0] 60 | 61 | if invis: 62 | mask = 1 - mask 63 | 64 | return perceptual_sim( 65 | pred_imgs["OutputImg"] * mask, pred_imgs[key] * mask, vgg16 66 | ) 67 | 68 | def check_initial_batch(batch, dataset): 69 | try: 70 | if dataset == 'replica': 71 | np.testing.assert_allclose(batch['cameras'][0]['P'].data.numpy().ravel(), 72 | np.loadtxt('./data/files/eval_cached_cameras_replica.txt')) 73 | else: 74 | np.testing.assert_allclose(batch['cameras'][0]['P'].data.numpy().ravel(), 75 | np.loadtxt('./data/files/eval_cached_cameras_mp3d.txt')) 76 | except Exception as e: 77 | raise Exception("\n \nThere is an error with your setup or options. \ 78 | \n\nYour results will NOT be comparable with results in the paper or online.") 79 | 80 | METRICS = { 81 | "PSNR": lambda pred_imgs, key: psnr( 82 | pred_imgs["OutputImg"], pred_imgs[key] 83 | ).clamp(max=100), 84 | "PSNR_invis": lambda pred_imgs, key: psnr_mask( 85 | pred_imgs, key, True 86 | ).clamp(max=100), 87 | "PSNR_vis": lambda pred_imgs, key: psnr_mask( 88 | pred_imgs, key, False 89 | ).clamp(max=100), 90 | "SSIM": lambda pred_imgs, key: ssim_metric( 91 | pred_imgs["OutputImg"], pred_imgs[key] 92 | ), 93 | "SSIM_invis": lambda pred_imgs, key: ssim_mask(pred_imgs, key, True), 94 | "SSIM_vis": lambda pred_imgs, key: ssim_mask(pred_imgs, key, False), 95 | "PercSim": lambda pred_imgs, key: perceptual_sim( 96 | pred_imgs["OutputImg"], pred_imgs[key], vgg16 97 | ), 98 | "PercSim_invis": lambda pred_imgs, key: perceptual_sim_mask( 99 | pred_imgs, key, vgg16, True 100 | ), 101 | "PercSim_vis": lambda pred_imgs, key: perceptual_sim_mask( 102 | pred_imgs, key, vgg16, False 103 | ), 104 | } 105 | 106 | if __name__ == "__main__": 107 | print("STARTING MAIN METHOD...", flush=True) 108 | try: 109 | set_start_method("spawn", force=True) 110 | except RuntimeError: 111 | pass 112 | 113 | test_ops, _ = ArgumentParser().parse() 114 | 115 | # Load model to be tested 116 | MODEL_PATH = test_ops.old_model 117 | BATCH_SIZE = test_ops.batch_size 118 | 119 | opts = torch.load(MODEL_PATH)["opts"] 120 | print("Model is: ", MODEL_PATH) 121 | 122 | opts.image_type = test_ops.image_type 123 | opts.only_high_res = False 124 | 125 | opts.train_depth = False 126 | 127 | if test_ops.dataset: 128 | opts.dataset = test_ops.dataset 129 | 130 | Dataset = get_dataset(opts) 131 | model = get_model(opts) 132 | 133 | # Update parameters 134 | opts.render_ids = test_ops.render_ids 135 | opts.gpu_ids = test_ops.gpu_ids 136 | 137 | opts.images_before_reset = test_ops.images_before_reset 138 | 139 | torch_devices = [int(gpu_id.strip()) for gpu_id in opts.gpu_ids.split(",")] 140 | device = "cuda:" + str(torch_devices[0]) 141 | 142 | if "sync" in opts.norm_G: 143 | model = convert_model(model) 144 | model = nn.DataParallel(model, torch_devices).to(device) 145 | else: 146 | model = nn.DataParallel(model, torch_devices).to(device) 147 | 148 | # Load the original model to be tested 149 | model_to_test = BaseModel(model, opts) 150 | model_to_test.eval() 151 | model_to_test.load_state_dict(torch.load(MODEL_PATH)["state_dict"]) 152 | 153 | # Load VGG16 for feature similarity 154 | vgg16 = PNet().to(device) 155 | vgg16.eval() 156 | 157 | # Create dummy depth model for doing sampling images 158 | sampled_model = Model(opts).to(device) 159 | sampled_model.eval() 160 | 161 | print("Loaded models...") 162 | 163 | model_to_test.eval() 164 | 165 | data = Dataset("test", opts, vectorize=False) 166 | dataloader = DataLoader( 167 | data, 168 | shuffle=False, 169 | drop_last=False, 170 | batch_size=BATCH_SIZE, 171 | num_workers=test_ops.num_workers, 172 | pin_memory=True, 173 | worker_init_fn=worker_init_fn, 174 | ) 175 | iter_data_loader = iter(dataloader) 176 | next(iter_data_loader) 177 | 178 | N = test_ops.images_before_reset * 18 * BATCH_SIZE 179 | 180 | if not os.path.exists(test_ops.result_folder): 181 | os.makedirs(test_ops.result_folder) 182 | 183 | file_to_write = open( 184 | test_ops.result_folder + "/%s_results.txt" % test_ops.short_name, "w" 185 | ) 186 | 187 | # Calculate the metrics and store for each index 188 | # and change in angle and translation 189 | results_all = {} 190 | for i in tqdm(range(0, N // BATCH_SIZE)): 191 | with torch.no_grad(): 192 | _, pred_imgs, batch = model_to_test( 193 | iter_data_loader, isval=True, return_batch=True 194 | ) 195 | 196 | _, new_imgs = sampled_model(batch) 197 | pred_imgs["SampledImg"] = new_imgs["SampledImg"] * 0.5 + 0.5 198 | 199 | # Check to make sure options were set right and this matches the setup 200 | # we used, so that numbers are comparable. 201 | if i == 0: 202 | check_initial_batch(batch, test_ops.dataset) 203 | 204 | # Obtain the angles and translation 205 | for batch_id in range(0, batch["cameras"][0]["P"].size(0)): 206 | dangles, dtrans = get_deltas( 207 | batch["cameras"][0]["P"][batch_id, 0:3, :], 208 | batch["cameras"][-1]["P"][batch_id, 0:3, :], 209 | ) 210 | 211 | for metric, func in METRICS.items(): 212 | key = "InputImg" if test_ops.test_input_image else "PredImg" 213 | t_results = func(pred_imgs, key) 214 | 215 | if not (metric in results_all.keys()): 216 | results_all[metric] = t_results.sum() 217 | else: 218 | results_all[metric] += t_results.sum() 219 | 220 | if i < 10: 221 | if not os.path.exists(test_ops.result_folder + "/%s" % metric): 222 | os.makedirs(test_ops.result_folder + "/%s/" % metric) 223 | 224 | torchvision.utils.save_image( 225 | pred_imgs["OutputImg"], 226 | test_ops.result_folder 227 | + "/%s/%03d_output_%s.png" 228 | % (metric, i, test_ops.short_name), 229 | pad_value=1, 230 | ) 231 | 232 | torchvision.utils.save_image( 233 | pred_imgs["InputImg"], 234 | test_ops.result_folder 235 | + "/%s/%03d_input_%s.png" 236 | % (metric, i, test_ops.short_name), 237 | pad_value=1, 238 | ) 239 | 240 | if "SampledImg" in pred_imgs.keys(): 241 | torchvision.utils.save_image( 242 | pred_imgs["SampledImg"], 243 | test_ops.result_folder 244 | + "/%s/%03d_sampled_%s.png" 245 | % (metric, i, test_ops.short_name), 246 | pad_value=1, 247 | ) 248 | torchvision.utils.save_image( 249 | (pred_imgs["SampledImg"] == pred_imgs["OutputImg"]).float(), 250 | test_ops.result_folder 251 | + "/%s/%03d_sampledmask_%s.png" 252 | % (metric, i, test_ops.short_name), 253 | pad_value=1, 254 | ) 255 | 256 | if "PredDepth" in pred_imgs.keys(): 257 | torchvision.utils.save_image( 258 | pred_imgs["PredDepth"], 259 | test_ops.result_folder 260 | + "/%s/%03d_depth_%s.png" 261 | % (metric, i, test_ops.short_name), 262 | pad_value=1, 263 | normalize=True, 264 | ) 265 | 266 | predimg = ( 267 | torchvision.utils.make_grid( 268 | pred_imgs["PredImg"], pad_value=1 269 | ) 270 | .clamp(min=0.00001, max=0.9999) 271 | .permute(1, 2, 0) 272 | .cpu() 273 | .numpy() 274 | ) 275 | predimg = cv2.cvtColor( 276 | (predimg * 255).astype(np.uint8), cv2.COLOR_RGB2BGR 277 | ) 278 | for b in range(0, t_results.size(0)): 279 | cv2.putText( 280 | predimg, 281 | "%0.4f" % t_results[b], 282 | org=(258 * b + 10, 250), 283 | fontFace=2, 284 | fontScale=1, 285 | color=(255, 255, 255), 286 | bottomLeftOrigin=False, 287 | ) 288 | cv2.imwrite( 289 | test_ops.result_folder 290 | + "/%s/%03d_p_%s.png" % (metric, i, test_ops.short_name), 291 | predimg, 292 | ) 293 | 294 | for metric, result in results_all.items(): 295 | file_to_write.write( 296 | "%s \t %0.5f \n" % (metric, result / float(BATCH_SIZE * (i + 1))) 297 | ) 298 | 299 | file_to_write.close() 300 | -------------------------------------------------------------------------------- /evaluation/eval_kitti.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | import os 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.utils.data as data 9 | import torchvision 10 | from PIL import Image 11 | from torch.utils.data import DataLoader 12 | 13 | from models.base_model import BaseModel 14 | from models.networks.sync_batchnorm import convert_model 15 | from options.options import get_dataset, get_model 16 | from options.test_options import ArgumentParser 17 | 18 | torch.backends.cudnn.benchmark = True 19 | torch.manual_seed(0) 20 | 21 | class Dataset(data.Dataset): 22 | def __init__(self): 23 | self.path = os.environ["KITTI"] 24 | 25 | self.files = np.loadtxt('./data/files/kitti.txt', dtype=np.str) 26 | 27 | self.K = np.array( 28 | [718.9 / 256., 0., 128 / 256., 0, \ 29 | 0., 718.9 / 256., 128 / 256., 0, \ 30 | 0., 0., 1., 0., \ 31 | 0., 0., 0., 1.]).reshape((4, 4)).astype(np.float32) 32 | self.invK = np.linalg.inv(self.K) 33 | 34 | def __len__(self): 35 | return len(self.files) 36 | 37 | def load_image(self, image_path): 38 | image = np.asarray(Image.open(image_path).convert('RGB')) 39 | image = image / 255. * 2. - 1. 40 | image = torch.from_numpy(image.astype(np.float32)).permute((2,0,1)) 41 | return image 42 | 43 | def __getitem__(self, index): 44 | imgA = self.path + self.files[index,1] + '.png' 45 | imgB = self.path + self.files[index,0] + '.png' 46 | RT = self.files[index,2:].astype(np.float32) 47 | 48 | B = self.load_image(imgB) 49 | A = self.load_image(imgA) 50 | 51 | RT = RT.astype(np.float32).reshape(4,4) 52 | RTinv = np.linalg.inv(RT).astype(np.float32) 53 | 54 | identity = torch.eye(4) 55 | 56 | offset = np.array( 57 | [[2, 0, -1, 0], [0, -2, 1, 0], [0, 0, 1, 0], [0, 0, 0, 1]], # Flip ys to match habitat 58 | dtype=np.float32, 59 | ) # Make z negative to match habitat (which assumes a negative z) 60 | 61 | K = np.matmul(offset, self.K) 62 | 63 | Kinv = np.linalg.inv(K).astype(np.float32) 64 | 65 | return {'images' : [A, B], 'cameras' : [{'Pinv' : identity, 'P' : identity, 'K' : K, 'Kinv' : Kinv}, 66 | {'Pinv' : RTinv, 'P' : RT, 'K' : K, 'Kinv' : Kinv}] 67 | } 68 | 69 | 70 | if __name__ == "__main__": 71 | test_ops, _ = ArgumentParser().parse() 72 | 73 | # Load model to be tested 74 | MODEL_PATH = test_ops.old_model 75 | BATCH_SIZE = test_ops.batch_size 76 | 77 | opts = torch.load(MODEL_PATH)["opts"] 78 | 79 | model = get_model(opts) 80 | 81 | opts.render_ids = test_ops.render_ids 82 | opts.gpu_ids = test_ops.gpu_ids 83 | 84 | torch_devices = [int(gpu_id.strip()) for gpu_id in opts.gpu_ids.split(",")] 85 | print(torch_devices) 86 | device = "cuda:" + str(torch_devices[0]) 87 | 88 | if "sync" in opts.norm_G: 89 | model = convert_model(model) 90 | model = nn.DataParallel(model, torch_devices).to(device) 91 | else: 92 | model = nn.DataParallel(model, torch_devices).to(device) 93 | 94 | # Load the original model to be tested 95 | model_to_test = BaseModel(model, opts) 96 | model_to_test.eval() 97 | 98 | # Allow for different image sizes 99 | state_dict = model_to_test.state_dict() 100 | pretrained_dict = { 101 | k: v 102 | for k, v in torch.load(MODEL_PATH)["state_dict"].items() 103 | if not ("xyzs" in k) and not ("ones" in k) 104 | } 105 | state_dict.update(pretrained_dict) 106 | 107 | model_to_test.load_state_dict(state_dict) 108 | 109 | print(opts) 110 | # Update parameters 111 | opts.render_ids = test_ops.render_ids 112 | opts.gpu_ids = test_ops.gpu_ids 113 | 114 | 115 | print("Loaded models...") 116 | 117 | # Load the dataset which is the set of images that came 118 | # from running the baselines' result scripts 119 | data = Dataset() 120 | 121 | model_to_test.eval() 122 | 123 | # Iterate through the dataset, predicting new views 124 | data_loader = DataLoader(data, batch_size=1, shuffle=False) 125 | iter_data_loader = iter(data_loader) 126 | 127 | for i in range(0, len(data_loader)): 128 | print(i, len(data_loader), flush=True) 129 | _, pred_imgs, batch = model_to_test( 130 | iter_data_loader, isval=True, return_batch=True 131 | ) 132 | 133 | if not os.path.exists( 134 | test_ops.result_folder 135 | + "/%d/" % (i) 136 | ): 137 | os.makedirs( 138 | test_ops.result_folder 139 | + "/%d/" % (i) 140 | ) 141 | 142 | torchvision.utils.save_image( 143 | pred_imgs["PredImg"], 144 | test_ops.result_folder 145 | + "/%d/im_res.png" % (i), 146 | ) 147 | torchvision.utils.save_image( 148 | pred_imgs["OutputImg"], 149 | test_ops.result_folder 150 | + "/%d/im_B.png" % (i), 151 | ) 152 | torchvision.utils.save_image( 153 | pred_imgs["InputImg"], 154 | test_ops.result_folder 155 | + "/%d/im_A.png" % (i), 156 | ) 157 | 158 | -------------------------------------------------------------------------------- /evaluation/eval_realestate.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.utils.data as data 7 | import torchvision 8 | from PIL import Image 9 | from torch.utils.data import DataLoader 10 | 11 | from models.base_model import BaseModel 12 | from models.networks.sync_batchnorm import convert_model 13 | from options.options import get_dataset, get_model 14 | from options.test_options import ArgumentParser 15 | 16 | torch.backends.cudnn.benchmark = True 17 | torch.manual_seed(0) 18 | 19 | 20 | class Dataset(data.Dataset): 21 | def __init__(self, W=256): 22 | 23 | self.base_path = os.environ['REALESTATE'] 24 | self.files = np.loadtxt('./data/files/realestate.txt', dtype=np.str) 25 | 26 | self.offset = np.array( 27 | [[2, 0, -1], [0, -2, 1], [0, 0, -1]], dtype=np.float32 28 | ) 29 | 30 | self.K = np.array( 31 | [ 32 | [1.0, 0.0, 0.0, 0.0], 33 | [0, 1.0, 0.0, 0.0], 34 | [0.0, 0.0, 1.0, 0.0], 35 | [0.0, 0.0, 0.0, 1.0], 36 | ], 37 | dtype=np.float32, 38 | ) 39 | self.invK = np.linalg.inv(self.K) 40 | 41 | self.input_transform = torchvision.transforms.Compose( 42 | [ 43 | torchvision.transforms.Resize((W, W)), 44 | torchvision.transforms.ToTensor(), 45 | torchvision.transforms.Normalize( 46 | (0.5, 0.5, 0.5), (0.5, 0.5, 0.5) 47 | ), 48 | ] 49 | ) 50 | 51 | self.W = W 52 | 53 | def __len__(self): 54 | return len(self.files) 55 | 56 | def __getitem__(self, index): 57 | 58 | # Then load the image and generate that 59 | file_name = self.files[index] 60 | 61 | src_image_name = ( 62 | self.base_path 63 | + '/%s/%s.png' % (file_name[0], file_name[1]) 64 | ) 65 | tgt_image_name = ( 66 | self.base_path 67 | + '/%s/%s.png' % (file_name[0], file_name[2]) 68 | ) 69 | 70 | intrinsics = file_name[3:7].astype(np.float32) / float(self.W) 71 | src_pose = file_name[7:19].astype(np.float32).reshape(3, 4) 72 | tgt_pose = file_name[19:].astype(np.float32).reshape(3, 4) 73 | 74 | src_image = self.input_transform(Image.open(src_image_name)) 75 | tgt_image = self.input_transform(Image.open(tgt_image_name)) 76 | 77 | poses = [src_pose, tgt_pose] 78 | cameras = [] 79 | 80 | for pose in poses: 81 | 82 | origK = np.array( 83 | [ 84 | [intrinsics[0], 0, intrinsics[2]], 85 | [0, intrinsics[1], intrinsics[3]], 86 | [0, 0, 1], 87 | ], 88 | dtype=np.float32, 89 | ) 90 | K = np.matmul(self.offset, origK) 91 | 92 | P = pose 93 | P = np.matmul(K, P) 94 | # Merge these together to match habitat 95 | P = np.vstack((P, np.zeros((1, 4)))).astype(np.float32) 96 | P[3, 3] = 1 97 | 98 | # Now artificially flip x/ys to match habitat 99 | Pinv = np.linalg.inv(P) 100 | 101 | cameras += [{"P": P, "Pinv": Pinv, "K": self.K, "Kinv": self.invK}] 102 | 103 | return {"images": [src_image, tgt_image], "cameras": cameras} 104 | 105 | 106 | if __name__ == "__main__": 107 | test_ops, _ = ArgumentParser().parse() 108 | 109 | # Load model to be tested 110 | MODEL_PATH = test_ops.old_model 111 | BATCH_SIZE = test_ops.batch_size 112 | 113 | opts = torch.load(MODEL_PATH)["opts"] 114 | opts.isTrain = True 115 | opts.only_high_res = False 116 | opts.lr_d = 0.001 117 | 118 | opts.train_depth = False 119 | print(opts) 120 | 121 | DatasetTrain = get_dataset(opts) 122 | model = get_model(opts) 123 | 124 | opts.render_ids = test_ops.render_ids 125 | opts.gpu_ids = test_ops.gpu_ids 126 | 127 | torch_devices = [int(gpu_id.strip()) for gpu_id in opts.gpu_ids.split(",")] 128 | print(torch_devices) 129 | device = "cuda:" + str(torch_devices[0]) 130 | 131 | if "sync" in opts.norm_G: 132 | model = convert_model(model) 133 | model = nn.DataParallel(model, torch_devices).to(device) 134 | else: 135 | model = nn.DataParallel(model, torch_devices).to(device) 136 | 137 | # Load the original model to be tested 138 | model_to_test = BaseModel(model, opts) 139 | model_to_test.eval() 140 | 141 | # Allow for different image sizes 142 | state_dict = model_to_test.state_dict() 143 | pretrained_dict = { 144 | k: v 145 | for k, v in torch.load(MODEL_PATH)["state_dict"].items() 146 | if not ("xyzs" in k) and not ("ones" in k) 147 | } 148 | state_dict.update(pretrained_dict) 149 | 150 | model_to_test.load_state_dict(state_dict) 151 | 152 | print(opts) 153 | # Update parameters 154 | opts.render_ids = test_ops.render_ids 155 | opts.gpu_ids = test_ops.gpu_ids 156 | 157 | print("Loaded models...") 158 | 159 | # Load the dataset which is the set of images that came 160 | # from running the baselines' result scripts 161 | data = Dataset(W=opts.W) 162 | 163 | model_to_test.eval() 164 | 165 | # Iterate through the dataset, predicting new views 166 | data_loader = DataLoader(data, batch_size=1, shuffle=False) 167 | iter_data_loader = iter(data_loader) 168 | 169 | for i in range(0, len(data_loader)): 170 | print(i, len(data_loader), flush=True) 171 | _, pred_imgs, batch = model_to_test( 172 | iter_data_loader, isval=True, return_batch=True 173 | ) 174 | 175 | if not os.path.exists( 176 | test_ops.result_folder 177 | + "/%04d/" % (i) 178 | ): 179 | os.makedirs( 180 | test_ops.result_folder 181 | + "/%04d/" % (i) 182 | ) 183 | 184 | torchvision.utils.save_image( 185 | pred_imgs["PredImg"], 186 | test_ops.result_folder 187 | + "/%04d/output_image_.png" % (i), 188 | ) 189 | torchvision.utils.save_image( 190 | pred_imgs["OutputImg"], 191 | test_ops.result_folder 192 | + "/%04d/tgt_image_.png" % (i), 193 | ) 194 | torchvision.utils.save_image( 195 | pred_imgs["InputImg"], 196 | test_ops.result_folder 197 | + "/%04d/input_image_.png" % (i), 198 | ) 199 | 200 | print( 201 | pred_imgs["PredImg"].mean().item(), 202 | pred_imgs["PredImg"].std().item(), 203 | flush=True, 204 | ) 205 | -------------------------------------------------------------------------------- /evaluation/evaluate_perceptualsim.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | import argparse 4 | import glob 5 | import os 6 | from tqdm import tqdm 7 | 8 | import numpy as np 9 | import torch 10 | import torchvision.transforms as transforms 11 | from PIL import Image 12 | 13 | from evaluation.metrics import perceptual_sim, psnr, ssim_metric 14 | from models.networks.pretrained_networks import PNet 15 | 16 | transform = transforms.Compose([transforms.ToTensor()]) 17 | 18 | 19 | def load_img(img_name, size=None): 20 | try: 21 | img = Image.open(img_name) 22 | 23 | if size: 24 | img = img.resize((size, size)) 25 | 26 | img = transform(img).cuda() 27 | img = img.unsqueeze(0) 28 | except Exception as e: 29 | print("Failed at loading %s " % img_name) 30 | print(e) 31 | img = torch.zeros(1, 3, 256, 256).cuda() 32 | return img 33 | 34 | 35 | def compute_perceptual_similarity(folder, pred_img, tgt_img, take_every_other): 36 | 37 | # Load VGG16 for feature similarity 38 | vgg16 = PNet().to("cuda") 39 | vgg16.eval() 40 | vgg16.cuda() 41 | 42 | values_percsim = [] 43 | values_ssim = [] 44 | values_psnr = [] 45 | folders = os.listdir(folder) 46 | for i, f in tqdm(enumerate(sorted(folders))): 47 | pred_imgs = glob.glob(folder + f + "/" + pred_img) 48 | tgt_imgs = glob.glob(folder + f + "/" + tgt_img) 49 | assert len(tgt_imgs) == 1 50 | 51 | perc_sim = 10000 52 | ssim_sim = -10 53 | psnr_sim = -10 54 | for p_img in pred_imgs: 55 | t_img = load_img(tgt_imgs[0]) 56 | p_img = load_img(p_img, size=t_img.size(2)) 57 | t_perc_sim = perceptual_sim(p_img, t_img, vgg16).item() 58 | perc_sim = min(perc_sim, t_perc_sim) 59 | 60 | ssim_sim = max(ssim_sim, ssim_metric(p_img, t_img).item()) 61 | psnr_sim = max(psnr_sim, psnr(p_img, t_img).item()) 62 | 63 | values_percsim += [perc_sim] 64 | values_ssim += [ssim_sim] 65 | values_psnr += [psnr_sim] 66 | 67 | if take_every_other: 68 | n_valuespercsim = [] 69 | n_valuesssim = [] 70 | n_valuespsnr = [] 71 | for i in range(0, len(values_percsim) // 2): 72 | n_valuespercsim += [ 73 | min(values_percsim[2 * i], values_percsim[2 * i + 1]) 74 | ] 75 | n_valuespsnr += [max(values_psnr[2 * i], values_psnr[2 * i + 1])] 76 | n_valuesssim += [max(values_ssim[2 * i], values_ssim[2 * i + 1])] 77 | 78 | values_percsim = n_valuespercsim 79 | values_ssim = n_valuesssim 80 | values_psnr = n_valuespsnr 81 | 82 | avg_percsim = np.mean(np.array(values_percsim)) 83 | std_percsim = np.std(np.array(values_percsim)) 84 | 85 | avg_psnr = np.mean(np.array(values_psnr)) 86 | std_psnr = np.std(np.array(values_psnr)) 87 | 88 | avg_ssim = np.mean(np.array(values_ssim)) 89 | std_ssim = np.std(np.array(values_ssim)) 90 | 91 | return { 92 | "Perceptual similarity": (avg_percsim, std_percsim), 93 | "PSNR": (avg_psnr, std_psnr), 94 | "SSIM": (avg_ssim, std_ssim), 95 | } 96 | 97 | 98 | if __name__ == "__main__": 99 | args = argparse.ArgumentParser() 100 | args.add_argument("--folder", type=str, default="") 101 | args.add_argument("--pred_image", type=str, default="") 102 | args.add_argument("--target_image", type=str, default="") 103 | args.add_argument("--take_every_other", action="store_true", default=False) 104 | args.add_argument("--output_file", type=str, default="") 105 | 106 | opts = args.parse_args() 107 | 108 | folder = opts.folder 109 | pred_img = opts.pred_image 110 | tgt_img = opts.target_image 111 | 112 | results = compute_perceptual_similarity( 113 | folder, pred_img, tgt_img, opts.take_every_other 114 | ) 115 | 116 | f = open(opts.output_file, 'w') 117 | for key in results: 118 | print("%s for %s: \n" % (key, opts.folder)) 119 | print( 120 | "\t {:0.4f} | {:0.4f} \n".format(results[key][0], results[key][1]) 121 | ) 122 | 123 | f.write("%s for %s: \n" % (key, opts.folder)) 124 | f.write( 125 | "\t {:0.4f} | {:0.4f} \n".format(results[key][0], results[key][1]) 126 | ) 127 | 128 | f.close() 129 | -------------------------------------------------------------------------------- /evaluation/metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | from models.losses.ssim import ssim 4 | 5 | # The SSIM metric 6 | def ssim_metric(img1, img2, mask=None): 7 | return ssim(img1, img2, mask=mask, size_average=False) 8 | 9 | 10 | # The PSNR metric 11 | def psnr(img1, img2, mask=None): 12 | b = img1.size(0) 13 | if not (mask is None): 14 | b = img1.size(0) 15 | mse_err = (img1 - img2).pow(2) * mask 16 | mse_err = mse_err.view(b, -1).sum(dim=1) / ( 17 | 3 * mask.view(b, -1).sum(dim=1).clamp(min=1) 18 | ) 19 | else: 20 | mse_err = (img1 - img2).pow(2).view(b, -1).mean(dim=1) 21 | 22 | psnr = 10 * (1 / mse_err).log10() 23 | return psnr 24 | 25 | 26 | # The perceptual similarity metric 27 | def perceptual_sim(img1, img2, vgg16): 28 | # First extract features 29 | dist = vgg16(img1 * 2 - 1, img2 * 2 - 1) 30 | 31 | return dist 32 | -------------------------------------------------------------------------------- /geometry/camera_transformations.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | import numpy as np 4 | import torch 5 | 6 | 7 | def invert_RT(RT): 8 | """ Given an RT matrix (e.g. [R | T]) matrix where R is 9 | indeed valid, then inverts this. 10 | """ 11 | R = RT[:, 0:3, 0:3] 12 | T = RT[:, :, 3:] 13 | 14 | # Get inverse of the rotation matrix 15 | Rinv = R.permute(0, 2, 1) 16 | Tinv = -Rinv.bmm(T) 17 | 18 | RTinv = torch.cat((Rinv, Tinv), 2) 19 | 20 | return RTinv 21 | 22 | 23 | def invert_K(K): 24 | """ Given a K matrix (an intrinsic matrix) of the form 25 | [f 0 px] 26 | [0 f py] 27 | [0 0 1], inverts it. 28 | """ 29 | K_inv = ( 30 | torch.eye(K.size(1)).to(K.device).unsqueeze(0).repeat(K.size(0), 1, 1) 31 | ) 32 | 33 | K_inv[:, 0, 0] = 1 / K[:, 0, 0] 34 | K_inv[:, 0, 2] = -K[:, 0, 2] / K[:, 0, 0] 35 | K_inv[:, 1, 1] = 1 / K[:, 1, 1] 36 | K_inv[:, 1, 2] = -K[:, 1, 2] / K[:, 1, 1] 37 | 38 | return K_inv 39 | 40 | 41 | def get_camera_matrices(position, rotation): 42 | 43 | Pinv = np.eye(4) 44 | Pinv[0:3, 0:3] = rotation 45 | Pinv[0:3, 3] = position 46 | 47 | P = np.linalg.inv(Pinv) 48 | 49 | return P.astype(np.float32), Pinv.astype(np.float32) 50 | 51 | 52 | if __name__ == "__main__": 53 | # Test the inversion code 54 | # 1. Test the RT 55 | 56 | rotate = np.linalg.qr(np.random.randn(3, 3))[0] 57 | R = torch.Tensor(rotate) 58 | translation = np.random.randn(3, 1) 59 | T = torch.Tensor(translation) 60 | 61 | RT = torch.cat((R, T), 1).unsqueeze(0) 62 | 63 | RTinv = invert_RT(RT) 64 | print(RT[:, 0:3, 0:3].bmm(RTinv[:, 0:3, 0:3])) 65 | x = torch.randn(1, 4, 1) 66 | x[0, 3, 0] = 1 67 | xp = RT.bmm(x) 68 | xp = torch.cat((xp, torch.ones((1, 1, 1))), 1) 69 | print(RTinv.bmm(xp) - x[:, 0:3, :]) 70 | 71 | K = torch.eye(3).unsqueeze(0).repeat(2, 1, 1) 72 | K[0, 0, 0] = torch.randn(1) 73 | K[0, 1, 1] = torch.randn(1) 74 | K[0, 0, 2] = torch.randn(1) 75 | K[1, 1, 2] = torch.randn(1) 76 | K[1, 0, 0] = torch.randn(1) 77 | K[1, 1, 1] = torch.randn(1) 78 | K[1, 0, 2] = torch.randn(1) 79 | K[1, 1, 2] = torch.randn(1) 80 | 81 | Kinv = invert_K(K) 82 | print(Kinv.bmm(K)) 83 | -------------------------------------------------------------------------------- /models/base_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import init 6 | 7 | from models.losses.gan_loss import DiscriminatorLoss 8 | 9 | class BaseModel(nn.Module): 10 | def __init__(self, model, opt): 11 | super().__init__() 12 | self.model = model 13 | 14 | self.opt = opt 15 | 16 | if opt.discriminator_losses: 17 | self.use_discriminator = True 18 | 19 | self.netD = DiscriminatorLoss(opt) 20 | 21 | if opt.isTrain: 22 | self.optimizer_D = torch.optim.Adam( 23 | list(self.netD.parameters()), 24 | lr=opt.lr_d, 25 | betas=(opt.beta1, opt.beta2), 26 | ) 27 | self.optimizer_G = torch.optim.Adam( 28 | list(self.model.parameters()), 29 | lr=opt.lr_g, 30 | betas=(opt.beta1, opt.beta2), 31 | ) 32 | else: 33 | self.use_discriminator = False 34 | self.optimizer_G = torch.optim.Adam( 35 | list(self.model.parameters()), 36 | lr=opt.lr_g, 37 | betas=(0.99, opt.beta2), 38 | ) 39 | 40 | if opt.isTrain: 41 | self.old_lr = opt.lr 42 | 43 | if opt.init: 44 | self.init_weights() 45 | 46 | def init_weights(self, gain=0.02): 47 | def init_func(m): 48 | classname = m.__class__.__name__ 49 | if hasattr(m, "weight") and ( 50 | classname.find("Conv") != -1 or classname.find("Linear") != -1 51 | ): 52 | if self.opt.init == "normal": 53 | init.normal_(m.weight.data, 0.0, gain) 54 | elif self.opt.init == "xavier": 55 | init.xavier_normal_(m.weight.data, gain=gain) 56 | elif self.opt.init == "xavier_uniform": 57 | init.xavier_uniform_(m.weight.data, gain=1.0) 58 | elif self.opt.init == "kaiming": 59 | init.kaiming_normal_(m.weight.data, a=0, mode="fan_in") 60 | elif self.opt.init == "orthogonal": 61 | init.orthogonal_(m.weight.data) 62 | elif self.opt.init == "": # uses pytorch's default init method 63 | m.reset_parameters() 64 | else: 65 | raise NotImplementedError( 66 | "initialization method [%s] is not implemented" 67 | % self.opt.init 68 | ) 69 | if hasattr(m, "bias") and m.bias is not None: 70 | init.constant_(m.bias.data, 0.0) 71 | 72 | self.apply(init_func) 73 | 74 | # propagate to children 75 | for m in self.children(): 76 | if hasattr(m, "init_weights"): 77 | m.init_weights(self.opt.init, gain) 78 | 79 | def __call__( 80 | self, dataloader, isval=False, num_steps=1, return_batch=False 81 | ): 82 | """ 83 | Main function call 84 | - dataloader: The sampler that choose data samples. 85 | - isval: Whether to train the discriminator etc. 86 | - num steps: not fully implemented but is number of steps in the discriminator for 87 | each in the generator 88 | - return_batch: Whether to return the input values 89 | """ 90 | weight = 1.0 / float(num_steps) 91 | if isval: 92 | batch = next(dataloader) 93 | t_losses, output_images = self.model(batch) 94 | 95 | if self.opt.normalize_image: 96 | for k in output_images.keys(): 97 | if "Img" in k: 98 | output_images[k] = 0.5 * output_images[k] + 0.5 99 | 100 | if return_batch: 101 | return t_losses, output_images, batch 102 | return t_losses, output_images 103 | 104 | self.optimizer_G.zero_grad() 105 | if self.use_discriminator: 106 | all_output_images = [] 107 | for j in range(0, num_steps): 108 | t_losses, output_images = self.model(next(dataloader)) 109 | g_losses = self.netD.run_generator_one_step( 110 | output_images["PredImg"], output_images["OutputImg"] 111 | ) 112 | ( 113 | g_losses["Total Loss"] / weight 114 | + t_losses["Total Loss"] / weight 115 | ).mean().backward() 116 | all_output_images += [output_images] 117 | self.optimizer_G.step() 118 | 119 | self.optimizer_D.zero_grad() 120 | for step in range(0, num_steps): 121 | d_losses = self.netD.run_discriminator_one_step( 122 | all_output_images[step]["PredImg"], 123 | all_output_images[step]["OutputImg"], 124 | ) 125 | (d_losses["Total Loss"] / weight).mean().backward() 126 | # Apply orthogonal regularization from BigGan 127 | self.optimizer_D.step() 128 | 129 | g_losses.pop("Total Loss") 130 | d_losses.pop("Total Loss") 131 | t_losses.update(g_losses) 132 | t_losses.update(d_losses) 133 | else: 134 | for step in range(0, num_steps): 135 | t_losses, output_images = self.model(next(dataloader)) 136 | (t_losses["Total Loss"] / weight).mean().backward() 137 | self.optimizer_G.step() 138 | 139 | if self.opt.normalize_image: 140 | for k in output_images.keys(): 141 | if "Img" in k: 142 | output_images[k] = 0.5 * output_images[k] + 0.5 143 | 144 | return t_losses, output_images 145 | -------------------------------------------------------------------------------- /models/depth_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from models.projection import depth_manipulator 8 | 9 | EPS = 1e-4 10 | 11 | 12 | class Model(nn.Module): 13 | """ 14 | Model for transforming by depth. Used only in evaluation, not training. 15 | As a result, use a brute force method for splatting points. This is 16 | improved in the z_buffermodel.py code to allow for differentiable rendering. 17 | """ 18 | def __init__(self, opt): 19 | super(Model, self).__init__() 20 | self.use_z = opt.use_z 21 | self.use_alpha = opt.use_alpha 22 | self.opt = opt 23 | 24 | # REFINER 25 | # Refine the projected depth 26 | 27 | num_inputs = 3 28 | 29 | opt.num_inputs = num_inputs 30 | 31 | # PROJECTION 32 | # Project according to the predicted depth 33 | self.projector = depth_manipulator.DepthManipulator(W=opt.W) 34 | 35 | 36 | def forward(self, batch): 37 | """ Forward pass of a view synthesis model with a predicted depth. 38 | """ 39 | # Input values 40 | input_img = batch["images"][0] 41 | depth_img = batch["depths"][0] 42 | output_img = batch["images"][-1] 43 | 44 | # Camera parameters 45 | K = batch["cameras"][0]["K"] 46 | K_inv = batch["cameras"][0]["Kinv"] 47 | 48 | input_RTinv = batch["cameras"][0]["Pinv"] 49 | output_RT = batch["cameras"][-1]["P"] 50 | 51 | if torch.cuda.is_available(): 52 | input_img = input_img.cuda() 53 | depth_img = depth_img.cuda() 54 | output_img = output_img.cuda() 55 | 56 | K = K.cuda() 57 | K_inv = K_inv.cuda() 58 | 59 | input_RTinv = input_RTinv.cuda() 60 | output_RT = output_RT.cuda() 61 | 62 | # Transform the image according to intrinsic parameters 63 | # and rotation and depth 64 | sampled_image = self.transform_perfimage( 65 | input_img, output_img, depth_img, K, K_inv, input_RTinv, output_RT 66 | ) 67 | 68 | mask = (batch["depths"][1] < 10).float() * ( 69 | batch["depths"][1] > EPS 70 | ).float() 71 | 72 | return ( 73 | 0, 74 | { 75 | "InputImg": input_img, 76 | "OutputImg": output_img, 77 | "Mask": mask.float(), 78 | "SampledImg": sampled_image[:, 0:3, :, :], 79 | "Diff Sampled": ( 80 | sampled_image[:, 0:3, :, :] - output_img 81 | ).abs(), 82 | "Depth": depth_img, 83 | }, 84 | ) 85 | 86 | def transform_perfimage( 87 | self, input_img, output_img, depth_img, K, K_inv, RTinv_cam1, RT_cam2 88 | ): 89 | """ Create a new view of an input image. 90 | Transform according to the output rotation/translation. 91 | """ 92 | # Transform according to the depth projection 93 | sampler, _ = self.projector.project_zbuffer( 94 | depth=depth_img, 95 | K=K, 96 | K_inv=K_inv, 97 | RTinv_cam1=RTinv_cam1, 98 | RT_cam2=RT_cam2, 99 | ) 100 | 101 | # Sample image according to this sampler 102 | 103 | mask = ((sampler > -1).float() * (sampler < 1).float()).min( 104 | dim=1, keepdim=True 105 | )[0] 106 | mask = F.avg_pool2d(mask, kernel_size=3, stride=1, padding=1) 107 | mask = (mask > 0).float() 108 | 109 | sampled_image = output_img * mask 110 | 111 | return sampled_image 112 | -------------------------------------------------------------------------------- /models/encoderdecoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from models.losses.synthesis import SynthesisLoss 8 | 9 | 10 | class CollapseLayer(nn.Module): 11 | def forward(self, input): 12 | return input.view(input.size(0), -1) 13 | 14 | 15 | class UnCollapseLayer(nn.Module): 16 | def __init__(self, C, W, H): 17 | super().__init__() 18 | self.C = C 19 | self.W = W 20 | self.H = H 21 | 22 | def forward(self, input): 23 | return input.view(input.size(0), self.C, self.W, self.H) 24 | 25 | 26 | class ViewAppearanceFlow(nn.Module): 27 | """ 28 | View Appearance Flow based on the corresponding paper. 29 | """ 30 | 31 | def __init__(self, opt): 32 | super().__init__() 33 | 34 | self.encoder = nn.Sequential( 35 | nn.Conv2d(3, 16, 3, 2, padding=1), 36 | nn.ReLU(), 37 | nn.BatchNorm2d(16), # 128 38 | nn.Conv2d(16, 32, 3, 2, padding=1), 39 | nn.ReLU(), 40 | nn.BatchNorm2d(32), # 64 41 | nn.Conv2d(32, 64, 3, 2, padding=1), 42 | nn.ReLU(), 43 | nn.BatchNorm2d(64), # 32 44 | nn.Conv2d(64, 128, 3, 2, padding=1), 45 | nn.ReLU(), 46 | nn.BatchNorm2d(128), # 16 47 | nn.Conv2d(128, 256, 3, 2, padding=1), 48 | nn.ReLU(), 49 | nn.BatchNorm2d(256), # 8 50 | nn.Conv2d(256, 512, 3, 2, padding=1), 51 | nn.ReLU(), 52 | nn.BatchNorm2d(512), # 4 53 | CollapseLayer(), 54 | nn.Linear(8192, 4096), 55 | nn.ReLU(), 56 | nn.BatchNorm2d(4096), 57 | nn.Linear(4096, 4096), 58 | nn.ReLU(), 59 | nn.BatchNorm2d(4096), 60 | ) 61 | 62 | self.decoder = nn.Sequential( 63 | nn.Linear(4096 + 256, 4096), 64 | nn.ReLU(), 65 | nn.BatchNorm2d(4096), 66 | nn.Linear(4096, 4096), 67 | nn.ReLU(), 68 | nn.BatchNorm2d(4096), 69 | UnCollapseLayer(64, 8, 8), 70 | nn.Conv2d(64, 256, 3, 1, padding=1), 71 | nn.ReLU(), 72 | nn.BatchNorm2d(256), 73 | nn.Upsample(scale_factor=2), 74 | nn.Conv2d(256, 128, 3, 1, padding=1), 75 | nn.ReLU(), 76 | nn.BatchNorm2d(128), 77 | nn.Upsample(scale_factor=2), 78 | nn.Conv2d(128, 64, 3, 1, padding=1), 79 | nn.ReLU(), 80 | nn.BatchNorm2d(64), 81 | nn.Upsample(scale_factor=2), 82 | nn.Conv2d(64, 32, 3, 1, padding=1), 83 | nn.ReLU(), 84 | nn.BatchNorm2d(32), 85 | nn.Upsample(scale_factor=2), 86 | nn.Conv2d(32, 16, 3, 1, padding=1), 87 | nn.ReLU(), 88 | nn.BatchNorm2d(16), 89 | nn.Upsample(scale_factor=2), 90 | nn.Conv2d(16, 2, 3, 1, padding=1), 91 | nn.Tanh(), 92 | ) 93 | 94 | self.angle_transformer = nn.Sequential( 95 | nn.Linear(12, 128), 96 | nn.ReLU(), 97 | nn.BatchNorm1d(128), 98 | nn.Linear(128, 256), 99 | nn.ReLU(), 100 | nn.BatchNorm1d(256), 101 | ) 102 | 103 | self.loss_function = SynthesisLoss(opt=opt) 104 | 105 | self.opt = opt 106 | 107 | def forward(self, batch): 108 | input_img = batch["images"][0] 109 | output_img = batch["images"][-1] 110 | 111 | input_RTinv = batch["cameras"][0]["Pinv"] 112 | output_RT = batch["cameras"][-1]["P"] 113 | 114 | if torch.cuda.is_available(): 115 | input_img = input_img.cuda() 116 | output_img = output_img.cuda() 117 | 118 | input_RTinv = input_RTinv.cuda() 119 | output_RT = output_RT.cuda() 120 | 121 | RT = input_RTinv.bmm(output_RT)[:, 0:3, :] 122 | 123 | # Now transform the change in angle 124 | fs = self.encoder(input_img) 125 | fs_angle = self.angle_transformer(RT.view(RT.size(0), -1)) 126 | 127 | # And concatenate 128 | fs = torch.cat((fs, fs_angle), 1) 129 | sampler = self.decoder(fs) 130 | gen_img = F.grid_sample(input_img, sampler.permute(0, 2, 3, 1)) 131 | 132 | # And the loss 133 | loss = self.loss_function(gen_img, output_img) 134 | 135 | # And return 136 | return ( 137 | loss, 138 | { 139 | "InputImg": input_img, 140 | "OutputImg": output_img, 141 | "PredImg": gen_img, 142 | }, 143 | ) 144 | 145 | def forward_angle(self, batch, RTs, return_depth=False): 146 | # Input values 147 | input_img = batch["images"][0] 148 | 149 | # Camera parameters 150 | K = batch["cameras"][0]["K"] 151 | K_inv = batch["cameras"][0]["Kinv"] 152 | 153 | if torch.cuda.is_available(): 154 | input_img = input_img.cuda() 155 | 156 | K = K.cuda() 157 | K_inv = K_inv.cuda() 158 | 159 | RTs = [RT[:, 0:3, :].cuda() for RT in RTs] 160 | 161 | fs = self.encoder(input_img) 162 | # Now rotate 163 | gen_imgs = [] 164 | for i, RT in enumerate(RTs): 165 | torch.manual_seed( 166 | 0 167 | ) # Reset seed each time so that noise vectors are the same 168 | fs_angle = self.angle_transformer(RT.view(RT.size(0), -1)) 169 | 170 | # And concatenate 171 | fs_new = torch.cat((fs, fs_angle), 1) 172 | sampler = self.decoder(fs_new) 173 | gen_img = F.grid_sample(input_img, sampler.permute(0, 2, 3, 1)) 174 | 175 | gen_imgs += [gen_img] 176 | 177 | if return_depth: 178 | return gen_imgs, torch.zeros(fs.size(0), 1, 256, 256) 179 | 180 | return gen_imgs 181 | 182 | 183 | class Tatarchenko(nn.Module): 184 | def __init__(self, opt): 185 | super().__init__() 186 | 187 | self.encoder = nn.Sequential( 188 | nn.Conv2d(3, 16, 3, 2, padding=1), 189 | nn.LeakyReLU(0.2), 190 | nn.BatchNorm2d(16), # 128 191 | nn.Conv2d(16, 32, 3, 2, padding=1), 192 | nn.LeakyReLU(0.2), 193 | nn.BatchNorm2d(32), # 64 194 | nn.Conv2d(32, 64, 3, 2, padding=1), 195 | nn.LeakyReLU(0.2), 196 | nn.BatchNorm2d(64), # 32 197 | nn.Conv2d(64, 128, 3, 2, padding=1), 198 | nn.LeakyReLU(0.2), 199 | nn.BatchNorm2d(128), # 16 200 | nn.Conv2d(128, 256, 3, 2, padding=1), 201 | nn.LeakyReLU(0.2), 202 | nn.BatchNorm2d(256), # 8 203 | nn.Conv2d(256, 512, 3, 2, padding=1), 204 | nn.LeakyReLU(0.2), 205 | nn.BatchNorm2d(512), # 4 206 | CollapseLayer(), 207 | nn.Linear(8192, 4096), 208 | nn.LeakyReLU(0.2), 209 | nn.BatchNorm2d(4096), 210 | nn.Linear(4096, 4096), 211 | nn.LeakyReLU(0.2), 212 | nn.BatchNorm2d(4096), 213 | ) 214 | 215 | self.decoder = nn.Sequential( 216 | nn.Linear(4096 + 64, 4096), 217 | nn.LeakyReLU(0.2), 218 | nn.BatchNorm2d(4096), 219 | nn.Linear(4096, 4096), 220 | nn.LeakyReLU(0.2), 221 | nn.BatchNorm2d(4096), 222 | UnCollapseLayer(64, 8, 8), 223 | nn.Conv2d(64, 256, 3, 1, padding=1), 224 | nn.ReLU(), 225 | nn.BatchNorm2d(256), 226 | nn.Upsample(scale_factor=2), 227 | nn.Conv2d(256, 128, 3, 1, padding=1), 228 | nn.ReLU(), 229 | nn.BatchNorm2d(128), 230 | nn.Upsample(scale_factor=2), 231 | nn.Conv2d(128, 64, 3, 1, padding=1), 232 | nn.ReLU(), 233 | nn.BatchNorm2d(64), 234 | nn.Upsample(scale_factor=2), 235 | nn.Conv2d(64, 32, 3, 1, padding=1), 236 | nn.ReLU(), 237 | nn.BatchNorm2d(32), 238 | nn.Upsample(scale_factor=2), 239 | nn.Conv2d(32, 16, 3, 1, padding=1), 240 | nn.ReLU(), 241 | nn.BatchNorm2d(16), 242 | nn.Upsample(scale_factor=2), 243 | nn.Conv2d(16, 3, 3, 1, padding=1), 244 | nn.Tanh(), 245 | ) 246 | 247 | self.angle_transformer = nn.Sequential( 248 | nn.Linear(12, 64), 249 | nn.LeakyReLU(0.2), 250 | nn.BatchNorm1d(64), 251 | nn.Linear(64, 64), 252 | nn.LeakyReLU(0.2), 253 | nn.BatchNorm1d(64), 254 | ) 255 | 256 | self.loss_function = SynthesisLoss(opt=opt) 257 | 258 | self.opt = opt 259 | 260 | def forward(self, batch): 261 | input_img = batch["images"][0] 262 | output_img = batch["images"][-1] 263 | 264 | input_RTinv = batch["cameras"][0]["Pinv"] 265 | output_RT = batch["cameras"][-1]["P"] 266 | 267 | if torch.cuda.is_available(): 268 | input_img = input_img.cuda() 269 | output_img = output_img.cuda() 270 | 271 | input_RTinv = input_RTinv.cuda() 272 | output_RT = output_RT.cuda() 273 | 274 | RT = input_RTinv.bmm(output_RT)[:, 0:3, :] 275 | 276 | # Now transform the change in angle 277 | fs = self.encoder(input_img) 278 | fs_angle = self.angle_transformer(RT.view(RT.size(0), -1)) 279 | 280 | # And concatenate 281 | fs = torch.cat((fs, fs_angle), 1) 282 | gen_img = self.decoder(fs) 283 | 284 | loss = self.loss_function(gen_img, output_img) 285 | 286 | # And return 287 | return ( 288 | loss, 289 | { 290 | "InputImg": input_img, 291 | "OutputImg": output_img, 292 | "PredImg": gen_img, 293 | }, 294 | ) 295 | -------------------------------------------------------------------------------- /models/layers/blocks.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | import torch.nn as nn 4 | from models.layers.normalization import LinearNoiseLayer 5 | 6 | 7 | def spectral_conv_function(in_c, out_c, k, p, s): 8 | return nn.utils.spectral_norm( 9 | nn.Conv2d(in_c, out_c, kernel_size=k, padding=p, stride=s) 10 | ) 11 | 12 | 13 | def conv_function(in_c, out_c, k, p, s): 14 | return nn.Conv2d(in_c, out_c, kernel_size=k, padding=p, stride=s) 15 | 16 | 17 | def get_conv_layer(opt): 18 | if "spectral" in opt.norm_G: 19 | conv_layer_base = spectral_conv_function 20 | else: 21 | conv_layer_base = conv_function 22 | 23 | return conv_layer_base 24 | 25 | 26 | # Convenience passthrough function 27 | class Identity(nn.Module): 28 | def forward(self, input): 29 | return input 30 | 31 | 32 | # ResNet Blocks 33 | class ResNet_Block(nn.Module): 34 | def __init__(self, in_c, in_o, opt, downsample=None): 35 | super().__init__() 36 | bn_noise1 = LinearNoiseLayer(opt, output_sz=in_c) 37 | bn_noise2 = LinearNoiseLayer(opt, output_sz=in_o) 38 | 39 | conv_layer = get_conv_layer(opt) 40 | 41 | conv_aa = conv_layer(in_c, in_o, 3, 1, 1) 42 | conv_ab = conv_layer(in_o, in_o, 3, 1, 1) 43 | 44 | conv_b = conv_layer(in_c, in_o, 1, 0, 1) 45 | 46 | if downsample == "Down": 47 | norm_downsample = nn.AvgPool2d(kernel_size=3, stride=2, padding=1) 48 | elif downsample == "Up": 49 | norm_downsample = nn.Upsample(scale_factor=2, mode="bilinear") 50 | elif downsample: 51 | norm_downsample = nn.AvgPool2d(kernel_size=3, stride=2, padding=1) 52 | else: 53 | norm_downsample = Identity() 54 | 55 | self.ch_a = nn.Sequential( 56 | bn_noise1, 57 | nn.ReLU(), 58 | conv_aa, 59 | bn_noise2, 60 | nn.ReLU(), 61 | conv_ab, 62 | norm_downsample, 63 | ) 64 | 65 | if downsample or (in_c != in_o): 66 | self.ch_b = nn.Sequential(conv_b, norm_downsample) 67 | else: 68 | self.ch_b = Identity() 69 | 70 | def forward(self, x): 71 | x_a = self.ch_a(x) 72 | x_b = self.ch_b(x) 73 | 74 | return x_a + x_b 75 | -------------------------------------------------------------------------------- /models/layers/normalization.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.utils.spectral_norm as spectral_norm 6 | from torch.nn.parameter import Parameter 7 | 8 | def get_linear_layer(opt, bias=False): 9 | if "spectral" in opt.norm_G: 10 | linear_layer_base = lambda in_c, out_c: nn.utils.spectral_norm( 11 | nn.Linear(in_c, out_c, bias=bias) 12 | ) 13 | else: 14 | linear_layer_base = lambda in_c, out_c: nn.Linear( 15 | in_c, out_c, bias=bias 16 | ) 17 | 18 | return linear_layer_base 19 | 20 | 21 | class LinearNoiseLayer(nn.Module): 22 | def __init__(self, opt, noise_sz=20, output_sz=32): 23 | """ 24 | Class for adding in noise to the batch normalisation layer. 25 | Based on the idea from BigGAN. 26 | """ 27 | super().__init__() 28 | self.noise_sz = noise_sz 29 | 30 | linear_layer = get_linear_layer(opt, bias=False) 31 | 32 | self.gain = linear_layer(noise_sz, output_sz) 33 | self.bias = linear_layer(noise_sz, output_sz) 34 | 35 | self.bn = bn(output_sz) 36 | 37 | self.noise_sz = noise_sz 38 | 39 | def forward(self, x): 40 | noise = torch.randn(x.size(0), self.noise_sz).to(x.device) 41 | 42 | # Predict biases/gains for this layer from the noise 43 | gain = (1 + self.gain(noise)).view(noise.size(0), -1, 1, 1) 44 | bias = self.bias(noise).view(noise.size(0), -1, 1, 1) 45 | 46 | xp = self.bn(x, gain=gain, bias=bias) 47 | return xp 48 | 49 | 50 | # Returns a function that creates a normalization function 51 | # that does not condition on semantic map 52 | def get_D_norm_layer(opt, norm_type="instance"): 53 | # helper function to get # output channels of the previous layer 54 | def get_out_channel(layer): 55 | if hasattr(layer, "out_channels"): 56 | return getattr(layer, "out_channels") 57 | return layer.weight.size(0) 58 | 59 | # this function will be returned 60 | def add_norm_layer(layer): 61 | nonlocal norm_type 62 | if norm_type.startswith("spectral"): 63 | layer = spectral_norm(layer) 64 | subnorm_type = norm_type[len("spectral") :] 65 | 66 | if subnorm_type == "none" or len(subnorm_type) == 0: 67 | return layer 68 | 69 | # remove bias in the previous layer, which is meaningless 70 | # since it has no effect after normalization 71 | if getattr(layer, "bias", None) is not None: 72 | delattr(layer, "bias") 73 | layer.register_parameter("bias", None) 74 | 75 | if subnorm_type == "batch": 76 | norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True) 77 | 78 | elif subnorm_type == "instance": 79 | norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False) 80 | else: 81 | raise ValueError( 82 | "normalization layer %s is not recognized" % subnorm_type 83 | ) 84 | 85 | return nn.Sequential(layer, norm_layer) 86 | 87 | return add_norm_layer 88 | 89 | 90 | # BatchNorm layers are taken from the BigGAN code base. 91 | # https://github.com/ajbrock/BigGAN-PyTorch/blob/a5557079924c3070b39e67f2eaea3a52c0fb72ab/layers.py 92 | # Distributed under the MIT licence. 93 | 94 | # Normal, non-class-conditional BN 95 | class BatchNorm_StandingStats(nn.Module): 96 | def __init__(self, output_size, eps=1e-5, momentum=0.1): 97 | super().__init__() 98 | self.output_size = output_size 99 | # Prepare gain and bias layers 100 | self.register_parameter("gain", Parameter(torch.ones(output_size))) 101 | self.register_parameter("bias", Parameter(torch.zeros(output_size))) 102 | # epsilon to avoid dividing by 0 103 | self.eps = eps 104 | # Momentum 105 | self.momentum = momentum 106 | 107 | self.bn = bn(output_size, self.eps, self.momentum) 108 | 109 | def forward(self, x, y=None): 110 | gain = self.gain.view(1, -1, 1, 1) 111 | bias = self.bias.view(1, -1, 1, 1) 112 | return self.bn(x, gain=gain, bias=bias) 113 | 114 | class bn(nn.Module): 115 | def __init__(self, num_channels, eps=1e-5, momentum=0.1): 116 | super().__init__() 117 | 118 | # momentum for updating stats 119 | self.momentum = momentum 120 | self.eps = eps 121 | 122 | self.register_buffer("stored_mean", torch.zeros(num_channels)) 123 | self.register_buffer("stored_var", torch.ones(num_channels)) 124 | self.register_buffer("accumulation_counter", torch.zeros(1)) 125 | # Accumulate running means and vars 126 | self.accumulate_standing = False 127 | 128 | # reset standing stats 129 | def reset_stats(self): 130 | self.stored_mean[:] = 0 131 | self.stored_var[:] = 0 132 | self.accumulation_counter[:] = 0 133 | 134 | def forward(self, x, gain, bias): 135 | if self.training: 136 | out, mean, var = manual_bn( 137 | x, gain, bias, return_mean_var=True, eps=self.eps 138 | ) 139 | # If accumulating standing stats, increment them 140 | with torch.no_grad(): 141 | if self.accumulate_standing: 142 | self.stored_mean[:] = self.stored_mean + mean.data 143 | self.stored_var[:] = self.stored_var + var.data 144 | self.accumulation_counter += 1.0 145 | # If not accumulating standing stats, take running averages 146 | else: 147 | self.stored_mean[:] = ( 148 | self.stored_mean * (1 - self.momentum) 149 | + mean * self.momentum 150 | ) 151 | self.stored_var[:] = ( 152 | self.stored_var * (1 - self.momentum) + var * self.momentum 153 | ) 154 | return out 155 | # If not in training mode, use the stored statistics 156 | else: 157 | mean = self.stored_mean.view(1, -1, 1, 1) 158 | var = self.stored_var.view(1, -1, 1, 1) 159 | # If using standing stats, divide them by the accumulation counter 160 | if self.accumulate_standing: 161 | mean = mean / self.accumulation_counter 162 | var = var / self.accumulation_counter 163 | return fused_bn(x, mean, var, gain, bias, self.eps) 164 | 165 | 166 | # Fused batchnorm op 167 | def fused_bn(x, mean, var, gain=None, bias=None, eps=1e-5): 168 | # Apply scale and shift--if gain and bias are provided, fuse them here 169 | # Prepare scale 170 | scale = torch.rsqrt(var + eps) 171 | # If a gain is provided, use it 172 | if gain is not None: 173 | scale = scale * gain 174 | # Prepare shift 175 | shift = mean * scale 176 | # If bias is provided, use it 177 | if bias is not None: 178 | shift = shift - bias 179 | return x * scale - shift 180 | 181 | 182 | # Manual BN 183 | # Calculate means and variances using mean-of-squares minus mean-squared 184 | def manual_bn(x, gain=None, bias=None, return_mean_var=False, eps=1e-5): 185 | # Cast x to float32 if necessary 186 | float_x = x.float() 187 | # Calculate expected value of x (m) and expected value of x**2 (m2) 188 | # Mean of x 189 | m = torch.mean(float_x, [0, 2, 3], keepdim=True) 190 | # Mean of x squared 191 | m2 = torch.mean(float_x ** 2, [0, 2, 3], keepdim=True) 192 | # Calculate variance as mean of squared minus mean squared. 193 | var = m2 - m ** 2 194 | # Cast back to float 16 if necessary 195 | var = var.type(x.type()) 196 | m = m.type(x.type()) 197 | # Return mean and variance for updating stored mean/var if requested 198 | if return_mean_var: 199 | return fused_bn(x, m, var, gain, bias, eps), m.squeeze(), var.squeeze() 200 | else: 201 | return fused_bn(x, m, var, gain, bias, eps) 202 | -------------------------------------------------------------------------------- /models/layers/z_buffer_layers.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from pytorch3d.structures import Pointclouds 7 | from pytorch3d.renderer import compositing 8 | from pytorch3d.renderer.points import rasterize_points 9 | 10 | torch.manual_seed(42) 11 | 12 | class RasterizePointsXYsBlending(nn.Module): 13 | """ 14 | Rasterizes a set of points using a differentiable renderer. Points are 15 | accumulated in a z-buffer using an accumulation function 16 | defined in opts.accumulation and are normalised with a value M=opts.M. 17 | Inputs: 18 | - pts3D: the 3D points to be projected 19 | - src: the corresponding features 20 | - C: size of feature 21 | - learn_feature: whether to learn the default feature filled in when 22 | none project 23 | - radius: where pixels project to (in pixels) 24 | - size: size of the image being created 25 | - points_per_pixel: number of values stored in z-buffer per pixel 26 | - opts: additional options 27 | 28 | Outputs: 29 | - transformed_src_alphas: features projected and accumulated 30 | in the new view 31 | """ 32 | 33 | def __init__( 34 | self, 35 | C=64, 36 | learn_feature=True, 37 | radius=1.5, 38 | size=256, 39 | points_per_pixel=8, 40 | opts=None, 41 | ): 42 | super().__init__() 43 | if learn_feature: 44 | default_feature = nn.Parameter(torch.randn(1, C, 1)) 45 | self.register_parameter("default_feature", default_feature) 46 | else: 47 | default_feature = torch.zeros(1, C, 1) 48 | self.register_buffer("default_feature", default_feature) 49 | 50 | self.radius = radius 51 | self.size = size 52 | self.points_per_pixel = points_per_pixel 53 | self.opts = opts 54 | 55 | def forward(self, pts3D, src): 56 | bs = src.size(0) 57 | if len(src.size()) > 3: 58 | bs, c, w, _ = src.size() 59 | image_size = w 60 | 61 | pts3D = pts3D.permute(0, 2, 1) 62 | src = src.unsqueeze(2).repeat(1, 1, w, 1, 1).view(bs, c, -1) 63 | else: 64 | bs = src.size(0) 65 | image_size = self.size 66 | 67 | # Make sure these have been arranged in the same way 68 | assert pts3D.size(2) == 3 69 | assert pts3D.size(1) == src.size(2) 70 | 71 | pts3D[:,:,1] = - pts3D[:,:,1] 72 | pts3D[:,:,0] = - pts3D[:,:,0] 73 | 74 | # Add on the default feature to the end of the src 75 | # src = torch.cat((src, self.default_feature.repeat(bs, 1, 1)), 2) 76 | 77 | radius = float(self.radius) / float(image_size) * 2.0 78 | 79 | pts3D = Pointclouds(points=pts3D, features=src.permute(0,2,1)) 80 | points_idx, _, dist = rasterize_points( 81 | pts3D, image_size, radius, self.points_per_pixel 82 | ) 83 | 84 | if os.environ["DEBUG"]: 85 | print("Max dist: ", dist.max(), pow(radius, self.opts.rad_pow)) 86 | 87 | dist = dist / pow(radius, self.opts.rad_pow) 88 | 89 | if os.environ["DEBUG"]: 90 | print("Max dist: ", dist.max()) 91 | 92 | alphas = ( 93 | (1 - dist.clamp(max=1, min=1e-3).pow(0.5)) 94 | .pow(self.opts.tau) 95 | .permute(0, 3, 1, 2) 96 | ) 97 | 98 | if self.opts.accumulation == 'alphacomposite': 99 | transformed_src_alphas = compositing.alpha_composite( 100 | points_idx.permute(0, 3, 1, 2).long(), 101 | alphas, 102 | pts3D.features_packed().permute(1,0), 103 | ) 104 | elif self.opts.accumulation == 'wsum': 105 | transformed_src_alphas = compositing.weighted_sum( 106 | points_idx.permute(0, 3, 1, 2).long(), 107 | alphas, 108 | pts3D.features_packed().permute(1,0), 109 | ) 110 | elif self.opts.accumulation == 'wsumnorm': 111 | transformed_src_alphas = compositing.weighted_sum_norm( 112 | points_idx.permute(0, 3, 1, 2).long(), 113 | alphas, 114 | pts3D.features_packed().permute(1,0), 115 | ) 116 | 117 | return transformed_src_alphas 118 | -------------------------------------------------------------------------------- /models/losses/gan_loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | 5 | Based on https://github.com/NVlabs/SPADE/blob/master/models/pix2pix_model.py 6 | """ 7 | 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | import models.networks.discriminators as discriminators 14 | 15 | 16 | # Defines the GAN loss which uses either LSGAN or the regular GAN. 17 | # When LSGAN is used, it is basically same as MSELoss, 18 | # but it abstracts away the need to create the target label tensor 19 | # that has the same size as the input 20 | class GANLoss(nn.Module): 21 | def __init__( 22 | self, 23 | gan_mode, 24 | target_real_label=1.0, 25 | target_fake_label=0.0, 26 | tensor=torch.FloatTensor, 27 | opt=None, 28 | ): 29 | super(GANLoss, self).__init__() 30 | self.real_label = target_real_label 31 | self.fake_label = target_fake_label 32 | self.real_label_tensor = None 33 | self.fake_label_tensor = None 34 | self.zero_tensor = None 35 | self.Tensor = tensor 36 | self.gan_mode = gan_mode 37 | self.opt = opt 38 | if gan_mode == "ls": 39 | pass 40 | elif gan_mode == "original": 41 | pass 42 | elif gan_mode == "w": 43 | pass 44 | elif gan_mode == "hinge": 45 | pass 46 | else: 47 | raise ValueError("Unexpected gan_mode {}".format(gan_mode)) 48 | 49 | def get_target_tensor(self, input, target_is_real): 50 | if target_is_real: 51 | if self.real_label_tensor is None: 52 | self.real_label_tensor = ( 53 | self.Tensor(1).fill_(self.real_label).to(input.device) 54 | ) 55 | self.real_label_tensor.requires_grad_(False) 56 | return self.real_label_tensor.expand_as(input) 57 | else: 58 | if self.fake_label_tensor is None: 59 | self.fake_label_tensor = ( 60 | self.Tensor(1).fill_(self.fake_label).to(input.device) 61 | ) 62 | self.fake_label_tensor.requires_grad_(False) 63 | return self.fake_label_tensor.expand_as(input) 64 | 65 | def get_zero_tensor(self, input): 66 | if self.zero_tensor is None: 67 | self.zero_tensor = self.Tensor(1).fill_(0) 68 | self.zero_tensor.requires_grad_(False) 69 | 70 | self.zero_tensor = self.zero_tensor.to(input.device) 71 | return self.zero_tensor.expand_as(input) 72 | 73 | def loss(self, input, target_is_real, for_discriminator=True): 74 | if self.gan_mode == "original": # cross entropy loss 75 | target_tensor = self.get_target_tensor(input, target_is_real) 76 | loss = F.binary_cross_entropy_with_logits(input, target_tensor) 77 | return loss 78 | elif self.gan_mode == "ls": 79 | target_tensor = self.get_target_tensor(input, target_is_real) 80 | return F.mse_loss(input, target_tensor) 81 | elif self.gan_mode == "hinge": 82 | if for_discriminator: 83 | if target_is_real: 84 | minval = torch.min(input - 1, self.get_zero_tensor(input)) 85 | loss = -torch.mean(minval) 86 | else: 87 | minval = torch.min(-input - 1, self.get_zero_tensor(input)) 88 | loss = -torch.mean(minval) 89 | else: 90 | assert ( 91 | target_is_real 92 | ), "The generator's hinge loss must be aiming for real" 93 | loss = -torch.mean(input) 94 | return loss 95 | else: 96 | # wgan 97 | if target_is_real: 98 | return -input.mean() 99 | else: 100 | return input.mean() 101 | 102 | def __call__(self, input, target_is_real, for_discriminator=True): 103 | # computing loss is a bit complicated because |input| may not be 104 | # a tensor, but list of tensors in case of multiscale discriminator 105 | if isinstance(input, list): 106 | loss = 0 107 | for pred_i in input: 108 | if isinstance(pred_i, list): 109 | pred_i = pred_i[-1] 110 | loss_tensor = self.loss( 111 | pred_i, target_is_real, for_discriminator 112 | ) 113 | bs = 1 if len(loss_tensor.size()) == 0 else loss_tensor.size(0) 114 | new_loss = torch.mean(loss_tensor.view(bs, -1), dim=1) 115 | loss += new_loss 116 | return loss / len(input) 117 | else: 118 | return self.loss(input, target_is_real, for_discriminator) 119 | 120 | 121 | class BaseDiscriminator(nn.Module): 122 | def __init__(self, opt, name): 123 | super().__init__() 124 | 125 | if name == "pix2pixHD": 126 | self.netD = discriminators.define_D(opt) 127 | self.criterionGAN = GANLoss( 128 | opt.gan_mode, tensor=torch.FloatTensor, opt=opt 129 | ) 130 | self.criterionFeat = torch.nn.L1Loss() 131 | self.opt = opt 132 | 133 | self.FloatTensor = ( 134 | torch.cuda.FloatTensor 135 | if torch.cuda.is_available() 136 | else torch.FloatTensor 137 | ) 138 | 139 | # Given fake and real image, return the prediction of discriminator 140 | # for each fake and real image. 141 | 142 | def discriminate(self, fake_image, real_image): 143 | 144 | # In Batch Normalization, the fake and real images are 145 | # recommended to be in the same batch to avoid disparate 146 | # statistics in fake and real images. 147 | # So both fake and real images are fed to D all at once. 148 | fake_and_real = torch.cat([fake_image, real_image], dim=0) 149 | 150 | discriminator_out = self.netD(fake_and_real) 151 | 152 | pred_fake, pred_real = self.divide_pred(discriminator_out) 153 | 154 | return pred_fake, pred_real 155 | 156 | # Take the prediction of fake and real images from the combined batch 157 | def divide_pred(self, pred): 158 | # the prediction contains the intermediate outputs of multiscale GAN, 159 | # so it's usually a list 160 | if type(pred) == list: 161 | fake = [] 162 | real = [] 163 | for p in pred: 164 | fake.append([tensor[: tensor.size(0) // 2] for tensor in p]) 165 | real.append([tensor[tensor.size(0) // 2 :] for tensor in p]) 166 | else: 167 | fake = pred[: pred.size(0) // 2] 168 | real = pred[pred.size(0) // 2 :] 169 | 170 | return fake, real 171 | 172 | def compute_discrimator_loss(self, fake_image, real_image): 173 | D_losses = {} 174 | with torch.no_grad(): 175 | fake_image = fake_image.detach() 176 | fake_image.requires_grad_() 177 | 178 | pred_fake, pred_real = self.discriminate(fake_image, real_image) 179 | 180 | D_losses["D_Fake"] = self.criterionGAN( 181 | pred_fake, False, for_discriminator=True 182 | ) 183 | D_losses["D_real"] = self.criterionGAN( 184 | pred_real, True, for_discriminator=True 185 | ) 186 | 187 | D_losses["Total Loss"] = sum(D_losses.values()).mean() 188 | 189 | return D_losses 190 | 191 | def compute_generator_loss(self, fake_image, real_image): 192 | G_losses = {} 193 | pred_fake, pred_real = self.discriminate(fake_image, real_image) 194 | 195 | G_losses["GAN"] = self.criterionGAN( 196 | pred_fake, True, for_discriminator=False 197 | ) 198 | 199 | if not self.opt.no_ganFeat_loss: 200 | num_D = len(pred_fake) 201 | GAN_Feat_loss = self.FloatTensor(1).fill_(0) 202 | for i in range(num_D): # for each discriminator 203 | # last output is the final prediction, so we exclude it 204 | num_intermediate_outputs = len(pred_fake[i]) - 1 205 | for j in range( 206 | num_intermediate_outputs 207 | ): # for each layer output 208 | unweighted_loss = self.criterionFeat( 209 | pred_fake[i][j], pred_real[i][j].detach() 210 | ) 211 | GAN_Feat_loss += ( 212 | unweighted_loss * self.opt.lambda_feat / num_D 213 | ) 214 | G_losses["GAN_Feat"] = GAN_Feat_loss 215 | 216 | G_losses["Total Loss"] = sum(G_losses.values()).mean() 217 | 218 | return G_losses, fake_image 219 | 220 | def forward(self, fake_image, real_image, mode="generator"): 221 | if mode == "generator": 222 | g_loss, generated = self.compute_generator_loss( 223 | fake_image, real_image 224 | ) 225 | return g_loss 226 | 227 | elif mode == "discriminator": 228 | d_loss = self.compute_discrimator_loss(fake_image, real_image) 229 | return d_loss 230 | 231 | def update_learning_rate(self, curr_epoch): 232 | restart, new_lrs = self.netD.update_learning_rate(curr_epoch) 233 | 234 | return restart, new_lrs 235 | 236 | 237 | class DiscriminatorLoss(nn.Module): 238 | def __init__(self, opt): 239 | super().__init__() 240 | self.opt = opt 241 | 242 | # Get the losses 243 | loss_name = opt.discriminator_losses 244 | 245 | self.netD = self.get_loss_from_name(loss_name) 246 | 247 | def get_optimizer(self): 248 | optimizerD = torch.optim.Adam( 249 | list(self.netD.parameters()), lr=self.opt.lr * 2, betas=(0, 0.9) 250 | ) 251 | return optimizerD 252 | 253 | def get_loss_from_name(self, name): 254 | netD = BaseDiscriminator(self.opt, name=name) 255 | 256 | if torch.cuda.is_available(): 257 | return netD.cuda() 258 | 259 | return netD 260 | 261 | def forward(self, pred_img, gt_img): 262 | losses = [ 263 | loss(pred_img, gt_img, mode="discriminator") for loss in self.losses 264 | ] 265 | 266 | loss_dir = {} 267 | for i, l in enumerate(losses): 268 | if "Total Loss" in l.keys(): 269 | if "Total Loss" in loss_dir.keys(): 270 | loss_dir["Total Loss"] = ( 271 | loss_dir["Total Loss"] 272 | + l["Total Loss"] * self.lambdas[i] 273 | ) 274 | else: 275 | loss_dir["Total Loss"] = l["Total Loss"] 276 | 277 | loss_dir = dict(l, **loss_dir) # Have loss_dir override l 278 | 279 | return loss_dir 280 | 281 | def run_generator_one_step(self, pred_img, gt_img): 282 | return self.netD(pred_img, gt_img, mode="generator") 283 | 284 | def run_discriminator_one_step(self, pred_img, gt_img): 285 | return self.netD(pred_img, gt_img, mode="discriminator") 286 | 287 | def update_learning_rate(self, curr_epoch): 288 | restart, new_lrs = self.netD.update_learning_rate(curr_epoch) 289 | 290 | return restart, new_lrs 291 | -------------------------------------------------------------------------------- /models/losses/ssim.py: -------------------------------------------------------------------------------- 1 | # MIT Licence 2 | 3 | # Methods to predict the SSIM, taken from 4 | # https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/__init__.py 5 | 6 | from math import exp 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | from torch.autograd import Variable 11 | 12 | def gaussian(window_size, sigma): 13 | gauss = torch.Tensor( 14 | [ 15 | exp(-((x - window_size // 2) ** 2) / float(2 * sigma ** 2)) 16 | for x in range(window_size) 17 | ] 18 | ) 19 | return gauss / gauss.sum() 20 | 21 | 22 | def create_window(window_size, channel): 23 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 24 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 25 | window = Variable( 26 | _2D_window.expand(channel, 1, window_size, window_size).contiguous() 27 | ) 28 | return window 29 | 30 | 31 | def _ssim( 32 | img1, img2, window, window_size, channel, mask=None, size_average=True 33 | ): 34 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 35 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 36 | 37 | mu1_sq = mu1.pow(2) 38 | mu2_sq = mu2.pow(2) 39 | mu1_mu2 = mu1 * mu2 40 | 41 | sigma1_sq = ( 42 | F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) 43 | - mu1_sq 44 | ) 45 | sigma2_sq = ( 46 | F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) 47 | - mu2_sq 48 | ) 49 | sigma12 = ( 50 | F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) 51 | - mu1_mu2 52 | ) 53 | 54 | C1 = (0.01) ** 2 55 | C2 = (0.03) ** 2 56 | 57 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ( 58 | (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2) 59 | ) 60 | 61 | if not (mask is None): 62 | b = mask.size(0) 63 | ssim_map = ssim_map.mean(dim=1, keepdim=True) * mask 64 | ssim_map = ssim_map.view(b, -1).sum(dim=1) / mask.view(b, -1).sum( 65 | dim=1 66 | ).clamp(min=1) 67 | return ssim_map 68 | 69 | import pdb 70 | 71 | pdb.set_trace 72 | 73 | if size_average: 74 | return ssim_map.mean() 75 | else: 76 | return ssim_map.mean(1).mean(1).mean(1) 77 | 78 | 79 | class SSIM(torch.nn.Module): 80 | def __init__(self, window_size=11, size_average=True): 81 | super(SSIM, self).__init__() 82 | self.window_size = window_size 83 | self.size_average = size_average 84 | self.channel = 1 85 | self.window = create_window(window_size, self.channel) 86 | 87 | def forward(self, img1, img2, mask=None): 88 | (_, channel, _, _) = img1.size() 89 | 90 | if ( 91 | channel == self.channel 92 | and self.window.data.type() == img1.data.type() 93 | ): 94 | window = self.window 95 | else: 96 | window = create_window(self.window_size, channel) 97 | 98 | if img1.is_cuda: 99 | window = window.cuda(img1.get_device()) 100 | window = window.type_as(img1) 101 | 102 | self.window = window 103 | self.channel = channel 104 | 105 | return _ssim( 106 | img1, 107 | img2, 108 | window, 109 | self.window_size, 110 | channel, 111 | mask, 112 | self.size_average, 113 | ) 114 | 115 | 116 | def ssim(img1, img2, window_size=11, mask=None, size_average=True): 117 | (_, channel, _, _) = img1.size() 118 | window = create_window(window_size, channel) 119 | 120 | if img1.is_cuda: 121 | window = window.cuda(img1.get_device()) 122 | window = window.type_as(img1) 123 | 124 | return _ssim(img1, img2, window, window_size, channel, mask, size_average) 125 | -------------------------------------------------------------------------------- /models/losses/synthesis.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from models.losses.ssim import ssim 7 | from models.networks.architectures import VGG19 8 | 9 | 10 | class SynthesisLoss(nn.Module): 11 | def __init__(self, opt): 12 | super().__init__() 13 | self.opt = opt 14 | 15 | # Get the losses 16 | print(opt.losses) 17 | print(zip(*[l.split("_") for l in opt.losses])) 18 | lambdas, loss_names = zip(*[l.split("_") for l in opt.losses]) 19 | lambdas = [float(l) for l in lambdas] 20 | 21 | loss_names += ("PSNR", "SSIM") 22 | 23 | self.lambdas = lambdas 24 | self.losses = nn.ModuleList( 25 | [self.get_loss_from_name(loss_name) for loss_name in loss_names] 26 | ) 27 | 28 | def get_loss_from_name(self, name): 29 | if name == "l1": 30 | loss = L1LossWrapper() 31 | elif name == "content": 32 | loss = PerceptualLoss(self.opt) 33 | elif name == "PSNR": 34 | loss = PSNR() 35 | elif name == "SSIM": 36 | loss = SSIM() 37 | 38 | if torch.cuda.is_available(): 39 | return loss.cuda() 40 | 41 | def forward(self, pred_img, gt_img): 42 | losses = [loss(pred_img, gt_img) for loss in self.losses] 43 | 44 | loss_dir = {} 45 | for i, l in enumerate(losses): 46 | if "Total Loss" in l.keys(): 47 | if "Total Loss" in loss_dir.keys(): 48 | loss_dir["Total Loss"] = ( 49 | loss_dir["Total Loss"] 50 | + l["Total Loss"] * self.lambdas[i] 51 | ) 52 | else: 53 | loss_dir["Total Loss"] = l["Total Loss"] 54 | 55 | loss_dir = dict(l, **loss_dir) # Have loss_dir override l 56 | 57 | return loss_dir 58 | 59 | 60 | class PSNR(nn.Module): 61 | def forward(self, pred_img, gt_img): 62 | bs = pred_img.size(0) 63 | mse_err = (pred_img - gt_img).pow(2).sum(dim=1).view(bs, -1).mean(dim=1) 64 | 65 | psnr = 10 * (1 / mse_err).log10() 66 | return {"psnr": psnr.mean()} 67 | 68 | 69 | class SSIM(nn.Module): 70 | def forward(self, pred_img, gt_img): 71 | return {"ssim": ssim(pred_img, gt_img)} 72 | 73 | 74 | # Wrapper of the L1Loss so that the format matches what is expected 75 | class L1LossWrapper(nn.Module): 76 | def forward(self, pred_img, gt_img): 77 | err = nn.L1Loss()(pred_img, gt_img) 78 | return {"L1": err, "Total Loss": err} 79 | 80 | 81 | # Implementation of the perceptual loss to enforce that a 82 | # generated image matches the given image. 83 | # Adapted from SPADE's implementation 84 | # (https://github.com/NVlabs/SPADE/blob/master/models/networks/loss.py) 85 | class PerceptualLoss(nn.Module): 86 | def __init__(self, opt): 87 | super().__init__() 88 | self.model = VGG19( 89 | requires_grad=False 90 | ) # Set to false so that this part of the network is frozen 91 | self.criterion = nn.L1Loss() 92 | self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0] 93 | 94 | def forward(self, pred_img, gt_img): 95 | gt_fs = self.model(gt_img) 96 | pred_fs = self.model(pred_img) 97 | 98 | # Collect the losses at multiple layers (need unsqueeze in 99 | # order to concatenate these together) 100 | loss = 0 101 | for i in range(0, len(gt_fs)): 102 | loss += self.weights[i] * self.criterion(pred_fs[i], gt_fs[i]) 103 | 104 | return {"Perceptual": loss, "Total Loss": loss} 105 | -------------------------------------------------------------------------------- /models/networks/configs.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | def get_resnet_arch(model_type, opt, in_channels=3): 4 | setup = model_type.split("_")[1] 5 | 6 | if setup == "256W8UpDown": 7 | arch = { 8 | "layers_enc": [ 9 | in_channels, 10 | opt.ngf // 2, 11 | opt.ngf // 2, 12 | opt.ngf // 2, 13 | opt.ngf, 14 | opt.ngf, 15 | opt.ngf, 16 | opt.ngf, 17 | 64, 18 | ], 19 | "downsample": [ 20 | False, 21 | False, 22 | False, 23 | False, 24 | False, 25 | False, 26 | False, 27 | False, 28 | ], 29 | "layers_dec": [ 30 | 128, 31 | opt.ngf, 32 | opt.ngf * 2, 33 | opt.ngf * 4, 34 | opt.ngf * 4, 35 | opt.ngf * 2, 36 | opt.ngf * 2, 37 | opt.ngf * 2, 38 | 3, 39 | ], 40 | "upsample": [ 41 | False, 42 | "Down", 43 | "Down", 44 | False, 45 | "Up", 46 | "Up", 47 | False, 48 | False, 49 | ], 50 | "non_local": False, 51 | "non_local_index": 1, 52 | } 53 | 54 | elif setup == "256W8UpDown64": 55 | arch = { 56 | "layers_enc": [ 57 | in_channels, 58 | opt.ngf // 2, 59 | opt.ngf // 2, 60 | opt.ngf // 2, 61 | opt.ngf, 62 | opt.ngf, 63 | opt.ngf, 64 | opt.ngf, 65 | 64, 66 | ], 67 | "downsample": [ 68 | False, 69 | False, 70 | False, 71 | False, 72 | False, 73 | False, 74 | False, 75 | False, 76 | ], 77 | "layers_dec": [ 78 | 64, 79 | opt.ngf, 80 | opt.ngf * 2, 81 | opt.ngf * 4, 82 | opt.ngf * 4, 83 | opt.ngf * 2, 84 | opt.ngf * 2, 85 | opt.ngf * 2, 86 | 3, 87 | ], 88 | "upsample": [ 89 | False, 90 | "Down", 91 | "Down", 92 | False, 93 | "Up", 94 | "Up", 95 | False, 96 | False, 97 | ], 98 | "non_local": False, 99 | "non_local_index": 1, 100 | } 101 | 102 | elif setup == "256W8UpDownDV": 103 | arch = { 104 | "layers_enc": [ 105 | in_channels, 106 | opt.ngf // 2, 107 | opt.ngf // 2, 108 | opt.ngf // 2, 109 | opt.ngf, 110 | opt.ngf, 111 | opt.ngf, 112 | opt.ngf, 113 | 64, 114 | ], 115 | "downsample": [ 116 | False, 117 | False, 118 | False, 119 | False, 120 | False, 121 | False, 122 | False, 123 | False, 124 | ], 125 | "layers_dec": [ 126 | 64, 127 | opt.ngf, 128 | opt.ngf * 2, 129 | opt.ngf * 4, 130 | opt.ngf * 4, 131 | opt.ngf * 2, 132 | opt.ngf * 2, 133 | opt.ngf * 2, 134 | 3, 135 | ], 136 | "upsample": [ 137 | False, 138 | "Down", 139 | "Down", 140 | False, 141 | "Up", 142 | "Up", 143 | False, 144 | False, 145 | ], 146 | "non_local": False, 147 | "non_local_index": 1, 148 | } 149 | 150 | elif setup == "256W8UpDownRGB": 151 | arch = { 152 | "layers_enc": [ 153 | in_channels, 154 | opt.ngf // 2, 155 | opt.ngf // 2, 156 | opt.ngf // 2, 157 | opt.ngf, 158 | opt.ngf, 159 | opt.ngf, 160 | opt.ngf, 161 | 64, 162 | ], 163 | "downsample": [ 164 | False, 165 | False, 166 | False, 167 | False, 168 | False, 169 | False, 170 | False, 171 | False, 172 | ], 173 | "layers_dec": [ 174 | 6, 175 | opt.ngf, 176 | opt.ngf * 2, 177 | opt.ngf * 4, 178 | opt.ngf * 4, 179 | opt.ngf * 2, 180 | opt.ngf * 2, 181 | opt.ngf * 2, 182 | 3, 183 | ], 184 | "upsample": [ 185 | False, 186 | "Down", 187 | "Down", 188 | False, 189 | "Up", 190 | "Up", 191 | False, 192 | False, 193 | ], 194 | "non_local": False, 195 | "non_local_index": 1, 196 | } 197 | 198 | elif setup == "256W8UpDown3": 199 | arch = { 200 | "layers_enc": [ 201 | in_channels, 202 | opt.ngf // 2, 203 | opt.ngf // 2, 204 | opt.ngf // 2, 205 | opt.ngf, 206 | opt.ngf, 207 | opt.ngf, 208 | opt.ngf, 209 | 64, 210 | ], 211 | "downsample": [ 212 | False, 213 | False, 214 | False, 215 | False, 216 | False, 217 | False, 218 | False, 219 | False, 220 | ], 221 | "layers_dec": [ 222 | 3, 223 | opt.ngf, 224 | opt.ngf * 2, 225 | opt.ngf * 4, 226 | opt.ngf * 4, 227 | opt.ngf * 2, 228 | opt.ngf * 2, 229 | opt.ngf * 2, 230 | 3, 231 | ], 232 | "upsample": [ 233 | False, 234 | "Down", 235 | "Down", 236 | False, 237 | "Up", 238 | "Up", 239 | False, 240 | False, 241 | ], 242 | "non_local": False, 243 | "non_local_index": 1, 244 | } 245 | 246 | elif setup == "256W8": 247 | arch = { 248 | "layers_enc": [ 249 | in_channels, 250 | opt.ngf, 251 | opt.ngf, 252 | opt.ngf * 2, 253 | opt.ngf * 2, 254 | opt.ngf * 2, 255 | opt.ngf * 4, 256 | opt.ngf * 4, 257 | 64, 258 | ], 259 | "downsample": [ 260 | True, 261 | False, 262 | False, 263 | False, 264 | True, 265 | False, 266 | False, 267 | False, 268 | ], 269 | "layers_dec": [ 270 | 64, 271 | opt.ngf, 272 | opt.ngf, 273 | opt.ngf * 2, 274 | opt.ngf * 2, 275 | opt.ngf * 2, 276 | opt.ngf * 4, 277 | opt.ngf * 4, 278 | 3, 279 | ], 280 | "upsample": [False, False, True, False, False, False, True, False], 281 | "non_local": False, 282 | "non_local_index": 1, 283 | } 284 | 285 | return arch 286 | -------------------------------------------------------------------------------- /models/networks/discriminators.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | 5 | Based on https://github.com/NVlabs/SPADE/blob/master/models/pix2pix_model.py 6 | """ 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from torch.nn import init 13 | 14 | from models.layers.normalization import get_D_norm_layer 15 | 16 | 17 | class BaseNetwork(nn.Module): 18 | def __init__(self): 19 | super(BaseNetwork, self).__init__() 20 | 21 | @staticmethod 22 | def modify_commandline_options(parser, is_train): 23 | return parser 24 | 25 | def print_network(self): 26 | if isinstance(self, list): 27 | self = self[0] 28 | num_params = 0 29 | for param in self.parameters(): 30 | num_params += param.numel() 31 | print( 32 | "Network [%s] was created." 33 | + "Total number of parameters: %.1f million. " 34 | "To see the architecture, do print(network)." 35 | % (type(self).__name__, num_params / 1000000) 36 | ) 37 | 38 | def init_weights(self, init_type="normal", gain=0.02): 39 | def init_func(m): 40 | classname = m.__class__.__name__ 41 | if classname.find("BatchNorm2d") != -1: 42 | if hasattr(m, "weight") and m.weight is not None: 43 | init.normal_(m.weight.data, 1.0, gain) 44 | if hasattr(m, "bias") and m.bias is not None: 45 | init.constant_(m.bias.data, 0.0) 46 | elif hasattr(m, "weight") and ( 47 | classname.find("Conv") != -1 or classname.find("Linear") != -1 48 | ): 49 | if init_type == "normal": 50 | init.normal_(m.weight.data, 0.0, gain) 51 | elif init_type == "xavier": 52 | init.xavier_normal_(m.weight.data, gain=gain) 53 | elif init_type == "xavier_uniform": 54 | init.xavier_uniform_(m.weight.data, gain=1.0) 55 | elif init_type == "kaiming": 56 | init.kaiming_normal_(m.weight.data, a=0, mode="fan_in") 57 | elif init_type == "orthogonal": 58 | init.orthogonal_(m.weight.data, gain=gain) 59 | elif init_type == "none": # uses pytorch's default init method 60 | m.reset_parameters() 61 | else: 62 | raise NotImplementedError( 63 | "initialization method [%s] is not implemented" 64 | % init_type 65 | ) 66 | if hasattr(m, "bias") and m.bias is not None: 67 | init.constant_(m.bias.data, 0.0) 68 | 69 | self.apply(init_func) 70 | 71 | # propagate to children 72 | for m in self.children(): 73 | if hasattr(m, "init_weights"): 74 | m.init_weights(init_type, gain) 75 | 76 | 77 | # Defines the PatchGAN discriminator with the specified arguments. 78 | class NLayerDiscriminator(BaseNetwork): 79 | def __init__(self, opt): 80 | super().__init__() 81 | opt.n_layers_D = 4 82 | self.opt = opt 83 | 84 | kw = 4 85 | padw = int(np.ceil((kw - 1.0) / 2)) 86 | nf = opt.ndf 87 | input_nc = self.compute_D_input_nc(opt) 88 | 89 | norm_layer = get_D_norm_layer(opt, opt.norm_D) 90 | sequence = [ 91 | [ 92 | nn.Conv2d(input_nc, nf, kernel_size=kw, stride=2, padding=padw), 93 | nn.LeakyReLU(0.2, False), 94 | ] 95 | ] 96 | 97 | for n in range(1, opt.n_layers_D): 98 | nf_prev = nf 99 | nf = min(nf * 2, 512) 100 | stride = 1 if n == opt.n_layers_D - 1 else 2 101 | sequence += [ 102 | [ 103 | norm_layer( 104 | nn.Conv2d( 105 | nf_prev, 106 | nf, 107 | kernel_size=kw, 108 | stride=stride, 109 | padding=padw, 110 | ) 111 | ), 112 | nn.LeakyReLU(0.2, False), 113 | ] 114 | ] 115 | 116 | sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]] 117 | 118 | # We divide the layers into groups to extract intermediate layer outputs 119 | for n in range(len(sequence)): 120 | self.add_module("model" + str(n), nn.Sequential(*sequence[n])) 121 | 122 | def compute_D_input_nc(self, opt): 123 | # if opt.concat_discriminators: 124 | # input_nc = opt.output_nc * 2 125 | # else: 126 | input_nc = opt.output_nc 127 | return input_nc 128 | 129 | def forward(self, input): 130 | results = [input] 131 | for submodel in self.children(): 132 | intermediate_output = submodel(results[-1]) 133 | results.append(intermediate_output) 134 | 135 | get_intermediate_features = not self.opt.no_ganFeat_loss 136 | if get_intermediate_features: 137 | return results[1:] 138 | else: 139 | return results[-1] 140 | 141 | 142 | class MultiscaleDiscriminator(BaseNetwork): 143 | def __init__(self, opt): 144 | super().__init__() 145 | opt.netD_subarch = "n_layer" 146 | opt.num_D = 2 147 | self.opt = opt 148 | 149 | for i in range(opt.num_D): 150 | subnetD = self.create_single_discriminator(opt) 151 | self.add_module("discriminator_%d" % i, subnetD) 152 | if opt.isTrain: 153 | self.old_lr = opt.lr 154 | 155 | def create_single_discriminator(self, opt): 156 | subarch = opt.netD_subarch 157 | if subarch == "n_layer": 158 | netD = NLayerDiscriminator(opt) 159 | 160 | if torch.cuda.is_available(): 161 | netD = netD.cuda() 162 | else: 163 | raise ValueError( 164 | "unrecognized discriminator subarchitecture %s" % subarch 165 | ) 166 | return netD 167 | 168 | def downsample(self, input): 169 | return F.avg_pool2d( 170 | input, 171 | kernel_size=3, 172 | stride=2, 173 | padding=[1, 1], 174 | count_include_pad=False, 175 | ) 176 | 177 | def update_learning_rate(self, epoch): 178 | if epoch > self.opt.niter: 179 | lrd = self.opt.lr / self.opt.niter_decay 180 | new_lr = self.old_lr - lrd 181 | else: 182 | new_lr = self.old_lr 183 | 184 | if new_lr != self.old_lr: 185 | print("update learning rate: %f -> %f" % (self.old_lr, new_lr)) 186 | self.old_lr = new_lr 187 | new_lr_G = new_lr / 2 188 | new_lr_D = new_lr * 2 189 | return False, {"lr_D": new_lr_D, "lr_G": new_lr_G} 190 | 191 | else: 192 | return False, {"lr_D": new_lr, "lr_G": new_lr} 193 | 194 | # Returns list of lists of discriminator outputs. 195 | # The final result is of size opt.num_D x opt.n_layers_D 196 | def forward(self, input): 197 | result = [] 198 | get_intermediate_features = not self.opt.no_ganFeat_loss 199 | 200 | for name, D in self.named_children(): 201 | out = D(input) 202 | if not get_intermediate_features: 203 | out = [out] 204 | result.append(out) 205 | input = self.downsample(input) 206 | 207 | return result 208 | 209 | 210 | def define_D(opt): 211 | net = MultiscaleDiscriminator(opt) 212 | net.init_weights("xavier", 0.02) 213 | if torch.cuda.is_available(): 214 | net = net.cuda() 215 | return net 216 | -------------------------------------------------------------------------------- /models/networks/pretrained_networks.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # Heavily based on the code in https://github.com/richzhang/PerceptualSimilarity (BSD Licence) 3 | 4 | from collections import namedtuple 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torchvision import models 9 | 10 | 11 | def normalize_tensor(in_feat, eps=1e-10): 12 | norm_factor = torch.sqrt(torch.sum(in_feat ** 2, dim=1)).view( 13 | in_feat.size()[0], 1, in_feat.size()[2], in_feat.size()[3] 14 | ) 15 | return in_feat / (norm_factor.expand_as(in_feat) + eps) 16 | 17 | 18 | def cos_sim(in0, in1): 19 | in0_norm = normalize_tensor(in0) 20 | in1_norm = normalize_tensor(in1) 21 | N = in0.size()[0] 22 | X = in0.size()[2] 23 | Y = in0.size()[3] 24 | 25 | return torch.mean( 26 | torch.mean( 27 | torch.sum(in0_norm * in1_norm, dim=1).view(N, 1, X, Y), dim=2 28 | ).view(N, 1, 1, Y), 29 | dim=3, 30 | ).view(N) 31 | 32 | 33 | # Off-the-shelf deep network 34 | class PNet(nn.Module): 35 | """Pre-trained network with all channels equally weighted by default""" 36 | 37 | def __init__(self, pnet_type="vgg", pnet_rand=False, use_gpu=True): 38 | super(PNet, self).__init__() 39 | 40 | self.use_gpu = use_gpu 41 | 42 | self.pnet_type = pnet_type 43 | self.pnet_rand = pnet_rand 44 | 45 | self.shift = torch.Tensor([-0.030, -0.088, -0.188]).view(1, 3, 1, 1) 46 | self.scale = torch.Tensor([0.458, 0.448, 0.450]).view(1, 3, 1, 1) 47 | 48 | if self.pnet_type in ["vgg", "vgg16"]: 49 | self.net = vgg16(pretrained=not self.pnet_rand, requires_grad=False) 50 | elif self.pnet_type == "alex": 51 | self.net = alexnet( 52 | pretrained=not self.pnet_rand, requires_grad=False 53 | ) 54 | elif self.pnet_type[:-2] == "resnet": 55 | self.net = resnet( 56 | pretrained=not self.pnet_rand, 57 | requires_grad=False, 58 | num=int(self.pnet_type[-2:]), 59 | ) 60 | elif self.pnet_type == "squeeze": 61 | self.net = squeezenet( 62 | pretrained=not self.pnet_rand, requires_grad=False 63 | ) 64 | 65 | self.L = self.net.N_slices 66 | 67 | if use_gpu: 68 | self.net.cuda() 69 | self.shift = self.shift.cuda() 70 | self.scale = self.scale.cuda() 71 | 72 | def forward(self, in0, in1, retPerLayer=False): 73 | in0_sc = (in0 - self.shift.expand_as(in0)) / self.scale.expand_as(in0) 74 | in1_sc = (in1 - self.shift.expand_as(in0)) / self.scale.expand_as(in0) 75 | 76 | outs0 = self.net.forward(in0_sc) 77 | outs1 = self.net.forward(in1_sc) 78 | 79 | if retPerLayer: 80 | all_scores = [] 81 | for (kk, out0) in enumerate(outs0): 82 | cur_score = 1.0 - cos_sim(outs0[kk], outs1[kk]) 83 | if kk == 0: 84 | val = 1.0 * cur_score 85 | else: 86 | val = val + cur_score 87 | if retPerLayer: 88 | all_scores += [cur_score] 89 | 90 | if retPerLayer: 91 | return (val, all_scores) 92 | else: 93 | return val 94 | 95 | 96 | class squeezenet(torch.nn.Module): 97 | def __init__(self, requires_grad=False, pretrained=True): 98 | super(squeezenet, self).__init__() 99 | pretrained_features = models.squeezenet1_1( 100 | pretrained=pretrained 101 | ).features 102 | self.slice1 = torch.nn.Sequential() 103 | self.slice2 = torch.nn.Sequential() 104 | self.slice3 = torch.nn.Sequential() 105 | self.slice4 = torch.nn.Sequential() 106 | self.slice5 = torch.nn.Sequential() 107 | self.slice6 = torch.nn.Sequential() 108 | self.slice7 = torch.nn.Sequential() 109 | self.N_slices = 7 110 | for x in range(2): 111 | self.slice1.add_module(str(x), pretrained_features[x]) 112 | for x in range(2, 5): 113 | self.slice2.add_module(str(x), pretrained_features[x]) 114 | for x in range(5, 8): 115 | self.slice3.add_module(str(x), pretrained_features[x]) 116 | for x in range(8, 10): 117 | self.slice4.add_module(str(x), pretrained_features[x]) 118 | for x in range(10, 11): 119 | self.slice5.add_module(str(x), pretrained_features[x]) 120 | for x in range(11, 12): 121 | self.slice6.add_module(str(x), pretrained_features[x]) 122 | for x in range(12, 13): 123 | self.slice7.add_module(str(x), pretrained_features[x]) 124 | if not requires_grad: 125 | for param in self.parameters(): 126 | param.requires_grad = False 127 | 128 | def forward(self, X): 129 | h = self.slice1(X) 130 | h_relu1 = h 131 | h = self.slice2(h) 132 | h_relu2 = h 133 | h = self.slice3(h) 134 | h_relu3 = h 135 | h = self.slice4(h) 136 | h_relu4 = h 137 | h = self.slice5(h) 138 | h_relu5 = h 139 | h = self.slice6(h) 140 | h_relu6 = h 141 | h = self.slice7(h) 142 | h_relu7 = h 143 | vgg_outputs = namedtuple( 144 | "SqueezeOutputs", 145 | ["relu1", "relu2", "relu3", "relu4", "relu5", "relu6", "relu7"], 146 | ) 147 | out = vgg_outputs( 148 | h_relu1, h_relu2, h_relu3, h_relu4, h_relu5, h_relu6, h_relu7 149 | ) 150 | 151 | return out 152 | 153 | 154 | class alexnet(torch.nn.Module): 155 | def __init__(self, requires_grad=False, pretrained=True): 156 | super(alexnet, self).__init__() 157 | alexnet_pretrained_features = models.alexnet( 158 | pretrained=pretrained 159 | ).features 160 | self.slice1 = torch.nn.Sequential() 161 | self.slice2 = torch.nn.Sequential() 162 | self.slice3 = torch.nn.Sequential() 163 | self.slice4 = torch.nn.Sequential() 164 | self.slice5 = torch.nn.Sequential() 165 | self.N_slices = 5 166 | for x in range(2): 167 | self.slice1.add_module(str(x), alexnet_pretrained_features[x]) 168 | for x in range(2, 5): 169 | self.slice2.add_module(str(x), alexnet_pretrained_features[x]) 170 | for x in range(5, 8): 171 | self.slice3.add_module(str(x), alexnet_pretrained_features[x]) 172 | for x in range(8, 10): 173 | self.slice4.add_module(str(x), alexnet_pretrained_features[x]) 174 | for x in range(10, 12): 175 | self.slice5.add_module(str(x), alexnet_pretrained_features[x]) 176 | if not requires_grad: 177 | for param in self.parameters(): 178 | param.requires_grad = False 179 | 180 | def forward(self, X): 181 | h = self.slice1(X) 182 | h_relu1 = h 183 | h = self.slice2(h) 184 | h_relu2 = h 185 | h = self.slice3(h) 186 | h_relu3 = h 187 | h = self.slice4(h) 188 | h_relu4 = h 189 | h = self.slice5(h) 190 | h_relu5 = h 191 | alexnet_outputs = namedtuple( 192 | "AlexnetOutputs", ["relu1", "relu2", "relu3", "relu4", "relu5"] 193 | ) 194 | out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5) 195 | 196 | return out 197 | 198 | 199 | class vgg16(torch.nn.Module): 200 | def __init__(self, requires_grad=False, pretrained=True): 201 | super(vgg16, self).__init__() 202 | vgg_pretrained_features = models.vgg16(pretrained=pretrained).features 203 | self.slice1 = torch.nn.Sequential() 204 | self.slice2 = torch.nn.Sequential() 205 | self.slice3 = torch.nn.Sequential() 206 | self.slice4 = torch.nn.Sequential() 207 | self.slice5 = torch.nn.Sequential() 208 | self.N_slices = 5 209 | for x in range(4): 210 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 211 | for x in range(4, 9): 212 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 213 | for x in range(9, 16): 214 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 215 | for x in range(16, 23): 216 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 217 | for x in range(23, 30): 218 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 219 | if not requires_grad: 220 | for param in self.parameters(): 221 | param.requires_grad = False 222 | 223 | def forward(self, X): 224 | h = self.slice1(X) 225 | h_relu1_2 = h 226 | h = self.slice2(h) 227 | h_relu2_2 = h 228 | h = self.slice3(h) 229 | h_relu3_3 = h 230 | h = self.slice4(h) 231 | h_relu4_3 = h 232 | h = self.slice5(h) 233 | h_relu5_3 = h 234 | vgg_outputs = namedtuple( 235 | "VggOutputs", 236 | ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"], 237 | ) 238 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 239 | 240 | return out 241 | 242 | 243 | class resnet(torch.nn.Module): 244 | def __init__(self, requires_grad=False, pretrained=True, num=18): 245 | super(resnet, self).__init__() 246 | if num == 18: 247 | self.net = models.resnet18(pretrained=pretrained) 248 | elif num == 34: 249 | self.net = models.resnet34(pretrained=pretrained) 250 | elif num == 50: 251 | self.net = models.resnet50(pretrained=pretrained) 252 | elif num == 101: 253 | self.net = models.resnet101(pretrained=pretrained) 254 | elif num == 152: 255 | self.net = models.resnet152(pretrained=pretrained) 256 | self.N_slices = 5 257 | 258 | self.conv1 = self.net.conv1 259 | self.bn1 = self.net.bn1 260 | self.relu = self.net.relu 261 | self.maxpool = self.net.maxpool 262 | self.layer1 = self.net.layer1 263 | self.layer2 = self.net.layer2 264 | self.layer3 = self.net.layer3 265 | self.layer4 = self.net.layer4 266 | 267 | def forward(self, X): 268 | h = self.conv1(X) 269 | h = self.bn1(h) 270 | h = self.relu(h) 271 | h_relu1 = h 272 | h = self.maxpool(h) 273 | h = self.layer1(h) 274 | h_conv2 = h 275 | h = self.layer2(h) 276 | h_conv3 = h 277 | h = self.layer3(h) 278 | h_conv4 = h 279 | h = self.layer4(h) 280 | h_conv5 = h 281 | 282 | outputs = namedtuple( 283 | "Outputs", ["relu1", "conv2", "conv3", "conv4", "conv5"] 284 | ) 285 | out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5) 286 | 287 | return out 288 | -------------------------------------------------------------------------------- /models/networks/sync_batchnorm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import ( 12 | SynchronizedBatchNorm1d, 13 | SynchronizedBatchNorm2d, 14 | SynchronizedBatchNorm3d, 15 | convert_model, 16 | patch_sync_batchnorm, 17 | ) 18 | from .replicate import DataParallelWithCallback, patch_replication_callback 19 | -------------------------------------------------------------------------------- /models/networks/sync_batchnorm/batchnorm_reimpl.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # File : batchnorm_reimpl.py 4 | # Author : acgtyrant 5 | # Date : 11/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.init as init 14 | 15 | __all__ = ["BatchNormReimpl"] 16 | 17 | 18 | class BatchNorm2dReimpl(nn.Module): 19 | """ 20 | A re-implementation of batch normalization, used for testing the numerical 21 | stability. 22 | 23 | Author: acgtyrant 24 | See also: 25 | https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14 26 | """ 27 | 28 | def __init__(self, num_features, eps=1e-5, momentum=0.1): 29 | super().__init__() 30 | 31 | self.num_features = num_features 32 | self.eps = eps 33 | self.momentum = momentum 34 | self.weight = nn.Parameter(torch.empty(num_features)) 35 | self.bias = nn.Parameter(torch.empty(num_features)) 36 | self.register_buffer("running_mean", torch.zeros(num_features)) 37 | self.register_buffer("running_var", torch.ones(num_features)) 38 | self.reset_parameters() 39 | 40 | def reset_running_stats(self): 41 | self.running_mean.zero_() 42 | self.running_var.fill_(1) 43 | 44 | def reset_parameters(self): 45 | self.reset_running_stats() 46 | init.uniform_(self.weight) 47 | init.zeros_(self.bias) 48 | 49 | def forward(self, input_): 50 | batchsize, channels, height, width = input_.size() 51 | numel = batchsize * height * width 52 | input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel) 53 | sum_ = input_.sum(1) 54 | sum_of_square = input_.pow(2).sum(1) 55 | mean = sum_ / numel 56 | sumvar = sum_of_square - sum_ * mean 57 | 58 | self.running_mean = ( 59 | 1 - self.momentum 60 | ) * self.running_mean + self.momentum * mean.detach() 61 | unbias_var = sumvar / (numel - 1) 62 | self.running_var = ( 63 | 1 - self.momentum 64 | ) * self.running_var + self.momentum * unbias_var.detach() 65 | 66 | bias_var = sumvar / numel 67 | inv_std = 1 / (bias_var + self.eps).pow(0.5) 68 | output = (input_ - mean.unsqueeze(1)) * inv_std.unsqueeze( 69 | 1 70 | ) * self.weight.unsqueeze(1) + self.bias.unsqueeze(1) 71 | 72 | return ( 73 | output.view(channels, batchsize, height, width) 74 | .permute(1, 0, 2, 3) 75 | .contiguous() 76 | ) 77 | -------------------------------------------------------------------------------- /models/networks/sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import collections 12 | import queue 13 | import threading 14 | 15 | __all__ = ["FutureResult", "SlavePipe", "SyncMaster"] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, "Previous result has't been fetched." 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple("MasterRegistry", ["result"]) 43 | _SlavePipeBase = collections.namedtuple( 44 | "_SlavePipeBase", ["identifier", "queue", "result"] 45 | ) 46 | 47 | 48 | class SlavePipe(_SlavePipeBase): 49 | """Pipe for master-slave communication.""" 50 | 51 | def run_slave(self, msg): 52 | self.queue.put((self.identifier, msg)) 53 | ret = self.result.get() 54 | self.queue.put(True) 55 | return ret 56 | 57 | 58 | class SyncMaster(object): 59 | """An abstract `SyncMaster` object. 60 | 61 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 62 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 63 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 64 | and passed to a registered callback. 65 | - After receiving the messages, the master device should gather the information and determine to message passed 66 | back to each slave devices. 67 | """ 68 | 69 | def __init__(self, master_callback): 70 | """ 71 | 72 | Args: 73 | master_callback: a callback to be invoked after having collected messages from slave devices. 74 | """ 75 | self._master_callback = master_callback 76 | self._queue = queue.Queue() 77 | self._registry = collections.OrderedDict() 78 | self._activated = False 79 | 80 | def __getstate__(self): 81 | return {"master_callback": self._master_callback} 82 | 83 | def __setstate__(self, state): 84 | self.__init__(state["master_callback"]) 85 | 86 | def register_slave(self, identifier): 87 | """ 88 | Register an slave device. 89 | 90 | Args: 91 | identifier: an identifier, usually is the device id. 92 | 93 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 94 | 95 | """ 96 | if self._activated: 97 | assert ( 98 | self._queue.empty() 99 | ), "Queue is not clean before next initialization." 100 | self._activated = False 101 | self._registry.clear() 102 | future = FutureResult() 103 | self._registry[identifier] = _MasterRegistry(future) 104 | return SlavePipe(identifier, self._queue, future) 105 | 106 | def run_master(self, master_msg): 107 | """ 108 | Main entry for the master device in each forward pass. 109 | The messages were first collected from each devices (including the master device), and then 110 | an callback will be invoked to compute the message to be sent back to each devices 111 | (including the master device). 112 | 113 | Args: 114 | master_msg: the message that the master want to send to itself. This will be placed as the first 115 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 116 | 117 | Returns: the message to be sent back to the master device. 118 | 119 | """ 120 | self._activated = True 121 | 122 | intermediates = [(0, master_msg)] 123 | for i in range(self.nr_slaves): 124 | intermediates.append(self._queue.get()) 125 | 126 | results = self._master_callback(intermediates) 127 | assert ( 128 | results[0][0] == 0 129 | ), "The first result should belongs to the master." 130 | 131 | for i, res in results: 132 | if i == 0: 133 | continue 134 | self._registry[i].result.put(res) 135 | 136 | for i in range(self.nr_slaves): 137 | assert self._queue.get() is True 138 | 139 | return results[0][1] 140 | 141 | @property 142 | def nr_slaves(self): 143 | return len(self._registry) 144 | -------------------------------------------------------------------------------- /models/networks/sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | "CallbackContext", 17 | "execute_replication_callbacks", 18 | "DataParallelWithCallback", 19 | "patch_replication_callback", 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | 31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 32 | 33 | Note that, as all modules are isomorphism, we assign each sub-module with a context 34 | (shared among multiple copies of this module on different devices). 35 | Through this context, different copies can share some information. 36 | 37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 38 | of any slave copies. 39 | """ 40 | master_copy = modules[0] 41 | nr_modules = len(list(master_copy.modules())) 42 | ctxs = [CallbackContext() for _ in range(nr_modules)] 43 | 44 | for i, module in enumerate(modules): 45 | for j, m in enumerate(module.modules()): 46 | if hasattr(m, "__data_parallel_replicate__"): 47 | m.__data_parallel_replicate__(ctxs[j], i) 48 | 49 | 50 | class DataParallelWithCallback(DataParallel): 51 | """ 52 | Data Parallel with a replication callback. 53 | 54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 55 | original `replicate` function. 56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 57 | 58 | Examples: 59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 61 | # sync_bn.__data_parallel_replicate__ will be invoked. 62 | """ 63 | 64 | def replicate(self, module, device_ids): 65 | modules = super(DataParallelWithCallback, self).replicate( 66 | module, device_ids 67 | ) 68 | execute_replication_callbacks(modules) 69 | return modules 70 | 71 | 72 | def patch_replication_callback(data_parallel): 73 | """ 74 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 75 | Useful when you have customized `DataParallel` implementation. 76 | 77 | Examples: 78 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 79 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 80 | > patch_replication_callback(sync_bn) 81 | # this is equivalent to 82 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 83 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 84 | """ 85 | 86 | assert isinstance(data_parallel, DataParallel) 87 | 88 | old_replicate = data_parallel.replicate 89 | 90 | @functools.wraps(old_replicate) 91 | def new_replicate(module, device_ids): 92 | modules = old_replicate(module, device_ids) 93 | execute_replication_callbacks(modules) 94 | return modules 95 | 96 | data_parallel.replicate = new_replicate 97 | -------------------------------------------------------------------------------- /models/networks/sync_batchnorm/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | 13 | import torch 14 | 15 | 16 | class TorchTestCase(unittest.TestCase): 17 | def assertTensorClose(self, x, y): 18 | adiff = float((x - y).abs().max()) 19 | if (y == 0).all(): 20 | rdiff = "NaN" 21 | else: 22 | rdiff = float((adiff / y).abs().max()) 23 | 24 | message = ( 25 | "Tensor close check failed\n" "adiff={}\n" "rdiff={}\n" 26 | ).format(adiff, rdiff) 27 | self.assertTrue(torch.allclose(x, y), message) 28 | -------------------------------------------------------------------------------- /models/networks/utilities.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | import torch.nn as nn 4 | 5 | from models.networks.architectures import ( 6 | ResNetDecoder, 7 | ResNetEncoder, 8 | UNetDecoder64, 9 | UNetEncoder64, 10 | ) 11 | 12 | EPS = 1e-2 13 | 14 | def get_encoder(opt, downsample=True): 15 | if opt.refine_model_type == "unet": 16 | encoder = UNetEncoder64(opt, channels_in=3, channels_out=64) 17 | elif "resnet" in opt.refine_model_type: 18 | print("RESNET encoder") 19 | encoder = ResNetEncoder( 20 | opt, channels_in=3, channels_out=64, downsample=downsample 21 | ) 22 | 23 | return encoder 24 | 25 | 26 | def get_decoder(opt): 27 | if opt.refine_model_type == "unet": 28 | decoder = UNetDecoder64(opt, channels_in=64, channels_out=3) 29 | elif "resnet" in opt.refine_model_type: 30 | print("RESNET decoder") 31 | decoder = ResNetDecoder(opt, channels_in=3, channels_out=3) 32 | 33 | return decoder 34 | -------------------------------------------------------------------------------- /models/projection/depth_manipulator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | 7 | EPS = 1e-2 8 | 9 | 10 | class DepthManipulator(nn.Module): 11 | """ 12 | Use depth in order to naively manipulate an image. Simply splatter the 13 | depth into the new image such that the nearest point colours a pixel. 14 | 15 | Is not used for training; just for evaluation to determine visible/invisible 16 | regions. 17 | """ 18 | def __init__(self, W=256): 19 | super(DepthManipulator, self).__init__() 20 | # Set up default grid using clip coordinates (e.g. between [-1, 1]) 21 | xs, ys = np.meshgrid(np.linspace(-1, 1, W), np.linspace(1, -1, W)) 22 | xs = xs.reshape(1, W, W) 23 | ys = ys.reshape(1, W, W) 24 | xys = np.vstack((xs, ys, -np.ones(xs.shape), np.ones(xs.shape))) 25 | 26 | self.grid = torch.Tensor(xys).unsqueeze(0) 27 | 28 | if torch.cuda.is_available(): 29 | self.grid = self.grid.cuda() 30 | 31 | def homogenize(self, xys): 32 | assert xys.size(1) <= 3 33 | ones = torch.ones(xys.size(0), 1, xys.size(2)).to(xys.device) 34 | 35 | return torch.cat((xys, ones), 1) 36 | 37 | def project_zbuffer(self, depth, K, K_inv, RTinv_cam1, RT_cam2): 38 | """ Determine the sampler that comes from projecting 39 | the given depth according to the given camera parameters. 40 | """ 41 | bs, _, w, h = depth.size() 42 | 43 | # Obtain unprojected coordinates 44 | orig_xys = self.grid.to(depth.device).repeat(bs, 1, 1, 1).detach() 45 | xys = orig_xys * depth 46 | xys[:, -1, :] = 1 47 | 48 | xys = xys.view(bs, 4, -1) 49 | 50 | # Transform into camera coordinate of the first view 51 | cam1_X = K_inv.bmm(xys) 52 | 53 | # Transform into world coordinates 54 | RT = RT_cam2.bmm(RTinv_cam1) 55 | wrld_X = RT.bmm(cam1_X) 56 | 57 | # And intrinsics 58 | xy_proj = K.bmm(wrld_X) 59 | 60 | # And finally we project to get the final result 61 | mask = (xy_proj[:, 2:3, :].abs() < EPS).detach() 62 | sampler = xy_proj[:, 0:2, :] / -xy_proj[:, 2:3, :] 63 | sampler[mask.repeat(1, 2, 1)] = -10 64 | sampler[:, 1, :] = -sampler[:, 1, :] 65 | sampler[:, 0, :] = sampler[:, 0, :] 66 | 67 | with torch.no_grad(): 68 | print( 69 | "Warning : not backpropagating through the " 70 | + "projection -- is this what you want??" 71 | ) 72 | tsampler = (sampler + 1) * 128 73 | tsampler = tsampler.view(bs, 2, -1) 74 | zs, sampler_inds = xy_proj[:, 2:3, :].sort( 75 | dim=2, descending=True 76 | ) # Hack for how it's going to be understood by scatter: enforces that 77 | # nearer points are the ones rendered. 78 | bsinds = ( 79 | torch.linspace(0, bs - 1, bs) 80 | .long() 81 | .unsqueeze(1) 82 | .repeat(1, w * h) 83 | .to(sampler.device) 84 | .unsqueeze(1) 85 | ) 86 | 87 | xs = tsampler[bsinds, 0, sampler_inds].long() 88 | ys = tsampler[bsinds, 1, sampler_inds].long() 89 | mask = (tsampler < 0) | (tsampler > 255) 90 | mask = mask.float().max(dim=1, keepdim=True)[0] * 4 91 | 92 | xs = xs.clamp(min=0, max=255) 93 | ys = ys.clamp(min=0, max=255) 94 | 95 | bilinear_sampler = torch.zeros(bs, 2, w, h).to(sampler.device) - 2 96 | orig_xys = orig_xys[:, :2, :, :].view((bs, 2, -1)) 97 | bilinear_sampler[bsinds, 0, ys, xs] = ( 98 | orig_xys[bsinds, 0, sampler_inds] + mask 99 | ) 100 | bilinear_sampler[bsinds, 1, ys, xs] = ( 101 | -orig_xys[bsinds, 1, sampler_inds] + mask 102 | ) 103 | 104 | return bilinear_sampler, -xy_proj[:, 2:3, :].view(bs, 1, w, h) 105 | -------------------------------------------------------------------------------- /models/projection/z_buffer_manipulator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from pytorch3d.structures import Pointclouds 7 | 8 | EPS = 1e-2 9 | 10 | 11 | def get_splatter( 12 | name, depth_values, opt=None, size=256, C=64, points_per_pixel=8 13 | ): 14 | if name == "xyblending": 15 | from models.layers.z_buffer_layers import RasterizePointsXYsBlending 16 | 17 | return RasterizePointsXYsBlending( 18 | C, 19 | learn_feature=opt.learn_default_feature, 20 | radius=opt.radius, 21 | size=size, 22 | points_per_pixel=points_per_pixel, 23 | opts=opt, 24 | ) 25 | else: 26 | raise NotImplementedError() 27 | 28 | 29 | class PtsManipulator(nn.Module): 30 | def __init__(self, W, C=64, opt=None): 31 | super().__init__() 32 | self.opt = opt 33 | 34 | self.splatter = get_splatter( 35 | opt.splatter, None, opt, size=W, C=C, points_per_pixel=opt.pp_pixel 36 | ) 37 | 38 | xs = torch.linspace(0, W - 1, W) / float(W - 1) * 2 - 1 39 | ys = torch.linspace(0, W - 1, W) / float(W - 1) * 2 - 1 40 | 41 | xs = xs.view(1, 1, 1, W).repeat(1, 1, W, 1) 42 | ys = ys.view(1, 1, W, 1).repeat(1, 1, 1, W) 43 | 44 | xyzs = torch.cat( 45 | (xs, -ys, -torch.ones(xs.size()), torch.ones(xs.size())), 1 46 | ).view(1, 4, -1) 47 | 48 | self.register_buffer("xyzs", xyzs) 49 | 50 | def project_pts( 51 | self, pts3D, K, K_inv, RT_cam1, RTinv_cam1, RT_cam2, RTinv_cam2 52 | ): 53 | # PERFORM PROJECTION 54 | # Project the world points into the new view 55 | projected_coors = self.xyzs * pts3D 56 | projected_coors[:, -1, :] = 1 57 | 58 | # Transform into camera coordinate of the first view 59 | cam1_X = K_inv.bmm(projected_coors) 60 | 61 | # Transform into world coordinates 62 | RT = RT_cam2.bmm(RTinv_cam1) 63 | 64 | wrld_X = RT.bmm(cam1_X) 65 | 66 | # And intrinsics 67 | xy_proj = K.bmm(wrld_X) 68 | 69 | # And finally we project to get the final result 70 | mask = (xy_proj[:, 2:3, :].abs() < EPS).detach() 71 | 72 | # Remove invalid zs that cause nans 73 | zs = xy_proj[:, 2:3, :] 74 | zs[mask] = EPS 75 | 76 | sampler = torch.cat((xy_proj[:, 0:2, :] / -zs, xy_proj[:, 2:3, :]), 1) 77 | sampler[mask.repeat(1, 3, 1)] = -10 78 | # Flip the ys 79 | sampler = sampler * torch.Tensor([1, -1, -1]).unsqueeze(0).unsqueeze( 80 | 2 81 | ).to(sampler.device) 82 | 83 | return sampler 84 | 85 | def forward_justpts( 86 | self, src, pred_pts, K, K_inv, RT_cam1, RTinv_cam1, RT_cam2, RTinv_cam2 87 | ): 88 | # Now project these points into a new view 89 | bs, c, w, h = src.size() 90 | 91 | if len(pred_pts.size()) > 3: 92 | # reshape into the right positioning 93 | pred_pts = pred_pts.view(bs, 1, -1) 94 | src = src.view(bs, c, -1) 95 | 96 | pts3D = self.project_pts( 97 | pred_pts, K, K_inv, RT_cam1, RTinv_cam1, RT_cam2, RTinv_cam2 98 | ) 99 | pointcloud = pts3D.permute(0, 2, 1).contiguous() 100 | result = self.splatter(pointcloud, src) 101 | 102 | return result 103 | 104 | def forward( 105 | self, 106 | alphas, 107 | src, 108 | pred_pts, 109 | K, 110 | K_inv, 111 | RT_cam1, 112 | RTinv_cam1, 113 | RT_cam2, 114 | RTinv_cam2, 115 | ): 116 | # Now project these points into a new view 117 | bs, c, w, h = src.size() 118 | 119 | if len(pred_pts.size()) > 3: 120 | # reshape into the right positioning 121 | pred_pts = pred_pts.view(bs, 1, -1) 122 | src = src.view(bs, c, -1) 123 | alphas = alphas.view(bs, 1, -1).permute(0, 2, 1).contiguous() 124 | 125 | pts3D = self.project_pts( 126 | pred_pts, K, K_inv, RT_cam1, RTinv_cam1, RT_cam2, RTinv_cam2 127 | ) 128 | result = self.splatter(pts3D.permute(0, 2, 1).contiguous(), alphas, src) 129 | 130 | return result 131 | -------------------------------------------------------------------------------- /models/z_buffermodel.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from models.losses.synthesis import SynthesisLoss 7 | from models.networks.architectures import Unet 8 | from models.networks.utilities import get_decoder, get_encoder 9 | from models.projection.z_buffer_manipulator import PtsManipulator 10 | 11 | 12 | class ZbufferModelPts(nn.Module): 13 | def __init__(self, opt): 14 | super().__init__() 15 | 16 | self.opt = opt 17 | 18 | # ENCODER 19 | # Encode features to a given resolution 20 | self.encoder = get_encoder(opt) 21 | 22 | # POINT CLOUD TRANSFORMER 23 | # REGRESS 3D POINTS 24 | self.pts_regressor = Unet(channels_in=3, channels_out=1, opt=opt) 25 | 26 | if "modifier" in self.opt.depth_predictor_type: 27 | self.modifier = Unet(channels_in=64, channels_out=64, opt=opt) 28 | 29 | # 3D Points transformer 30 | if self.opt.use_rgb_features: 31 | self.pts_transformer = PtsManipulator(opt.W, C=3, opt=opt) 32 | else: 33 | self.pts_transformer = PtsManipulator(opt.W, opt=opt) 34 | 35 | self.projector = get_decoder(opt) 36 | 37 | # LOSS FUNCTION 38 | # Module to abstract away the loss function complexity 39 | self.loss_function = SynthesisLoss(opt=opt) 40 | 41 | self.min_tensor = self.register_buffer("min_z", torch.Tensor([0.1])) 42 | self.max_tensor = self.register_buffer( 43 | "max_z", torch.Tensor([self.opt.max_z]) 44 | ) 45 | self.discretized = self.register_buffer( 46 | "discretized_zs", 47 | torch.linspace(self.opt.min_z, self.opt.max_z, self.opt.voxel_size), 48 | ) 49 | 50 | def forward(self, batch): 51 | """ Forward pass of a view synthesis model with a voxel latent field. 52 | """ 53 | # Input values 54 | input_img = batch["images"][0] 55 | output_img = batch["images"][-1] 56 | 57 | if "depths" in batch.keys(): 58 | depth_img = batch["depths"][0] 59 | 60 | # Camera parameters 61 | K = batch["cameras"][0]["K"] 62 | K_inv = batch["cameras"][0]["Kinv"] 63 | 64 | input_RT = batch["cameras"][0]["P"] 65 | input_RTinv = batch["cameras"][0]["Pinv"] 66 | output_RT = batch["cameras"][-1]["P"] 67 | output_RTinv = batch["cameras"][-1]["Pinv"] 68 | 69 | if torch.cuda.is_available(): 70 | input_img = input_img.cuda() 71 | output_img = output_img.cuda() 72 | if "depths" in batch.keys(): 73 | depth_img = depth_img.cuda() 74 | 75 | K = K.cuda() 76 | K_inv = K_inv.cuda() 77 | 78 | input_RT = input_RT.cuda() 79 | input_RTinv = input_RTinv.cuda() 80 | output_RT = output_RT.cuda() 81 | output_RTinv = output_RTinv.cuda() 82 | 83 | if self.opt.use_rgb_features: 84 | fs = input_img 85 | else: 86 | fs = self.encoder(input_img) 87 | 88 | # Regressed points 89 | if not (self.opt.use_gt_depth): 90 | if not('use_inverse_depth' in self.opt) or not(self.opt.use_inverse_depth): 91 | regressed_pts = ( 92 | nn.Sigmoid()(self.pts_regressor(input_img)) 93 | * (self.opt.max_z - self.opt.min_z) 94 | + self.opt.min_z 95 | ) 96 | 97 | else: 98 | # Use the inverse for datasets with landscapes, where there 99 | # is a long tail on the depth distribution 100 | depth = self.pts_regressor(input_img) 101 | regressed_pts = 1. / (nn.Sigmoid()(depth) * 10 + 0.01) 102 | else: 103 | regressed_pts = depth_img 104 | 105 | gen_fs = self.pts_transformer.forward_justpts( 106 | fs, 107 | regressed_pts, 108 | K, 109 | K_inv, 110 | input_RT, 111 | input_RTinv, 112 | output_RT, 113 | output_RTinv, 114 | ) 115 | 116 | if "modifier" in self.opt.depth_predictor_type: 117 | gen_fs = self.modifier(gen_fs) 118 | 119 | gen_img = self.projector(gen_fs) 120 | 121 | # And the loss 122 | loss = self.loss_function(gen_img, output_img) 123 | 124 | if self.opt.train_depth: 125 | depth_loss = nn.L1Loss()(regressed_pts, depth_img) 126 | loss["Total Loss"] += depth_loss 127 | loss["depth_loss"] = depth_loss 128 | 129 | return ( 130 | loss, 131 | { 132 | "InputImg": input_img, 133 | "OutputImg": output_img, 134 | "PredImg": gen_img, 135 | "PredDepth": regressed_pts, 136 | }, 137 | ) 138 | 139 | def forward_angle(self, batch, RTs, return_depth=False): 140 | # Input values 141 | input_img = batch["images"][0] 142 | 143 | # Camera parameters 144 | K = batch["cameras"][0]["K"] 145 | K_inv = batch["cameras"][0]["Kinv"] 146 | 147 | if torch.cuda.is_available(): 148 | input_img = input_img.cuda() 149 | 150 | K = K.cuda() 151 | K_inv = K_inv.cuda() 152 | 153 | RTs = [RT.cuda() for RT in RTs] 154 | identity = ( 155 | torch.eye(4).unsqueeze(0).repeat(input_img.size(0), 1, 1).cuda() 156 | ) 157 | 158 | fs = self.encoder(input_img) 159 | regressed_pts = ( 160 | nn.Sigmoid()(self.pts_regressor(input_img)) 161 | * (self.opt.max_z - self.opt.min_z) 162 | + self.opt.min_z 163 | ) 164 | 165 | # Now rotate 166 | gen_imgs = [] 167 | for RT in RTs: 168 | torch.manual_seed( 169 | 0 170 | ) # Reset seed each time so that noise vectors are the same 171 | gen_fs = self.pts_transformer.forward_justpts( 172 | fs, regressed_pts, K, K_inv, identity, identity, RT, None 173 | ) 174 | 175 | # now create a new image 176 | gen_img = self.projector(gen_fs) 177 | 178 | gen_imgs += [gen_img] 179 | 180 | if return_depth: 181 | return gen_imgs, regressed_pts 182 | 183 | return gen_imgs 184 | -------------------------------------------------------------------------------- /options/options.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | def get_model(opt): 4 | print("Loading model %s ... ") 5 | if opt.model_type == "zbuffer_pts": 6 | from models.z_buffermodel import ZbufferModelPts 7 | 8 | model = ZbufferModelPts(opt) 9 | elif opt.model_type == "viewappearance": 10 | from models.encoderdecoder import ViewAppearanceFlow 11 | 12 | model = ViewAppearanceFlow(opt) 13 | elif opt.model_type == "tatarchenko": 14 | from models.encoderdecoder import Tatarchenko 15 | 16 | model = Tatarchenko(opt) 17 | 18 | return model 19 | 20 | 21 | def get_dataset(opt): 22 | 23 | print("Loading dataset %s ..." % opt.dataset) 24 | if opt.dataset == "mp3d": 25 | opt.train_data_path = ( 26 | "/private/home/ow045820/projects/habitat/" 27 | + "habitat-api/data/datasets/pointnav/mp3d/v1/train/train.json.gz" 28 | ) 29 | opt.val_data_path = ( 30 | "/private/home/ow045820/projects/habitat/" 31 | + "habitat-api/data/datasets/pointnav/mp3d/v1/test/test.json.gz" 32 | ) 33 | opt.test_data_path = ( 34 | "/private/home/ow045820/projects/habitat/" 35 | + "habitat-api/data/datasets/pointnav/mp3d/v1/val/val.json.gz" 36 | ) 37 | opt.scenes_dir = "/checkpoint/ow045820/data/" # this should store mp3d 38 | elif opt.dataset == "habitat": 39 | opt.train_data_path = ( 40 | "/private/home/ow045820/projects/habitat/habitat-api/" 41 | + "data/datasets/pointnav/habitat-test-scenes/v1/train/train.json.gz" 42 | ) 43 | opt.val_data_path = ( 44 | "/private/home/ow045820/projects/habitat/habitat-api/" 45 | + "data/datasets/pointnav/habitat-test-scenes/v1/val/val.json.gz" 46 | ) 47 | opt.test_data_path = ( 48 | "/private/home/ow045820/projects/habitat/habitat-api/" 49 | + "data/datasets/pointnav/habitat-test-scenes/v1/test/test.json.gz" 50 | ) 51 | opt.scenes_dir = "/private/home/ow045820/projects/habitat/habitat-api//data/scene_datasets" 52 | elif opt.dataset == "replica": 53 | opt.train_data_path = ( 54 | "/private/home/ow045820/projects/habitat/habitat-api/" 55 | + "data/datasets/pointnav/replica/v1/train/train.json.gz" 56 | ) 57 | opt.val_data_path = ( 58 | "/private/home/ow045820/projects/habitat/habitat-api/" 59 | + "data/datasets/pointnav/replica/v1/val/val.json.gz" 60 | ) 61 | opt.test_data_path = ( 62 | "/private/home/ow045820/projects/habitat/habitat-api/" 63 | + "data/datasets/pointnav/replica/v1/test/test.json.gz" 64 | ) 65 | opt.scenes_dir = "/checkpoint/ow045820/data/replica/" 66 | elif opt.dataset == "realestate": 67 | opt.min_z = 1.0 68 | opt.max_z = 100.0 69 | opt.train_data_path = ( 70 | "/checkpoint/ow045820/data/realestate10K/RealEstate10K/" 71 | ) 72 | from data.realestate10k import RealEstate10K 73 | 74 | return RealEstate10K 75 | elif opt.dataset == 'kitti': 76 | opt.min_z = 1.0 77 | opt.max_z = 50.0 78 | opt.train_data_path = ( 79 | '/private/home/ow045820/projects/code/continuous_view_synthesis/datasets/dataset_kitti' 80 | ) 81 | from data.kitti import KITTIDataLoader 82 | 83 | return KITTIDataLoader 84 | 85 | from data.habitat_data import HabitatImageGenerator as Dataset 86 | 87 | return Dataset 88 | -------------------------------------------------------------------------------- /options/test_options.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | import argparse 4 | 5 | 6 | class ArgumentParser: 7 | def __init__(self): 8 | self.parser = argparse.ArgumentParser() 9 | self.add_eval_parameters() 10 | 11 | def add_eval_parameters(self): 12 | eval_params = self.parser.add_argument_group("eval") 13 | 14 | eval_params.add_argument("--old_model", type=str, default="") 15 | eval_params.add_argument("--short_name", type=str, default="") 16 | eval_params.add_argument("--result_folder", type=str, default="") 17 | eval_params.add_argument("--test_folder", type=str, default="") 18 | eval_params.add_argument("--gt_folder", type=str, default="") 19 | eval_params.add_argument("--batch_size", type=int, default=4) 20 | eval_params.add_argument("--num_views", type=int, default=2) 21 | eval_params.add_argument("--num_workers", type=int, default=1) 22 | eval_params.add_argument( 23 | "--render_ids", type=int, nargs="+", default=[1] 24 | ) 25 | eval_params.add_argument( 26 | "--image_type", 27 | type=str, 28 | default="both", 29 | choices=( 30 | "both" 31 | ), 32 | ) 33 | eval_params.add_argument("--gpu_ids", type=str, default="0") 34 | eval_params.add_argument("--images_before_reset", type=int, default=100) 35 | eval_params.add_argument( 36 | "--test_input_image", action="store_true", default=False 37 | ) 38 | eval_params.add_argument( 39 | "--use_videos", action="store_true", default=False 40 | ) 41 | eval_params.add_argument( 42 | "--auto_regressive", action="store_true", default=False 43 | ) 44 | eval_params.add_argument("--use_gt", action="store_true", default=False) 45 | eval_params.add_argument("--save_data", action="store_true", default=False) 46 | eval_params.add_argument("--dataset", type=str, default="") 47 | eval_params.add_argument( 48 | "--use_higher_res", action="store_true", default=False 49 | ) 50 | 51 | def parse(self, arg_str=None): 52 | if arg_str is None: 53 | args = self.parser.parse_args() 54 | else: 55 | args = self.parser.parse_args(arg_str.split()) 56 | 57 | arg_groups = {} 58 | for group in self.parser._action_groups: 59 | group_dict = { 60 | a.dest: getattr(args, a.dest, None) 61 | for a in group._group_actions 62 | } 63 | arg_groups[group.title] = group_dict 64 | 65 | return (args, arg_groups) 66 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | name: pytorch3d_src 2 | channels: 3 | - pytorch 4 | - fvcore 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - portalocker=1.5.2=py37hc8dfbb8_1 9 | - python_abi=3.7=1_cp37m 10 | - pyyaml=5.3.1=py37h8f50634_0 11 | - termcolor=1.1.0=py_2 12 | - tqdm=4.43.0=py_0 13 | - yacs=0.1.6=py_0 14 | - yaml=0.2.2=h516909a_1 15 | - _libgcc_mutex=0.1=main 16 | - attrs=19.3.0=py_0 17 | - backcall=0.1.0=py37_0 18 | - blas=1.0=mkl 19 | - bleach=3.1.0=py_0 20 | - ca-certificates=2020.1.1=0 21 | - certifi=2019.11.28=py37_0 22 | - cudatoolkit=10.0.130=0 23 | - dbus=1.13.12=h746ee38_0 24 | - decorator=4.4.2=py_0 25 | - defusedxml=0.6.0=py_0 26 | - entrypoints=0.3=py37_0 27 | - expat=2.2.6=he6710b0_0 28 | - fontconfig=2.13.0=h9420a91_0 29 | - freetype=2.9.1=h8a8886c_1 30 | - glib=2.63.1=h5a9c865_0 31 | - gmp=6.1.2=h6c8ec71_1 32 | - gst-plugins-base=1.14.0=hbbd80ab_1 33 | - gstreamer=1.14.0=hb453b48_1 34 | - icu=58.2=h9c2bf20_1 35 | - importlib_metadata=1.5.0=py37_0 36 | - intel-openmp=2020.0=166 37 | - ipykernel=5.1.4=py37h39e3cac_0 38 | - ipython=7.13.0=py37h5ca1d4c_0 39 | - ipython_genutils=0.2.0=py37_0 40 | - ipywidgets=7.5.1=py_0 41 | - jedi=0.16.0=py37_0 42 | - jinja2=2.11.1=py_0 43 | - jpeg=9b=h024ee3a_2 44 | - jsonschema=3.2.0=py37_0 45 | - jupyter=1.0.0=py37_7 46 | - jupyter_client=6.0.0=py_0 47 | - jupyter_console=6.1.0=py_0 48 | - jupyter_core=4.6.1=py37_0 49 | - ld_impl_linux-64=2.33.1=h53a641e_7 50 | - libedit=3.1.20181209=hc058e9b_0 51 | - libffi=3.2.1=hd88cf55_4 52 | - libgcc-ng=9.1.0=hdf63c60_0 53 | - libgfortran-ng=7.3.0=hdf63c60_0 54 | - libpng=1.6.37=hbc83047_0 55 | - libsodium=1.0.16=h1bed415_0 56 | - libstdcxx-ng=9.1.0=hdf63c60_0 57 | - libtiff=4.1.0=h2733197_0 58 | - libuuid=1.0.3=h1bed415_2 59 | - libxcb=1.13=h1bed415_1 60 | - libxml2=2.9.9=hea5a465_1 61 | - markupsafe=1.1.1=py37h7b6447c_0 62 | - mistune=0.8.4=py37h7b6447c_0 63 | - mkl=2020.0=166 64 | - mkl-service=2.3.0=py37he904b0f_0 65 | - mkl_fft=1.0.15=py37ha843d7b_0 66 | - mkl_random=1.1.0=py37hd6b4f25_0 67 | - nbconvert=5.6.1=py37_0 68 | - nbformat=5.0.4=py_0 69 | - ncurses=6.2=he6710b0_0 70 | - ninja=1.9.0=py37hfd86e86_0 71 | - notebook=6.0.3=py37_0 72 | - numpy=1.18.1=py37h4f9e942_0 73 | - numpy-base=1.18.1=py37hde5b4d6_1 74 | - olefile=0.46=py_0 75 | - openssl=1.1.1e=h7b6447c_0 76 | - pandoc=2.2.3.2=0 77 | - pandocfilters=1.4.2=py37_1 78 | - parso=0.6.2=py_0 79 | - pcre=8.43=he6710b0_0 80 | - pexpect=4.8.0=py37_0 81 | - pickleshare=0.7.5=py37_0 82 | - pillow=7.0.0=py37hb39fc2d_0 83 | - pip=20.0.2=py37_1 84 | - prometheus_client=0.7.1=py_0 85 | - prompt_toolkit=3.0.3=py_0 86 | - ptyprocess=0.6.0=py37_0 87 | - pygments=2.6.1=py_0 88 | - pyqt=5.9.2=py37h05f1152_2 89 | - pyrsistent=0.15.7=py37h7b6447c_0 90 | - python=3.7.6=h0371630_2 91 | - python-dateutil=2.8.1=py_0 92 | - pyzmq=18.1.1=py37he6710b0_0 93 | - qt=5.9.7=h5867ecd_1 94 | - qtconsole=4.7.1=py_0 95 | - qtpy=1.9.0=py_0 96 | - readline=7.0=h7b6447c_5 97 | - send2trash=1.5.0=py37_0 98 | - setuptools=46.0.0=py37_0 99 | - sip=4.19.8=py37hf484d3e_0 100 | - six=1.14.0=py37_0 101 | - sqlite=3.31.1=h7b6447c_0 102 | - terminado=0.8.3=py37_0 103 | - testpath=0.4.4=py_0 104 | - tk=8.6.8=hbc83047_0 105 | - tornado=6.0.4=py37h7b6447c_1 106 | - traitlets=4.3.3=py37_0 107 | - wcwidth=0.1.8=py_0 108 | - webencodings=0.5.1=py37_1 109 | - wheel=0.34.2=py37_0 110 | - widgetsnbextension=3.5.1=py37_0 111 | - xz=5.2.4=h14c3975_4 112 | - zeromq=4.3.1=he6710b0_3 113 | - zipp=2.2.0=py_0 114 | - zlib=1.2.11=h7b6447c_3 115 | - zstd=1.3.7=h0b5b093_0 116 | - fvcore=0.1.dev200114=py37_0 117 | - pytorch=1.4.0=py3.7_cuda10.0.130_cudnn7.6.3_0 118 | - torchvision=0.5.0=py37_cu100 119 | prefix: /private/home/ow045820/.conda/envs/pytorch3d_src 120 | 121 | -------------------------------------------------------------------------------- /submit.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 4 | 5 | # This is an explanatory file for how to run our code on a SLURM cluster. It is not 6 | # meant to be run out of the box but should be useful if this is your use case. 7 | 8 | timestamp="$(date +"%Y-%m-%d--%H:%M:%S")" 9 | mkdir /checkpoint/ow045820/code/$timestamp/ 10 | rsync -r --exclude '.ipynb_checkpoints/' --exclude '.vscode/' \ 11 | --exclude 'mp3dtatarchenko' --exclude 'mp3dviewapp' --exclude 'mp3ddeepvoxelsunet' --exclude 'mp3dzbufferpts' \ 12 | --exclude 'method_figure' --exclude 'real_estate_results' --exclude 'checkpoints' --exclude 'modelcheckpoints' \ 13 | --exclude 'data/.ipynb_checkpoints' --exclude 'temp/' --exclude 'data_preprocessing/temp/' \ 14 | --exclude 'results/' --exclude '.git/' ~/projects/code/synsin_public /checkpoint/ow045820/code/$timestamp/ 15 | rsync -r ~/projects/code/torch3d_fork/torch3d /checkpoint/ow045820/code/$timestamp/ 16 | 17 | cd /checkpoint/ow045820/code/$timestamp/synsin_public/ 18 | 19 | TIMESTAMP=timestamp 20 | 21 | chmod +x ./submit_slurm_synsin.sh 22 | 23 | radius=4 24 | radpow=2 25 | 26 | ################################################################################################# 27 | ### Setting up parameters for different datasets 28 | ################################################################################################# 29 | 30 | # KITTI 31 | #additionalops="--num_workers 10 --dataset kitti --use_inv_z --lr 0.0001 --use_inverse_depth" 32 | 33 | # Matterport 34 | additionalops=" --num_workers 0 --lr 0.0001 " 35 | 36 | # RealEstate10K 37 | #additionalops=" --num_workers 10 --dataset realestate --use_inv_z --lr 0.0001 " 38 | 39 | 40 | ################################################################################################# 41 | ### Ablation for different hyperparameters when compositing 42 | ################################################################################################# 43 | 44 | #pppixel=4 45 | #modeltype='zbuffer_pts' 46 | #accumulation="alphacomposite" 47 | #refinemodeltype='resnet_256W8UpDown64' 48 | #suffix="_accum${accumulation}_${radius}_${radpow}_redo_ppp4" 49 | #sbatch --export=ALL,radpow=$radpow,accumulation=$accumulation,radius=$radius,additionalops="$additionalops",pppixel=$pppixel,taugumbel=$taugumbel,modeltype=$modeltype,refinemodeltype=$refinemodeltype,suffix=$suffix ./submit_slurm_synsin.sh 50 | 51 | # pppixel=8 52 | # modeltype='zbuffer_pts' 53 | # accumulation="alphacomposite" 54 | # refinemodeltype='resnet_256W8UpDown64' 55 | # suffix="_accum${accumulation}_${radius}_${radpow}_redo_ppp8_maxdisparity" 56 | # sbatch --export=ALL,radpow=$radpow,accumulation=$accumulation,radius=$radius,additionalops="$additionalops",pppixel=$pppixel,taugumbel=$taugumbel,modeltype=$modeltype,refinemodeltype=$refinemodeltype,suffix=$suffix ./submit_slurm_synsin.sh 57 | 58 | pppixel=128 59 | modeltype='zbuffer_pts' 60 | accumulation="alphacomposite" 61 | refinemodeltype='resnet_256W8UpDown64' 62 | suffix="_accum${accumulation}_${radius}_${radpow}_redo_maxdisparity" 63 | sbatch --export=ALL,radpow=$radpow,accumulation=$accumulation,radius=$radius,additionalops="$additionalops",pppixel=$pppixel,taugumbel=$taugumbel,modeltype=$modeltype,refinemodeltype=$refinemodeltype,suffix=$suffix ./submit_slurm_synsin.sh 64 | 65 | 66 | 67 | ##pppixel=1 68 | #modeltype='zbuffer_pts' 69 | #accumulation="alphacomposite" 70 | #refinemodeltype='resnet_256W8UpDown64' 71 | #suffix="_accum${accumulation}_${radius}_${radpow}_ppp${ppppixel}_redo" 72 | #sbatch --export=ALL,radpow=$radpow,accumulation=$accumulation,radius=$radius,additionalops="$additionalops",pppixel=$pppixel,taugumbel=$taugumbel,modeltype=$modeltype,refinemodeltype=$refinemodeltype,suffix=$suffix ./submit_slurm_synsin.sh 73 | 74 | 75 | #pppixel=128 76 | #radius=0.5 77 | #modeltype='zbuffer_pts' 78 | #accumulation="alphacomposite" 79 | #refinemodeltype='resnet_256W8UpDown64' 80 | #suffix="_accum${accumulation}_${radius}_${radpow}_redo" 81 | #sbatch --export=ALL,radpow=$radpow,accumulation=$accumulation,radius=$radius,additionalops="$additionalops",pppixel=$pppixel,taugumbel=$taugumbel,modeltype=$modeltype,refinemodeltype=$refinemodeltype,suffix=$suffix ./submit_slurm_synsin.sh 82 | 83 | 84 | #pppixel=128 85 | #radius=4 86 | #modeltype='zbuffer_pts' 87 | #accumulation="alphacomposite" 88 | #refinemodeltype='resnet_256W8UpDown3' 89 | #suffix="_accum${accumulation}_${radius}_${radpow}_rgb_redo" 90 | #additionalops="$additionalops --use_rgb_features" 91 | #sbatch --export=ALL,radpow=$radpow,accumulation=$accumulation,radius=$radius,additionalops="$additionalops",pppixel=$pppixel,taugumbel=$taugumbel,modeltype=$modeltype,refinemodeltype=$refinemodeltype,suffix=$suffix ./submit_slurm_synsin.sh 92 | 93 | #pppixel=128 94 | #radius=4 95 | #modeltype='zbuffer_pts' 96 | #accumulation="alphacomposite" 97 | #refinemodeltype='resnet_256W8UpDown64' 98 | #suffix="_accum${accumulation}_${radius}_${radpow}_gtdepth_redo" 99 | #additionalops="$additonalops --use_gt_depth" 100 | #sbatch --export=ALL,radpow=$radpow,accumulation=$accumulation,radius=$radius,additionalops="$additionalops",pppixel=$pppixel,taugumbel=$taugumbel,modeltype=$modeltype,refinemodeltype=$refinemodeltype,suffix=$suffix ./submit_slurm_synsin.sh 101 | 102 | #pppixel=128 103 | #radius=4 104 | #modeltype='zbuffer_pts' 105 | #accumulation="alphacomposite" 106 | #refinemodeltype='resnet_256W8UpDown64' 107 | #suffix="_accum${accumulation}_${radius}_${radpow}_traindepth_redo" 108 | #additionalops="$additonalops --train_depth" 109 | #sbatch --export=ALL,radpow=$radpow,accumulation=$accumulation,radius=$radius,additionalops="$additionalops",pppixel=$pppixel,taugumbel=$taugumbel,modeltype=$modeltype,refinemodeltype=$refinemodeltype,suffix=$suffix ./submit_slurm_synsin.sh 110 | 111 | ################################################################################################# 112 | ### Different accumulation Functions 113 | ################################################################################################# 114 | 115 | 116 | # pppixel=128 117 | # modeltype='zbuffer_pts' 118 | # accumulation="wsum" 119 | # refinemodeltype='resnet_256W8UpDown64' 120 | # suffix="_accum${accumulation}_${radius}_${radpow}_redo" 121 | # additionalops="" 122 | # sbatch --export=ALL,radpow=$radpow,accumulation=$accumulation,radius=$radius,additionalops="$additionalops",pppixel=$pppixel,taugumbel=$taugumbel,modeltype=$modeltype,refinemodeltype=$refinemodeltype,suffix=$suffix ./submit_slurm_synsin.sh 123 | 124 | # pppixel=128 125 | # modeltype='zbuffer_pts' 126 | # accumulation="wsumnorm" 127 | # refinemodeltype='resnet_256W8UpDown64' 128 | # suffix="_accum${accumulation}_${radius}_${radpow}_redo" 129 | # additionalops="" 130 | # sbatch --export=ALL,radpow=$radpow,accumulation=$accumulation,radius=$radius,additionalops="$additionalops",pppixel=$pppixel,taugumbel=$taugumbel,modeltype=$modeltype,refinemodeltype=$refinemodeltype,suffix=$suffix ./submit_slurm_synsin.sh 131 | 132 | # pppixel=128 133 | # modeltype='zbuffer_pts' 134 | # accumulation="alphacomposite" 135 | # refinemodeltype='resnet_256W8UpDown64' 136 | # suffix="_accum${accumulation}_${radius}_${radpow}_redo" 137 | # additionalops="" 138 | # sbatch --export=ALL,radpow=$radpow,accumulation=$accumulation,radius=$radius,additionalops="$additionalops",pppixel=$pppixel,taugumbel=$taugumbel,modeltype=$modeltype,refinemodeltype=$refinemodeltype,suffix=$suffix ./submit_slurm_synsin.sh 139 | 140 | 141 | -------------------------------------------------------------------------------- /submit_slurm_synsin.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 4 | 5 | # This is an explanatory file for how to run our code on a SLURM cluster. It is not 6 | # meant to be run out of the box but should be useful if this is your use case. 7 | 8 | ################################################################################# 9 | # File Name : submit_synsin.sh 10 | # Created By : Olivia Wiles 11 | # Description : submit view synthesis jobs 12 | ################################################################################# 13 | 14 | #SBATCH --job-name=vs3d 15 | 16 | #SBATCH --output=/checkpoint/%u/jobs/sample-%j.out 17 | 18 | #SBATCH --error=/checkpoint/%u/jobs/sample-%j.err 19 | 20 | #SBATCH --nodes=1 -C volta32gb 21 | 22 | #SBATCH --partition=dev 23 | 24 | #SBATCH --ntasks-per-node=1 25 | 26 | #SBATCH --gres=gpu:volta:4 27 | 28 | #SBATCH --cpus-per-task=40 29 | 30 | #SBATCH --mem=250G 31 | 32 | #SBATCH --signal=USR1@600 33 | 34 | #SBATCH --open-mode=append 35 | 36 | #SBATCH --time=72:00:00 37 | 38 | # The ENV below are only used in distributed training with env:// initialization 39 | export MASTER_ADDR=${SLURM_NODELIST:0:9}${SLURM_NODELIST:10:3} 40 | export MASTER_PORT=29500 41 | 42 | unset PYTHONPATH 43 | export LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu/nvidia-opengl:${LD_LIBRARY_PATH} 44 | export USE_SLURM=1 45 | 46 | source activate synsin_env 47 | 48 | export DEBUG=0 49 | echo $additionalops 50 | 51 | echo "Starting model with... radius $radius^$radpow pp $pppixel with model type $modeltype, accumulation $accumulation and $refinemodeltype saving to $suffix" 52 | srun --label python train.py --batch-size 32 --folder 'final2' \ 53 | --pp_pixel $pppixel --radius $radius --resume --accumulation $accumulation --rad_pow $radpow \ 54 | --model_type $modeltype --refine_model_type $refinemodeltype $additionalops \ 55 | --norm_G 'sync:spectral_batch' --gpu_ids 0,1,2 --render_ids 3 \ 56 | --suffix $suffix --normalize_image #--W 512 \ 57 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | source activate synsin_env 4 | 5 | export DEBUG=0 6 | export USE_SLURM=0 7 | 8 | # How to run on RealEstate10K 9 | python train.py --batch-size 32 --folder 'temp' --num_workers 4 \ 10 | --resume --dataset 'realestate' --use_inv_z --accumulation 'alphacomposite' \ 11 | --model_type 'zbuffer_pts' --refine_model_type 'resnet_256W8UpDown64' \ 12 | --norm_G 'sync:spectral_batch' --gpu_ids 0,1 --render_ids 1 \ 13 | --suffix '' --normalize_image --lr 0.0001 14 | 15 | # How to run on KITTI 16 | # python train.py --batch-size 32 --folder 'temp' --num_workers 4 \ 17 | # --resume --dataset 'kitti' --use_inv_z --use_inverse_depth --accumulation 'alphacomposite' \ 18 | # --model_type 'zbuffer_pts' --refine_model_type 'resnet_256W8UpDown64' \ 19 | # --norm_G 'sync:spectral_batch' --gpu_ids 0,1 --render_ids 1 \ 20 | # --suffix '' --normalize_image --lr 0.0001 21 | 22 | # # How to run on Matterport3D 23 | # python train.py --batch-size 32 --folder 'temp' --num_workers 0 \ 24 | # --resume --accumulation 'alphacomposite' \ 25 | # --model_type 'zbuffer_pts' --refine_model_type 'resnet_256W8UpDown64' \ 26 | # --norm_G 'sync:spectral_batch' --gpu_ids 0 --render_ids 1 \ 27 | # --suffix '' --normalize_image --lr 0.0001 28 | -------------------------------------------------------------------------------- /utils/geometry.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | from math import sqrt 4 | 5 | import numpy as np 6 | 7 | 8 | def get_deltas(mat1, mat2): 9 | mat1 = np.vstack((mat1, np.array([0, 0, 0, 1]))) 10 | mat2 = np.vstack((mat2, np.array([0, 0, 0, 1]))) 11 | 12 | dMat = np.matmul(np.linalg.inv(mat1), mat2) 13 | dtrans = dMat[0:3, 3] ** 2 14 | dtrans = sqrt(dtrans.sum()) 15 | 16 | origVec = np.array([[0], [0], [1]]) 17 | rotVec = np.matmul(dMat[0:3, 0:3], origVec) 18 | arccos = (rotVec * origVec).sum() / sqrt((rotVec ** 2).sum()) 19 | dAngle = np.arccos(arccos) * 180.0 / np.pi 20 | 21 | return dAngle, dtrans 22 | -------------------------------------------------------------------------------- /utils/jitter.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | import numpy as np 4 | import quaternion 5 | 6 | def jitter_quaternions(original_quaternion, rnd, angle=30.0): 7 | original_euler = quaternion.as_euler_angles(original_quaternion) 8 | euler_angles = np.array( 9 | [ 10 | (rnd.rand() - 0.5) * np.pi * angle / 180.0 + original_euler[0], 11 | (rnd.rand() - 0.5) * np.pi * angle / 180.0 + original_euler[1], 12 | (rnd.rand() - 0.5) * np.pi * angle / 180.0 + original_euler[2], 13 | ] 14 | ) 15 | quaternions = quaternion.from_euler_angles(euler_angles) 16 | 17 | return quaternions 18 | --------------------------------------------------------------------------------