├── .gitignore ├── README.md ├── build_toolbox.sh ├── clean_toolbox_build.sh ├── datasets ├── __init__.py ├── shapenet.py └── test.py ├── downloads ├── data │ └── test │ │ ├── genre │ │ ├── 03001627_10c08a28cae054e53a762233fffc49ea_view000_rgb.png │ │ ├── 03001627_10c08a28cae054e53a762233fffc49ea_view000_silhouette.png │ │ ├── 04256520_2c6dcb7184bfed32599dcc439b161a52_view010_rgb.png │ │ ├── 04256520_2c6dcb7184bfed32599dcc439b161a52_view010_silhouette.png │ │ ├── 04256520_2d987393f7f7c5d1f51f77a6d7299806_view001_rgb.png │ │ ├── 04256520_2d987393f7f7c5d1f51f77a6d7299806_view001_silhouette.png │ │ ├── 04379243_133d7c9a1f79b01ad0176f9a144100cd_view000_rgb.png │ │ └── 04379243_133d7c9a1f79b01ad0176f9a144100cd_view000_silhouette.png │ │ └── shapehd │ │ ├── 0044_mask.png │ │ ├── 0044_rgb.png │ │ ├── 0503_mask.png │ │ ├── 0503_rgb.jpg │ │ ├── 1209_mask.png │ │ └── 1209_rgb.jpg └── results │ ├── genre.png │ └── shapehd.png ├── environment.yml ├── install_trimesh.sh ├── loggers ├── Progbar.py ├── __init__.py └── loggers.py ├── models ├── __init__.py ├── depth_pred_with_sph_inpaint.py ├── genre_full_model.py ├── marrnet.py ├── marrnet1.py ├── marrnet2.py ├── marrnetbase.py ├── netinterface.py ├── shapehd.py └── wgangp.py ├── networks ├── __init__.py ├── networks.py ├── revresnet.py └── uresnet.py ├── options ├── __init__.py ├── options_test.py └── options_train.py ├── scripts ├── finetune_marrnet.sh ├── finetune_shapehd.sh ├── test_genre.sh ├── test_marrnet.sh ├── test_shapehd.sh ├── train_full_genre.sh ├── train_inpaint.sh ├── train_marrnet1.sh ├── train_marrnet2.sh └── train_wgangp.sh ├── test.py ├── toolbox ├── __init__.py ├── calc_prob │ ├── build.py │ ├── calc_prob │ │ ├── __init__.py │ │ ├── functions │ │ │ ├── __init__.py │ │ │ └── calc_prob.py │ │ └── src │ │ │ ├── calc_prob.c │ │ │ ├── calc_prob.h │ │ │ ├── calc_prob_kernel.cu │ │ │ └── calc_prob_kernel.h │ ├── clean.sh │ ├── setup.py │ └── setup.sh ├── cam_bp │ ├── build.py │ ├── cam_bp │ │ ├── __init__.py │ │ ├── functions │ │ │ ├── __init__.py │ │ │ ├── cam_back_projection.py │ │ │ ├── get_surface_mask.py │ │ │ └── sperical_to_tdf.py │ │ ├── modules │ │ │ ├── Spherical_backproj.py │ │ │ ├── __init__.py │ │ │ └── camera_backprojection_module.py │ │ └── src │ │ │ ├── _cam_bp_lib.abi3.so │ │ │ ├── back_projection.c │ │ │ ├── back_projection.h │ │ │ ├── back_projection_kernel.cu │ │ │ └── back_projection_kernel.h │ ├── clean.sh │ ├── setup.py │ └── setup.sh ├── nndistance │ ├── README.md │ ├── build.py │ ├── clean.sh │ ├── functions │ │ ├── __init__.py │ │ └── nnd.py │ ├── modules │ │ ├── __init__.py │ │ └── nnd.py │ ├── setup.sh │ ├── src │ │ ├── my_lib.c │ │ ├── my_lib.h │ │ ├── my_lib_cuda.c │ │ ├── my_lib_cuda.h │ │ ├── nnd_cuda.cu │ │ └── nnd_cuda.h │ └── test.py └── spherical_proj.py ├── train.py ├── util ├── __init__.py ├── util_cam_para.py ├── util_camera.py ├── util_img.py ├── util_io.py ├── util_loadlib.py ├── util_print.py ├── util_reproj.py ├── util_sph.py ├── util_voxel.py └── util_xml_to_cam_params.py └── visualize ├── config.json └── visualizer.py /.gitignore: -------------------------------------------------------------------------------- 1 | !.gitignore 2 | 3 | **/__pycache__/ 4 | **/*.pyc 5 | **/*.swp 6 | 7 | code_test 8 | private 9 | downloads/data/shapenet_cars_chairs_planes_20views.tar 10 | downloads/data/shapenet 11 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Generalizable Reconstruction (GenRe) and ShapeHD 2 | 3 | 4 | ## Papers 5 | 6 | This is a repo covering the following three papers. If you find the code useful, please cite the paper(s). 7 | 8 | 1. Generalizable Reconstruction (GenRe)
9 | **Learning to Reconstruct Shapes from Unseen Classes**
10 | [Xiuming Zhang](http://people.csail.mit.edu/xiuming/)*, [Zhoutong Zhang](https://ztzhang.info)*, [Chengkai Zhang](https://www.csail.mit.edu/person/chengkai-zhang), [Joshua B. Tenenbaum](http://web.mit.edu/cocosci/josh.html), [William T. Freeman](https://billf.mit.edu/), and [Jiajun Wu](https://jiajunwu.com/)
11 | *NeurIPS 2018 (Oral)*
12 | [Paper](http://genre.csail.mit.edu/papers/genre_nips.pdf)   |   [BibTeX](http://genre.csail.mit.edu/bibtex/genre_nips.bib)   |   [Project](http://genre.csail.mit.edu/) 13 | 14 | * indicates equal contribution. 15 | 16 | 1. ShapeHD
17 | **Learning Shape Priors for Single-View 3D Completion and Reconstruction**
18 | [Jiajun Wu](https://jiajunwu.com/)*, [Chengkai Zhang](https://www.csail.mit.edu/person/chengkai-zhang)*, [Xiuming Zhang](http://people.csail.mit.edu/xiuming/), [Zhoutong Zhang](https://ztzhang.info), [William T. Freeman](https://billf.mit.edu/), and [Joshua B. Tenenbaum](http://web.mit.edu/cocosci/josh.html)
19 | *ECCV 2018*
20 | [Paper](http://shapehd.csail.mit.edu/papers/shapehd_eccv.pdf)   |   [BibTeX](http://shapehd.csail.mit.edu/bibtex/shapehd_eccv.bib)   |   [Project](http://shapehd.csail.mit.edu/) 21 | 22 | 1. MarrNet
23 | **MarrNet: 3D Shape Reconstruction via 2.5D Sketches**
24 | [Jiajun Wu](https://jiajunwu.com/)*, [Yifan Wang](https://homes.cs.washington.edu/~yifan1/)*, [Tianfan Xue](https://people.csail.mit.edu/tfxue/), [Xingyuan Sun](http://people.csail.mit.edu/xingyuan/), [William T. Freeman](https://billf.mit.edu/), and [Joshua B. Tenenbaum](http://web.mit.edu/cocosci/josh.html)
25 | *NeurIPS 2017*
26 | [Paper](http://marrnet.csail.mit.edu/papers/marrnet_nips.pdf)   |   [BibTeX](http://marrnet.csail.mit.edu/bibtex/marrnet_nips.bib)   |   [Project](http://marrnet.csail.mit.edu/) 27 | 28 | 29 | 30 | ## Environment Setup 31 | 32 | All code was built and tested on Ubuntu 16.04.5 LTS with Python 3.6, PyTorch 0.4.1, and CUDA 9.0. Versions for other packages can be found in `environment.yml`. 33 | 34 | 1. Clone this repo with 35 | ``` 36 | # cd to the directory you want to work in 37 | git clone https://github.com/xiumingzhang/GenRe-ShapeHD.git 38 | cd GenRe-ShapeHD 39 | ``` 40 | The code below assumes you are at the repo root. 41 | 42 | 1. Create a conda environment named `shaperecon` with necessary dependencies specified in `environment.yml`. In order to make sure trimesh is installed correctly, please run `install_trimesh.sh` after setting up the conda environment. 43 | ``` 44 | conda env create -f environment.yml 45 | ./install_trimesh.sh 46 | ``` 47 | The TensorFlow dependency in `environment.yml` is for using TensorBoard only. Remove it if you do not want to monitor your training with TensorBoard. 48 | 49 | 1. The instructions below assume you have activated this environment and built the cuda extension with 50 | ``` 51 | source activate shaperecon 52 | ./build_toolbox.sh 53 | ``` 54 | Note that due to the deprecation of cffi from pytorch 1.0 and on, this only works for pytorch 0.4.1. 55 | 56 | 57 | ## Downloading Our Trained Models and Training Data 58 | 59 | ### Models 60 | 61 | To download our trained GenRe and ShapeHD models (1 GB in total), run 62 | ``` 63 | wget http://genre.csail.mit.edu/downloads/genre_shapehd_models.tar -P downloads/models/ 64 | tar -xvf downloads/models/genre_shapehd_models.tar -C downloads/models/ 65 | ``` 66 | 67 | * GenRe: `depth_pred_with_inpaint.pt` and `full_model.pt` 68 | * ShapeHD: `marrnet1_with_minmax.pt` and `shapehd.pt` 69 | 70 | ### Data 71 | 72 | This repo comes with a few [Pix3D](http://pix3d.csail.mit.edu/) images and [ShapeNet](https://www.shapenet.org/) renderings, located in `downloads/data/test`, for testing purposes. 73 | 74 | For training, we make available our RGB and 2.5D sketch renderings, paired with their corresponding 3D shapes, for ShapeNet cars, chairs, and airplanes, with each object captured in 20 random views. Note that this `.tar` is 143 GB. 75 | ``` 76 | wget http://genre.csail.mit.edu/downloads/shapenet_cars_chairs_planes_20views.tar -P downloads/data/ 77 | mkdir downloads/data/shapenet/ 78 | tar -xvf downloads/data/shapenet_cars_chairs_planes_20views.tar -C downloads/data/shapenet/ 79 | ``` 80 | 81 | **New (Oct. 20, 2019)** 82 | 83 | For training, in addition to the renderings already included in the initial release, we now also release the Mitsuba scene `.xml` files used to produce these renderings. [This download link](http://genre.csail.mit.edu/downloads/training_xml.zip) is a `.zip` (394 MB) consisting of the three training classes: cars, chairs, and airplanes. Among other scene parameters, camera poses can now be retrieved from these `.xml` files, which we hope would be useful for tasks like camera/object pose estimation. 84 | 85 | For testing, we release the data of the unseen categories shown in Table 1 of the paper. [This download link](http://genre.csail.mit.edu/downloads/shapenet_unseen.tar) is a `.tar` (44 GB) consisting of, for each of the unseen classes, the 500 random shapes we used for testing GenRe. Right now, nine classes are included, as we are tracking down the 10th. 86 | 87 | 88 | ## Testing with Our Models 89 | 90 | We provide `.sh` wrappers to perform testing for GenRe, ShapeHD, and MarrNet (without the reprojection consistency part). 91 | 92 | ### GenRe 93 | 94 | See `scripts/test_genre.sh`. 95 | 96 |

97 | 98 |

99 | 100 | We updated our entire pipeline to support fully differentiable end-to-end finetuning. In our NeurIPS submission, the projection from depth images to spherical maps was not implemented in a differentiable way. As a result of both the pipeline and PyTorch version upgrades, the model performace is slightly different from what was reported in the original paper. 101 | 102 | Below we tabulate the original vs. updated Chamfer distances (CD) across different Pix3D classes. The "Original" row is from Table 2 of the paper. 103 | 104 | | |Chair | Bed | Bookcase | Desk | Sofa | Table | Wardrobe | 105 | |----------|:----:|:---:|:---:|:---:|:---:|:---:|:---:| 106 | | **Updated** | .094 | .117 | .104 | .110 | .086 | .114 | .106 | 107 | | **Original** | .093 | .113 | .101 | .109 | .083 | .116 | .109 | 108 | 109 | ### ShapeHD 110 | 111 | See `scripts/test_shapehd.sh`. 112 | 113 |

114 | 115 |

116 | 117 | After ECCV, we upgraded our entire pipeline and re-trained ShapeHD with this new pipeline. The models released here are newly trained, producing quantative results slightly better than what was reported in the ECCV paper. If you use [the Pix3D repo](https://github.com/xingyuansun/pix3d) to evaluate the model released here, you will get an average CD of 0.122 for the 1,552 untruncated, unoccluded chair images (whose inplane rotation < 5°). The average CD on Pix3D chairs reported in the paper was 0.123. 118 | 119 | ### MarrNet w/o Reprojection Consistency 120 | 121 | See `scripts/test_marrnet.sh`. 122 | 123 | The architectures in this implementation of MarrNet are different from those presented in the original NeurIPS 2017 paper. For instance, the reprojection consistency is not implemented here. MarrNet-1 that predicts 2.5D sketches from RGB inputs is now a U-ResNet, different from its original architecture. That said, the idea remains the same: predicting 2.5D sketches as an intermediate step to the final 3D voxel predictions. 124 | 125 | If you want to test with the original MarrNet, see [the MarrNet repo](https://github.com/jiajunwu/marrnet) for the pretrained models. 126 | 127 | 128 | ## Training Your Own Models 129 | 130 | This repo allows you to train your own models from scratch, possibly with data different from our training data provided above. You can monitor your training with TensorBoard. For that, make sure to include `--tensorboard` while running `train.py`, and then run 131 | ``` 132 | python -m tensorboard.main --logdir="$logdir"/tensorboard 133 | ``` 134 | to visualize your losses. 135 | 136 | ### GenRe 137 | 138 | Follow these steps to train the GenRe model. 139 | 1. Train the depth estimator with `scripts/train_marrnet1.sh` 140 | 1. Train the spherical inpainting network with `scripts/train_inpaint.sh` 141 | 1. Train the full model with `scripts/train_full_genre.sh` 142 | 143 | ### ShapeHD 144 | 145 | Follow these steps to train the ShapeHD model. 146 | 1. Train the 2.5D sketch estimator with `scripts/train_marrnet1.sh` 147 | 1. Train the 2.5D-to-3D network with `scripts/train_marrnet2.sh` 148 | 1. Train a 3D-GAN with `scripts/train_wgangp.sh` 149 | 1. Finetune the 2.5D-to-3D network with perceptual losses provided by the discriminator of the 3D-GAN, using `scripts/finetune_shapehd.sh` 150 | 151 | ### MarrNet w/o Reprojection Consistency 152 | 153 | Follow these steps to train the MarrNet model, excluding the reprojection consistency. 154 | 1. Train the 2.5D sketch estimator with `scripts/train_marrnet1.sh` 155 | 1. Train the 2.5D-to-3D network with `scripts/train_marrnet2.sh` 156 | 1. Finetune the 2.5D-to-3D network with `scripts/finetune_marrnet.sh` 157 | 158 | 159 | ## Questions 160 | 161 | Please open an issue if you encounter any problem. You will likely get a quicker response than via email. 162 | 163 | 164 | ## Changelog 165 | 166 | * Dec. 28, 2018: Initial release 167 | * Oct. 20, 2019: Added testing data of the unseen categories, and all `.xml` scene files used to render training data 168 | -------------------------------------------------------------------------------- /build_toolbox.sh: -------------------------------------------------------------------------------- 1 | cd toolbox/calc_prob 2 | bash setup.sh script 3 | cd ../../ 4 | cd toolbox/nndistance 5 | bash setup.sh script 6 | cd ../../ 7 | cd toolbox/cam_bp 8 | bash setup.sh script 9 | -------------------------------------------------------------------------------- /clean_toolbox_build.sh: -------------------------------------------------------------------------------- 1 | cd toolbox/calc_prob 2 | bash clean.sh 3 | cd ../../ 4 | cd toolbox/nndistance 5 | bash clean.sh 6 | cd ../../ 7 | cd toolbox/cam_bp 8 | bash clean.sh 9 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | 4 | def get_dataset(alias): 5 | dataset_module = importlib.import_module('datasets.' + alias.lower()) 6 | return dataset_module.Dataset 7 | -------------------------------------------------------------------------------- /datasets/shapenet.py: -------------------------------------------------------------------------------- 1 | from os.path import join 2 | import random 3 | import numpy as np 4 | from scipy.io import loadmat 5 | import torch.utils.data as data 6 | import util.util_img 7 | 8 | 9 | class Dataset(data.Dataset): 10 | data_root = './downloads/data/shapenet' 11 | list_root = join(data_root, 'status') 12 | status_and_suffix = { 13 | 'rgb': { 14 | 'status': 'rgb.txt', 15 | 'suffix': '_rgb.png', 16 | }, 17 | 'depth': { 18 | 'status': 'depth.txt', 19 | 'suffix': '_depth.png', 20 | }, 21 | 'depth_minmax': { 22 | 'status': 'depth_minmax.txt', 23 | 'suffix': '.npy', 24 | }, 25 | 'silhou': { 26 | 'status': 'silhou.txt', 27 | 'suffix': '_silhouette.png', 28 | }, 29 | 'normal': { 30 | 'status': 'normal.txt', 31 | 'suffix': '_normal.png' 32 | }, 33 | 'voxel': { 34 | 'status': 'vox_rot.txt', 35 | 'suffix': '_gt_rotvox_samescale_128.npz' 36 | }, 37 | 'spherical': { 38 | 'status': 'spherical.txt', 39 | 'suffix': '_spherical.npz' 40 | }, 41 | 'voxel_canon': { 42 | 'status': 'vox_canon.txt', 43 | 'suffix': '_voxel_normalized_128.mat' 44 | }, 45 | } 46 | class_aliases = { 47 | 'drc': '03001627+02691156+02958343', 48 | 'chair': '03001627', 49 | 'table': '04379243', 50 | 'sofa': '04256520', 51 | 'couch': '04256520', 52 | 'cabinet': '03337140', 53 | 'bed': '02818832', 54 | 'plane': '02691156', 55 | 'car': '02958343', 56 | 'bench': '02828884', 57 | 'monitor': '03211117', 58 | 'lamp': '03636649', 59 | 'speaker': '03691459', 60 | 'firearm': '03948459+04090263', 61 | 'cellphone': '02992529+04401088', 62 | 'watercraft': '04530566', 63 | 'hat': '02954340', 64 | 'pot': '03991062', 65 | 'rocket': '04099429', 66 | 'train': '04468005', 67 | 'bus': '02924116', 68 | 'pistol': '03948459', 69 | 'faucet': '03325088', 70 | 'helmet': '03513137', 71 | 'clock': '03046257', 72 | 'phone': '04401088', 73 | 'display': '03211117', 74 | 'vessel': '04530566', 75 | 'rifle': '04090263', 76 | 'small': '03001627+04379243+02933112+04256520+02958343+03636649+02691156+04530566', 77 | 'all-but-table': '02691156+02747177+02773838+02801938+02808440+02818832+02828884+02843684+02871439+02876657+02880940+02924116+02933112+02942699+02946921+02954340+02958343+02992529+03001627+03046257+03085013+03207941+03211117+03261776+03325088+03337140+03467517+03513137+03593526+03624134+03636649+03642806+03691459+03710193+03759954+03761084+03790512+03797390+03928116+03938244+03948459+03991062+04004475+04074963+04090263+04099429+04225987+04256520+04330267+04401088+04460130+04468005+04530566+04554684', 78 | 'all-but-chair': '02691156+02747177+02773838+02801938+02808440+02818832+02828884+02843684+02871439+02876657+02880940+02924116+02933112+02942699+02946921+02954340+02958343+02992529+03046257+03085013+03207941+03211117+03261776+03325088+03337140+03467517+03513137+03593526+03624134+03636649+03642806+03691459+03710193+03759954+03761084+03790512+03797390+03928116+03938244+03948459+03991062+04004475+04074963+04090263+04099429+04225987+04256520+04330267+04379243+04401088+04460130+04468005+04530566+04554684', 79 | 'all': '02691156+02747177+02773838+02801938+02808440+02818832+02828884+02843684+02871439+02876657+02880940+02924116+02933112+02942699+02946921+02954340+02958343+02992529+03001627+03046257+03085013+03207941+03211117+03261776+03325088+03337140+03467517+03513137+03593526+03624134+03636649+03642806+03691459+03710193+03759954+03761084+03790512+03797390+03928116+03938244+03948459+03991062+04004475+04074963+04090263+04099429+04225987+04256520+04330267+04379243+04401088+04460130+04468005+04530566+04554684', 80 | } 81 | class_list = class_aliases['all'].split('+') 82 | 83 | @classmethod 84 | def add_arguments(cls, parser): 85 | return parser, set() 86 | 87 | @classmethod 88 | def read_bool_status(cls, status_file): 89 | with open(join(cls.list_root, status_file)) as f: 90 | lines = f.read() 91 | return [x == 'True' for x in lines.split('\n')[:-1]] 92 | 93 | def __init__(self, opt, mode='train', model=None): 94 | assert mode in ('train', 'vali') 95 | self.mode = mode 96 | if model is None: 97 | required = ['rgb'] 98 | self.preproc = None 99 | else: 100 | required = model.requires 101 | self.preproc = model.preprocess 102 | 103 | # Parse classes 104 | classes = [] # alias to real for locating data 105 | class_str = '' # real to alias for logging 106 | for c in opt.classes.split('+'): 107 | class_str += c + '+' 108 | if c in self.class_aliases: # nickname given 109 | classes += self.class_aliases[c].split('+') 110 | else: 111 | classes = c.split('+') 112 | class_str = class_str[:-1] # removes the final + 113 | classes = sorted(list(set(classes))) 114 | 115 | # Load items and train-test split 116 | with open(join(self.list_root, 'items_all.txt')) as f: 117 | lines = f.read() 118 | item_list = lines.split('\n')[:-1] 119 | is_train = self.read_bool_status('is_train.txt') 120 | assert len(item_list) == len(is_train) 121 | 122 | # Load status the network requires 123 | has = {} 124 | for data_type in required: 125 | assert data_type in self.status_and_suffix.keys(), \ 126 | "%s required, but unspecified in status_and_suffix" % data_type 127 | has[data_type] = self.read_bool_status( 128 | self.status_and_suffix[data_type]['status'] 129 | ) 130 | assert len(has[data_type]) == len(item_list) 131 | 132 | # Pack paths into a dict 133 | samples = [] 134 | for i, item in enumerate(item_list): 135 | class_id, _ = item.split('/')[:2] 136 | item_in_split = ((self.mode == 'train') == is_train[i]) 137 | if item_in_split and class_id in classes: 138 | # Look up subclass_id for this item 139 | sample_dict = {'item': join(self.data_root, item)} 140 | # As long as a type is required, it appears as a key 141 | # If it doens't exist, its value will be None 142 | for data_type in required: 143 | suffix = self.status_and_suffix[data_type]['suffix'] 144 | k = data_type + '_path' 145 | if data_type == 'voxel_canon': 146 | # All different views share the same canonical voxel 147 | sample_dict[k] = join(self.data_root, item.split('_view')[0] + suffix) \ 148 | if has[data_type][i] else None 149 | else: 150 | sample_dict[k] = join(self.data_root, item + suffix) \ 151 | if has[data_type][i] else None 152 | if None not in sample_dict.values(): 153 | # All that are required exist 154 | samples.append(sample_dict) 155 | 156 | # If validation, dataloader shuffle will be off, so need to DETERMINISTICALLY 157 | # shuffle here to have a bit of every class 158 | if self.mode == 'vali': 159 | if opt.manual_seed: 160 | seed = opt.manual_seed 161 | else: 162 | seed = 0 163 | random.Random(seed).shuffle(samples) 164 | self.samples = samples 165 | 166 | def __getitem__(self, i): 167 | sample_loaded = {} 168 | for k, v in self.samples[i].items(): 169 | sample_loaded[k] = v # as-is 170 | if k.endswith('_path'): 171 | if v.endswith('.png'): 172 | im = util.util_img.imread_wrapper( 173 | v, util.util_img.IMREAD_UNCHANGED, 174 | output_channel_order='RGB') 175 | # Normalize to [0, 1] floats 176 | im = im.astype(float) / float(np.iinfo(im.dtype).max) 177 | sample_loaded[k[:-5]] = im 178 | elif v.endswith('.npy'): 179 | # Right now .npy must be depth_minmax 180 | sample_loaded['depth_minmax'] = np.load(v) 181 | elif v.endswith('_128.npz'): 182 | sample_loaded['voxel'] = np.load(v)['voxel'][None, ...] 183 | elif v.endswith('_spherical.npz'): 184 | spherical_data = np.load(v) 185 | sample_loaded['spherical_object'] = spherical_data['obj_spherical'][None, ...] 186 | sample_loaded['spherical_depth'] = spherical_data['depth_spherical'][None, ...] 187 | elif v.endswith('.mat'): 188 | # Right now .mat must be voxel_canon 189 | sample_loaded['voxel_canon'] = loadmat(v)['voxel'][None, ...] 190 | else: 191 | raise NotImplementedError(v) 192 | # Three identical channels for grayscale images 193 | if self.preproc is not None: 194 | sample_loaded = self.preproc(sample_loaded, mode=self.mode) 195 | # convert all types to float32 for better copy speed 196 | self.convert_to_float32(sample_loaded) 197 | return sample_loaded 198 | 199 | @staticmethod 200 | def convert_to_float32(sample_loaded): 201 | for k, v in sample_loaded.items(): 202 | if isinstance(v, np.ndarray): 203 | if v.dtype != np.float32: 204 | sample_loaded[k] = v.astype(np.float32) 205 | 206 | def __len__(self): 207 | return len(self.samples) 208 | 209 | def get_classes(self): 210 | return self._class_str 211 | -------------------------------------------------------------------------------- /datasets/test.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | import numpy as np 3 | import torch.utils.data as data 4 | import util.util_img 5 | 6 | 7 | class Dataset(data.Dataset): 8 | @classmethod 9 | def add_arguments(cls, parser): 10 | return parser, set() 11 | 12 | def __init__(self, opt, model): 13 | # Get required keys and preprocessing from the model 14 | required = model.requires 15 | self.preproc = model.preprocess_wrapper 16 | # Wrapper usually crops and resizes the input image (so that it's just 17 | # like our renders) before sending it to the actual preprocessing 18 | 19 | # Associate each data type required by the model with input paths 20 | type2filename = {} 21 | for k in required: 22 | type2filename[k] = getattr(opt, 'input_' + k) 23 | 24 | # Generate a sorted filelist for each data type 25 | type2files = {} 26 | for k, v in type2filename.items(): 27 | type2files[k] = sorted(glob(v)) 28 | ns = [len(x) for x in type2files.values()] 29 | assert len(set(ns)) == 1, \ 30 | ("Filelists for different types must be of the same length " 31 | "(1-to-1 correspondance)") 32 | self.length = ns[0] 33 | 34 | samples = [] 35 | for i in range(self.length): 36 | sample = {} 37 | for k, v in type2files.items(): 38 | sample[k + '_path'] = v[i] 39 | samples.append(sample) 40 | self.samples = samples 41 | 42 | def __len__(self): 43 | return self.length 44 | 45 | def __getitem__(self, i): 46 | sample = self.samples[i] 47 | 48 | # Actually loading the item 49 | sample_loaded = {} 50 | for k, v in sample.items(): 51 | sample_loaded[k] = v # as-is 52 | if k == 'rgb_path': 53 | im = util.util_img.imread_wrapper( 54 | v, util.util_img.IMREAD_COLOR, output_channel_order='RGB') 55 | # Normalize to [0, 1] floats 56 | im = im.astype(float) / float(np.iinfo(im.dtype).max) 57 | sample_loaded['rgb'] = im 58 | elif k == 'mask_path': 59 | im = util.util_img.imread_wrapper( 60 | v, util.util_img.IMREAD_GRAYSCALE) 61 | # Normalize to [0, 1] floats 62 | im = im.astype(float) / float(np.iinfo(im.dtype).max) 63 | sample_loaded['silhou'] = im 64 | else: 65 | raise NotImplementedError(v) 66 | 67 | # Preprocessing specified by the model 68 | sample_loaded = self.preproc(sample_loaded) 69 | # Convert all types to float32 for faster copying 70 | self.convert_to_float32(sample_loaded) 71 | return sample_loaded 72 | 73 | @staticmethod 74 | def convert_to_float32(sample_loaded): 75 | for k, v in sample_loaded.items(): 76 | if isinstance(v, np.ndarray): 77 | if v.dtype != np.float32: 78 | sample_loaded[k] = v.astype(np.float32) 79 | -------------------------------------------------------------------------------- /downloads/data/test/genre/03001627_10c08a28cae054e53a762233fffc49ea_view000_rgb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiumingzhang/GenRe-ShapeHD/ee42add2707de509b5914ab444ae91b832f75981/downloads/data/test/genre/03001627_10c08a28cae054e53a762233fffc49ea_view000_rgb.png -------------------------------------------------------------------------------- /downloads/data/test/genre/03001627_10c08a28cae054e53a762233fffc49ea_view000_silhouette.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiumingzhang/GenRe-ShapeHD/ee42add2707de509b5914ab444ae91b832f75981/downloads/data/test/genre/03001627_10c08a28cae054e53a762233fffc49ea_view000_silhouette.png -------------------------------------------------------------------------------- /downloads/data/test/genre/04256520_2c6dcb7184bfed32599dcc439b161a52_view010_rgb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiumingzhang/GenRe-ShapeHD/ee42add2707de509b5914ab444ae91b832f75981/downloads/data/test/genre/04256520_2c6dcb7184bfed32599dcc439b161a52_view010_rgb.png -------------------------------------------------------------------------------- /downloads/data/test/genre/04256520_2c6dcb7184bfed32599dcc439b161a52_view010_silhouette.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiumingzhang/GenRe-ShapeHD/ee42add2707de509b5914ab444ae91b832f75981/downloads/data/test/genre/04256520_2c6dcb7184bfed32599dcc439b161a52_view010_silhouette.png -------------------------------------------------------------------------------- /downloads/data/test/genre/04256520_2d987393f7f7c5d1f51f77a6d7299806_view001_rgb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiumingzhang/GenRe-ShapeHD/ee42add2707de509b5914ab444ae91b832f75981/downloads/data/test/genre/04256520_2d987393f7f7c5d1f51f77a6d7299806_view001_rgb.png -------------------------------------------------------------------------------- /downloads/data/test/genre/04256520_2d987393f7f7c5d1f51f77a6d7299806_view001_silhouette.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiumingzhang/GenRe-ShapeHD/ee42add2707de509b5914ab444ae91b832f75981/downloads/data/test/genre/04256520_2d987393f7f7c5d1f51f77a6d7299806_view001_silhouette.png -------------------------------------------------------------------------------- /downloads/data/test/genre/04379243_133d7c9a1f79b01ad0176f9a144100cd_view000_rgb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiumingzhang/GenRe-ShapeHD/ee42add2707de509b5914ab444ae91b832f75981/downloads/data/test/genre/04379243_133d7c9a1f79b01ad0176f9a144100cd_view000_rgb.png -------------------------------------------------------------------------------- /downloads/data/test/genre/04379243_133d7c9a1f79b01ad0176f9a144100cd_view000_silhouette.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiumingzhang/GenRe-ShapeHD/ee42add2707de509b5914ab444ae91b832f75981/downloads/data/test/genre/04379243_133d7c9a1f79b01ad0176f9a144100cd_view000_silhouette.png -------------------------------------------------------------------------------- /downloads/data/test/shapehd/0044_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiumingzhang/GenRe-ShapeHD/ee42add2707de509b5914ab444ae91b832f75981/downloads/data/test/shapehd/0044_mask.png -------------------------------------------------------------------------------- /downloads/data/test/shapehd/0044_rgb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiumingzhang/GenRe-ShapeHD/ee42add2707de509b5914ab444ae91b832f75981/downloads/data/test/shapehd/0044_rgb.png -------------------------------------------------------------------------------- /downloads/data/test/shapehd/0503_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiumingzhang/GenRe-ShapeHD/ee42add2707de509b5914ab444ae91b832f75981/downloads/data/test/shapehd/0503_mask.png -------------------------------------------------------------------------------- /downloads/data/test/shapehd/0503_rgb.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiumingzhang/GenRe-ShapeHD/ee42add2707de509b5914ab444ae91b832f75981/downloads/data/test/shapehd/0503_rgb.jpg -------------------------------------------------------------------------------- /downloads/data/test/shapehd/1209_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiumingzhang/GenRe-ShapeHD/ee42add2707de509b5914ab444ae91b832f75981/downloads/data/test/shapehd/1209_mask.png -------------------------------------------------------------------------------- /downloads/data/test/shapehd/1209_rgb.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiumingzhang/GenRe-ShapeHD/ee42add2707de509b5914ab444ae91b832f75981/downloads/data/test/shapehd/1209_rgb.jpg -------------------------------------------------------------------------------- /downloads/results/genre.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiumingzhang/GenRe-ShapeHD/ee42add2707de509b5914ab444ae91b832f75981/downloads/results/genre.png -------------------------------------------------------------------------------- /downloads/results/shapehd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiumingzhang/GenRe-ShapeHD/ee42add2707de509b5914ab444ae91b832f75981/downloads/results/shapehd.png -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: shaperecon 2 | channels: 3 | - anaconda 4 | - pytorch 5 | - conda-forge 6 | dependencies: 7 | - python=3.6 8 | - numpy=1.15.4 9 | - pandas=0.23.4 10 | - tqdm=4.28.1 11 | - scikit-image=0.14.0 12 | - numba=0.41.0 13 | - opencv=3.4.2 14 | - pytorch=0.4.1 15 | - torchvision=0.2.1 16 | - tensorflow=1.5.1 17 | - trimesh=2.35.47 18 | - rtree=0.8.3 19 | - scikit-learn=0.20.1 20 | -------------------------------------------------------------------------------- /install_trimesh.sh: -------------------------------------------------------------------------------- 1 | source activate shaperecon 2 | conda config --add channels conda-forge 3 | conda install shapely rtree pyembree 4 | conda install -c conda-forge scikit-image 5 | conda install "pillow<7" 6 | pip install trimesh[all]==2.35.47 7 | -------------------------------------------------------------------------------- /loggers/Progbar.py: -------------------------------------------------------------------------------- 1 | # taken from Keras (https://github.com/fchollet/keras/blob/d687c6eda4d9cb58756822fd77402274db309da8/keras/utils/generic_utils.py) 2 | import sys 3 | import time 4 | import numpy as np 5 | 6 | 7 | class Progbar(object): 8 | """Displays a progress bar. 9 | # Arguments 10 | target: Total number of steps expected, None if unknown. 11 | interval: Minimum visual progress update interval (in seconds). 12 | """ 13 | 14 | def __init__(self, target, width=30, verbose=1, interval=0.05): 15 | self.width = width 16 | if target is None: 17 | target = -1 18 | self.target = target 19 | self.sum_values = {} 20 | self.unique_values = [] 21 | self.start = time.time() 22 | self.last_update = 0 23 | self.interval = interval 24 | self.total_width = 0 25 | self.seen_so_far = 0 26 | self.verbose = verbose 27 | 28 | def update(self, current, values=None, force=False): 29 | """Updates the progress bar. 30 | # Arguments 31 | current: Index of current step. 32 | values: List of tuples (name, value_for_last_step). 33 | The progress bar will display averages for these values. 34 | force: Whether to force visual progress update. 35 | """ 36 | values = values or [] 37 | for k, v in values: 38 | if k not in self.sum_values: 39 | self.sum_values[k] = [v * (current - self.seen_so_far), 40 | current - self.seen_so_far] 41 | self.unique_values.append(k) 42 | else: 43 | self.sum_values[k][0] += v * (current - self.seen_so_far) 44 | self.sum_values[k][1] += (current - self.seen_so_far) 45 | self.seen_so_far = current 46 | 47 | now = time.time() 48 | if self.verbose == 1: 49 | if not force and (now - self.last_update) < self.interval: 50 | return 51 | 52 | prev_total_width = self.total_width 53 | sys.stdout.write('\b' * prev_total_width) 54 | sys.stdout.write('\r') 55 | 56 | if self.target is not -1: 57 | numdigits = int(np.floor(np.log10(self.target))) + 1 58 | barstr = '%%%dd/%%%dd [' % (numdigits, numdigits) 59 | bar = barstr % (current, self.target) 60 | prog = float(current) / self.target 61 | prog_width = int(self.width * prog) 62 | if prog_width > 0: 63 | bar += ('=' * (prog_width - 1)) 64 | if current < self.target: 65 | bar += '>' 66 | else: 67 | bar += '=' 68 | bar += ('.' * (self.width - prog_width)) 69 | bar += ']' 70 | sys.stdout.write(bar) 71 | self.total_width = len(bar) 72 | 73 | if current: 74 | time_per_unit = (now - self.start) / current 75 | else: 76 | time_per_unit = 0 77 | eta = time_per_unit * (self.target - current) 78 | info = '' 79 | if current < self.target and self.target is not -1: 80 | info += ' - ETA: %ds' % eta 81 | else: 82 | info += ' - %ds' % (now - self.start) 83 | for k in self.unique_values: 84 | info += ' - %s:' % k 85 | if isinstance(self.sum_values[k], list): 86 | avg = np.mean(self.sum_values[k][0] / max(1, self.sum_values[k][1])) 87 | if abs(avg) > 1e-3: 88 | info += ' %.4f' % avg 89 | else: 90 | info += ' %.4e' % avg 91 | else: 92 | info += ' %s' % self.sum_values[k] 93 | 94 | self.total_width += len(info) 95 | if prev_total_width > self.total_width: 96 | info += ((prev_total_width - self.total_width) * ' ') 97 | 98 | sys.stdout.write(info) 99 | sys.stdout.flush() 100 | 101 | if current >= self.target: 102 | sys.stdout.write('\n') 103 | 104 | if self.verbose == 2: 105 | if current >= self.target: 106 | info = '%ds' % (now - self.start) 107 | for k in self.unique_values: 108 | info += ' - %s:' % k 109 | avg = np.mean(self.sum_values[k][0] / max(1, self.sum_values[k][1])) 110 | if avg > 1e-3: 111 | info += ' %.4f' % avg 112 | else: 113 | info += ' %.4e' % avg 114 | sys.stdout.write(info + "\n") 115 | 116 | self.last_update = now 117 | 118 | def add(self, n, values=None): 119 | self.update(self.seen_so_far + n, values) 120 | -------------------------------------------------------------------------------- /loggers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiumingzhang/GenRe-ShapeHD/ee42add2707de509b5914ab444ae91b832f75981/loggers/__init__.py -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | 4 | def get_model(alias, test=False): 5 | module = importlib.import_module('models.' + alias) 6 | if test: 7 | return module.Model_test 8 | return module.Model 9 | -------------------------------------------------------------------------------- /models/depth_pred_with_sph_inpaint.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from models.marrnet1 import Model as DepthModel 5 | from models.marrnet1 import Net as Net1 6 | from networks.uresnet import Net_inpaint as Uresnet 7 | from toolbox.cam_bp.cam_bp.modules.camera_backprojection_module import Camera_back_projection_layer 8 | from toolbox.spherical_proj import render_spherical, sph_pad 9 | import torch.nn.functional as F 10 | 11 | 12 | class Model(DepthModel): 13 | @classmethod 14 | def add_arguments(cls, parser): 15 | parser.add_argument('--pred_depth_minmax', action='store_true', default=True, 16 | help="GenRe needs minmax prediction") 17 | parser.add_argument('--load_offline', action='store_true', 18 | help="load offline prediction results") 19 | parser.add_argument('--joint_train', action='store_true', 20 | help="joint train net1 and net2") 21 | parser.add_argument('--net1_path', default=None, type=str, 22 | help="path to pretrained net1") 23 | parser.add_argument('--padding_margin', default=16, type=int, 24 | help="padding margin for spherical maps") 25 | unique_params = {'joint_train'} 26 | return parser, unique_params 27 | 28 | def __init__(self, opt, logger): 29 | super(Model, self).__init__(opt, logger) 30 | self.joint_train = opt.joint_train 31 | if not self.joint_train: 32 | self.requires = ['silhou', 'rgb', 'spherical'] 33 | self.gt_names = ['spherical_object'] 34 | self._metrics = ['spherical'] 35 | else: 36 | self.requires.append('spherical') 37 | self.gt_names = ['depth', 'silhou', 'normal', 'depth_minmax', 'spherical_object'] 38 | self._metrics.append('spherical') 39 | self.input_names = ['rgb', 'silhou', 'spherical_depth'] 40 | self.net = Net(opt, Model) 41 | self.optimizer = self.adam( 42 | self.net.parameters(), 43 | lr=opt.lr, 44 | **self.optim_params 45 | ) 46 | self._nets = [self.net] 47 | self._optimizers = [self.optimizer] 48 | self.init_vars(add_path=True) 49 | self.init_weight(self.net.net2) 50 | 51 | def __str__(self): 52 | string = "Depth Prediction with Spherical Refinement" 53 | if self.joint_train: 54 | string += ' Jointly training all the modules.' 55 | else: 56 | string += ' Only training the inpainting module.' 57 | return string 58 | 59 | def compute_loss(self, pred): 60 | loss_data = {} 61 | loss = 0 62 | if self.joint_train: 63 | loss, loss_data = super(Model, self).compute_loss(pred) 64 | sph_loss = F.mse_loss(pred['pred_sph_full'], self._gt.spherical_object) 65 | loss_data['spherical'] = sph_loss.mean().item() 66 | loss += sph_loss 67 | loss_data['loss'] = loss.mean().item() 68 | return loss, loss_data 69 | 70 | def pack_output(self, pred, batch, add_gt=True): 71 | pack = {} 72 | if self.joint_train: 73 | pack = super(Model, self).pack_output(pred, batch, add_gt=False) 74 | pack['pred_spherical_full'] = pred['pred_sph_full'].data.cpu().numpy() 75 | pack['pred_spherical_partial'] = pred['pred_sph_partial'].data.cpu().numpy() 76 | pack['proj_depth'] = pred['proj_depth'].data.cpu().numpy() 77 | pack['rgb_path'] = batch['rgb_path'] 78 | if add_gt: 79 | pack['gt_spherical_full'] = batch['spherical_object'].numpy() 80 | return pack 81 | 82 | @classmethod 83 | def preprocess(cls, data, mode='train'): 84 | dataout = DepthModel.preprocess(data, mode) 85 | if 'spherical_object' in dataout.keys(): 86 | val = dataout['spherical_object'] 87 | assert(val.shape[1] == val.shape[2]) 88 | assert(val.shape[1] == 128) 89 | sph_padded = np.pad(val, ((0, 0), (0, 0), (16, 16)), 'wrap') 90 | sph_padded = np.pad(sph_padded, ((0, 0), (16, 16), (0, 0)), 'edge') 91 | dataout['spherical_object'] = sph_padded 92 | return dataout 93 | 94 | 95 | class Net(nn.Module): 96 | def __init__(self, opt, base_class=Model): 97 | super().__init__() 98 | self.net1 = Net1( 99 | [3, 1, 1], 100 | ['normal', 'depth', 'silhou'], 101 | pred_depth_minmax=True) 102 | self.net2 = Uresnet([1], ['spherical'], input_planes=1) 103 | self.base_class = base_class 104 | self.proj_depth = Camera_back_projection_layer() 105 | self.render_spherical = render_spherical() 106 | self.joint_train = opt.joint_train 107 | self.load_offline = opt.load_offline 108 | self.padding_margin = opt.padding_margin 109 | if opt.net1_path: 110 | state_dicts = torch.load(opt.net1_path) 111 | self.net1.load_state_dict(state_dicts['nets'][0]) 112 | 113 | def forward(self, input_struct): 114 | if not self.joint_train: 115 | with torch.no_grad(): 116 | out_1 = self.net1(input_struct) 117 | else: 118 | out_1 = self.net1(input_struct) 119 | pred_abs_depth = self.get_abs_depth(out_1, input_struct) 120 | proj = self.proj_depth(pred_abs_depth) 121 | if self.load_offline: 122 | sph_in = input_struct.spherical_depth 123 | else: 124 | sph_in = self.render_spherical(torch.clamp(proj * 50, 1e-5, 1 - 1e-5)) 125 | # pad sph_in to approximate boundary conditions 126 | sph_in = sph_pad(sph_in, self.padding_margin) 127 | out_2 = self.net2(sph_in) 128 | out_1['proj_depth'] = proj * 50 129 | out_1['pred_sph_partial'] = sph_in 130 | out_1['pred_sph_full'] = out_2['spherical'] 131 | return out_1 132 | 133 | def get_abs_depth(self, pred, input_struct): 134 | pred_depth = pred['depth'] 135 | pred_depth = self.base_class.postprocess(pred_depth) 136 | pred_depth_minmax = pred['depth_minmax'].detach() 137 | pred_abs_depth = self.base_class.to_abs_depth(1 - pred_depth, pred_depth_minmax) 138 | silhou = self.base_class.postprocess(input_struct.silhou).detach() 139 | pred_abs_depth[silhou < 0.5] = 0 140 | pred_abs_depth = pred_abs_depth.permute(0, 1, 3, 2) 141 | pred_abs_depth = torch.flip(pred_abs_depth, [2]) 142 | return pred_abs_depth 143 | -------------------------------------------------------------------------------- /models/marrnet.py: -------------------------------------------------------------------------------- 1 | from os import makedirs 2 | from os.path import join 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | from util import util_img 7 | from .marrnet1 import Net as Marrnet1 8 | from .marrnet2 import Net as Marrnet2, Model as Marrnet2_model 9 | 10 | 11 | class Model(Marrnet2_model): 12 | @classmethod 13 | def add_arguments(cls, parser): 14 | parser.add_argument( 15 | '--canon_sup', 16 | action='store_true', 17 | help="Use canonical-pose voxels as supervision" 18 | ) 19 | parser.add_argument( 20 | '--marrnet1', 21 | type=str, default=None, 22 | help="Path to pretrained MarrNet-1" 23 | ) 24 | parser.add_argument( 25 | '--marrnet2', 26 | type=str, default=None, 27 | help="Path to pretrained MarrNet-2 (to be finetuned)" 28 | ) 29 | return parser, set() 30 | 31 | def __init__(self, opt, logger): 32 | super().__init__(opt, logger) 33 | pred_silhou_thres = self.pred_silhou_thres * self.scale_25d 34 | self.requires = ['rgb', self.voxel_key] 35 | self.net = Net(opt.marrnet1, opt.marrnet2, pred_silhou_thres) 36 | self._nets = [self.net] 37 | self.optimizer = self.adam( 38 | self.net.marrnet2.parameters(), 39 | lr=opt.lr, 40 | **self.optim_params 41 | ) # just finetune MarrNet-2 42 | self._optimizers[-1] = self.optimizer 43 | self.input_names = ['rgb'] 44 | self.init_vars(add_path=True) 45 | 46 | def __str__(self): 47 | return "Finetuning MarrNet-2 with MarrNet-1 predictions" 48 | 49 | def pack_output(self, pred, batch, add_gt=True): 50 | pred_normal = pred['normal'].detach().cpu() 51 | pred_silhou = pred['silhou'].detach().cpu() 52 | pred_depth = pred['depth'].detach().cpu() 53 | out = {} 54 | out['rgb_path'] = batch['rgb_path'] 55 | out['rgb'] = util_img.denormalize_colors(batch['rgb'].detach().numpy()) 56 | pred_silhou = self.postprocess(pred_silhou) 57 | pred_silhou = torch.clamp(pred_silhou, 0, 1) 58 | pred_silhou[pred_silhou < 0] = 0 59 | out['pred_silhou'] = pred_silhou.numpy() 60 | out['pred_normal'] = self.postprocess( 61 | pred_normal, bg=1.0, input_mask=pred_silhou 62 | ).numpy() 63 | out['pred_depth'] = self.postprocess( 64 | pred_depth, bg=0.0, input_mask=pred_silhou 65 | ).numpy() 66 | out['pred_voxel'] = pred['voxel'].detach().cpu().numpy() 67 | if add_gt: 68 | out['gt_voxel'] = batch[self.voxel_key].numpy() 69 | return out 70 | 71 | def compute_loss(self, pred): 72 | loss = self.criterion( 73 | pred['voxel'], 74 | getattr(self._gt, self.voxel_key) 75 | ) 76 | loss_data = {} 77 | loss_data['loss'] = loss.mean().item() 78 | return loss, loss_data 79 | 80 | 81 | class Net(nn.Module): 82 | """ 83 | MarrNet-1 MarrNet-2 84 | RGB ------> 2.5D ------> 3D 85 | fixed finetuned 86 | """ 87 | 88 | def __init__(self, marrnet1_path=None, marrnet2_path=None, pred_silhou_thres=0.3): 89 | super().__init__() 90 | # Init MarrNet-1 and load weights 91 | self.marrnet1 = Marrnet1( 92 | [3, 1, 1], 93 | ['normal', 'depth', 'silhou'], 94 | pred_depth_minmax=True, # not used in MarrNet 95 | ) 96 | if marrnet1_path: 97 | state_dict = torch.load(marrnet1_path)['nets'][0] 98 | self.marrnet1.load_state_dict(state_dict) 99 | # Init MarrNet-2 and load weights 100 | self.marrnet2 = Marrnet2(4) 101 | if marrnet2_path: 102 | state_dict = torch.load(marrnet2_path)['nets'][0] 103 | self.marrnet2.load_state_dict(state_dict) 104 | # Fix MarrNet-1, but finetune 2 105 | for p in self.marrnet1.parameters(): 106 | p.requires_grad = False 107 | for p in self.marrnet2.parameters(): 108 | p.requires_grad = True 109 | self.pred_silhou_thres = pred_silhou_thres 110 | 111 | def forward(self, input_struct): 112 | # Predict 2.5D sketches 113 | with torch.no_grad(): 114 | pred = self.marrnet1(input_struct) 115 | depth = pred['depth'] 116 | normal = pred['normal'] 117 | silhou = pred['silhou'] 118 | # Mask 119 | is_bg = silhou < self.pred_silhou_thres 120 | depth[is_bg] = 0 121 | normal[is_bg.repeat(1, 3, 1, 1)] = 0 122 | x = torch.cat((depth, normal), 1) 123 | # Forward 124 | latent_vec = self.marrnet2.encoder(x) 125 | vox = self.marrnet2.decoder(latent_vec) 126 | pred['voxel'] = vox 127 | return pred 128 | 129 | 130 | class Model_test(Model): 131 | def __init__(self, opt, logger): 132 | super().__init__(opt, logger) 133 | self.requires = ['rgb', 'mask'] # mask for bbox cropping only 134 | self.load_state_dict(opt.net_file, load_optimizer='auto') 135 | self.input_names = ['rgb'] 136 | self.init_vars(add_path=True) 137 | self.output_dir = opt.output_dir 138 | 139 | def __str__(self): 140 | return "Testing MarrNet" 141 | 142 | @classmethod 143 | def preprocess_wrapper(cls, in_dict): 144 | silhou_thres = 0.95 145 | in_size = 480 146 | pad = 85 147 | im = in_dict['rgb'] 148 | mask = in_dict['silhou'] 149 | bbox = util_img.get_bbox(mask, th=silhou_thres) 150 | im_crop = util_img.crop(im, bbox, in_size, pad, pad_zero=False) 151 | in_dict['rgb'] = im_crop 152 | del in_dict['silhou'] # just for cropping -- done its job 153 | # Now the image is just like those we rendered 154 | out_dict = cls.preprocess(in_dict, mode='test') 155 | return out_dict 156 | 157 | def test_on_batch(self, batch_i, batch): 158 | outdir = join(self.output_dir, 'batch%04d' % batch_i) 159 | makedirs(outdir, exist_ok=True) 160 | pred = self.predict(batch, load_gt=False, no_grad=True) 161 | output = self.pack_output(pred, batch, add_gt=False) 162 | self.visualizer.visualize(output, batch_i, outdir) 163 | np.savez(outdir + '.npz', **output) 164 | -------------------------------------------------------------------------------- /models/marrnet1.py: -------------------------------------------------------------------------------- 1 | from os import makedirs 2 | from os.path import join 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | from networks.networks import ViewAsLinear 7 | from networks.uresnet import Net as Uresnet 8 | from .marrnetbase import MarrnetBaseModel 9 | 10 | 11 | class Model(MarrnetBaseModel): 12 | @classmethod 13 | def add_arguments(cls, parser): 14 | parser.add_argument( 15 | '--pred_depth_minmax', 16 | action='store_true', 17 | help="Also predicts depth minmax (for GenRe)", 18 | ) 19 | return parser, set() 20 | 21 | def __init__(self, opt, logger): 22 | super(Model, self).__init__(opt, logger) 23 | self.requires = ['rgb', 'depth', 'silhou', 'normal'] 24 | if opt.pred_depth_minmax: 25 | self.requires.append('depth_minmax') 26 | self.net = Net( 27 | [3, 1, 1], 28 | ['normal', 'depth', 'silhou'], 29 | pred_depth_minmax=opt.pred_depth_minmax, 30 | ) 31 | self.criterion = nn.functional.mse_loss 32 | self.optimizer = self.adam( 33 | self.net.parameters(), 34 | lr=opt.lr, 35 | **self.optim_params 36 | ) 37 | self._nets = [self.net] 38 | self._optimizers.append(self.optimizer) 39 | self.input_names = ['rgb'] 40 | self.gt_names = ['depth', 'silhou', 'normal'] 41 | if opt.pred_depth_minmax: 42 | self.gt_names.append('depth_minmax') 43 | self.init_vars(add_path=True) 44 | self._metrics = ['loss', 'depth', 'silhou', 'normal'] 45 | if opt.pred_depth_minmax: 46 | self._metrics.append('depth_minmax') 47 | self.init_weight(self.net) 48 | 49 | def __str__(self): 50 | return "MarrNet-1 predicting 2.5D sketches" 51 | 52 | def _train_on_batch(self, epoch, batch_idx, batch): 53 | self.net.zero_grad() 54 | pred = self.predict(batch) 55 | loss, loss_data = self.compute_loss(pred) 56 | loss.backward() 57 | self.optimizer.step() 58 | batch_size = len(batch['rgb_path']) 59 | batch_log = {'size': batch_size, **loss_data} 60 | return batch_log 61 | 62 | def _vali_on_batch(self, epoch, batch_idx, batch): 63 | pred = self.predict(batch, no_grad=True) 64 | _, loss_data = self.compute_loss(pred) 65 | if np.mod(epoch, self.opt.vis_every_vali) == 0: 66 | if batch_idx < self.opt.vis_batches_vali: 67 | outdir = join(self.full_logdir, 'epoch%04d_vali' % epoch) 68 | makedirs(outdir, exist_ok=True) 69 | output = self.pack_output(pred, batch) 70 | self.visualizer.visualize(output, batch_idx, outdir) 71 | np.savez(join(outdir, 'batch%04d' % batch_idx), **output) 72 | batch_size = len(batch['rgb_path']) 73 | batch_log = {'size': batch_size, **loss_data} 74 | return batch_log 75 | 76 | def pack_output(self, pred, batch, add_gt=True): 77 | pred_normal = pred['normal'].detach().cpu() 78 | pred_silhou = pred['silhou'].detach().cpu() 79 | pred_depth = pred['depth'].detach().cpu() 80 | gt_silhou = self.postprocess(batch['silhou']) 81 | out = {} 82 | out['rgb_path'] = batch['rgb_path'] 83 | out['pred_normal'] = self.postprocess(pred_normal, bg=1.0, input_mask=gt_silhou).numpy() 84 | out['pred_silhou'] = self.postprocess(pred_silhou).numpy() 85 | pred_depth = self.postprocess(pred_depth, bg=0.0, input_mask=gt_silhou) 86 | out['pred_depth'] = pred_depth.numpy() 87 | if self.opt.pred_depth_minmax: 88 | pred_depth_minmax = pred['depth_minmax'].detach() 89 | pred_abs_depth = self.to_abs_depth( 90 | (1 - pred_depth).to(torch.device('cuda')), 91 | pred_depth_minmax 92 | ) # background is max now 93 | pred_abs_depth[gt_silhou < 1] = 0 # set background to 0 94 | out['proj_depth'] = self.proj_depth(pred_abs_depth).cpu().numpy() 95 | out['pred_depth_minmax'] = pred_depth_minmax.cpu().numpy() 96 | if add_gt: 97 | out['normal_path'] = batch['normal_path'] 98 | out['silhou_path'] = batch['silhou_path'] 99 | out['depth_path'] = batch['depth_path'] 100 | if self.opt.pred_depth_minmax: 101 | out['gt_depth_minmax'] = batch['depth_minmax'].numpy() 102 | return out 103 | 104 | def compute_loss(self, pred): 105 | """ 106 | TODO: we should add normal and depth consistency loss here in the future. 107 | """ 108 | pred_normal = pred['normal'] 109 | pred_depth = pred['depth'] 110 | pred_silhou = pred['silhou'] 111 | is_fg = self._gt.silhou != 0 # excludes background 112 | is_fg_full = is_fg.expand_as(pred_normal) 113 | loss_normal = self.criterion( 114 | pred_normal[is_fg_full], self._gt.normal[is_fg_full] 115 | ) 116 | loss_depth = self.criterion( 117 | pred_depth[is_fg], self._gt.depth[is_fg] 118 | ) 119 | loss_silhou = self.criterion(pred_silhou, self._gt.silhou) 120 | loss = loss_normal + loss_depth + loss_silhou 121 | loss_data = {} 122 | loss_data['loss'] = loss.mean().item() 123 | loss_data['normal'] = loss_normal.mean().item() 124 | loss_data['depth'] = loss_depth.mean().item() 125 | loss_data['silhou'] = loss_silhou.mean().item() 126 | if self.opt.pred_depth_minmax: 127 | w_minmax = (256 ** 2) / 2 # matching scale of pixel predictions very roughly 128 | loss_depth_minmax = w_minmax * self.criterion( 129 | pred['depth_minmax'], 130 | self._gt.depth_minmax 131 | ) 132 | loss += loss_depth_minmax 133 | loss_data['depth_minmax'] = loss_depth_minmax.mean().item() 134 | return loss, loss_data 135 | 136 | 137 | class Net(Uresnet): 138 | def __init__(self, *args, pred_depth_minmax=True): 139 | super().__init__(*args) 140 | self.pred_depth_minmax = pred_depth_minmax 141 | if self.pred_depth_minmax: 142 | module_list = nn.Sequential( 143 | nn.Conv2d(512, 512, 2, stride=2), 144 | nn.Conv2d(512, 512, 4, stride=1), 145 | ViewAsLinear(), 146 | nn.Linear(512, 256), 147 | nn.BatchNorm1d(256), 148 | nn.ReLU(inplace=True), 149 | nn.Linear(256, 128), 150 | nn.BatchNorm1d(128), 151 | nn.ReLU(inplace=True), 152 | nn.Linear(128, 2) 153 | ) 154 | self.decoder_minmax = module_list 155 | 156 | def forward(self, input_struct): 157 | x = input_struct.rgb 158 | out_dict = super().forward(x) 159 | if self.pred_depth_minmax: 160 | out_dict['depth_minmax'] = self.decoder_minmax(self.encoder_out) 161 | return out_dict 162 | -------------------------------------------------------------------------------- /models/marrnet2.py: -------------------------------------------------------------------------------- 1 | from os import makedirs 2 | from os.path import join 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | from networks.networks import ImageEncoder, VoxelDecoder 7 | from .marrnetbase import MarrnetBaseModel 8 | 9 | 10 | class Model(MarrnetBaseModel): 11 | @classmethod 12 | def add_arguments(cls, parser): 13 | parser.add_argument( 14 | '--canon_sup', 15 | action='store_true', 16 | help="Use canonical-pose voxels as supervision" 17 | ) 18 | return parser, set() 19 | 20 | def __init__(self, opt, logger): 21 | super(Model, self).__init__(opt, logger) 22 | if opt.canon_sup: 23 | voxel_key = 'voxel_canon' 24 | else: 25 | voxel_key = 'voxel' 26 | self.voxel_key = voxel_key 27 | self.requires = ['rgb', 'depth', 'normal', 'silhou', voxel_key] 28 | self.net = Net(4) 29 | self.criterion = nn.BCEWithLogitsLoss(reduction='elementwise_mean') 30 | self.optimizer = self.adam( 31 | self.net.parameters(), 32 | lr=opt.lr, 33 | **self.optim_params 34 | ) 35 | self._nets = [self.net] 36 | self._optimizers.append(self.optimizer) 37 | self.input_names = ['depth', 'normal', 'silhou'] 38 | self.gt_names = [voxel_key] 39 | self.init_vars(add_path=True) 40 | self._metrics = ['loss'] 41 | self.init_weight(self.net) 42 | 43 | def __str__(self): 44 | return "MarrNet-2 predicting voxels from 2.5D sketches" 45 | 46 | def _train_on_batch(self, epoch, batch_idx, batch): 47 | self.net.zero_grad() 48 | pred = self.predict(batch) 49 | loss, loss_data = self.compute_loss(pred) 50 | loss.backward() 51 | self.optimizer.step() 52 | batch_size = len(batch['rgb_path']) 53 | batch_log = {'size': batch_size, **loss_data} 54 | return batch_log 55 | 56 | def _vali_on_batch(self, epoch, batch_idx, batch): 57 | pred = self.predict(batch, no_grad=True) 58 | _, loss_data = self.compute_loss(pred) 59 | if np.mod(epoch, self.opt.vis_every_vali) == 0: 60 | if batch_idx < self.opt.vis_batches_vali: 61 | outdir = join(self.full_logdir, 'epoch%04d_vali' % epoch) 62 | makedirs(outdir, exist_ok=True) 63 | output = self.pack_output(pred, batch) 64 | self.visualizer.visualize(output, batch_idx, outdir) 65 | np.savez(join(outdir, 'batch%04d' % batch_idx), **output) 66 | batch_size = len(batch['rgb_path']) 67 | batch_log = {'size': batch_size, **loss_data} 68 | return batch_log 69 | 70 | def pack_output(self, pred, batch, add_gt=True): 71 | out = {} 72 | out['rgb_path'] = batch['rgb_path'] 73 | out['pred_voxel'] = pred.detach().cpu().numpy() 74 | if add_gt: 75 | out['gt_voxel'] = batch[self.voxel_key].numpy() 76 | out['normal_path'] = batch['normal_path'] 77 | out['depth_path'] = batch['depth_path'] 78 | out['silhou_path'] = batch['silhou_path'] 79 | return out 80 | 81 | def compute_loss(self, pred): 82 | loss = self.criterion(pred, getattr(self._gt, self.voxel_key)) 83 | loss_data = {} 84 | loss_data['loss'] = loss.mean().item() 85 | return loss, loss_data 86 | 87 | 88 | class Net(nn.Module): 89 | """ 90 | 2.5D maps to 3D voxel 91 | """ 92 | 93 | def __init__(self, in_planes, encode_dims=200, silhou_thres=0): 94 | super().__init__() 95 | self.encoder = ImageEncoder(in_planes, encode_dims=encode_dims) 96 | self.decoder = VoxelDecoder(n_dims=encode_dims, nf=512) 97 | self.silhou_thres = silhou_thres 98 | 99 | def forward(self, input_struct): 100 | depth = input_struct.depth 101 | normal = input_struct.normal 102 | silhou = input_struct.silhou 103 | # Mask 104 | is_bg = silhou <= self.silhou_thres 105 | depth[is_bg] = 0 106 | normal[is_bg.repeat(1, 3, 1, 1)] = 0 # NOTE: if old net2, set to white (100), 107 | x = torch.cat((depth, normal), 1) # and swap depth and normal 108 | # Forward 109 | latent_vec = self.encoder(x) 110 | vox = self.decoder(latent_vec) 111 | return vox 112 | -------------------------------------------------------------------------------- /models/marrnetbase.py: -------------------------------------------------------------------------------- 1 | from os import makedirs 2 | from os.path import join 3 | import numpy as np 4 | import torch 5 | from torch import nn 6 | from models.netinterface import NetInterface 7 | from toolbox.cam_bp.cam_bp.functions import CameraBackProjection 8 | import util.util_img 9 | 10 | 11 | class MarrnetBaseModel(NetInterface): 12 | im_size = 256 13 | rgb_jitter_d = 0.4 14 | rgb_light_noise = 0.1 15 | silhou_thres = 0.999 16 | pred_silhou_thres = 0.3 17 | scale_25d = 100 18 | 19 | def __init__(self, opt, logger): 20 | super(MarrnetBaseModel, self).__init__(opt, logger) 21 | self.opt = opt 22 | self.n_batches_per_epoch = opt.epoch_batches 23 | self.n_batches_to_vis_train = opt.vis_batches_train 24 | self.n_batches_to_vis_vali = opt.vis_batches_vali 25 | self.full_logdir = opt.full_logdir 26 | self._metrics = [] 27 | self.batches_to_vis = {} 28 | self.dataset = opt.dataset 29 | self._nets = [] 30 | self._optimizers = [] 31 | self._moveable_vars = [] 32 | self.cam_bp = Camera_back_projection_layer(128) 33 | if opt.log_time: 34 | self._metrics += ['batch_time', 'data_time'] 35 | # Parameters for different optimization methods 36 | self.optim_params = dict() 37 | if opt.optim == 'adam': 38 | self.optim_params['betas'] = (opt.adam_beta1, opt.adam_beta2) 39 | elif opt.optim == 'sgd': 40 | self.optim_params['momentum'] = opt.sgd_momentum 41 | self.optim_params['dampening'] = opt.sgd_dampening 42 | self.optim_params['weight_decay'] = opt.sgd_wdecay 43 | else: 44 | raise NotImplementedError(opt.optim) 45 | 46 | def _train_on_batch(self, batch_idx, batch): 47 | self.net.zero_grad() 48 | pred = self.predict(batch) 49 | loss, loss_data = self.compute_loss(pred) 50 | loss.backward() 51 | self.optimizer.step() 52 | batch_size = len(batch['rgb_path']) 53 | batch_log = {'size': batch_size, **loss_data} 54 | self.record_batch(batch_idx, batch) 55 | return batch_log 56 | 57 | def _vali_on_batch(self, epoch, batch_idx, batch): 58 | pred = self.predict(batch, no_grad=True) 59 | _, loss_data = self.compute_loss(pred) 60 | if np.mod(epoch, self.opt.vis_every_vali) == 0: 61 | if batch_idx < self.opt.vis_batches_vali: 62 | outdir = join(self.full_logdir, 'epoch%04d_vali' % epoch) 63 | makedirs(outdir, exist_ok=True) 64 | output = self.pack_output(pred, batch) 65 | self.visualizer.visualize(output, batch_idx, outdir) 66 | np.savez(join(outdir, 'batch%04d' % batch_idx), **output) 67 | batch_size = len(batch['rgb_path']) 68 | batch_log = {'size': batch_size, **loss_data} 69 | return batch_log 70 | 71 | @classmethod 72 | def preprocess(cls, data, mode='train'): 73 | """ 74 | This function should be applied to [0, 1] floats, except absolute depth 75 | """ 76 | data_proc = {} 77 | for key, val in data.items(): 78 | if key == 'rgb': 79 | im = val 80 | # H x W x 3 81 | im = util.util_img.resize(im, cls.im_size, 'horizontal') 82 | if mode == 'train': 83 | im = util.util_img.jitter_colors( 84 | im, 85 | d_brightness=cls.rgb_jitter_d, 86 | d_contrast=cls.rgb_jitter_d, 87 | d_saturation=cls.rgb_jitter_d 88 | ) 89 | im = util.util_img.add_lighting_noise( 90 | im, cls.rgb_light_noise) 91 | im = util.util_img.normalize_colors(im) 92 | val = im.transpose(2, 0, 1) 93 | 94 | elif key == 'depth': 95 | im = val 96 | if im.ndim == 3: 97 | im = im[:, :, 0] 98 | im = util.util_img.resize( 99 | im, cls.im_size, 'horizontal', clamp=(im.min(), im.max())) 100 | im *= cls.scale_25d 101 | val = im[np.newaxis, :, :] 102 | # 1 x H x W, scaled 103 | 104 | elif key == 'silhou': 105 | im = val 106 | if im.ndim == 3: 107 | im = im[:, :, 0] 108 | im = util.util_img.resize( 109 | im, cls.im_size, 'horizontal', clamp=(im.min(), im.max())) 110 | im = util.util_img.binarize( 111 | im, cls.silhou_thres, gt_is_1=True) 112 | im *= cls.scale_25d 113 | val = im[np.newaxis, :, :] 114 | # 1 x H x W, binarized, scaled 115 | 116 | elif key == 'normal': 117 | # H x W x 3 118 | im = val 119 | im = util.util_img.resize( 120 | im, cls.im_size, 'horizontal', clamp=(im.min(), im.max())) 121 | im *= cls.scale_25d 122 | val = im.transpose(2, 0, 1) 123 | # 3 x H x W, scaled 124 | 125 | data_proc[key] = val 126 | return data_proc 127 | 128 | @staticmethod 129 | def mask(input_image, input_mask, bg=1.0): 130 | assert isinstance(bg, (int, float)) 131 | assert (input_mask >= 0).all() and (input_mask <= 1).all() 132 | input_mask = input_mask.expand_as(input_image) 133 | bg = bg * input_image.new_ones(input_image.size()) 134 | output = input_mask * input_image + (1 - input_mask) * bg 135 | return output 136 | 137 | @classmethod 138 | def postprocess(cls, tensor, bg=1.0, input_mask=None): 139 | scaled = tensor / cls.scale_25d 140 | if input_mask is not None: 141 | return cls.mask(scaled, input_mask, bg=bg) 142 | return scaled 143 | 144 | @staticmethod 145 | def to_abs_depth(rel_depth, depth_minmax): 146 | bmin = depth_minmax[:, 0] 147 | bmax = depth_minmax[:, 1] 148 | depth_min = bmin.view(-1, 1, 1, 1) 149 | depth_max = bmax.view(-1, 1, 1, 1) 150 | abs_depth = rel_depth * (depth_max - depth_min + 1e-4) + depth_min 151 | return abs_depth 152 | 153 | def proj_depth(self, abs_depth): 154 | proj_depth = self.cam_bp(abs_depth) 155 | return self.cam_bp.shift_tdf(proj_depth) 156 | 157 | 158 | class Camera_back_projection_layer(nn.Module): 159 | def __init__(self, res): 160 | super(Camera_back_projection_layer, self).__init__() 161 | self.res = res 162 | 163 | def forward(self, depth_t, fl=784.4645406, cam_dist=2.2): 164 | # print(cam_dist) 165 | n = depth_t.size(0) 166 | if isinstance(fl, float): 167 | fl_v = fl 168 | fl = torch.FloatTensor(n, 1).cuda() 169 | fl.fill_(fl_v) 170 | if isinstance(cam_dist, float): 171 | cmd_v = cam_dist 172 | cam_dist = torch.FloatTensor(n, 1).cuda() 173 | cam_dist.fill_(cmd_v) 174 | return CameraBackProjection.apply(depth_t, fl, cam_dist, self.res) 175 | 176 | @staticmethod 177 | def shift_tdf(input_tdf, res=128): 178 | out_tdf = 1 - res * input_tdf 179 | return out_tdf 180 | -------------------------------------------------------------------------------- /models/shapehd.py: -------------------------------------------------------------------------------- 1 | from os import makedirs 2 | from os.path import join 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | from util import util_img 7 | from .wgangp import D 8 | from .marrnet2 import Net as Marrnet2, Model as Marrnet2_model 9 | from .marrnet1 import Model as Marrnet1_model 10 | 11 | 12 | class Model(Marrnet2_model): 13 | @classmethod 14 | def add_arguments(cls, parser): 15 | parser.add_argument( 16 | '--canon_sup', 17 | action='store_true', 18 | help="Use canonical-pose voxels as supervision" 19 | ) 20 | parser.add_argument( 21 | '--marrnet2', 22 | type=str, default=None, 23 | help="Path to pretrained MarrNet-2 (to be finetuned)" 24 | ) 25 | parser.add_argument( 26 | '--gan', 27 | type=str, default=None, 28 | help="Path to pretrained WGANGP" 29 | ) 30 | parser.add_argument( 31 | '--w_gan_loss', 32 | type=float, default=0, 33 | help="Weight for perceptual loss relative to supervised loss" 34 | ) 35 | return parser, set() 36 | 37 | def __init__(self, opt, logger): 38 | super().__init__(opt, logger) 39 | assert opt.canon_sup, "ShapeHD uses canonical-pose voxels" 40 | self.net = Net(opt.marrnet2, opt.gan) 41 | self._nets = [self.net] 42 | self.optimizer = self.adam( 43 | self.net.marrnet2.parameters(), 44 | lr=opt.lr, 45 | **self.optim_params 46 | ) # just finetune MarrNet-2 47 | self._optimizers[-1] = self.optimizer 48 | self._metrics += ['sup', 'gan'] 49 | self.init_vars(add_path=True) 50 | assert opt.w_gan_loss >= 0 51 | 52 | def __str__(self): 53 | return "Finetuning 3D estimator of ShapeHD with GAN loss" 54 | 55 | def pack_output(self, pred, batch, add_gt=True): 56 | out = {} 57 | out['rgb_path'] = batch['rgb_path'] 58 | out['pred_voxel_noft'] = pred['voxel_noft'].detach().cpu().numpy() 59 | out['pred_voxel'] = pred['voxel'].detach().cpu().numpy() 60 | if add_gt: 61 | out['gt_voxel'] = batch[self.voxel_key].numpy() 62 | out['normal_path'] = batch['normal_path'] 63 | out['depth_path'] = batch['depth_path'] 64 | out['silhou_path'] = batch['silhou_path'] 65 | return out 66 | 67 | def compute_loss(self, pred): 68 | loss_sup = self.criterion( 69 | pred['voxel'], # will be sigmoid'ed 70 | getattr(self._gt, self.voxel_key) 71 | ) 72 | loss_gan = -pred['is_real'].mean() # negate to maximize 73 | loss_gan *= self.opt.w_gan_loss 74 | loss = loss_sup + loss_gan 75 | loss_data = {} 76 | loss_data['sup'] = loss_sup.item() 77 | loss_data['gan'] = loss_gan.item() 78 | loss_data['loss'] = loss.item() 79 | return loss, loss_data 80 | 81 | 82 | class Net(nn.Module): 83 | """ 84 | 3D Estimator D of GAN 85 | 2.5D --------> 3D --------> real/fake 86 | finetuned fixed 87 | """ 88 | 89 | def __init__(self, marrnet2_path=None, gan_path=None): 90 | super().__init__() 91 | # Init MarrNet-2 and load weights 92 | self.marrnet2 = Marrnet2(4) 93 | self.marrnet2_noft = Marrnet2(4) 94 | if marrnet2_path: 95 | state_dicts = torch.load(marrnet2_path) 96 | state_dict = state_dicts['nets'][0] 97 | self.marrnet2.load_state_dict(state_dict) 98 | self.marrnet2_noft.load_state_dict(state_dict) 99 | # Init discriminator and load weights 100 | self.d = D() 101 | if gan_path: 102 | state_dicts = torch.load(gan_path) 103 | self.d.load_state_dict(state_dicts['nets'][1]) 104 | # Fix D, but finetune MarrNet-2 105 | for p in self.d.parameters(): 106 | p.requires_grad = False 107 | for p in self.marrnet2_noft.parameters(): 108 | p.requires_grad = False 109 | for p in self.marrnet2.parameters(): 110 | p.requires_grad = True 111 | self.sigmoid = nn.Sigmoid() 112 | 113 | def forward(self, input_struct): 114 | pred = {} 115 | pred['voxel_noft'] = self.marrnet2_noft(input_struct) # unfinetuned 116 | pred['voxel'] = self.marrnet2(input_struct) 117 | pred['is_real'] = self.d(self.sigmoid(pred['voxel'])) 118 | return pred 119 | 120 | 121 | class Model_test(Model): 122 | @classmethod 123 | def add_arguments(cls, parser): 124 | parser, unique_params = Model.add_arguments(parser) 125 | parser.add_argument( 126 | '--marrnet1_file', 127 | type=str, required=True, 128 | help="Path to pretrained MarrNet-1" 129 | ) 130 | return parser, unique_params 131 | 132 | def __init__(self, opt, logger): 133 | opt.canon_sup = True # dummy, for network init only 134 | super().__init__(opt, logger) 135 | self.requires = ['rgb', 'mask'] # mask for bbox cropping only 136 | self.input_names = ['rgb'] 137 | self.init_vars(add_path=True) 138 | self.output_dir = opt.output_dir 139 | # Load MarrNet-2 and D (though unused at test time) 140 | self.load_state_dict(opt.net_file, load_optimizer='auto') 141 | # Load MarrNet-1 whose outputs are inputs to D-tuned MarrNet-2 142 | opt.pred_depth_minmax = True # dummy 143 | self.marrnet1 = Marrnet1_model(opt, logger) 144 | self.marrnet1.load_state_dict(opt.marrnet1_file) 145 | self._nets.append(self.marrnet1.net) 146 | 147 | def __str__(self): 148 | return "Testing ShapeHD" 149 | 150 | @classmethod 151 | def preprocess_wrapper(cls, in_dict): 152 | silhou_thres = 0.95 153 | in_size = 480 154 | pad = 85 155 | im = in_dict['rgb'] 156 | mask = in_dict['silhou'] 157 | bbox = util_img.get_bbox(mask, th=silhou_thres) 158 | im_crop = util_img.crop(im, bbox, in_size, pad, pad_zero=False) 159 | in_dict['rgb'] = im_crop 160 | del in_dict['silhou'] # just for cropping -- done its job 161 | # Now the image is just like those we rendered 162 | out_dict = cls.preprocess(in_dict, mode='test') 163 | return out_dict 164 | 165 | def test_on_batch(self, batch_i, batch): 166 | outdir = join(self.output_dir, 'batch%04d' % batch_i) 167 | makedirs(outdir, exist_ok=True) 168 | # Forward MarrNet-1 169 | pred1 = self.marrnet1.predict(batch, load_gt=False, no_grad=True) 170 | # Forward MarrNet-2 171 | for net_name in ('marrnet2', 'marrnet2_noft'): 172 | net = getattr(self.net, net_name) 173 | net.silhou_thres = self.pred_silhou_thres * self.scale_25d 174 | self.input_names = ['depth', 'normal', 'silhou'] 175 | pred2 = self.predict(pred1, load_gt=False, no_grad=True) 176 | # Pack, visualize, and save outputs 177 | output = self.pack_output(pred1, pred2, batch) 178 | self.visualizer.visualize(output, batch_i, outdir) 179 | np.savez(outdir + '.npz', **output) 180 | 181 | def pack_output(self, pred1, pred2, batch): 182 | out = {} 183 | # MarrNet-1 outputs 184 | pred_normal = pred1['normal'].detach().cpu() 185 | pred_silhou = pred1['silhou'].detach().cpu() 186 | pred_depth = pred1['depth'].detach().cpu() 187 | out['rgb_path'] = batch['rgb_path'] 188 | out['rgb'] = util_img.denormalize_colors(batch['rgb'].detach().numpy()) 189 | pred_silhou = self.postprocess(pred_silhou) 190 | pred_silhou = torch.clamp(pred_silhou, 0, 1) 191 | pred_silhou[pred_silhou < 0] = 0 192 | out['pred_silhou'] = pred_silhou.numpy() 193 | out['pred_normal'] = self.postprocess( 194 | pred_normal, bg=1.0, input_mask=pred_silhou 195 | ).numpy() 196 | out['pred_depth'] = self.postprocess( 197 | pred_depth, bg=0.0, input_mask=pred_silhou 198 | ).numpy() 199 | # D-tuned MarrNet-2 outputs 200 | out['pred_voxel'] = pred2['voxel'].detach().cpu().numpy() 201 | out['pred_voxel_noft'] = pred2['voxel_noft'].detach().cpu().numpy() 202 | return out 203 | -------------------------------------------------------------------------------- /models/wgangp.py: -------------------------------------------------------------------------------- 1 | from os import makedirs 2 | from os.path import join 3 | from time import time 4 | import numpy as np 5 | import torch 6 | from networks.networks import VoxelGenerator, VoxelDiscriminator 7 | from .netinterface import NetInterface 8 | 9 | 10 | class Model(NetInterface): 11 | @classmethod 12 | def add_arguments(cls, parser): 13 | parser.add_argument( 14 | '--canon_voxel', 15 | action='store_true', 16 | help="Generate/discriminate canonical-pose voxels" 17 | ) 18 | parser.add_argument( 19 | '--wgangp_lambda', 20 | type=float, 21 | default=10, 22 | help="WGANGP gradient penalty coefficient" 23 | ) 24 | parser.add_argument( 25 | '--wgangp_norm', 26 | type=float, 27 | default=1, 28 | help="WGANGP gradient penalty norm" 29 | ) 30 | parser.add_argument( 31 | '--gan_d_iter', 32 | type=int, 33 | default=1, 34 | help="# iterations D is trained per G's iteration" 35 | ) 36 | return parser, set() 37 | 38 | def __init__(self, opt, logger): 39 | super().__init__(opt, logger) 40 | assert opt.canon_voxel, "GAN requires canonical-pose voxels to work" 41 | self.requires = ['voxel_canon'] 42 | self.nz = 200 43 | self.net_g = G(self.nz) 44 | self.net_d = D() 45 | self._nets = [self.net_g, self.net_d] 46 | # Optimizers 47 | self.optim_params = dict() 48 | self.optim_params['betas'] = (opt.adam_beta1, opt.adam_beta2) 49 | self.optimizer_g = self.adam( 50 | self.net_g.parameters(), 51 | lr=opt.lr, 52 | **self.optim_params 53 | ) 54 | self.optimizer_d = self.adam( 55 | self.net_d.parameters(), 56 | lr=opt.lr, 57 | **self.optim_params 58 | ) 59 | self._optimizers = [self.optimizer_g, self.optimizer_d] 60 | # 61 | self.opt = opt 62 | self.preprocess = None 63 | self._metrics = ['err_d_real', 'err_d_fake', 'err_d_gp', 'err_d', 'err_g', 'loss'] 64 | if opt.log_time: 65 | self._metrics += ['t_d_real', 't_d_fake', 't_d_grad', 't_g'] 66 | self.input_names = ['voxel_canon'] 67 | self.aux_names = ['one', 'neg_one'] 68 | self.init_vars(add_path=True) 69 | self.init_weight(self.net_d) 70 | self.init_weight(self.net_g) 71 | self._last_err_g = None 72 | 73 | def __str__(self): 74 | s = "3D-WGANGP" 75 | return s 76 | 77 | def _train_on_batch(self, epoch, batch_idx, batch): 78 | net_d, net_g = self.net_d, self.net_g 79 | opt_d, opt_g = self.optimizer_d, self.optimizer_g 80 | one = self._aux.one 81 | neg_one = self._aux.neg_one 82 | real = batch['voxel_canon'].cuda() 83 | batch_size = real.shape[0] 84 | batch_log = {'size': batch_size} 85 | 86 | # Train D ... 87 | net_d.zero_grad() 88 | for p in net_d.parameters(): 89 | p.requires_grad = True 90 | for p in net_g.parameters(): 91 | p.requires_grad = False 92 | # with real 93 | t0 = time() 94 | err_d_real = self.net_d(real).mean() 95 | err_d_real.backward(neg_one) 96 | batch_log['err_d_real'] = -err_d_real.item() 97 | d_real_t = time() - t0 98 | # with fake 99 | t0 = time() 100 | with torch.no_grad(): 101 | _, fake = self.net_g(batch_size) 102 | err_d_fake = self.net_d(fake).mean() 103 | err_d_fake.backward(one) 104 | batch_log['err_d_fake'] = err_d_fake.item() 105 | d_fake_t = time() - t0 106 | # with grad penalty 107 | t0 = time() 108 | if self.opt.wgangp_lambda > 0: 109 | grad_penalty = self.calc_grad_penalty(real, fake) 110 | grad_penalty.backward() 111 | batch_log['err_d_gp'] = grad_penalty.item() 112 | else: 113 | batch_log['err_d_gp'] = 0 114 | batch_log['err_d'] = batch_log['err_d_fake'] + batch_log['err_d_real'] \ 115 | + batch_log['err_d_gp'] 116 | d_grad_t = time() - t0 117 | opt_d.step() 118 | 119 | # Train G 120 | t0 = time() 121 | for p in net_d.parameters(): 122 | p.requires_grad = False 123 | for p in net_g.parameters(): 124 | p.requires_grad = True 125 | net_g.zero_grad() 126 | if batch_idx % self.opt.gan_d_iter == 0: 127 | _, gen = self.net_g(batch_size) 128 | err_g = self.net_d(gen).mean() 129 | err_g.backward(neg_one) 130 | opt_g.step() 131 | batch_log['err_g'] = -err_g.item() 132 | self._last_err_g = batch_log['err_g'] 133 | else: 134 | batch_log['err_g'] = self._last_err_g 135 | g_t = time() - t0 136 | 137 | if self.opt.log_time: 138 | batch_log['t_d_real'] = d_real_t 139 | batch_log['t_d_fake'] = d_fake_t 140 | batch_log['t_d_grad'] = d_grad_t 141 | batch_log['t_g'] = g_t 142 | return batch_log 143 | 144 | def calc_grad_penalty(self, real, fake): 145 | alpha = torch.rand(real.shape[0], 1) 146 | alpha = alpha.expand( 147 | real.shape[0], real.nelement() // real.shape[0] 148 | ).contiguous().view(*real.shape).cuda() 149 | inter = alpha * real + (1 - alpha) * fake 150 | inter.requires_grad = True 151 | err_d_inter = self.net_d(inter) 152 | grads = torch.autograd.grad( 153 | outputs=err_d_inter, 154 | inputs=inter, 155 | grad_outputs=torch.ones(err_d_inter.size()).cuda(), 156 | create_graph=True, 157 | retain_graph=True, 158 | only_inputs=True 159 | )[0] 160 | grads = grads.view(grads.size(0), -1) 161 | grad_penalty = ( 162 | ((grads + 1e-16).norm(2, dim=1) - self.opt.wgangp_norm) ** 2 163 | ).mean() * self.opt.wgangp_lambda 164 | return grad_penalty 165 | 166 | def _vali_on_batch(self, epoch, batch_idx, batch): 167 | batch_size = batch['voxel_canon'].shape[0] 168 | batch_log = {'size': batch_size} 169 | with torch.no_grad(): 170 | noise, gen = self.net_g(batch_size) 171 | disc = self.net_d(gen) 172 | batch_log['loss'] = -disc.mean().item() 173 | # Save and visualize 174 | if np.mod(epoch, self.opt.vis_every_train) == 0: 175 | if batch_idx < self.opt.vis_batches_train: 176 | outdir = join(self.full_logdir, 'epoch%04d_vali' % epoch) 177 | makedirs(outdir, exist_ok=True) 178 | output = self.pack_output(noise, gen, disc) 179 | self.visualizer.visualize(output, batch_idx, outdir) 180 | np.savez(join(outdir, 'batch%04d' % batch_idx), **output) 181 | return batch_log 182 | 183 | @staticmethod 184 | def pack_output(noise, gen, disc): 185 | out = { 186 | 'noise': noise.cpu().numpy(), 187 | 'gen_voxel': gen.cpu().numpy(), 188 | 'disc': disc.cpu().numpy(), 189 | } 190 | return out 191 | 192 | 193 | class G(VoxelGenerator): 194 | def __init__(self, nz): 195 | super().__init__(nz=nz, nf=64, bias=False, res=128) 196 | self.nz = nz 197 | self.noise = torch.FloatTensor().cuda() 198 | 199 | def forward(self, batch_size): 200 | x = self.noise 201 | x.resize_(batch_size, self.nz, 1, 1, 1).normal_(0, 1) 202 | y = super().forward(x) 203 | return x, y 204 | 205 | 206 | class D(VoxelDiscriminator): 207 | def __init__(self): 208 | super().__init__(nf=64, bias=False, res=128) 209 | 210 | def forward(self, x): 211 | if x.dim() == 4: 212 | x.unsqueeze_(1) 213 | y = super().forward(x) 214 | return y 215 | -------------------------------------------------------------------------------- /networks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiumingzhang/GenRe-ShapeHD/ee42add2707de509b5914ab444ae91b832f75981/networks/__init__.py -------------------------------------------------------------------------------- /networks/networks.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .revresnet import resnet18 3 | from torch import cat 4 | 5 | 6 | class ImageEncoder(nn.Module): 7 | """ 8 | Used for 2.5D maps to 3D voxels 9 | """ 10 | 11 | def __init__(self, input_nc, encode_dims=200): 12 | super().__init__() 13 | resnet_m = resnet18(pretrained=True) 14 | resnet_m.conv1 = nn.Conv2d( 15 | input_nc, 64, 7, stride=2, padding=3, bias=False 16 | ) 17 | resnet_m.avgpool = nn.AdaptiveAvgPool2d(1) 18 | resnet_m.fc = nn.Linear(512, encode_dims) 19 | self.main = nn.Sequential(resnet_m) 20 | 21 | def forward(self, x): 22 | return self.main(x) 23 | 24 | 25 | class VoxelDecoder(nn.Module): 26 | """ 27 | Used for 2.5D maps to 3D voxels 28 | """ 29 | 30 | def __init__(self, n_dims=200, nf=512): 31 | super().__init__() 32 | self.main = nn.Sequential( 33 | # volconv1 34 | deconv3d_add3(n_dims, nf, True), 35 | batchnorm3d(nf), 36 | relu(), 37 | # volconv2 38 | deconv3d_2x(nf, nf // 2, True), 39 | batchnorm3d(nf // 2), 40 | relu(), 41 | # volconv3 42 | nn.Sequential(), # NOTE: no-op for backward compatibility; consider removing 43 | nn.Sequential(), # NOTE 44 | deconv3d_2x(nf // 2, nf // 4, True), 45 | batchnorm3d(nf // 4), 46 | relu(), 47 | # volconv4 48 | deconv3d_2x(nf // 4, nf // 8, True), 49 | batchnorm3d(nf // 8), 50 | relu(), 51 | # volconv5 52 | deconv3d_2x(nf // 8, nf // 16, True), 53 | batchnorm3d(nf // 16), 54 | relu(), 55 | # volconv6 56 | deconv3d_2x(nf // 16, 1, True) 57 | ) 58 | 59 | def forward(self, x): 60 | x_vox = x.view(x.size(0), -1, 1, 1, 1) 61 | return self.main(x_vox) 62 | 63 | 64 | class VoxelGenerator(nn.Module): 65 | def __init__(self, nz=200, nf=64, bias=False, res=128): 66 | super().__init__() 67 | layers = [ 68 | # nzx1x1x1 69 | deconv3d_add3(nz, nf * 8, bias), 70 | batchnorm3d(nf * 8), 71 | relu(), 72 | # (nf*8)x4x4x4 73 | deconv3d_2x(nf * 8, nf * 4, bias), 74 | batchnorm3d(nf * 4), 75 | relu(), 76 | # (nf*4)x8x8x8 77 | deconv3d_2x(nf * 4, nf * 2, bias), 78 | batchnorm3d(nf * 2), 79 | relu(), 80 | # (nf*2)x16x16x16 81 | deconv3d_2x(nf * 2, nf, bias), 82 | batchnorm3d(nf), 83 | relu(), 84 | # nfx32x32x32 85 | ] 86 | if res == 64: 87 | layers.append(deconv3d_2x(nf, 1, bias)) 88 | # 1x64x64x64 89 | elif res == 128: 90 | layers += [ 91 | deconv3d_2x(nf, nf, bias), 92 | batchnorm3d(nf), 93 | relu(), 94 | # nfx64x64x64 95 | deconv3d_2x(nf, 1, bias), 96 | # 1x128x128x128 97 | ] 98 | else: 99 | raise NotImplementedError(res) 100 | layers.append(nn.Sigmoid()) 101 | self.main = nn.Sequential(*layers) 102 | 103 | def forward(self, x): 104 | return self.main(x) 105 | 106 | 107 | class VoxelDiscriminator(nn.Module): 108 | def __init__(self, nf=64, bias=False, res=128): 109 | super().__init__() 110 | layers = [ 111 | # 1x64x64x64 112 | conv3d_half(1, nf, bias), 113 | relu_leaky(), 114 | # nfx32x32x32 115 | conv3d_half(nf, nf * 2, bias), 116 | # batchnorm3d(nf * 2), 117 | relu_leaky(), 118 | # (nf*2)x16x16x16 119 | conv3d_half(nf * 2, nf * 4, bias), 120 | # batchnorm3d(nf * 4), 121 | relu_leaky(), 122 | # (nf*4)x8x8x8 123 | conv3d_half(nf * 4, nf * 8, bias), 124 | # batchnorm3d(nf * 8), 125 | relu_leaky(), 126 | # (nf*8)x4x4 127 | conv3d_minus3(nf * 8, 1, bias), 128 | # 1x1x1 129 | ] 130 | if res == 64: 131 | pass 132 | elif res == 128: 133 | extra_layers = [ 134 | conv3d_half(nf, nf, bias), 135 | relu_leaky(), 136 | ] 137 | layers = layers[:2] + extra_layers + layers[2:] 138 | else: 139 | raise NotImplementedError(res) 140 | self.main = nn.Sequential(*layers) 141 | 142 | def forward(self, x): 143 | y = self.main(x) 144 | return y.view(-1, 1).squeeze(1) 145 | 146 | 147 | class Unet_3D(nn.Module): 148 | def __init__(self, nf=20, in_channel=2, no_linear=False): 149 | super(Unet_3D, self).__init__() 150 | self.nf = nf 151 | self.enc1 = Conv3d_block(in_channel, nf, 8, 2, 3) # =>64 152 | self.enc2 = Conv3d_block(nf, 2 * nf, 4, 2, 1) # =>32 153 | self.enc3 = Conv3d_block(2 * nf, 4 * nf, 4, 2, 1) # =>16 154 | self.enc4 = Conv3d_block(4 * nf, 8 * nf, 4, 2, 1) # =>8 155 | self.enc5 = Conv3d_block(8 * nf, 16 * nf, 4, 2, 1) # =>4 156 | self.enc6 = Conv3d_block(16 * nf, 32 * nf, 4, 1, 0) # =>1 157 | self.full_conv_block = nn.Sequential( 158 | nn.Linear(32 * nf, 32 * nf), 159 | nn.LeakyReLU(), 160 | ) 161 | self.dec1 = Deconv3d_skip(32 * 2 * nf, 16 * nf, 4, 1, 0, 0) # =>4 162 | self.dec2 = Deconv3d_skip(16 * 2 * nf, 8 * nf, 4, 2, 1, 0) # =>8 163 | self.dec3 = Deconv3d_skip(8 * 2 * nf, 4 * nf, 4, 2, 1, 0) # =>16 164 | self.dec4 = Deconv3d_skip(4 * 2 * nf, 2 * nf, 4, 2, 1, 0) # =>32 165 | self.dec5 = Deconv3d_skip(4 * nf, nf, 8, 2, 3, 0) # =>64 166 | self.dec6 = Deconv3d_skip( 167 | 2 * nf, 1, 4, 2, 1, 0, is_activate=False) # =>128 168 | self.no_linear = no_linear 169 | 170 | def forward(self, x): 171 | enc1 = self.enc1(x) 172 | enc2 = self.enc2(enc1) 173 | enc3 = self.enc3(enc2) 174 | enc4 = self.enc4(enc3) 175 | enc5 = self.enc5(enc4) 176 | enc6 = self.enc6(enc5) 177 | # print(enc6.size()) 178 | if not self.no_linear: 179 | flatten = enc6.view(enc6.size()[0], self.nf * 32) 180 | bottleneck = self.full_conv_block(flatten) 181 | bottleneck = bottleneck.view(enc6.size()[0], self.nf * 32, 1, 1, 1) 182 | dec1 = self.dec1(bottleneck, enc6) 183 | else: 184 | dec1 = self.dec1(enc6, enc6) 185 | dec2 = self.dec2(dec1, enc5) 186 | dec3 = self.dec3(dec2, enc4) 187 | dec4 = self.dec4(dec3, enc3) 188 | dec5 = self.dec5(dec4, enc2) 189 | out = self.dec6(dec5, enc1) 190 | return out 191 | 192 | 193 | class Conv3d_block(nn.Module): 194 | def __init__(self, ncin, ncout, kernel_size, stride, pad, dropout=False): 195 | super().__init__() 196 | self.net = nn.Sequential( 197 | nn.Conv3d(ncin, ncout, kernel_size, stride, pad), 198 | nn.BatchNorm3d(ncout), 199 | nn.LeakyReLU() 200 | ) 201 | 202 | def forward(self, x): 203 | return self.net(x) 204 | 205 | 206 | class Deconv3d_skip(nn.Module): 207 | def __init__(self, ncin, ncout, kernel_size, stride, pad, extra=0, is_activate=True): 208 | super(Deconv3d_skip, self).__init__() 209 | if is_activate: 210 | self.net = nn.Sequential( 211 | nn.ConvTranspose3d(ncin, ncout, kernel_size, 212 | stride, pad, extra), 213 | nn.BatchNorm3d(ncout), 214 | nn.LeakyReLU() 215 | ) 216 | else: 217 | self.net = nn.ConvTranspose3d( 218 | ncin, ncout, kernel_size, stride, pad, extra) 219 | 220 | def forward(self, x, skip_in): 221 | y = cat((x, skip_in), dim=1) 222 | return self.net(y) 223 | 224 | 225 | class ViewAsLinear(nn.Module): 226 | @staticmethod 227 | def forward(x): 228 | return x.view(x.shape[0], -1) 229 | 230 | 231 | def relu(): 232 | return nn.ReLU(inplace=True) 233 | 234 | 235 | def relu_leaky(): 236 | return nn.LeakyReLU(0.2, inplace=True) 237 | 238 | 239 | def maxpool(): 240 | return nn.MaxPool2d(3, stride=2, padding=0) 241 | 242 | 243 | def dropout(): 244 | return nn.Dropout(p=0.5, inplace=False) 245 | 246 | 247 | def conv3d_half(n_ch_in, n_ch_out, bias): 248 | return nn.Conv3d( 249 | n_ch_in, n_ch_out, 4, stride=2, padding=1, dilation=1, groups=1, bias=bias 250 | ) 251 | 252 | 253 | def deconv3d_2x(n_ch_in, n_ch_out, bias): 254 | return nn.ConvTranspose3d( 255 | n_ch_in, n_ch_out, 4, stride=2, padding=1, dilation=1, groups=1, bias=bias 256 | ) 257 | 258 | 259 | def conv3d_minus3(n_ch_in, n_ch_out, bias): 260 | return nn.Conv3d( 261 | n_ch_in, n_ch_out, 4, stride=1, padding=0, dilation=1, groups=1, bias=bias 262 | ) 263 | 264 | 265 | def deconv3d_add3(n_ch_in, n_ch_out, bias): 266 | return nn.ConvTranspose3d( 267 | n_ch_in, n_ch_out, 4, stride=1, padding=0, dilation=1, groups=1, bias=bias 268 | ) 269 | 270 | 271 | def batchnorm1d(n_feat): 272 | return nn.BatchNorm1d(n_feat, eps=1e-5, momentum=0.1, affine=True) 273 | 274 | 275 | def batchnorm(n_feat): 276 | return nn.BatchNorm2d(n_feat, eps=1e-5, momentum=0.1, affine=True) 277 | 278 | 279 | def batchnorm3d(n_feat): 280 | return nn.BatchNorm3d(n_feat, eps=1e-5, momentum=0.1, affine=True) 281 | 282 | 283 | def fc(n_in, n_out): 284 | return nn.Linear(n_in, n_out, bias=True) 285 | -------------------------------------------------------------------------------- /networks/revresnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is an implementation of a U-Net using ResNet-18 blocks 3 | """ 4 | import torch 5 | from torch import nn 6 | from torchvision.models import resnet18 7 | 8 | 9 | def deconv3x3(in_planes, out_planes, stride=1, output_padding=0): 10 | return nn.ConvTranspose2d( 11 | in_planes, 12 | out_planes, 13 | kernel_size=3, 14 | stride=stride, 15 | padding=1, 16 | bias=False, 17 | output_padding=output_padding 18 | ) 19 | 20 | 21 | class RevBasicBlock(nn.Module): 22 | expansion = 1 23 | 24 | def __init__(self, inplanes, planes, stride=1, upsample=None): 25 | super(RevBasicBlock, self).__init__() 26 | self.deconv1 = deconv3x3(inplanes, planes, stride=1) 27 | # Note that in ResNet, the stride is on the second layer 28 | # Here we put it on the first layer as the mirrored block 29 | self.bn1 = nn.BatchNorm2d(planes) 30 | self.relu = nn.ReLU(inplace=True) 31 | self.deconv2 = deconv3x3(planes, planes, stride=stride, 32 | output_padding=1 if stride > 1 else 0) 33 | self.bn2 = nn.BatchNorm2d(planes) 34 | self.upsample = upsample 35 | self.stride = stride 36 | 37 | def forward(self, x): 38 | residual = x 39 | out = self.deconv1(x) 40 | out = self.bn1(out) 41 | out = self.relu(out) 42 | out = self.deconv2(out) 43 | out = self.bn2(out) 44 | if self.upsample is not None: 45 | residual = self.upsample(x) 46 | out += residual 47 | out = self.relu(out) 48 | return out 49 | 50 | 51 | class RevBottleneck(nn.Module): 52 | expansion = 4 53 | 54 | def __init__(self, inplanes, planes, stride=1, upsample=None): 55 | super(RevBottleneck, self).__init__() 56 | bottleneck_planes = int(inplanes / 4) 57 | self.deconv1 = nn.ConvTranspose2d( 58 | inplanes, 59 | bottleneck_planes, 60 | kernel_size=1, 61 | bias=False, 62 | stride=1 63 | ) # conv and deconv are the same when kernel size is 1 64 | self.bn1 = nn.BatchNorm2d(bottleneck_planes) 65 | self.deconv2 = nn.ConvTranspose2d( 66 | bottleneck_planes, 67 | bottleneck_planes, 68 | kernel_size=3, 69 | stride=1, 70 | padding=1, 71 | bias=False 72 | ) 73 | self.bn2 = nn.BatchNorm2d(bottleneck_planes) 74 | self.deconv3 = nn.ConvTranspose2d( 75 | bottleneck_planes, 76 | planes, 77 | kernel_size=1, 78 | bias=False, 79 | stride=stride, 80 | output_padding=1 if stride > 0 else 0 81 | ) 82 | self.bn3 = nn.BatchNorm2d(planes) 83 | self.relu = nn.ReLU(inplace=True) 84 | self.upsample = upsample 85 | self.stride = stride 86 | 87 | def forward(self, x): 88 | residual = x 89 | out = self.deconv1(x) 90 | out = self.bn1(out) 91 | out = self.relu(out) 92 | out = self.deconv2(out) 93 | out = self.bn2(out) 94 | out = self.relu(out) 95 | out = self.deconv3(out) 96 | out = self.bn3(out) 97 | if self.upsample is not None: 98 | residual = self.upsample(x) 99 | out += residual 100 | out = self.relu(out) 101 | return out 102 | 103 | 104 | class RevResNet(nn.Module): 105 | def __init__(self, block, layers, planes, inplanes=None, out_planes=5): 106 | """ 107 | planes: # output channels for each block 108 | inplanes: # input channels for the input at each layer 109 | If missing, it will be inferred. 110 | """ 111 | if inplanes is None: 112 | inplanes = [512] 113 | self.inplanes = inplanes[0] 114 | super(RevResNet, self).__init__() 115 | inplanes_after_blocks = inplanes[4] if len(inplanes) > 4 else planes[3] 116 | self.deconv1 = nn.ConvTranspose2d( 117 | inplanes_after_blocks, 118 | planes[3], 119 | kernel_size=3, 120 | stride=2, 121 | padding=1, 122 | output_padding=1 123 | ) 124 | self.deconv2 = nn.ConvTranspose2d( 125 | planes[3], 126 | out_planes, 127 | kernel_size=7, 128 | stride=2, 129 | padding=3, 130 | bias=False, 131 | output_padding=1 132 | ) 133 | self.bn1 = nn.BatchNorm2d(planes[3]) 134 | self.relu = nn.ReLU(inplace=True) 135 | self.layer1 = self._make_layer(block, planes[0], layers[0], stride=2) 136 | if len(inplanes) > 1: 137 | self.inplanes = inplanes[1] 138 | self.layer2 = self._make_layer(block, planes[1], layers[1], stride=2) 139 | if len(inplanes) > 2: 140 | self.inplanes = inplanes[2] 141 | self.layer3 = self._make_layer(block, planes[2], layers[2], stride=2) 142 | if len(inplanes) > 3: 143 | self.inplanes = inplanes[3] 144 | self.layer4 = self._make_layer(block, planes[3], layers[3]) 145 | 146 | def _make_layer(self, block, planes, blocks, stride=1): 147 | upsample = None 148 | if stride != 1 or self.inplanes != planes: 149 | upsample = nn.Sequential( 150 | nn.ConvTranspose2d( 151 | self.inplanes, 152 | planes, 153 | kernel_size=1, 154 | stride=stride, 155 | bias=False, 156 | output_padding=1 if stride > 1 else 0 157 | ), 158 | nn.BatchNorm2d(planes), 159 | ) 160 | layers = [] 161 | layers.append(block(self.inplanes, planes, stride, upsample)) 162 | self.inplanes = planes 163 | for _ in range(1, blocks): 164 | layers.append(block(self.inplanes, planes)) 165 | return nn.Sequential(*layers) 166 | 167 | def forward(self, x): 168 | x = self.layer1(x) 169 | x = self.layer2(x) 170 | x = self.layer3(x) 171 | x = self.layer4(x) 172 | x = self.deconv1(x) 173 | x = self.bn1(x) 174 | x = self.relu(x) 175 | x = self.deconv2(x) 176 | return x 177 | 178 | 179 | def revresnet18(**kwargs): 180 | model = RevResNet( 181 | RevBasicBlock, 182 | [2, 2, 2, 2], 183 | [512, 256, 128, 64], 184 | **kwargs 185 | ) 186 | return model 187 | 188 | 189 | def revuresnet18(**kwargs): 190 | """ 191 | Reverse ResNet-18 compatible with the U-Net setting 192 | """ 193 | model = RevResNet( 194 | RevBasicBlock, 195 | [2, 2, 2, 2], 196 | [256, 128, 64, 64], 197 | inplanes=[512, 512, 256, 128, 128], 198 | **kwargs 199 | ) 200 | return model 201 | 202 | 203 | def _num_parameters(net): 204 | return sum([ 205 | x.numel() for x in list(net.parameters()) 206 | ]) 207 | 208 | 209 | def main(): 210 | net = resnet18() 211 | revnet = revresnet18() 212 | net.avgpool = nn.AvgPool2d(kernel_size=8) 213 | for name, mod in net.named_children(): 214 | mod.__name = name 215 | mod.register_forward_hook( 216 | lambda mod, input, output: print(mod.__name, output.shape) 217 | ) 218 | for name, mod in revnet.named_children(): 219 | mod.__name = name 220 | mod.register_forward_hook( 221 | lambda mod, input, output: print(mod.__name, output.shape) 222 | ) 223 | # print(net) 224 | print('resnet', _num_parameters(net)) 225 | net(torch.zeros(2, 3, 256, 256)) 226 | print('') 227 | print('revresnet', _num_parameters(revnet)) 228 | # print(revnet) 229 | revnet(torch.zeros(2, 512, 8, 8)) 230 | print('') 231 | revunet = RevResNet(RevBasicBlock, [2, 2, 2, 2], [512, 512, 256, 128]) 232 | print('revunet', _num_parameters(revunet)) 233 | 234 | 235 | if __name__ == '__main__': 236 | main() 237 | -------------------------------------------------------------------------------- /networks/uresnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from networks.revresnet import revuresnet18, resnet18 4 | 5 | 6 | class Net(nn.Module): 7 | """ 8 | Used for RGB to 2.5D maps 9 | """ 10 | 11 | def __init__(self, out_planes, layer_names, input_planes=3): 12 | super().__init__() 13 | 14 | # Encoder 15 | module_list = list() 16 | resnet = resnet18(pretrained=True) 17 | in_conv = nn.Conv2d(input_planes, 64, kernel_size=7, stride=2, padding=3, 18 | bias=False) 19 | module_list.append( 20 | nn.Sequential( 21 | resnet.conv1 if input_planes == 3 else in_conv, 22 | resnet.bn1, 23 | resnet.relu, 24 | resnet.maxpool 25 | ) 26 | ) 27 | module_list.append(resnet.layer1) 28 | module_list.append(resnet.layer2) 29 | module_list.append(resnet.layer3) 30 | module_list.append(resnet.layer4) 31 | self.encoder = nn.ModuleList(module_list) 32 | self.encoder_out = None 33 | 34 | # Decoder 35 | self.decoders = {} 36 | for out_plane, layer_name in zip(out_planes, layer_names): 37 | module_list = list() 38 | revresnet = revuresnet18(out_planes=out_plane) 39 | module_list.append(revresnet.layer1) 40 | module_list.append(revresnet.layer2) 41 | module_list.append(revresnet.layer3) 42 | module_list.append(revresnet.layer4) 43 | module_list.append( 44 | nn.Sequential( 45 | revresnet.deconv1, 46 | revresnet.bn1, 47 | revresnet.relu, 48 | revresnet.deconv2 49 | ) 50 | ) 51 | module_list = nn.ModuleList(module_list) 52 | setattr(self, 'decoder_' + layer_name, module_list) 53 | self.decoders[layer_name] = module_list 54 | 55 | def forward(self, im): 56 | # Encode 57 | feat = im 58 | feat_maps = list() 59 | for f in self.encoder: 60 | feat = f(feat) 61 | feat_maps.append(feat) 62 | self.encoder_out = feat_maps[-1] 63 | # Decode 64 | outputs = {} 65 | for layer_name, decoder in self.decoders.items(): 66 | x = feat_maps[-1] 67 | for idx, f in enumerate(decoder): 68 | x = f(x) 69 | if idx < len(decoder) - 1: 70 | feat_map = feat_maps[-(idx + 2)] 71 | assert feat_map.shape[2:4] == x.shape[2:4] 72 | x = torch.cat((x, feat_map), dim=1) 73 | outputs[layer_name] = x 74 | return outputs 75 | 76 | 77 | class Net_inpaint(nn.Module): 78 | """ 79 | Used for RGB to 2.5D maps 80 | """ 81 | 82 | def __init__(self, out_planes, layer_names, input_planes=3): 83 | super().__init__() 84 | 85 | # Encoder 86 | module_list = list() 87 | resnet = resnet18(pretrained=True) 88 | in_conv = nn.Conv2d(input_planes, 64, kernel_size=7, stride=2, padding=3, 89 | bias=False) 90 | module_list.append( 91 | nn.Sequential( 92 | resnet.conv1 if input_planes == 3 else in_conv, 93 | resnet.bn1, 94 | resnet.relu, 95 | resnet.maxpool 96 | ) 97 | ) 98 | module_list.append(resnet.layer1) 99 | module_list.append(resnet.layer2) 100 | module_list.append(resnet.layer3) 101 | module_list.append(resnet.layer4) 102 | self.encoder = nn.ModuleList(module_list) 103 | self.encoder_out = None 104 | self.deconv2 = nn.ConvTranspose2d(64, 1, kernel_size=8, stride=2, padding=3, bias=False, output_padding=0) 105 | # Decoder 106 | self.decoders = {} 107 | for out_plane, layer_name in zip(out_planes, layer_names): 108 | module_list = list() 109 | revresnet = revuresnet18(out_planes=out_plane) 110 | module_list.append(revresnet.layer1) 111 | module_list.append(revresnet.layer2) 112 | module_list.append(revresnet.layer3) 113 | module_list.append(revresnet.layer4) 114 | module_list.append( 115 | nn.Sequential( 116 | revresnet.deconv1, 117 | revresnet.bn1, 118 | revresnet.relu, 119 | self.deconv2 120 | ) 121 | ) 122 | module_list = nn.ModuleList(module_list) 123 | setattr(self, 'decoder_' + layer_name, module_list) 124 | self.decoders[layer_name] = module_list 125 | 126 | def forward(self, im): 127 | # Encode 128 | feat = im 129 | feat_maps = list() 130 | for f in self.encoder: 131 | feat = f(feat) 132 | feat_maps.append(feat) 133 | self.encoder_out = feat_maps[-1] 134 | # Decode 135 | outputs = {} 136 | for layer_name, decoder in self.decoders.items(): 137 | x = feat_maps[-1] 138 | for idx, f in enumerate(decoder): 139 | x = f(x) 140 | if idx < len(decoder) - 1: 141 | feat_map = feat_maps[-(idx + 2)] 142 | assert feat_map.shape[2:4] == x.shape[2:4] 143 | x = torch.cat((x, feat_map), dim=1) 144 | outputs[layer_name] = x 145 | return outputs 146 | -------------------------------------------------------------------------------- /options/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiumingzhang/GenRe-ShapeHD/ee42add2707de509b5914ab444ae91b832f75981/options/__init__.py -------------------------------------------------------------------------------- /options/options_test.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | from datasets import get_dataset 4 | from models import get_model 5 | from options import options_train 6 | 7 | 8 | def add_general_arguments(parser): 9 | parser, _ = options_train.add_general_arguments(parser) 10 | 11 | # Dataset IO 12 | parser.add_argument('--input_rgb', type=str, required=True, 13 | help="Input RGB filename patterns, e.g., '/path/to/images/*_rgb.png'") 14 | parser.add_argument('--input_mask', type=str, required=True, 15 | help=("Corresponding mask filename patterns, e.g., '/path/to/images/*_mask.png'. " 16 | "For MarrNet/ShapeHD, masks are not required, so used only for bbox cropping. " 17 | "For GenRe, masks are input together with RGB")) 18 | 19 | # Network 20 | parser.add_argument('--net_file', type=str, required=True, 21 | help="Path to the trained network") 22 | 23 | # Output 24 | parser.add_argument('--output_dir', type=str, required=True, 25 | help="Output directory") 26 | parser.add_argument('--overwrite', action='store_true', 27 | help="Whether to overwrite the output folder if it exists") 28 | 29 | return parser 30 | 31 | 32 | def parse(add_additional_arguments=None): 33 | parser = argparse.ArgumentParser() 34 | parser = add_general_arguments(parser) 35 | if add_additional_arguments: 36 | parser, _ = add_additional_arguments(parser) 37 | opt_general, _ = parser.parse_known_args() 38 | net_name = opt_general.net 39 | del opt_general 40 | dataset_name = 'test' 41 | 42 | # Add parsers depending on dataset and models 43 | parser, _ = get_dataset(dataset_name).add_arguments(parser) 44 | parser, _ = get_model(net_name, test=True).add_arguments(parser) 45 | 46 | # Manually add '-h' after adding all parser arguments 47 | if '--printhelp' in sys.argv: 48 | sys.argv.append('-h') 49 | 50 | opt = parser.parse_args() 51 | return opt 52 | -------------------------------------------------------------------------------- /options/options_train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | import torch 4 | from util.util_print import str_warning 5 | from datasets import get_dataset 6 | from models import get_model 7 | 8 | 9 | def add_general_arguments(parser): 10 | # Parameters that will NOT be overwritten when resuming 11 | unique_params = {'gpu', 'resume', 'epoch', 'workers', 'batch_size', 'save_net', 'epoch_batches', 'logdir'} 12 | 13 | parser.add_argument('--gpu', default='0', type=str, 14 | help='gpu to use') 15 | parser.add_argument('--manual_seed', type=int, default=None, 16 | help='manual seed for randomness') 17 | parser.add_argument('--resume', type=int, default=0, 18 | help='resume training by loading checkpoint.pt or best.pt. Use 0 for training from scratch, -1 for last and -2 for previous best. Use positive number for a specific epoch. \ 19 | Most options will be overwritten to resume training with exactly same environment') 20 | parser.add_argument( 21 | '--suffix', default='', type=str, 22 | help="Suffix for `logdir` that will be formatted with `opt`, e.g., '{classes}_lr{lr}'" 23 | ) 24 | parser.add_argument('--epoch', type=int, default=0, 25 | help='number of epochs to train') 26 | 27 | # Dataset IO 28 | parser.add_argument('--dataset', type=str, default=None, 29 | help='dataset to use') 30 | parser.add_argument('--workers', type=int, default=4, 31 | help='number of data loading workers') 32 | parser.add_argument('--classes', default='chair', type=str, 33 | help='class to use') 34 | parser.add_argument('--batch_size', type=int, default=16, 35 | help='training batch size') 36 | parser.add_argument('--epoch_batches', default=None, type=int, help='number of batches used per epoch') 37 | parser.add_argument('--eval_batches', default=None, 38 | type=int, help='max number of batches used for evaluation per epoch') 39 | parser.add_argument('--eval_at_start', action='store_true', 40 | help='run evaluation before starting to train') 41 | parser.add_argument('--log_time', action='store_true', help='adding time log') 42 | 43 | # Network name 44 | parser.add_argument('--net', type=str, required=True, 45 | help='network type to use') 46 | 47 | # Optimizer 48 | parser.add_argument('--optim', type=str, default='adam', 49 | help='optimizer to use') 50 | parser.add_argument('--lr', type=float, default=1e-4, 51 | help='learning rate') 52 | parser.add_argument('--adam_beta1', type=float, default=0.5, 53 | help='beta1 of adam') 54 | parser.add_argument('--adam_beta2', type=float, default=0.9, 55 | help='beta2 of adam') 56 | parser.add_argument('--sgd_momentum', type=float, default=0.9, 57 | help="momentum factor of SGD") 58 | parser.add_argument('--sgd_dampening', type=float, default=0, 59 | help="dampening for momentum of SGD") 60 | parser.add_argument('--wdecay', type=float, default=0.0, 61 | help='weight decay') 62 | 63 | # Logging and visualization 64 | parser.add_argument('--logdir', type=str, default=None, 65 | help='Root directory for logging. Actual dir is [logdir]/[net_classes_dataset]/[expr_id]') 66 | parser.add_argument('--log_batch', action='store_true', 67 | help='Log batch loss') 68 | parser.add_argument('--expr_id', type=int, default=0, 69 | help='Experiment index. non-positive ones are overwritten by default. Use 0 for code test. ') 70 | parser.add_argument('--save_net', type=int, default=1, 71 | help='Period of saving network weights') 72 | parser.add_argument('--save_net_opt', action='store_true', 73 | help='Save optimizer state in regular network saving') 74 | parser.add_argument('--vis_every_vali', default=1, type=int, 75 | help="Visualize every N epochs during validation") 76 | parser.add_argument('--vis_every_train', default=1, type=int, 77 | help="Visualize every N epochs during training") 78 | parser.add_argument('--vis_batches_vali', type=int, default=10, 79 | help="# batches to visualize during validation") 80 | parser.add_argument('--vis_batches_train', type=int, default=10, 81 | help="# batches to visualize during training") 82 | parser.add_argument('--tensorboard', action='store_true', 83 | help='Use tensorboard for logging. If enabled, the output log will be at [logdir]/[tensorboard]/[net_classes_dataset]/[expr_id]') 84 | parser.add_argument('--vis_workers', default=4, type=int, help="# workers for the visualizer") 85 | parser.add_argument('--vis_param_f', default=None, type=str, 86 | help="Parameter file read by the visualizer on every batch; defaults to 'visualize/config.json'") 87 | 88 | return parser, unique_params 89 | 90 | 91 | def overwrite(opt, opt_f_old, unique_params): 92 | opt_dict = vars(opt) 93 | opt_dict_old = torch.load(opt_f_old) 94 | for k, v in opt_dict_old.items(): 95 | if k in opt_dict: 96 | if (k not in unique_params) and (opt_dict[k] != v): 97 | print(str_warning, "Overwriting %s for resuming training: %s -> %s" 98 | % (k, str(opt_dict[k]), str(v))) 99 | opt_dict[k] = v 100 | else: 101 | print(str_warning, "Ignoring %s, an old option that no longer exists" % k) 102 | opt = argparse.Namespace(**opt_dict) 103 | return opt 104 | 105 | 106 | def parse(add_additional_arguments=None): 107 | parser = argparse.ArgumentParser() 108 | parser, unique_params = add_general_arguments(parser) 109 | if add_additional_arguments is not None: 110 | parser, unique_params_additional = add_additional_arguments(parser) 111 | unique_params = unique_params.union(unique_params_additional) 112 | opt_general, _ = parser.parse_known_args() 113 | dataset_name, net_name = opt_general.dataset, opt_general.net 114 | del opt_general 115 | 116 | # Add parsers depending on dataset and models 117 | parser, unique_params_dataset = get_dataset(dataset_name).add_arguments(parser) 118 | parser, unique_params_model = get_model(net_name).add_arguments(parser) 119 | 120 | # Manually add '-h' after adding all parser arguments 121 | if '--printhelp' in sys.argv: 122 | sys.argv.append('-h') 123 | 124 | opt = parser.parse_args() 125 | unique_params = unique_params.union(unique_params_dataset) 126 | unique_params = unique_params.union(unique_params_model) 127 | return opt, unique_params 128 | -------------------------------------------------------------------------------- /scripts/finetune_marrnet.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Finetune MarrNet-2 with MarrNet-1 predictions 4 | 5 | outdir=./output/marrnet 6 | class=drc 7 | marrnet1=/path/to/marrnet1.pt 8 | marrnet2=/path/to/marrnet2.pt 9 | 10 | if [ $# -lt 1 ]; then 11 | echo "Usage: $0 gpu[ ...]" 12 | exit 1 13 | fi 14 | gpu="$1" 15 | shift # shift the remaining arguments 16 | 17 | set -e 18 | 19 | source activate shaperecon 20 | 21 | python train.py \ 22 | --net marrnet \ 23 | --marrnet1 "$marrnet1" \ 24 | --marrnet2 "$marrnet2" \ 25 | --dataset shapenet \ 26 | --classes "$class" \ 27 | --batch_size 4 \ 28 | --epoch_batches 2500 \ 29 | --eval_batches 5 \ 30 | --optim adam \ 31 | --lr 1e-3 \ 32 | --epoch 1000 \ 33 | --vis_batches_vali 10 \ 34 | --gpu "$gpu" \ 35 | --save_net 10 \ 36 | --workers 4 \ 37 | --logdir "$outdir" \ 38 | --suffix '{classes}' \ 39 | --tensorboard \ 40 | $* 41 | 42 | source deactivate 43 | -------------------------------------------------------------------------------- /scripts/finetune_shapehd.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Finetune ShapeHD 3D estimator with GAN losses 4 | 5 | outdir=./output/shapehd 6 | class=drc 7 | marrnet2=/path/to/marrnet2.pt 8 | gan=/path/to/gan.pt 9 | 10 | if [ $# -lt 2 ]; then 11 | echo "Usage: $0 gpu[ ...]" 12 | exit 1 13 | fi 14 | gpu="$1" 15 | shift # shift the remaining arguments 16 | 17 | set -e 18 | 19 | source activate shaperecon 20 | 21 | python train.py \ 22 | --net shapehd \ 23 | --marrnet2 "$marrnet2" \ 24 | --gan "$gan" \ 25 | --dataset shapenet \ 26 | --classes "$class" \ 27 | --canon_sup \ 28 | --w_gan_loss 1e-3 \ 29 | --batch_size 4 \ 30 | --epoch_batches 1000 \ 31 | --eval_batches 10 \ 32 | --optim adam \ 33 | --lr 1e-3 \ 34 | --epoch 1000 \ 35 | --vis_batches_vali 10 \ 36 | --gpu "$gpu" \ 37 | --save_net 1 \ 38 | --workers 4 \ 39 | --logdir "$outdir" \ 40 | --suffix '{classes}_w_ganloss{w_gan_loss}' \ 41 | --tensorboard \ 42 | $* 43 | 44 | source deactivate 45 | -------------------------------------------------------------------------------- /scripts/test_genre.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Test GenRe 4 | 5 | out_dir="./output/test" 6 | fullmodel=./downloads/models/full_model.pt 7 | rgb_pattern='./downloads/data/test/genre/*_rgb.*' 8 | mask_pattern='./downloads/data/test/genre/*_silhouette.*' 9 | 10 | if [ $# -lt 1 ]; then 11 | echo "Usage: $0 gpu[ ...]" 12 | exit 1 13 | fi 14 | gpu="$1" 15 | shift # shift the remaining arguments 16 | 17 | set -e 18 | 19 | source activate shaperecon 20 | 21 | python 'test.py' \ 22 | --net genre_full_model \ 23 | --net_file "$fullmodel" \ 24 | --input_rgb "$rgb_pattern" \ 25 | --input_mask "$mask_pattern" \ 26 | --output_dir "$out_dir" \ 27 | --suffix '{net}' \ 28 | --overwrite \ 29 | --workers 0 \ 30 | --batch_size 1 \ 31 | --vis_workers 4 \ 32 | --gpu "$gpu" \ 33 | $* 34 | 35 | source deactivate 36 | -------------------------------------------------------------------------------- /scripts/test_marrnet.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Test MarrNet 4 | 5 | out_dir="./output/test" 6 | marrnet=/path/to/marrnet.pt 7 | rgb_pattern='./downloads/data/test/shapehd/*_rgb.*' 8 | mask_pattern='./downloads/data/test/shapehd/*_mask.*' 9 | 10 | if [ $# -lt 1 ]; then 11 | echo "Usage: $0 gpu[ ...]" 12 | exit 1 13 | fi 14 | gpu="$1" 15 | shift # shift the remaining arguments 16 | 17 | set -e 18 | 19 | source activate shaperecon 20 | 21 | python 'test.py' \ 22 | --net marrnet \ 23 | --net_file "$marrnet" \ 24 | --input_rgb "$rgb_pattern" \ 25 | --input_mask "$mask_pattern" \ 26 | --output_dir "$out_dir" \ 27 | --suffix '{net}' \ 28 | --overwrite \ 29 | --workers 1 \ 30 | --batch_size 1 \ 31 | --vis_workers 4 \ 32 | --gpu "$gpu" \ 33 | $* 34 | 35 | source deactivate 36 | -------------------------------------------------------------------------------- /scripts/test_shapehd.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Test ShapeHD 4 | 5 | out_dir="./output/test" 6 | net1=./downloads/models/marrnet1_with_minmax.pt 7 | net2=./downloads/models/shapehd.pt 8 | rgb_pattern='./downloads/data/test/shapehd/*_rgb.*' 9 | mask_pattern='./downloads/data/test/shapehd/*_mask.*' 10 | 11 | if [ $# -lt 1 ]; then 12 | echo "Usage: $0 gpu[ ...]" 13 | exit 1 14 | fi 15 | gpu="$1" 16 | shift # shift the remaining arguments 17 | 18 | set -e 19 | 20 | 21 | source activate shaperecon 22 | 23 | python 'test.py' \ 24 | --net shapehd \ 25 | --net_file "$net2" \ 26 | --marrnet1_file "$net1" \ 27 | --input_rgb "$rgb_pattern" \ 28 | --input_mask "$mask_pattern" \ 29 | --output_dir "$out_dir" \ 30 | --suffix '{net}' \ 31 | --overwrite \ 32 | --workers 1 \ 33 | --batch_size 1 \ 34 | --vis_workers 4 \ 35 | --gpu "$gpu" \ 36 | $* 37 | 38 | source deactivate 39 | -------------------------------------------------------------------------------- /scripts/train_full_genre.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | outdir=./output/genre 4 | inpaint_path=/path/to/trained/inpaint.pt 5 | 6 | if [ $# -lt 2 ]; then 7 | echo "Usage: $0 gpu class[ ...]" 8 | exit 1 9 | fi 10 | gpu="$1" 11 | class="$2" 12 | shift # shift the remaining arguments 13 | shift 14 | 15 | set -e 16 | 17 | source activate shaperecon 18 | 19 | python train.py \ 20 | --net genre_full_model \ 21 | --pred_depth_minmax \ 22 | --dataset shapenet \ 23 | --classes "$class" \ 24 | --batch_size 4 \ 25 | --epoch_batches 1000 \ 26 | --eval_batches 30 \ 27 | --log_time \ 28 | --optim adam \ 29 | --lr 1e-4 \ 30 | --epoch 1000 \ 31 | --vis_batches_vali 10 \ 32 | --gpu "$gpu" \ 33 | --save_net 10 \ 34 | --workers 4 \ 35 | --logdir "$outdir" \ 36 | --suffix '{classes}' \ 37 | --tensorboard \ 38 | --surface_weight 10 \ 39 | --inpaint_path "$inpaint_path" \ 40 | $* 41 | 42 | source deactivate 43 | -------------------------------------------------------------------------------- /scripts/train_inpaint.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | outdir=./output/inpaint 4 | net1_path=/path/to/trained/marrnet1.pt 5 | 6 | if [ $# -lt 2 ]; then 7 | echo "Usage: $0 gpu class[ ...]" 8 | exit 1 9 | fi 10 | gpu="$1" 11 | class="$2" 12 | shift # shift the remaining arguments 13 | shift 14 | 15 | set -e 16 | 17 | source activate shaperecon 18 | 19 | python train.py \ 20 | --net depth_pred_with_sph_inpaint \ 21 | --pred_depth_minmax \ 22 | --dataset shapenet \ 23 | --classes "$class" \ 24 | --batch_size 4 \ 25 | --epoch_batches 2000 \ 26 | --eval_batches 10 \ 27 | --log_time \ 28 | --optim adam \ 29 | --lr 1e-4 \ 30 | --epoch 1000 \ 31 | --vis_batches_vali 10 \ 32 | --gpu "$gpu" \ 33 | --save_net 10 \ 34 | --workers 4 \ 35 | --logdir "$outdir" \ 36 | --suffix '{classes}' \ 37 | --tensorboard \ 38 | --net1_path "$net1_path" \ 39 | $* 40 | 41 | source deactivate 42 | -------------------------------------------------------------------------------- /scripts/train_marrnet1.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | outdir=./output/marrnet1 4 | 5 | if [ $# -lt 2 ]; then 6 | echo "Usage: $0 gpu class[ ...]" 7 | exit 1 8 | fi 9 | gpu="$1" 10 | class="$2" 11 | shift # shift the remaining arguments 12 | shift 13 | 14 | set -e 15 | 16 | source activate shaperecon 17 | 18 | python train.py \ 19 | --net marrnet1 \ 20 | --pred_depth_minmax \ 21 | --dataset shapenet \ 22 | --classes "$class" \ 23 | --batch_size 4 \ 24 | --epoch_batches 2500 \ 25 | --eval_batches 5 \ 26 | --log_time \ 27 | --optim adam \ 28 | --lr 1e-3 \ 29 | --epoch 1000 \ 30 | --vis_batches_vali 10 \ 31 | --gpu "$gpu" \ 32 | --save_net 10 \ 33 | --workers 4 \ 34 | --logdir "$outdir" \ 35 | --suffix '{classes}' \ 36 | --tensorboard \ 37 | $* 38 | 39 | source deactivate 40 | -------------------------------------------------------------------------------- /scripts/train_marrnet2.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | outdir=./output/marrnet2 4 | 5 | if [ $# -lt 2 ]; then 6 | echo "Usage: $0 gpu class[ ...]" 7 | exit 1 8 | fi 9 | gpu="$1" 10 | class="$2" 11 | shift # shift the remaining arguments 12 | shift 13 | 14 | set -e 15 | 16 | source activate shaperecon 17 | 18 | python train.py \ 19 | --net marrnet2 \ 20 | --dataset shapenet \ 21 | --classes "$class" \ 22 | --canon_sup \ 23 | --batch_size 4 \ 24 | --epoch_batches 2500 \ 25 | --eval_batches 5 \ 26 | --optim adam \ 27 | --lr 1e-3 \ 28 | --epoch 1000 \ 29 | --vis_batches_vali 10 \ 30 | --gpu "$gpu" \ 31 | --save_net 10 \ 32 | --workers 4 \ 33 | --logdir "$outdir" \ 34 | --suffix '{classes}_canon-{canon_sup}' \ 35 | --tensorboard \ 36 | $* 37 | 38 | source deactivate 39 | -------------------------------------------------------------------------------- /scripts/train_wgangp.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | outdir=./output/wgangp 4 | 5 | if [ $# -lt 2 ]; then 6 | echo "Usage: $0 gpu class[ ...]" 7 | exit 1 8 | fi 9 | gpu="$1" 10 | class="$2" 11 | shift # shift the remaining arguments 12 | shift 13 | 14 | set -e 15 | 16 | source activate shaperecon 17 | 18 | python train.py \ 19 | --net wgangp \ 20 | --canon_voxel \ 21 | --dataset shapenet \ 22 | --classes "$class" \ 23 | --batch_size 4 \ 24 | --epoch_batches 2500 \ 25 | --eval_batches 5 \ 26 | --log_time \ 27 | --optim adam \ 28 | --lr 1e-4 \ 29 | --epoch 1000 \ 30 | --vis_batches_vali 10 \ 31 | --gpu "$gpu" \ 32 | --save_net 10 \ 33 | --workers 4 \ 34 | --logdir "$outdir" \ 35 | --suffix '{classes}' \ 36 | --tensorboard \ 37 | $* 38 | 39 | source deactivate 40 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from shutil import rmtree 4 | from tqdm import tqdm 5 | import torch 6 | from options import options_test 7 | import datasets 8 | import models 9 | from util.util_print import str_error, str_stage, str_verbose 10 | import util.util_loadlib as loadlib 11 | from loggers import loggers 12 | 13 | 14 | print("Testing Pipeline") 15 | 16 | ################################################### 17 | 18 | print(str_stage, "Parsing arguments") 19 | opt = options_test.parse() 20 | opt.full_logdir = None 21 | print(opt) 22 | 23 | ################################################### 24 | 25 | print(str_stage, "Setting device") 26 | if opt.gpu == '-1': 27 | device = torch.device('cpu') 28 | else: 29 | loadlib.set_gpu(opt.gpu) 30 | device = torch.device('cuda') 31 | if opt.manual_seed is not None: 32 | loadlib.set_manual_seed(opt.manual_seed) 33 | 34 | ################################################### 35 | 36 | print(str_stage, "Setting up output directory") 37 | output_dir = opt.output_dir 38 | output_dir += ('_' + opt.suffix.format(**vars(opt))) \ 39 | if opt.suffix != '' else '' 40 | opt.output_dir = output_dir 41 | 42 | if os.path.isdir(output_dir): 43 | if opt.overwrite: 44 | rmtree(output_dir) 45 | else: 46 | raise ValueError(str_error + 47 | " %s already exists, but no overwrite flag" 48 | % output_dir) 49 | os.makedirs(output_dir) 50 | 51 | ################################################### 52 | 53 | print(str_stage, "Setting up loggers") 54 | logger_list = [ 55 | loggers.TerminateOnNaN(), 56 | ] 57 | logger = loggers.ComposeLogger(logger_list) 58 | 59 | ################################################### 60 | 61 | print(str_stage, "Setting up models") 62 | Model = models.get_model(opt.net, test=True) 63 | model = Model(opt, logger) 64 | model.to(device) 65 | model.eval() 66 | print(model) 67 | print("# model parameters: {:,d}".format(model.num_parameters())) 68 | 69 | ################################################### 70 | 71 | print(str_stage, "Setting up data loaders") 72 | start_time = time.time() 73 | Dataset = datasets.get_dataset('test') 74 | dataset = Dataset(opt, model=model) 75 | dataloader = torch.utils.data.DataLoader( 76 | dataset, 77 | batch_size=opt.batch_size, 78 | num_workers=opt.workers, 79 | pin_memory=True, 80 | drop_last=False, 81 | shuffle=False 82 | ) 83 | n_batches = len(dataloader) 84 | dataiter = iter(dataloader) 85 | print(str_verbose, "Time spent in data IO initialization: %.2fs" % 86 | (time.time() - start_time)) 87 | print(str_verbose, "# test points: " + str(len(dataset))) 88 | print(str_verbose, "# test batches: " + str(n_batches)) 89 | 90 | ################################################### 91 | 92 | print(str_stage, "Testing") 93 | for i in tqdm(range(n_batches)): 94 | batch = next(dataiter) 95 | model.test_on_batch(i, batch) 96 | -------------------------------------------------------------------------------- /toolbox/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiumingzhang/GenRe-ShapeHD/ee42add2707de509b5914ab444ae91b832f75981/toolbox/__init__.py -------------------------------------------------------------------------------- /toolbox/calc_prob/build.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | from torch.utils.ffi import create_extension 5 | 6 | 7 | this_file = os.path.dirname(os.path.realpath(__file__)) 8 | print(this_file) 9 | 10 | extra_compile_args = list() 11 | 12 | 13 | extra_objects = list() 14 | assert(torch.cuda.is_available()) 15 | sources = ['calc_prob/src/calc_prob.c'] 16 | headers = ['calc_prob/src/calc_prob.h'] 17 | defines = [('WITH_CUDA', True)] 18 | with_cuda = True 19 | 20 | extra_objects = ['calc_prob/src/calc_prob_kernel.cu.o'] 21 | extra_objects = [os.path.join(this_file, fname) for fname in extra_objects] 22 | 23 | ffi_params = { 24 | 'headers': headers, 25 | 'sources': sources, 26 | 'define_macros': defines, 27 | 'relative_to': __file__, 28 | 'with_cuda': with_cuda, 29 | 'extra_objects': extra_objects, 30 | 'include_dirs': [os.path.join(this_file, 'calc_prob/src')], 31 | 'extra_compile_args': extra_compile_args, 32 | } 33 | 34 | 35 | if __name__ == '__main__': 36 | ext = create_extension( 37 | 'calc_prob._ext.calc_prob_lib', 38 | package=False, 39 | **ffi_params) 40 | #from setuptools import setup 41 | # setup() 42 | ext.build() 43 | 44 | # ffi.build() 45 | -------------------------------------------------------------------------------- /toolbox/calc_prob/calc_prob/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiumingzhang/GenRe-ShapeHD/ee42add2707de509b5914ab444ae91b832f75981/toolbox/calc_prob/calc_prob/__init__.py -------------------------------------------------------------------------------- /toolbox/calc_prob/calc_prob/functions/__init__.py: -------------------------------------------------------------------------------- 1 | from .calc_prob import CalcStopProb 2 | -------------------------------------------------------------------------------- /toolbox/calc_prob/calc_prob/functions/calc_prob.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | from torch.autograd.function import once_differentiable 4 | from .._ext import calc_prob_lib 5 | from cffi import FFI 6 | ffi = FFI() 7 | 8 | 9 | class CalcStopProb(Function): 10 | @staticmethod 11 | def forward(ctx, prob_in): 12 | assert prob_in.dim() == 5 13 | assert prob_in.dtype == torch.float32 14 | assert prob_in.is_cuda 15 | stop_prob = prob_in.new(prob_in.shape) 16 | stop_prob.zero_() 17 | calc_prob_lib.calc_prob_forward(prob_in, stop_prob) 18 | ctx.save_for_backward(prob_in, stop_prob) 19 | return stop_prob 20 | 21 | @staticmethod 22 | @once_differentiable 23 | def backward(ctx, grad_in): 24 | prob_in, stop_prob = ctx.saved_tensors 25 | grad_out = grad_in.new(grad_in.shape) 26 | grad_out.zero_() 27 | stop_prob_weighted = stop_prob * grad_in 28 | calc_prob_lib.calc_prob_backward(prob_in, stop_prob_weighted, grad_out) 29 | return grad_out 30 | -------------------------------------------------------------------------------- /toolbox/calc_prob/calc_prob/src/calc_prob.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include "calc_prob.h" 5 | #include "calc_prob_kernel.h" 6 | 7 | extern THCState *state; 8 | 9 | int calc_prob_forward(THCudaTensor* prob_in, THCudaTensor* prob_out){ 10 | int success = 0; 11 | success = calc_prob_forward_wrap(state, prob_in, prob_out); 12 | // check for errors 13 | if (!success) { 14 | THError("aborting"); 15 | } 16 | return 1; 17 | } 18 | int calc_prob_backward(THCudaTensor* prob_in, THCudaTensor* stop_prob_weighted, THCudaTensor* grad_out){ 19 | int success = 0; 20 | success = calc_prob_backward_wrap(state, prob_in, stop_prob_weighted, grad_out); 21 | // check for errors 22 | if (!success) { 23 | THError("aborting"); 24 | } 25 | return 1; 26 | } 27 | -------------------------------------------------------------------------------- /toolbox/calc_prob/calc_prob/src/calc_prob.h: -------------------------------------------------------------------------------- 1 | int calc_prob_forward(THCudaTensor* prob_in, THCudaTensor* prob_out); 2 | int calc_prob_backward(THCudaTensor* prob_in, THCudaTensor* stop_prob_weighted, THCudaTensor* grad_out); 3 | -------------------------------------------------------------------------------- /toolbox/calc_prob/calc_prob/src/calc_prob_kernel.h: -------------------------------------------------------------------------------- 1 | #ifdef __cplusplus 2 | extern "C" { 3 | #endif 4 | int calc_prob_forward_wrap(THCState* state, THCudaTensor* prob_in, THCudaTensor* prob_out); 5 | int calc_prob_backward_wrap(THCState* state, THCudaTensor* prob_in, THCudaTensor* stop_prob_weighted, THCudaTensor* grad_out); 6 | #ifdef __cplusplus 7 | } 8 | #endif 9 | -------------------------------------------------------------------------------- /toolbox/calc_prob/clean.sh: -------------------------------------------------------------------------------- 1 | 2 | # ANSI color codes 3 | RS="\033[0m" # reset 4 | HC="\033[1m" # hicolor 5 | UL="\033[4m" # underline 6 | INV="\033[7m" # inverse background and foreground 7 | FBLK="\033[30m" # foreground black 8 | FRED="\033[31m" # foreground red 9 | FGRN="\033[32m" # foreground green 10 | FYEL="\033[33m" # foreground yellow 11 | FBLE="\033[34m" # foreground blue 12 | FMAG="\033[35m" # foreground magenta 13 | FCYN="\033[36m" # foreground cyan 14 | FWHT="\033[37m" # foreground white 15 | BBLK="\033[40m" # background black 16 | BRED="\033[41m" # background red 17 | BGRN="\033[42m" # background green 18 | BYEL="\033[43m" # background yellow 19 | BBLE="\033[44m" # background blue 20 | BMAG="\033[45m" # background magenta 21 | BCYN="\033[46m" # background cyan 22 | BWHT="\033[47m" # background white 23 | 24 | function rm_if_exist() { 25 | if [ -f "$1" ]; then 26 | rm "$1"; 27 | echo -e "${FGRN}File $1 removed${RS}" 28 | elif [ -d "$1" ]; then 29 | rm -r "$1"; 30 | echo -e "${FBLE}Directory $1 removed${RS}" 31 | #else 32 | # echo -e "${FRED}$1 not found${RS}" 33 | fi 34 | } 35 | 36 | rm_if_exist "calc_prob/src/calc_prob_kernel.cu.o" 37 | rm_if_exist "__pycache__" 38 | rm_if_exist "dist" 39 | rm_if_exist "build" 40 | rm_if_exist "pytorch_calc_stop_problility.egg-info" 41 | rm_if_exist ".cache" 42 | rm_if_exist "calc_prob/__pycache__" 43 | rm_if_exist "calc_prob/_ext" 44 | rm_if_exist "calc_prob/functions/__pycache__" 45 | rm_if_exist "calc_prob/modules/__pycache__" 46 | -------------------------------------------------------------------------------- /toolbox/calc_prob/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | from setuptools import setup, find_packages 5 | 6 | import build 7 | 8 | this_file = os.path.dirname(__file__) 9 | 10 | setup( 11 | name="pytorch_calc_stop_problility", 12 | version="0.1.0", 13 | description="Pytorch extension of calcualting ray stop probability", 14 | url="https://bluhbluhbluh", 15 | author="Zhoutong Zhang", 16 | author_email="ztzhang@mit.edu", 17 | # Require cffi. 18 | install_requires=["cffi>=1.0.0"], 19 | setup_requires=["cffi>=1.0.0"], 20 | # Exclude the build files. 21 | packages=find_packages(exclude=["build", "test"]), 22 | # Package where to put the extensions. Has to be a prefix of build.py. 23 | ext_package="", 24 | # Extensions to compile. 25 | cffi_modules=[ 26 | os.path.join(this_file, "build.py:ffi") 27 | ], 28 | ) 29 | -------------------------------------------------------------------------------- /toolbox/calc_prob/setup.sh: -------------------------------------------------------------------------------- 1 | echo "Add -gencode to match all the GPU architectures you have." 2 | echo "Check 'https://en.wikipedia.org/wiki/CUDA#GPUs_supported' for list of architecture." 3 | echo "Check 'http://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html' for GPU compilation based on architecture." 4 | 5 | # GPU architecture short list: 6 | # GTX 650M: 30 7 | # GTX Titan: 35 8 | # GTX Titan Black: 35 9 | # Tesla K40c: 35 10 | # GTX Titan X: 52 11 | # Titan X (Pascal): 61 12 | # GTX 1080: 61 13 | # Titan Xp: 61 14 | 15 | TORCH=$(python -c "import os; import torch; print(os.path.dirname(torch.__file__))") 16 | HAS_CUDA=$(python -c "import torch; print(torch.cuda.is_available())") 17 | 18 | if [ "$HAS_CUDA" == "True" ]; then 19 | if ! type nvcc >/dev/null 2>&1 ; then 20 | echo 'cuda available but nvcc not found. Please add nvcc to $PATH. ' 21 | exit 1 22 | fi 23 | cd calc_prob/src 24 | HERE=$(pwd -P) 25 | cmd="nvcc -c -o calc_prob_kernel.cu.o calc_prob_kernel.cu -x cu -Xcompiler -fPIC -I ${TORCH}/lib/include -I ${TORCH}/lib/include/TH -I ${TORCH}/lib/include/THC -I ${HERE} \ 26 | -gencode arch=compute_30,code=sm_30 \ 27 | -gencode arch=compute_35,code=sm_35 \ 28 | -gencode arch=compute_52,code=sm_52 \ 29 | -gencode arch=compute_61,code=sm_61 " 30 | echo "$cmd" 31 | eval "$cmd" 32 | cd ../../ 33 | fi 34 | if [ "$1" = "package" ]; then 35 | # for install 36 | python setup.py install 37 | elif [ "$1" = "script" ]; then 38 | # for build 39 | python build.py 40 | else 41 | echo "Shouldn't be here." 42 | fi 43 | -------------------------------------------------------------------------------- /toolbox/cam_bp/build.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | from torch.utils.ffi import create_extension 5 | 6 | this_file = os.path.dirname(os.path.realpath(__file__)) 7 | print(this_file) 8 | 9 | extra_compile_args = list() 10 | 11 | extra_objects = list() 12 | assert(torch.cuda.is_available()) 13 | sources = ['cam_bp/src/back_projection.c'] 14 | headers = ['cam_bp/src/back_projection.h'] 15 | defines = [('WITH_CUDA', True)] 16 | with_cuda = True 17 | 18 | extra_objects = ['cam_bp/src/back_projection_kernel.cu.o'] 19 | extra_objects = [os.path.join(this_file, fname) for fname in extra_objects] 20 | 21 | ffi_params = { 22 | 'headers': headers, 23 | 'sources': sources, 24 | 'define_macros': defines, 25 | 'relative_to': __file__, 26 | 'with_cuda': with_cuda, 27 | 'extra_objects': extra_objects, 28 | 'include_dirs': [os.path.join(this_file, 'cam_bp/src')], 29 | 'extra_compile_args': extra_compile_args, 30 | } 31 | 32 | ffi = create_extension( 33 | 'cam_bp._ext.cam_bp_lib', 34 | package=True, 35 | **ffi_params 36 | ) 37 | 38 | if __name__ == '__main__': 39 | ffi = create_extension( 40 | 'cam_bp._ext.cam_bp_lib', 41 | package=False, 42 | **ffi_params) 43 | ffi.build() 44 | -------------------------------------------------------------------------------- /toolbox/cam_bp/cam_bp/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiumingzhang/GenRe-ShapeHD/ee42add2707de509b5914ab444ae91b832f75981/toolbox/cam_bp/cam_bp/__init__.py -------------------------------------------------------------------------------- /toolbox/cam_bp/cam_bp/functions/__init__.py: -------------------------------------------------------------------------------- 1 | from .cam_back_projection import CameraBackProjection 2 | from .get_surface_mask import get_surface_mask 3 | from .sperical_to_tdf import SphericalBackProjection 4 | -------------------------------------------------------------------------------- /toolbox/cam_bp/cam_bp/functions/cam_back_projection.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | from torch.autograd.function import once_differentiable 4 | from .._ext import cam_bp_lib 5 | from cffi import FFI 6 | ffi = FFI() 7 | 8 | 9 | class CameraBackProjection(Function): 10 | 11 | @staticmethod 12 | def forward(ctx, depth_t, fl, cam_dist, res=128): 13 | assert depth_t.dim() == 4 14 | assert fl.dim() == 2 and fl.size(1) == depth_t.size(1) 15 | assert cam_dist.dim() == 2 and cam_dist.size(1) == depth_t.size(1) 16 | assert cam_dist.size(0) == depth_t.size(0) 17 | assert fl.size(0) == depth_t.size(0) 18 | assert depth_t.is_cuda 19 | assert fl.is_cuda 20 | assert cam_dist.is_cuda 21 | in_shape = depth_t.shape 22 | cnt = depth_t.new(in_shape[0], in_shape[1], res, res, res).zero_() 23 | tdf = depth_t.new(in_shape[0], in_shape[1], 24 | res, res, res).zero_() + 1 / res 25 | cam_bp_lib.back_projection_forward(depth_t, cam_dist, fl, tdf, cnt) 26 | # print(cnt) 27 | ctx.save_for_backward(depth_t, fl, cam_dist) 28 | ctx.cnt_forward = cnt 29 | ctx.depth_shape = in_shape 30 | return tdf 31 | 32 | @staticmethod 33 | @once_differentiable 34 | def backward(ctx, grad_output): 35 | assert grad_output.is_cuda 36 | # print(grad_output.type()) 37 | depth_t, fl, cam_dist = ctx.saved_tensors 38 | cnt = ctx.cnt_forward 39 | grad_depth = grad_output.new(ctx.depth_shape).zero_() 40 | grad_fl = grad_output.new( 41 | ctx.depth_shape[0], ctx.depth_shape[1]).zero_() 42 | grad_camdist = grad_output.new( 43 | ctx.depth_shape[0], ctx.depth_shape[1]).zero_() 44 | cam_bp_lib.back_projection_backward( 45 | depth_t, fl, cam_dist, cnt, grad_output, grad_depth, grad_camdist, grad_fl) 46 | return grad_depth, grad_fl, grad_camdist, None 47 | -------------------------------------------------------------------------------- /toolbox/cam_bp/cam_bp/functions/get_surface_mask.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from .._ext import cam_bp_lib 4 | from cffi import FFI 5 | ffi = FFI() 6 | 7 | 8 | def get_vox_surface_cnt(depth_t, fl, cam_dist, res=128): 9 | assert depth_t.dim() == 4 10 | assert fl.dim() == 2 and fl.size(1) == depth_t.size(1) 11 | assert cam_dist.dim() == 2 and cam_dist.size(1) == depth_t.size(1) 12 | assert cam_dist.size(0) == depth_t.size(0) 13 | assert fl.size(0) == depth_t.size(0) 14 | assert depth_t.is_cuda 15 | assert fl.is_cuda 16 | assert cam_dist.is_cuda 17 | in_shape = depth_t.shape 18 | cnt = depth_t.new(in_shape[0], in_shape[1], res, res, res).zero_() 19 | tdf = depth_t.new(in_shape[0], in_shape[1], res, 20 | res, res).zero_() + 1 / res 21 | cam_bp_lib.back_projection_forward(depth_t, cam_dist, fl, tdf, cnt) 22 | return cnt 23 | 24 | 25 | def get_surface_mask(depth_t, fl=784.4645406, cam_dist=2.0, res=128): 26 | n = depth_t.size(0) 27 | nc = depth_t.size(1) 28 | if type(fl) == float: 29 | fl_v = fl 30 | fl = torch.FloatTensor(n, nc).cuda() 31 | fl.fill_(fl_v) 32 | if type(cam_dist) == float: 33 | cmd_v = cam_dist 34 | cam_dist = torch.FloatTensor(n, nc).cuda() 35 | cam_dist.fill_(cmd_v) 36 | cnt = get_vox_surface_cnt(depth_t, fl, cam_dist, res) 37 | mask = cnt.new(n, nc, res, res, res).zero_() 38 | cam_bp_lib.get_surface_mask(depth_t, cam_dist, fl, cnt, mask) 39 | surface_vox = torch.clamp(cnt, min=0.0, max=1.0) 40 | return surface_vox, mask 41 | -------------------------------------------------------------------------------- /toolbox/cam_bp/cam_bp/functions/sperical_to_tdf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.autograd import Function 4 | from torch.autograd.function import once_differentiable 5 | from .._ext import cam_bp_lib 6 | from cffi import FFI 7 | ffi = FFI() 8 | 9 | 10 | class SphericalBackProjection(Function): 11 | 12 | @staticmethod 13 | def forward(ctx, spherical, grid, res=128): 14 | assert spherical.dim() == 4 15 | assert grid.dim() == 5 16 | assert spherical.size(0) == grid.size(0) 17 | assert spherical.size(1) == grid.size(1) 18 | assert spherical.size(2) == grid.size(2) 19 | assert spherical.size(3) == grid.size(3) 20 | assert grid.size(4) == 3 21 | assert spherical.is_cuda 22 | assert grid.is_cuda 23 | in_shape = spherical.shape 24 | cnt = spherical.new(in_shape[0], in_shape[1], res, res, res).zero_() 25 | tdf = spherical.new(in_shape[0], in_shape[1], 26 | res, res, res).zero_() 27 | cam_bp_lib.spherical_back_proj_forward(spherical, grid, tdf, cnt) 28 | # print(cnt) 29 | ctx.save_for_backward(spherical.detach(), grid, cnt) 30 | ctx.depth_shape = in_shape 31 | return tdf, cnt 32 | 33 | @staticmethod 34 | @once_differentiable 35 | def backward(ctx, grad_output, grad_phony): 36 | assert grad_output.is_cuda 37 | assert not np.isnan(torch.sum(grad_output.detach())) 38 | spherical, grid, cnt = ctx.saved_tensors 39 | grad_depth = grad_output.new(ctx.depth_shape).zero_() 40 | cam_bp_lib.spherical_back_proj_backward( 41 | spherical, grid, cnt, grad_output, grad_depth) 42 | try: 43 | assert not np.isnan(torch.sum(grad_depth)) 44 | except: 45 | import pdb 46 | pdb.set_trace() 47 | return grad_depth, None, None 48 | -------------------------------------------------------------------------------- /toolbox/cam_bp/cam_bp/modules/Spherical_backproj.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from ..functions import SphericalBackProjection 4 | from torch.autograd import Variable 5 | 6 | 7 | class spherical_backprojection(nn.Module): 8 | 9 | def __init__(self, grid, vox_res=128): 10 | super(camera_backprojection, self).__init__() 11 | self.vox_res = vox_res 12 | self.backprojection_layer = SphericalBackProjection() 13 | assert type(grid) == torch.FloatTensor 14 | self.grid = Variable(grid.cuda()) 15 | 16 | def forward(self, spherical): 17 | return self.backprojection_layer(spherical, self.grid, self.vox_res) 18 | -------------------------------------------------------------------------------- /toolbox/cam_bp/cam_bp/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiumingzhang/GenRe-ShapeHD/ee42add2707de509b5914ab444ae91b832f75981/toolbox/cam_bp/cam_bp/modules/__init__.py -------------------------------------------------------------------------------- /toolbox/cam_bp/cam_bp/modules/camera_backprojection_module.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from ..functions import CameraBackProjection 3 | import torch 4 | 5 | 6 | class Camera_back_projection_layer(nn.Module): 7 | def __init__(self, res=128): 8 | super(Camera_back_projection_layer, self).__init__() 9 | assert res == 128 10 | self.res = 128 11 | 12 | def forward(self, depth_t, fl=418.3, cam_dist=2.2, shift=True): 13 | n = depth_t.size(0) 14 | if type(fl) == float: 15 | fl_v = fl 16 | fl = torch.FloatTensor(n, 1).cuda() 17 | fl.fill_(fl_v) 18 | if type(cam_dist) == float: 19 | cmd_v = cam_dist 20 | cam_dist = torch.FloatTensor(n, 1).cuda() 21 | cam_dist.fill_(cmd_v) 22 | df = CameraBackProjection.apply(depth_t, fl, cam_dist, self.res) 23 | return self.shift_tdf(df) if shift else df 24 | 25 | @staticmethod 26 | def shift_tdf(input_tdf, res=128): 27 | out_tdf = 1 - res * (input_tdf) 28 | return out_tdf 29 | 30 | 31 | class camera_backprojection(nn.Module): 32 | 33 | def __init__(self, vox_res=128): 34 | super(camera_backprojection, self).__init__() 35 | self.vox_res = vox_res 36 | self.backprojection_layer = CameraBackProjection() 37 | 38 | def forward(self, depth, fl, camdist): 39 | return self.backprojection_layer(depth, fl, camdist, self.voxel_res) 40 | -------------------------------------------------------------------------------- /toolbox/cam_bp/cam_bp/src/_cam_bp_lib.abi3.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiumingzhang/GenRe-ShapeHD/ee42add2707de509b5914ab444ae91b832f75981/toolbox/cam_bp/cam_bp/src/_cam_bp_lib.abi3.so -------------------------------------------------------------------------------- /toolbox/cam_bp/cam_bp/src/back_projection.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include "back_projection.h" 5 | #include "back_projection_kernel.h" 6 | 7 | extern THCState *state; 8 | 9 | int back_projection_forward(THCudaTensor* depth, THCudaTensor* camdist, THCudaTensor* fl, THCudaTensor* voxel, THCudaTensor* cnt){ 10 | int success = 0; 11 | success = back_projection_forward_wrap(state, depth, camdist, fl, voxel, cnt); 12 | // check for errors 13 | if (!success) { 14 | THError("aborting"); 15 | } 16 | return 1; 17 | } 18 | 19 | 20 | int back_projection_backward(THCudaTensor* depth, THCudaTensor* fl, THCudaTensor* camdist, THCudaTensor* cnt, THCudaTensor* grad_in, THCudaTensor* grad_depth, THCudaTensor* grad_camdist, THCudaTensor* grad_fl){ 21 | int success = 0; 22 | success = back_projection_backward_wrap(state, depth, fl, camdist, cnt, grad_in, grad_depth, grad_camdist, grad_fl); 23 | // check for errors 24 | if (!success) { 25 | THError("aborting"); 26 | } 27 | return 1; 28 | } 29 | 30 | int get_surface_mask(THCudaTensor* depth, THCudaTensor* camdist, THCudaTensor* fl, THCudaTensor* cnt, THCudaTensor* mask){ 31 | int success = 0; 32 | success = get_surface_mask_wrap(state, depth, camdist, fl, cnt, mask); 33 | // check for errors 34 | if (!success) { 35 | THError("aborting"); 36 | } 37 | return 1; 38 | } 39 | 40 | int spherical_back_proj_forward(THCudaTensor* depth, THCudaTensor* grid_in, THCudaTensor* voxel, THCudaTensor* cnt){ 41 | int success = 0; 42 | success = spherical_back_proj_forward_wrap(state, depth, grid_in, voxel, cnt); 43 | // check for errors 44 | if (!success) { 45 | THError("aborting"); 46 | } 47 | return 1; 48 | } 49 | int spherical_back_proj_backward(THCudaTensor* depth, THCudaTensor* grid_in, THCudaTensor* cnt, THCudaTensor* grad_in, THCudaTensor* grad_depth){ 50 | int success = 0; 51 | success = spherical_back_proj_backward_wrap(state, depth, grid_in,cnt,grad_in,grad_depth); 52 | // check for errors 53 | if (!success) { 54 | THError("aborting"); 55 | } 56 | return 1; 57 | } 58 | -------------------------------------------------------------------------------- /toolbox/cam_bp/cam_bp/src/back_projection.h: -------------------------------------------------------------------------------- 1 | int back_projection_forward(THCudaTensor* depth, THCudaTensor* camdist, THCudaTensor* fl, THCudaTensor* voxel, THCudaTensor* cnt); 2 | int back_projection_backward(THCudaTensor* depth, THCudaTensor* fl, THCudaTensor* camdist, THCudaTensor* cnt, THCudaTensor* grad_in, THCudaTensor* grad_depth, THCudaTensor* grad_camdist, THCudaTensor* grad_fl); 3 | int get_surface_mask(THCudaTensor* depth, THCudaTensor* camdist, THCudaTensor* fl, THCudaTensor* cnt, THCudaTensor* mask); 4 | int spherical_back_proj_forward(THCudaTensor* depth, THCudaTensor* grid_in, THCudaTensor* voxel, THCudaTensor* cnt); 5 | int spherical_back_proj_backward(THCudaTensor* depth, THCudaTensor* grid_in, THCudaTensor* cnt, THCudaTensor* grad_in, THCudaTensor* grad_depth); 6 | -------------------------------------------------------------------------------- /toolbox/cam_bp/cam_bp/src/back_projection_kernel.h: -------------------------------------------------------------------------------- 1 | 2 | #ifdef __cplusplus 3 | extern "C" { 4 | #endif 5 | 6 | int back_projection_forward_wrap (THCState* state, THCudaTensor* depth, THCudaTensor* camdist, THCudaTensor* fl, THCudaTensor* voxel, THCudaTensor* cnt); 7 | int back_projection_backward_wrap (THCState* state, THCudaTensor* depth, THCudaTensor* fl, THCudaTensor* camdist, THCudaTensor* cnt, THCudaTensor* grad_in, THCudaTensor* grad_depth, THCudaTensor* grad_camdist, THCudaTensor* grad_fl); 8 | int get_surface_mask_wrap(THCState* state, THCudaTensor* depth, THCudaTensor* camdist, THCudaTensor* fl, THCudaTensor* cnt, THCudaTensor* mask); 9 | int spherical_back_proj_forward_wrap(THCState* state, THCudaTensor* depth, THCudaTensor* grid_in, THCudaTensor* voxel, THCudaTensor* cnt); 10 | int spherical_back_proj_backward_wrap(THCState* state, THCudaTensor* depth, THCudaTensor* grid_in, THCudaTensor* cnt, THCudaTensor* grad_in, THCudaTensor* grad_depth); 11 | #ifdef __cplusplus 12 | } 13 | #endif 14 | -------------------------------------------------------------------------------- /toolbox/cam_bp/clean.sh: -------------------------------------------------------------------------------- 1 | 2 | # ANSI color codes 3 | RS="\033[0m" # reset 4 | HC="\033[1m" # hicolor 5 | UL="\033[4m" # underline 6 | INV="\033[7m" # inverse background and foreground 7 | FBLK="\033[30m" # foreground black 8 | FRED="\033[31m" # foreground red 9 | FGRN="\033[32m" # foreground green 10 | FYEL="\033[33m" # foreground yellow 11 | FBLE="\033[34m" # foreground blue 12 | FMAG="\033[35m" # foreground magenta 13 | FCYN="\033[36m" # foreground cyan 14 | FWHT="\033[37m" # foreground white 15 | BBLK="\033[40m" # background black 16 | BRED="\033[41m" # background red 17 | BGRN="\033[42m" # background green 18 | BYEL="\033[43m" # background yellow 19 | BBLE="\033[44m" # background blue 20 | BMAG="\033[45m" # background magenta 21 | BCYN="\033[46m" # background cyan 22 | BWHT="\033[47m" # background white 23 | 24 | function rm_if_exist() { 25 | if [ -f "$1" ]; then 26 | rm "$1"; 27 | echo -e "${FGRN}File $1 removed${RS}" 28 | elif [ -d "$1" ]; then 29 | rm -r "$1"; 30 | echo -e "${FBLE}Directory $1 removed${RS}" 31 | else 32 | echo -e "${FRED}$1 not found${RS}" 33 | fi 34 | } 35 | 36 | rm_if_exist "cam_bp/src/back_projection_kernel.cu.o" 37 | rm_if_exist "__pycache__" 38 | rm_if_exist "dist" 39 | rm_if_exist "build" 40 | rm_if_exist "pytorch_camera_back_projection.egg-info" 41 | rm_if_exist ".cache" 42 | rm_if_exist "cam_bp/__pycache__" 43 | rm_if_exist "cam_bp/_ext" 44 | rm_if_exist "cam_bp/functions/__pycache__" 45 | rm_if_exist "cam_bp/modules/__pycache__" 46 | -------------------------------------------------------------------------------- /toolbox/cam_bp/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | from setuptools import setup, find_packages 5 | 6 | import build 7 | 8 | this_file = os.path.dirname(__file__) 9 | 10 | setup( 11 | name="pytorch_camera_back_projection", 12 | version="0.1.0", 13 | description="Pytorch extension of back projecting depth", 14 | url="https://bluhbluhbluh", 15 | author="Zhoutong Zhang", 16 | author_email="ztzhang@mit.edu", 17 | # Require cffi. 18 | install_requires=["cffi>=1.0.0"], 19 | setup_requires=["cffi>=1.0.0"], 20 | # Exclude the build files. 21 | packages=find_packages(exclude=["build", "test"]), 22 | # Package where to put the extensions. Has to be a prefix of build.py. 23 | ext_package="", 24 | # Extensions to compile. 25 | cffi_modules=[ 26 | os.path.join(this_file, "build.py:ffi") 27 | ], 28 | ) 29 | -------------------------------------------------------------------------------- /toolbox/cam_bp/setup.sh: -------------------------------------------------------------------------------- 1 | echo "Add -gencode to match all the GPU architectures you have." 2 | echo "Check 'https://en.wikipedia.org/wiki/CUDA#GPUs_supported' for list of architecture." 3 | echo "Check 'http://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html' for GPU compilation based on architecture." 4 | 5 | # GPU architecture short list: 6 | # GTX 650M: 30 7 | # GTX Titan: 35 8 | # GTX Titan Black: 35 9 | # Tesla K40c: 35 10 | # GTX Titan X: 52 11 | # Titan X (Pascal): 61 12 | # GTX 1080: 61 13 | # Titan Xp: 61 14 | 15 | TORCH=$(python -c "import os; import torch; print(os.path.dirname(torch.__file__))") 16 | HAS_CUDA=$(python -c "import torch; print(torch.cuda.is_available())") 17 | 18 | if [ "$HAS_CUDA" == "True" ]; then 19 | if ! type nvcc >/dev/null 2>&1 ; then 20 | echo 'cuda available but nvcc not found. Please add nvcc to $PATH. ' 21 | exit 1 22 | fi 23 | cd cam_bp/src 24 | HERE=$(pwd -P) 25 | cmd="nvcc -c -o back_projection_kernel.cu.o back_projection_kernel.cu -x cu -Xcompiler -fPIC -I ${TORCH}/lib/include/TH -I ${TORCH}/lib/include -I ${TORCH}/lib/include/THC -I ${HERE} -I ${TORCH}/lib/include\ 26 | -gencode arch=compute_30,code=sm_30 \ 27 | -gencode arch=compute_35,code=sm_35 \ 28 | -gencode arch=compute_52,code=sm_52 \ 29 | -gencode arch=compute_61,code=sm_61" 30 | echo "$cmd" 31 | eval "$cmd" 32 | cd .. 33 | fi 34 | cd .. 35 | pwd 36 | if [ "$1" = "package" ]; then 37 | # for install 38 | python setup.py install 39 | elif [ "$1" = "script" ]; then 40 | # for build 41 | python build.py 42 | else 43 | echo "Shouldn't be here." 44 | fi 45 | -------------------------------------------------------------------------------- /toolbox/nndistance/README.md: -------------------------------------------------------------------------------- 1 | # Chamfer Distance for Pytorch 2 | Modified from [pointGAN](https://github.com/fxia22/pointGAN) 3 | 4 | ## Requirements 5 | Tested on Pytorch 0.3.1 6 | Due to syntax change in Pytorch 0.4.0, this implementation probably won't work on Pytorch 0.4.0 7 | 8 | ## Install 9 | ```bash 10 | ./clean.sh 11 | ./setup.sh script 12 | ``` 13 | Note that currently the code only supports building as script, so you'll need to put this directory under your code's root directory, where you can import using `import nndistance` 14 | 15 | ## Example 16 | Run `test.py` as an example: 17 | 18 | ```bash 19 | cp test.py .. 20 | python test.py 21 | ``` 22 | 23 | ## Usage 24 | - The function `nndistance.functions.nndistance(pts1, pts2)` return two lists of distances - the closest distance for each point in `pts1` to point cloud `pts2`, and the closest distance for each point in `pts2` to point cloud `pts1`. 25 | - For convenience, the distance here is defined as `(x1-x2)*(x1-x2) + (y1-y2)*(y1-y2) + (z1-z2)*(z1-z2)`, **without taking the square root**. 26 | - If you want to take the square root, keep in mind that in Pytorch, **the gradient of `sqrt(0)` is `nan`**, so you'll probably want to add a small `eps` before taking sqrt. 27 | - The function `nndistance.functions.nndistance_score(pts1, pts2)` return a list of scores. 28 | 29 | 30 | Internal note: this implementation gives the same result as our previously used implementation in tensorflow. -------------------------------------------------------------------------------- /toolbox/nndistance/build.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.utils.ffi import create_extension 4 | 5 | this_file = os.path.dirname(__file__) 6 | 7 | sources = ['src/my_lib.c'] 8 | headers = ['src/my_lib.h'] 9 | defines = [] 10 | with_cuda = False 11 | 12 | if torch.cuda.is_available(): 13 | print('Including CUDA code.') 14 | sources += ['src/my_lib_cuda.c'] 15 | headers += ['src/my_lib_cuda.h'] 16 | defines += [('WITH_CUDA', None)] 17 | with_cuda = True 18 | 19 | this_file = os.path.dirname(os.path.realpath(__file__)) 20 | print(this_file) 21 | extra_objects = ['src/nnd_cuda.cu.o'] 22 | extra_objects = [os.path.join(this_file, fname) for fname in extra_objects] 23 | 24 | ffi = create_extension( 25 | '_ext.my_lib', 26 | headers=headers, 27 | sources=sources, 28 | define_macros=defines, 29 | relative_to=__file__, 30 | with_cuda=with_cuda, 31 | extra_objects=extra_objects 32 | ) 33 | 34 | if __name__ == '__main__': 35 | ffi.build() 36 | -------------------------------------------------------------------------------- /toolbox/nndistance/clean.sh: -------------------------------------------------------------------------------- 1 | rm -rf _ext 2 | rm -f src/*.o 3 | rm -rf */__pycache__ -------------------------------------------------------------------------------- /toolbox/nndistance/functions/__init__.py: -------------------------------------------------------------------------------- 1 | from .nnd import nndistance, nndistance_w_idx, nndistance_score 2 | -------------------------------------------------------------------------------- /toolbox/nndistance/functions/nnd.py: -------------------------------------------------------------------------------- 1 | # functions/add.py 2 | import torch 3 | from torch.autograd import Function 4 | from torch.autograd.function import once_differentiable 5 | from nndistance._ext import my_lib 6 | 7 | 8 | class NNDFunction(Function): 9 | @staticmethod 10 | def forward(ctx, xyz1, xyz2): 11 | assert xyz1.dim() == 3 and xyz2.dim() == 3 12 | assert xyz1.size(0) == xyz2.size(0) 13 | assert xyz1.size(2) == 3 and xyz2.size(2) == 3 14 | assert xyz1.is_cuda == xyz2.is_cuda 15 | assert xyz1.type().endswith('FloatTensor') and xyz2.type().endswith('FloatTensor'), 'only FloatTensor are supported for NNDistance' 16 | assert xyz1.is_contiguous() and xyz2.is_contiguous() # the CPU and GPU code are not robust and will break if the storage is not contiguous 17 | ctx.is_cuda = xyz1.is_cuda 18 | 19 | batchsize, n, _ = xyz1.size() 20 | _, m, _ = xyz2.size() 21 | dist1 = torch.zeros(batchsize, n) 22 | dist2 = torch.zeros(batchsize, m) 23 | 24 | idx1 = torch.zeros(batchsize, n).type(torch.IntTensor) 25 | idx2 = torch.zeros(batchsize, m).type(torch.IntTensor) 26 | 27 | if not xyz1.is_cuda: 28 | my_lib.nnd_forward(xyz1, xyz2, dist1, dist2, idx1, idx2) 29 | else: 30 | dist1 = dist1.cuda() 31 | dist2 = dist2.cuda() 32 | idx1 = idx1.cuda() 33 | idx2 = idx2.cuda() 34 | my_lib.nnd_forward_cuda(xyz1, xyz2, dist1, dist2, idx1, idx2) 35 | 36 | ctx.save_for_backward(xyz1, xyz2, idx1, idx2) 37 | return dist1, dist2, idx1, idx2 38 | 39 | @staticmethod 40 | @once_differentiable 41 | def backward(ctx, graddist1, graddist2, gradidx1, gradidx2): 42 | """ 43 | Note that this function needs gradidx placeholders 44 | """ 45 | assert ctx.is_cuda == graddist1.is_cuda and ctx.is_cuda == graddist2.is_cuda 46 | xyz1, xyz2, idx1, idx2 = ctx.saved_tensors 47 | graddist1 = graddist1.contiguous() 48 | graddist2 = graddist2.contiguous() 49 | assert xyz1.is_contiguous() 50 | assert xyz2.is_contiguous() 51 | assert idx1.is_contiguous() 52 | assert idx2.is_contiguous() 53 | assert graddist1.type().endswith('FloatTensor') and graddist2.type().endswith('FloatTensor'), 'only FloatTensor are supported for NNDistance' 54 | 55 | gradxyz1 = xyz1.new(xyz1.size()) 56 | gradxyz2 = xyz1.new(xyz2.size()) 57 | 58 | if not graddist1.is_cuda: 59 | my_lib.nnd_backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2) 60 | else: 61 | my_lib.nnd_backward_cuda(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2) 62 | 63 | return gradxyz1, gradxyz2 64 | 65 | 66 | def nndistance_w_idx(xyz1, xyz2): 67 | xyz1 = xyz1.contiguous() 68 | xyz2 = xyz2.contiguous() 69 | return NNDFunction.apply(xyz1, xyz2) 70 | 71 | 72 | def nndistance(xyz1, xyz2): 73 | if xyz1.size(2) != 3: 74 | xyz1 = xyz1.transpose(1, 2) 75 | if xyz2.size(2) != 3: 76 | xyz2 = xyz2.transpose(1, 2) 77 | xyz1 = xyz1.contiguous() 78 | xyz2 = xyz2.contiguous() 79 | dist1, dist2, _, _ = NNDFunction.apply(xyz1, xyz2) 80 | return dist1, dist2 81 | 82 | 83 | def nndistance_score(xyz1, xyz2, eps=1e-10): 84 | dist1, dist2 = nndistance(xyz1, xyz2) 85 | return torch.sqrt(dist1 + eps).mean(1) + torch.sqrt(dist2 + eps).mean(1) 86 | -------------------------------------------------------------------------------- /toolbox/nndistance/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiumingzhang/GenRe-ShapeHD/ee42add2707de509b5914ab444ae91b832f75981/toolbox/nndistance/modules/__init__.py -------------------------------------------------------------------------------- /toolbox/nndistance/modules/nnd.py: -------------------------------------------------------------------------------- 1 | from torch.nn import Module 2 | from nndistance.functions.nnd import nndistance 3 | 4 | 5 | class NNDModule(Module): 6 | def forward(self, input1, input2): 7 | return nndistance(input1, input2) 8 | -------------------------------------------------------------------------------- /toolbox/nndistance/setup.sh: -------------------------------------------------------------------------------- 1 | if [[ "$#" -ne 1 || ( "$1" != "script") ]]; then 2 | echo "Usage: ./setup.sh mode" 3 | echo "mode: script (package mode is not supported for now)" 4 | echo "package: build and install as a pip package" 5 | echo "script: build and use as a script. Must be present in local directory for import" 6 | exit 1 7 | fi 8 | 9 | echo "Add -gencode to match all the GPU architectures you have." 10 | echo "Check 'https://en.wikipedia.org/wiki/CUDA#GPUs_supported' for list of architecture." 11 | echo "Check 'http://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html' for GPU compilation based on architecture." 12 | 13 | # GPU architecture short list: 14 | # GTX 650M: 30 15 | # GTX Titan: 35 16 | # GTX Titan Black: 35 17 | # Tesla K40c: 35 18 | # GTX Titan X: 52 19 | # Titan X (Pascal): 61 20 | # GTX 1080: 61 21 | # Titan Xp: 61 22 | # Titan V: 70 23 | 24 | TORCH=$(python -c "import os; import torch; print(os.path.dirname(torch.__file__))") 25 | HAS_CUDA=$(python -c "import torch; print(torch.cuda.is_available())") 26 | 27 | if [ "$HAS_CUDA" == "True" ]; then 28 | if ! type nvcc >/dev/null 2>&1 ; then 29 | echo 'cuda available but nvcc not found. Please add nvcc to $PATH. ' 30 | exit 1 31 | fi 32 | cd src 33 | HERE=$(pwd -P) 34 | cmd="nvcc -c -o nnd_cuda.cu.o nnd_cuda.cu -x cu -Xcompiler -fPIC -I ${TORCH}/lib/include/TH -I ${TORCH}/lib/include/THC -I ${HERE} -I ${TORCH}/lib/include\ 35 | -gencode arch=compute_30,code=sm_30 \ 36 | -gencode arch=compute_35,code=sm_35 \ 37 | -gencode arch=compute_52,code=sm_52 \ 38 | -gencode arch=compute_61,code=sm_61" 39 | echo "$cmd" 40 | eval "$cmd" 41 | cd .. 42 | fi 43 | 44 | if [ "$1" = "package" ]; then 45 | # for install 46 | python setup.py install 47 | elif [ "$1" = "script" ]; then 48 | # for build 49 | python build.py 50 | else 51 | echo "Shouldn't be here." 52 | fi 53 | -------------------------------------------------------------------------------- /toolbox/nndistance/src/my_lib.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | extern THCState *state; 5 | 6 | void nnsearch(int b,int n,int m,const float * xyz1,const float * xyz2,float * dist,int * idx){ 7 | for (int i=0;isize[0]; 32 | int batchsize = THCudaTensor_size(state, xyz1, 0); 33 | int n = THCudaTensor_size(state, xyz1, 1); 34 | int m = THCudaTensor_size(state, xyz2, 1); 35 | // int n = xyz1->size[1]; 36 | // int m = xyz2->size[1]; 37 | 38 | float *xyz1_data = THFloatTensor_data(xyz1); 39 | float *xyz2_data = THFloatTensor_data(xyz2); 40 | float *dist1_data = THFloatTensor_data(dist1); 41 | float *dist2_data = THFloatTensor_data(dist2); 42 | int *idx1_data = THIntTensor_data(idx1); 43 | int *idx2_data = THIntTensor_data(idx2); 44 | 45 | nnsearch(batchsize, n, m, xyz1_data, xyz2_data, dist1_data, idx1_data); 46 | nnsearch(batchsize, m, n, xyz2_data, xyz1_data, dist2_data, idx2_data); 47 | 48 | return 1; 49 | } 50 | 51 | 52 | int nnd_backward(THFloatTensor *xyz1, THFloatTensor *xyz2, THFloatTensor *gradxyz1, THFloatTensor *gradxyz2, THFloatTensor *graddist1, THFloatTensor *graddist2, THIntTensor *idx1, THIntTensor *idx2) { 53 | 54 | int b = THCudaTensor_size(state, xyz1, 0); 55 | int n = THCudaTensor_size(state, xyz1, 1); 56 | int m = THCudaTensor_size(state, xyz2, 1); 57 | 58 | // int b = xyz1->size[0]; 59 | // int n = xyz1->size[1]; 60 | // int m = xyz2->size[1]; 61 | 62 | //printf("%d %d %d\n", batchsize, n, m); 63 | 64 | float *xyz1_data = THFloatTensor_data(xyz1); 65 | float *xyz2_data = THFloatTensor_data(xyz2); 66 | float *gradxyz1_data = THFloatTensor_data(gradxyz1); 67 | float *gradxyz2_data = THFloatTensor_data(gradxyz2); 68 | float *graddist1_data = THFloatTensor_data(graddist1); 69 | float *graddist2_data = THFloatTensor_data(graddist2); 70 | int *idx1_data = THIntTensor_data(idx1); 71 | int *idx2_data = THIntTensor_data(idx2); 72 | 73 | 74 | for (int i=0;i 2 | #include "nnd_cuda.h" 3 | 4 | 5 | 6 | extern THCState *state; 7 | 8 | 9 | int nnd_forward_cuda(THCudaTensor *xyz1, THCudaTensor *xyz2, THCudaTensor *dist1, THCudaTensor *dist2, THCudaIntTensor *idx1, THCudaIntTensor *idx2) { 10 | int success = 0; 11 | success = NmDistanceKernelLauncher(THCudaTensor_size(state, xyz1,0), 12 | THCudaTensor_size(state, xyz1,1), 13 | THCudaTensor_data(state, xyz1), 14 | THCudaTensor_size(state, xyz2,1), 15 | THCudaTensor_data(state, xyz2), 16 | THCudaTensor_data(state, dist1), 17 | THCudaIntTensor_data(state, idx1), 18 | THCudaTensor_data(state, dist2), 19 | THCudaIntTensor_data(state, idx2), 20 | THCState_getCurrentStream(state) 21 | ); 22 | //int NmDistanceKernelLauncher(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream) 23 | if (!success) { 24 | THError("aborting"); 25 | } 26 | return 1; 27 | } 28 | 29 | 30 | int nnd_backward_cuda(THCudaTensor *xyz1, THCudaTensor *xyz2, THCudaTensor *gradxyz1, THCudaTensor *gradxyz2, THCudaTensor *graddist1, 31 | THCudaTensor *graddist2, THCudaIntTensor *idx1, THCudaIntTensor *idx2) { 32 | int success = 0; 33 | success = NmDistanceGradKernelLauncher( 34 | THCudaTensor_size(state, xyz1,0), 35 | THCudaTensor_size(state, xyz1,1), 36 | THCudaTensor_data(state, xyz1), 37 | THCudaTensor_size(state, xyz2,1), 38 | THCudaTensor_data(state, xyz2), 39 | THCudaTensor_data(state, graddist1), 40 | THCudaIntTensor_data(state, idx1), 41 | THCudaTensor_data(state, graddist2), 42 | THCudaIntTensor_data(state, idx2), 43 | THCudaTensor_data(state, gradxyz1), 44 | THCudaTensor_data(state, gradxyz2), 45 | THCState_getCurrentStream(state) 46 | ); 47 | //int NmDistanceGradKernelLauncher(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,const float * grad_dist2,const int * idx2,float * grad_xyz1,float * grad_xyz2, cudaStream_t stream) 48 | 49 | if (!success) { 50 | THError("aborting"); 51 | } 52 | 53 | return 1; 54 | } 55 | 56 | 57 | 58 | -------------------------------------------------------------------------------- /toolbox/nndistance/src/my_lib_cuda.h: -------------------------------------------------------------------------------- 1 | int nnd_forward_cuda(THCudaTensor *xyz1, THCudaTensor *xyz2, THCudaTensor *dist1, THCudaTensor *dist2, THCudaIntTensor *idx1, THCudaIntTensor *idx2); 2 | 3 | 4 | int nnd_backward_cuda(THCudaTensor *xyz1, THCudaTensor *xyz2, THCudaTensor *gradxyz1, THCudaTensor *gradxyz2, THCudaTensor *graddist1, THCudaTensor *graddist2, THCudaIntTensor *idx1, THCudaIntTensor *idx2); 5 | 6 | -------------------------------------------------------------------------------- /toolbox/nndistance/src/nnd_cuda.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include "nnd_cuda.h" 3 | 4 | 5 | 6 | __global__ void NmDistanceKernel(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i){ 7 | const int batch=512; 8 | __shared__ float buf[batch*3]; 9 | for (int i=blockIdx.x;ibest){ 121 | result[(i*n+j)]=best; 122 | result_i[(i*n+j)]=best_i; 123 | } 124 | } 125 | __syncthreads(); 126 | } 127 | } 128 | } 129 | int NmDistanceKernelLauncher(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream){ 130 | NmDistanceKernel<<>>(b,n,xyz,m,xyz2,result,result_i); 131 | NmDistanceKernel<<>>(b,m,xyz2,n,xyz,result2,result2_i); 132 | 133 | cudaError_t err = cudaGetLastError(); 134 | if (err != cudaSuccess) { 135 | printf("error in nnd updateOutput: %s\n", cudaGetErrorString(err)); 136 | //THError("aborting"); 137 | return 0; 138 | } 139 | return 1; 140 | 141 | 142 | } 143 | __global__ void NmDistanceGradKernel(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,float * grad_xyz1,float * grad_xyz2){ 144 | for (int i=blockIdx.x;i>>(b,n,xyz1,m,xyz2,grad_dist1,idx1,grad_xyz1,grad_xyz2); 167 | NmDistanceGradKernel<<>>(b,m,xyz2,n,xyz1,grad_dist2,idx2,grad_xyz2,grad_xyz1); 168 | 169 | cudaError_t err = cudaGetLastError(); 170 | if (err != cudaSuccess) { 171 | printf("error in nnd get grad: %s\n", cudaGetErrorString(err)); 172 | //THError("aborting"); 173 | return 0; 174 | } 175 | return 1; 176 | 177 | } 178 | 179 | -------------------------------------------------------------------------------- /toolbox/nndistance/src/nnd_cuda.h: -------------------------------------------------------------------------------- 1 | #ifdef __cplusplus 2 | extern "C" { 3 | #endif 4 | 5 | int NmDistanceKernelLauncher(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream); 6 | 7 | int NmDistanceGradKernelLauncher(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,const float * grad_dist2,const int * idx2,float * grad_xyz1,float * grad_xyz2, cudaStream_t stream); 8 | 9 | #ifdef __cplusplus 10 | } 11 | #endif -------------------------------------------------------------------------------- /toolbox/nndistance/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | 4 | try: 5 | from nndistance.modules.nnd import NNDModule 6 | except ImportError as err: 7 | raise ImportError('This file should be copied to its parent directory for import to work properly.') 8 | 9 | dist = NNDModule() 10 | 11 | p1 = torch.rand(1,50,3)*20 12 | p2 = torch.rand(1,50,3)*20 13 | # p1 = p1.int() 14 | # p1.random_(0,2) 15 | # p1 = p1.float() 16 | # p2 = p2.int() 17 | # p2.random_(0,2) 18 | p2 = p2.float() 19 | # print(p1) 20 | # print(p2) 21 | 22 | print('cpu') 23 | points1 = Variable(p1, requires_grad=True) 24 | points2 = Variable(p2, requires_grad=True) 25 | dist1, dist2 = dist(points1, points2) 26 | print(dist1, dist2) 27 | loss = torch.sum(dist1) 28 | print(loss) 29 | loss.backward() 30 | print(points1.grad, points2.grad) 31 | 32 | print('gpu') 33 | points1_cuda = Variable(p1.cuda(), requires_grad=True) 34 | points2_cuda = Variable(p2.cuda(), requires_grad=True) 35 | dist1_cuda, dist2_cuda = dist(points1_cuda, points2_cuda) 36 | print(dist1_cuda, dist2_cuda) 37 | loss_cuda = torch.sum(dist1_cuda) 38 | print(loss_cuda) 39 | loss_cuda.backward() 40 | print(points1_cuda.grad, points2_cuda.grad) 41 | 42 | print('stats:') 43 | print('loss:', loss.data[0], loss_cuda.data[0]) 44 | print('loss diff:', loss.data[0] - loss_cuda.data[0]) 45 | print('grad diff:', (points1.grad.data.cpu() - points1_cuda.grad.data.cpu()).abs().max(), (points2.grad.data.cpu() - points2_cuda.grad.data.cpu()).abs().max()) 46 | 47 | from nndistance.functions.nnd import nndistance_score 48 | print('total score:', nndistance_score(points1, points2)) 49 | -------------------------------------------------------------------------------- /toolbox/spherical_proj.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from .calc_prob.calc_prob.functions.calc_prob import CalcStopProb 4 | 5 | 6 | def gen_sph_grid(res=128): 7 | pi = np.pi 8 | phi = np.linspace(0, 180, res * 2 + 1)[1::2] 9 | theta = np.linspace(0, 360, res + 1)[:-1] 10 | grid = np.zeros([res, res, 3]) 11 | for idp, p in enumerate(phi): 12 | for idt, t in enumerate(theta): 13 | grid[idp, idt, 2] = np.cos((p * pi / 180)) 14 | proj = np.sin((p * pi / 180)) 15 | grid[idp, idt, 0] = proj * np.cos(t * pi / 180) 16 | grid[idp, idt, 1] = proj * np.sin(t * pi / 180) 17 | grid = np.reshape(grid, (1, 1, res, res, 3)) 18 | return torch.from_numpy(grid).float() 19 | 20 | 21 | def sph_pad(sph_tensor, padding_margin=16): 22 | F = torch.nn.functional 23 | pad2d = (padding_margin, padding_margin, padding_margin, padding_margin) 24 | rep_padded_sph = F.pad(sph_tensor, pad2d, mode='replicate') 25 | _, _, h, w = rep_padded_sph.shape 26 | rep_padded_sph[:, :, :, 0:padding_margin] = rep_padded_sph[:, :, :, w - 2 * padding_margin:w - padding_margin] 27 | rep_padded_sph[:, :, :, h - padding_margin:] = rep_padded_sph[:, :, :, padding_margin:2 * padding_margin] 28 | return rep_padded_sph 29 | 30 | 31 | class render_spherical(torch.nn.Module): 32 | def __init__(self, sph_res=128, z_res=256): 33 | super().__init__() 34 | self.sph_res = sph_res 35 | self.z_res = z_res 36 | self.gen_grid() 37 | self.calc_stop_prob = CalcStopProb().apply 38 | 39 | def gen_grid(self): 40 | res = self.sph_res 41 | z_res = self.z_res 42 | pi = np.pi 43 | phi = np.linspace(0, 180, res * 2 + 1)[1::2] 44 | theta = np.linspace(0, 360, res + 1)[:-1] 45 | grid = np.zeros([res, res, 3]) 46 | for idp, p in enumerate(phi): 47 | for idt, t in enumerate(theta): 48 | grid[idp, idt, 2] = np.cos((p * pi / 180)) 49 | proj = np.sin((p * pi / 180)) 50 | grid[idp, idt, 0] = proj * np.cos(t * pi / 180) 51 | grid[idp, idt, 1] = proj * np.sin(t * pi / 180) 52 | grid = np.reshape(grid * 2, (res, res, 3)) 53 | alpha = np.zeros([1, 1, z_res, 1]) 54 | alpha[0, 0, :, 0] = np.linspace(0, 1, z_res) 55 | grid = grid[:, :, np.newaxis, :] 56 | grid = grid * (1 - alpha) 57 | grid = torch.from_numpy(grid).float() 58 | depth_weight = torch.linspace(0, 1, self.z_res) 59 | self.register_buffer('depth_weight', depth_weight) 60 | self.register_buffer('grid', grid) 61 | 62 | def forward(self, vox): 63 | grid = self.grid.expand(vox.shape[0], -1, -1, -1, -1) 64 | vox = vox.permute(0, 1, 4, 3, 2) 65 | prob_sph = torch.nn.functional.grid_sample(vox, grid) 66 | prob_sph = torch.clamp(prob_sph, 1e-5, 1 - 1e-5) 67 | sph_stop_prob = self.calc_stop_prob(prob_sph) 68 | exp_depth = torch.matmul(sph_stop_prob, self.depth_weight) 69 | back_groud_prob = torch.prod(1.0 - prob_sph, dim=4) 70 | back_groud_prob = back_groud_prob * 1.0 71 | exp_depth = exp_depth + back_groud_prob 72 | return exp_depth 73 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import time 4 | import pandas as pd 5 | import torch 6 | from options import options_train 7 | import datasets 8 | import models 9 | from loggers import loggers 10 | from util.util_print import str_error, str_stage, str_verbose, str_warning 11 | from util import util_loadlib as loadlib 12 | 13 | 14 | ################################################### 15 | 16 | print(str_stage, "Parsing arguments") 17 | opt, unique_opt_params = options_train.parse() 18 | # Get all parse done, including subparsers 19 | print(opt) 20 | 21 | ################################################### 22 | 23 | print(str_stage, "Setting device") 24 | if opt.gpu == '-1': 25 | device = torch.device('cpu') 26 | else: 27 | loadlib.set_gpu(opt.gpu) 28 | device = torch.device('cuda') 29 | if opt.manual_seed is not None: 30 | loadlib.set_manual_seed(opt.manual_seed) 31 | 32 | ################################################### 33 | 34 | print(str_stage, "Setting up logging directory") 35 | exprdir = '{}_{}_{}'.format(opt.net, opt.dataset, opt.lr) 36 | exprdir += ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else '' 37 | logdir = os.path.join(opt.logdir, exprdir, str(opt.expr_id)) 38 | 39 | if opt.resume == 0: 40 | if os.path.isdir(logdir): 41 | if opt.expr_id <= 0: 42 | print( 43 | str_warning, ( 44 | "Will remove Experiment %d at\n\t%s\n" 45 | "Do you want to continue? (y/n)" 46 | ) % (opt.expr_id, logdir) 47 | ) 48 | need_input = True 49 | while need_input: 50 | response = input().lower() 51 | if response in ('y', 'n'): 52 | need_input = False 53 | if response == 'n': 54 | print(str_stage, "User decides to quit") 55 | sys.exit() 56 | os.system('rm -rf ' + logdir) 57 | else: 58 | raise ValueError(str_error + 59 | " Refuse to remove positive expr_id") 60 | os.system('mkdir -p ' + logdir) 61 | else: 62 | assert os.path.isdir(logdir) 63 | opt_f_old = os.path.join(logdir, 'opt.pt') 64 | opt = options_train.overwrite(opt, opt_f_old, unique_opt_params) 65 | 66 | # Save opt 67 | torch.save(vars(opt), os.path.join(logdir, 'opt.pt')) 68 | with open(os.path.join(logdir, 'opt.txt'), 'w') as fout: 69 | for k, v in vars(opt).items(): 70 | fout.write('%20s\t%-20s\n' % (k, v)) 71 | 72 | opt.full_logdir = logdir 73 | print(str_verbose, "Logging directory set to: %s" % logdir) 74 | 75 | ################################################### 76 | 77 | print(str_stage, "Setting up loggers") 78 | if opt.resume != 0 and os.path.isfile(os.path.join(logdir, 'best.pt')): 79 | try: 80 | prev_best_data = torch.load(os.path.join(logdir, 'best.pt')) 81 | prev_best = prev_best_data['loss_eval'] 82 | del prev_best_data 83 | except KeyError: 84 | prev_best = None 85 | else: 86 | prev_best = None 87 | best_model_logger = loggers.ModelSaveLogger( 88 | os.path.join(logdir, 'best.pt'), 89 | period=1, 90 | save_optimizer=True, 91 | save_best=True, 92 | prev_best=prev_best 93 | ) 94 | logger_list = [ 95 | loggers.TerminateOnNaN(), 96 | loggers.ProgbarLogger(allow_unused_fields='all'), 97 | loggers.CsvLogger( 98 | os.path.join(logdir, 'epoch_loss.csv'), 99 | allow_unused_fields='all' 100 | ), 101 | loggers.ModelSaveLogger( 102 | os.path.join(logdir, 'nets', '{epoch:04d}.pt'), 103 | period=opt.save_net, 104 | save_optimizer=opt.save_net_opt 105 | ), 106 | loggers.ModelSaveLogger( 107 | os.path.join(logdir, 'checkpoint.pt'), 108 | period=1, 109 | save_optimizer=True 110 | ), 111 | best_model_logger, 112 | ] 113 | if opt.log_batch: 114 | logger_list.append( 115 | loggers.BatchCsvLogger( 116 | os.path.join(logdir, 'batch_loss.csv'), 117 | allow_unused_fields='all' 118 | ) 119 | ) 120 | if opt.tensorboard: 121 | tf_logdir = os.path.join( 122 | opt.logdir, 'tensorboard', exprdir, str(opt.expr_id)) 123 | if os.path.isdir(tf_logdir) and opt.resume == 0: 124 | os.system('rm -r ' + tf_logdir) # remove previous tensorboard log if overwriting 125 | if not os.path.isdir(os.path.join(logdir, 'tensorboard')): 126 | os.symlink(tf_logdir, os.path.join(logdir, 'tensorboard')) 127 | logger_list.append( 128 | loggers.TensorBoardLogger( 129 | tf_logdir, 130 | allow_unused_fields='all' 131 | ) 132 | ) 133 | logger = loggers.ComposeLogger(logger_list) 134 | 135 | ################################################### 136 | 137 | print(str_stage, "Setting up models") 138 | Model = models.get_model(opt.net) 139 | model = Model(opt, logger) 140 | model.to(device) 141 | print(model) 142 | print("# model parameters: {:,d}".format(model.num_parameters())) 143 | 144 | initial_epoch = 1 145 | if opt.resume != 0: 146 | if opt.resume == -1: 147 | net_filename = os.path.join(logdir, 'checkpoint.pt') 148 | elif opt.resume == -2: 149 | net_filename = os.path.join(logdir, 'best.pt') 150 | else: 151 | net_filename = os.path.join( 152 | logdir, 'nets', '{epoch:04d}.pt').format(epoch=opt.resume) 153 | if not os.path.isfile(net_filename): 154 | print(str_warning, ("Network file not found for opt.resume=%d. " 155 | "Starting from scratch") % opt.resume) 156 | else: 157 | additional_values = model.load_state_dict(net_filename, load_optimizer='auto') 158 | try: 159 | initial_epoch += additional_values['epoch'] 160 | except KeyError as err: 161 | # Old saved model does not have epoch as additional values 162 | epoch_loss_csv = os.path.join(logdir, 'epoch_loss.csv') 163 | if opt.resume == -1: 164 | try: 165 | initial_epoch += pd.read_csv(epoch_loss_csv)['epoch'].max() 166 | except pd.errors.ParserError: 167 | with open(epoch_loss_csv, 'r') as f: 168 | lines = f.readlines() 169 | initial_epoch += max([int(l.split(',')[0]) for l in lines[1:]]) 170 | else: 171 | initial_epoch += opt.resume 172 | 173 | ################################################### 174 | 175 | print(str_stage, "Setting up data loaders") 176 | start_time = time.time() 177 | dataset = datasets.get_dataset(opt.dataset) 178 | dataset_train = dataset(opt, mode='train', model=model) 179 | dataset_vali = dataset(opt, mode='vali', model=model) 180 | dataloader_train = torch.utils.data.DataLoader( 181 | dataset_train, 182 | batch_size=opt.batch_size, 183 | shuffle=True, 184 | num_workers=opt.workers, 185 | pin_memory=True, 186 | drop_last=True 187 | ) 188 | dataloader_vali = torch.utils.data.DataLoader( 189 | dataset_vali, 190 | batch_size=opt.batch_size, 191 | num_workers=opt.workers, 192 | pin_memory=True, 193 | drop_last=True, 194 | shuffle=False 195 | ) 196 | print(str_verbose, "Time spent in data IO initialization: %.2fs" % 197 | (time.time() - start_time)) 198 | print(str_verbose, "# training points: " + str(len(dataset_train))) 199 | print(str_verbose, "# training batches per epoch: " + str(len(dataloader_train))) 200 | print(str_verbose, "# test batches: " + str(len(dataloader_vali))) 201 | 202 | ################################################### 203 | 204 | if opt.epoch > 0: 205 | print(str_stage, "Training") 206 | model.train_epoch( 207 | dataloader_train, 208 | dataloader_eval=dataloader_vali, 209 | max_batches_per_train=opt.epoch_batches, 210 | epochs=opt.epoch, 211 | initial_epoch=initial_epoch, 212 | max_batches_per_eval=opt.eval_batches, 213 | eval_at_start=opt.eval_at_start 214 | ) 215 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiumingzhang/GenRe-ShapeHD/ee42add2707de509b5914ab444ae91b832f75981/util/__init__.py -------------------------------------------------------------------------------- /util/util_cam_para.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def read_cam_para_from_xml(xml_name): 5 | # azi ele only 6 | import xml.etree.ElementTree 7 | e = xml.etree.ElementTree.parse(xml_name).getroot() 8 | 9 | assert len(e.findall('sensor')) == 1 10 | for x in e.findall('sensor'): 11 | assert len(x.findall('transform')) == 1 12 | for y in x.findall('transform'): 13 | assert len(y.findall('lookAt')) == 1 14 | for z in y.findall('lookAt'): 15 | origin = np.array(z.get('origin').split(','), dtype=np.float32) 16 | # up = np.array(z.get('up').split(','), dtype=np.float32) 17 | 18 | x, y, z = origin 19 | elevation = np.arctan2(y, np.sqrt(x ** 2 + z ** 2)) 20 | azimuth = np.arctan2(x, z) + np.pi 21 | if azimuth >= np.pi: 22 | azimuth -= 2 * np.pi 23 | assert azimuth >= -np.pi and azimuth <= np.pi 24 | assert elevation >= -np.pi / 2. and elevation <= np.pi / 2. 25 | return azimuth, elevation 26 | 27 | 28 | def raw_camparam_from_xml(path, pose="lookAt"): 29 | import xml.etree.ElementTree as ET 30 | tree = ET.parse(path) 31 | elm = tree.find("./sensor/transform/" + pose) 32 | camparam = elm.attrib 33 | origin = np.fromstring(camparam['origin'], dtype=np.float32, sep=',') 34 | target = np.fromstring(camparam['target'], dtype=np.float32, sep=',') 35 | up = np.fromstring(camparam['up'], dtype=np.float32, sep=',') 36 | height = int( 37 | tree.find("./sensor/film/integer[@name='height']").attrib['value']) 38 | width = int( 39 | tree.find("./sensor/film/integer[@name='width']").attrib['value']) 40 | 41 | camparam = dict() 42 | camparam['origin'] = origin 43 | camparam['up'] = up 44 | camparam['target'] = target 45 | camparam['height'] = height 46 | camparam['width'] = width 47 | return camparam 48 | 49 | 50 | def get_object_rotation(xml_path, style='zup'): 51 | style_set = ['yup', 'zup', 'spherical_proj'] 52 | assert(style in style_set) 53 | camparam = raw_camparam_from_xml(xml_path) 54 | if style == 'zup': 55 | Rx = camparam['target'] - camparam['origin'] 56 | up = camparam['up'] 57 | Rz = np.cross(Rx, up) 58 | Ry = np.cross(Rz, Rx) 59 | Rx /= np.linalg.norm(Rx) 60 | Ry /= np.linalg.norm(Ry) 61 | Rz /= np.linalg.norm(Rz) 62 | R = np.array([Rx, Ry, Rz]) 63 | R_coord = np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]]) 64 | R = R_coord @ R 65 | R = R @ R_coord.transpose() 66 | elif style == 'yup': 67 | Rx = camparam['target'] - camparam['origin'] 68 | up = camparam['up'] 69 | Rz = np.cross(Rx, up) 70 | Ry = np.cross(Rz, Rx) 71 | Rx /= np.linalg.norm(Rx) 72 | Ry /= np.linalg.norm(Ry) 73 | Rz /= np.linalg.norm(Rz) 74 | #print(Rx, Ry, Rz) 75 | # no transpose needed! 76 | R = np.array([Rx, Ry, Rz]) 77 | elif style == 'spherical_proj': 78 | Rx = camparam['target'] - camparam['origin'] 79 | up = camparam['up'] 80 | Rz = np.cross(Rx, up) 81 | Ry = np.cross(Rz, Rx) 82 | Rx /= np.linalg.norm(Rx) 83 | Ry /= np.linalg.norm(Ry) 84 | Rz /= np.linalg.norm(Rz) 85 | #print(Rx, Ry, Rz) 86 | # no transpose needed! 87 | R = np.array([Rx, Ry, Rz]) 88 | 89 | raise NotImplementedError 90 | return R 91 | 92 | 93 | def get_object_rotation_translation(xml_path, style='zup'): 94 | pass 95 | 96 | 97 | def _devide_into_section(angle, num_section, lower_bound, upper_bound): 98 | rst = np.zeros(num_section) 99 | per_section_size = (upper_bound - lower_bound) / num_section 100 | angle -= per_section_size / 2 101 | if angle < lower_bound: 102 | angle += upper_bound - lower_bound 103 | idx = int((angle - lower_bound) / per_section_size) 104 | rst[idx] = 1 105 | return rst 106 | 107 | 108 | def _section_to_angle(idx, num_section, lower_bound, upper_bound): 109 | per_section_size = (upper_bound - lower_bound) / num_section 110 | 111 | angle = (idx + 0.5) * per_section_size + lower_bound 112 | angle += per_section_size / 2 113 | if angle > upper_bound: 114 | angle -= upper_bound - lower_bound 115 | return angle 116 | 117 | 118 | def azimuth_to_onehot(azimuth, num_azimuth): 119 | return _devide_into_section(azimuth, num_azimuth, -np.pi, np.pi) 120 | 121 | 122 | def elevation_to_onehot(elevation, num_elevation): 123 | return _devide_into_section(elevation, num_elevation, -np.pi / 2., np.pi / 2.) 124 | 125 | 126 | def onehot_to_azimuth(v, num_azimuth): 127 | idx = np.argmax(v) 128 | return _section_to_angle(idx, num_azimuth, -np.pi, np.pi) 129 | 130 | 131 | def onehot_to_elevation(v, num_elevation): 132 | idx = np.argmax(v) 133 | return _section_to_angle(idx, num_elevation, -np.pi / 2., np.pi / 2.) 134 | 135 | 136 | if __name__ == '__main__': 137 | num_azimuth = 24 138 | num_elevation = 12 139 | for i in range(num_azimuth): 140 | rst = np.zeros(num_azimuth) 141 | rst[i] = 1 142 | print(onehot_to_azimuth(rst, num_azimuth)) 143 | 144 | ''' 145 | for i in range(100): 146 | angle = (np.random.rand() - 0.5) * np.pi * 2 147 | print(angle, np.argmax(azimuth_to_onehot(angle, 24)), onehot_to_azimuth(azimuth_to_onehot(angle, 24), 24)) 148 | assert np.abs(angle - onehot_to_azimuth(azimuth_to_onehot(angle, 24), 24)) < 2 * np.pi / 24 149 | ''' 150 | -------------------------------------------------------------------------------- /util/util_camera.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.misc import imresize 3 | from numba import jit 4 | 5 | 6 | @jit 7 | def calc_ptnum(triangle, density): 8 | pt_num_tr = np.zeros(len(triangle)).astype(int) 9 | pt_num_total = 0 10 | for tr_id, tr in enumerate(triangle): 11 | a = np.linalg.norm(np.cross(tr[1] - tr[0], tr[2] - tr[0])) / 2 12 | ptnum = max(int(a * density), 1) 13 | pt_num_tr[tr_id] = ptnum 14 | pt_num_total += ptnum 15 | return pt_num_tr, pt_num_total 16 | 17 | 18 | class Camera(): 19 | # camera coordinates: y up, z forward, x right. 20 | # consistent with blender definitions. 21 | # res = [w,h] 22 | def __init__(self): 23 | self.position = np.array([1.6, 0, 0]) 24 | self.rx = np.array([0, 1, 0]) 25 | self.ry = np.array([0, 0, 1]) 26 | self.rz = np.array([1, 0, 0]) 27 | self.res = [800, 600] 28 | self.focal_length = 0.05 29 | # set the diagnal to be 35mm film's diagnal 30 | self.set_diagal((0.036**2 + 0.024**2)**0.5) 31 | 32 | def rotate(self, rot_mat): 33 | self.rx = rot_mat[:, 0] 34 | self.ry = rot_mat[:, 1] 35 | self.rz = rot_mat[:, 2] 36 | 37 | def move_cam(self, new_pos): 38 | self.position = new_pos 39 | 40 | def set_pose(self, inward, up): 41 | self.rx = np.cross(up, inward) 42 | self.ry = np.array(up) 43 | self.rz = np.array(inward) 44 | self.rx /= np.linalg.norm(self.rx) 45 | self.ry /= np.linalg.norm(self.ry) 46 | self.rz /= np.linalg.norm(self.rz) 47 | 48 | def set_diagal(self, diag): 49 | h_relative = self.res[1] / self.res[0] 50 | self.sensor_width = np.sqrt(diag**2 / (1 + h_relative**2)) 51 | 52 | def lookat(self, orig, target, up): 53 | self.position = np.array(orig) 54 | target = np.array(target) 55 | inward = self.position - target 56 | right = np.cross(up, inward) 57 | up = np.cross(inward, right) 58 | self.set_pose(inward, up) 59 | 60 | def set_cam_from_mitsuba(self, path): 61 | camparam = util.cam_from_mitsuba(path) 62 | self.lookat(orig=camparam['origin'], 63 | up=camparam['up'], target=camparam['target']) 64 | self.res = [camparam['width'], camparam['height']] 65 | self.focal_length = 0.05 66 | # set the diagnal to be 35mm film's diagnal 67 | self.set_diagal((0.036**2 + 0.024**2)**0.5) 68 | 69 | def project_point(self, pt): 70 | # project global point to image coordinates in pixels (float not 71 | # integer). 72 | res = self.res 73 | rel = np.array(pt) - self.position 74 | depth = -np.dot(rel, self.rz) 75 | if rel.ndim != 1: 76 | depth = depth.reshape([np.size(depth, axis=0), 1]) 77 | rel_plane = rel * self.focal_length / depth 78 | rel_width = np.dot(rel_plane, self.rx) 79 | rel_height = np.dot(rel_plane, self.ry) 80 | topleft = np.array([-self.sensor_width / 2, 81 | self.sensor_width * (res[1] / res[0]) / 2]) 82 | pix_size = self.sensor_width / res[0] 83 | topleft += np.array([pix_size / 2, -pix_size / 2]) 84 | im_pix_x = (topleft[1] - rel_height) / pix_size 85 | im_pix_y = (rel_width - topleft[0]) / pix_size 86 | return im_pix_x, im_pix_y 87 | 88 | def project_depth(self, pt, depth_type='ray'): 89 | if depth_type == 'ray': 90 | if np.array(pt).ndim == 1: 91 | return np.linalg.norm(pt - self.position) 92 | return np.linalg.norm(pt - self.position, axis=1) 93 | else: 94 | return np.dot(pt - self.position, -self.rz) 95 | 96 | def pack(self): 97 | params = [] 98 | params += self.res 99 | params += [self.sensor_width] 100 | params += self.position.tolist() 101 | params += self.rx.tolist() 102 | params += self.ry.tolist() 103 | params += self.rz.tolist() 104 | params += [self.focal_length] 105 | return params 106 | 107 | 108 | class tsdf_renderer: 109 | def __init__(self): 110 | self.camera = Camera() 111 | self.depth = [] 112 | 113 | def load_depth_map_npy(self, path): 114 | self.depth = np.load(path) 115 | 116 | def back_project_ptcloud(self, upsample=1.0, depth_type='ray'): 117 | if not self.check_valid(): 118 | return 119 | mask = np.where(self.depth < 0, 0, 1) 120 | depth = imresize(self.depth, upsample, mode='F', interp='bilinear') 121 | up_mask = imresize(mask, upsample, mode='F', interp='bilinear') 122 | up_mask = np.where(up_mask < 1, 0, 1) 123 | ind = np.where(up_mask == 0) 124 | depth[ind] = -1 125 | # res = self.camera.res 126 | res = np.array([0, 0]) 127 | res[0] = np.shape(depth)[1] # width 128 | res[1] = np.shape(depth)[0] # height 129 | self.check_depth = np.zeros([res[1], res[0]], dtype=np.float32) - 1 130 | pt_pos = np.where(up_mask == 1) 131 | ptnum = len(pt_pos[0]) 132 | ptcld = np.zeros([ptnum, 3]) 133 | half_width = self.camera.sensor_width / 2 134 | half_height = half_width * res[1] / res[0] 135 | pix_size = self.camera.sensor_width / res[0] 136 | top_left = self.camera.position \ 137 | - self.camera.focal_length * self.camera.rz\ 138 | - half_width * self.camera.rx\ 139 | + half_height * self.camera.ry 140 | 141 | for x in range(ptnum): 142 | height_id = pt_pos[0][x] 143 | width_id = pt_pos[1][x] 144 | pix_depth = depth[height_id, width_id] 145 | pix_coord = - (height_id + 0.5) * pix_size * self.camera.ry\ 146 | + (width_id + 0.5) * pix_size * self.camera.rx\ 147 | + top_left 148 | pix_rel = pix_coord - self.camera.position 149 | if depth_type == 'plane': 150 | ptcld_pos = (pix_rel)\ 151 | * (pix_depth / self.camera.focal_length) \ 152 | + self.camera.position 153 | back_project_depth = -np.dot(pix_rel, self.camera.rz) 154 | else: 155 | ptcld_pos = (pix_rel / np.linalg.norm(pix_rel))\ 156 | * (pix_depth) + self.camera.position 157 | back_project_depth = np.linalg.norm( 158 | ptcld_pos - self.camera.position) 159 | ptcld[x, :] = ptcld_pos 160 | self.check_depth[height_id, width_id] = back_project_depth 161 | self.ptcld = ptcld 162 | self.pt_pos = pt_pos 163 | 164 | def check_valid(self, warning=True): 165 | if self.depth == []: 166 | print('No depth map available!') 167 | return False 168 | shape = np.shape(self.depth) 169 | if warning and (shape[0] != self.camera.res[1] or shape[1] != self.camera.res[0]): 170 | print('depth map and camera resolution mismatch!') 171 | print('camera: {}'.format(self.camera.res)) 172 | print('depth: {}'.format(shape)) 173 | return True 174 | return True 175 | -------------------------------------------------------------------------------- /util/util_loadlib.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | from .util_print import str_warning, str_verbose 3 | 4 | 5 | def set_gpu(gpu, check=True): 6 | import os 7 | _check_gpu(gpu) 8 | os.environ["CUDA_VISIBLE_DEVICES"] = gpu 9 | import torch 10 | import torch.backends.cudnn as cudnn 11 | cudnn.benchmark = True 12 | if check: 13 | if not _check_gpu_setting_in_use(gpu): 14 | print('[Warning] gpu setting overwritten. torch.cuda may be initialized before running this function.') 15 | 16 | 17 | def _check_gpu_setting_in_use(gpu): 18 | ''' 19 | check that CUDA_VISIBLE_DEVICES is actually working 20 | by starting a clean thread with the same CUDA_VISIBLE_DEVICES 21 | ''' 22 | import subprocess 23 | output = subprocess.check_output('CUDA_VISIBLE_DEVICES=%s python -c "import torch; print(torch.cuda.device_count())"' % gpu, shell=True) 24 | output = output.decode().strip() 25 | import torch 26 | return torch.cuda.device_count() == int(output) 27 | 28 | 29 | def _check_gpu(gpu): 30 | msg = subprocess.check_output('nvidia-smi --query-gpu=index,utilization.gpu,memory.used --format=csv,nounits,noheader -i %s' % (gpu,), shell=True) 31 | msg = msg.decode('utf-8') 32 | all_ok = True 33 | for line in msg.split('\n'): 34 | if line == '': 35 | break 36 | stats = [x.strip() for x in line.split(',')] 37 | gpu = stats[0] 38 | util = int(stats[1]) 39 | mem_used = int(stats[2]) 40 | if util > 10 or mem_used > 1000: # util in percentage and mem_used in MiB 41 | print(str_warning, 'Designated GPU in use: id=%s, util=%d%%, memory in use: %d MiB' % (gpu, util, mem_used)) 42 | all_ok = False 43 | if all_ok: 44 | print(str_verbose, 'All designated GPU(s) free to use. ') 45 | 46 | 47 | def set_manual_seed(seed): 48 | import random 49 | random.seed(seed) 50 | try: 51 | import numpy as np 52 | np.random.seed(seed) 53 | except ImportError as err: 54 | print('Numpy not found. Random seed for numpy not set. ') 55 | try: 56 | import torch 57 | torch.manual_seed(seed) 58 | torch.cuda.manual_seed_all(seed) 59 | except ImportError as err: 60 | print('Pytorch not found. Random seed for pytorch not set. ') 61 | -------------------------------------------------------------------------------- /util/util_print.py: -------------------------------------------------------------------------------- 1 | class bcolors: 2 | HEADER = '\033[95m' 3 | OKBLUE = '\033[94m' 4 | OKGREEN = '\033[92m' 5 | WARNING = '\033[93m' 6 | FAIL = '\033[91m' 7 | ENDC = '\033[0m' 8 | BOLD = '\033[1m' 9 | UNDERLINE = '\033[4m' 10 | 11 | 12 | str_stage = bcolors.OKBLUE + '==>' + bcolors.ENDC 13 | str_verbose = bcolors.OKGREEN + '[Verbose]' + bcolors.ENDC 14 | str_warning = bcolors.WARNING + '[Warning]' + bcolors.ENDC 15 | str_error = bcolors.FAIL + '[Error]' + bcolors.ENDC 16 | -------------------------------------------------------------------------------- /util/util_reproj.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | import torch 4 | from torch.autograd import Variable 5 | 6 | 7 | def cross_prod(u, v): 8 | # Cross pruduct between a set of vectors and a vector 9 | if len(u.size()) == 2: 10 | i = u[:, 1] * v[2] - u[:, 2] * v[1] 11 | j = u[:, 2] * v[0] - u[:, 0] * v[2] 12 | k = u[:, 0] * v[1] - u[:, 1] * v[0] 13 | return torch.stack((i, j, k), 1) 14 | elif len(u.size()) == 3: 15 | i = u[:, :, 1] * v[2] - u[:, :, 2] * v[1] 16 | j = u[:, :, 2] * v[0] - u[:, :, 0] * v[2] 17 | k = u[:, :, 0] * v[1] - u[:, :, 1] * v[0] 18 | return torch.stack((i, j, k), 2) 19 | raise Exception() 20 | 21 | 22 | def criterion_single(v, x, x_0, n_0, l, alpha=np.sqrt(2) / 2, beta=1, gamma=1.): 23 | v = v.view(-1) 24 | x = x.view(-1, 3) 25 | n_0 /= torch.sum(n_0 ** 2) 26 | 27 | # Find the voxel which is nearest to x_0 28 | _, index = torch.min(torch.sum((x - x_0) ** 2, dim=1), dim=0) 29 | i_0 = index.data.cpu().numpy()[0] 30 | 31 | # loss for (i_0, j_0, k_0) 32 | loss_1 = (1 - v[i_0]) ** 2 33 | 34 | # loss for others 35 | d = torch.sum(cross_prod((x - x_0), n_0) ** 2, dim=1) ** 0.5 36 | mask_1 = (d < alpha * l).float() 37 | mask_2 = torch.ones(*v.size()) 38 | mask_2[i_0] = 0 39 | mask_2 = Variable(mask_2.cuda()) 40 | loss_2 = torch.sum((gamma * (1 - d / (alpha * l)) ** beta * v ** 2) * mask_1 * mask_2) 41 | 42 | return loss_1 + loss_2 43 | 44 | 45 | def criterion(v, x, x_0, n_0, l, alpha=np.sqrt(2) / 2, beta=1, gamma=1.): 46 | n_sample = x_0.size(0) 47 | v = v.view(-1) 48 | x = x.view(-1, 3) 49 | n_0 /= torch.sum(n_0 ** 2) 50 | 51 | # Find the voxel which is nearest to x_0 52 | x_repeat = x.view(x.size(0), 1, x.size(1)).repeat(1, n_sample, 1) 53 | x_sub = x_repeat - x_0 54 | _, index = torch.min(torch.sum(x_sub ** 2, dim=2), dim=0) 55 | i_0 = index.data.cpu().numpy() 56 | 57 | # loss for (i_0, j_0, k_0) 58 | loss_1 = Variable(torch.zeros(1).cuda()) 59 | for i in range(n_sample): 60 | loss_1 += (1 - v[i_0[i]]) ** 2 61 | 62 | # loss for others 63 | d = torch.sum(cross_prod(x_sub, n_0) ** 2, dim=2) ** 0.5 64 | mask_1 = (d < alpha * l).float() 65 | mask_2 = torch.ones(v.size(0), n_sample) 66 | for i in range(n_sample): 67 | mask_2[i_0[i]][i] = 0 68 | mask_2 = Variable(mask_2.cuda()) 69 | v_repeat = v.view(v.size(0), 1).repeat(1, n_sample) 70 | loss_2 = torch.sum((gamma * (1 - d / (alpha * l)) ** beta * v_repeat ** 2) * mask_1 * mask_2) 71 | return loss_2 72 | 73 | 74 | if __name__ == '__main__': 75 | torch.manual_seed(70) 76 | n_sample = 90 77 | N = 128 78 | l = 1. 79 | v = Variable(torch.rand(N, N, N).cuda(), requires_grad=True) 80 | x = Variable(torch.rand(N, N, N, 3).cuda()) 81 | x_0 = Variable(torch.rand(n_sample, 3).cuda()) 82 | n_0 = Variable(torch.rand(3).cuda()) 83 | 84 | start = time.time() 85 | 86 | loss = criterion(v, x, x_0, n_0, l) 87 | 88 | ''' 89 | loss = Variable(torch.zeros(1).cuda()) 90 | for i in range(n_sample): 91 | loss += criterion_single(v, x, x_0[i], n_0, l) 92 | ''' 93 | 94 | loss.backward() 95 | print(v.grad[0, 0, 0]) 96 | 97 | end = time.time() 98 | print(end - start) 99 | 100 | 101 | u = Variable(torch.rand(N, 3).cuda()) 102 | v = Variable(torch.rand(3).cuda()) 103 | # print(cross_prod(u, v)) 104 | 105 | # print(np.cross(u.data.cpu().numpy()[0], v.data.cpu().numpy())) 106 | -------------------------------------------------------------------------------- /util/util_sph.py: -------------------------------------------------------------------------------- 1 | import trimesh 2 | from util.util_img import depth_to_mesh_df, resize 3 | from skimage import measure 4 | import numpy as np 5 | 6 | 7 | def render_model(mesh, sgrid): 8 | index_tri, index_ray, loc = mesh.ray.intersects_id( 9 | ray_origins=sgrid, ray_directions=-sgrid, multiple_hits=False, return_locations=True) 10 | loc = loc.reshape((-1, 3)) 11 | 12 | grid_hits = sgrid[index_ray] 13 | dist = np.linalg.norm(grid_hits - loc, axis=-1) 14 | dist_im = np.ones(sgrid.shape[0]) 15 | dist_im[index_ray] = dist 16 | im = dist_im 17 | return im 18 | 19 | 20 | def make_sgrid(b, alpha, beta, gamma): 21 | res = b * 2 22 | pi = np.pi 23 | phi = np.linspace(0, 180, res * 2 + 1)[1::2] 24 | theta = np.linspace(0, 360, res + 1)[:-1] 25 | grid = np.zeros([res, res, 3]) 26 | for idp, p in enumerate(phi): 27 | for idt, t in enumerate(theta): 28 | grid[idp, idt, 2] = np.cos((p * pi / 180)) 29 | proj = np.sin((p * pi / 180)) 30 | grid[idp, idt, 0] = proj * np.cos(t * pi / 180) 31 | grid[idp, idt, 1] = proj * np.sin(t * pi / 180) 32 | grid = np.reshape(grid, (res * res, 3)) 33 | return grid 34 | 35 | 36 | def render_spherical(data, mask, obj_path=None, debug=False): 37 | depth_im = data['depth'][0, 0, :, :] 38 | th = data['depth_minmax'] 39 | depth_im = resize(depth_im, 480, 'vertical') 40 | im = resize(mask, 480, 'vertical') 41 | gt_sil = np.where(im > 0.95, 1, 0) 42 | depth_im = depth_im * gt_sil 43 | depth_im = depth_im[:, :, np.newaxis] 44 | b = 64 45 | tdf = depth_to_mesh_df(depth_im, th, False, 1.0, 2.2) 46 | try: 47 | verts, faces, normals, values = measure.marching_cubes_lewiner( 48 | tdf, 0.999 / 128, spacing=(1 / 128, 1 / 128, 1 / 128)) 49 | mesh = trimesh.Trimesh(vertices=verts - 0.5, faces=faces) 50 | sgrid = make_sgrid(b, 0, 0, 0) 51 | im_depth = render_model(mesh, sgrid) 52 | im_depth = im_depth.reshape(2 * b, 2 * b) 53 | im_depth = np.where(im_depth > 1, 1, im_depth) 54 | except: 55 | im_depth = np.ones([128, 128]) 56 | return im_depth 57 | return im_depth 58 | -------------------------------------------------------------------------------- /util/util_xml_to_cam_params.py: -------------------------------------------------------------------------------- 1 | 2 | from glob import glob 3 | import re 4 | import argparse 5 | import numpy as np 6 | from pathlib import Path 7 | import os 8 | 9 | def raw_camparam_from_xml(path, pose="lookAt"): 10 | import xml.etree.ElementTree as ET 11 | tree = ET.parse(path) 12 | elm = tree.find("./sensor/transform/" + pose) 13 | camparam = elm.attrib 14 | origin = np.fromstring(camparam['origin'], dtype=np.float32, sep=',') 15 | target = np.fromstring(camparam['target'], dtype=np.float32, sep=',') 16 | up = np.fromstring(camparam['up'], dtype=np.float32, sep=',') 17 | height = int( 18 | tree.find("./sensor/film/integer[@name='height']").attrib['value']) 19 | width = int( 20 | tree.find("./sensor/film/integer[@name='width']").attrib['value']) 21 | 22 | camparam = dict() 23 | camparam['origin'] = origin 24 | camparam['up'] = up 25 | camparam['target'] = target 26 | camparam['height'] = height 27 | camparam['width'] = width 28 | return camparam 29 | 30 | def get_cam_pos(origin, target, up): 31 | inward = origin - target 32 | right = np.cross(up, inward) 33 | up = np.cross(inward, right) 34 | rx = np.cross(up, inward) 35 | ry = np.array(up) 36 | rz = np.array(inward) 37 | rx /= np.linalg.norm(rx) 38 | ry /= np.linalg.norm(ry) 39 | rz /= np.linalg.norm(rz) 40 | 41 | rot = np.stack([ 42 | rx, 43 | ry, 44 | -rz 45 | ], axis=0) 46 | 47 | 48 | aff = np.concatenate([ 49 | np.eye(3), -origin[:,None] 50 | ], axis=1) 51 | 52 | 53 | ext = np.matmul(rot, aff) 54 | 55 | result = np.concatenate( 56 | [ext, np.array([[0,0,0,1]])], axis=0 57 | ) 58 | 59 | 60 | 61 | return result 62 | 63 | 64 | 65 | def convert_cam_params_all_views(datapoint_dir, dataroot, camera_param_dir): 66 | depths = sorted(glob(os.path.join(datapoint_dir, '*depth.png'))) 67 | cam_ext = ['_'.join(re.sub(dataroot.strip('/'), camera_param_dir.strip('/'), f).split('_')[:-1])+'.xml' for f in depths] 68 | 69 | 70 | for i, (f, pth) in enumerate(zip(cam_ext, depths)): 71 | if not os.path.exists(f): 72 | continue 73 | params=raw_camparam_from_xml(f) 74 | origin, target, up, width, height = params['origin'], params['target'], params['up'],\ 75 | params['width'], params['height'] 76 | 77 | ext_matrix = get_cam_pos(origin, target, up) 78 | 79 | ##### 80 | diag = (0.036 ** 2 + 0.024 ** 2) ** 0.5 81 | focal_length = 0.05 82 | res = [480, 480] 83 | h_relative = (res[1] / res[0]) 84 | sensor_width = np.sqrt(diag ** 2 / (1 + h_relative ** 2)) 85 | pix_size = sensor_width / res[0] 86 | 87 | K = np.array([ 88 | [focal_length / pix_size, 0, (sensor_width / pix_size - 1) / 2], 89 | [0, -focal_length / pix_size, (sensor_width * (res[1] / res[0]) / pix_size - 1) / 2], 90 | [0, 0, 1] 91 | ]) 92 | 93 | np.savez(pth.split('depth.png')[0]+ 'cam_params.npz', extr=ext_matrix, intr=K) 94 | 95 | 96 | def main(opt): 97 | dataroot_dir = Path(opt.dataroot) 98 | 99 | leaf_subdirs = [] 100 | 101 | for dirpath, dirnames, filenames in os.walk(dataroot_dir): 102 | if (not dirnames) and opt.mitsuba_xml_root not in dirpath: 103 | leaf_subdirs.append(dirpath) 104 | 105 | 106 | 107 | for k, dir_ in enumerate(leaf_subdirs): 108 | print('Processing dir {}/{}: {}'.format(k, len(leaf_subdirs), dir_)) 109 | 110 | convert_cam_params_all_views(dir_, opt.dataroot, opt.mitsuba_xml_root) 111 | 112 | 113 | 114 | 115 | if __name__ == '__main__': 116 | args = argparse.ArgumentParser() 117 | args.add_argument('--dataroot', type=str, help='GenRe data root. Absolute path is recommanded.') 118 | # e.g. '/root/.../data/shapenet/' 119 | args.add_argument('--mitsuba_xml_root', type=str, help='XML directory root. Absolute path is recommanded.') 120 | # e.g. '/root/.../data/genre-xml_v2/' 121 | opt = args.parse_args() 122 | 123 | main(opt) 124 | -------------------------------------------------------------------------------- /visualize/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "voxel": { 3 | "isosurf_thres": 0.3 4 | } 5 | } 6 | -------------------------------------------------------------------------------- /visualize/visualizer.py: -------------------------------------------------------------------------------- 1 | from os.path import join, dirname 2 | from os import makedirs 3 | from shutil import copyfile 4 | from multiprocessing import Pool 5 | import atexit 6 | import json 7 | import numpy as np 8 | from skimage import measure 9 | from util.util_img import imwrite_wrapper 10 | 11 | 12 | class Visualizer(): 13 | """ 14 | Unified Visulization Worker 15 | """ 16 | paths = [ 17 | 'rgb_path', 18 | 'silhou_path', 19 | 'depth_path', 20 | 'normal_path', 21 | ] 22 | imgs = [ 23 | 'rgb', 24 | 'pred_depth', 25 | 'pred_silhou', 26 | 'pred_normal', 27 | ] 28 | voxels = [ 29 | 'pred_voxel_noft', 30 | 'pred_voxel', 31 | 'gen_voxel', 32 | ] # will go through sigmoid 33 | txts = [ 34 | 'gt_depth_minmax', 35 | 'pred_depth_minmax', 36 | 'disc', 37 | 'scores' 38 | ] 39 | sphmaps = [ 40 | 'pred_spherical_full', 41 | 'pred_spherical_partial', 42 | 'gt_spherical_full', 43 | ] 44 | voxels_gt = [ 45 | 'pred_proj_depth', 46 | 'gt_voxel', 47 | 'pred_proj_sph_full', 48 | ] 49 | 50 | def __init__(self, n_workers=4, param_f=None): 51 | if n_workers == 0: 52 | pool = None 53 | elif n_workers > 0: 54 | pool = Pool(n_workers) 55 | else: 56 | raise ValueError(n_workers) 57 | self.pool = pool 58 | if param_f: 59 | self.param_f = param_f 60 | else: 61 | self.param_f = join(dirname(__file__), 'config.json') 62 | 63 | def cleanup(): 64 | if pool: 65 | pool.close() 66 | pool.join() 67 | atexit.register(cleanup) 68 | 69 | def visualize(self, pack, batch_idx, outdir): 70 | if self.pool: 71 | self.pool.apply_async( 72 | self._visualize, 73 | [pack, batch_idx, self.param_f, outdir], 74 | error_callback=self._error_callback 75 | ) 76 | else: 77 | self._visualize(pack, batch_idx, self.param_f, outdir) 78 | 79 | @classmethod 80 | def _visualize(cls, pack, batch_idx, param_f, outdir): 81 | makedirs(outdir, exist_ok=True) 82 | 83 | # Dynamically read parameters from disk 84 | #param_dict = cls._read_params(param_f) 85 | voxel_isosurf_th = 0.25 # param_dict['voxel']['isosurf_thres'] 86 | 87 | batch_size = cls._get_batch_size(pack) 88 | instance_cnt = batch_idx * batch_size 89 | counter = 0 90 | for k in cls.paths: 91 | prefix = '{:04d}_%02d_' % counter + k.split('_')[0] + '.png' 92 | cls._cp_img(pack.get(k), join(outdir, prefix), instance_cnt) 93 | counter += 1 94 | for k in cls.imgs: 95 | prefix = '{:04d}_%02d_' % counter + k + '.png' 96 | cls._vis_img(pack.get(k), join(outdir, prefix), instance_cnt) 97 | counter += 1 98 | for k in cls.voxels_gt: 99 | prefix = '{:04d}_%02d_' % counter + k + '.obj' 100 | cls._vis_voxel(pack.get(k), join(outdir, prefix), instance_cnt, 101 | voxel_isosurf_th, False) 102 | counter += 1 103 | for k in cls.voxels: 104 | prefix = '{:04d}_%02d_' % counter + k + '.obj' 105 | cls._vis_voxel(pack.get(k), join(outdir, prefix), instance_cnt, 106 | voxel_isosurf_th) 107 | counter += 1 108 | for k in cls.txts: 109 | prefix = '{:04d}_%02d_' % counter + k + '.txt' 110 | cls._vis_txt(pack.get(k), join(outdir, prefix), instance_cnt) 111 | counter += 1 112 | for k in cls.sphmaps: 113 | prefix = '{:04d}_%02d_' % counter + k + '.png' 114 | cls._vis_sph(pack.get(k), join(outdir, prefix), instance_cnt) 115 | counter += 1 116 | 117 | @staticmethod 118 | def _read_params(param_f): 119 | with open(param_f, 'r') as h: 120 | param_dict = json.load(h) 121 | return param_dict 122 | 123 | @staticmethod 124 | def _get_batch_size(pack): 125 | batch_size = None 126 | for v in pack.values(): 127 | if hasattr(v, 'shape'): 128 | if batch_size is None or batch_size == 0: 129 | batch_size = v.shape[0] 130 | else: 131 | assert batch_size == v.shape[0] 132 | return batch_size 133 | 134 | @staticmethod 135 | def _sigmoid(x): 136 | return 1 / (1 + np.exp(-x)) 137 | 138 | @staticmethod 139 | def _to_obj_str(verts, faces): 140 | text = "" 141 | for p in verts: 142 | text += "v " 143 | for x in p: 144 | text += "{} ".format(x) 145 | text += "\n" 146 | for f in faces: 147 | text += "f " 148 | for x in f: 149 | text += "{} ".format(x + 1) 150 | text += "\n" 151 | return text 152 | 153 | @classmethod 154 | def _save_iso_obj(cls, df, path, th, shift=True): 155 | if th < np.min(df): 156 | df[0, 0, 0] = th - 1 157 | if th > np.max(df): 158 | df[-1, -1, -1] = th + 1 159 | spacing = (1 / 128, 1 / 128, 1 / 128) 160 | verts, faces, _, _ = measure.marching_cubes_lewiner( 161 | df, th, spacing=spacing) 162 | if shift: 163 | verts -= np.array([0.5, 0.5, 0.5]) 164 | obj_str = cls._to_obj_str(verts, faces) 165 | with open(path, 'w') as f: 166 | f.write(obj_str) 167 | 168 | @staticmethod 169 | def _vis_img(img, output_pattern, counter=0): 170 | if img is not None and not isinstance(img, str): 171 | assert img.shape[0] != 0 172 | img = np.clip(img * 255, 0, 255).astype(int) 173 | img = np.transpose(img, (0, 2, 3, 1)) 174 | bsize = img.shape[0] 175 | for batch_id in range(bsize): 176 | im = img[batch_id, :, :, :] 177 | imwrite_wrapper(output_pattern.format(counter + batch_id), im) 178 | 179 | @staticmethod 180 | def _vis_sph(img, output_pattern, counter=0): 181 | if img is not None and not isinstance(img, str): 182 | assert img.shape[0] != 0 183 | img = np.transpose(img, (0, 2, 3, 1)) 184 | bsize = img.shape[0] 185 | for batch_id in range(bsize): 186 | im = img[batch_id, :, :, 0] 187 | im = im / im.max() 188 | im = np.clip(im * 255, 0, 255).astype(int) 189 | imwrite_wrapper(output_pattern.format(counter + batch_id), im) 190 | 191 | @staticmethod 192 | def _cp_img(paths, output_pattern, counter=0): 193 | if paths is not None: 194 | for batch_id, path in enumerate(paths): 195 | copyfile(path, output_pattern.format(counter + batch_id)) 196 | 197 | @classmethod 198 | def _vis_voxel(cls, voxels, output_pattern, counter=0, th=0.5, use_sigmoid=True): 199 | if voxels is not None: 200 | assert voxels.shape[0] != 0 201 | for batch_id, voxel in enumerate(voxels): 202 | if voxel.ndim == 4: 203 | voxel = voxel[0, ...] 204 | voxel = cls._sigmoid(voxel) if use_sigmoid else voxel 205 | cls._save_iso_obj(voxel, output_pattern.format(counter + batch_id), th=th) 206 | 207 | @staticmethod 208 | def _vis_txt(txts, output_pattern, counter=0): 209 | if txts is not None: 210 | for batch_id, txt in enumerate(txts): 211 | with open(output_pattern.format(counter + batch_id), 'w') as h: 212 | h.write("%s\n" % txt) 213 | 214 | @staticmethod 215 | def _error_callback(e): 216 | print(str(e)) 217 | --------------------------------------------------------------------------------