├── .gitignore
├── README.md
├── build_toolbox.sh
├── clean_toolbox_build.sh
├── datasets
├── __init__.py
├── shapenet.py
└── test.py
├── downloads
├── data
│ └── test
│ │ ├── genre
│ │ ├── 03001627_10c08a28cae054e53a762233fffc49ea_view000_rgb.png
│ │ ├── 03001627_10c08a28cae054e53a762233fffc49ea_view000_silhouette.png
│ │ ├── 04256520_2c6dcb7184bfed32599dcc439b161a52_view010_rgb.png
│ │ ├── 04256520_2c6dcb7184bfed32599dcc439b161a52_view010_silhouette.png
│ │ ├── 04256520_2d987393f7f7c5d1f51f77a6d7299806_view001_rgb.png
│ │ ├── 04256520_2d987393f7f7c5d1f51f77a6d7299806_view001_silhouette.png
│ │ ├── 04379243_133d7c9a1f79b01ad0176f9a144100cd_view000_rgb.png
│ │ └── 04379243_133d7c9a1f79b01ad0176f9a144100cd_view000_silhouette.png
│ │ └── shapehd
│ │ ├── 0044_mask.png
│ │ ├── 0044_rgb.png
│ │ ├── 0503_mask.png
│ │ ├── 0503_rgb.jpg
│ │ ├── 1209_mask.png
│ │ └── 1209_rgb.jpg
└── results
│ ├── genre.png
│ └── shapehd.png
├── environment.yml
├── install_trimesh.sh
├── loggers
├── Progbar.py
├── __init__.py
└── loggers.py
├── models
├── __init__.py
├── depth_pred_with_sph_inpaint.py
├── genre_full_model.py
├── marrnet.py
├── marrnet1.py
├── marrnet2.py
├── marrnetbase.py
├── netinterface.py
├── shapehd.py
└── wgangp.py
├── networks
├── __init__.py
├── networks.py
├── revresnet.py
└── uresnet.py
├── options
├── __init__.py
├── options_test.py
└── options_train.py
├── scripts
├── finetune_marrnet.sh
├── finetune_shapehd.sh
├── test_genre.sh
├── test_marrnet.sh
├── test_shapehd.sh
├── train_full_genre.sh
├── train_inpaint.sh
├── train_marrnet1.sh
├── train_marrnet2.sh
└── train_wgangp.sh
├── test.py
├── toolbox
├── __init__.py
├── calc_prob
│ ├── build.py
│ ├── calc_prob
│ │ ├── __init__.py
│ │ ├── functions
│ │ │ ├── __init__.py
│ │ │ └── calc_prob.py
│ │ └── src
│ │ │ ├── calc_prob.c
│ │ │ ├── calc_prob.h
│ │ │ ├── calc_prob_kernel.cu
│ │ │ └── calc_prob_kernel.h
│ ├── clean.sh
│ ├── setup.py
│ └── setup.sh
├── cam_bp
│ ├── build.py
│ ├── cam_bp
│ │ ├── __init__.py
│ │ ├── functions
│ │ │ ├── __init__.py
│ │ │ ├── cam_back_projection.py
│ │ │ ├── get_surface_mask.py
│ │ │ └── sperical_to_tdf.py
│ │ ├── modules
│ │ │ ├── Spherical_backproj.py
│ │ │ ├── __init__.py
│ │ │ └── camera_backprojection_module.py
│ │ └── src
│ │ │ ├── _cam_bp_lib.abi3.so
│ │ │ ├── back_projection.c
│ │ │ ├── back_projection.h
│ │ │ ├── back_projection_kernel.cu
│ │ │ └── back_projection_kernel.h
│ ├── clean.sh
│ ├── setup.py
│ └── setup.sh
├── nndistance
│ ├── README.md
│ ├── build.py
│ ├── clean.sh
│ ├── functions
│ │ ├── __init__.py
│ │ └── nnd.py
│ ├── modules
│ │ ├── __init__.py
│ │ └── nnd.py
│ ├── setup.sh
│ ├── src
│ │ ├── my_lib.c
│ │ ├── my_lib.h
│ │ ├── my_lib_cuda.c
│ │ ├── my_lib_cuda.h
│ │ ├── nnd_cuda.cu
│ │ └── nnd_cuda.h
│ └── test.py
└── spherical_proj.py
├── train.py
├── util
├── __init__.py
├── util_cam_para.py
├── util_camera.py
├── util_img.py
├── util_io.py
├── util_loadlib.py
├── util_print.py
├── util_reproj.py
├── util_sph.py
├── util_voxel.py
└── util_xml_to_cam_params.py
└── visualize
├── config.json
└── visualizer.py
/.gitignore:
--------------------------------------------------------------------------------
1 | !.gitignore
2 |
3 | **/__pycache__/
4 | **/*.pyc
5 | **/*.swp
6 |
7 | code_test
8 | private
9 | downloads/data/shapenet_cars_chairs_planes_20views.tar
10 | downloads/data/shapenet
11 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Generalizable Reconstruction (GenRe) and ShapeHD
2 |
3 |
4 | ## Papers
5 |
6 | This is a repo covering the following three papers. If you find the code useful, please cite the paper(s).
7 |
8 | 1. Generalizable Reconstruction (GenRe)
9 | **Learning to Reconstruct Shapes from Unseen Classes**
10 | [Xiuming Zhang](http://people.csail.mit.edu/xiuming/)*, [Zhoutong Zhang](https://ztzhang.info)*, [Chengkai Zhang](https://www.csail.mit.edu/person/chengkai-zhang), [Joshua B. Tenenbaum](http://web.mit.edu/cocosci/josh.html), [William T. Freeman](https://billf.mit.edu/), and [Jiajun Wu](https://jiajunwu.com/)
11 | *NeurIPS 2018 (Oral)*
12 | [Paper](http://genre.csail.mit.edu/papers/genre_nips.pdf) | [BibTeX](http://genre.csail.mit.edu/bibtex/genre_nips.bib) | [Project](http://genre.csail.mit.edu/)
13 |
14 | * indicates equal contribution.
15 |
16 | 1. ShapeHD
17 | **Learning Shape Priors for Single-View 3D Completion and Reconstruction**
18 | [Jiajun Wu](https://jiajunwu.com/)*, [Chengkai Zhang](https://www.csail.mit.edu/person/chengkai-zhang)*, [Xiuming Zhang](http://people.csail.mit.edu/xiuming/), [Zhoutong Zhang](https://ztzhang.info), [William T. Freeman](https://billf.mit.edu/), and [Joshua B. Tenenbaum](http://web.mit.edu/cocosci/josh.html)
19 | *ECCV 2018*
20 | [Paper](http://shapehd.csail.mit.edu/papers/shapehd_eccv.pdf) | [BibTeX](http://shapehd.csail.mit.edu/bibtex/shapehd_eccv.bib) | [Project](http://shapehd.csail.mit.edu/)
21 |
22 | 1. MarrNet
23 | **MarrNet: 3D Shape Reconstruction via 2.5D Sketches**
24 | [Jiajun Wu](https://jiajunwu.com/)*, [Yifan Wang](https://homes.cs.washington.edu/~yifan1/)*, [Tianfan Xue](https://people.csail.mit.edu/tfxue/), [Xingyuan Sun](http://people.csail.mit.edu/xingyuan/), [William T. Freeman](https://billf.mit.edu/), and [Joshua B. Tenenbaum](http://web.mit.edu/cocosci/josh.html)
25 | *NeurIPS 2017*
26 | [Paper](http://marrnet.csail.mit.edu/papers/marrnet_nips.pdf) | [BibTeX](http://marrnet.csail.mit.edu/bibtex/marrnet_nips.bib) | [Project](http://marrnet.csail.mit.edu/)
27 |
28 |
29 |
30 | ## Environment Setup
31 |
32 | All code was built and tested on Ubuntu 16.04.5 LTS with Python 3.6, PyTorch 0.4.1, and CUDA 9.0. Versions for other packages can be found in `environment.yml`.
33 |
34 | 1. Clone this repo with
35 | ```
36 | # cd to the directory you want to work in
37 | git clone https://github.com/xiumingzhang/GenRe-ShapeHD.git
38 | cd GenRe-ShapeHD
39 | ```
40 | The code below assumes you are at the repo root.
41 |
42 | 1. Create a conda environment named `shaperecon` with necessary dependencies specified in `environment.yml`. In order to make sure trimesh is installed correctly, please run `install_trimesh.sh` after setting up the conda environment.
43 | ```
44 | conda env create -f environment.yml
45 | ./install_trimesh.sh
46 | ```
47 | The TensorFlow dependency in `environment.yml` is for using TensorBoard only. Remove it if you do not want to monitor your training with TensorBoard.
48 |
49 | 1. The instructions below assume you have activated this environment and built the cuda extension with
50 | ```
51 | source activate shaperecon
52 | ./build_toolbox.sh
53 | ```
54 | Note that due to the deprecation of cffi from pytorch 1.0 and on, this only works for pytorch 0.4.1.
55 |
56 |
57 | ## Downloading Our Trained Models and Training Data
58 |
59 | ### Models
60 |
61 | To download our trained GenRe and ShapeHD models (1 GB in total), run
62 | ```
63 | wget http://genre.csail.mit.edu/downloads/genre_shapehd_models.tar -P downloads/models/
64 | tar -xvf downloads/models/genre_shapehd_models.tar -C downloads/models/
65 | ```
66 |
67 | * GenRe: `depth_pred_with_inpaint.pt` and `full_model.pt`
68 | * ShapeHD: `marrnet1_with_minmax.pt` and `shapehd.pt`
69 |
70 | ### Data
71 |
72 | This repo comes with a few [Pix3D](http://pix3d.csail.mit.edu/) images and [ShapeNet](https://www.shapenet.org/) renderings, located in `downloads/data/test`, for testing purposes.
73 |
74 | For training, we make available our RGB and 2.5D sketch renderings, paired with their corresponding 3D shapes, for ShapeNet cars, chairs, and airplanes, with each object captured in 20 random views. Note that this `.tar` is 143 GB.
75 | ```
76 | wget http://genre.csail.mit.edu/downloads/shapenet_cars_chairs_planes_20views.tar -P downloads/data/
77 | mkdir downloads/data/shapenet/
78 | tar -xvf downloads/data/shapenet_cars_chairs_planes_20views.tar -C downloads/data/shapenet/
79 | ```
80 |
81 | **New (Oct. 20, 2019)**
82 |
83 | For training, in addition to the renderings already included in the initial release, we now also release the Mitsuba scene `.xml` files used to produce these renderings. [This download link](http://genre.csail.mit.edu/downloads/training_xml.zip) is a `.zip` (394 MB) consisting of the three training classes: cars, chairs, and airplanes. Among other scene parameters, camera poses can now be retrieved from these `.xml` files, which we hope would be useful for tasks like camera/object pose estimation.
84 |
85 | For testing, we release the data of the unseen categories shown in Table 1 of the paper. [This download link](http://genre.csail.mit.edu/downloads/shapenet_unseen.tar) is a `.tar` (44 GB) consisting of, for each of the unseen classes, the 500 random shapes we used for testing GenRe. Right now, nine classes are included, as we are tracking down the 10th.
86 |
87 |
88 | ## Testing with Our Models
89 |
90 | We provide `.sh` wrappers to perform testing for GenRe, ShapeHD, and MarrNet (without the reprojection consistency part).
91 |
92 | ### GenRe
93 |
94 | See `scripts/test_genre.sh`.
95 |
96 |
97 |
98 |
99 |
100 | We updated our entire pipeline to support fully differentiable end-to-end finetuning. In our NeurIPS submission, the projection from depth images to spherical maps was not implemented in a differentiable way. As a result of both the pipeline and PyTorch version upgrades, the model performace is slightly different from what was reported in the original paper.
101 |
102 | Below we tabulate the original vs. updated Chamfer distances (CD) across different Pix3D classes. The "Original" row is from Table 2 of the paper.
103 |
104 | | |Chair | Bed | Bookcase | Desk | Sofa | Table | Wardrobe |
105 | |----------|:----:|:---:|:---:|:---:|:---:|:---:|:---:|
106 | | **Updated** | .094 | .117 | .104 | .110 | .086 | .114 | .106 |
107 | | **Original** | .093 | .113 | .101 | .109 | .083 | .116 | .109 |
108 |
109 | ### ShapeHD
110 |
111 | See `scripts/test_shapehd.sh`.
112 |
113 |
114 |
115 |
116 |
117 | After ECCV, we upgraded our entire pipeline and re-trained ShapeHD with this new pipeline. The models released here are newly trained, producing quantative results slightly better than what was reported in the ECCV paper. If you use [the Pix3D repo](https://github.com/xingyuansun/pix3d) to evaluate the model released here, you will get an average CD of 0.122 for the 1,552 untruncated, unoccluded chair images (whose inplane rotation < 5°). The average CD on Pix3D chairs reported in the paper was 0.123.
118 |
119 | ### MarrNet w/o Reprojection Consistency
120 |
121 | See `scripts/test_marrnet.sh`.
122 |
123 | The architectures in this implementation of MarrNet are different from those presented in the original NeurIPS 2017 paper. For instance, the reprojection consistency is not implemented here. MarrNet-1 that predicts 2.5D sketches from RGB inputs is now a U-ResNet, different from its original architecture. That said, the idea remains the same: predicting 2.5D sketches as an intermediate step to the final 3D voxel predictions.
124 |
125 | If you want to test with the original MarrNet, see [the MarrNet repo](https://github.com/jiajunwu/marrnet) for the pretrained models.
126 |
127 |
128 | ## Training Your Own Models
129 |
130 | This repo allows you to train your own models from scratch, possibly with data different from our training data provided above. You can monitor your training with TensorBoard. For that, make sure to include `--tensorboard` while running `train.py`, and then run
131 | ```
132 | python -m tensorboard.main --logdir="$logdir"/tensorboard
133 | ```
134 | to visualize your losses.
135 |
136 | ### GenRe
137 |
138 | Follow these steps to train the GenRe model.
139 | 1. Train the depth estimator with `scripts/train_marrnet1.sh`
140 | 1. Train the spherical inpainting network with `scripts/train_inpaint.sh`
141 | 1. Train the full model with `scripts/train_full_genre.sh`
142 |
143 | ### ShapeHD
144 |
145 | Follow these steps to train the ShapeHD model.
146 | 1. Train the 2.5D sketch estimator with `scripts/train_marrnet1.sh`
147 | 1. Train the 2.5D-to-3D network with `scripts/train_marrnet2.sh`
148 | 1. Train a 3D-GAN with `scripts/train_wgangp.sh`
149 | 1. Finetune the 2.5D-to-3D network with perceptual losses provided by the discriminator of the 3D-GAN, using `scripts/finetune_shapehd.sh`
150 |
151 | ### MarrNet w/o Reprojection Consistency
152 |
153 | Follow these steps to train the MarrNet model, excluding the reprojection consistency.
154 | 1. Train the 2.5D sketch estimator with `scripts/train_marrnet1.sh`
155 | 1. Train the 2.5D-to-3D network with `scripts/train_marrnet2.sh`
156 | 1. Finetune the 2.5D-to-3D network with `scripts/finetune_marrnet.sh`
157 |
158 |
159 | ## Questions
160 |
161 | Please open an issue if you encounter any problem. You will likely get a quicker response than via email.
162 |
163 |
164 | ## Changelog
165 |
166 | * Dec. 28, 2018: Initial release
167 | * Oct. 20, 2019: Added testing data of the unseen categories, and all `.xml` scene files used to render training data
168 |
--------------------------------------------------------------------------------
/build_toolbox.sh:
--------------------------------------------------------------------------------
1 | cd toolbox/calc_prob
2 | bash setup.sh script
3 | cd ../../
4 | cd toolbox/nndistance
5 | bash setup.sh script
6 | cd ../../
7 | cd toolbox/cam_bp
8 | bash setup.sh script
9 |
--------------------------------------------------------------------------------
/clean_toolbox_build.sh:
--------------------------------------------------------------------------------
1 | cd toolbox/calc_prob
2 | bash clean.sh
3 | cd ../../
4 | cd toolbox/nndistance
5 | bash clean.sh
6 | cd ../../
7 | cd toolbox/cam_bp
8 | bash clean.sh
9 |
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 |
3 |
4 | def get_dataset(alias):
5 | dataset_module = importlib.import_module('datasets.' + alias.lower())
6 | return dataset_module.Dataset
7 |
--------------------------------------------------------------------------------
/datasets/shapenet.py:
--------------------------------------------------------------------------------
1 | from os.path import join
2 | import random
3 | import numpy as np
4 | from scipy.io import loadmat
5 | import torch.utils.data as data
6 | import util.util_img
7 |
8 |
9 | class Dataset(data.Dataset):
10 | data_root = './downloads/data/shapenet'
11 | list_root = join(data_root, 'status')
12 | status_and_suffix = {
13 | 'rgb': {
14 | 'status': 'rgb.txt',
15 | 'suffix': '_rgb.png',
16 | },
17 | 'depth': {
18 | 'status': 'depth.txt',
19 | 'suffix': '_depth.png',
20 | },
21 | 'depth_minmax': {
22 | 'status': 'depth_minmax.txt',
23 | 'suffix': '.npy',
24 | },
25 | 'silhou': {
26 | 'status': 'silhou.txt',
27 | 'suffix': '_silhouette.png',
28 | },
29 | 'normal': {
30 | 'status': 'normal.txt',
31 | 'suffix': '_normal.png'
32 | },
33 | 'voxel': {
34 | 'status': 'vox_rot.txt',
35 | 'suffix': '_gt_rotvox_samescale_128.npz'
36 | },
37 | 'spherical': {
38 | 'status': 'spherical.txt',
39 | 'suffix': '_spherical.npz'
40 | },
41 | 'voxel_canon': {
42 | 'status': 'vox_canon.txt',
43 | 'suffix': '_voxel_normalized_128.mat'
44 | },
45 | }
46 | class_aliases = {
47 | 'drc': '03001627+02691156+02958343',
48 | 'chair': '03001627',
49 | 'table': '04379243',
50 | 'sofa': '04256520',
51 | 'couch': '04256520',
52 | 'cabinet': '03337140',
53 | 'bed': '02818832',
54 | 'plane': '02691156',
55 | 'car': '02958343',
56 | 'bench': '02828884',
57 | 'monitor': '03211117',
58 | 'lamp': '03636649',
59 | 'speaker': '03691459',
60 | 'firearm': '03948459+04090263',
61 | 'cellphone': '02992529+04401088',
62 | 'watercraft': '04530566',
63 | 'hat': '02954340',
64 | 'pot': '03991062',
65 | 'rocket': '04099429',
66 | 'train': '04468005',
67 | 'bus': '02924116',
68 | 'pistol': '03948459',
69 | 'faucet': '03325088',
70 | 'helmet': '03513137',
71 | 'clock': '03046257',
72 | 'phone': '04401088',
73 | 'display': '03211117',
74 | 'vessel': '04530566',
75 | 'rifle': '04090263',
76 | 'small': '03001627+04379243+02933112+04256520+02958343+03636649+02691156+04530566',
77 | 'all-but-table': '02691156+02747177+02773838+02801938+02808440+02818832+02828884+02843684+02871439+02876657+02880940+02924116+02933112+02942699+02946921+02954340+02958343+02992529+03001627+03046257+03085013+03207941+03211117+03261776+03325088+03337140+03467517+03513137+03593526+03624134+03636649+03642806+03691459+03710193+03759954+03761084+03790512+03797390+03928116+03938244+03948459+03991062+04004475+04074963+04090263+04099429+04225987+04256520+04330267+04401088+04460130+04468005+04530566+04554684',
78 | 'all-but-chair': '02691156+02747177+02773838+02801938+02808440+02818832+02828884+02843684+02871439+02876657+02880940+02924116+02933112+02942699+02946921+02954340+02958343+02992529+03046257+03085013+03207941+03211117+03261776+03325088+03337140+03467517+03513137+03593526+03624134+03636649+03642806+03691459+03710193+03759954+03761084+03790512+03797390+03928116+03938244+03948459+03991062+04004475+04074963+04090263+04099429+04225987+04256520+04330267+04379243+04401088+04460130+04468005+04530566+04554684',
79 | 'all': '02691156+02747177+02773838+02801938+02808440+02818832+02828884+02843684+02871439+02876657+02880940+02924116+02933112+02942699+02946921+02954340+02958343+02992529+03001627+03046257+03085013+03207941+03211117+03261776+03325088+03337140+03467517+03513137+03593526+03624134+03636649+03642806+03691459+03710193+03759954+03761084+03790512+03797390+03928116+03938244+03948459+03991062+04004475+04074963+04090263+04099429+04225987+04256520+04330267+04379243+04401088+04460130+04468005+04530566+04554684',
80 | }
81 | class_list = class_aliases['all'].split('+')
82 |
83 | @classmethod
84 | def add_arguments(cls, parser):
85 | return parser, set()
86 |
87 | @classmethod
88 | def read_bool_status(cls, status_file):
89 | with open(join(cls.list_root, status_file)) as f:
90 | lines = f.read()
91 | return [x == 'True' for x in lines.split('\n')[:-1]]
92 |
93 | def __init__(self, opt, mode='train', model=None):
94 | assert mode in ('train', 'vali')
95 | self.mode = mode
96 | if model is None:
97 | required = ['rgb']
98 | self.preproc = None
99 | else:
100 | required = model.requires
101 | self.preproc = model.preprocess
102 |
103 | # Parse classes
104 | classes = [] # alias to real for locating data
105 | class_str = '' # real to alias for logging
106 | for c in opt.classes.split('+'):
107 | class_str += c + '+'
108 | if c in self.class_aliases: # nickname given
109 | classes += self.class_aliases[c].split('+')
110 | else:
111 | classes = c.split('+')
112 | class_str = class_str[:-1] # removes the final +
113 | classes = sorted(list(set(classes)))
114 |
115 | # Load items and train-test split
116 | with open(join(self.list_root, 'items_all.txt')) as f:
117 | lines = f.read()
118 | item_list = lines.split('\n')[:-1]
119 | is_train = self.read_bool_status('is_train.txt')
120 | assert len(item_list) == len(is_train)
121 |
122 | # Load status the network requires
123 | has = {}
124 | for data_type in required:
125 | assert data_type in self.status_and_suffix.keys(), \
126 | "%s required, but unspecified in status_and_suffix" % data_type
127 | has[data_type] = self.read_bool_status(
128 | self.status_and_suffix[data_type]['status']
129 | )
130 | assert len(has[data_type]) == len(item_list)
131 |
132 | # Pack paths into a dict
133 | samples = []
134 | for i, item in enumerate(item_list):
135 | class_id, _ = item.split('/')[:2]
136 | item_in_split = ((self.mode == 'train') == is_train[i])
137 | if item_in_split and class_id in classes:
138 | # Look up subclass_id for this item
139 | sample_dict = {'item': join(self.data_root, item)}
140 | # As long as a type is required, it appears as a key
141 | # If it doens't exist, its value will be None
142 | for data_type in required:
143 | suffix = self.status_and_suffix[data_type]['suffix']
144 | k = data_type + '_path'
145 | if data_type == 'voxel_canon':
146 | # All different views share the same canonical voxel
147 | sample_dict[k] = join(self.data_root, item.split('_view')[0] + suffix) \
148 | if has[data_type][i] else None
149 | else:
150 | sample_dict[k] = join(self.data_root, item + suffix) \
151 | if has[data_type][i] else None
152 | if None not in sample_dict.values():
153 | # All that are required exist
154 | samples.append(sample_dict)
155 |
156 | # If validation, dataloader shuffle will be off, so need to DETERMINISTICALLY
157 | # shuffle here to have a bit of every class
158 | if self.mode == 'vali':
159 | if opt.manual_seed:
160 | seed = opt.manual_seed
161 | else:
162 | seed = 0
163 | random.Random(seed).shuffle(samples)
164 | self.samples = samples
165 |
166 | def __getitem__(self, i):
167 | sample_loaded = {}
168 | for k, v in self.samples[i].items():
169 | sample_loaded[k] = v # as-is
170 | if k.endswith('_path'):
171 | if v.endswith('.png'):
172 | im = util.util_img.imread_wrapper(
173 | v, util.util_img.IMREAD_UNCHANGED,
174 | output_channel_order='RGB')
175 | # Normalize to [0, 1] floats
176 | im = im.astype(float) / float(np.iinfo(im.dtype).max)
177 | sample_loaded[k[:-5]] = im
178 | elif v.endswith('.npy'):
179 | # Right now .npy must be depth_minmax
180 | sample_loaded['depth_minmax'] = np.load(v)
181 | elif v.endswith('_128.npz'):
182 | sample_loaded['voxel'] = np.load(v)['voxel'][None, ...]
183 | elif v.endswith('_spherical.npz'):
184 | spherical_data = np.load(v)
185 | sample_loaded['spherical_object'] = spherical_data['obj_spherical'][None, ...]
186 | sample_loaded['spherical_depth'] = spherical_data['depth_spherical'][None, ...]
187 | elif v.endswith('.mat'):
188 | # Right now .mat must be voxel_canon
189 | sample_loaded['voxel_canon'] = loadmat(v)['voxel'][None, ...]
190 | else:
191 | raise NotImplementedError(v)
192 | # Three identical channels for grayscale images
193 | if self.preproc is not None:
194 | sample_loaded = self.preproc(sample_loaded, mode=self.mode)
195 | # convert all types to float32 for better copy speed
196 | self.convert_to_float32(sample_loaded)
197 | return sample_loaded
198 |
199 | @staticmethod
200 | def convert_to_float32(sample_loaded):
201 | for k, v in sample_loaded.items():
202 | if isinstance(v, np.ndarray):
203 | if v.dtype != np.float32:
204 | sample_loaded[k] = v.astype(np.float32)
205 |
206 | def __len__(self):
207 | return len(self.samples)
208 |
209 | def get_classes(self):
210 | return self._class_str
211 |
--------------------------------------------------------------------------------
/datasets/test.py:
--------------------------------------------------------------------------------
1 | from glob import glob
2 | import numpy as np
3 | import torch.utils.data as data
4 | import util.util_img
5 |
6 |
7 | class Dataset(data.Dataset):
8 | @classmethod
9 | def add_arguments(cls, parser):
10 | return parser, set()
11 |
12 | def __init__(self, opt, model):
13 | # Get required keys and preprocessing from the model
14 | required = model.requires
15 | self.preproc = model.preprocess_wrapper
16 | # Wrapper usually crops and resizes the input image (so that it's just
17 | # like our renders) before sending it to the actual preprocessing
18 |
19 | # Associate each data type required by the model with input paths
20 | type2filename = {}
21 | for k in required:
22 | type2filename[k] = getattr(opt, 'input_' + k)
23 |
24 | # Generate a sorted filelist for each data type
25 | type2files = {}
26 | for k, v in type2filename.items():
27 | type2files[k] = sorted(glob(v))
28 | ns = [len(x) for x in type2files.values()]
29 | assert len(set(ns)) == 1, \
30 | ("Filelists for different types must be of the same length "
31 | "(1-to-1 correspondance)")
32 | self.length = ns[0]
33 |
34 | samples = []
35 | for i in range(self.length):
36 | sample = {}
37 | for k, v in type2files.items():
38 | sample[k + '_path'] = v[i]
39 | samples.append(sample)
40 | self.samples = samples
41 |
42 | def __len__(self):
43 | return self.length
44 |
45 | def __getitem__(self, i):
46 | sample = self.samples[i]
47 |
48 | # Actually loading the item
49 | sample_loaded = {}
50 | for k, v in sample.items():
51 | sample_loaded[k] = v # as-is
52 | if k == 'rgb_path':
53 | im = util.util_img.imread_wrapper(
54 | v, util.util_img.IMREAD_COLOR, output_channel_order='RGB')
55 | # Normalize to [0, 1] floats
56 | im = im.astype(float) / float(np.iinfo(im.dtype).max)
57 | sample_loaded['rgb'] = im
58 | elif k == 'mask_path':
59 | im = util.util_img.imread_wrapper(
60 | v, util.util_img.IMREAD_GRAYSCALE)
61 | # Normalize to [0, 1] floats
62 | im = im.astype(float) / float(np.iinfo(im.dtype).max)
63 | sample_loaded['silhou'] = im
64 | else:
65 | raise NotImplementedError(v)
66 |
67 | # Preprocessing specified by the model
68 | sample_loaded = self.preproc(sample_loaded)
69 | # Convert all types to float32 for faster copying
70 | self.convert_to_float32(sample_loaded)
71 | return sample_loaded
72 |
73 | @staticmethod
74 | def convert_to_float32(sample_loaded):
75 | for k, v in sample_loaded.items():
76 | if isinstance(v, np.ndarray):
77 | if v.dtype != np.float32:
78 | sample_loaded[k] = v.astype(np.float32)
79 |
--------------------------------------------------------------------------------
/downloads/data/test/genre/03001627_10c08a28cae054e53a762233fffc49ea_view000_rgb.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiumingzhang/GenRe-ShapeHD/ee42add2707de509b5914ab444ae91b832f75981/downloads/data/test/genre/03001627_10c08a28cae054e53a762233fffc49ea_view000_rgb.png
--------------------------------------------------------------------------------
/downloads/data/test/genre/03001627_10c08a28cae054e53a762233fffc49ea_view000_silhouette.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiumingzhang/GenRe-ShapeHD/ee42add2707de509b5914ab444ae91b832f75981/downloads/data/test/genre/03001627_10c08a28cae054e53a762233fffc49ea_view000_silhouette.png
--------------------------------------------------------------------------------
/downloads/data/test/genre/04256520_2c6dcb7184bfed32599dcc439b161a52_view010_rgb.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiumingzhang/GenRe-ShapeHD/ee42add2707de509b5914ab444ae91b832f75981/downloads/data/test/genre/04256520_2c6dcb7184bfed32599dcc439b161a52_view010_rgb.png
--------------------------------------------------------------------------------
/downloads/data/test/genre/04256520_2c6dcb7184bfed32599dcc439b161a52_view010_silhouette.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiumingzhang/GenRe-ShapeHD/ee42add2707de509b5914ab444ae91b832f75981/downloads/data/test/genre/04256520_2c6dcb7184bfed32599dcc439b161a52_view010_silhouette.png
--------------------------------------------------------------------------------
/downloads/data/test/genre/04256520_2d987393f7f7c5d1f51f77a6d7299806_view001_rgb.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiumingzhang/GenRe-ShapeHD/ee42add2707de509b5914ab444ae91b832f75981/downloads/data/test/genre/04256520_2d987393f7f7c5d1f51f77a6d7299806_view001_rgb.png
--------------------------------------------------------------------------------
/downloads/data/test/genre/04256520_2d987393f7f7c5d1f51f77a6d7299806_view001_silhouette.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiumingzhang/GenRe-ShapeHD/ee42add2707de509b5914ab444ae91b832f75981/downloads/data/test/genre/04256520_2d987393f7f7c5d1f51f77a6d7299806_view001_silhouette.png
--------------------------------------------------------------------------------
/downloads/data/test/genre/04379243_133d7c9a1f79b01ad0176f9a144100cd_view000_rgb.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiumingzhang/GenRe-ShapeHD/ee42add2707de509b5914ab444ae91b832f75981/downloads/data/test/genre/04379243_133d7c9a1f79b01ad0176f9a144100cd_view000_rgb.png
--------------------------------------------------------------------------------
/downloads/data/test/genre/04379243_133d7c9a1f79b01ad0176f9a144100cd_view000_silhouette.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiumingzhang/GenRe-ShapeHD/ee42add2707de509b5914ab444ae91b832f75981/downloads/data/test/genre/04379243_133d7c9a1f79b01ad0176f9a144100cd_view000_silhouette.png
--------------------------------------------------------------------------------
/downloads/data/test/shapehd/0044_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiumingzhang/GenRe-ShapeHD/ee42add2707de509b5914ab444ae91b832f75981/downloads/data/test/shapehd/0044_mask.png
--------------------------------------------------------------------------------
/downloads/data/test/shapehd/0044_rgb.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiumingzhang/GenRe-ShapeHD/ee42add2707de509b5914ab444ae91b832f75981/downloads/data/test/shapehd/0044_rgb.png
--------------------------------------------------------------------------------
/downloads/data/test/shapehd/0503_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiumingzhang/GenRe-ShapeHD/ee42add2707de509b5914ab444ae91b832f75981/downloads/data/test/shapehd/0503_mask.png
--------------------------------------------------------------------------------
/downloads/data/test/shapehd/0503_rgb.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiumingzhang/GenRe-ShapeHD/ee42add2707de509b5914ab444ae91b832f75981/downloads/data/test/shapehd/0503_rgb.jpg
--------------------------------------------------------------------------------
/downloads/data/test/shapehd/1209_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiumingzhang/GenRe-ShapeHD/ee42add2707de509b5914ab444ae91b832f75981/downloads/data/test/shapehd/1209_mask.png
--------------------------------------------------------------------------------
/downloads/data/test/shapehd/1209_rgb.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiumingzhang/GenRe-ShapeHD/ee42add2707de509b5914ab444ae91b832f75981/downloads/data/test/shapehd/1209_rgb.jpg
--------------------------------------------------------------------------------
/downloads/results/genre.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiumingzhang/GenRe-ShapeHD/ee42add2707de509b5914ab444ae91b832f75981/downloads/results/genre.png
--------------------------------------------------------------------------------
/downloads/results/shapehd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiumingzhang/GenRe-ShapeHD/ee42add2707de509b5914ab444ae91b832f75981/downloads/results/shapehd.png
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: shaperecon
2 | channels:
3 | - anaconda
4 | - pytorch
5 | - conda-forge
6 | dependencies:
7 | - python=3.6
8 | - numpy=1.15.4
9 | - pandas=0.23.4
10 | - tqdm=4.28.1
11 | - scikit-image=0.14.0
12 | - numba=0.41.0
13 | - opencv=3.4.2
14 | - pytorch=0.4.1
15 | - torchvision=0.2.1
16 | - tensorflow=1.5.1
17 | - trimesh=2.35.47
18 | - rtree=0.8.3
19 | - scikit-learn=0.20.1
20 |
--------------------------------------------------------------------------------
/install_trimesh.sh:
--------------------------------------------------------------------------------
1 | source activate shaperecon
2 | conda config --add channels conda-forge
3 | conda install shapely rtree pyembree
4 | conda install -c conda-forge scikit-image
5 | conda install "pillow<7"
6 | pip install trimesh[all]==2.35.47
7 |
--------------------------------------------------------------------------------
/loggers/Progbar.py:
--------------------------------------------------------------------------------
1 | # taken from Keras (https://github.com/fchollet/keras/blob/d687c6eda4d9cb58756822fd77402274db309da8/keras/utils/generic_utils.py)
2 | import sys
3 | import time
4 | import numpy as np
5 |
6 |
7 | class Progbar(object):
8 | """Displays a progress bar.
9 | # Arguments
10 | target: Total number of steps expected, None if unknown.
11 | interval: Minimum visual progress update interval (in seconds).
12 | """
13 |
14 | def __init__(self, target, width=30, verbose=1, interval=0.05):
15 | self.width = width
16 | if target is None:
17 | target = -1
18 | self.target = target
19 | self.sum_values = {}
20 | self.unique_values = []
21 | self.start = time.time()
22 | self.last_update = 0
23 | self.interval = interval
24 | self.total_width = 0
25 | self.seen_so_far = 0
26 | self.verbose = verbose
27 |
28 | def update(self, current, values=None, force=False):
29 | """Updates the progress bar.
30 | # Arguments
31 | current: Index of current step.
32 | values: List of tuples (name, value_for_last_step).
33 | The progress bar will display averages for these values.
34 | force: Whether to force visual progress update.
35 | """
36 | values = values or []
37 | for k, v in values:
38 | if k not in self.sum_values:
39 | self.sum_values[k] = [v * (current - self.seen_so_far),
40 | current - self.seen_so_far]
41 | self.unique_values.append(k)
42 | else:
43 | self.sum_values[k][0] += v * (current - self.seen_so_far)
44 | self.sum_values[k][1] += (current - self.seen_so_far)
45 | self.seen_so_far = current
46 |
47 | now = time.time()
48 | if self.verbose == 1:
49 | if not force and (now - self.last_update) < self.interval:
50 | return
51 |
52 | prev_total_width = self.total_width
53 | sys.stdout.write('\b' * prev_total_width)
54 | sys.stdout.write('\r')
55 |
56 | if self.target is not -1:
57 | numdigits = int(np.floor(np.log10(self.target))) + 1
58 | barstr = '%%%dd/%%%dd [' % (numdigits, numdigits)
59 | bar = barstr % (current, self.target)
60 | prog = float(current) / self.target
61 | prog_width = int(self.width * prog)
62 | if prog_width > 0:
63 | bar += ('=' * (prog_width - 1))
64 | if current < self.target:
65 | bar += '>'
66 | else:
67 | bar += '='
68 | bar += ('.' * (self.width - prog_width))
69 | bar += ']'
70 | sys.stdout.write(bar)
71 | self.total_width = len(bar)
72 |
73 | if current:
74 | time_per_unit = (now - self.start) / current
75 | else:
76 | time_per_unit = 0
77 | eta = time_per_unit * (self.target - current)
78 | info = ''
79 | if current < self.target and self.target is not -1:
80 | info += ' - ETA: %ds' % eta
81 | else:
82 | info += ' - %ds' % (now - self.start)
83 | for k in self.unique_values:
84 | info += ' - %s:' % k
85 | if isinstance(self.sum_values[k], list):
86 | avg = np.mean(self.sum_values[k][0] / max(1, self.sum_values[k][1]))
87 | if abs(avg) > 1e-3:
88 | info += ' %.4f' % avg
89 | else:
90 | info += ' %.4e' % avg
91 | else:
92 | info += ' %s' % self.sum_values[k]
93 |
94 | self.total_width += len(info)
95 | if prev_total_width > self.total_width:
96 | info += ((prev_total_width - self.total_width) * ' ')
97 |
98 | sys.stdout.write(info)
99 | sys.stdout.flush()
100 |
101 | if current >= self.target:
102 | sys.stdout.write('\n')
103 |
104 | if self.verbose == 2:
105 | if current >= self.target:
106 | info = '%ds' % (now - self.start)
107 | for k in self.unique_values:
108 | info += ' - %s:' % k
109 | avg = np.mean(self.sum_values[k][0] / max(1, self.sum_values[k][1]))
110 | if avg > 1e-3:
111 | info += ' %.4f' % avg
112 | else:
113 | info += ' %.4e' % avg
114 | sys.stdout.write(info + "\n")
115 |
116 | self.last_update = now
117 |
118 | def add(self, n, values=None):
119 | self.update(self.seen_so_far + n, values)
120 |
--------------------------------------------------------------------------------
/loggers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiumingzhang/GenRe-ShapeHD/ee42add2707de509b5914ab444ae91b832f75981/loggers/__init__.py
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 |
3 |
4 | def get_model(alias, test=False):
5 | module = importlib.import_module('models.' + alias)
6 | if test:
7 | return module.Model_test
8 | return module.Model
9 |
--------------------------------------------------------------------------------
/models/depth_pred_with_sph_inpaint.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | from models.marrnet1 import Model as DepthModel
5 | from models.marrnet1 import Net as Net1
6 | from networks.uresnet import Net_inpaint as Uresnet
7 | from toolbox.cam_bp.cam_bp.modules.camera_backprojection_module import Camera_back_projection_layer
8 | from toolbox.spherical_proj import render_spherical, sph_pad
9 | import torch.nn.functional as F
10 |
11 |
12 | class Model(DepthModel):
13 | @classmethod
14 | def add_arguments(cls, parser):
15 | parser.add_argument('--pred_depth_minmax', action='store_true', default=True,
16 | help="GenRe needs minmax prediction")
17 | parser.add_argument('--load_offline', action='store_true',
18 | help="load offline prediction results")
19 | parser.add_argument('--joint_train', action='store_true',
20 | help="joint train net1 and net2")
21 | parser.add_argument('--net1_path', default=None, type=str,
22 | help="path to pretrained net1")
23 | parser.add_argument('--padding_margin', default=16, type=int,
24 | help="padding margin for spherical maps")
25 | unique_params = {'joint_train'}
26 | return parser, unique_params
27 |
28 | def __init__(self, opt, logger):
29 | super(Model, self).__init__(opt, logger)
30 | self.joint_train = opt.joint_train
31 | if not self.joint_train:
32 | self.requires = ['silhou', 'rgb', 'spherical']
33 | self.gt_names = ['spherical_object']
34 | self._metrics = ['spherical']
35 | else:
36 | self.requires.append('spherical')
37 | self.gt_names = ['depth', 'silhou', 'normal', 'depth_minmax', 'spherical_object']
38 | self._metrics.append('spherical')
39 | self.input_names = ['rgb', 'silhou', 'spherical_depth']
40 | self.net = Net(opt, Model)
41 | self.optimizer = self.adam(
42 | self.net.parameters(),
43 | lr=opt.lr,
44 | **self.optim_params
45 | )
46 | self._nets = [self.net]
47 | self._optimizers = [self.optimizer]
48 | self.init_vars(add_path=True)
49 | self.init_weight(self.net.net2)
50 |
51 | def __str__(self):
52 | string = "Depth Prediction with Spherical Refinement"
53 | if self.joint_train:
54 | string += ' Jointly training all the modules.'
55 | else:
56 | string += ' Only training the inpainting module.'
57 | return string
58 |
59 | def compute_loss(self, pred):
60 | loss_data = {}
61 | loss = 0
62 | if self.joint_train:
63 | loss, loss_data = super(Model, self).compute_loss(pred)
64 | sph_loss = F.mse_loss(pred['pred_sph_full'], self._gt.spherical_object)
65 | loss_data['spherical'] = sph_loss.mean().item()
66 | loss += sph_loss
67 | loss_data['loss'] = loss.mean().item()
68 | return loss, loss_data
69 |
70 | def pack_output(self, pred, batch, add_gt=True):
71 | pack = {}
72 | if self.joint_train:
73 | pack = super(Model, self).pack_output(pred, batch, add_gt=False)
74 | pack['pred_spherical_full'] = pred['pred_sph_full'].data.cpu().numpy()
75 | pack['pred_spherical_partial'] = pred['pred_sph_partial'].data.cpu().numpy()
76 | pack['proj_depth'] = pred['proj_depth'].data.cpu().numpy()
77 | pack['rgb_path'] = batch['rgb_path']
78 | if add_gt:
79 | pack['gt_spherical_full'] = batch['spherical_object'].numpy()
80 | return pack
81 |
82 | @classmethod
83 | def preprocess(cls, data, mode='train'):
84 | dataout = DepthModel.preprocess(data, mode)
85 | if 'spherical_object' in dataout.keys():
86 | val = dataout['spherical_object']
87 | assert(val.shape[1] == val.shape[2])
88 | assert(val.shape[1] == 128)
89 | sph_padded = np.pad(val, ((0, 0), (0, 0), (16, 16)), 'wrap')
90 | sph_padded = np.pad(sph_padded, ((0, 0), (16, 16), (0, 0)), 'edge')
91 | dataout['spherical_object'] = sph_padded
92 | return dataout
93 |
94 |
95 | class Net(nn.Module):
96 | def __init__(self, opt, base_class=Model):
97 | super().__init__()
98 | self.net1 = Net1(
99 | [3, 1, 1],
100 | ['normal', 'depth', 'silhou'],
101 | pred_depth_minmax=True)
102 | self.net2 = Uresnet([1], ['spherical'], input_planes=1)
103 | self.base_class = base_class
104 | self.proj_depth = Camera_back_projection_layer()
105 | self.render_spherical = render_spherical()
106 | self.joint_train = opt.joint_train
107 | self.load_offline = opt.load_offline
108 | self.padding_margin = opt.padding_margin
109 | if opt.net1_path:
110 | state_dicts = torch.load(opt.net1_path)
111 | self.net1.load_state_dict(state_dicts['nets'][0])
112 |
113 | def forward(self, input_struct):
114 | if not self.joint_train:
115 | with torch.no_grad():
116 | out_1 = self.net1(input_struct)
117 | else:
118 | out_1 = self.net1(input_struct)
119 | pred_abs_depth = self.get_abs_depth(out_1, input_struct)
120 | proj = self.proj_depth(pred_abs_depth)
121 | if self.load_offline:
122 | sph_in = input_struct.spherical_depth
123 | else:
124 | sph_in = self.render_spherical(torch.clamp(proj * 50, 1e-5, 1 - 1e-5))
125 | # pad sph_in to approximate boundary conditions
126 | sph_in = sph_pad(sph_in, self.padding_margin)
127 | out_2 = self.net2(sph_in)
128 | out_1['proj_depth'] = proj * 50
129 | out_1['pred_sph_partial'] = sph_in
130 | out_1['pred_sph_full'] = out_2['spherical']
131 | return out_1
132 |
133 | def get_abs_depth(self, pred, input_struct):
134 | pred_depth = pred['depth']
135 | pred_depth = self.base_class.postprocess(pred_depth)
136 | pred_depth_minmax = pred['depth_minmax'].detach()
137 | pred_abs_depth = self.base_class.to_abs_depth(1 - pred_depth, pred_depth_minmax)
138 | silhou = self.base_class.postprocess(input_struct.silhou).detach()
139 | pred_abs_depth[silhou < 0.5] = 0
140 | pred_abs_depth = pred_abs_depth.permute(0, 1, 3, 2)
141 | pred_abs_depth = torch.flip(pred_abs_depth, [2])
142 | return pred_abs_depth
143 |
--------------------------------------------------------------------------------
/models/marrnet.py:
--------------------------------------------------------------------------------
1 | from os import makedirs
2 | from os.path import join
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | from util import util_img
7 | from .marrnet1 import Net as Marrnet1
8 | from .marrnet2 import Net as Marrnet2, Model as Marrnet2_model
9 |
10 |
11 | class Model(Marrnet2_model):
12 | @classmethod
13 | def add_arguments(cls, parser):
14 | parser.add_argument(
15 | '--canon_sup',
16 | action='store_true',
17 | help="Use canonical-pose voxels as supervision"
18 | )
19 | parser.add_argument(
20 | '--marrnet1',
21 | type=str, default=None,
22 | help="Path to pretrained MarrNet-1"
23 | )
24 | parser.add_argument(
25 | '--marrnet2',
26 | type=str, default=None,
27 | help="Path to pretrained MarrNet-2 (to be finetuned)"
28 | )
29 | return parser, set()
30 |
31 | def __init__(self, opt, logger):
32 | super().__init__(opt, logger)
33 | pred_silhou_thres = self.pred_silhou_thres * self.scale_25d
34 | self.requires = ['rgb', self.voxel_key]
35 | self.net = Net(opt.marrnet1, opt.marrnet2, pred_silhou_thres)
36 | self._nets = [self.net]
37 | self.optimizer = self.adam(
38 | self.net.marrnet2.parameters(),
39 | lr=opt.lr,
40 | **self.optim_params
41 | ) # just finetune MarrNet-2
42 | self._optimizers[-1] = self.optimizer
43 | self.input_names = ['rgb']
44 | self.init_vars(add_path=True)
45 |
46 | def __str__(self):
47 | return "Finetuning MarrNet-2 with MarrNet-1 predictions"
48 |
49 | def pack_output(self, pred, batch, add_gt=True):
50 | pred_normal = pred['normal'].detach().cpu()
51 | pred_silhou = pred['silhou'].detach().cpu()
52 | pred_depth = pred['depth'].detach().cpu()
53 | out = {}
54 | out['rgb_path'] = batch['rgb_path']
55 | out['rgb'] = util_img.denormalize_colors(batch['rgb'].detach().numpy())
56 | pred_silhou = self.postprocess(pred_silhou)
57 | pred_silhou = torch.clamp(pred_silhou, 0, 1)
58 | pred_silhou[pred_silhou < 0] = 0
59 | out['pred_silhou'] = pred_silhou.numpy()
60 | out['pred_normal'] = self.postprocess(
61 | pred_normal, bg=1.0, input_mask=pred_silhou
62 | ).numpy()
63 | out['pred_depth'] = self.postprocess(
64 | pred_depth, bg=0.0, input_mask=pred_silhou
65 | ).numpy()
66 | out['pred_voxel'] = pred['voxel'].detach().cpu().numpy()
67 | if add_gt:
68 | out['gt_voxel'] = batch[self.voxel_key].numpy()
69 | return out
70 |
71 | def compute_loss(self, pred):
72 | loss = self.criterion(
73 | pred['voxel'],
74 | getattr(self._gt, self.voxel_key)
75 | )
76 | loss_data = {}
77 | loss_data['loss'] = loss.mean().item()
78 | return loss, loss_data
79 |
80 |
81 | class Net(nn.Module):
82 | """
83 | MarrNet-1 MarrNet-2
84 | RGB ------> 2.5D ------> 3D
85 | fixed finetuned
86 | """
87 |
88 | def __init__(self, marrnet1_path=None, marrnet2_path=None, pred_silhou_thres=0.3):
89 | super().__init__()
90 | # Init MarrNet-1 and load weights
91 | self.marrnet1 = Marrnet1(
92 | [3, 1, 1],
93 | ['normal', 'depth', 'silhou'],
94 | pred_depth_minmax=True, # not used in MarrNet
95 | )
96 | if marrnet1_path:
97 | state_dict = torch.load(marrnet1_path)['nets'][0]
98 | self.marrnet1.load_state_dict(state_dict)
99 | # Init MarrNet-2 and load weights
100 | self.marrnet2 = Marrnet2(4)
101 | if marrnet2_path:
102 | state_dict = torch.load(marrnet2_path)['nets'][0]
103 | self.marrnet2.load_state_dict(state_dict)
104 | # Fix MarrNet-1, but finetune 2
105 | for p in self.marrnet1.parameters():
106 | p.requires_grad = False
107 | for p in self.marrnet2.parameters():
108 | p.requires_grad = True
109 | self.pred_silhou_thres = pred_silhou_thres
110 |
111 | def forward(self, input_struct):
112 | # Predict 2.5D sketches
113 | with torch.no_grad():
114 | pred = self.marrnet1(input_struct)
115 | depth = pred['depth']
116 | normal = pred['normal']
117 | silhou = pred['silhou']
118 | # Mask
119 | is_bg = silhou < self.pred_silhou_thres
120 | depth[is_bg] = 0
121 | normal[is_bg.repeat(1, 3, 1, 1)] = 0
122 | x = torch.cat((depth, normal), 1)
123 | # Forward
124 | latent_vec = self.marrnet2.encoder(x)
125 | vox = self.marrnet2.decoder(latent_vec)
126 | pred['voxel'] = vox
127 | return pred
128 |
129 |
130 | class Model_test(Model):
131 | def __init__(self, opt, logger):
132 | super().__init__(opt, logger)
133 | self.requires = ['rgb', 'mask'] # mask for bbox cropping only
134 | self.load_state_dict(opt.net_file, load_optimizer='auto')
135 | self.input_names = ['rgb']
136 | self.init_vars(add_path=True)
137 | self.output_dir = opt.output_dir
138 |
139 | def __str__(self):
140 | return "Testing MarrNet"
141 |
142 | @classmethod
143 | def preprocess_wrapper(cls, in_dict):
144 | silhou_thres = 0.95
145 | in_size = 480
146 | pad = 85
147 | im = in_dict['rgb']
148 | mask = in_dict['silhou']
149 | bbox = util_img.get_bbox(mask, th=silhou_thres)
150 | im_crop = util_img.crop(im, bbox, in_size, pad, pad_zero=False)
151 | in_dict['rgb'] = im_crop
152 | del in_dict['silhou'] # just for cropping -- done its job
153 | # Now the image is just like those we rendered
154 | out_dict = cls.preprocess(in_dict, mode='test')
155 | return out_dict
156 |
157 | def test_on_batch(self, batch_i, batch):
158 | outdir = join(self.output_dir, 'batch%04d' % batch_i)
159 | makedirs(outdir, exist_ok=True)
160 | pred = self.predict(batch, load_gt=False, no_grad=True)
161 | output = self.pack_output(pred, batch, add_gt=False)
162 | self.visualizer.visualize(output, batch_i, outdir)
163 | np.savez(outdir + '.npz', **output)
164 |
--------------------------------------------------------------------------------
/models/marrnet1.py:
--------------------------------------------------------------------------------
1 | from os import makedirs
2 | from os.path import join
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | from networks.networks import ViewAsLinear
7 | from networks.uresnet import Net as Uresnet
8 | from .marrnetbase import MarrnetBaseModel
9 |
10 |
11 | class Model(MarrnetBaseModel):
12 | @classmethod
13 | def add_arguments(cls, parser):
14 | parser.add_argument(
15 | '--pred_depth_minmax',
16 | action='store_true',
17 | help="Also predicts depth minmax (for GenRe)",
18 | )
19 | return parser, set()
20 |
21 | def __init__(self, opt, logger):
22 | super(Model, self).__init__(opt, logger)
23 | self.requires = ['rgb', 'depth', 'silhou', 'normal']
24 | if opt.pred_depth_minmax:
25 | self.requires.append('depth_minmax')
26 | self.net = Net(
27 | [3, 1, 1],
28 | ['normal', 'depth', 'silhou'],
29 | pred_depth_minmax=opt.pred_depth_minmax,
30 | )
31 | self.criterion = nn.functional.mse_loss
32 | self.optimizer = self.adam(
33 | self.net.parameters(),
34 | lr=opt.lr,
35 | **self.optim_params
36 | )
37 | self._nets = [self.net]
38 | self._optimizers.append(self.optimizer)
39 | self.input_names = ['rgb']
40 | self.gt_names = ['depth', 'silhou', 'normal']
41 | if opt.pred_depth_minmax:
42 | self.gt_names.append('depth_minmax')
43 | self.init_vars(add_path=True)
44 | self._metrics = ['loss', 'depth', 'silhou', 'normal']
45 | if opt.pred_depth_minmax:
46 | self._metrics.append('depth_minmax')
47 | self.init_weight(self.net)
48 |
49 | def __str__(self):
50 | return "MarrNet-1 predicting 2.5D sketches"
51 |
52 | def _train_on_batch(self, epoch, batch_idx, batch):
53 | self.net.zero_grad()
54 | pred = self.predict(batch)
55 | loss, loss_data = self.compute_loss(pred)
56 | loss.backward()
57 | self.optimizer.step()
58 | batch_size = len(batch['rgb_path'])
59 | batch_log = {'size': batch_size, **loss_data}
60 | return batch_log
61 |
62 | def _vali_on_batch(self, epoch, batch_idx, batch):
63 | pred = self.predict(batch, no_grad=True)
64 | _, loss_data = self.compute_loss(pred)
65 | if np.mod(epoch, self.opt.vis_every_vali) == 0:
66 | if batch_idx < self.opt.vis_batches_vali:
67 | outdir = join(self.full_logdir, 'epoch%04d_vali' % epoch)
68 | makedirs(outdir, exist_ok=True)
69 | output = self.pack_output(pred, batch)
70 | self.visualizer.visualize(output, batch_idx, outdir)
71 | np.savez(join(outdir, 'batch%04d' % batch_idx), **output)
72 | batch_size = len(batch['rgb_path'])
73 | batch_log = {'size': batch_size, **loss_data}
74 | return batch_log
75 |
76 | def pack_output(self, pred, batch, add_gt=True):
77 | pred_normal = pred['normal'].detach().cpu()
78 | pred_silhou = pred['silhou'].detach().cpu()
79 | pred_depth = pred['depth'].detach().cpu()
80 | gt_silhou = self.postprocess(batch['silhou'])
81 | out = {}
82 | out['rgb_path'] = batch['rgb_path']
83 | out['pred_normal'] = self.postprocess(pred_normal, bg=1.0, input_mask=gt_silhou).numpy()
84 | out['pred_silhou'] = self.postprocess(pred_silhou).numpy()
85 | pred_depth = self.postprocess(pred_depth, bg=0.0, input_mask=gt_silhou)
86 | out['pred_depth'] = pred_depth.numpy()
87 | if self.opt.pred_depth_minmax:
88 | pred_depth_minmax = pred['depth_minmax'].detach()
89 | pred_abs_depth = self.to_abs_depth(
90 | (1 - pred_depth).to(torch.device('cuda')),
91 | pred_depth_minmax
92 | ) # background is max now
93 | pred_abs_depth[gt_silhou < 1] = 0 # set background to 0
94 | out['proj_depth'] = self.proj_depth(pred_abs_depth).cpu().numpy()
95 | out['pred_depth_minmax'] = pred_depth_minmax.cpu().numpy()
96 | if add_gt:
97 | out['normal_path'] = batch['normal_path']
98 | out['silhou_path'] = batch['silhou_path']
99 | out['depth_path'] = batch['depth_path']
100 | if self.opt.pred_depth_minmax:
101 | out['gt_depth_minmax'] = batch['depth_minmax'].numpy()
102 | return out
103 |
104 | def compute_loss(self, pred):
105 | """
106 | TODO: we should add normal and depth consistency loss here in the future.
107 | """
108 | pred_normal = pred['normal']
109 | pred_depth = pred['depth']
110 | pred_silhou = pred['silhou']
111 | is_fg = self._gt.silhou != 0 # excludes background
112 | is_fg_full = is_fg.expand_as(pred_normal)
113 | loss_normal = self.criterion(
114 | pred_normal[is_fg_full], self._gt.normal[is_fg_full]
115 | )
116 | loss_depth = self.criterion(
117 | pred_depth[is_fg], self._gt.depth[is_fg]
118 | )
119 | loss_silhou = self.criterion(pred_silhou, self._gt.silhou)
120 | loss = loss_normal + loss_depth + loss_silhou
121 | loss_data = {}
122 | loss_data['loss'] = loss.mean().item()
123 | loss_data['normal'] = loss_normal.mean().item()
124 | loss_data['depth'] = loss_depth.mean().item()
125 | loss_data['silhou'] = loss_silhou.mean().item()
126 | if self.opt.pred_depth_minmax:
127 | w_minmax = (256 ** 2) / 2 # matching scale of pixel predictions very roughly
128 | loss_depth_minmax = w_minmax * self.criterion(
129 | pred['depth_minmax'],
130 | self._gt.depth_minmax
131 | )
132 | loss += loss_depth_minmax
133 | loss_data['depth_minmax'] = loss_depth_minmax.mean().item()
134 | return loss, loss_data
135 |
136 |
137 | class Net(Uresnet):
138 | def __init__(self, *args, pred_depth_minmax=True):
139 | super().__init__(*args)
140 | self.pred_depth_minmax = pred_depth_minmax
141 | if self.pred_depth_minmax:
142 | module_list = nn.Sequential(
143 | nn.Conv2d(512, 512, 2, stride=2),
144 | nn.Conv2d(512, 512, 4, stride=1),
145 | ViewAsLinear(),
146 | nn.Linear(512, 256),
147 | nn.BatchNorm1d(256),
148 | nn.ReLU(inplace=True),
149 | nn.Linear(256, 128),
150 | nn.BatchNorm1d(128),
151 | nn.ReLU(inplace=True),
152 | nn.Linear(128, 2)
153 | )
154 | self.decoder_minmax = module_list
155 |
156 | def forward(self, input_struct):
157 | x = input_struct.rgb
158 | out_dict = super().forward(x)
159 | if self.pred_depth_minmax:
160 | out_dict['depth_minmax'] = self.decoder_minmax(self.encoder_out)
161 | return out_dict
162 |
--------------------------------------------------------------------------------
/models/marrnet2.py:
--------------------------------------------------------------------------------
1 | from os import makedirs
2 | from os.path import join
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | from networks.networks import ImageEncoder, VoxelDecoder
7 | from .marrnetbase import MarrnetBaseModel
8 |
9 |
10 | class Model(MarrnetBaseModel):
11 | @classmethod
12 | def add_arguments(cls, parser):
13 | parser.add_argument(
14 | '--canon_sup',
15 | action='store_true',
16 | help="Use canonical-pose voxels as supervision"
17 | )
18 | return parser, set()
19 |
20 | def __init__(self, opt, logger):
21 | super(Model, self).__init__(opt, logger)
22 | if opt.canon_sup:
23 | voxel_key = 'voxel_canon'
24 | else:
25 | voxel_key = 'voxel'
26 | self.voxel_key = voxel_key
27 | self.requires = ['rgb', 'depth', 'normal', 'silhou', voxel_key]
28 | self.net = Net(4)
29 | self.criterion = nn.BCEWithLogitsLoss(reduction='elementwise_mean')
30 | self.optimizer = self.adam(
31 | self.net.parameters(),
32 | lr=opt.lr,
33 | **self.optim_params
34 | )
35 | self._nets = [self.net]
36 | self._optimizers.append(self.optimizer)
37 | self.input_names = ['depth', 'normal', 'silhou']
38 | self.gt_names = [voxel_key]
39 | self.init_vars(add_path=True)
40 | self._metrics = ['loss']
41 | self.init_weight(self.net)
42 |
43 | def __str__(self):
44 | return "MarrNet-2 predicting voxels from 2.5D sketches"
45 |
46 | def _train_on_batch(self, epoch, batch_idx, batch):
47 | self.net.zero_grad()
48 | pred = self.predict(batch)
49 | loss, loss_data = self.compute_loss(pred)
50 | loss.backward()
51 | self.optimizer.step()
52 | batch_size = len(batch['rgb_path'])
53 | batch_log = {'size': batch_size, **loss_data}
54 | return batch_log
55 |
56 | def _vali_on_batch(self, epoch, batch_idx, batch):
57 | pred = self.predict(batch, no_grad=True)
58 | _, loss_data = self.compute_loss(pred)
59 | if np.mod(epoch, self.opt.vis_every_vali) == 0:
60 | if batch_idx < self.opt.vis_batches_vali:
61 | outdir = join(self.full_logdir, 'epoch%04d_vali' % epoch)
62 | makedirs(outdir, exist_ok=True)
63 | output = self.pack_output(pred, batch)
64 | self.visualizer.visualize(output, batch_idx, outdir)
65 | np.savez(join(outdir, 'batch%04d' % batch_idx), **output)
66 | batch_size = len(batch['rgb_path'])
67 | batch_log = {'size': batch_size, **loss_data}
68 | return batch_log
69 |
70 | def pack_output(self, pred, batch, add_gt=True):
71 | out = {}
72 | out['rgb_path'] = batch['rgb_path']
73 | out['pred_voxel'] = pred.detach().cpu().numpy()
74 | if add_gt:
75 | out['gt_voxel'] = batch[self.voxel_key].numpy()
76 | out['normal_path'] = batch['normal_path']
77 | out['depth_path'] = batch['depth_path']
78 | out['silhou_path'] = batch['silhou_path']
79 | return out
80 |
81 | def compute_loss(self, pred):
82 | loss = self.criterion(pred, getattr(self._gt, self.voxel_key))
83 | loss_data = {}
84 | loss_data['loss'] = loss.mean().item()
85 | return loss, loss_data
86 |
87 |
88 | class Net(nn.Module):
89 | """
90 | 2.5D maps to 3D voxel
91 | """
92 |
93 | def __init__(self, in_planes, encode_dims=200, silhou_thres=0):
94 | super().__init__()
95 | self.encoder = ImageEncoder(in_planes, encode_dims=encode_dims)
96 | self.decoder = VoxelDecoder(n_dims=encode_dims, nf=512)
97 | self.silhou_thres = silhou_thres
98 |
99 | def forward(self, input_struct):
100 | depth = input_struct.depth
101 | normal = input_struct.normal
102 | silhou = input_struct.silhou
103 | # Mask
104 | is_bg = silhou <= self.silhou_thres
105 | depth[is_bg] = 0
106 | normal[is_bg.repeat(1, 3, 1, 1)] = 0 # NOTE: if old net2, set to white (100),
107 | x = torch.cat((depth, normal), 1) # and swap depth and normal
108 | # Forward
109 | latent_vec = self.encoder(x)
110 | vox = self.decoder(latent_vec)
111 | return vox
112 |
--------------------------------------------------------------------------------
/models/marrnetbase.py:
--------------------------------------------------------------------------------
1 | from os import makedirs
2 | from os.path import join
3 | import numpy as np
4 | import torch
5 | from torch import nn
6 | from models.netinterface import NetInterface
7 | from toolbox.cam_bp.cam_bp.functions import CameraBackProjection
8 | import util.util_img
9 |
10 |
11 | class MarrnetBaseModel(NetInterface):
12 | im_size = 256
13 | rgb_jitter_d = 0.4
14 | rgb_light_noise = 0.1
15 | silhou_thres = 0.999
16 | pred_silhou_thres = 0.3
17 | scale_25d = 100
18 |
19 | def __init__(self, opt, logger):
20 | super(MarrnetBaseModel, self).__init__(opt, logger)
21 | self.opt = opt
22 | self.n_batches_per_epoch = opt.epoch_batches
23 | self.n_batches_to_vis_train = opt.vis_batches_train
24 | self.n_batches_to_vis_vali = opt.vis_batches_vali
25 | self.full_logdir = opt.full_logdir
26 | self._metrics = []
27 | self.batches_to_vis = {}
28 | self.dataset = opt.dataset
29 | self._nets = []
30 | self._optimizers = []
31 | self._moveable_vars = []
32 | self.cam_bp = Camera_back_projection_layer(128)
33 | if opt.log_time:
34 | self._metrics += ['batch_time', 'data_time']
35 | # Parameters for different optimization methods
36 | self.optim_params = dict()
37 | if opt.optim == 'adam':
38 | self.optim_params['betas'] = (opt.adam_beta1, opt.adam_beta2)
39 | elif opt.optim == 'sgd':
40 | self.optim_params['momentum'] = opt.sgd_momentum
41 | self.optim_params['dampening'] = opt.sgd_dampening
42 | self.optim_params['weight_decay'] = opt.sgd_wdecay
43 | else:
44 | raise NotImplementedError(opt.optim)
45 |
46 | def _train_on_batch(self, batch_idx, batch):
47 | self.net.zero_grad()
48 | pred = self.predict(batch)
49 | loss, loss_data = self.compute_loss(pred)
50 | loss.backward()
51 | self.optimizer.step()
52 | batch_size = len(batch['rgb_path'])
53 | batch_log = {'size': batch_size, **loss_data}
54 | self.record_batch(batch_idx, batch)
55 | return batch_log
56 |
57 | def _vali_on_batch(self, epoch, batch_idx, batch):
58 | pred = self.predict(batch, no_grad=True)
59 | _, loss_data = self.compute_loss(pred)
60 | if np.mod(epoch, self.opt.vis_every_vali) == 0:
61 | if batch_idx < self.opt.vis_batches_vali:
62 | outdir = join(self.full_logdir, 'epoch%04d_vali' % epoch)
63 | makedirs(outdir, exist_ok=True)
64 | output = self.pack_output(pred, batch)
65 | self.visualizer.visualize(output, batch_idx, outdir)
66 | np.savez(join(outdir, 'batch%04d' % batch_idx), **output)
67 | batch_size = len(batch['rgb_path'])
68 | batch_log = {'size': batch_size, **loss_data}
69 | return batch_log
70 |
71 | @classmethod
72 | def preprocess(cls, data, mode='train'):
73 | """
74 | This function should be applied to [0, 1] floats, except absolute depth
75 | """
76 | data_proc = {}
77 | for key, val in data.items():
78 | if key == 'rgb':
79 | im = val
80 | # H x W x 3
81 | im = util.util_img.resize(im, cls.im_size, 'horizontal')
82 | if mode == 'train':
83 | im = util.util_img.jitter_colors(
84 | im,
85 | d_brightness=cls.rgb_jitter_d,
86 | d_contrast=cls.rgb_jitter_d,
87 | d_saturation=cls.rgb_jitter_d
88 | )
89 | im = util.util_img.add_lighting_noise(
90 | im, cls.rgb_light_noise)
91 | im = util.util_img.normalize_colors(im)
92 | val = im.transpose(2, 0, 1)
93 |
94 | elif key == 'depth':
95 | im = val
96 | if im.ndim == 3:
97 | im = im[:, :, 0]
98 | im = util.util_img.resize(
99 | im, cls.im_size, 'horizontal', clamp=(im.min(), im.max()))
100 | im *= cls.scale_25d
101 | val = im[np.newaxis, :, :]
102 | # 1 x H x W, scaled
103 |
104 | elif key == 'silhou':
105 | im = val
106 | if im.ndim == 3:
107 | im = im[:, :, 0]
108 | im = util.util_img.resize(
109 | im, cls.im_size, 'horizontal', clamp=(im.min(), im.max()))
110 | im = util.util_img.binarize(
111 | im, cls.silhou_thres, gt_is_1=True)
112 | im *= cls.scale_25d
113 | val = im[np.newaxis, :, :]
114 | # 1 x H x W, binarized, scaled
115 |
116 | elif key == 'normal':
117 | # H x W x 3
118 | im = val
119 | im = util.util_img.resize(
120 | im, cls.im_size, 'horizontal', clamp=(im.min(), im.max()))
121 | im *= cls.scale_25d
122 | val = im.transpose(2, 0, 1)
123 | # 3 x H x W, scaled
124 |
125 | data_proc[key] = val
126 | return data_proc
127 |
128 | @staticmethod
129 | def mask(input_image, input_mask, bg=1.0):
130 | assert isinstance(bg, (int, float))
131 | assert (input_mask >= 0).all() and (input_mask <= 1).all()
132 | input_mask = input_mask.expand_as(input_image)
133 | bg = bg * input_image.new_ones(input_image.size())
134 | output = input_mask * input_image + (1 - input_mask) * bg
135 | return output
136 |
137 | @classmethod
138 | def postprocess(cls, tensor, bg=1.0, input_mask=None):
139 | scaled = tensor / cls.scale_25d
140 | if input_mask is not None:
141 | return cls.mask(scaled, input_mask, bg=bg)
142 | return scaled
143 |
144 | @staticmethod
145 | def to_abs_depth(rel_depth, depth_minmax):
146 | bmin = depth_minmax[:, 0]
147 | bmax = depth_minmax[:, 1]
148 | depth_min = bmin.view(-1, 1, 1, 1)
149 | depth_max = bmax.view(-1, 1, 1, 1)
150 | abs_depth = rel_depth * (depth_max - depth_min + 1e-4) + depth_min
151 | return abs_depth
152 |
153 | def proj_depth(self, abs_depth):
154 | proj_depth = self.cam_bp(abs_depth)
155 | return self.cam_bp.shift_tdf(proj_depth)
156 |
157 |
158 | class Camera_back_projection_layer(nn.Module):
159 | def __init__(self, res):
160 | super(Camera_back_projection_layer, self).__init__()
161 | self.res = res
162 |
163 | def forward(self, depth_t, fl=784.4645406, cam_dist=2.2):
164 | # print(cam_dist)
165 | n = depth_t.size(0)
166 | if isinstance(fl, float):
167 | fl_v = fl
168 | fl = torch.FloatTensor(n, 1).cuda()
169 | fl.fill_(fl_v)
170 | if isinstance(cam_dist, float):
171 | cmd_v = cam_dist
172 | cam_dist = torch.FloatTensor(n, 1).cuda()
173 | cam_dist.fill_(cmd_v)
174 | return CameraBackProjection.apply(depth_t, fl, cam_dist, self.res)
175 |
176 | @staticmethod
177 | def shift_tdf(input_tdf, res=128):
178 | out_tdf = 1 - res * input_tdf
179 | return out_tdf
180 |
--------------------------------------------------------------------------------
/models/shapehd.py:
--------------------------------------------------------------------------------
1 | from os import makedirs
2 | from os.path import join
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | from util import util_img
7 | from .wgangp import D
8 | from .marrnet2 import Net as Marrnet2, Model as Marrnet2_model
9 | from .marrnet1 import Model as Marrnet1_model
10 |
11 |
12 | class Model(Marrnet2_model):
13 | @classmethod
14 | def add_arguments(cls, parser):
15 | parser.add_argument(
16 | '--canon_sup',
17 | action='store_true',
18 | help="Use canonical-pose voxels as supervision"
19 | )
20 | parser.add_argument(
21 | '--marrnet2',
22 | type=str, default=None,
23 | help="Path to pretrained MarrNet-2 (to be finetuned)"
24 | )
25 | parser.add_argument(
26 | '--gan',
27 | type=str, default=None,
28 | help="Path to pretrained WGANGP"
29 | )
30 | parser.add_argument(
31 | '--w_gan_loss',
32 | type=float, default=0,
33 | help="Weight for perceptual loss relative to supervised loss"
34 | )
35 | return parser, set()
36 |
37 | def __init__(self, opt, logger):
38 | super().__init__(opt, logger)
39 | assert opt.canon_sup, "ShapeHD uses canonical-pose voxels"
40 | self.net = Net(opt.marrnet2, opt.gan)
41 | self._nets = [self.net]
42 | self.optimizer = self.adam(
43 | self.net.marrnet2.parameters(),
44 | lr=opt.lr,
45 | **self.optim_params
46 | ) # just finetune MarrNet-2
47 | self._optimizers[-1] = self.optimizer
48 | self._metrics += ['sup', 'gan']
49 | self.init_vars(add_path=True)
50 | assert opt.w_gan_loss >= 0
51 |
52 | def __str__(self):
53 | return "Finetuning 3D estimator of ShapeHD with GAN loss"
54 |
55 | def pack_output(self, pred, batch, add_gt=True):
56 | out = {}
57 | out['rgb_path'] = batch['rgb_path']
58 | out['pred_voxel_noft'] = pred['voxel_noft'].detach().cpu().numpy()
59 | out['pred_voxel'] = pred['voxel'].detach().cpu().numpy()
60 | if add_gt:
61 | out['gt_voxel'] = batch[self.voxel_key].numpy()
62 | out['normal_path'] = batch['normal_path']
63 | out['depth_path'] = batch['depth_path']
64 | out['silhou_path'] = batch['silhou_path']
65 | return out
66 |
67 | def compute_loss(self, pred):
68 | loss_sup = self.criterion(
69 | pred['voxel'], # will be sigmoid'ed
70 | getattr(self._gt, self.voxel_key)
71 | )
72 | loss_gan = -pred['is_real'].mean() # negate to maximize
73 | loss_gan *= self.opt.w_gan_loss
74 | loss = loss_sup + loss_gan
75 | loss_data = {}
76 | loss_data['sup'] = loss_sup.item()
77 | loss_data['gan'] = loss_gan.item()
78 | loss_data['loss'] = loss.item()
79 | return loss, loss_data
80 |
81 |
82 | class Net(nn.Module):
83 | """
84 | 3D Estimator D of GAN
85 | 2.5D --------> 3D --------> real/fake
86 | finetuned fixed
87 | """
88 |
89 | def __init__(self, marrnet2_path=None, gan_path=None):
90 | super().__init__()
91 | # Init MarrNet-2 and load weights
92 | self.marrnet2 = Marrnet2(4)
93 | self.marrnet2_noft = Marrnet2(4)
94 | if marrnet2_path:
95 | state_dicts = torch.load(marrnet2_path)
96 | state_dict = state_dicts['nets'][0]
97 | self.marrnet2.load_state_dict(state_dict)
98 | self.marrnet2_noft.load_state_dict(state_dict)
99 | # Init discriminator and load weights
100 | self.d = D()
101 | if gan_path:
102 | state_dicts = torch.load(gan_path)
103 | self.d.load_state_dict(state_dicts['nets'][1])
104 | # Fix D, but finetune MarrNet-2
105 | for p in self.d.parameters():
106 | p.requires_grad = False
107 | for p in self.marrnet2_noft.parameters():
108 | p.requires_grad = False
109 | for p in self.marrnet2.parameters():
110 | p.requires_grad = True
111 | self.sigmoid = nn.Sigmoid()
112 |
113 | def forward(self, input_struct):
114 | pred = {}
115 | pred['voxel_noft'] = self.marrnet2_noft(input_struct) # unfinetuned
116 | pred['voxel'] = self.marrnet2(input_struct)
117 | pred['is_real'] = self.d(self.sigmoid(pred['voxel']))
118 | return pred
119 |
120 |
121 | class Model_test(Model):
122 | @classmethod
123 | def add_arguments(cls, parser):
124 | parser, unique_params = Model.add_arguments(parser)
125 | parser.add_argument(
126 | '--marrnet1_file',
127 | type=str, required=True,
128 | help="Path to pretrained MarrNet-1"
129 | )
130 | return parser, unique_params
131 |
132 | def __init__(self, opt, logger):
133 | opt.canon_sup = True # dummy, for network init only
134 | super().__init__(opt, logger)
135 | self.requires = ['rgb', 'mask'] # mask for bbox cropping only
136 | self.input_names = ['rgb']
137 | self.init_vars(add_path=True)
138 | self.output_dir = opt.output_dir
139 | # Load MarrNet-2 and D (though unused at test time)
140 | self.load_state_dict(opt.net_file, load_optimizer='auto')
141 | # Load MarrNet-1 whose outputs are inputs to D-tuned MarrNet-2
142 | opt.pred_depth_minmax = True # dummy
143 | self.marrnet1 = Marrnet1_model(opt, logger)
144 | self.marrnet1.load_state_dict(opt.marrnet1_file)
145 | self._nets.append(self.marrnet1.net)
146 |
147 | def __str__(self):
148 | return "Testing ShapeHD"
149 |
150 | @classmethod
151 | def preprocess_wrapper(cls, in_dict):
152 | silhou_thres = 0.95
153 | in_size = 480
154 | pad = 85
155 | im = in_dict['rgb']
156 | mask = in_dict['silhou']
157 | bbox = util_img.get_bbox(mask, th=silhou_thres)
158 | im_crop = util_img.crop(im, bbox, in_size, pad, pad_zero=False)
159 | in_dict['rgb'] = im_crop
160 | del in_dict['silhou'] # just for cropping -- done its job
161 | # Now the image is just like those we rendered
162 | out_dict = cls.preprocess(in_dict, mode='test')
163 | return out_dict
164 |
165 | def test_on_batch(self, batch_i, batch):
166 | outdir = join(self.output_dir, 'batch%04d' % batch_i)
167 | makedirs(outdir, exist_ok=True)
168 | # Forward MarrNet-1
169 | pred1 = self.marrnet1.predict(batch, load_gt=False, no_grad=True)
170 | # Forward MarrNet-2
171 | for net_name in ('marrnet2', 'marrnet2_noft'):
172 | net = getattr(self.net, net_name)
173 | net.silhou_thres = self.pred_silhou_thres * self.scale_25d
174 | self.input_names = ['depth', 'normal', 'silhou']
175 | pred2 = self.predict(pred1, load_gt=False, no_grad=True)
176 | # Pack, visualize, and save outputs
177 | output = self.pack_output(pred1, pred2, batch)
178 | self.visualizer.visualize(output, batch_i, outdir)
179 | np.savez(outdir + '.npz', **output)
180 |
181 | def pack_output(self, pred1, pred2, batch):
182 | out = {}
183 | # MarrNet-1 outputs
184 | pred_normal = pred1['normal'].detach().cpu()
185 | pred_silhou = pred1['silhou'].detach().cpu()
186 | pred_depth = pred1['depth'].detach().cpu()
187 | out['rgb_path'] = batch['rgb_path']
188 | out['rgb'] = util_img.denormalize_colors(batch['rgb'].detach().numpy())
189 | pred_silhou = self.postprocess(pred_silhou)
190 | pred_silhou = torch.clamp(pred_silhou, 0, 1)
191 | pred_silhou[pred_silhou < 0] = 0
192 | out['pred_silhou'] = pred_silhou.numpy()
193 | out['pred_normal'] = self.postprocess(
194 | pred_normal, bg=1.0, input_mask=pred_silhou
195 | ).numpy()
196 | out['pred_depth'] = self.postprocess(
197 | pred_depth, bg=0.0, input_mask=pred_silhou
198 | ).numpy()
199 | # D-tuned MarrNet-2 outputs
200 | out['pred_voxel'] = pred2['voxel'].detach().cpu().numpy()
201 | out['pred_voxel_noft'] = pred2['voxel_noft'].detach().cpu().numpy()
202 | return out
203 |
--------------------------------------------------------------------------------
/models/wgangp.py:
--------------------------------------------------------------------------------
1 | from os import makedirs
2 | from os.path import join
3 | from time import time
4 | import numpy as np
5 | import torch
6 | from networks.networks import VoxelGenerator, VoxelDiscriminator
7 | from .netinterface import NetInterface
8 |
9 |
10 | class Model(NetInterface):
11 | @classmethod
12 | def add_arguments(cls, parser):
13 | parser.add_argument(
14 | '--canon_voxel',
15 | action='store_true',
16 | help="Generate/discriminate canonical-pose voxels"
17 | )
18 | parser.add_argument(
19 | '--wgangp_lambda',
20 | type=float,
21 | default=10,
22 | help="WGANGP gradient penalty coefficient"
23 | )
24 | parser.add_argument(
25 | '--wgangp_norm',
26 | type=float,
27 | default=1,
28 | help="WGANGP gradient penalty norm"
29 | )
30 | parser.add_argument(
31 | '--gan_d_iter',
32 | type=int,
33 | default=1,
34 | help="# iterations D is trained per G's iteration"
35 | )
36 | return parser, set()
37 |
38 | def __init__(self, opt, logger):
39 | super().__init__(opt, logger)
40 | assert opt.canon_voxel, "GAN requires canonical-pose voxels to work"
41 | self.requires = ['voxel_canon']
42 | self.nz = 200
43 | self.net_g = G(self.nz)
44 | self.net_d = D()
45 | self._nets = [self.net_g, self.net_d]
46 | # Optimizers
47 | self.optim_params = dict()
48 | self.optim_params['betas'] = (opt.adam_beta1, opt.adam_beta2)
49 | self.optimizer_g = self.adam(
50 | self.net_g.parameters(),
51 | lr=opt.lr,
52 | **self.optim_params
53 | )
54 | self.optimizer_d = self.adam(
55 | self.net_d.parameters(),
56 | lr=opt.lr,
57 | **self.optim_params
58 | )
59 | self._optimizers = [self.optimizer_g, self.optimizer_d]
60 | #
61 | self.opt = opt
62 | self.preprocess = None
63 | self._metrics = ['err_d_real', 'err_d_fake', 'err_d_gp', 'err_d', 'err_g', 'loss']
64 | if opt.log_time:
65 | self._metrics += ['t_d_real', 't_d_fake', 't_d_grad', 't_g']
66 | self.input_names = ['voxel_canon']
67 | self.aux_names = ['one', 'neg_one']
68 | self.init_vars(add_path=True)
69 | self.init_weight(self.net_d)
70 | self.init_weight(self.net_g)
71 | self._last_err_g = None
72 |
73 | def __str__(self):
74 | s = "3D-WGANGP"
75 | return s
76 |
77 | def _train_on_batch(self, epoch, batch_idx, batch):
78 | net_d, net_g = self.net_d, self.net_g
79 | opt_d, opt_g = self.optimizer_d, self.optimizer_g
80 | one = self._aux.one
81 | neg_one = self._aux.neg_one
82 | real = batch['voxel_canon'].cuda()
83 | batch_size = real.shape[0]
84 | batch_log = {'size': batch_size}
85 |
86 | # Train D ...
87 | net_d.zero_grad()
88 | for p in net_d.parameters():
89 | p.requires_grad = True
90 | for p in net_g.parameters():
91 | p.requires_grad = False
92 | # with real
93 | t0 = time()
94 | err_d_real = self.net_d(real).mean()
95 | err_d_real.backward(neg_one)
96 | batch_log['err_d_real'] = -err_d_real.item()
97 | d_real_t = time() - t0
98 | # with fake
99 | t0 = time()
100 | with torch.no_grad():
101 | _, fake = self.net_g(batch_size)
102 | err_d_fake = self.net_d(fake).mean()
103 | err_d_fake.backward(one)
104 | batch_log['err_d_fake'] = err_d_fake.item()
105 | d_fake_t = time() - t0
106 | # with grad penalty
107 | t0 = time()
108 | if self.opt.wgangp_lambda > 0:
109 | grad_penalty = self.calc_grad_penalty(real, fake)
110 | grad_penalty.backward()
111 | batch_log['err_d_gp'] = grad_penalty.item()
112 | else:
113 | batch_log['err_d_gp'] = 0
114 | batch_log['err_d'] = batch_log['err_d_fake'] + batch_log['err_d_real'] \
115 | + batch_log['err_d_gp']
116 | d_grad_t = time() - t0
117 | opt_d.step()
118 |
119 | # Train G
120 | t0 = time()
121 | for p in net_d.parameters():
122 | p.requires_grad = False
123 | for p in net_g.parameters():
124 | p.requires_grad = True
125 | net_g.zero_grad()
126 | if batch_idx % self.opt.gan_d_iter == 0:
127 | _, gen = self.net_g(batch_size)
128 | err_g = self.net_d(gen).mean()
129 | err_g.backward(neg_one)
130 | opt_g.step()
131 | batch_log['err_g'] = -err_g.item()
132 | self._last_err_g = batch_log['err_g']
133 | else:
134 | batch_log['err_g'] = self._last_err_g
135 | g_t = time() - t0
136 |
137 | if self.opt.log_time:
138 | batch_log['t_d_real'] = d_real_t
139 | batch_log['t_d_fake'] = d_fake_t
140 | batch_log['t_d_grad'] = d_grad_t
141 | batch_log['t_g'] = g_t
142 | return batch_log
143 |
144 | def calc_grad_penalty(self, real, fake):
145 | alpha = torch.rand(real.shape[0], 1)
146 | alpha = alpha.expand(
147 | real.shape[0], real.nelement() // real.shape[0]
148 | ).contiguous().view(*real.shape).cuda()
149 | inter = alpha * real + (1 - alpha) * fake
150 | inter.requires_grad = True
151 | err_d_inter = self.net_d(inter)
152 | grads = torch.autograd.grad(
153 | outputs=err_d_inter,
154 | inputs=inter,
155 | grad_outputs=torch.ones(err_d_inter.size()).cuda(),
156 | create_graph=True,
157 | retain_graph=True,
158 | only_inputs=True
159 | )[0]
160 | grads = grads.view(grads.size(0), -1)
161 | grad_penalty = (
162 | ((grads + 1e-16).norm(2, dim=1) - self.opt.wgangp_norm) ** 2
163 | ).mean() * self.opt.wgangp_lambda
164 | return grad_penalty
165 |
166 | def _vali_on_batch(self, epoch, batch_idx, batch):
167 | batch_size = batch['voxel_canon'].shape[0]
168 | batch_log = {'size': batch_size}
169 | with torch.no_grad():
170 | noise, gen = self.net_g(batch_size)
171 | disc = self.net_d(gen)
172 | batch_log['loss'] = -disc.mean().item()
173 | # Save and visualize
174 | if np.mod(epoch, self.opt.vis_every_train) == 0:
175 | if batch_idx < self.opt.vis_batches_train:
176 | outdir = join(self.full_logdir, 'epoch%04d_vali' % epoch)
177 | makedirs(outdir, exist_ok=True)
178 | output = self.pack_output(noise, gen, disc)
179 | self.visualizer.visualize(output, batch_idx, outdir)
180 | np.savez(join(outdir, 'batch%04d' % batch_idx), **output)
181 | return batch_log
182 |
183 | @staticmethod
184 | def pack_output(noise, gen, disc):
185 | out = {
186 | 'noise': noise.cpu().numpy(),
187 | 'gen_voxel': gen.cpu().numpy(),
188 | 'disc': disc.cpu().numpy(),
189 | }
190 | return out
191 |
192 |
193 | class G(VoxelGenerator):
194 | def __init__(self, nz):
195 | super().__init__(nz=nz, nf=64, bias=False, res=128)
196 | self.nz = nz
197 | self.noise = torch.FloatTensor().cuda()
198 |
199 | def forward(self, batch_size):
200 | x = self.noise
201 | x.resize_(batch_size, self.nz, 1, 1, 1).normal_(0, 1)
202 | y = super().forward(x)
203 | return x, y
204 |
205 |
206 | class D(VoxelDiscriminator):
207 | def __init__(self):
208 | super().__init__(nf=64, bias=False, res=128)
209 |
210 | def forward(self, x):
211 | if x.dim() == 4:
212 | x.unsqueeze_(1)
213 | y = super().forward(x)
214 | return y
215 |
--------------------------------------------------------------------------------
/networks/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiumingzhang/GenRe-ShapeHD/ee42add2707de509b5914ab444ae91b832f75981/networks/__init__.py
--------------------------------------------------------------------------------
/networks/networks.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from .revresnet import resnet18
3 | from torch import cat
4 |
5 |
6 | class ImageEncoder(nn.Module):
7 | """
8 | Used for 2.5D maps to 3D voxels
9 | """
10 |
11 | def __init__(self, input_nc, encode_dims=200):
12 | super().__init__()
13 | resnet_m = resnet18(pretrained=True)
14 | resnet_m.conv1 = nn.Conv2d(
15 | input_nc, 64, 7, stride=2, padding=3, bias=False
16 | )
17 | resnet_m.avgpool = nn.AdaptiveAvgPool2d(1)
18 | resnet_m.fc = nn.Linear(512, encode_dims)
19 | self.main = nn.Sequential(resnet_m)
20 |
21 | def forward(self, x):
22 | return self.main(x)
23 |
24 |
25 | class VoxelDecoder(nn.Module):
26 | """
27 | Used for 2.5D maps to 3D voxels
28 | """
29 |
30 | def __init__(self, n_dims=200, nf=512):
31 | super().__init__()
32 | self.main = nn.Sequential(
33 | # volconv1
34 | deconv3d_add3(n_dims, nf, True),
35 | batchnorm3d(nf),
36 | relu(),
37 | # volconv2
38 | deconv3d_2x(nf, nf // 2, True),
39 | batchnorm3d(nf // 2),
40 | relu(),
41 | # volconv3
42 | nn.Sequential(), # NOTE: no-op for backward compatibility; consider removing
43 | nn.Sequential(), # NOTE
44 | deconv3d_2x(nf // 2, nf // 4, True),
45 | batchnorm3d(nf // 4),
46 | relu(),
47 | # volconv4
48 | deconv3d_2x(nf // 4, nf // 8, True),
49 | batchnorm3d(nf // 8),
50 | relu(),
51 | # volconv5
52 | deconv3d_2x(nf // 8, nf // 16, True),
53 | batchnorm3d(nf // 16),
54 | relu(),
55 | # volconv6
56 | deconv3d_2x(nf // 16, 1, True)
57 | )
58 |
59 | def forward(self, x):
60 | x_vox = x.view(x.size(0), -1, 1, 1, 1)
61 | return self.main(x_vox)
62 |
63 |
64 | class VoxelGenerator(nn.Module):
65 | def __init__(self, nz=200, nf=64, bias=False, res=128):
66 | super().__init__()
67 | layers = [
68 | # nzx1x1x1
69 | deconv3d_add3(nz, nf * 8, bias),
70 | batchnorm3d(nf * 8),
71 | relu(),
72 | # (nf*8)x4x4x4
73 | deconv3d_2x(nf * 8, nf * 4, bias),
74 | batchnorm3d(nf * 4),
75 | relu(),
76 | # (nf*4)x8x8x8
77 | deconv3d_2x(nf * 4, nf * 2, bias),
78 | batchnorm3d(nf * 2),
79 | relu(),
80 | # (nf*2)x16x16x16
81 | deconv3d_2x(nf * 2, nf, bias),
82 | batchnorm3d(nf),
83 | relu(),
84 | # nfx32x32x32
85 | ]
86 | if res == 64:
87 | layers.append(deconv3d_2x(nf, 1, bias))
88 | # 1x64x64x64
89 | elif res == 128:
90 | layers += [
91 | deconv3d_2x(nf, nf, bias),
92 | batchnorm3d(nf),
93 | relu(),
94 | # nfx64x64x64
95 | deconv3d_2x(nf, 1, bias),
96 | # 1x128x128x128
97 | ]
98 | else:
99 | raise NotImplementedError(res)
100 | layers.append(nn.Sigmoid())
101 | self.main = nn.Sequential(*layers)
102 |
103 | def forward(self, x):
104 | return self.main(x)
105 |
106 |
107 | class VoxelDiscriminator(nn.Module):
108 | def __init__(self, nf=64, bias=False, res=128):
109 | super().__init__()
110 | layers = [
111 | # 1x64x64x64
112 | conv3d_half(1, nf, bias),
113 | relu_leaky(),
114 | # nfx32x32x32
115 | conv3d_half(nf, nf * 2, bias),
116 | # batchnorm3d(nf * 2),
117 | relu_leaky(),
118 | # (nf*2)x16x16x16
119 | conv3d_half(nf * 2, nf * 4, bias),
120 | # batchnorm3d(nf * 4),
121 | relu_leaky(),
122 | # (nf*4)x8x8x8
123 | conv3d_half(nf * 4, nf * 8, bias),
124 | # batchnorm3d(nf * 8),
125 | relu_leaky(),
126 | # (nf*8)x4x4
127 | conv3d_minus3(nf * 8, 1, bias),
128 | # 1x1x1
129 | ]
130 | if res == 64:
131 | pass
132 | elif res == 128:
133 | extra_layers = [
134 | conv3d_half(nf, nf, bias),
135 | relu_leaky(),
136 | ]
137 | layers = layers[:2] + extra_layers + layers[2:]
138 | else:
139 | raise NotImplementedError(res)
140 | self.main = nn.Sequential(*layers)
141 |
142 | def forward(self, x):
143 | y = self.main(x)
144 | return y.view(-1, 1).squeeze(1)
145 |
146 |
147 | class Unet_3D(nn.Module):
148 | def __init__(self, nf=20, in_channel=2, no_linear=False):
149 | super(Unet_3D, self).__init__()
150 | self.nf = nf
151 | self.enc1 = Conv3d_block(in_channel, nf, 8, 2, 3) # =>64
152 | self.enc2 = Conv3d_block(nf, 2 * nf, 4, 2, 1) # =>32
153 | self.enc3 = Conv3d_block(2 * nf, 4 * nf, 4, 2, 1) # =>16
154 | self.enc4 = Conv3d_block(4 * nf, 8 * nf, 4, 2, 1) # =>8
155 | self.enc5 = Conv3d_block(8 * nf, 16 * nf, 4, 2, 1) # =>4
156 | self.enc6 = Conv3d_block(16 * nf, 32 * nf, 4, 1, 0) # =>1
157 | self.full_conv_block = nn.Sequential(
158 | nn.Linear(32 * nf, 32 * nf),
159 | nn.LeakyReLU(),
160 | )
161 | self.dec1 = Deconv3d_skip(32 * 2 * nf, 16 * nf, 4, 1, 0, 0) # =>4
162 | self.dec2 = Deconv3d_skip(16 * 2 * nf, 8 * nf, 4, 2, 1, 0) # =>8
163 | self.dec3 = Deconv3d_skip(8 * 2 * nf, 4 * nf, 4, 2, 1, 0) # =>16
164 | self.dec4 = Deconv3d_skip(4 * 2 * nf, 2 * nf, 4, 2, 1, 0) # =>32
165 | self.dec5 = Deconv3d_skip(4 * nf, nf, 8, 2, 3, 0) # =>64
166 | self.dec6 = Deconv3d_skip(
167 | 2 * nf, 1, 4, 2, 1, 0, is_activate=False) # =>128
168 | self.no_linear = no_linear
169 |
170 | def forward(self, x):
171 | enc1 = self.enc1(x)
172 | enc2 = self.enc2(enc1)
173 | enc3 = self.enc3(enc2)
174 | enc4 = self.enc4(enc3)
175 | enc5 = self.enc5(enc4)
176 | enc6 = self.enc6(enc5)
177 | # print(enc6.size())
178 | if not self.no_linear:
179 | flatten = enc6.view(enc6.size()[0], self.nf * 32)
180 | bottleneck = self.full_conv_block(flatten)
181 | bottleneck = bottleneck.view(enc6.size()[0], self.nf * 32, 1, 1, 1)
182 | dec1 = self.dec1(bottleneck, enc6)
183 | else:
184 | dec1 = self.dec1(enc6, enc6)
185 | dec2 = self.dec2(dec1, enc5)
186 | dec3 = self.dec3(dec2, enc4)
187 | dec4 = self.dec4(dec3, enc3)
188 | dec5 = self.dec5(dec4, enc2)
189 | out = self.dec6(dec5, enc1)
190 | return out
191 |
192 |
193 | class Conv3d_block(nn.Module):
194 | def __init__(self, ncin, ncout, kernel_size, stride, pad, dropout=False):
195 | super().__init__()
196 | self.net = nn.Sequential(
197 | nn.Conv3d(ncin, ncout, kernel_size, stride, pad),
198 | nn.BatchNorm3d(ncout),
199 | nn.LeakyReLU()
200 | )
201 |
202 | def forward(self, x):
203 | return self.net(x)
204 |
205 |
206 | class Deconv3d_skip(nn.Module):
207 | def __init__(self, ncin, ncout, kernel_size, stride, pad, extra=0, is_activate=True):
208 | super(Deconv3d_skip, self).__init__()
209 | if is_activate:
210 | self.net = nn.Sequential(
211 | nn.ConvTranspose3d(ncin, ncout, kernel_size,
212 | stride, pad, extra),
213 | nn.BatchNorm3d(ncout),
214 | nn.LeakyReLU()
215 | )
216 | else:
217 | self.net = nn.ConvTranspose3d(
218 | ncin, ncout, kernel_size, stride, pad, extra)
219 |
220 | def forward(self, x, skip_in):
221 | y = cat((x, skip_in), dim=1)
222 | return self.net(y)
223 |
224 |
225 | class ViewAsLinear(nn.Module):
226 | @staticmethod
227 | def forward(x):
228 | return x.view(x.shape[0], -1)
229 |
230 |
231 | def relu():
232 | return nn.ReLU(inplace=True)
233 |
234 |
235 | def relu_leaky():
236 | return nn.LeakyReLU(0.2, inplace=True)
237 |
238 |
239 | def maxpool():
240 | return nn.MaxPool2d(3, stride=2, padding=0)
241 |
242 |
243 | def dropout():
244 | return nn.Dropout(p=0.5, inplace=False)
245 |
246 |
247 | def conv3d_half(n_ch_in, n_ch_out, bias):
248 | return nn.Conv3d(
249 | n_ch_in, n_ch_out, 4, stride=2, padding=1, dilation=1, groups=1, bias=bias
250 | )
251 |
252 |
253 | def deconv3d_2x(n_ch_in, n_ch_out, bias):
254 | return nn.ConvTranspose3d(
255 | n_ch_in, n_ch_out, 4, stride=2, padding=1, dilation=1, groups=1, bias=bias
256 | )
257 |
258 |
259 | def conv3d_minus3(n_ch_in, n_ch_out, bias):
260 | return nn.Conv3d(
261 | n_ch_in, n_ch_out, 4, stride=1, padding=0, dilation=1, groups=1, bias=bias
262 | )
263 |
264 |
265 | def deconv3d_add3(n_ch_in, n_ch_out, bias):
266 | return nn.ConvTranspose3d(
267 | n_ch_in, n_ch_out, 4, stride=1, padding=0, dilation=1, groups=1, bias=bias
268 | )
269 |
270 |
271 | def batchnorm1d(n_feat):
272 | return nn.BatchNorm1d(n_feat, eps=1e-5, momentum=0.1, affine=True)
273 |
274 |
275 | def batchnorm(n_feat):
276 | return nn.BatchNorm2d(n_feat, eps=1e-5, momentum=0.1, affine=True)
277 |
278 |
279 | def batchnorm3d(n_feat):
280 | return nn.BatchNorm3d(n_feat, eps=1e-5, momentum=0.1, affine=True)
281 |
282 |
283 | def fc(n_in, n_out):
284 | return nn.Linear(n_in, n_out, bias=True)
285 |
--------------------------------------------------------------------------------
/networks/revresnet.py:
--------------------------------------------------------------------------------
1 | """
2 | This is an implementation of a U-Net using ResNet-18 blocks
3 | """
4 | import torch
5 | from torch import nn
6 | from torchvision.models import resnet18
7 |
8 |
9 | def deconv3x3(in_planes, out_planes, stride=1, output_padding=0):
10 | return nn.ConvTranspose2d(
11 | in_planes,
12 | out_planes,
13 | kernel_size=3,
14 | stride=stride,
15 | padding=1,
16 | bias=False,
17 | output_padding=output_padding
18 | )
19 |
20 |
21 | class RevBasicBlock(nn.Module):
22 | expansion = 1
23 |
24 | def __init__(self, inplanes, planes, stride=1, upsample=None):
25 | super(RevBasicBlock, self).__init__()
26 | self.deconv1 = deconv3x3(inplanes, planes, stride=1)
27 | # Note that in ResNet, the stride is on the second layer
28 | # Here we put it on the first layer as the mirrored block
29 | self.bn1 = nn.BatchNorm2d(planes)
30 | self.relu = nn.ReLU(inplace=True)
31 | self.deconv2 = deconv3x3(planes, planes, stride=stride,
32 | output_padding=1 if stride > 1 else 0)
33 | self.bn2 = nn.BatchNorm2d(planes)
34 | self.upsample = upsample
35 | self.stride = stride
36 |
37 | def forward(self, x):
38 | residual = x
39 | out = self.deconv1(x)
40 | out = self.bn1(out)
41 | out = self.relu(out)
42 | out = self.deconv2(out)
43 | out = self.bn2(out)
44 | if self.upsample is not None:
45 | residual = self.upsample(x)
46 | out += residual
47 | out = self.relu(out)
48 | return out
49 |
50 |
51 | class RevBottleneck(nn.Module):
52 | expansion = 4
53 |
54 | def __init__(self, inplanes, planes, stride=1, upsample=None):
55 | super(RevBottleneck, self).__init__()
56 | bottleneck_planes = int(inplanes / 4)
57 | self.deconv1 = nn.ConvTranspose2d(
58 | inplanes,
59 | bottleneck_planes,
60 | kernel_size=1,
61 | bias=False,
62 | stride=1
63 | ) # conv and deconv are the same when kernel size is 1
64 | self.bn1 = nn.BatchNorm2d(bottleneck_planes)
65 | self.deconv2 = nn.ConvTranspose2d(
66 | bottleneck_planes,
67 | bottleneck_planes,
68 | kernel_size=3,
69 | stride=1,
70 | padding=1,
71 | bias=False
72 | )
73 | self.bn2 = nn.BatchNorm2d(bottleneck_planes)
74 | self.deconv3 = nn.ConvTranspose2d(
75 | bottleneck_planes,
76 | planes,
77 | kernel_size=1,
78 | bias=False,
79 | stride=stride,
80 | output_padding=1 if stride > 0 else 0
81 | )
82 | self.bn3 = nn.BatchNorm2d(planes)
83 | self.relu = nn.ReLU(inplace=True)
84 | self.upsample = upsample
85 | self.stride = stride
86 |
87 | def forward(self, x):
88 | residual = x
89 | out = self.deconv1(x)
90 | out = self.bn1(out)
91 | out = self.relu(out)
92 | out = self.deconv2(out)
93 | out = self.bn2(out)
94 | out = self.relu(out)
95 | out = self.deconv3(out)
96 | out = self.bn3(out)
97 | if self.upsample is not None:
98 | residual = self.upsample(x)
99 | out += residual
100 | out = self.relu(out)
101 | return out
102 |
103 |
104 | class RevResNet(nn.Module):
105 | def __init__(self, block, layers, planes, inplanes=None, out_planes=5):
106 | """
107 | planes: # output channels for each block
108 | inplanes: # input channels for the input at each layer
109 | If missing, it will be inferred.
110 | """
111 | if inplanes is None:
112 | inplanes = [512]
113 | self.inplanes = inplanes[0]
114 | super(RevResNet, self).__init__()
115 | inplanes_after_blocks = inplanes[4] if len(inplanes) > 4 else planes[3]
116 | self.deconv1 = nn.ConvTranspose2d(
117 | inplanes_after_blocks,
118 | planes[3],
119 | kernel_size=3,
120 | stride=2,
121 | padding=1,
122 | output_padding=1
123 | )
124 | self.deconv2 = nn.ConvTranspose2d(
125 | planes[3],
126 | out_planes,
127 | kernel_size=7,
128 | stride=2,
129 | padding=3,
130 | bias=False,
131 | output_padding=1
132 | )
133 | self.bn1 = nn.BatchNorm2d(planes[3])
134 | self.relu = nn.ReLU(inplace=True)
135 | self.layer1 = self._make_layer(block, planes[0], layers[0], stride=2)
136 | if len(inplanes) > 1:
137 | self.inplanes = inplanes[1]
138 | self.layer2 = self._make_layer(block, planes[1], layers[1], stride=2)
139 | if len(inplanes) > 2:
140 | self.inplanes = inplanes[2]
141 | self.layer3 = self._make_layer(block, planes[2], layers[2], stride=2)
142 | if len(inplanes) > 3:
143 | self.inplanes = inplanes[3]
144 | self.layer4 = self._make_layer(block, planes[3], layers[3])
145 |
146 | def _make_layer(self, block, planes, blocks, stride=1):
147 | upsample = None
148 | if stride != 1 or self.inplanes != planes:
149 | upsample = nn.Sequential(
150 | nn.ConvTranspose2d(
151 | self.inplanes,
152 | planes,
153 | kernel_size=1,
154 | stride=stride,
155 | bias=False,
156 | output_padding=1 if stride > 1 else 0
157 | ),
158 | nn.BatchNorm2d(planes),
159 | )
160 | layers = []
161 | layers.append(block(self.inplanes, planes, stride, upsample))
162 | self.inplanes = planes
163 | for _ in range(1, blocks):
164 | layers.append(block(self.inplanes, planes))
165 | return nn.Sequential(*layers)
166 |
167 | def forward(self, x):
168 | x = self.layer1(x)
169 | x = self.layer2(x)
170 | x = self.layer3(x)
171 | x = self.layer4(x)
172 | x = self.deconv1(x)
173 | x = self.bn1(x)
174 | x = self.relu(x)
175 | x = self.deconv2(x)
176 | return x
177 |
178 |
179 | def revresnet18(**kwargs):
180 | model = RevResNet(
181 | RevBasicBlock,
182 | [2, 2, 2, 2],
183 | [512, 256, 128, 64],
184 | **kwargs
185 | )
186 | return model
187 |
188 |
189 | def revuresnet18(**kwargs):
190 | """
191 | Reverse ResNet-18 compatible with the U-Net setting
192 | """
193 | model = RevResNet(
194 | RevBasicBlock,
195 | [2, 2, 2, 2],
196 | [256, 128, 64, 64],
197 | inplanes=[512, 512, 256, 128, 128],
198 | **kwargs
199 | )
200 | return model
201 |
202 |
203 | def _num_parameters(net):
204 | return sum([
205 | x.numel() for x in list(net.parameters())
206 | ])
207 |
208 |
209 | def main():
210 | net = resnet18()
211 | revnet = revresnet18()
212 | net.avgpool = nn.AvgPool2d(kernel_size=8)
213 | for name, mod in net.named_children():
214 | mod.__name = name
215 | mod.register_forward_hook(
216 | lambda mod, input, output: print(mod.__name, output.shape)
217 | )
218 | for name, mod in revnet.named_children():
219 | mod.__name = name
220 | mod.register_forward_hook(
221 | lambda mod, input, output: print(mod.__name, output.shape)
222 | )
223 | # print(net)
224 | print('resnet', _num_parameters(net))
225 | net(torch.zeros(2, 3, 256, 256))
226 | print('')
227 | print('revresnet', _num_parameters(revnet))
228 | # print(revnet)
229 | revnet(torch.zeros(2, 512, 8, 8))
230 | print('')
231 | revunet = RevResNet(RevBasicBlock, [2, 2, 2, 2], [512, 512, 256, 128])
232 | print('revunet', _num_parameters(revunet))
233 |
234 |
235 | if __name__ == '__main__':
236 | main()
237 |
--------------------------------------------------------------------------------
/networks/uresnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from networks.revresnet import revuresnet18, resnet18
4 |
5 |
6 | class Net(nn.Module):
7 | """
8 | Used for RGB to 2.5D maps
9 | """
10 |
11 | def __init__(self, out_planes, layer_names, input_planes=3):
12 | super().__init__()
13 |
14 | # Encoder
15 | module_list = list()
16 | resnet = resnet18(pretrained=True)
17 | in_conv = nn.Conv2d(input_planes, 64, kernel_size=7, stride=2, padding=3,
18 | bias=False)
19 | module_list.append(
20 | nn.Sequential(
21 | resnet.conv1 if input_planes == 3 else in_conv,
22 | resnet.bn1,
23 | resnet.relu,
24 | resnet.maxpool
25 | )
26 | )
27 | module_list.append(resnet.layer1)
28 | module_list.append(resnet.layer2)
29 | module_list.append(resnet.layer3)
30 | module_list.append(resnet.layer4)
31 | self.encoder = nn.ModuleList(module_list)
32 | self.encoder_out = None
33 |
34 | # Decoder
35 | self.decoders = {}
36 | for out_plane, layer_name in zip(out_planes, layer_names):
37 | module_list = list()
38 | revresnet = revuresnet18(out_planes=out_plane)
39 | module_list.append(revresnet.layer1)
40 | module_list.append(revresnet.layer2)
41 | module_list.append(revresnet.layer3)
42 | module_list.append(revresnet.layer4)
43 | module_list.append(
44 | nn.Sequential(
45 | revresnet.deconv1,
46 | revresnet.bn1,
47 | revresnet.relu,
48 | revresnet.deconv2
49 | )
50 | )
51 | module_list = nn.ModuleList(module_list)
52 | setattr(self, 'decoder_' + layer_name, module_list)
53 | self.decoders[layer_name] = module_list
54 |
55 | def forward(self, im):
56 | # Encode
57 | feat = im
58 | feat_maps = list()
59 | for f in self.encoder:
60 | feat = f(feat)
61 | feat_maps.append(feat)
62 | self.encoder_out = feat_maps[-1]
63 | # Decode
64 | outputs = {}
65 | for layer_name, decoder in self.decoders.items():
66 | x = feat_maps[-1]
67 | for idx, f in enumerate(decoder):
68 | x = f(x)
69 | if idx < len(decoder) - 1:
70 | feat_map = feat_maps[-(idx + 2)]
71 | assert feat_map.shape[2:4] == x.shape[2:4]
72 | x = torch.cat((x, feat_map), dim=1)
73 | outputs[layer_name] = x
74 | return outputs
75 |
76 |
77 | class Net_inpaint(nn.Module):
78 | """
79 | Used for RGB to 2.5D maps
80 | """
81 |
82 | def __init__(self, out_planes, layer_names, input_planes=3):
83 | super().__init__()
84 |
85 | # Encoder
86 | module_list = list()
87 | resnet = resnet18(pretrained=True)
88 | in_conv = nn.Conv2d(input_planes, 64, kernel_size=7, stride=2, padding=3,
89 | bias=False)
90 | module_list.append(
91 | nn.Sequential(
92 | resnet.conv1 if input_planes == 3 else in_conv,
93 | resnet.bn1,
94 | resnet.relu,
95 | resnet.maxpool
96 | )
97 | )
98 | module_list.append(resnet.layer1)
99 | module_list.append(resnet.layer2)
100 | module_list.append(resnet.layer3)
101 | module_list.append(resnet.layer4)
102 | self.encoder = nn.ModuleList(module_list)
103 | self.encoder_out = None
104 | self.deconv2 = nn.ConvTranspose2d(64, 1, kernel_size=8, stride=2, padding=3, bias=False, output_padding=0)
105 | # Decoder
106 | self.decoders = {}
107 | for out_plane, layer_name in zip(out_planes, layer_names):
108 | module_list = list()
109 | revresnet = revuresnet18(out_planes=out_plane)
110 | module_list.append(revresnet.layer1)
111 | module_list.append(revresnet.layer2)
112 | module_list.append(revresnet.layer3)
113 | module_list.append(revresnet.layer4)
114 | module_list.append(
115 | nn.Sequential(
116 | revresnet.deconv1,
117 | revresnet.bn1,
118 | revresnet.relu,
119 | self.deconv2
120 | )
121 | )
122 | module_list = nn.ModuleList(module_list)
123 | setattr(self, 'decoder_' + layer_name, module_list)
124 | self.decoders[layer_name] = module_list
125 |
126 | def forward(self, im):
127 | # Encode
128 | feat = im
129 | feat_maps = list()
130 | for f in self.encoder:
131 | feat = f(feat)
132 | feat_maps.append(feat)
133 | self.encoder_out = feat_maps[-1]
134 | # Decode
135 | outputs = {}
136 | for layer_name, decoder in self.decoders.items():
137 | x = feat_maps[-1]
138 | for idx, f in enumerate(decoder):
139 | x = f(x)
140 | if idx < len(decoder) - 1:
141 | feat_map = feat_maps[-(idx + 2)]
142 | assert feat_map.shape[2:4] == x.shape[2:4]
143 | x = torch.cat((x, feat_map), dim=1)
144 | outputs[layer_name] = x
145 | return outputs
146 |
--------------------------------------------------------------------------------
/options/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiumingzhang/GenRe-ShapeHD/ee42add2707de509b5914ab444ae91b832f75981/options/__init__.py
--------------------------------------------------------------------------------
/options/options_test.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import argparse
3 | from datasets import get_dataset
4 | from models import get_model
5 | from options import options_train
6 |
7 |
8 | def add_general_arguments(parser):
9 | parser, _ = options_train.add_general_arguments(parser)
10 |
11 | # Dataset IO
12 | parser.add_argument('--input_rgb', type=str, required=True,
13 | help="Input RGB filename patterns, e.g., '/path/to/images/*_rgb.png'")
14 | parser.add_argument('--input_mask', type=str, required=True,
15 | help=("Corresponding mask filename patterns, e.g., '/path/to/images/*_mask.png'. "
16 | "For MarrNet/ShapeHD, masks are not required, so used only for bbox cropping. "
17 | "For GenRe, masks are input together with RGB"))
18 |
19 | # Network
20 | parser.add_argument('--net_file', type=str, required=True,
21 | help="Path to the trained network")
22 |
23 | # Output
24 | parser.add_argument('--output_dir', type=str, required=True,
25 | help="Output directory")
26 | parser.add_argument('--overwrite', action='store_true',
27 | help="Whether to overwrite the output folder if it exists")
28 |
29 | return parser
30 |
31 |
32 | def parse(add_additional_arguments=None):
33 | parser = argparse.ArgumentParser()
34 | parser = add_general_arguments(parser)
35 | if add_additional_arguments:
36 | parser, _ = add_additional_arguments(parser)
37 | opt_general, _ = parser.parse_known_args()
38 | net_name = opt_general.net
39 | del opt_general
40 | dataset_name = 'test'
41 |
42 | # Add parsers depending on dataset and models
43 | parser, _ = get_dataset(dataset_name).add_arguments(parser)
44 | parser, _ = get_model(net_name, test=True).add_arguments(parser)
45 |
46 | # Manually add '-h' after adding all parser arguments
47 | if '--printhelp' in sys.argv:
48 | sys.argv.append('-h')
49 |
50 | opt = parser.parse_args()
51 | return opt
52 |
--------------------------------------------------------------------------------
/options/options_train.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import argparse
3 | import torch
4 | from util.util_print import str_warning
5 | from datasets import get_dataset
6 | from models import get_model
7 |
8 |
9 | def add_general_arguments(parser):
10 | # Parameters that will NOT be overwritten when resuming
11 | unique_params = {'gpu', 'resume', 'epoch', 'workers', 'batch_size', 'save_net', 'epoch_batches', 'logdir'}
12 |
13 | parser.add_argument('--gpu', default='0', type=str,
14 | help='gpu to use')
15 | parser.add_argument('--manual_seed', type=int, default=None,
16 | help='manual seed for randomness')
17 | parser.add_argument('--resume', type=int, default=0,
18 | help='resume training by loading checkpoint.pt or best.pt. Use 0 for training from scratch, -1 for last and -2 for previous best. Use positive number for a specific epoch. \
19 | Most options will be overwritten to resume training with exactly same environment')
20 | parser.add_argument(
21 | '--suffix', default='', type=str,
22 | help="Suffix for `logdir` that will be formatted with `opt`, e.g., '{classes}_lr{lr}'"
23 | )
24 | parser.add_argument('--epoch', type=int, default=0,
25 | help='number of epochs to train')
26 |
27 | # Dataset IO
28 | parser.add_argument('--dataset', type=str, default=None,
29 | help='dataset to use')
30 | parser.add_argument('--workers', type=int, default=4,
31 | help='number of data loading workers')
32 | parser.add_argument('--classes', default='chair', type=str,
33 | help='class to use')
34 | parser.add_argument('--batch_size', type=int, default=16,
35 | help='training batch size')
36 | parser.add_argument('--epoch_batches', default=None, type=int, help='number of batches used per epoch')
37 | parser.add_argument('--eval_batches', default=None,
38 | type=int, help='max number of batches used for evaluation per epoch')
39 | parser.add_argument('--eval_at_start', action='store_true',
40 | help='run evaluation before starting to train')
41 | parser.add_argument('--log_time', action='store_true', help='adding time log')
42 |
43 | # Network name
44 | parser.add_argument('--net', type=str, required=True,
45 | help='network type to use')
46 |
47 | # Optimizer
48 | parser.add_argument('--optim', type=str, default='adam',
49 | help='optimizer to use')
50 | parser.add_argument('--lr', type=float, default=1e-4,
51 | help='learning rate')
52 | parser.add_argument('--adam_beta1', type=float, default=0.5,
53 | help='beta1 of adam')
54 | parser.add_argument('--adam_beta2', type=float, default=0.9,
55 | help='beta2 of adam')
56 | parser.add_argument('--sgd_momentum', type=float, default=0.9,
57 | help="momentum factor of SGD")
58 | parser.add_argument('--sgd_dampening', type=float, default=0,
59 | help="dampening for momentum of SGD")
60 | parser.add_argument('--wdecay', type=float, default=0.0,
61 | help='weight decay')
62 |
63 | # Logging and visualization
64 | parser.add_argument('--logdir', type=str, default=None,
65 | help='Root directory for logging. Actual dir is [logdir]/[net_classes_dataset]/[expr_id]')
66 | parser.add_argument('--log_batch', action='store_true',
67 | help='Log batch loss')
68 | parser.add_argument('--expr_id', type=int, default=0,
69 | help='Experiment index. non-positive ones are overwritten by default. Use 0 for code test. ')
70 | parser.add_argument('--save_net', type=int, default=1,
71 | help='Period of saving network weights')
72 | parser.add_argument('--save_net_opt', action='store_true',
73 | help='Save optimizer state in regular network saving')
74 | parser.add_argument('--vis_every_vali', default=1, type=int,
75 | help="Visualize every N epochs during validation")
76 | parser.add_argument('--vis_every_train', default=1, type=int,
77 | help="Visualize every N epochs during training")
78 | parser.add_argument('--vis_batches_vali', type=int, default=10,
79 | help="# batches to visualize during validation")
80 | parser.add_argument('--vis_batches_train', type=int, default=10,
81 | help="# batches to visualize during training")
82 | parser.add_argument('--tensorboard', action='store_true',
83 | help='Use tensorboard for logging. If enabled, the output log will be at [logdir]/[tensorboard]/[net_classes_dataset]/[expr_id]')
84 | parser.add_argument('--vis_workers', default=4, type=int, help="# workers for the visualizer")
85 | parser.add_argument('--vis_param_f', default=None, type=str,
86 | help="Parameter file read by the visualizer on every batch; defaults to 'visualize/config.json'")
87 |
88 | return parser, unique_params
89 |
90 |
91 | def overwrite(opt, opt_f_old, unique_params):
92 | opt_dict = vars(opt)
93 | opt_dict_old = torch.load(opt_f_old)
94 | for k, v in opt_dict_old.items():
95 | if k in opt_dict:
96 | if (k not in unique_params) and (opt_dict[k] != v):
97 | print(str_warning, "Overwriting %s for resuming training: %s -> %s"
98 | % (k, str(opt_dict[k]), str(v)))
99 | opt_dict[k] = v
100 | else:
101 | print(str_warning, "Ignoring %s, an old option that no longer exists" % k)
102 | opt = argparse.Namespace(**opt_dict)
103 | return opt
104 |
105 |
106 | def parse(add_additional_arguments=None):
107 | parser = argparse.ArgumentParser()
108 | parser, unique_params = add_general_arguments(parser)
109 | if add_additional_arguments is not None:
110 | parser, unique_params_additional = add_additional_arguments(parser)
111 | unique_params = unique_params.union(unique_params_additional)
112 | opt_general, _ = parser.parse_known_args()
113 | dataset_name, net_name = opt_general.dataset, opt_general.net
114 | del opt_general
115 |
116 | # Add parsers depending on dataset and models
117 | parser, unique_params_dataset = get_dataset(dataset_name).add_arguments(parser)
118 | parser, unique_params_model = get_model(net_name).add_arguments(parser)
119 |
120 | # Manually add '-h' after adding all parser arguments
121 | if '--printhelp' in sys.argv:
122 | sys.argv.append('-h')
123 |
124 | opt = parser.parse_args()
125 | unique_params = unique_params.union(unique_params_dataset)
126 | unique_params = unique_params.union(unique_params_model)
127 | return opt, unique_params
128 |
--------------------------------------------------------------------------------
/scripts/finetune_marrnet.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # Finetune MarrNet-2 with MarrNet-1 predictions
4 |
5 | outdir=./output/marrnet
6 | class=drc
7 | marrnet1=/path/to/marrnet1.pt
8 | marrnet2=/path/to/marrnet2.pt
9 |
10 | if [ $# -lt 1 ]; then
11 | echo "Usage: $0 gpu[ ...]"
12 | exit 1
13 | fi
14 | gpu="$1"
15 | shift # shift the remaining arguments
16 |
17 | set -e
18 |
19 | source activate shaperecon
20 |
21 | python train.py \
22 | --net marrnet \
23 | --marrnet1 "$marrnet1" \
24 | --marrnet2 "$marrnet2" \
25 | --dataset shapenet \
26 | --classes "$class" \
27 | --batch_size 4 \
28 | --epoch_batches 2500 \
29 | --eval_batches 5 \
30 | --optim adam \
31 | --lr 1e-3 \
32 | --epoch 1000 \
33 | --vis_batches_vali 10 \
34 | --gpu "$gpu" \
35 | --save_net 10 \
36 | --workers 4 \
37 | --logdir "$outdir" \
38 | --suffix '{classes}' \
39 | --tensorboard \
40 | $*
41 |
42 | source deactivate
43 |
--------------------------------------------------------------------------------
/scripts/finetune_shapehd.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # Finetune ShapeHD 3D estimator with GAN losses
4 |
5 | outdir=./output/shapehd
6 | class=drc
7 | marrnet2=/path/to/marrnet2.pt
8 | gan=/path/to/gan.pt
9 |
10 | if [ $# -lt 2 ]; then
11 | echo "Usage: $0 gpu[ ...]"
12 | exit 1
13 | fi
14 | gpu="$1"
15 | shift # shift the remaining arguments
16 |
17 | set -e
18 |
19 | source activate shaperecon
20 |
21 | python train.py \
22 | --net shapehd \
23 | --marrnet2 "$marrnet2" \
24 | --gan "$gan" \
25 | --dataset shapenet \
26 | --classes "$class" \
27 | --canon_sup \
28 | --w_gan_loss 1e-3 \
29 | --batch_size 4 \
30 | --epoch_batches 1000 \
31 | --eval_batches 10 \
32 | --optim adam \
33 | --lr 1e-3 \
34 | --epoch 1000 \
35 | --vis_batches_vali 10 \
36 | --gpu "$gpu" \
37 | --save_net 1 \
38 | --workers 4 \
39 | --logdir "$outdir" \
40 | --suffix '{classes}_w_ganloss{w_gan_loss}' \
41 | --tensorboard \
42 | $*
43 |
44 | source deactivate
45 |
--------------------------------------------------------------------------------
/scripts/test_genre.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # Test GenRe
4 |
5 | out_dir="./output/test"
6 | fullmodel=./downloads/models/full_model.pt
7 | rgb_pattern='./downloads/data/test/genre/*_rgb.*'
8 | mask_pattern='./downloads/data/test/genre/*_silhouette.*'
9 |
10 | if [ $# -lt 1 ]; then
11 | echo "Usage: $0 gpu[ ...]"
12 | exit 1
13 | fi
14 | gpu="$1"
15 | shift # shift the remaining arguments
16 |
17 | set -e
18 |
19 | source activate shaperecon
20 |
21 | python 'test.py' \
22 | --net genre_full_model \
23 | --net_file "$fullmodel" \
24 | --input_rgb "$rgb_pattern" \
25 | --input_mask "$mask_pattern" \
26 | --output_dir "$out_dir" \
27 | --suffix '{net}' \
28 | --overwrite \
29 | --workers 0 \
30 | --batch_size 1 \
31 | --vis_workers 4 \
32 | --gpu "$gpu" \
33 | $*
34 |
35 | source deactivate
36 |
--------------------------------------------------------------------------------
/scripts/test_marrnet.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # Test MarrNet
4 |
5 | out_dir="./output/test"
6 | marrnet=/path/to/marrnet.pt
7 | rgb_pattern='./downloads/data/test/shapehd/*_rgb.*'
8 | mask_pattern='./downloads/data/test/shapehd/*_mask.*'
9 |
10 | if [ $# -lt 1 ]; then
11 | echo "Usage: $0 gpu[ ...]"
12 | exit 1
13 | fi
14 | gpu="$1"
15 | shift # shift the remaining arguments
16 |
17 | set -e
18 |
19 | source activate shaperecon
20 |
21 | python 'test.py' \
22 | --net marrnet \
23 | --net_file "$marrnet" \
24 | --input_rgb "$rgb_pattern" \
25 | --input_mask "$mask_pattern" \
26 | --output_dir "$out_dir" \
27 | --suffix '{net}' \
28 | --overwrite \
29 | --workers 1 \
30 | --batch_size 1 \
31 | --vis_workers 4 \
32 | --gpu "$gpu" \
33 | $*
34 |
35 | source deactivate
36 |
--------------------------------------------------------------------------------
/scripts/test_shapehd.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # Test ShapeHD
4 |
5 | out_dir="./output/test"
6 | net1=./downloads/models/marrnet1_with_minmax.pt
7 | net2=./downloads/models/shapehd.pt
8 | rgb_pattern='./downloads/data/test/shapehd/*_rgb.*'
9 | mask_pattern='./downloads/data/test/shapehd/*_mask.*'
10 |
11 | if [ $# -lt 1 ]; then
12 | echo "Usage: $0 gpu[ ...]"
13 | exit 1
14 | fi
15 | gpu="$1"
16 | shift # shift the remaining arguments
17 |
18 | set -e
19 |
20 |
21 | source activate shaperecon
22 |
23 | python 'test.py' \
24 | --net shapehd \
25 | --net_file "$net2" \
26 | --marrnet1_file "$net1" \
27 | --input_rgb "$rgb_pattern" \
28 | --input_mask "$mask_pattern" \
29 | --output_dir "$out_dir" \
30 | --suffix '{net}' \
31 | --overwrite \
32 | --workers 1 \
33 | --batch_size 1 \
34 | --vis_workers 4 \
35 | --gpu "$gpu" \
36 | $*
37 |
38 | source deactivate
39 |
--------------------------------------------------------------------------------
/scripts/train_full_genre.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | outdir=./output/genre
4 | inpaint_path=/path/to/trained/inpaint.pt
5 |
6 | if [ $# -lt 2 ]; then
7 | echo "Usage: $0 gpu class[ ...]"
8 | exit 1
9 | fi
10 | gpu="$1"
11 | class="$2"
12 | shift # shift the remaining arguments
13 | shift
14 |
15 | set -e
16 |
17 | source activate shaperecon
18 |
19 | python train.py \
20 | --net genre_full_model \
21 | --pred_depth_minmax \
22 | --dataset shapenet \
23 | --classes "$class" \
24 | --batch_size 4 \
25 | --epoch_batches 1000 \
26 | --eval_batches 30 \
27 | --log_time \
28 | --optim adam \
29 | --lr 1e-4 \
30 | --epoch 1000 \
31 | --vis_batches_vali 10 \
32 | --gpu "$gpu" \
33 | --save_net 10 \
34 | --workers 4 \
35 | --logdir "$outdir" \
36 | --suffix '{classes}' \
37 | --tensorboard \
38 | --surface_weight 10 \
39 | --inpaint_path "$inpaint_path" \
40 | $*
41 |
42 | source deactivate
43 |
--------------------------------------------------------------------------------
/scripts/train_inpaint.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | outdir=./output/inpaint
4 | net1_path=/path/to/trained/marrnet1.pt
5 |
6 | if [ $# -lt 2 ]; then
7 | echo "Usage: $0 gpu class[ ...]"
8 | exit 1
9 | fi
10 | gpu="$1"
11 | class="$2"
12 | shift # shift the remaining arguments
13 | shift
14 |
15 | set -e
16 |
17 | source activate shaperecon
18 |
19 | python train.py \
20 | --net depth_pred_with_sph_inpaint \
21 | --pred_depth_minmax \
22 | --dataset shapenet \
23 | --classes "$class" \
24 | --batch_size 4 \
25 | --epoch_batches 2000 \
26 | --eval_batches 10 \
27 | --log_time \
28 | --optim adam \
29 | --lr 1e-4 \
30 | --epoch 1000 \
31 | --vis_batches_vali 10 \
32 | --gpu "$gpu" \
33 | --save_net 10 \
34 | --workers 4 \
35 | --logdir "$outdir" \
36 | --suffix '{classes}' \
37 | --tensorboard \
38 | --net1_path "$net1_path" \
39 | $*
40 |
41 | source deactivate
42 |
--------------------------------------------------------------------------------
/scripts/train_marrnet1.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | outdir=./output/marrnet1
4 |
5 | if [ $# -lt 2 ]; then
6 | echo "Usage: $0 gpu class[ ...]"
7 | exit 1
8 | fi
9 | gpu="$1"
10 | class="$2"
11 | shift # shift the remaining arguments
12 | shift
13 |
14 | set -e
15 |
16 | source activate shaperecon
17 |
18 | python train.py \
19 | --net marrnet1 \
20 | --pred_depth_minmax \
21 | --dataset shapenet \
22 | --classes "$class" \
23 | --batch_size 4 \
24 | --epoch_batches 2500 \
25 | --eval_batches 5 \
26 | --log_time \
27 | --optim adam \
28 | --lr 1e-3 \
29 | --epoch 1000 \
30 | --vis_batches_vali 10 \
31 | --gpu "$gpu" \
32 | --save_net 10 \
33 | --workers 4 \
34 | --logdir "$outdir" \
35 | --suffix '{classes}' \
36 | --tensorboard \
37 | $*
38 |
39 | source deactivate
40 |
--------------------------------------------------------------------------------
/scripts/train_marrnet2.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | outdir=./output/marrnet2
4 |
5 | if [ $# -lt 2 ]; then
6 | echo "Usage: $0 gpu class[ ...]"
7 | exit 1
8 | fi
9 | gpu="$1"
10 | class="$2"
11 | shift # shift the remaining arguments
12 | shift
13 |
14 | set -e
15 |
16 | source activate shaperecon
17 |
18 | python train.py \
19 | --net marrnet2 \
20 | --dataset shapenet \
21 | --classes "$class" \
22 | --canon_sup \
23 | --batch_size 4 \
24 | --epoch_batches 2500 \
25 | --eval_batches 5 \
26 | --optim adam \
27 | --lr 1e-3 \
28 | --epoch 1000 \
29 | --vis_batches_vali 10 \
30 | --gpu "$gpu" \
31 | --save_net 10 \
32 | --workers 4 \
33 | --logdir "$outdir" \
34 | --suffix '{classes}_canon-{canon_sup}' \
35 | --tensorboard \
36 | $*
37 |
38 | source deactivate
39 |
--------------------------------------------------------------------------------
/scripts/train_wgangp.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | outdir=./output/wgangp
4 |
5 | if [ $# -lt 2 ]; then
6 | echo "Usage: $0 gpu class[ ...]"
7 | exit 1
8 | fi
9 | gpu="$1"
10 | class="$2"
11 | shift # shift the remaining arguments
12 | shift
13 |
14 | set -e
15 |
16 | source activate shaperecon
17 |
18 | python train.py \
19 | --net wgangp \
20 | --canon_voxel \
21 | --dataset shapenet \
22 | --classes "$class" \
23 | --batch_size 4 \
24 | --epoch_batches 2500 \
25 | --eval_batches 5 \
26 | --log_time \
27 | --optim adam \
28 | --lr 1e-4 \
29 | --epoch 1000 \
30 | --vis_batches_vali 10 \
31 | --gpu "$gpu" \
32 | --save_net 10 \
33 | --workers 4 \
34 | --logdir "$outdir" \
35 | --suffix '{classes}' \
36 | --tensorboard \
37 | $*
38 |
39 | source deactivate
40 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | from shutil import rmtree
4 | from tqdm import tqdm
5 | import torch
6 | from options import options_test
7 | import datasets
8 | import models
9 | from util.util_print import str_error, str_stage, str_verbose
10 | import util.util_loadlib as loadlib
11 | from loggers import loggers
12 |
13 |
14 | print("Testing Pipeline")
15 |
16 | ###################################################
17 |
18 | print(str_stage, "Parsing arguments")
19 | opt = options_test.parse()
20 | opt.full_logdir = None
21 | print(opt)
22 |
23 | ###################################################
24 |
25 | print(str_stage, "Setting device")
26 | if opt.gpu == '-1':
27 | device = torch.device('cpu')
28 | else:
29 | loadlib.set_gpu(opt.gpu)
30 | device = torch.device('cuda')
31 | if opt.manual_seed is not None:
32 | loadlib.set_manual_seed(opt.manual_seed)
33 |
34 | ###################################################
35 |
36 | print(str_stage, "Setting up output directory")
37 | output_dir = opt.output_dir
38 | output_dir += ('_' + opt.suffix.format(**vars(opt))) \
39 | if opt.suffix != '' else ''
40 | opt.output_dir = output_dir
41 |
42 | if os.path.isdir(output_dir):
43 | if opt.overwrite:
44 | rmtree(output_dir)
45 | else:
46 | raise ValueError(str_error +
47 | " %s already exists, but no overwrite flag"
48 | % output_dir)
49 | os.makedirs(output_dir)
50 |
51 | ###################################################
52 |
53 | print(str_stage, "Setting up loggers")
54 | logger_list = [
55 | loggers.TerminateOnNaN(),
56 | ]
57 | logger = loggers.ComposeLogger(logger_list)
58 |
59 | ###################################################
60 |
61 | print(str_stage, "Setting up models")
62 | Model = models.get_model(opt.net, test=True)
63 | model = Model(opt, logger)
64 | model.to(device)
65 | model.eval()
66 | print(model)
67 | print("# model parameters: {:,d}".format(model.num_parameters()))
68 |
69 | ###################################################
70 |
71 | print(str_stage, "Setting up data loaders")
72 | start_time = time.time()
73 | Dataset = datasets.get_dataset('test')
74 | dataset = Dataset(opt, model=model)
75 | dataloader = torch.utils.data.DataLoader(
76 | dataset,
77 | batch_size=opt.batch_size,
78 | num_workers=opt.workers,
79 | pin_memory=True,
80 | drop_last=False,
81 | shuffle=False
82 | )
83 | n_batches = len(dataloader)
84 | dataiter = iter(dataloader)
85 | print(str_verbose, "Time spent in data IO initialization: %.2fs" %
86 | (time.time() - start_time))
87 | print(str_verbose, "# test points: " + str(len(dataset)))
88 | print(str_verbose, "# test batches: " + str(n_batches))
89 |
90 | ###################################################
91 |
92 | print(str_stage, "Testing")
93 | for i in tqdm(range(n_batches)):
94 | batch = next(dataiter)
95 | model.test_on_batch(i, batch)
96 |
--------------------------------------------------------------------------------
/toolbox/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiumingzhang/GenRe-ShapeHD/ee42add2707de509b5914ab444ae91b832f75981/toolbox/__init__.py
--------------------------------------------------------------------------------
/toolbox/calc_prob/build.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import torch
4 | from torch.utils.ffi import create_extension
5 |
6 |
7 | this_file = os.path.dirname(os.path.realpath(__file__))
8 | print(this_file)
9 |
10 | extra_compile_args = list()
11 |
12 |
13 | extra_objects = list()
14 | assert(torch.cuda.is_available())
15 | sources = ['calc_prob/src/calc_prob.c']
16 | headers = ['calc_prob/src/calc_prob.h']
17 | defines = [('WITH_CUDA', True)]
18 | with_cuda = True
19 |
20 | extra_objects = ['calc_prob/src/calc_prob_kernel.cu.o']
21 | extra_objects = [os.path.join(this_file, fname) for fname in extra_objects]
22 |
23 | ffi_params = {
24 | 'headers': headers,
25 | 'sources': sources,
26 | 'define_macros': defines,
27 | 'relative_to': __file__,
28 | 'with_cuda': with_cuda,
29 | 'extra_objects': extra_objects,
30 | 'include_dirs': [os.path.join(this_file, 'calc_prob/src')],
31 | 'extra_compile_args': extra_compile_args,
32 | }
33 |
34 |
35 | if __name__ == '__main__':
36 | ext = create_extension(
37 | 'calc_prob._ext.calc_prob_lib',
38 | package=False,
39 | **ffi_params)
40 | #from setuptools import setup
41 | # setup()
42 | ext.build()
43 |
44 | # ffi.build()
45 |
--------------------------------------------------------------------------------
/toolbox/calc_prob/calc_prob/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiumingzhang/GenRe-ShapeHD/ee42add2707de509b5914ab444ae91b832f75981/toolbox/calc_prob/calc_prob/__init__.py
--------------------------------------------------------------------------------
/toolbox/calc_prob/calc_prob/functions/__init__.py:
--------------------------------------------------------------------------------
1 | from .calc_prob import CalcStopProb
2 |
--------------------------------------------------------------------------------
/toolbox/calc_prob/calc_prob/functions/calc_prob.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Function
3 | from torch.autograd.function import once_differentiable
4 | from .._ext import calc_prob_lib
5 | from cffi import FFI
6 | ffi = FFI()
7 |
8 |
9 | class CalcStopProb(Function):
10 | @staticmethod
11 | def forward(ctx, prob_in):
12 | assert prob_in.dim() == 5
13 | assert prob_in.dtype == torch.float32
14 | assert prob_in.is_cuda
15 | stop_prob = prob_in.new(prob_in.shape)
16 | stop_prob.zero_()
17 | calc_prob_lib.calc_prob_forward(prob_in, stop_prob)
18 | ctx.save_for_backward(prob_in, stop_prob)
19 | return stop_prob
20 |
21 | @staticmethod
22 | @once_differentiable
23 | def backward(ctx, grad_in):
24 | prob_in, stop_prob = ctx.saved_tensors
25 | grad_out = grad_in.new(grad_in.shape)
26 | grad_out.zero_()
27 | stop_prob_weighted = stop_prob * grad_in
28 | calc_prob_lib.calc_prob_backward(prob_in, stop_prob_weighted, grad_out)
29 | return grad_out
30 |
--------------------------------------------------------------------------------
/toolbox/calc_prob/calc_prob/src/calc_prob.c:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include "calc_prob.h"
5 | #include "calc_prob_kernel.h"
6 |
7 | extern THCState *state;
8 |
9 | int calc_prob_forward(THCudaTensor* prob_in, THCudaTensor* prob_out){
10 | int success = 0;
11 | success = calc_prob_forward_wrap(state, prob_in, prob_out);
12 | // check for errors
13 | if (!success) {
14 | THError("aborting");
15 | }
16 | return 1;
17 | }
18 | int calc_prob_backward(THCudaTensor* prob_in, THCudaTensor* stop_prob_weighted, THCudaTensor* grad_out){
19 | int success = 0;
20 | success = calc_prob_backward_wrap(state, prob_in, stop_prob_weighted, grad_out);
21 | // check for errors
22 | if (!success) {
23 | THError("aborting");
24 | }
25 | return 1;
26 | }
27 |
--------------------------------------------------------------------------------
/toolbox/calc_prob/calc_prob/src/calc_prob.h:
--------------------------------------------------------------------------------
1 | int calc_prob_forward(THCudaTensor* prob_in, THCudaTensor* prob_out);
2 | int calc_prob_backward(THCudaTensor* prob_in, THCudaTensor* stop_prob_weighted, THCudaTensor* grad_out);
3 |
--------------------------------------------------------------------------------
/toolbox/calc_prob/calc_prob/src/calc_prob_kernel.h:
--------------------------------------------------------------------------------
1 | #ifdef __cplusplus
2 | extern "C" {
3 | #endif
4 | int calc_prob_forward_wrap(THCState* state, THCudaTensor* prob_in, THCudaTensor* prob_out);
5 | int calc_prob_backward_wrap(THCState* state, THCudaTensor* prob_in, THCudaTensor* stop_prob_weighted, THCudaTensor* grad_out);
6 | #ifdef __cplusplus
7 | }
8 | #endif
9 |
--------------------------------------------------------------------------------
/toolbox/calc_prob/clean.sh:
--------------------------------------------------------------------------------
1 |
2 | # ANSI color codes
3 | RS="\033[0m" # reset
4 | HC="\033[1m" # hicolor
5 | UL="\033[4m" # underline
6 | INV="\033[7m" # inverse background and foreground
7 | FBLK="\033[30m" # foreground black
8 | FRED="\033[31m" # foreground red
9 | FGRN="\033[32m" # foreground green
10 | FYEL="\033[33m" # foreground yellow
11 | FBLE="\033[34m" # foreground blue
12 | FMAG="\033[35m" # foreground magenta
13 | FCYN="\033[36m" # foreground cyan
14 | FWHT="\033[37m" # foreground white
15 | BBLK="\033[40m" # background black
16 | BRED="\033[41m" # background red
17 | BGRN="\033[42m" # background green
18 | BYEL="\033[43m" # background yellow
19 | BBLE="\033[44m" # background blue
20 | BMAG="\033[45m" # background magenta
21 | BCYN="\033[46m" # background cyan
22 | BWHT="\033[47m" # background white
23 |
24 | function rm_if_exist() {
25 | if [ -f "$1" ]; then
26 | rm "$1";
27 | echo -e "${FGRN}File $1 removed${RS}"
28 | elif [ -d "$1" ]; then
29 | rm -r "$1";
30 | echo -e "${FBLE}Directory $1 removed${RS}"
31 | #else
32 | # echo -e "${FRED}$1 not found${RS}"
33 | fi
34 | }
35 |
36 | rm_if_exist "calc_prob/src/calc_prob_kernel.cu.o"
37 | rm_if_exist "__pycache__"
38 | rm_if_exist "dist"
39 | rm_if_exist "build"
40 | rm_if_exist "pytorch_calc_stop_problility.egg-info"
41 | rm_if_exist ".cache"
42 | rm_if_exist "calc_prob/__pycache__"
43 | rm_if_exist "calc_prob/_ext"
44 | rm_if_exist "calc_prob/functions/__pycache__"
45 | rm_if_exist "calc_prob/modules/__pycache__"
46 |
--------------------------------------------------------------------------------
/toolbox/calc_prob/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | from setuptools import setup, find_packages
5 |
6 | import build
7 |
8 | this_file = os.path.dirname(__file__)
9 |
10 | setup(
11 | name="pytorch_calc_stop_problility",
12 | version="0.1.0",
13 | description="Pytorch extension of calcualting ray stop probability",
14 | url="https://bluhbluhbluh",
15 | author="Zhoutong Zhang",
16 | author_email="ztzhang@mit.edu",
17 | # Require cffi.
18 | install_requires=["cffi>=1.0.0"],
19 | setup_requires=["cffi>=1.0.0"],
20 | # Exclude the build files.
21 | packages=find_packages(exclude=["build", "test"]),
22 | # Package where to put the extensions. Has to be a prefix of build.py.
23 | ext_package="",
24 | # Extensions to compile.
25 | cffi_modules=[
26 | os.path.join(this_file, "build.py:ffi")
27 | ],
28 | )
29 |
--------------------------------------------------------------------------------
/toolbox/calc_prob/setup.sh:
--------------------------------------------------------------------------------
1 | echo "Add -gencode to match all the GPU architectures you have."
2 | echo "Check 'https://en.wikipedia.org/wiki/CUDA#GPUs_supported' for list of architecture."
3 | echo "Check 'http://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html' for GPU compilation based on architecture."
4 |
5 | # GPU architecture short list:
6 | # GTX 650M: 30
7 | # GTX Titan: 35
8 | # GTX Titan Black: 35
9 | # Tesla K40c: 35
10 | # GTX Titan X: 52
11 | # Titan X (Pascal): 61
12 | # GTX 1080: 61
13 | # Titan Xp: 61
14 |
15 | TORCH=$(python -c "import os; import torch; print(os.path.dirname(torch.__file__))")
16 | HAS_CUDA=$(python -c "import torch; print(torch.cuda.is_available())")
17 |
18 | if [ "$HAS_CUDA" == "True" ]; then
19 | if ! type nvcc >/dev/null 2>&1 ; then
20 | echo 'cuda available but nvcc not found. Please add nvcc to $PATH. '
21 | exit 1
22 | fi
23 | cd calc_prob/src
24 | HERE=$(pwd -P)
25 | cmd="nvcc -c -o calc_prob_kernel.cu.o calc_prob_kernel.cu -x cu -Xcompiler -fPIC -I ${TORCH}/lib/include -I ${TORCH}/lib/include/TH -I ${TORCH}/lib/include/THC -I ${HERE} \
26 | -gencode arch=compute_30,code=sm_30 \
27 | -gencode arch=compute_35,code=sm_35 \
28 | -gencode arch=compute_52,code=sm_52 \
29 | -gencode arch=compute_61,code=sm_61 "
30 | echo "$cmd"
31 | eval "$cmd"
32 | cd ../../
33 | fi
34 | if [ "$1" = "package" ]; then
35 | # for install
36 | python setup.py install
37 | elif [ "$1" = "script" ]; then
38 | # for build
39 | python build.py
40 | else
41 | echo "Shouldn't be here."
42 | fi
43 |
--------------------------------------------------------------------------------
/toolbox/cam_bp/build.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import torch
4 | from torch.utils.ffi import create_extension
5 |
6 | this_file = os.path.dirname(os.path.realpath(__file__))
7 | print(this_file)
8 |
9 | extra_compile_args = list()
10 |
11 | extra_objects = list()
12 | assert(torch.cuda.is_available())
13 | sources = ['cam_bp/src/back_projection.c']
14 | headers = ['cam_bp/src/back_projection.h']
15 | defines = [('WITH_CUDA', True)]
16 | with_cuda = True
17 |
18 | extra_objects = ['cam_bp/src/back_projection_kernel.cu.o']
19 | extra_objects = [os.path.join(this_file, fname) for fname in extra_objects]
20 |
21 | ffi_params = {
22 | 'headers': headers,
23 | 'sources': sources,
24 | 'define_macros': defines,
25 | 'relative_to': __file__,
26 | 'with_cuda': with_cuda,
27 | 'extra_objects': extra_objects,
28 | 'include_dirs': [os.path.join(this_file, 'cam_bp/src')],
29 | 'extra_compile_args': extra_compile_args,
30 | }
31 |
32 | ffi = create_extension(
33 | 'cam_bp._ext.cam_bp_lib',
34 | package=True,
35 | **ffi_params
36 | )
37 |
38 | if __name__ == '__main__':
39 | ffi = create_extension(
40 | 'cam_bp._ext.cam_bp_lib',
41 | package=False,
42 | **ffi_params)
43 | ffi.build()
44 |
--------------------------------------------------------------------------------
/toolbox/cam_bp/cam_bp/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiumingzhang/GenRe-ShapeHD/ee42add2707de509b5914ab444ae91b832f75981/toolbox/cam_bp/cam_bp/__init__.py
--------------------------------------------------------------------------------
/toolbox/cam_bp/cam_bp/functions/__init__.py:
--------------------------------------------------------------------------------
1 | from .cam_back_projection import CameraBackProjection
2 | from .get_surface_mask import get_surface_mask
3 | from .sperical_to_tdf import SphericalBackProjection
4 |
--------------------------------------------------------------------------------
/toolbox/cam_bp/cam_bp/functions/cam_back_projection.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Function
3 | from torch.autograd.function import once_differentiable
4 | from .._ext import cam_bp_lib
5 | from cffi import FFI
6 | ffi = FFI()
7 |
8 |
9 | class CameraBackProjection(Function):
10 |
11 | @staticmethod
12 | def forward(ctx, depth_t, fl, cam_dist, res=128):
13 | assert depth_t.dim() == 4
14 | assert fl.dim() == 2 and fl.size(1) == depth_t.size(1)
15 | assert cam_dist.dim() == 2 and cam_dist.size(1) == depth_t.size(1)
16 | assert cam_dist.size(0) == depth_t.size(0)
17 | assert fl.size(0) == depth_t.size(0)
18 | assert depth_t.is_cuda
19 | assert fl.is_cuda
20 | assert cam_dist.is_cuda
21 | in_shape = depth_t.shape
22 | cnt = depth_t.new(in_shape[0], in_shape[1], res, res, res).zero_()
23 | tdf = depth_t.new(in_shape[0], in_shape[1],
24 | res, res, res).zero_() + 1 / res
25 | cam_bp_lib.back_projection_forward(depth_t, cam_dist, fl, tdf, cnt)
26 | # print(cnt)
27 | ctx.save_for_backward(depth_t, fl, cam_dist)
28 | ctx.cnt_forward = cnt
29 | ctx.depth_shape = in_shape
30 | return tdf
31 |
32 | @staticmethod
33 | @once_differentiable
34 | def backward(ctx, grad_output):
35 | assert grad_output.is_cuda
36 | # print(grad_output.type())
37 | depth_t, fl, cam_dist = ctx.saved_tensors
38 | cnt = ctx.cnt_forward
39 | grad_depth = grad_output.new(ctx.depth_shape).zero_()
40 | grad_fl = grad_output.new(
41 | ctx.depth_shape[0], ctx.depth_shape[1]).zero_()
42 | grad_camdist = grad_output.new(
43 | ctx.depth_shape[0], ctx.depth_shape[1]).zero_()
44 | cam_bp_lib.back_projection_backward(
45 | depth_t, fl, cam_dist, cnt, grad_output, grad_depth, grad_camdist, grad_fl)
46 | return grad_depth, grad_fl, grad_camdist, None
47 |
--------------------------------------------------------------------------------
/toolbox/cam_bp/cam_bp/functions/get_surface_mask.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from .._ext import cam_bp_lib
4 | from cffi import FFI
5 | ffi = FFI()
6 |
7 |
8 | def get_vox_surface_cnt(depth_t, fl, cam_dist, res=128):
9 | assert depth_t.dim() == 4
10 | assert fl.dim() == 2 and fl.size(1) == depth_t.size(1)
11 | assert cam_dist.dim() == 2 and cam_dist.size(1) == depth_t.size(1)
12 | assert cam_dist.size(0) == depth_t.size(0)
13 | assert fl.size(0) == depth_t.size(0)
14 | assert depth_t.is_cuda
15 | assert fl.is_cuda
16 | assert cam_dist.is_cuda
17 | in_shape = depth_t.shape
18 | cnt = depth_t.new(in_shape[0], in_shape[1], res, res, res).zero_()
19 | tdf = depth_t.new(in_shape[0], in_shape[1], res,
20 | res, res).zero_() + 1 / res
21 | cam_bp_lib.back_projection_forward(depth_t, cam_dist, fl, tdf, cnt)
22 | return cnt
23 |
24 |
25 | def get_surface_mask(depth_t, fl=784.4645406, cam_dist=2.0, res=128):
26 | n = depth_t.size(0)
27 | nc = depth_t.size(1)
28 | if type(fl) == float:
29 | fl_v = fl
30 | fl = torch.FloatTensor(n, nc).cuda()
31 | fl.fill_(fl_v)
32 | if type(cam_dist) == float:
33 | cmd_v = cam_dist
34 | cam_dist = torch.FloatTensor(n, nc).cuda()
35 | cam_dist.fill_(cmd_v)
36 | cnt = get_vox_surface_cnt(depth_t, fl, cam_dist, res)
37 | mask = cnt.new(n, nc, res, res, res).zero_()
38 | cam_bp_lib.get_surface_mask(depth_t, cam_dist, fl, cnt, mask)
39 | surface_vox = torch.clamp(cnt, min=0.0, max=1.0)
40 | return surface_vox, mask
41 |
--------------------------------------------------------------------------------
/toolbox/cam_bp/cam_bp/functions/sperical_to_tdf.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from torch.autograd import Function
4 | from torch.autograd.function import once_differentiable
5 | from .._ext import cam_bp_lib
6 | from cffi import FFI
7 | ffi = FFI()
8 |
9 |
10 | class SphericalBackProjection(Function):
11 |
12 | @staticmethod
13 | def forward(ctx, spherical, grid, res=128):
14 | assert spherical.dim() == 4
15 | assert grid.dim() == 5
16 | assert spherical.size(0) == grid.size(0)
17 | assert spherical.size(1) == grid.size(1)
18 | assert spherical.size(2) == grid.size(2)
19 | assert spherical.size(3) == grid.size(3)
20 | assert grid.size(4) == 3
21 | assert spherical.is_cuda
22 | assert grid.is_cuda
23 | in_shape = spherical.shape
24 | cnt = spherical.new(in_shape[0], in_shape[1], res, res, res).zero_()
25 | tdf = spherical.new(in_shape[0], in_shape[1],
26 | res, res, res).zero_()
27 | cam_bp_lib.spherical_back_proj_forward(spherical, grid, tdf, cnt)
28 | # print(cnt)
29 | ctx.save_for_backward(spherical.detach(), grid, cnt)
30 | ctx.depth_shape = in_shape
31 | return tdf, cnt
32 |
33 | @staticmethod
34 | @once_differentiable
35 | def backward(ctx, grad_output, grad_phony):
36 | assert grad_output.is_cuda
37 | assert not np.isnan(torch.sum(grad_output.detach()))
38 | spherical, grid, cnt = ctx.saved_tensors
39 | grad_depth = grad_output.new(ctx.depth_shape).zero_()
40 | cam_bp_lib.spherical_back_proj_backward(
41 | spherical, grid, cnt, grad_output, grad_depth)
42 | try:
43 | assert not np.isnan(torch.sum(grad_depth))
44 | except:
45 | import pdb
46 | pdb.set_trace()
47 | return grad_depth, None, None
48 |
--------------------------------------------------------------------------------
/toolbox/cam_bp/cam_bp/modules/Spherical_backproj.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from ..functions import SphericalBackProjection
4 | from torch.autograd import Variable
5 |
6 |
7 | class spherical_backprojection(nn.Module):
8 |
9 | def __init__(self, grid, vox_res=128):
10 | super(camera_backprojection, self).__init__()
11 | self.vox_res = vox_res
12 | self.backprojection_layer = SphericalBackProjection()
13 | assert type(grid) == torch.FloatTensor
14 | self.grid = Variable(grid.cuda())
15 |
16 | def forward(self, spherical):
17 | return self.backprojection_layer(spherical, self.grid, self.vox_res)
18 |
--------------------------------------------------------------------------------
/toolbox/cam_bp/cam_bp/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiumingzhang/GenRe-ShapeHD/ee42add2707de509b5914ab444ae91b832f75981/toolbox/cam_bp/cam_bp/modules/__init__.py
--------------------------------------------------------------------------------
/toolbox/cam_bp/cam_bp/modules/camera_backprojection_module.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | from ..functions import CameraBackProjection
3 | import torch
4 |
5 |
6 | class Camera_back_projection_layer(nn.Module):
7 | def __init__(self, res=128):
8 | super(Camera_back_projection_layer, self).__init__()
9 | assert res == 128
10 | self.res = 128
11 |
12 | def forward(self, depth_t, fl=418.3, cam_dist=2.2, shift=True):
13 | n = depth_t.size(0)
14 | if type(fl) == float:
15 | fl_v = fl
16 | fl = torch.FloatTensor(n, 1).cuda()
17 | fl.fill_(fl_v)
18 | if type(cam_dist) == float:
19 | cmd_v = cam_dist
20 | cam_dist = torch.FloatTensor(n, 1).cuda()
21 | cam_dist.fill_(cmd_v)
22 | df = CameraBackProjection.apply(depth_t, fl, cam_dist, self.res)
23 | return self.shift_tdf(df) if shift else df
24 |
25 | @staticmethod
26 | def shift_tdf(input_tdf, res=128):
27 | out_tdf = 1 - res * (input_tdf)
28 | return out_tdf
29 |
30 |
31 | class camera_backprojection(nn.Module):
32 |
33 | def __init__(self, vox_res=128):
34 | super(camera_backprojection, self).__init__()
35 | self.vox_res = vox_res
36 | self.backprojection_layer = CameraBackProjection()
37 |
38 | def forward(self, depth, fl, camdist):
39 | return self.backprojection_layer(depth, fl, camdist, self.voxel_res)
40 |
--------------------------------------------------------------------------------
/toolbox/cam_bp/cam_bp/src/_cam_bp_lib.abi3.so:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiumingzhang/GenRe-ShapeHD/ee42add2707de509b5914ab444ae91b832f75981/toolbox/cam_bp/cam_bp/src/_cam_bp_lib.abi3.so
--------------------------------------------------------------------------------
/toolbox/cam_bp/cam_bp/src/back_projection.c:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include "back_projection.h"
5 | #include "back_projection_kernel.h"
6 |
7 | extern THCState *state;
8 |
9 | int back_projection_forward(THCudaTensor* depth, THCudaTensor* camdist, THCudaTensor* fl, THCudaTensor* voxel, THCudaTensor* cnt){
10 | int success = 0;
11 | success = back_projection_forward_wrap(state, depth, camdist, fl, voxel, cnt);
12 | // check for errors
13 | if (!success) {
14 | THError("aborting");
15 | }
16 | return 1;
17 | }
18 |
19 |
20 | int back_projection_backward(THCudaTensor* depth, THCudaTensor* fl, THCudaTensor* camdist, THCudaTensor* cnt, THCudaTensor* grad_in, THCudaTensor* grad_depth, THCudaTensor* grad_camdist, THCudaTensor* grad_fl){
21 | int success = 0;
22 | success = back_projection_backward_wrap(state, depth, fl, camdist, cnt, grad_in, grad_depth, grad_camdist, grad_fl);
23 | // check for errors
24 | if (!success) {
25 | THError("aborting");
26 | }
27 | return 1;
28 | }
29 |
30 | int get_surface_mask(THCudaTensor* depth, THCudaTensor* camdist, THCudaTensor* fl, THCudaTensor* cnt, THCudaTensor* mask){
31 | int success = 0;
32 | success = get_surface_mask_wrap(state, depth, camdist, fl, cnt, mask);
33 | // check for errors
34 | if (!success) {
35 | THError("aborting");
36 | }
37 | return 1;
38 | }
39 |
40 | int spherical_back_proj_forward(THCudaTensor* depth, THCudaTensor* grid_in, THCudaTensor* voxel, THCudaTensor* cnt){
41 | int success = 0;
42 | success = spherical_back_proj_forward_wrap(state, depth, grid_in, voxel, cnt);
43 | // check for errors
44 | if (!success) {
45 | THError("aborting");
46 | }
47 | return 1;
48 | }
49 | int spherical_back_proj_backward(THCudaTensor* depth, THCudaTensor* grid_in, THCudaTensor* cnt, THCudaTensor* grad_in, THCudaTensor* grad_depth){
50 | int success = 0;
51 | success = spherical_back_proj_backward_wrap(state, depth, grid_in,cnt,grad_in,grad_depth);
52 | // check for errors
53 | if (!success) {
54 | THError("aborting");
55 | }
56 | return 1;
57 | }
58 |
--------------------------------------------------------------------------------
/toolbox/cam_bp/cam_bp/src/back_projection.h:
--------------------------------------------------------------------------------
1 | int back_projection_forward(THCudaTensor* depth, THCudaTensor* camdist, THCudaTensor* fl, THCudaTensor* voxel, THCudaTensor* cnt);
2 | int back_projection_backward(THCudaTensor* depth, THCudaTensor* fl, THCudaTensor* camdist, THCudaTensor* cnt, THCudaTensor* grad_in, THCudaTensor* grad_depth, THCudaTensor* grad_camdist, THCudaTensor* grad_fl);
3 | int get_surface_mask(THCudaTensor* depth, THCudaTensor* camdist, THCudaTensor* fl, THCudaTensor* cnt, THCudaTensor* mask);
4 | int spherical_back_proj_forward(THCudaTensor* depth, THCudaTensor* grid_in, THCudaTensor* voxel, THCudaTensor* cnt);
5 | int spherical_back_proj_backward(THCudaTensor* depth, THCudaTensor* grid_in, THCudaTensor* cnt, THCudaTensor* grad_in, THCudaTensor* grad_depth);
6 |
--------------------------------------------------------------------------------
/toolbox/cam_bp/cam_bp/src/back_projection_kernel.h:
--------------------------------------------------------------------------------
1 |
2 | #ifdef __cplusplus
3 | extern "C" {
4 | #endif
5 |
6 | int back_projection_forward_wrap (THCState* state, THCudaTensor* depth, THCudaTensor* camdist, THCudaTensor* fl, THCudaTensor* voxel, THCudaTensor* cnt);
7 | int back_projection_backward_wrap (THCState* state, THCudaTensor* depth, THCudaTensor* fl, THCudaTensor* camdist, THCudaTensor* cnt, THCudaTensor* grad_in, THCudaTensor* grad_depth, THCudaTensor* grad_camdist, THCudaTensor* grad_fl);
8 | int get_surface_mask_wrap(THCState* state, THCudaTensor* depth, THCudaTensor* camdist, THCudaTensor* fl, THCudaTensor* cnt, THCudaTensor* mask);
9 | int spherical_back_proj_forward_wrap(THCState* state, THCudaTensor* depth, THCudaTensor* grid_in, THCudaTensor* voxel, THCudaTensor* cnt);
10 | int spherical_back_proj_backward_wrap(THCState* state, THCudaTensor* depth, THCudaTensor* grid_in, THCudaTensor* cnt, THCudaTensor* grad_in, THCudaTensor* grad_depth);
11 | #ifdef __cplusplus
12 | }
13 | #endif
14 |
--------------------------------------------------------------------------------
/toolbox/cam_bp/clean.sh:
--------------------------------------------------------------------------------
1 |
2 | # ANSI color codes
3 | RS="\033[0m" # reset
4 | HC="\033[1m" # hicolor
5 | UL="\033[4m" # underline
6 | INV="\033[7m" # inverse background and foreground
7 | FBLK="\033[30m" # foreground black
8 | FRED="\033[31m" # foreground red
9 | FGRN="\033[32m" # foreground green
10 | FYEL="\033[33m" # foreground yellow
11 | FBLE="\033[34m" # foreground blue
12 | FMAG="\033[35m" # foreground magenta
13 | FCYN="\033[36m" # foreground cyan
14 | FWHT="\033[37m" # foreground white
15 | BBLK="\033[40m" # background black
16 | BRED="\033[41m" # background red
17 | BGRN="\033[42m" # background green
18 | BYEL="\033[43m" # background yellow
19 | BBLE="\033[44m" # background blue
20 | BMAG="\033[45m" # background magenta
21 | BCYN="\033[46m" # background cyan
22 | BWHT="\033[47m" # background white
23 |
24 | function rm_if_exist() {
25 | if [ -f "$1" ]; then
26 | rm "$1";
27 | echo -e "${FGRN}File $1 removed${RS}"
28 | elif [ -d "$1" ]; then
29 | rm -r "$1";
30 | echo -e "${FBLE}Directory $1 removed${RS}"
31 | else
32 | echo -e "${FRED}$1 not found${RS}"
33 | fi
34 | }
35 |
36 | rm_if_exist "cam_bp/src/back_projection_kernel.cu.o"
37 | rm_if_exist "__pycache__"
38 | rm_if_exist "dist"
39 | rm_if_exist "build"
40 | rm_if_exist "pytorch_camera_back_projection.egg-info"
41 | rm_if_exist ".cache"
42 | rm_if_exist "cam_bp/__pycache__"
43 | rm_if_exist "cam_bp/_ext"
44 | rm_if_exist "cam_bp/functions/__pycache__"
45 | rm_if_exist "cam_bp/modules/__pycache__"
46 |
--------------------------------------------------------------------------------
/toolbox/cam_bp/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | from setuptools import setup, find_packages
5 |
6 | import build
7 |
8 | this_file = os.path.dirname(__file__)
9 |
10 | setup(
11 | name="pytorch_camera_back_projection",
12 | version="0.1.0",
13 | description="Pytorch extension of back projecting depth",
14 | url="https://bluhbluhbluh",
15 | author="Zhoutong Zhang",
16 | author_email="ztzhang@mit.edu",
17 | # Require cffi.
18 | install_requires=["cffi>=1.0.0"],
19 | setup_requires=["cffi>=1.0.0"],
20 | # Exclude the build files.
21 | packages=find_packages(exclude=["build", "test"]),
22 | # Package where to put the extensions. Has to be a prefix of build.py.
23 | ext_package="",
24 | # Extensions to compile.
25 | cffi_modules=[
26 | os.path.join(this_file, "build.py:ffi")
27 | ],
28 | )
29 |
--------------------------------------------------------------------------------
/toolbox/cam_bp/setup.sh:
--------------------------------------------------------------------------------
1 | echo "Add -gencode to match all the GPU architectures you have."
2 | echo "Check 'https://en.wikipedia.org/wiki/CUDA#GPUs_supported' for list of architecture."
3 | echo "Check 'http://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html' for GPU compilation based on architecture."
4 |
5 | # GPU architecture short list:
6 | # GTX 650M: 30
7 | # GTX Titan: 35
8 | # GTX Titan Black: 35
9 | # Tesla K40c: 35
10 | # GTX Titan X: 52
11 | # Titan X (Pascal): 61
12 | # GTX 1080: 61
13 | # Titan Xp: 61
14 |
15 | TORCH=$(python -c "import os; import torch; print(os.path.dirname(torch.__file__))")
16 | HAS_CUDA=$(python -c "import torch; print(torch.cuda.is_available())")
17 |
18 | if [ "$HAS_CUDA" == "True" ]; then
19 | if ! type nvcc >/dev/null 2>&1 ; then
20 | echo 'cuda available but nvcc not found. Please add nvcc to $PATH. '
21 | exit 1
22 | fi
23 | cd cam_bp/src
24 | HERE=$(pwd -P)
25 | cmd="nvcc -c -o back_projection_kernel.cu.o back_projection_kernel.cu -x cu -Xcompiler -fPIC -I ${TORCH}/lib/include/TH -I ${TORCH}/lib/include -I ${TORCH}/lib/include/THC -I ${HERE} -I ${TORCH}/lib/include\
26 | -gencode arch=compute_30,code=sm_30 \
27 | -gencode arch=compute_35,code=sm_35 \
28 | -gencode arch=compute_52,code=sm_52 \
29 | -gencode arch=compute_61,code=sm_61"
30 | echo "$cmd"
31 | eval "$cmd"
32 | cd ..
33 | fi
34 | cd ..
35 | pwd
36 | if [ "$1" = "package" ]; then
37 | # for install
38 | python setup.py install
39 | elif [ "$1" = "script" ]; then
40 | # for build
41 | python build.py
42 | else
43 | echo "Shouldn't be here."
44 | fi
45 |
--------------------------------------------------------------------------------
/toolbox/nndistance/README.md:
--------------------------------------------------------------------------------
1 | # Chamfer Distance for Pytorch
2 | Modified from [pointGAN](https://github.com/fxia22/pointGAN)
3 |
4 | ## Requirements
5 | Tested on Pytorch 0.3.1
6 | Due to syntax change in Pytorch 0.4.0, this implementation probably won't work on Pytorch 0.4.0
7 |
8 | ## Install
9 | ```bash
10 | ./clean.sh
11 | ./setup.sh script
12 | ```
13 | Note that currently the code only supports building as script, so you'll need to put this directory under your code's root directory, where you can import using `import nndistance`
14 |
15 | ## Example
16 | Run `test.py` as an example:
17 |
18 | ```bash
19 | cp test.py ..
20 | python test.py
21 | ```
22 |
23 | ## Usage
24 | - The function `nndistance.functions.nndistance(pts1, pts2)` return two lists of distances - the closest distance for each point in `pts1` to point cloud `pts2`, and the closest distance for each point in `pts2` to point cloud `pts1`.
25 | - For convenience, the distance here is defined as `(x1-x2)*(x1-x2) + (y1-y2)*(y1-y2) + (z1-z2)*(z1-z2)`, **without taking the square root**.
26 | - If you want to take the square root, keep in mind that in Pytorch, **the gradient of `sqrt(0)` is `nan`**, so you'll probably want to add a small `eps` before taking sqrt.
27 | - The function `nndistance.functions.nndistance_score(pts1, pts2)` return a list of scores.
28 |
29 |
30 | Internal note: this implementation gives the same result as our previously used implementation in tensorflow.
--------------------------------------------------------------------------------
/toolbox/nndistance/build.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from torch.utils.ffi import create_extension
4 |
5 | this_file = os.path.dirname(__file__)
6 |
7 | sources = ['src/my_lib.c']
8 | headers = ['src/my_lib.h']
9 | defines = []
10 | with_cuda = False
11 |
12 | if torch.cuda.is_available():
13 | print('Including CUDA code.')
14 | sources += ['src/my_lib_cuda.c']
15 | headers += ['src/my_lib_cuda.h']
16 | defines += [('WITH_CUDA', None)]
17 | with_cuda = True
18 |
19 | this_file = os.path.dirname(os.path.realpath(__file__))
20 | print(this_file)
21 | extra_objects = ['src/nnd_cuda.cu.o']
22 | extra_objects = [os.path.join(this_file, fname) for fname in extra_objects]
23 |
24 | ffi = create_extension(
25 | '_ext.my_lib',
26 | headers=headers,
27 | sources=sources,
28 | define_macros=defines,
29 | relative_to=__file__,
30 | with_cuda=with_cuda,
31 | extra_objects=extra_objects
32 | )
33 |
34 | if __name__ == '__main__':
35 | ffi.build()
36 |
--------------------------------------------------------------------------------
/toolbox/nndistance/clean.sh:
--------------------------------------------------------------------------------
1 | rm -rf _ext
2 | rm -f src/*.o
3 | rm -rf */__pycache__
--------------------------------------------------------------------------------
/toolbox/nndistance/functions/__init__.py:
--------------------------------------------------------------------------------
1 | from .nnd import nndistance, nndistance_w_idx, nndistance_score
2 |
--------------------------------------------------------------------------------
/toolbox/nndistance/functions/nnd.py:
--------------------------------------------------------------------------------
1 | # functions/add.py
2 | import torch
3 | from torch.autograd import Function
4 | from torch.autograd.function import once_differentiable
5 | from nndistance._ext import my_lib
6 |
7 |
8 | class NNDFunction(Function):
9 | @staticmethod
10 | def forward(ctx, xyz1, xyz2):
11 | assert xyz1.dim() == 3 and xyz2.dim() == 3
12 | assert xyz1.size(0) == xyz2.size(0)
13 | assert xyz1.size(2) == 3 and xyz2.size(2) == 3
14 | assert xyz1.is_cuda == xyz2.is_cuda
15 | assert xyz1.type().endswith('FloatTensor') and xyz2.type().endswith('FloatTensor'), 'only FloatTensor are supported for NNDistance'
16 | assert xyz1.is_contiguous() and xyz2.is_contiguous() # the CPU and GPU code are not robust and will break if the storage is not contiguous
17 | ctx.is_cuda = xyz1.is_cuda
18 |
19 | batchsize, n, _ = xyz1.size()
20 | _, m, _ = xyz2.size()
21 | dist1 = torch.zeros(batchsize, n)
22 | dist2 = torch.zeros(batchsize, m)
23 |
24 | idx1 = torch.zeros(batchsize, n).type(torch.IntTensor)
25 | idx2 = torch.zeros(batchsize, m).type(torch.IntTensor)
26 |
27 | if not xyz1.is_cuda:
28 | my_lib.nnd_forward(xyz1, xyz2, dist1, dist2, idx1, idx2)
29 | else:
30 | dist1 = dist1.cuda()
31 | dist2 = dist2.cuda()
32 | idx1 = idx1.cuda()
33 | idx2 = idx2.cuda()
34 | my_lib.nnd_forward_cuda(xyz1, xyz2, dist1, dist2, idx1, idx2)
35 |
36 | ctx.save_for_backward(xyz1, xyz2, idx1, idx2)
37 | return dist1, dist2, idx1, idx2
38 |
39 | @staticmethod
40 | @once_differentiable
41 | def backward(ctx, graddist1, graddist2, gradidx1, gradidx2):
42 | """
43 | Note that this function needs gradidx placeholders
44 | """
45 | assert ctx.is_cuda == graddist1.is_cuda and ctx.is_cuda == graddist2.is_cuda
46 | xyz1, xyz2, idx1, idx2 = ctx.saved_tensors
47 | graddist1 = graddist1.contiguous()
48 | graddist2 = graddist2.contiguous()
49 | assert xyz1.is_contiguous()
50 | assert xyz2.is_contiguous()
51 | assert idx1.is_contiguous()
52 | assert idx2.is_contiguous()
53 | assert graddist1.type().endswith('FloatTensor') and graddist2.type().endswith('FloatTensor'), 'only FloatTensor are supported for NNDistance'
54 |
55 | gradxyz1 = xyz1.new(xyz1.size())
56 | gradxyz2 = xyz1.new(xyz2.size())
57 |
58 | if not graddist1.is_cuda:
59 | my_lib.nnd_backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2)
60 | else:
61 | my_lib.nnd_backward_cuda(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2)
62 |
63 | return gradxyz1, gradxyz2
64 |
65 |
66 | def nndistance_w_idx(xyz1, xyz2):
67 | xyz1 = xyz1.contiguous()
68 | xyz2 = xyz2.contiguous()
69 | return NNDFunction.apply(xyz1, xyz2)
70 |
71 |
72 | def nndistance(xyz1, xyz2):
73 | if xyz1.size(2) != 3:
74 | xyz1 = xyz1.transpose(1, 2)
75 | if xyz2.size(2) != 3:
76 | xyz2 = xyz2.transpose(1, 2)
77 | xyz1 = xyz1.contiguous()
78 | xyz2 = xyz2.contiguous()
79 | dist1, dist2, _, _ = NNDFunction.apply(xyz1, xyz2)
80 | return dist1, dist2
81 |
82 |
83 | def nndistance_score(xyz1, xyz2, eps=1e-10):
84 | dist1, dist2 = nndistance(xyz1, xyz2)
85 | return torch.sqrt(dist1 + eps).mean(1) + torch.sqrt(dist2 + eps).mean(1)
86 |
--------------------------------------------------------------------------------
/toolbox/nndistance/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiumingzhang/GenRe-ShapeHD/ee42add2707de509b5914ab444ae91b832f75981/toolbox/nndistance/modules/__init__.py
--------------------------------------------------------------------------------
/toolbox/nndistance/modules/nnd.py:
--------------------------------------------------------------------------------
1 | from torch.nn import Module
2 | from nndistance.functions.nnd import nndistance
3 |
4 |
5 | class NNDModule(Module):
6 | def forward(self, input1, input2):
7 | return nndistance(input1, input2)
8 |
--------------------------------------------------------------------------------
/toolbox/nndistance/setup.sh:
--------------------------------------------------------------------------------
1 | if [[ "$#" -ne 1 || ( "$1" != "script") ]]; then
2 | echo "Usage: ./setup.sh mode"
3 | echo "mode: script (package mode is not supported for now)"
4 | echo "package: build and install as a pip package"
5 | echo "script: build and use as a script. Must be present in local directory for import"
6 | exit 1
7 | fi
8 |
9 | echo "Add -gencode to match all the GPU architectures you have."
10 | echo "Check 'https://en.wikipedia.org/wiki/CUDA#GPUs_supported' for list of architecture."
11 | echo "Check 'http://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html' for GPU compilation based on architecture."
12 |
13 | # GPU architecture short list:
14 | # GTX 650M: 30
15 | # GTX Titan: 35
16 | # GTX Titan Black: 35
17 | # Tesla K40c: 35
18 | # GTX Titan X: 52
19 | # Titan X (Pascal): 61
20 | # GTX 1080: 61
21 | # Titan Xp: 61
22 | # Titan V: 70
23 |
24 | TORCH=$(python -c "import os; import torch; print(os.path.dirname(torch.__file__))")
25 | HAS_CUDA=$(python -c "import torch; print(torch.cuda.is_available())")
26 |
27 | if [ "$HAS_CUDA" == "True" ]; then
28 | if ! type nvcc >/dev/null 2>&1 ; then
29 | echo 'cuda available but nvcc not found. Please add nvcc to $PATH. '
30 | exit 1
31 | fi
32 | cd src
33 | HERE=$(pwd -P)
34 | cmd="nvcc -c -o nnd_cuda.cu.o nnd_cuda.cu -x cu -Xcompiler -fPIC -I ${TORCH}/lib/include/TH -I ${TORCH}/lib/include/THC -I ${HERE} -I ${TORCH}/lib/include\
35 | -gencode arch=compute_30,code=sm_30 \
36 | -gencode arch=compute_35,code=sm_35 \
37 | -gencode arch=compute_52,code=sm_52 \
38 | -gencode arch=compute_61,code=sm_61"
39 | echo "$cmd"
40 | eval "$cmd"
41 | cd ..
42 | fi
43 |
44 | if [ "$1" = "package" ]; then
45 | # for install
46 | python setup.py install
47 | elif [ "$1" = "script" ]; then
48 | # for build
49 | python build.py
50 | else
51 | echo "Shouldn't be here."
52 | fi
53 |
--------------------------------------------------------------------------------
/toolbox/nndistance/src/my_lib.c:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | extern THCState *state;
5 |
6 | void nnsearch(int b,int n,int m,const float * xyz1,const float * xyz2,float * dist,int * idx){
7 | for (int i=0;isize[0];
32 | int batchsize = THCudaTensor_size(state, xyz1, 0);
33 | int n = THCudaTensor_size(state, xyz1, 1);
34 | int m = THCudaTensor_size(state, xyz2, 1);
35 | // int n = xyz1->size[1];
36 | // int m = xyz2->size[1];
37 |
38 | float *xyz1_data = THFloatTensor_data(xyz1);
39 | float *xyz2_data = THFloatTensor_data(xyz2);
40 | float *dist1_data = THFloatTensor_data(dist1);
41 | float *dist2_data = THFloatTensor_data(dist2);
42 | int *idx1_data = THIntTensor_data(idx1);
43 | int *idx2_data = THIntTensor_data(idx2);
44 |
45 | nnsearch(batchsize, n, m, xyz1_data, xyz2_data, dist1_data, idx1_data);
46 | nnsearch(batchsize, m, n, xyz2_data, xyz1_data, dist2_data, idx2_data);
47 |
48 | return 1;
49 | }
50 |
51 |
52 | int nnd_backward(THFloatTensor *xyz1, THFloatTensor *xyz2, THFloatTensor *gradxyz1, THFloatTensor *gradxyz2, THFloatTensor *graddist1, THFloatTensor *graddist2, THIntTensor *idx1, THIntTensor *idx2) {
53 |
54 | int b = THCudaTensor_size(state, xyz1, 0);
55 | int n = THCudaTensor_size(state, xyz1, 1);
56 | int m = THCudaTensor_size(state, xyz2, 1);
57 |
58 | // int b = xyz1->size[0];
59 | // int n = xyz1->size[1];
60 | // int m = xyz2->size[1];
61 |
62 | //printf("%d %d %d\n", batchsize, n, m);
63 |
64 | float *xyz1_data = THFloatTensor_data(xyz1);
65 | float *xyz2_data = THFloatTensor_data(xyz2);
66 | float *gradxyz1_data = THFloatTensor_data(gradxyz1);
67 | float *gradxyz2_data = THFloatTensor_data(gradxyz2);
68 | float *graddist1_data = THFloatTensor_data(graddist1);
69 | float *graddist2_data = THFloatTensor_data(graddist2);
70 | int *idx1_data = THIntTensor_data(idx1);
71 | int *idx2_data = THIntTensor_data(idx2);
72 |
73 |
74 | for (int i=0;i
2 | #include "nnd_cuda.h"
3 |
4 |
5 |
6 | extern THCState *state;
7 |
8 |
9 | int nnd_forward_cuda(THCudaTensor *xyz1, THCudaTensor *xyz2, THCudaTensor *dist1, THCudaTensor *dist2, THCudaIntTensor *idx1, THCudaIntTensor *idx2) {
10 | int success = 0;
11 | success = NmDistanceKernelLauncher(THCudaTensor_size(state, xyz1,0),
12 | THCudaTensor_size(state, xyz1,1),
13 | THCudaTensor_data(state, xyz1),
14 | THCudaTensor_size(state, xyz2,1),
15 | THCudaTensor_data(state, xyz2),
16 | THCudaTensor_data(state, dist1),
17 | THCudaIntTensor_data(state, idx1),
18 | THCudaTensor_data(state, dist2),
19 | THCudaIntTensor_data(state, idx2),
20 | THCState_getCurrentStream(state)
21 | );
22 | //int NmDistanceKernelLauncher(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream)
23 | if (!success) {
24 | THError("aborting");
25 | }
26 | return 1;
27 | }
28 |
29 |
30 | int nnd_backward_cuda(THCudaTensor *xyz1, THCudaTensor *xyz2, THCudaTensor *gradxyz1, THCudaTensor *gradxyz2, THCudaTensor *graddist1,
31 | THCudaTensor *graddist2, THCudaIntTensor *idx1, THCudaIntTensor *idx2) {
32 | int success = 0;
33 | success = NmDistanceGradKernelLauncher(
34 | THCudaTensor_size(state, xyz1,0),
35 | THCudaTensor_size(state, xyz1,1),
36 | THCudaTensor_data(state, xyz1),
37 | THCudaTensor_size(state, xyz2,1),
38 | THCudaTensor_data(state, xyz2),
39 | THCudaTensor_data(state, graddist1),
40 | THCudaIntTensor_data(state, idx1),
41 | THCudaTensor_data(state, graddist2),
42 | THCudaIntTensor_data(state, idx2),
43 | THCudaTensor_data(state, gradxyz1),
44 | THCudaTensor_data(state, gradxyz2),
45 | THCState_getCurrentStream(state)
46 | );
47 | //int NmDistanceGradKernelLauncher(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,const float * grad_dist2,const int * idx2,float * grad_xyz1,float * grad_xyz2, cudaStream_t stream)
48 |
49 | if (!success) {
50 | THError("aborting");
51 | }
52 |
53 | return 1;
54 | }
55 |
56 |
57 |
58 |
--------------------------------------------------------------------------------
/toolbox/nndistance/src/my_lib_cuda.h:
--------------------------------------------------------------------------------
1 | int nnd_forward_cuda(THCudaTensor *xyz1, THCudaTensor *xyz2, THCudaTensor *dist1, THCudaTensor *dist2, THCudaIntTensor *idx1, THCudaIntTensor *idx2);
2 |
3 |
4 | int nnd_backward_cuda(THCudaTensor *xyz1, THCudaTensor *xyz2, THCudaTensor *gradxyz1, THCudaTensor *gradxyz2, THCudaTensor *graddist1, THCudaTensor *graddist2, THCudaIntTensor *idx1, THCudaIntTensor *idx2);
5 |
6 |
--------------------------------------------------------------------------------
/toolbox/nndistance/src/nnd_cuda.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include "nnd_cuda.h"
3 |
4 |
5 |
6 | __global__ void NmDistanceKernel(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i){
7 | const int batch=512;
8 | __shared__ float buf[batch*3];
9 | for (int i=blockIdx.x;ibest){
121 | result[(i*n+j)]=best;
122 | result_i[(i*n+j)]=best_i;
123 | }
124 | }
125 | __syncthreads();
126 | }
127 | }
128 | }
129 | int NmDistanceKernelLauncher(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream){
130 | NmDistanceKernel<<>>(b,n,xyz,m,xyz2,result,result_i);
131 | NmDistanceKernel<<>>(b,m,xyz2,n,xyz,result2,result2_i);
132 |
133 | cudaError_t err = cudaGetLastError();
134 | if (err != cudaSuccess) {
135 | printf("error in nnd updateOutput: %s\n", cudaGetErrorString(err));
136 | //THError("aborting");
137 | return 0;
138 | }
139 | return 1;
140 |
141 |
142 | }
143 | __global__ void NmDistanceGradKernel(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,float * grad_xyz1,float * grad_xyz2){
144 | for (int i=blockIdx.x;i>>(b,n,xyz1,m,xyz2,grad_dist1,idx1,grad_xyz1,grad_xyz2);
167 | NmDistanceGradKernel<<>>(b,m,xyz2,n,xyz1,grad_dist2,idx2,grad_xyz2,grad_xyz1);
168 |
169 | cudaError_t err = cudaGetLastError();
170 | if (err != cudaSuccess) {
171 | printf("error in nnd get grad: %s\n", cudaGetErrorString(err));
172 | //THError("aborting");
173 | return 0;
174 | }
175 | return 1;
176 |
177 | }
178 |
179 |
--------------------------------------------------------------------------------
/toolbox/nndistance/src/nnd_cuda.h:
--------------------------------------------------------------------------------
1 | #ifdef __cplusplus
2 | extern "C" {
3 | #endif
4 |
5 | int NmDistanceKernelLauncher(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream);
6 |
7 | int NmDistanceGradKernelLauncher(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,const float * grad_dist2,const int * idx2,float * grad_xyz1,float * grad_xyz2, cudaStream_t stream);
8 |
9 | #ifdef __cplusplus
10 | }
11 | #endif
--------------------------------------------------------------------------------
/toolbox/nndistance/test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Variable
3 |
4 | try:
5 | from nndistance.modules.nnd import NNDModule
6 | except ImportError as err:
7 | raise ImportError('This file should be copied to its parent directory for import to work properly.')
8 |
9 | dist = NNDModule()
10 |
11 | p1 = torch.rand(1,50,3)*20
12 | p2 = torch.rand(1,50,3)*20
13 | # p1 = p1.int()
14 | # p1.random_(0,2)
15 | # p1 = p1.float()
16 | # p2 = p2.int()
17 | # p2.random_(0,2)
18 | p2 = p2.float()
19 | # print(p1)
20 | # print(p2)
21 |
22 | print('cpu')
23 | points1 = Variable(p1, requires_grad=True)
24 | points2 = Variable(p2, requires_grad=True)
25 | dist1, dist2 = dist(points1, points2)
26 | print(dist1, dist2)
27 | loss = torch.sum(dist1)
28 | print(loss)
29 | loss.backward()
30 | print(points1.grad, points2.grad)
31 |
32 | print('gpu')
33 | points1_cuda = Variable(p1.cuda(), requires_grad=True)
34 | points2_cuda = Variable(p2.cuda(), requires_grad=True)
35 | dist1_cuda, dist2_cuda = dist(points1_cuda, points2_cuda)
36 | print(dist1_cuda, dist2_cuda)
37 | loss_cuda = torch.sum(dist1_cuda)
38 | print(loss_cuda)
39 | loss_cuda.backward()
40 | print(points1_cuda.grad, points2_cuda.grad)
41 |
42 | print('stats:')
43 | print('loss:', loss.data[0], loss_cuda.data[0])
44 | print('loss diff:', loss.data[0] - loss_cuda.data[0])
45 | print('grad diff:', (points1.grad.data.cpu() - points1_cuda.grad.data.cpu()).abs().max(), (points2.grad.data.cpu() - points2_cuda.grad.data.cpu()).abs().max())
46 |
47 | from nndistance.functions.nnd import nndistance_score
48 | print('total score:', nndistance_score(points1, points2))
49 |
--------------------------------------------------------------------------------
/toolbox/spherical_proj.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from .calc_prob.calc_prob.functions.calc_prob import CalcStopProb
4 |
5 |
6 | def gen_sph_grid(res=128):
7 | pi = np.pi
8 | phi = np.linspace(0, 180, res * 2 + 1)[1::2]
9 | theta = np.linspace(0, 360, res + 1)[:-1]
10 | grid = np.zeros([res, res, 3])
11 | for idp, p in enumerate(phi):
12 | for idt, t in enumerate(theta):
13 | grid[idp, idt, 2] = np.cos((p * pi / 180))
14 | proj = np.sin((p * pi / 180))
15 | grid[idp, idt, 0] = proj * np.cos(t * pi / 180)
16 | grid[idp, idt, 1] = proj * np.sin(t * pi / 180)
17 | grid = np.reshape(grid, (1, 1, res, res, 3))
18 | return torch.from_numpy(grid).float()
19 |
20 |
21 | def sph_pad(sph_tensor, padding_margin=16):
22 | F = torch.nn.functional
23 | pad2d = (padding_margin, padding_margin, padding_margin, padding_margin)
24 | rep_padded_sph = F.pad(sph_tensor, pad2d, mode='replicate')
25 | _, _, h, w = rep_padded_sph.shape
26 | rep_padded_sph[:, :, :, 0:padding_margin] = rep_padded_sph[:, :, :, w - 2 * padding_margin:w - padding_margin]
27 | rep_padded_sph[:, :, :, h - padding_margin:] = rep_padded_sph[:, :, :, padding_margin:2 * padding_margin]
28 | return rep_padded_sph
29 |
30 |
31 | class render_spherical(torch.nn.Module):
32 | def __init__(self, sph_res=128, z_res=256):
33 | super().__init__()
34 | self.sph_res = sph_res
35 | self.z_res = z_res
36 | self.gen_grid()
37 | self.calc_stop_prob = CalcStopProb().apply
38 |
39 | def gen_grid(self):
40 | res = self.sph_res
41 | z_res = self.z_res
42 | pi = np.pi
43 | phi = np.linspace(0, 180, res * 2 + 1)[1::2]
44 | theta = np.linspace(0, 360, res + 1)[:-1]
45 | grid = np.zeros([res, res, 3])
46 | for idp, p in enumerate(phi):
47 | for idt, t in enumerate(theta):
48 | grid[idp, idt, 2] = np.cos((p * pi / 180))
49 | proj = np.sin((p * pi / 180))
50 | grid[idp, idt, 0] = proj * np.cos(t * pi / 180)
51 | grid[idp, idt, 1] = proj * np.sin(t * pi / 180)
52 | grid = np.reshape(grid * 2, (res, res, 3))
53 | alpha = np.zeros([1, 1, z_res, 1])
54 | alpha[0, 0, :, 0] = np.linspace(0, 1, z_res)
55 | grid = grid[:, :, np.newaxis, :]
56 | grid = grid * (1 - alpha)
57 | grid = torch.from_numpy(grid).float()
58 | depth_weight = torch.linspace(0, 1, self.z_res)
59 | self.register_buffer('depth_weight', depth_weight)
60 | self.register_buffer('grid', grid)
61 |
62 | def forward(self, vox):
63 | grid = self.grid.expand(vox.shape[0], -1, -1, -1, -1)
64 | vox = vox.permute(0, 1, 4, 3, 2)
65 | prob_sph = torch.nn.functional.grid_sample(vox, grid)
66 | prob_sph = torch.clamp(prob_sph, 1e-5, 1 - 1e-5)
67 | sph_stop_prob = self.calc_stop_prob(prob_sph)
68 | exp_depth = torch.matmul(sph_stop_prob, self.depth_weight)
69 | back_groud_prob = torch.prod(1.0 - prob_sph, dim=4)
70 | back_groud_prob = back_groud_prob * 1.0
71 | exp_depth = exp_depth + back_groud_prob
72 | return exp_depth
73 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | import time
4 | import pandas as pd
5 | import torch
6 | from options import options_train
7 | import datasets
8 | import models
9 | from loggers import loggers
10 | from util.util_print import str_error, str_stage, str_verbose, str_warning
11 | from util import util_loadlib as loadlib
12 |
13 |
14 | ###################################################
15 |
16 | print(str_stage, "Parsing arguments")
17 | opt, unique_opt_params = options_train.parse()
18 | # Get all parse done, including subparsers
19 | print(opt)
20 |
21 | ###################################################
22 |
23 | print(str_stage, "Setting device")
24 | if opt.gpu == '-1':
25 | device = torch.device('cpu')
26 | else:
27 | loadlib.set_gpu(opt.gpu)
28 | device = torch.device('cuda')
29 | if opt.manual_seed is not None:
30 | loadlib.set_manual_seed(opt.manual_seed)
31 |
32 | ###################################################
33 |
34 | print(str_stage, "Setting up logging directory")
35 | exprdir = '{}_{}_{}'.format(opt.net, opt.dataset, opt.lr)
36 | exprdir += ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else ''
37 | logdir = os.path.join(opt.logdir, exprdir, str(opt.expr_id))
38 |
39 | if opt.resume == 0:
40 | if os.path.isdir(logdir):
41 | if opt.expr_id <= 0:
42 | print(
43 | str_warning, (
44 | "Will remove Experiment %d at\n\t%s\n"
45 | "Do you want to continue? (y/n)"
46 | ) % (opt.expr_id, logdir)
47 | )
48 | need_input = True
49 | while need_input:
50 | response = input().lower()
51 | if response in ('y', 'n'):
52 | need_input = False
53 | if response == 'n':
54 | print(str_stage, "User decides to quit")
55 | sys.exit()
56 | os.system('rm -rf ' + logdir)
57 | else:
58 | raise ValueError(str_error +
59 | " Refuse to remove positive expr_id")
60 | os.system('mkdir -p ' + logdir)
61 | else:
62 | assert os.path.isdir(logdir)
63 | opt_f_old = os.path.join(logdir, 'opt.pt')
64 | opt = options_train.overwrite(opt, opt_f_old, unique_opt_params)
65 |
66 | # Save opt
67 | torch.save(vars(opt), os.path.join(logdir, 'opt.pt'))
68 | with open(os.path.join(logdir, 'opt.txt'), 'w') as fout:
69 | for k, v in vars(opt).items():
70 | fout.write('%20s\t%-20s\n' % (k, v))
71 |
72 | opt.full_logdir = logdir
73 | print(str_verbose, "Logging directory set to: %s" % logdir)
74 |
75 | ###################################################
76 |
77 | print(str_stage, "Setting up loggers")
78 | if opt.resume != 0 and os.path.isfile(os.path.join(logdir, 'best.pt')):
79 | try:
80 | prev_best_data = torch.load(os.path.join(logdir, 'best.pt'))
81 | prev_best = prev_best_data['loss_eval']
82 | del prev_best_data
83 | except KeyError:
84 | prev_best = None
85 | else:
86 | prev_best = None
87 | best_model_logger = loggers.ModelSaveLogger(
88 | os.path.join(logdir, 'best.pt'),
89 | period=1,
90 | save_optimizer=True,
91 | save_best=True,
92 | prev_best=prev_best
93 | )
94 | logger_list = [
95 | loggers.TerminateOnNaN(),
96 | loggers.ProgbarLogger(allow_unused_fields='all'),
97 | loggers.CsvLogger(
98 | os.path.join(logdir, 'epoch_loss.csv'),
99 | allow_unused_fields='all'
100 | ),
101 | loggers.ModelSaveLogger(
102 | os.path.join(logdir, 'nets', '{epoch:04d}.pt'),
103 | period=opt.save_net,
104 | save_optimizer=opt.save_net_opt
105 | ),
106 | loggers.ModelSaveLogger(
107 | os.path.join(logdir, 'checkpoint.pt'),
108 | period=1,
109 | save_optimizer=True
110 | ),
111 | best_model_logger,
112 | ]
113 | if opt.log_batch:
114 | logger_list.append(
115 | loggers.BatchCsvLogger(
116 | os.path.join(logdir, 'batch_loss.csv'),
117 | allow_unused_fields='all'
118 | )
119 | )
120 | if opt.tensorboard:
121 | tf_logdir = os.path.join(
122 | opt.logdir, 'tensorboard', exprdir, str(opt.expr_id))
123 | if os.path.isdir(tf_logdir) and opt.resume == 0:
124 | os.system('rm -r ' + tf_logdir) # remove previous tensorboard log if overwriting
125 | if not os.path.isdir(os.path.join(logdir, 'tensorboard')):
126 | os.symlink(tf_logdir, os.path.join(logdir, 'tensorboard'))
127 | logger_list.append(
128 | loggers.TensorBoardLogger(
129 | tf_logdir,
130 | allow_unused_fields='all'
131 | )
132 | )
133 | logger = loggers.ComposeLogger(logger_list)
134 |
135 | ###################################################
136 |
137 | print(str_stage, "Setting up models")
138 | Model = models.get_model(opt.net)
139 | model = Model(opt, logger)
140 | model.to(device)
141 | print(model)
142 | print("# model parameters: {:,d}".format(model.num_parameters()))
143 |
144 | initial_epoch = 1
145 | if opt.resume != 0:
146 | if opt.resume == -1:
147 | net_filename = os.path.join(logdir, 'checkpoint.pt')
148 | elif opt.resume == -2:
149 | net_filename = os.path.join(logdir, 'best.pt')
150 | else:
151 | net_filename = os.path.join(
152 | logdir, 'nets', '{epoch:04d}.pt').format(epoch=opt.resume)
153 | if not os.path.isfile(net_filename):
154 | print(str_warning, ("Network file not found for opt.resume=%d. "
155 | "Starting from scratch") % opt.resume)
156 | else:
157 | additional_values = model.load_state_dict(net_filename, load_optimizer='auto')
158 | try:
159 | initial_epoch += additional_values['epoch']
160 | except KeyError as err:
161 | # Old saved model does not have epoch as additional values
162 | epoch_loss_csv = os.path.join(logdir, 'epoch_loss.csv')
163 | if opt.resume == -1:
164 | try:
165 | initial_epoch += pd.read_csv(epoch_loss_csv)['epoch'].max()
166 | except pd.errors.ParserError:
167 | with open(epoch_loss_csv, 'r') as f:
168 | lines = f.readlines()
169 | initial_epoch += max([int(l.split(',')[0]) for l in lines[1:]])
170 | else:
171 | initial_epoch += opt.resume
172 |
173 | ###################################################
174 |
175 | print(str_stage, "Setting up data loaders")
176 | start_time = time.time()
177 | dataset = datasets.get_dataset(opt.dataset)
178 | dataset_train = dataset(opt, mode='train', model=model)
179 | dataset_vali = dataset(opt, mode='vali', model=model)
180 | dataloader_train = torch.utils.data.DataLoader(
181 | dataset_train,
182 | batch_size=opt.batch_size,
183 | shuffle=True,
184 | num_workers=opt.workers,
185 | pin_memory=True,
186 | drop_last=True
187 | )
188 | dataloader_vali = torch.utils.data.DataLoader(
189 | dataset_vali,
190 | batch_size=opt.batch_size,
191 | num_workers=opt.workers,
192 | pin_memory=True,
193 | drop_last=True,
194 | shuffle=False
195 | )
196 | print(str_verbose, "Time spent in data IO initialization: %.2fs" %
197 | (time.time() - start_time))
198 | print(str_verbose, "# training points: " + str(len(dataset_train)))
199 | print(str_verbose, "# training batches per epoch: " + str(len(dataloader_train)))
200 | print(str_verbose, "# test batches: " + str(len(dataloader_vali)))
201 |
202 | ###################################################
203 |
204 | if opt.epoch > 0:
205 | print(str_stage, "Training")
206 | model.train_epoch(
207 | dataloader_train,
208 | dataloader_eval=dataloader_vali,
209 | max_batches_per_train=opt.epoch_batches,
210 | epochs=opt.epoch,
211 | initial_epoch=initial_epoch,
212 | max_batches_per_eval=opt.eval_batches,
213 | eval_at_start=opt.eval_at_start
214 | )
215 |
--------------------------------------------------------------------------------
/util/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiumingzhang/GenRe-ShapeHD/ee42add2707de509b5914ab444ae91b832f75981/util/__init__.py
--------------------------------------------------------------------------------
/util/util_cam_para.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def read_cam_para_from_xml(xml_name):
5 | # azi ele only
6 | import xml.etree.ElementTree
7 | e = xml.etree.ElementTree.parse(xml_name).getroot()
8 |
9 | assert len(e.findall('sensor')) == 1
10 | for x in e.findall('sensor'):
11 | assert len(x.findall('transform')) == 1
12 | for y in x.findall('transform'):
13 | assert len(y.findall('lookAt')) == 1
14 | for z in y.findall('lookAt'):
15 | origin = np.array(z.get('origin').split(','), dtype=np.float32)
16 | # up = np.array(z.get('up').split(','), dtype=np.float32)
17 |
18 | x, y, z = origin
19 | elevation = np.arctan2(y, np.sqrt(x ** 2 + z ** 2))
20 | azimuth = np.arctan2(x, z) + np.pi
21 | if azimuth >= np.pi:
22 | azimuth -= 2 * np.pi
23 | assert azimuth >= -np.pi and azimuth <= np.pi
24 | assert elevation >= -np.pi / 2. and elevation <= np.pi / 2.
25 | return azimuth, elevation
26 |
27 |
28 | def raw_camparam_from_xml(path, pose="lookAt"):
29 | import xml.etree.ElementTree as ET
30 | tree = ET.parse(path)
31 | elm = tree.find("./sensor/transform/" + pose)
32 | camparam = elm.attrib
33 | origin = np.fromstring(camparam['origin'], dtype=np.float32, sep=',')
34 | target = np.fromstring(camparam['target'], dtype=np.float32, sep=',')
35 | up = np.fromstring(camparam['up'], dtype=np.float32, sep=',')
36 | height = int(
37 | tree.find("./sensor/film/integer[@name='height']").attrib['value'])
38 | width = int(
39 | tree.find("./sensor/film/integer[@name='width']").attrib['value'])
40 |
41 | camparam = dict()
42 | camparam['origin'] = origin
43 | camparam['up'] = up
44 | camparam['target'] = target
45 | camparam['height'] = height
46 | camparam['width'] = width
47 | return camparam
48 |
49 |
50 | def get_object_rotation(xml_path, style='zup'):
51 | style_set = ['yup', 'zup', 'spherical_proj']
52 | assert(style in style_set)
53 | camparam = raw_camparam_from_xml(xml_path)
54 | if style == 'zup':
55 | Rx = camparam['target'] - camparam['origin']
56 | up = camparam['up']
57 | Rz = np.cross(Rx, up)
58 | Ry = np.cross(Rz, Rx)
59 | Rx /= np.linalg.norm(Rx)
60 | Ry /= np.linalg.norm(Ry)
61 | Rz /= np.linalg.norm(Rz)
62 | R = np.array([Rx, Ry, Rz])
63 | R_coord = np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]])
64 | R = R_coord @ R
65 | R = R @ R_coord.transpose()
66 | elif style == 'yup':
67 | Rx = camparam['target'] - camparam['origin']
68 | up = camparam['up']
69 | Rz = np.cross(Rx, up)
70 | Ry = np.cross(Rz, Rx)
71 | Rx /= np.linalg.norm(Rx)
72 | Ry /= np.linalg.norm(Ry)
73 | Rz /= np.linalg.norm(Rz)
74 | #print(Rx, Ry, Rz)
75 | # no transpose needed!
76 | R = np.array([Rx, Ry, Rz])
77 | elif style == 'spherical_proj':
78 | Rx = camparam['target'] - camparam['origin']
79 | up = camparam['up']
80 | Rz = np.cross(Rx, up)
81 | Ry = np.cross(Rz, Rx)
82 | Rx /= np.linalg.norm(Rx)
83 | Ry /= np.linalg.norm(Ry)
84 | Rz /= np.linalg.norm(Rz)
85 | #print(Rx, Ry, Rz)
86 | # no transpose needed!
87 | R = np.array([Rx, Ry, Rz])
88 |
89 | raise NotImplementedError
90 | return R
91 |
92 |
93 | def get_object_rotation_translation(xml_path, style='zup'):
94 | pass
95 |
96 |
97 | def _devide_into_section(angle, num_section, lower_bound, upper_bound):
98 | rst = np.zeros(num_section)
99 | per_section_size = (upper_bound - lower_bound) / num_section
100 | angle -= per_section_size / 2
101 | if angle < lower_bound:
102 | angle += upper_bound - lower_bound
103 | idx = int((angle - lower_bound) / per_section_size)
104 | rst[idx] = 1
105 | return rst
106 |
107 |
108 | def _section_to_angle(idx, num_section, lower_bound, upper_bound):
109 | per_section_size = (upper_bound - lower_bound) / num_section
110 |
111 | angle = (idx + 0.5) * per_section_size + lower_bound
112 | angle += per_section_size / 2
113 | if angle > upper_bound:
114 | angle -= upper_bound - lower_bound
115 | return angle
116 |
117 |
118 | def azimuth_to_onehot(azimuth, num_azimuth):
119 | return _devide_into_section(azimuth, num_azimuth, -np.pi, np.pi)
120 |
121 |
122 | def elevation_to_onehot(elevation, num_elevation):
123 | return _devide_into_section(elevation, num_elevation, -np.pi / 2., np.pi / 2.)
124 |
125 |
126 | def onehot_to_azimuth(v, num_azimuth):
127 | idx = np.argmax(v)
128 | return _section_to_angle(idx, num_azimuth, -np.pi, np.pi)
129 |
130 |
131 | def onehot_to_elevation(v, num_elevation):
132 | idx = np.argmax(v)
133 | return _section_to_angle(idx, num_elevation, -np.pi / 2., np.pi / 2.)
134 |
135 |
136 | if __name__ == '__main__':
137 | num_azimuth = 24
138 | num_elevation = 12
139 | for i in range(num_azimuth):
140 | rst = np.zeros(num_azimuth)
141 | rst[i] = 1
142 | print(onehot_to_azimuth(rst, num_azimuth))
143 |
144 | '''
145 | for i in range(100):
146 | angle = (np.random.rand() - 0.5) * np.pi * 2
147 | print(angle, np.argmax(azimuth_to_onehot(angle, 24)), onehot_to_azimuth(azimuth_to_onehot(angle, 24), 24))
148 | assert np.abs(angle - onehot_to_azimuth(azimuth_to_onehot(angle, 24), 24)) < 2 * np.pi / 24
149 | '''
150 |
--------------------------------------------------------------------------------
/util/util_camera.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy.misc import imresize
3 | from numba import jit
4 |
5 |
6 | @jit
7 | def calc_ptnum(triangle, density):
8 | pt_num_tr = np.zeros(len(triangle)).astype(int)
9 | pt_num_total = 0
10 | for tr_id, tr in enumerate(triangle):
11 | a = np.linalg.norm(np.cross(tr[1] - tr[0], tr[2] - tr[0])) / 2
12 | ptnum = max(int(a * density), 1)
13 | pt_num_tr[tr_id] = ptnum
14 | pt_num_total += ptnum
15 | return pt_num_tr, pt_num_total
16 |
17 |
18 | class Camera():
19 | # camera coordinates: y up, z forward, x right.
20 | # consistent with blender definitions.
21 | # res = [w,h]
22 | def __init__(self):
23 | self.position = np.array([1.6, 0, 0])
24 | self.rx = np.array([0, 1, 0])
25 | self.ry = np.array([0, 0, 1])
26 | self.rz = np.array([1, 0, 0])
27 | self.res = [800, 600]
28 | self.focal_length = 0.05
29 | # set the diagnal to be 35mm film's diagnal
30 | self.set_diagal((0.036**2 + 0.024**2)**0.5)
31 |
32 | def rotate(self, rot_mat):
33 | self.rx = rot_mat[:, 0]
34 | self.ry = rot_mat[:, 1]
35 | self.rz = rot_mat[:, 2]
36 |
37 | def move_cam(self, new_pos):
38 | self.position = new_pos
39 |
40 | def set_pose(self, inward, up):
41 | self.rx = np.cross(up, inward)
42 | self.ry = np.array(up)
43 | self.rz = np.array(inward)
44 | self.rx /= np.linalg.norm(self.rx)
45 | self.ry /= np.linalg.norm(self.ry)
46 | self.rz /= np.linalg.norm(self.rz)
47 |
48 | def set_diagal(self, diag):
49 | h_relative = self.res[1] / self.res[0]
50 | self.sensor_width = np.sqrt(diag**2 / (1 + h_relative**2))
51 |
52 | def lookat(self, orig, target, up):
53 | self.position = np.array(orig)
54 | target = np.array(target)
55 | inward = self.position - target
56 | right = np.cross(up, inward)
57 | up = np.cross(inward, right)
58 | self.set_pose(inward, up)
59 |
60 | def set_cam_from_mitsuba(self, path):
61 | camparam = util.cam_from_mitsuba(path)
62 | self.lookat(orig=camparam['origin'],
63 | up=camparam['up'], target=camparam['target'])
64 | self.res = [camparam['width'], camparam['height']]
65 | self.focal_length = 0.05
66 | # set the diagnal to be 35mm film's diagnal
67 | self.set_diagal((0.036**2 + 0.024**2)**0.5)
68 |
69 | def project_point(self, pt):
70 | # project global point to image coordinates in pixels (float not
71 | # integer).
72 | res = self.res
73 | rel = np.array(pt) - self.position
74 | depth = -np.dot(rel, self.rz)
75 | if rel.ndim != 1:
76 | depth = depth.reshape([np.size(depth, axis=0), 1])
77 | rel_plane = rel * self.focal_length / depth
78 | rel_width = np.dot(rel_plane, self.rx)
79 | rel_height = np.dot(rel_plane, self.ry)
80 | topleft = np.array([-self.sensor_width / 2,
81 | self.sensor_width * (res[1] / res[0]) / 2])
82 | pix_size = self.sensor_width / res[0]
83 | topleft += np.array([pix_size / 2, -pix_size / 2])
84 | im_pix_x = (topleft[1] - rel_height) / pix_size
85 | im_pix_y = (rel_width - topleft[0]) / pix_size
86 | return im_pix_x, im_pix_y
87 |
88 | def project_depth(self, pt, depth_type='ray'):
89 | if depth_type == 'ray':
90 | if np.array(pt).ndim == 1:
91 | return np.linalg.norm(pt - self.position)
92 | return np.linalg.norm(pt - self.position, axis=1)
93 | else:
94 | return np.dot(pt - self.position, -self.rz)
95 |
96 | def pack(self):
97 | params = []
98 | params += self.res
99 | params += [self.sensor_width]
100 | params += self.position.tolist()
101 | params += self.rx.tolist()
102 | params += self.ry.tolist()
103 | params += self.rz.tolist()
104 | params += [self.focal_length]
105 | return params
106 |
107 |
108 | class tsdf_renderer:
109 | def __init__(self):
110 | self.camera = Camera()
111 | self.depth = []
112 |
113 | def load_depth_map_npy(self, path):
114 | self.depth = np.load(path)
115 |
116 | def back_project_ptcloud(self, upsample=1.0, depth_type='ray'):
117 | if not self.check_valid():
118 | return
119 | mask = np.where(self.depth < 0, 0, 1)
120 | depth = imresize(self.depth, upsample, mode='F', interp='bilinear')
121 | up_mask = imresize(mask, upsample, mode='F', interp='bilinear')
122 | up_mask = np.where(up_mask < 1, 0, 1)
123 | ind = np.where(up_mask == 0)
124 | depth[ind] = -1
125 | # res = self.camera.res
126 | res = np.array([0, 0])
127 | res[0] = np.shape(depth)[1] # width
128 | res[1] = np.shape(depth)[0] # height
129 | self.check_depth = np.zeros([res[1], res[0]], dtype=np.float32) - 1
130 | pt_pos = np.where(up_mask == 1)
131 | ptnum = len(pt_pos[0])
132 | ptcld = np.zeros([ptnum, 3])
133 | half_width = self.camera.sensor_width / 2
134 | half_height = half_width * res[1] / res[0]
135 | pix_size = self.camera.sensor_width / res[0]
136 | top_left = self.camera.position \
137 | - self.camera.focal_length * self.camera.rz\
138 | - half_width * self.camera.rx\
139 | + half_height * self.camera.ry
140 |
141 | for x in range(ptnum):
142 | height_id = pt_pos[0][x]
143 | width_id = pt_pos[1][x]
144 | pix_depth = depth[height_id, width_id]
145 | pix_coord = - (height_id + 0.5) * pix_size * self.camera.ry\
146 | + (width_id + 0.5) * pix_size * self.camera.rx\
147 | + top_left
148 | pix_rel = pix_coord - self.camera.position
149 | if depth_type == 'plane':
150 | ptcld_pos = (pix_rel)\
151 | * (pix_depth / self.camera.focal_length) \
152 | + self.camera.position
153 | back_project_depth = -np.dot(pix_rel, self.camera.rz)
154 | else:
155 | ptcld_pos = (pix_rel / np.linalg.norm(pix_rel))\
156 | * (pix_depth) + self.camera.position
157 | back_project_depth = np.linalg.norm(
158 | ptcld_pos - self.camera.position)
159 | ptcld[x, :] = ptcld_pos
160 | self.check_depth[height_id, width_id] = back_project_depth
161 | self.ptcld = ptcld
162 | self.pt_pos = pt_pos
163 |
164 | def check_valid(self, warning=True):
165 | if self.depth == []:
166 | print('No depth map available!')
167 | return False
168 | shape = np.shape(self.depth)
169 | if warning and (shape[0] != self.camera.res[1] or shape[1] != self.camera.res[0]):
170 | print('depth map and camera resolution mismatch!')
171 | print('camera: {}'.format(self.camera.res))
172 | print('depth: {}'.format(shape))
173 | return True
174 | return True
175 |
--------------------------------------------------------------------------------
/util/util_loadlib.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | from .util_print import str_warning, str_verbose
3 |
4 |
5 | def set_gpu(gpu, check=True):
6 | import os
7 | _check_gpu(gpu)
8 | os.environ["CUDA_VISIBLE_DEVICES"] = gpu
9 | import torch
10 | import torch.backends.cudnn as cudnn
11 | cudnn.benchmark = True
12 | if check:
13 | if not _check_gpu_setting_in_use(gpu):
14 | print('[Warning] gpu setting overwritten. torch.cuda may be initialized before running this function.')
15 |
16 |
17 | def _check_gpu_setting_in_use(gpu):
18 | '''
19 | check that CUDA_VISIBLE_DEVICES is actually working
20 | by starting a clean thread with the same CUDA_VISIBLE_DEVICES
21 | '''
22 | import subprocess
23 | output = subprocess.check_output('CUDA_VISIBLE_DEVICES=%s python -c "import torch; print(torch.cuda.device_count())"' % gpu, shell=True)
24 | output = output.decode().strip()
25 | import torch
26 | return torch.cuda.device_count() == int(output)
27 |
28 |
29 | def _check_gpu(gpu):
30 | msg = subprocess.check_output('nvidia-smi --query-gpu=index,utilization.gpu,memory.used --format=csv,nounits,noheader -i %s' % (gpu,), shell=True)
31 | msg = msg.decode('utf-8')
32 | all_ok = True
33 | for line in msg.split('\n'):
34 | if line == '':
35 | break
36 | stats = [x.strip() for x in line.split(',')]
37 | gpu = stats[0]
38 | util = int(stats[1])
39 | mem_used = int(stats[2])
40 | if util > 10 or mem_used > 1000: # util in percentage and mem_used in MiB
41 | print(str_warning, 'Designated GPU in use: id=%s, util=%d%%, memory in use: %d MiB' % (gpu, util, mem_used))
42 | all_ok = False
43 | if all_ok:
44 | print(str_verbose, 'All designated GPU(s) free to use. ')
45 |
46 |
47 | def set_manual_seed(seed):
48 | import random
49 | random.seed(seed)
50 | try:
51 | import numpy as np
52 | np.random.seed(seed)
53 | except ImportError as err:
54 | print('Numpy not found. Random seed for numpy not set. ')
55 | try:
56 | import torch
57 | torch.manual_seed(seed)
58 | torch.cuda.manual_seed_all(seed)
59 | except ImportError as err:
60 | print('Pytorch not found. Random seed for pytorch not set. ')
61 |
--------------------------------------------------------------------------------
/util/util_print.py:
--------------------------------------------------------------------------------
1 | class bcolors:
2 | HEADER = '\033[95m'
3 | OKBLUE = '\033[94m'
4 | OKGREEN = '\033[92m'
5 | WARNING = '\033[93m'
6 | FAIL = '\033[91m'
7 | ENDC = '\033[0m'
8 | BOLD = '\033[1m'
9 | UNDERLINE = '\033[4m'
10 |
11 |
12 | str_stage = bcolors.OKBLUE + '==>' + bcolors.ENDC
13 | str_verbose = bcolors.OKGREEN + '[Verbose]' + bcolors.ENDC
14 | str_warning = bcolors.WARNING + '[Warning]' + bcolors.ENDC
15 | str_error = bcolors.FAIL + '[Error]' + bcolors.ENDC
16 |
--------------------------------------------------------------------------------
/util/util_reproj.py:
--------------------------------------------------------------------------------
1 | import time
2 | import numpy as np
3 | import torch
4 | from torch.autograd import Variable
5 |
6 |
7 | def cross_prod(u, v):
8 | # Cross pruduct between a set of vectors and a vector
9 | if len(u.size()) == 2:
10 | i = u[:, 1] * v[2] - u[:, 2] * v[1]
11 | j = u[:, 2] * v[0] - u[:, 0] * v[2]
12 | k = u[:, 0] * v[1] - u[:, 1] * v[0]
13 | return torch.stack((i, j, k), 1)
14 | elif len(u.size()) == 3:
15 | i = u[:, :, 1] * v[2] - u[:, :, 2] * v[1]
16 | j = u[:, :, 2] * v[0] - u[:, :, 0] * v[2]
17 | k = u[:, :, 0] * v[1] - u[:, :, 1] * v[0]
18 | return torch.stack((i, j, k), 2)
19 | raise Exception()
20 |
21 |
22 | def criterion_single(v, x, x_0, n_0, l, alpha=np.sqrt(2) / 2, beta=1, gamma=1.):
23 | v = v.view(-1)
24 | x = x.view(-1, 3)
25 | n_0 /= torch.sum(n_0 ** 2)
26 |
27 | # Find the voxel which is nearest to x_0
28 | _, index = torch.min(torch.sum((x - x_0) ** 2, dim=1), dim=0)
29 | i_0 = index.data.cpu().numpy()[0]
30 |
31 | # loss for (i_0, j_0, k_0)
32 | loss_1 = (1 - v[i_0]) ** 2
33 |
34 | # loss for others
35 | d = torch.sum(cross_prod((x - x_0), n_0) ** 2, dim=1) ** 0.5
36 | mask_1 = (d < alpha * l).float()
37 | mask_2 = torch.ones(*v.size())
38 | mask_2[i_0] = 0
39 | mask_2 = Variable(mask_2.cuda())
40 | loss_2 = torch.sum((gamma * (1 - d / (alpha * l)) ** beta * v ** 2) * mask_1 * mask_2)
41 |
42 | return loss_1 + loss_2
43 |
44 |
45 | def criterion(v, x, x_0, n_0, l, alpha=np.sqrt(2) / 2, beta=1, gamma=1.):
46 | n_sample = x_0.size(0)
47 | v = v.view(-1)
48 | x = x.view(-1, 3)
49 | n_0 /= torch.sum(n_0 ** 2)
50 |
51 | # Find the voxel which is nearest to x_0
52 | x_repeat = x.view(x.size(0), 1, x.size(1)).repeat(1, n_sample, 1)
53 | x_sub = x_repeat - x_0
54 | _, index = torch.min(torch.sum(x_sub ** 2, dim=2), dim=0)
55 | i_0 = index.data.cpu().numpy()
56 |
57 | # loss for (i_0, j_0, k_0)
58 | loss_1 = Variable(torch.zeros(1).cuda())
59 | for i in range(n_sample):
60 | loss_1 += (1 - v[i_0[i]]) ** 2
61 |
62 | # loss for others
63 | d = torch.sum(cross_prod(x_sub, n_0) ** 2, dim=2) ** 0.5
64 | mask_1 = (d < alpha * l).float()
65 | mask_2 = torch.ones(v.size(0), n_sample)
66 | for i in range(n_sample):
67 | mask_2[i_0[i]][i] = 0
68 | mask_2 = Variable(mask_2.cuda())
69 | v_repeat = v.view(v.size(0), 1).repeat(1, n_sample)
70 | loss_2 = torch.sum((gamma * (1 - d / (alpha * l)) ** beta * v_repeat ** 2) * mask_1 * mask_2)
71 | return loss_2
72 |
73 |
74 | if __name__ == '__main__':
75 | torch.manual_seed(70)
76 | n_sample = 90
77 | N = 128
78 | l = 1.
79 | v = Variable(torch.rand(N, N, N).cuda(), requires_grad=True)
80 | x = Variable(torch.rand(N, N, N, 3).cuda())
81 | x_0 = Variable(torch.rand(n_sample, 3).cuda())
82 | n_0 = Variable(torch.rand(3).cuda())
83 |
84 | start = time.time()
85 |
86 | loss = criterion(v, x, x_0, n_0, l)
87 |
88 | '''
89 | loss = Variable(torch.zeros(1).cuda())
90 | for i in range(n_sample):
91 | loss += criterion_single(v, x, x_0[i], n_0, l)
92 | '''
93 |
94 | loss.backward()
95 | print(v.grad[0, 0, 0])
96 |
97 | end = time.time()
98 | print(end - start)
99 |
100 |
101 | u = Variable(torch.rand(N, 3).cuda())
102 | v = Variable(torch.rand(3).cuda())
103 | # print(cross_prod(u, v))
104 |
105 | # print(np.cross(u.data.cpu().numpy()[0], v.data.cpu().numpy()))
106 |
--------------------------------------------------------------------------------
/util/util_sph.py:
--------------------------------------------------------------------------------
1 | import trimesh
2 | from util.util_img import depth_to_mesh_df, resize
3 | from skimage import measure
4 | import numpy as np
5 |
6 |
7 | def render_model(mesh, sgrid):
8 | index_tri, index_ray, loc = mesh.ray.intersects_id(
9 | ray_origins=sgrid, ray_directions=-sgrid, multiple_hits=False, return_locations=True)
10 | loc = loc.reshape((-1, 3))
11 |
12 | grid_hits = sgrid[index_ray]
13 | dist = np.linalg.norm(grid_hits - loc, axis=-1)
14 | dist_im = np.ones(sgrid.shape[0])
15 | dist_im[index_ray] = dist
16 | im = dist_im
17 | return im
18 |
19 |
20 | def make_sgrid(b, alpha, beta, gamma):
21 | res = b * 2
22 | pi = np.pi
23 | phi = np.linspace(0, 180, res * 2 + 1)[1::2]
24 | theta = np.linspace(0, 360, res + 1)[:-1]
25 | grid = np.zeros([res, res, 3])
26 | for idp, p in enumerate(phi):
27 | for idt, t in enumerate(theta):
28 | grid[idp, idt, 2] = np.cos((p * pi / 180))
29 | proj = np.sin((p * pi / 180))
30 | grid[idp, idt, 0] = proj * np.cos(t * pi / 180)
31 | grid[idp, idt, 1] = proj * np.sin(t * pi / 180)
32 | grid = np.reshape(grid, (res * res, 3))
33 | return grid
34 |
35 |
36 | def render_spherical(data, mask, obj_path=None, debug=False):
37 | depth_im = data['depth'][0, 0, :, :]
38 | th = data['depth_minmax']
39 | depth_im = resize(depth_im, 480, 'vertical')
40 | im = resize(mask, 480, 'vertical')
41 | gt_sil = np.where(im > 0.95, 1, 0)
42 | depth_im = depth_im * gt_sil
43 | depth_im = depth_im[:, :, np.newaxis]
44 | b = 64
45 | tdf = depth_to_mesh_df(depth_im, th, False, 1.0, 2.2)
46 | try:
47 | verts, faces, normals, values = measure.marching_cubes_lewiner(
48 | tdf, 0.999 / 128, spacing=(1 / 128, 1 / 128, 1 / 128))
49 | mesh = trimesh.Trimesh(vertices=verts - 0.5, faces=faces)
50 | sgrid = make_sgrid(b, 0, 0, 0)
51 | im_depth = render_model(mesh, sgrid)
52 | im_depth = im_depth.reshape(2 * b, 2 * b)
53 | im_depth = np.where(im_depth > 1, 1, im_depth)
54 | except:
55 | im_depth = np.ones([128, 128])
56 | return im_depth
57 | return im_depth
58 |
--------------------------------------------------------------------------------
/util/util_xml_to_cam_params.py:
--------------------------------------------------------------------------------
1 |
2 | from glob import glob
3 | import re
4 | import argparse
5 | import numpy as np
6 | from pathlib import Path
7 | import os
8 |
9 | def raw_camparam_from_xml(path, pose="lookAt"):
10 | import xml.etree.ElementTree as ET
11 | tree = ET.parse(path)
12 | elm = tree.find("./sensor/transform/" + pose)
13 | camparam = elm.attrib
14 | origin = np.fromstring(camparam['origin'], dtype=np.float32, sep=',')
15 | target = np.fromstring(camparam['target'], dtype=np.float32, sep=',')
16 | up = np.fromstring(camparam['up'], dtype=np.float32, sep=',')
17 | height = int(
18 | tree.find("./sensor/film/integer[@name='height']").attrib['value'])
19 | width = int(
20 | tree.find("./sensor/film/integer[@name='width']").attrib['value'])
21 |
22 | camparam = dict()
23 | camparam['origin'] = origin
24 | camparam['up'] = up
25 | camparam['target'] = target
26 | camparam['height'] = height
27 | camparam['width'] = width
28 | return camparam
29 |
30 | def get_cam_pos(origin, target, up):
31 | inward = origin - target
32 | right = np.cross(up, inward)
33 | up = np.cross(inward, right)
34 | rx = np.cross(up, inward)
35 | ry = np.array(up)
36 | rz = np.array(inward)
37 | rx /= np.linalg.norm(rx)
38 | ry /= np.linalg.norm(ry)
39 | rz /= np.linalg.norm(rz)
40 |
41 | rot = np.stack([
42 | rx,
43 | ry,
44 | -rz
45 | ], axis=0)
46 |
47 |
48 | aff = np.concatenate([
49 | np.eye(3), -origin[:,None]
50 | ], axis=1)
51 |
52 |
53 | ext = np.matmul(rot, aff)
54 |
55 | result = np.concatenate(
56 | [ext, np.array([[0,0,0,1]])], axis=0
57 | )
58 |
59 |
60 |
61 | return result
62 |
63 |
64 |
65 | def convert_cam_params_all_views(datapoint_dir, dataroot, camera_param_dir):
66 | depths = sorted(glob(os.path.join(datapoint_dir, '*depth.png')))
67 | cam_ext = ['_'.join(re.sub(dataroot.strip('/'), camera_param_dir.strip('/'), f).split('_')[:-1])+'.xml' for f in depths]
68 |
69 |
70 | for i, (f, pth) in enumerate(zip(cam_ext, depths)):
71 | if not os.path.exists(f):
72 | continue
73 | params=raw_camparam_from_xml(f)
74 | origin, target, up, width, height = params['origin'], params['target'], params['up'],\
75 | params['width'], params['height']
76 |
77 | ext_matrix = get_cam_pos(origin, target, up)
78 |
79 | #####
80 | diag = (0.036 ** 2 + 0.024 ** 2) ** 0.5
81 | focal_length = 0.05
82 | res = [480, 480]
83 | h_relative = (res[1] / res[0])
84 | sensor_width = np.sqrt(diag ** 2 / (1 + h_relative ** 2))
85 | pix_size = sensor_width / res[0]
86 |
87 | K = np.array([
88 | [focal_length / pix_size, 0, (sensor_width / pix_size - 1) / 2],
89 | [0, -focal_length / pix_size, (sensor_width * (res[1] / res[0]) / pix_size - 1) / 2],
90 | [0, 0, 1]
91 | ])
92 |
93 | np.savez(pth.split('depth.png')[0]+ 'cam_params.npz', extr=ext_matrix, intr=K)
94 |
95 |
96 | def main(opt):
97 | dataroot_dir = Path(opt.dataroot)
98 |
99 | leaf_subdirs = []
100 |
101 | for dirpath, dirnames, filenames in os.walk(dataroot_dir):
102 | if (not dirnames) and opt.mitsuba_xml_root not in dirpath:
103 | leaf_subdirs.append(dirpath)
104 |
105 |
106 |
107 | for k, dir_ in enumerate(leaf_subdirs):
108 | print('Processing dir {}/{}: {}'.format(k, len(leaf_subdirs), dir_))
109 |
110 | convert_cam_params_all_views(dir_, opt.dataroot, opt.mitsuba_xml_root)
111 |
112 |
113 |
114 |
115 | if __name__ == '__main__':
116 | args = argparse.ArgumentParser()
117 | args.add_argument('--dataroot', type=str, help='GenRe data root. Absolute path is recommanded.')
118 | # e.g. '/root/.../data/shapenet/'
119 | args.add_argument('--mitsuba_xml_root', type=str, help='XML directory root. Absolute path is recommanded.')
120 | # e.g. '/root/.../data/genre-xml_v2/'
121 | opt = args.parse_args()
122 |
123 | main(opt)
124 |
--------------------------------------------------------------------------------
/visualize/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "voxel": {
3 | "isosurf_thres": 0.3
4 | }
5 | }
6 |
--------------------------------------------------------------------------------
/visualize/visualizer.py:
--------------------------------------------------------------------------------
1 | from os.path import join, dirname
2 | from os import makedirs
3 | from shutil import copyfile
4 | from multiprocessing import Pool
5 | import atexit
6 | import json
7 | import numpy as np
8 | from skimage import measure
9 | from util.util_img import imwrite_wrapper
10 |
11 |
12 | class Visualizer():
13 | """
14 | Unified Visulization Worker
15 | """
16 | paths = [
17 | 'rgb_path',
18 | 'silhou_path',
19 | 'depth_path',
20 | 'normal_path',
21 | ]
22 | imgs = [
23 | 'rgb',
24 | 'pred_depth',
25 | 'pred_silhou',
26 | 'pred_normal',
27 | ]
28 | voxels = [
29 | 'pred_voxel_noft',
30 | 'pred_voxel',
31 | 'gen_voxel',
32 | ] # will go through sigmoid
33 | txts = [
34 | 'gt_depth_minmax',
35 | 'pred_depth_minmax',
36 | 'disc',
37 | 'scores'
38 | ]
39 | sphmaps = [
40 | 'pred_spherical_full',
41 | 'pred_spherical_partial',
42 | 'gt_spherical_full',
43 | ]
44 | voxels_gt = [
45 | 'pred_proj_depth',
46 | 'gt_voxel',
47 | 'pred_proj_sph_full',
48 | ]
49 |
50 | def __init__(self, n_workers=4, param_f=None):
51 | if n_workers == 0:
52 | pool = None
53 | elif n_workers > 0:
54 | pool = Pool(n_workers)
55 | else:
56 | raise ValueError(n_workers)
57 | self.pool = pool
58 | if param_f:
59 | self.param_f = param_f
60 | else:
61 | self.param_f = join(dirname(__file__), 'config.json')
62 |
63 | def cleanup():
64 | if pool:
65 | pool.close()
66 | pool.join()
67 | atexit.register(cleanup)
68 |
69 | def visualize(self, pack, batch_idx, outdir):
70 | if self.pool:
71 | self.pool.apply_async(
72 | self._visualize,
73 | [pack, batch_idx, self.param_f, outdir],
74 | error_callback=self._error_callback
75 | )
76 | else:
77 | self._visualize(pack, batch_idx, self.param_f, outdir)
78 |
79 | @classmethod
80 | def _visualize(cls, pack, batch_idx, param_f, outdir):
81 | makedirs(outdir, exist_ok=True)
82 |
83 | # Dynamically read parameters from disk
84 | #param_dict = cls._read_params(param_f)
85 | voxel_isosurf_th = 0.25 # param_dict['voxel']['isosurf_thres']
86 |
87 | batch_size = cls._get_batch_size(pack)
88 | instance_cnt = batch_idx * batch_size
89 | counter = 0
90 | for k in cls.paths:
91 | prefix = '{:04d}_%02d_' % counter + k.split('_')[0] + '.png'
92 | cls._cp_img(pack.get(k), join(outdir, prefix), instance_cnt)
93 | counter += 1
94 | for k in cls.imgs:
95 | prefix = '{:04d}_%02d_' % counter + k + '.png'
96 | cls._vis_img(pack.get(k), join(outdir, prefix), instance_cnt)
97 | counter += 1
98 | for k in cls.voxels_gt:
99 | prefix = '{:04d}_%02d_' % counter + k + '.obj'
100 | cls._vis_voxel(pack.get(k), join(outdir, prefix), instance_cnt,
101 | voxel_isosurf_th, False)
102 | counter += 1
103 | for k in cls.voxels:
104 | prefix = '{:04d}_%02d_' % counter + k + '.obj'
105 | cls._vis_voxel(pack.get(k), join(outdir, prefix), instance_cnt,
106 | voxel_isosurf_th)
107 | counter += 1
108 | for k in cls.txts:
109 | prefix = '{:04d}_%02d_' % counter + k + '.txt'
110 | cls._vis_txt(pack.get(k), join(outdir, prefix), instance_cnt)
111 | counter += 1
112 | for k in cls.sphmaps:
113 | prefix = '{:04d}_%02d_' % counter + k + '.png'
114 | cls._vis_sph(pack.get(k), join(outdir, prefix), instance_cnt)
115 | counter += 1
116 |
117 | @staticmethod
118 | def _read_params(param_f):
119 | with open(param_f, 'r') as h:
120 | param_dict = json.load(h)
121 | return param_dict
122 |
123 | @staticmethod
124 | def _get_batch_size(pack):
125 | batch_size = None
126 | for v in pack.values():
127 | if hasattr(v, 'shape'):
128 | if batch_size is None or batch_size == 0:
129 | batch_size = v.shape[0]
130 | else:
131 | assert batch_size == v.shape[0]
132 | return batch_size
133 |
134 | @staticmethod
135 | def _sigmoid(x):
136 | return 1 / (1 + np.exp(-x))
137 |
138 | @staticmethod
139 | def _to_obj_str(verts, faces):
140 | text = ""
141 | for p in verts:
142 | text += "v "
143 | for x in p:
144 | text += "{} ".format(x)
145 | text += "\n"
146 | for f in faces:
147 | text += "f "
148 | for x in f:
149 | text += "{} ".format(x + 1)
150 | text += "\n"
151 | return text
152 |
153 | @classmethod
154 | def _save_iso_obj(cls, df, path, th, shift=True):
155 | if th < np.min(df):
156 | df[0, 0, 0] = th - 1
157 | if th > np.max(df):
158 | df[-1, -1, -1] = th + 1
159 | spacing = (1 / 128, 1 / 128, 1 / 128)
160 | verts, faces, _, _ = measure.marching_cubes_lewiner(
161 | df, th, spacing=spacing)
162 | if shift:
163 | verts -= np.array([0.5, 0.5, 0.5])
164 | obj_str = cls._to_obj_str(verts, faces)
165 | with open(path, 'w') as f:
166 | f.write(obj_str)
167 |
168 | @staticmethod
169 | def _vis_img(img, output_pattern, counter=0):
170 | if img is not None and not isinstance(img, str):
171 | assert img.shape[0] != 0
172 | img = np.clip(img * 255, 0, 255).astype(int)
173 | img = np.transpose(img, (0, 2, 3, 1))
174 | bsize = img.shape[0]
175 | for batch_id in range(bsize):
176 | im = img[batch_id, :, :, :]
177 | imwrite_wrapper(output_pattern.format(counter + batch_id), im)
178 |
179 | @staticmethod
180 | def _vis_sph(img, output_pattern, counter=0):
181 | if img is not None and not isinstance(img, str):
182 | assert img.shape[0] != 0
183 | img = np.transpose(img, (0, 2, 3, 1))
184 | bsize = img.shape[0]
185 | for batch_id in range(bsize):
186 | im = img[batch_id, :, :, 0]
187 | im = im / im.max()
188 | im = np.clip(im * 255, 0, 255).astype(int)
189 | imwrite_wrapper(output_pattern.format(counter + batch_id), im)
190 |
191 | @staticmethod
192 | def _cp_img(paths, output_pattern, counter=0):
193 | if paths is not None:
194 | for batch_id, path in enumerate(paths):
195 | copyfile(path, output_pattern.format(counter + batch_id))
196 |
197 | @classmethod
198 | def _vis_voxel(cls, voxels, output_pattern, counter=0, th=0.5, use_sigmoid=True):
199 | if voxels is not None:
200 | assert voxels.shape[0] != 0
201 | for batch_id, voxel in enumerate(voxels):
202 | if voxel.ndim == 4:
203 | voxel = voxel[0, ...]
204 | voxel = cls._sigmoid(voxel) if use_sigmoid else voxel
205 | cls._save_iso_obj(voxel, output_pattern.format(counter + batch_id), th=th)
206 |
207 | @staticmethod
208 | def _vis_txt(txts, output_pattern, counter=0):
209 | if txts is not None:
210 | for batch_id, txt in enumerate(txts):
211 | with open(output_pattern.format(counter + batch_id), 'w') as h:
212 | h.write("%s\n" % txt)
213 |
214 | @staticmethod
215 | def _error_callback(e):
216 | print(str(e))
217 |
--------------------------------------------------------------------------------
|