├── .gitignore ├── README.md ├── env.yml ├── env_sub.yml ├── lib ├── arguments.py ├── data │ ├── base.py │ ├── cub.py │ ├── cub_pseudo_dataset.py │ └── pseudo_dataset.py ├── external │ └── ChamferDistancePytorch │ │ ├── .DS_Store │ │ ├── .gitignore │ │ ├── LICENSE │ │ ├── README.md │ │ ├── chamfer2D │ │ ├── chamfer2D.cu │ │ ├── chamfer_cuda.cpp │ │ ├── dist_chamfer_2D.py │ │ └── setup.py │ │ ├── chamfer3D │ │ ├── chamfer3D.cu │ │ ├── chamfer_cuda.cpp │ │ ├── dist_chamfer_3D.py │ │ └── setup.py │ │ ├── chamfer5D │ │ ├── chamfer5D.cu │ │ ├── chamfer_cuda.cpp │ │ ├── dist_chamfer_5D.py │ │ └── setup.py │ │ ├── chamfer_python.py │ │ ├── fscore.py │ │ └── unit_test.py ├── mesh_inversion.py ├── mesh_templates │ ├── uvsphere_16rings.obj │ ├── uvsphere_17rings.obj │ ├── uvsphere_31rings.obj │ ├── uvsphere_32rings.obj │ ├── wireframe_16rings.png │ ├── wireframe_17rings.png │ ├── wireframe_31rings.png │ └── wireframe_32rings.png ├── models │ ├── __init__.py │ ├── cmr_mesh_net.py │ ├── cmr_net_blocks.py │ ├── cyclegan_base_model.py │ ├── cyclegan_networks.py │ ├── gan.py │ └── reconstruction.py ├── rendering │ ├── cmr_geom_utils.py │ ├── cmr_mesh.py │ ├── cmr_meshzoo.py │ ├── cmr_nmr_kaolin.py │ ├── fragment_shader.py │ ├── mesh_template.py │ ├── monkey_patches.py │ ├── renderer.py │ └── utils.py ├── sync_batchnorm │ ├── __init__.py │ ├── batchnorm.py │ ├── batchnorm_reimpl.py │ ├── comm.py │ ├── replicate.py │ └── unittest.py └── utils │ ├── __init__.py │ ├── cam_pose.py │ ├── common_utils.py │ ├── fid.py │ ├── image.py │ ├── image_pool.py │ ├── inception.py │ ├── inversion_dist.py │ ├── losses.py │ ├── mask_proj.py │ ├── nn_modules.py │ ├── text_functions.py │ ├── tf_visualizer.py │ ├── transformations.py │ └── vgg_feat.py ├── run_evaluation.py ├── run_inversion.py └── run_pretraining.py /.gitignore: -------------------------------------------------------------------------------- 1 | checkpoints_gan/ 2 | checkpoints_recon/ 3 | tensorboard_pretrain/ 4 | tensorboard_inversion/ 5 | datasets/ 6 | outputs/ 7 | lib/utils/vgg19-dcbb9e9d.pth 8 | __pycache__/ 9 | *.zip 10 | *.py[cod] 11 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Monocular 3D Object Reconstruction with GAN Inversion (ECCV 2022) 2 | 3 | This paper presents a novel GAN Inversion framework for single view 3D object reconstruction. 4 | 5 | * Project page: [link](https://www.mmlab-ntu.com/project/meshinversion/) 6 | * Paper: [link](https://arxiv.org/abs/2207.10061) 7 | * Youtube: [link](https://www.youtube.com/watch?v=13QfxbZqmvM) 8 | 9 | 10 | ## Setup 11 | Install environment: 12 | ``` 13 | conda env create -f env.yml 14 | 15 | # if you couldn't solve the environment: 16 | conda env create -f env_sub.yml 17 | 18 | conda activate mesh_inv 19 | ``` 20 | Install [Kaolin](https://github.com/NVIDIAGameWorks/kaolin) (tested on commit [e7e5131](https://github.com/NVIDIAGameWorks/kaolin/tree/e7e513173bd4159ae45be6b3e156a3ad156a3eb9)). 21 | 22 | Download the [pretrained model](https://drive.google.com/file/d/1TeE_c0V3lWd5y5Ine4Gmesc2O4cfIH9S/view?usp=sharing) and place it under `checkpoints_gan/pretrained`. Download the CUB dataset [CUB_200_2011](http://www.vision.caltech.edu/datasets/cub_200_2011/), [cache](https://drive.google.com/file/d/11PPf-obl-eakPElU6ghcgkje8S8hwFrT/view?usp=sharing), [predicted_mask](https://drive.google.com/file/d/1L-pbvxb6jL7fUEyFPPRgXHNHsK2U01qo/view?usp=sharing), and [PseudoGT](https://drive.google.com/file/d/1wCfVDRx_8DJzfP7aYBX0AQXs4LYxX4rI/view?usp=sharing) for ConvMesh GAN training, and place them under `datasets/cub/`. Alternatively, you can obtained your own predicted mask by PointRend, and you can obtain your own PseudoGT following [ConvMesh](https://github.com/dariopavllo/convmesh). 23 | 24 | ``` 25 | - datasets 26 | - cub 27 | - CUB_200_2011 28 | - cache 29 | - predicted_mask 30 | - pseudogt_512x512 31 | ``` 32 | 33 | ## Reconstruction 34 | The reconstruction results of the test split is obtained through GAN inversion. 35 | ``` 36 | python run_inversion.py --name author_released --checkpoint_dir pretrained 37 | ``` 38 | 39 | ## Evaluation 40 | Evaluation results can be obtained upon GAN inversion. 41 | ``` 42 | python run_evaluation.py --name author_released --eval_option IoU 43 | python run_evaluation.py --name author_released --eval_option FID_1 44 | python run_evaluation.py --name author_released --eval_option FID_12 45 | python run_evaluation.py --name author_released --eval_option FID_10 46 | ``` 47 | 48 | ## Pretraining 49 | You can also pretrain your own GAN from scratch. 50 | ``` 51 | python run_pretraining.py --name self_train --gpu_ids 0,1,2,3 --epochs 600 52 | ``` 53 | 54 | ## Acknowledgement 55 | The code is in part built on [ConvMesh](https://github.com/dariopavllo/convmesh), [ShapeInversion](https://github.com/junzhezhang/shape-inversion) and [CMR](https://github.com/chenyuntc/cmr). Besides, Chamfer Distance is borrowed from [ChamferDistancePytorch](https://github.com/ThibaultGROUEIX/ChamferDistancePytorch), which is included in the `lib/external` folder for convenience. 56 | 57 | ## Citation 58 | ``` 59 | @inproceedings{zhang2022monocular, 60 | title = {Monocular 3D Object Reconstruction with GAN Inversion}, 61 | author = {Zhang, Junzhe and Ren, Daxuan and Cai, Zhongang and Yeo, Chai Kiat and Dai, Bo and Loy, Chen Change}, 62 | booktitle = {ECCV}, 63 | year = {2022}} 64 | ``` 65 | -------------------------------------------------------------------------------- /env.yml: -------------------------------------------------------------------------------- 1 | name: mesh_inv 2 | channels: 3 | - pytorch3d 4 | - pytorch 5 | - bottler 6 | - iopath 7 | - fvcore 8 | - conda-forge 9 | - defaults 10 | dependencies: 11 | - _libgcc_mutex=0.1=main 12 | - _openmp_mutex=4.5=1_gnu 13 | - absl-py=1.0.0=pyhd8ed1ab_0 14 | - blas=1.0=mkl 15 | - brotlipy=0.7.0=py37h27cfd23_1003 16 | - bzip2=1.0.8=h7b6447c_0 17 | - c-ares=1.18.1=h7f8727e_0 18 | - ca-certificates=2022.5.18.1=ha878542_0 19 | - certifi=2022.5.18.1=py37h89c1867_0 20 | - cffi=1.15.0=py37hd667e15_1 21 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 22 | - colorama=0.4.4=pyh9f0ad1d_0 23 | - cryptography=36.0.0=py37h9ce1e76_0 24 | - cudatoolkit=11.3.1=h2bc3f7f_2 25 | - ffmpeg=4.3.2=hca11adc_0 26 | - freetype=2.11.0=h70c0345_0 27 | - future=0.18.2=py37h89c1867_5 28 | - fvcore=0.1.5.post20210915=py37 29 | - giflib=5.2.1=h7b6447c_0 30 | - gmp=6.2.1=h2531618_2 31 | - gnutls=3.6.15=he1e5248_0 32 | - idna=3.3=pyhd3eb1b0_0 33 | - importlib-metadata=4.11.3=py37h89c1867_1 34 | - intel-openmp=2021.4.0=h06a4308_3561 35 | - iopath=0.1.9=py37 36 | - jpeg=9d=h7f8727e_0 37 | - lame=3.100=h7b6447c_0 38 | - lcms2=2.12=h3be6417_0 39 | - ld_impl_linux-64=2.35.1=h7274673_9 40 | - libffi=3.3=he6710b0_2 41 | - libgcc-ng=9.3.0=h5101ec6_17 42 | - libgomp=9.3.0=h5101ec6_17 43 | - libiconv=1.15=h63c8f33_5 44 | - libidn2=2.3.2=h7f8727e_0 45 | - libpng=1.6.37=hbc83047_0 46 | - libprotobuf=3.15.8=h780b84a_0 47 | - libstdcxx-ng=9.3.0=hd4cf53a_17 48 | - libtasn1=4.16.0=h27cfd23_0 49 | - libtiff=4.2.0=h85742a9_0 50 | - libunistring=0.9.10=h27cfd23_0 51 | - libuv=1.40.0=h7b6447c_0 52 | - libwebp=1.2.2=h55f646e_0 53 | - libwebp-base=1.2.2=h7f8727e_0 54 | - lz4-c=1.9.3=h295c915_1 55 | - mkl=2021.4.0=h06a4308_640 56 | - mkl-service=2.4.0=py37h7f8727e_0 57 | - mkl_fft=1.3.1=py37hd3c417c_0 58 | - mkl_random=1.2.2=py37h51133e4_0 59 | - ncurses=6.3=h7f8727e_2 60 | - nettle=3.7.3=hbbd107a_1 61 | - numpy-base=1.21.5=py37hf524024_1 62 | - nvidiacub=1.10.0=0 63 | - olefile=0.46=pyh9f0ad1d_1 64 | - openh264=2.1.1=h4ff587b_0 65 | - openssl=1.1.1o=h7f8727e_0 66 | - pip=21.2.2=py37h06a4308_0 67 | - portalocker=2.4.0=py37h89c1867_0 68 | - pycparser=2.21=pyhd3eb1b0_0 69 | - pyopenssl=22.0.0=pyhd3eb1b0_0 70 | - pysocks=1.7.1=py37_1 71 | - python=3.7.13=h12debd9_0 72 | - python_abi=3.7=2_cp37m 73 | - pytorch=1.11.0=py3.7_cuda11.3_cudnn8.2.0_0 74 | - pytorch-mutex=1.0=cuda 75 | - pytorch3d=0.6.2=py37_cu113_pyt1110 76 | - readline=8.1.2=h7f8727e_1 77 | - requests=2.27.1=pyhd3eb1b0_0 78 | - setuptools=61.2.0=py37h06a4308_0 79 | - six=1.16.0=pyhd3eb1b0_1 80 | - sqlite=3.38.2=hc218d9a_0 81 | - tabulate=0.8.9=pyhd8ed1ab_0 82 | - tk=8.6.11=h1ccaba5_0 83 | - torchaudio=0.11.0=py37_cu113 84 | - torchvision=0.12.0=py37_cu113 85 | - typing_extensions=4.1.1=pyh06a4308_0 86 | - urllib3=1.26.8=pyhd3eb1b0_0 87 | - wheel=0.37.1=pyhd3eb1b0_0 88 | - x264=1!161.3030=h7f98852_1 89 | - xz=5.2.5=h7b6447c_0 90 | - yacs=0.1.8=pyhd8ed1ab_0 91 | - yaml=0.2.5=h516909a_0 92 | - zipp=3.8.0=pyhd8ed1ab_0 93 | - zlib=1.2.11=h7f8727e_4 94 | - zstd=1.4.9=haebb681_0 95 | - pip: 96 | - addict==2.4.0 97 | - aiohttp==3.8.1 98 | - aiosignal==1.2.0 99 | - antlr4-python3-runtime==4.8 100 | - argon2-cffi==21.3.0 101 | - argon2-cffi-bindings==21.2.0 102 | - asttokens==2.0.5 103 | - async-timeout==4.0.2 104 | - asynctest==0.13.0 105 | - attrs==21.4.0 106 | - backcall==0.2.0 107 | - beautifulsoup4==4.11.1 108 | - bleach==5.0.0 109 | - cachetools==4.2.4 110 | - ccimport==0.3.7 111 | - chumpy==0.70 112 | - click==8.1.3 113 | - cumm-cu113==0.2.8 114 | - cycler==0.11.0 115 | - cython==0.29.28 116 | - dataclasses==0.8 117 | - debugpy==1.6.0 118 | - decorator==5.1.1 119 | - defusedxml==0.7.1 120 | - docker-pycreds==0.4.0 121 | - entrypoints==0.4 122 | - executing==0.8.3 123 | - fastjsonschema==2.15.3 124 | - fire==0.4.0 125 | - fonttools==4.32.0 126 | - freetype-py==2.2.0 127 | - frozenlist==1.3.0 128 | - fsspec==2022.3.0 129 | - gitdb==4.0.9 130 | - gitpython==3.1.27 131 | - google-auth==1.35.0 132 | - google-auth-oauthlib==0.4.6 133 | - gputil==1.4.0 134 | - grpcio==1.44.0 135 | - human-det==0.0.2 136 | - hydra-core==1.0.6 137 | - icecream==2.1.2 138 | - imageio==2.3.0 139 | - imageio-ffmpeg==0.4.7 140 | - importlib-resources==5.7.0 141 | - ipykernel==6.13.0 142 | - ipython==7.32.0 143 | - ipython-genutils==0.2.0 144 | - ipywidgets==7.7.0 145 | - jedi==0.18.1 146 | - jinja2==3.1.1 147 | - joblib==1.1.0 148 | - jsonschema==4.4.0 149 | - jupyter-client==7.2.2 150 | - jupyter-core==4.9.2 151 | - jupyterlab-pygments==0.2.1 152 | - jupyterlab-widgets==1.1.0 153 | - kiwisolver==1.4.2 154 | - lark==1.1.2 155 | - libcst==0.4.7 156 | - markdown==3.3.6 157 | - markupsafe==2.1.1 158 | - matplotlib==3.5.1 159 | - matplotlib-inline==0.1.3 160 | - mistune==0.8.4 161 | - multidict==6.0.2 162 | - mypy-extensions==0.4.3 163 | - nbclient==0.6.0 164 | - nbconvert==6.5.0 165 | - nbformat==5.3.0 166 | - nest-asyncio==1.5.5 167 | - networkx==2.6.3 168 | - ninja==1.10.2.3 169 | - notebook==6.4.10 170 | - numpy==1.21.6 171 | - nuscenes-devkit==1.1.9 172 | - oauthlib==3.2.0 173 | - omegaconf==2.0.6 174 | - open3d==0.9.0.0 175 | - opencv-python==3.4.17.63 176 | - opencv-python-headless==4.6.0.66 177 | - packaging==21.3 178 | - pandas==1.3.5 179 | - pandocfilters==1.5.0 180 | - parso==0.8.3 181 | - pathspec==0.9.0 182 | - pathtools==0.1.2 183 | - pccm==0.3.4 184 | - pexpect==4.8.0 185 | - pickleshare==0.7.5 186 | - pillow==6.2.2 187 | - plotly==5.7.0 188 | - plyfile==0.6 189 | - pptk==0.1.0 190 | - prometheus-client==0.14.1 191 | - promise==2.3 192 | - prompt-toolkit==3.0.29 193 | - protobuf==3.20.0 194 | - psutil==5.9.0 195 | - ptyprocess==0.7.0 196 | - pyasn1==0.4.8 197 | - pyasn1-modules==0.2.8 198 | - pybind11==2.9.2 199 | - pycln==2.0.1 200 | - pydeprecate==0.3.2 201 | - pyglet==1.5.16 202 | - pygments==2.11.2 203 | - pyhocon==0.3.59 204 | - pymcubes==0.1.0 205 | - pyopengl==3.1.0 206 | - pyparsing==2.4.7 207 | - pyquaternion==0.9.9 208 | - pyrender==0.1.45 209 | - pyrsistent==0.18.1 210 | - python-dateutil==2.8.2 211 | - pytorch-fid==0.2.0 212 | - pytorch-lightning==1.6.1 213 | - pytz==2022.1 214 | - pywavelets==1.3.0 215 | - pyyaml==5.4.1 216 | - pyzmq==22.3.0 217 | - requests-oauthlib==1.3.1 218 | - rsa==4.8 219 | - rtree==1.0.0 220 | - scikit-image==0.17.2 221 | - scikit-learn==1.0.2 222 | - scikit-video==1.1.11 223 | - scipy==1.5.2 224 | - send2trash==1.8.0 225 | - sentry-sdk==1.5.12 226 | - setproctitle==1.2.3 227 | - shortuuid==1.0.9 228 | - smmap==5.0.0 229 | - smplx==0.1.28 230 | - soupsieve==2.3.2 231 | - spconv-cu113==2.1.21 232 | - tenacity==8.0.1 233 | - tensorboard==2.4.1 234 | - tensorboard-plugin-wit==1.8.0 235 | - tensorboardx==1.2 236 | - termcolor==1.1.0 237 | - terminado==0.13.3 238 | - threadpoolctl==3.1.0 239 | - tifffile==2021.11.2 240 | - tinycss2==1.1.1 241 | - toml==0.10.2 242 | - torch==1.7.0 243 | - torch-ema==0.2 244 | - torch-fidelity==0.2.0 245 | - torchmetrics==0.8.2 246 | - tornado==6.1 247 | - tqdm==4.32.1 248 | - traitlets==5.1.1 249 | - trimesh==3.6.10 250 | - typer==0.4.2 251 | - typing-inspect==0.7.1 252 | - unknown==0.0.0 253 | - usd-core==20.11 254 | - wandb==0.12.18 255 | - wcwidth==0.2.5 256 | - webencodings==0.5.1 257 | - werkzeug==2.1.1 258 | - widgetsnbextension==3.6.0 259 | - yarl==1.7.2 260 | prefix: /home/zhangjunzhe/anaconda3/envs/mesh_inv 261 | -------------------------------------------------------------------------------- /env_sub.yml: -------------------------------------------------------------------------------- 1 | name: mesh_inv 2 | channels: 3 | - pytorch3d 4 | - pytorch 5 | - bottler 6 | - iopath 7 | - fvcore 8 | - conda-forge 9 | - defaults 10 | dependencies: 11 | - _libgcc_mutex=0.1=main 12 | - _openmp_mutex=4.5=1_gnu 13 | - absl-py=1.0.0=pyhd8ed1ab_0 14 | - blas=1.0=mkl 15 | - brotlipy=0.7.0=py37h27cfd23_1003 16 | - bzip2=1.0.8=h7b6447c_0 17 | - c-ares=1.18.1=h7f8727e_0 18 | - ca-certificates=2022.5.18.1=ha878542_0 19 | - certifi=2022.5.18.1=py37h89c1867_0 20 | - cffi=1.15.0=py37hd667e15_1 21 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 22 | - colorama=0.4.4=pyh9f0ad1d_0 23 | - cryptography=36.0.0=py37h9ce1e76_0 24 | - cudatoolkit=11.3.1=h2bc3f7f_2 25 | - ffmpeg=4.3.2=hca11adc_0 26 | - freetype=2.11.0=h70c0345_0 27 | - future=0.18.2=py37h89c1867_5 28 | - fvcore=0.1.5.post20210915=py37 29 | - giflib=5.2.1=h7b6447c_0 30 | - gmp=6.2.1=h2531618_2 31 | - gnutls=3.6.15=he1e5248_0 32 | - idna=3.3=pyhd3eb1b0_0 33 | - importlib-metadata=4.11.3=py37h89c1867_1 34 | - intel-openmp=2021.4.0=h06a4308_3561 35 | - iopath=0.1.9=py37 36 | - jpeg=9d=h7f8727e_0 37 | - lame=3.100=h7b6447c_0 38 | - lcms2=2.12=h3be6417_0 39 | - ld_impl_linux-64=2.35.1=h7274673_9 40 | - libffi=3.3=he6710b0_2 41 | - libgcc-ng=9.3.0=h5101ec6_17 42 | - libgomp=9.3.0=h5101ec6_17 43 | - libiconv=1.15=h63c8f33_5 44 | - libidn2=2.3.2=h7f8727e_0 45 | - libpng=1.6.37=hbc83047_0 46 | - libprotobuf=3.15.8=h780b84a_0 47 | - libstdcxx-ng=9.3.0=hd4cf53a_17 48 | - libtasn1=4.16.0=h27cfd23_0 49 | - libtiff=4.2.0=h85742a9_0 50 | - libunistring=0.9.10=h27cfd23_0 51 | - libuv=1.40.0=h7b6447c_0 52 | - libwebp=1.2.2=h55f646e_0 53 | - libwebp-base=1.2.2=h7f8727e_0 54 | - lz4-c=1.9.3=h295c915_1 55 | - mkl=2021.4.0=h06a4308_640 56 | - mkl-service=2.4.0=py37h7f8727e_0 57 | - mkl_fft=1.3.1=py37hd3c417c_0 58 | - mkl_random=1.2.2=py37h51133e4_0 59 | - ncurses=6.3=h7f8727e_2 60 | - nettle=3.7.3=hbbd107a_1 61 | - nvidiacub=1.10.0=0 62 | - olefile=0.46=pyh9f0ad1d_1 63 | - openh264=2.1.1=h4ff587b_0 64 | - openjpeg=2.4.0=hb52868f_1 65 | - openssl=1.1.1o=h7f8727e_0 66 | - portalocker=2.4.0=py37h89c1867_0 67 | - pycparser=2.21=pyhd3eb1b0_0 68 | - pyopenssl=22.0.0=pyhd3eb1b0_0 69 | - pysocks=1.7.1=py37_1 70 | - python=3.7.13=h12debd9_0 71 | - python_abi=3.7=2_cp37m 72 | - pytorch=1.11.0=py3.7_cuda11.3_cudnn8.2.0_0 73 | - pytorch-mutex=1.0=cuda 74 | - pytorch3d=0.6.2=py37_cu113_pyt1110 75 | - pyyaml=5.4.1=py37h5e8e339_0 76 | - readline=8.1.2=h7f8727e_1 77 | - requests=2.27.1=pyhd3eb1b0_0 78 | - six=1.16.0=pyhd3eb1b0_1 79 | - sqlite=3.38.2=hc218d9a_0 80 | - tabulate=0.8.9=pyhd8ed1ab_0 81 | - tk=8.6.11=h1ccaba5_0 82 | - torchaudio=0.11.0=py37_cu113 83 | - torchvision=0.12.0=py37_cu113 84 | - typing_extensions=4.1.1=pyh06a4308_0 85 | - urllib3=1.26.8=pyhd3eb1b0_0 86 | - x264=1!161.3030=h7f98852_1 87 | - xz=5.2.5=h7b6447c_0 88 | - yacs=0.1.8=pyhd8ed1ab_0 89 | - yaml=0.2.5=h516909a_0 90 | - zipp=3.8.0=pyhd8ed1ab_0 91 | - zlib=1.2.11=h7f8727e_4 92 | - zstd=1.4.9=haebb681_0 93 | - pip: 94 | - addict==2.4.0 95 | - aiohttp==3.8.1 96 | - aiosignal==1.2.0 97 | - antlr4-python3-runtime==4.8 98 | - argon2-cffi==21.3.0 99 | - argon2-cffi-bindings==21.2.0 100 | - asttokens==2.0.5 101 | - async-timeout==4.0.2 102 | - asynctest==0.13.0 103 | - attrs==21.4.0 104 | - backcall==0.2.0 105 | - beautifulsoup4==4.11.1 106 | - bleach==5.0.0 107 | - cachetools==4.2.4 108 | - ccimport==0.3.7 109 | - chumpy==0.70 110 | - click==8.1.3 111 | - cumm-cu113==0.2.8 112 | - cycler==0.11.0 113 | - cython==0.29.28 114 | - dataclasses==0.6 115 | - debugpy==1.6.0 116 | - decorator==5.1.1 117 | - defusedxml==0.7.1 118 | - descartes==1.1.0 119 | - docker-pycreds==0.4.0 120 | - entrypoints==0.4 121 | - executing==0.8.3 122 | - fastjsonschema==2.15.3 123 | - fire==0.4.0 124 | - fonttools==4.32.0 125 | - freetype-py==2.2.0 126 | - frozenlist==1.3.0 127 | - fsspec==2022.3.0 128 | - gitdb==4.0.9 129 | - gitpython==3.1.27 130 | - google-auth==1.35.0 131 | - google-auth-oauthlib==0.4.6 132 | - gputil==1.4.0 133 | - grpcio==1.44.0 134 | - human-det==0.0.2 135 | - hydra-core==1.0.6 136 | - icecream==2.1.2 137 | - imageio==2.3.0 138 | - imageio-ffmpeg==0.4.7 139 | - importlib-resources==5.7.0 140 | - ipykernel==6.13.0 141 | - ipython==7.32.0 142 | - ipython-genutils==0.2.0 143 | - ipywidgets==7.7.0 144 | - jedi==0.18.1 145 | - jinja2==3.1.1 146 | - joblib==1.1.0 147 | - jsonschema==4.4.0 148 | - jupyter==1.0.0 149 | - jupyter-client==7.2.2 150 | - jupyter-console==6.4.4 151 | - jupyter-core==4.9.2 152 | - jupyterlab-pygments==0.2.1 153 | - jupyterlab-widgets==1.1.0 154 | - kiwisolver==1.4.2 155 | - lark==1.1.2 156 | - libcst==0.4.7 157 | - markdown==3.3.6 158 | - markupsafe==2.1.1 159 | - matplotlib==3.5.1 160 | - matplotlib-inline==0.1.3 161 | - mistune==0.8.4 162 | - multidict==6.0.2 163 | - mypy-extensions==0.4.3 164 | - nbclient==0.6.0 165 | - nbconvert==6.5.0 166 | - nbformat==5.3.0 167 | - nest-asyncio==1.5.5 168 | - networkx==2.6.3 169 | - ninja==1.10.2.3 170 | - notebook==6.4.10 171 | - numpy==1.21.6 172 | - nuscenes-devkit==1.1.9 173 | - oauthlib==3.2.0 174 | - omegaconf==2.0.6 175 | - open3d==0.9.0.0 176 | - opencv-python==3.4.17.63 177 | - opencv-python-headless==4.6.0.66 178 | - packaging==21.3 179 | - pandas==1.3.5 180 | - pandocfilters==1.5.0 181 | - parso==0.8.3 182 | - pathspec==0.9.0 183 | - pathtools==0.1.2 184 | - pccm==0.3.4 185 | - pexpect==4.8.0 186 | - pickleshare==0.7.5 187 | - pillow==6.2.2 188 | - pip==22.3.1 189 | - plotly==5.7.0 190 | - plyfile==0.6 191 | - pptk==0.1.0 192 | - prometheus-client==0.14.1 193 | - promise==2.3 194 | - prompt-toolkit==3.0.29 195 | - protobuf==3.20.0 196 | - psutil==5.9.0 197 | - ptyprocess==0.7.0 198 | - pyasn1==0.4.8 199 | - pyasn1-modules==0.2.8 200 | - pybind11==2.9.2 201 | - pycln==2.0.1 202 | - pycocotools==2.0.6 203 | - pydeprecate==0.3.2 204 | - pyglet==1.5.16 205 | - pygments==2.11.2 206 | - pyhocon==0.3.59 207 | - pymcubes==0.1.0 208 | - pyopengl==3.1.0 209 | - pyparsing==2.4.7 210 | - pyquaternion==0.9.9 211 | - pyrender==0.1.45 212 | - pyrsistent==0.18.1 213 | - python-dateutil==2.8.2 214 | - pytorch-fid==0.2.1 215 | - pytorch-lightning==1.6.1 216 | - pytz==2022.1 217 | - pywavelets==1.3.0 218 | - pyzmq==22.3.0 219 | - qtconsole==5.4.0 220 | - qtpy==2.3.0 221 | - requests-oauthlib==1.3.1 222 | - rsa==4.8 223 | - rtree==1.0.0 224 | - scikit-image==0.16.2 225 | - scikit-learn==1.0.2 226 | - scikit-video==1.1.11 227 | - scipy==1.5.2 228 | - send2trash==1.8.0 229 | - sentry-sdk==1.5.12 230 | - setproctitle==1.2.3 231 | - setuptools==65.6.3 232 | - shapely==1.8.5.post1 233 | - shortuuid==1.0.9 234 | - smmap==5.0.0 235 | - smplx==0.1.28 236 | - soupsieve==2.3.2 237 | - spconv-cu113==2.1.21 238 | - tenacity==8.0.1 239 | - tensorboard==2.4.1 240 | - tensorboard-plugin-wit==1.8.0 241 | - tensorboardx==1.2 242 | - termcolor==1.1.0 243 | - terminado==0.13.3 244 | - threadpoolctl==3.1.0 245 | - tifffile==2021.11.2 246 | - tinycss2==1.1.1 247 | - toml==0.10.2 248 | - torch-ema==0.2 249 | - torch-fidelity==0.3.0 250 | - torchmetrics==0.8.2 251 | - tornado==6.1 252 | - tqdm==4.32.1 253 | - traitlets==5.1.1 254 | - trimesh==3.6.10 255 | - typer==0.4.2 256 | - typing-inspect==0.7.1 257 | - usd-core==20.11 258 | - wandb==0.12.18 259 | - wcwidth==0.2.5 260 | - webencodings==0.5.1 261 | - werkzeug==2.1.1 262 | - wheel==0.38.4 263 | - widgetsnbextension==3.6.0 264 | - yarl==1.7.2 265 | prefix: /home/watarukawakami/.conda/envs/mesh_inv 266 | -------------------------------------------------------------------------------- /lib/data/base.py: -------------------------------------------------------------------------------- 1 | """ 2 | Base data loading class. 3 | 4 | Should output: 5 | - img: B X 3 X H X W 6 | - kp: B X nKp X 2 7 | - mask: B X H X W 8 | - sfm_pose: B X 7 (s, tr, q) 9 | (kp, sfm_pose) correspond to image coordinates in [-1, 1] 10 | """ 11 | 12 | from __future__ import absolute_import 13 | from __future__ import division 14 | from __future__ import print_function 15 | 16 | import os.path as osp 17 | import numpy as np 18 | 19 | import scipy.linalg 20 | import scipy.ndimage.interpolation 21 | from skimage.io import imread 22 | from absl import flags, app 23 | 24 | import torch 25 | from torch.utils.data import Dataset 26 | from torch.utils.data import DataLoader 27 | from torch.utils.data.dataloader import default_collate 28 | 29 | from lib.utils import image as image_utils 30 | from lib.utils import transformations 31 | from lib.utils.inversion_dist import * 32 | import cv2 33 | import os 34 | 35 | # -------------- Dataset ------------- # 36 | # ------------------------------------ # 37 | class BaseDataset(Dataset): 38 | ''' 39 | img, mask, kp, pose data loader 40 | ''' 41 | 42 | def __init__(self, args, filter_key=None): 43 | self.args = args 44 | self.img_size = args.img_size 45 | self.jitter_frac = args.jitter_frac 46 | self.padding_frac = args.padding_frac 47 | self.filter_key = filter_key 48 | 49 | 50 | 51 | 52 | def forward_img(self, index): 53 | data = self.anno[index] 54 | data_sfm = self.anno_sfm[index] 55 | 56 | if self.args.use_predicted_mask: 57 | pred_mask_path = os.path.join(self.pred_mask_dir,str(index)+'.npy') 58 | pred_mask = np.load(pred_mask_path,allow_pickle=True) 59 | 60 | ### IoU 61 | input_mask = data.mask 62 | overlap_mask = input_mask * pred_mask 63 | union_mask = input_mask + pred_mask - overlap_mask 64 | iou = overlap_mask.sum()/union_mask.sum() 65 | 66 | ### replace 67 | data.mask = pred_mask 68 | 69 | # sfm_pose = (sfm_c, sfm_t, sfm_r) 70 | sfm_pose = [np.copy(data_sfm.scale), np.copy(data_sfm.trans), np.copy(data_sfm.rot)] 71 | 72 | sfm_rot = np.pad(sfm_pose[2], (0,1), 'constant') 73 | sfm_rot[3, 3] = 1 74 | sfm_pose[2] = transformations.quaternion_from_matrix(sfm_rot, isprecise=True) 75 | 76 | if self.args.dataset == 'p3d': 77 | # NOTE: temp fix car_imagenet\\n02814533_4600.JPEG 78 | sub_dirs = data.rel_path.split('\\') 79 | data.rel_path = sub_dirs[0]+'/'+sub_dirs[1] 80 | img_path = osp.join(self.img_dir, str(data.rel_path)) 81 | 82 | img = imread(img_path) / 255.0 83 | # Some are grayscale: 84 | if len(img.shape) == 2: 85 | img = np.repeat(np.expand_dims(img, 2), 3, axis=2) 86 | mask = np.expand_dims(data.mask, 2) 87 | 88 | 89 | # Adjust to 0 indexing 90 | bbox = np.array( 91 | [data.bbox.x1, data.bbox.y1, data.bbox.x2, data.bbox.y2], 92 | float) - 1 93 | 94 | parts = data.parts.T.astype(float) 95 | kp = np.copy(parts) 96 | vis = kp[:, 2] > 0 97 | kp[vis, :2] -= 1 98 | 99 | # Peturb bbox 100 | if self.args.split == 'train': 101 | bbox = image_utils.peturb_bbox( 102 | bbox, pf=self.padding_frac, jf=self.jitter_frac) 103 | else: 104 | bbox = image_utils.peturb_bbox( 105 | bbox, pf=self.padding_frac, jf=0) 106 | bbox = image_utils.square_bbox(bbox) 107 | 108 | # crop image around bbox, translate kps 109 | img, mask, kp, sfm_pose = self.crop_image(img, mask, bbox, kp, vis, sfm_pose) 110 | # set(data.mask.reshape(-1).tolist()) --> {0,1} I # NOTE: mask error is in below 111 | # scale image, and mask. And scale kps. 112 | img, mask, kp, sfm_pose = self.scale_image(img, mask, kp, vis, sfm_pose) 113 | 114 | # Mirror image on random. 115 | if self.args.split == 'train' and (not self.args.no_mirror): 116 | img, mask, kp, sfm_pose = self.mirror_image(img, mask, kp, sfm_pose) 117 | 118 | # Normalize kp to be [-1, 1] 119 | img_h, img_w = img.shape[:2] 120 | kp_norm, sfm_pose = self.normalize_kp(kp, sfm_pose, img_h, img_w) 121 | 122 | # Finally transpose the image to 3xHxW 123 | img = np.transpose(img, (2, 0, 1)) 124 | 125 | return img, kp_norm, mask, sfm_pose, img_path 126 | 127 | def normalize_kp(self, kp, sfm_pose, img_h, img_w): 128 | vis = kp[:, 2, None] > 0 129 | new_kp = np.stack([2 * (kp[:, 0] / img_w) - 1, 130 | 2 * (kp[:, 1] / img_h) - 1, 131 | kp[:, 2]]).T 132 | sfm_pose[0] *= (1.0/img_w + 1.0/img_h) 133 | sfm_pose[1][0] = 2.0 * (sfm_pose[1][0] / img_w) - 1 134 | sfm_pose[1][1] = 2.0 * (sfm_pose[1][1] / img_h) - 1 135 | new_kp = vis * new_kp 136 | 137 | return new_kp, sfm_pose 138 | 139 | def crop_image(self, img, mask, bbox, kp, vis, sfm_pose): 140 | # crop image and mask and translate kps 141 | img = image_utils.crop(img, bbox, bgval=1) 142 | mask = image_utils.crop(mask, bbox, bgval=0) 143 | kp[vis, 0] -= bbox[0] 144 | kp[vis, 1] -= bbox[1] 145 | sfm_pose[1][0] -= bbox[0] 146 | sfm_pose[1][1] -= bbox[1] 147 | return img, mask, kp, sfm_pose 148 | 149 | def scale_image(self, img, mask, kp, vis, sfm_pose): 150 | # Scale image so largest bbox size is img_size 151 | bwidth = np.shape(img)[0] 152 | bheight = np.shape(img)[1] 153 | scale = self.img_size / float(max(bwidth, bheight)) 154 | 155 | img_scale, _ = image_utils.resize_img(img, scale) 156 | mask_scale, _ = image_utils.resize_img(mask, scale) # NOTE bug is here 157 | 158 | kp[vis, :2] *= scale 159 | sfm_pose[0] *= scale 160 | sfm_pose[1] *= scale 161 | 162 | mask_scale04 = ((mask_scale > self.args.target_mask_threshold) * 1).astype('float64') 163 | 164 | return img_scale, mask_scale04, kp, sfm_pose 165 | 166 | def mirror_image(self, img, mask, kp, sfm_pose): 167 | kp_perm = self.kp_perm 168 | if np.random.rand(1) > 0.5: 169 | # Need copy bc torch collate doesnt like neg strides 170 | img_flip = img[:, ::-1, :].copy() 171 | mask_flip = mask[:, ::-1].copy() 172 | 173 | # Flip kps. 174 | new_x = img.shape[1] - kp[:, 0] - 1 175 | kp_flip = np.hstack((new_x[:, None], kp[:, 1:])) 176 | kp_flip = kp_flip[kp_perm, :] 177 | # Flip sfm_pose Rot. 178 | R = transformations.quaternion_matrix(sfm_pose[2]) 179 | flip_R = np.diag([-1, 1, 1, 1]).dot(R.dot(np.diag([-1, 1, 1, 1]))) 180 | sfm_pose[2] = transformations.quaternion_from_matrix(flip_R, isprecise=True) 181 | # Flip tx 182 | tx = img.shape[1] - sfm_pose[1][0] - 1 183 | sfm_pose[1][0] = tx 184 | return img_flip, mask_flip, kp_flip, sfm_pose 185 | else: 186 | return img, mask, kp, sfm_pose 187 | 188 | def __len__(self): 189 | return self.num_imgs 190 | 191 | def __getitem__(self, index): 192 | self.index = index 193 | img, kp, mask, sfm_pose, img_path = self.forward_img(index) 194 | sfm_pose[0].shape = 1 195 | 196 | mask_tensor = torch.from_numpy(mask) 197 | mask_dts = image_utils.compute_dt_barrier(mask_tensor) 198 | 199 | basename = os.path.basename(img_path) 200 | 201 | category = np.array([self.basename_to_class[basename]]) 202 | 203 | elem = { 204 | 'idx': index, 205 | 'img': img, 206 | 'kp': kp, 207 | 'mask': mask, 208 | 'sfm_pose': np.concatenate(sfm_pose), 209 | 'inds': index, 210 | 'img_path': img_path, 211 | 'mask_dt': mask_dts, 212 | 'class': category, 213 | } 214 | 215 | if self.filter_key is not None: 216 | if self.filter_key not in elem.keys(): 217 | print('Bad filter key %s' % self.filter_key) 218 | if self.filter_key == 'sfm_pose': 219 | # Return both vis and sfm_pose 220 | vis = elem['kp'][:, 2] 221 | elem = { 222 | 'vis': vis, 223 | 'sfm_pose': elem['sfm_pose'], 224 | } 225 | else: 226 | elem = elem[self.filter_key] 227 | 228 | return elem 229 | 230 | # ------------ Data Loader ----------- # 231 | # ------------------------------------ # 232 | def base_loader(d_set_func, batch_size, args, filter_key=None, shuffle=True): 233 | if args.dataset == 'cub': 234 | dset = d_set_func(args, filter_key=filter_key) 235 | elif args.dataset == 'p3d': 236 | dset = d_set_func(args) 237 | else: 238 | raise 239 | try: 240 | sampler = DistributedSampler(dset) if args.dist else None 241 | if args.dist: 242 | shuffle = False 243 | except: 244 | sampler = None 245 | 246 | return DataLoader( 247 | dset, 248 | batch_size=batch_size, 249 | shuffle=shuffle, 250 | sampler=sampler, 251 | num_workers=args.num_workers, 252 | drop_last=False) 253 | -------------------------------------------------------------------------------- /lib/data/cub.py: -------------------------------------------------------------------------------- 1 | """ 2 | CUB has 11788 images total, for 200 subcategories. 3 | 5994 train, 5794 test images. 4 | 5 | After removing images that are truncated: 6 | min kp threshold 6: 5964 train, 5771 test. 7 | min_kp threshold 7: 5937 train, 5747 test. 8 | 9 | """ 10 | from __future__ import absolute_import 11 | from __future__ import division 12 | from __future__ import print_function 13 | 14 | import os.path as osp 15 | import numpy as np 16 | 17 | import scipy.io as sio 18 | from absl import flags, app 19 | 20 | import torch 21 | from torch.utils.data import Dataset 22 | 23 | from lib.data import base as base_data 24 | 25 | # -------------- Dataset ------------- # 26 | # ------------------------------------ # 27 | class CUBDataset(base_data.BaseDataset): 28 | ''' 29 | CUB Data loader 30 | ''' 31 | 32 | def __init__(self, args, filter_key=None): 33 | super(CUBDataset, self).__init__(args, filter_key=filter_key) 34 | self.data_dir = args.data_dir 35 | 36 | self.img_dir = osp.join(self.data_dir, 'CUB_200_2011', 'images') 37 | self.anno_path = osp.join(self.data_dir, 'cache', 'data', '%s_cub_cleaned.mat' % args.split) 38 | self.anno_sfm_path = osp.join(self.data_dir, 'cache', 'sfm', 'anno_%s.mat' % args.split) 39 | 40 | if self.args.use_predicted_mask: 41 | self.pred_mask_dir = osp.join(self.data_dir, 'predicted_mask') 42 | 43 | self.filter_key = filter_key 44 | 45 | # Load the annotation file. 46 | self.anno = sio.loadmat( 47 | self.anno_path, struct_as_record=False, squeeze_me=True)['images'] 48 | self.anno_sfm = sio.loadmat( 49 | self.anno_sfm_path, struct_as_record=False, squeeze_me=True)['sfm_anno'] 50 | 51 | self.num_imgs = len(self.anno) 52 | print('%d images' % self.num_imgs) 53 | self.kp_perm = np.array([1, 2, 3, 4, 5, 6, 11, 12, 13, 10, 7, 8, 9, 14, 15]) - 1; 54 | 55 | ### get basename to class dictionary 56 | # copied from CubPseudoDataset 57 | # Load CUB labels 58 | cub_path = 'datasets/cub/CUB_200_2011' 59 | import os 60 | with open(os.path.join(cub_path, 'images.txt'), 'r') as f: 61 | images = f.readlines() 62 | images = [x.split(' ') for x in images] 63 | ids = {k: v.strip() for k, v in images} 64 | 65 | with open(os.path.join(cub_path, 'image_class_labels.txt'), 'r') as f: 66 | classes = f.readlines() 67 | classes = [x.split(' ') for x in classes] 68 | classes = {k: int(v.strip())-1 for k, v in classes} 69 | 70 | self.basename_to_class = {} 71 | for k, c in classes.items(): 72 | fname = ids[k] 73 | basename = os.path.basename(fname) 74 | self.basename_to_class[basename] = c 75 | 76 | #----------- Data Loader ----------# 77 | #----------------------------------# 78 | def data_loader(args, shuffle=False): 79 | return base_data.base_loader(CUBDataset, args.batch_size, args, filter_key=None, shuffle=shuffle) 80 | 81 | 82 | def kp_data_loader(batch_size, args): 83 | return base_data.base_loader(CUBDataset, batch_size, args, filter_key='kp') 84 | 85 | 86 | def mask_data_loader(batch_size, args): 87 | return base_data.base_loader(CUBDataset, batch_size, args, filter_key='mask') 88 | 89 | 90 | def sfm_data_loader(batch_size, args): 91 | return base_data.base_loader(CUBDataset, batch_size, args, filter_key='sfm_pose') 92 | -------------------------------------------------------------------------------- /lib/data/cub_pseudo_dataset.py: -------------------------------------------------------------------------------- 1 | from lib.data.pseudo_dataset import BasePseudoDataset 2 | 3 | import numpy as np 4 | import torch 5 | import os 6 | import pickle 7 | 8 | class CubPseudoDataset(BasePseudoDataset): 9 | def __init__(self, args, **kwargs): 10 | super().__init__(args, **kwargs) 11 | 12 | self.n_classes = args.n_classes 13 | 14 | # Load CUB labels 15 | cub_path = os.path.join(self.args.data_dir, 'CUB_200_2011') 16 | 17 | with open(os.path.join(cub_path, 'images.txt'), 'r') as f: 18 | images = f.readlines() 19 | images = [x.split(' ') for x in images] 20 | ids = {k: v.strip() for k, v in images} 21 | 22 | with open(os.path.join(cub_path, 'image_class_labels.txt'), 'r') as f: 23 | classes = f.readlines() 24 | classes = [x.split(' ') for x in classes] 25 | classes = {k: int(v.strip())-1 for k, v in classes} 26 | 27 | self.filename_to_class = {} 28 | for k, c in classes.items(): 29 | fname = ids[k] 30 | self.filename_to_class[fname] = c 31 | 32 | self.classes = [np.array([self.filename_to_class[x]]) for x in self.data['path']] 33 | 34 | num_images = len(self.data['path']) 35 | if args.conditional_encoding: 36 | filename = args.encoding_pickle_pathname 37 | with open(filename, 'rb') as filehandler: 38 | self.feat = pickle.load(filehandler) 39 | 40 | if args.conditional_text: 41 | from lib.utils.text_functions import TextDataProcessorCUB 42 | 43 | cub_text_path = cub_path 44 | self.text_processor = TextDataProcessorCUB(cub_text_path, 'train', 45 | captions_per_image=10, 46 | words_num=args.text_max_length) 47 | 48 | self.image_index_to_caption_index = {} 49 | for ind, el in enumerate(self.data['path']): 50 | self.image_index_to_caption_index[ind] = self.text_processor.filenames_to_index[el] 51 | 52 | # Randomly select a sentence for evaluation 53 | np.random.seed(1234) 54 | sent_ix = np.random.randint(0, 10) 55 | self.index_captions = [self.text_processor.get_caption(self.image_index_to_caption_index[idx_gt] *\ 56 | self.text_processor.embeddings_num+sent_ix, words_num=25) for idx_gt in range(num_images)] 57 | 58 | # print('Loaded CUB dataset with {} images and {} classes'.format(num_images, self.n_classes)) 59 | 60 | def name(self): 61 | return 'cub' 62 | 63 | def suggest_truncation_sigma(self): 64 | args = self.args 65 | if args.conditional_class: 66 | return 0.25 67 | elif args.conditional_text: 68 | return 0.5 69 | else: # Unconditional 70 | return 1.0 71 | 72 | def suggest_num_discriminators(self): 73 | if self.args.texture_resolution >= 512: 74 | return 3 75 | else: 76 | return 2 77 | 78 | def suggest_mesh_template(self): 79 | return 'mesh_templates/uvsphere_16rings.obj' 80 | 81 | def get_random_caption(self, idx): 82 | # Randomly select a sentence belonging to image idx 83 | sent_ix = torch.randint(0, self.text_processor.embeddings_num, size=(1,)).item() 84 | new_sent_ix = self.image_index_to_caption_index[idx] * self.text_processor.embeddings_num + sent_ix 85 | return self.text_processor.get_caption(new_sent_ix) # Tuple (padded tokens, lengths) 86 | 87 | def __getitem__(self, idx): 88 | gt_dict = super().__getitem__(idx) 89 | 90 | if self.args.conditional_encoding: 91 | gt_dict['encoding'] = self.feat[idx] 92 | 93 | if self.args.conditional_text: 94 | gt_dict['caption'] = self.get_random_caption(idx) 95 | 96 | return gt_dict -------------------------------------------------------------------------------- /lib/data/pseudo_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import numpy as np 4 | import glob 5 | 6 | class BasePseudoDataset(torch.utils.data.Dataset): 7 | def __init__(self, args, augment=True): 8 | self.args = args 9 | 10 | self.data_dir = args.data_dir 11 | 12 | self.data = np.load(os.path.join(self.data_dir, 'cache', 'poses_metadata.npz'), allow_pickle=True) 13 | self.data = self.data['data'].item() 14 | num_images = len(self.data['path']) 15 | self.augment = augment 16 | pseudogt_files = glob.glob(os.path.join(self.data_dir, 17 | f'pseudogt_{args.texture_resolution}x{args.texture_resolution}', 18 | '*.npz')) 19 | if len(pseudogt_files) == 0: 20 | print('Pseudo-ground-truth not found, only "Full FID" evaluation is available') 21 | self.has_pseudogt = False 22 | elif len(pseudogt_files) == num_images: 23 | print(f'Pseudo-ground-truth found! ({len(pseudogt_files)} images)') 24 | self.has_pseudogt = True 25 | else: 26 | raise ValueError('Found pseudo-ground-truth directory, but number of files does not match! ' 27 | f'Expected {num_images}, got {len(pseudogt_files)}. ' 28 | 'Please check your dataset setup.') 29 | 30 | if not self.has_pseudogt and not args.evaluate: 31 | raise ValueError('Training a model requires the pseudo-ground-truth to be setup beforehand.') 32 | 33 | def name(self): 34 | raise NotImplementedError() 35 | 36 | def suggest_truncation_sigma(self): 37 | raise NotImplementedError() 38 | 39 | def suggest_num_discriminators(self): 40 | raise NotImplementedError() 41 | 42 | def suggest_mesh_template(self): 43 | raise NotImplementedError() 44 | 45 | def __len__(self): 46 | return len(self.data['path']) 47 | 48 | def _load_pseudogt(self, idx): 49 | tex_res = self.args.texture_resolution 50 | data = np.load(os.path.join(self.data_dir, 51 | f'pseudogt_{tex_res}x{tex_res}', 52 | f'{idx}.npz'), allow_pickle=True) 53 | 54 | data = data['data'].item() 55 | 56 | gt_dict = { 57 | 'image': data['image'][:3].float(), 58 | 'texture': data['texture'].float(), 59 | 'texture_alpha': data['texture_alpha'].float(), 60 | 'mesh': data['mesh'] 61 | } 62 | if 'image_256' in data: 63 | gt_dict['image_256'] = data['image_256'][:3].float() 64 | return gt_dict 65 | 66 | def __getitem__(self, idx): 67 | gt_dict = self._load_pseudogt(idx) 68 | # del gt_dict['image'] # Not needed # NOTE removed by author 69 | 70 | # "Virtual" mirroring in UV space 71 | # A very simple form of data augmentation that does not require re-rendering 72 | if self.augment and not self.args.evaluate: 73 | if torch.randint(0, 2, size=(1,)).item() == 1: 74 | for k, v in gt_dict.items(): 75 | gt_dict[k] = BasePseudoDataset.mirror_tex(v) 76 | 77 | if self.args.conditional_class: 78 | gt_dict['class'] = self.classes[idx] 79 | 80 | gt_dict['idx'] = idx 81 | return gt_dict 82 | 83 | @staticmethod 84 | def mirror_tex(tr): 85 | # "Virtually" flip a texture or displacement map of shape (nc, H, W) 86 | # This is achieved by mirroring the image and shifting the u coordinate, 87 | # which is consistent with reprojecting the mirrored 2D image. 88 | tr = torch.flip(tr, dims=(2,)) 89 | tr = torch.cat((tr, tr), dim=2) 90 | tr = tr[:, :, tr.shape[2]//4:-tr.shape[2]//4] 91 | return tr 92 | 93 | 94 | 95 | 96 | class PseudoDatasetForEvaluation(torch.utils.data.Dataset): 97 | 98 | def __init__(self, dataset): 99 | self.dataset = dataset 100 | 101 | def __len__(self): 102 | return len(self.dataset) 103 | 104 | def __getitem__(self, idx): 105 | gt_dict = { 106 | 'scale': self.dataset.data['scale'][idx], 107 | 'translation': self.dataset.data['translation'][idx], 108 | 'rotation': self.dataset.data['rotation'][idx], 109 | 'idx': idx, 110 | } 111 | 112 | if self.dataset.args.conditional_class: 113 | gt_dict['class'] = self.dataset.classes[idx] 114 | 115 | if self.dataset.args.conditional_text: 116 | gt_dict['caption'] = self.dataset.index_captions[idx] # Tuple (padded tokens, lengths) 117 | 118 | if self.dataset.args.conditional_encoding: 119 | gt_dict['encoding'] = self.dataset.feat[idx] 120 | 121 | if self.dataset.has_pseudogt: 122 | # Add pseudo-ground-truth entries 123 | gt_dict.update(self.dataset._load_pseudogt(idx)) 124 | 125 | 126 | return gt_dict 127 | -------------------------------------------------------------------------------- /lib/external/ChamferDistancePytorch/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junzhezhang/mesh-inversion/d6614726344f5a56c068df2750fefc593c4ca43d/lib/external/ChamferDistancePytorch/.DS_Store -------------------------------------------------------------------------------- /lib/external/ChamferDistancePytorch/.gitignore: -------------------------------------------------------------------------------- 1 | *__pycache__* -------------------------------------------------------------------------------- /lib/external/ChamferDistancePytorch/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 ThibaultGROUEIX 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /lib/external/ChamferDistancePytorch/README.md: -------------------------------------------------------------------------------- 1 | # Pytorch Chamfer Distance. 2 | 3 | Include a **CUDA** version, and a **PYTHON** version with pytorch standard operations. 4 | NB : In this depo, dist1 and dist2 are squared pointcloud euclidean distances, so you should adapt thresholds accordingly. 5 | 6 | - [x] F - Score 7 | 8 | 9 | 10 | ### CUDA VERSION 11 | 12 | - [x] JIT compilation 13 | - [x] Supports multi-gpu 14 | - [x] 2D point clouds. 15 | - [x] 3D point clouds. 16 | - [x] 5D point clouds. 17 | - [x] Contiguous() safe. 18 | 19 | 20 | 21 | ### Python Version 22 | 23 | - [x] Supports any dimension 24 | 25 | 26 | 27 | ### Usage 28 | 29 | ```python 30 | import torch, chamfer3D.dist_chamfer_3D, fscore 31 | chamLoss = chamfer3D.dist_chamfer_3D.chamfer_3DDist() 32 | points1 = torch.rand(32, 1000, 3).cuda() 33 | points2 = torch.rand(32, 2000, 3, requires_grad=True).cuda() 34 | dist1, dist2, idx1, idx2 = chamLoss(points1, points2) 35 | f_score, precision, recall = fscore.fscore(dist1, dist2) 36 | ``` 37 | 38 | 39 | 40 | ### Add it to your project as a submodule 41 | 42 | ```shell 43 | git submodule add https://github.com/ThibaultGROUEIX/ChamferDistancePytorch 44 | ``` 45 | 46 | 47 | 48 | ### Benchmark: [forward + backward] pass 49 | - [x] CUDA 10.1, NVIDIA 435, Pytorch 1.4 50 | - [x] p1 : 32 x 2000 x dim 51 | - [x] p2 : 32 x 1000 x dim 52 | 53 | | *Timing (sec * 1000)* | 2D | 3D | 5D | 54 | | ---------- | -------- | ------- | ------- | 55 | | **Cuda Compiled** | **1.2** | 1.4 |1.8 | 56 | | **Cuda JIT** | 1.3 | **1.4** |**1.5** | 57 | | **Python** | 37 | 37 | 37 | 58 | 59 | 60 | | *Memory (MB)* | 2D | 3D | 5D | 61 | | ---------- | -------- | ------- | ------- | 62 | | **Cuda Compiled** | 529 | 529 | 549 | 63 | | **Cuda JIT** | **520** | **529** |**549** | 64 | | **Python** | 2495 | 2495 | 2495 | 65 | 66 | 67 | 68 | ### What is the chamfer distance ? 69 | 70 | [Stanford course](http://graphics.stanford.edu/courses/cs468-17-spring/LectureSlides/L14%20-%203d%20deep%20learning%20on%20point%20cloud%20representation%20(analysis).pdf) on 3D deep Learning 71 | 72 | 73 | 74 | ### Aknowledgment 75 | 76 | Original backbone from [Fei Xia](https://github.com/fxia22/pointGAN/blob/master/nndistance/src/nnd_cuda.cu). 77 | 78 | JIT cool trick from [Christian Diller](https://github.com/chrdiller) 79 | 80 | ### Troubleshoot 81 | 82 | - `Undefined symbol: Zxxxxxxxxxxxxxxxxx `: 83 | 84 | --> Fix: Make sure to `import torch` before you `import chamfer`. 85 | --> Use pytorch.version >= 1.1.0 86 | 87 | - [RuntimeError: Ninja is required to load C++ extension](https://github.com/zhanghang1989/PyTorch-Encoding/issues/167) 88 | 89 | ```shell 90 | wget https://github.com/ninja-build/ninja/releases/download/v1.8.2/ninja-linux.zip 91 | sudo unzip ninja-linux.zip -d /usr/local/bin/ 92 | sudo update-alternatives --install /usr/bin/ninja ninja /usr/local/bin/ninja 1 --force 93 | ``` 94 | 95 | 96 | 97 | 98 | 99 | #### TODO: 100 | 101 | * Discuss behaviour of torch.min() and tensor.min() which causes issues in some pytorch versions 102 | -------------------------------------------------------------------------------- /lib/external/ChamferDistancePytorch/chamfer2D/chamfer2D.cu: -------------------------------------------------------------------------------- 1 | 2 | #include 3 | #include 4 | 5 | #include 6 | #include 7 | 8 | #include 9 | 10 | 11 | 12 | __global__ void NmDistanceKernel(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i){ 13 | const int batch=512; 14 | __shared__ float buf[batch*2]; 15 | for (int i=blockIdx.x;ibest){ 117 | result[(i*n+j)]=best; 118 | result_i[(i*n+j)]=best_i; 119 | } 120 | } 121 | __syncthreads(); 122 | } 123 | } 124 | } 125 | // int chamfer_cuda_forward(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream){ 126 | int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2){ 127 | 128 | const auto batch_size = xyz1.size(0); 129 | const auto n = xyz1.size(1); //num_points point cloud A 130 | const auto m = xyz2.size(1); //num_points point cloud B 131 | 132 | NmDistanceKernel<<>>(batch_size, n, xyz1.data(), m, xyz2.data(), dist1.data(), idx1.data()); 133 | NmDistanceKernel<<>>(batch_size, m, xyz2.data(), n, xyz1.data(), dist2.data(), idx2.data()); 134 | 135 | cudaError_t err = cudaGetLastError(); 136 | if (err != cudaSuccess) { 137 | printf("error in nnd updateOutput: %s\n", cudaGetErrorString(err)); 138 | //THError("aborting"); 139 | return 0; 140 | } 141 | return 1; 142 | 143 | 144 | } 145 | __global__ void NmDistanceGradKernel(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,float * grad_xyz1,float * grad_xyz2){ 146 | for (int i=blockIdx.x;i>>(batch_size,n,xyz1.data(),m,xyz2.data(),graddist1.data(),idx1.data(),gradxyz1.data(),gradxyz2.data()); 171 | NmDistanceGradKernel<<>>(batch_size,m,xyz2.data(),n,xyz1.data(),graddist2.data(),idx2.data(),gradxyz2.data(),gradxyz1.data()); 172 | 173 | cudaError_t err = cudaGetLastError(); 174 | if (err != cudaSuccess) { 175 | printf("error in nnd get grad: %s\n", cudaGetErrorString(err)); 176 | //THError("aborting"); 177 | return 0; 178 | } 179 | return 1; 180 | 181 | } 182 | 183 | -------------------------------------------------------------------------------- /lib/external/ChamferDistancePytorch/chamfer2D/chamfer_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | ///TMP 5 | //#include "common.h" 6 | /// NOT TMP 7 | 8 | 9 | int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2); 10 | 11 | 12 | int chamfer_cuda_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2); 13 | 14 | 15 | 16 | 17 | int chamfer_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2) { 18 | return chamfer_cuda_forward(xyz1, xyz2, dist1, dist2, idx1, idx2); 19 | } 20 | 21 | 22 | int chamfer_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, 23 | at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2) { 24 | 25 | return chamfer_cuda_backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2); 26 | } 27 | 28 | 29 | 30 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 31 | m.def("forward", &chamfer_forward, "chamfer forward (CUDA)"); 32 | m.def("backward", &chamfer_backward, "chamfer backward (CUDA)"); 33 | } -------------------------------------------------------------------------------- /lib/external/ChamferDistancePytorch/chamfer2D/dist_chamfer_2D.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.autograd import Function 3 | import torch 4 | import importlib 5 | import os 6 | chamfer_found = importlib.find_loader("chamfer_2D") is not None 7 | if not chamfer_found: 8 | ## Cool trick from https://github.com/chrdiller 9 | print("Jitting Chamfer 2D") 10 | 11 | from torch.utils.cpp_extension import load 12 | chamfer_2D = load(name="chamfer_2D", 13 | sources=[ 14 | "/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer_cuda.cpp"]), 15 | "/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer2D.cu"]), 16 | ]) 17 | print("Loaded JIT 2D CUDA chamfer distance") 18 | 19 | else: 20 | import chamfer_2D 21 | print("Loaded compiled 2D CUDA chamfer distance") 22 | 23 | # Chamfer's distance module @thibaultgroueix 24 | # GPU tensors only 25 | class chamfer_2DFunction(Function): 26 | @staticmethod 27 | def forward(ctx, xyz1, xyz2): 28 | batchsize, n, _ = xyz1.size() 29 | _, m, _ = xyz2.size() 30 | device = xyz1.device 31 | 32 | dist1 = torch.zeros(batchsize, n) 33 | dist2 = torch.zeros(batchsize, m) 34 | 35 | idx1 = torch.zeros(batchsize, n).type(torch.IntTensor) 36 | idx2 = torch.zeros(batchsize, m).type(torch.IntTensor) 37 | 38 | dist1 = dist1.to(device) 39 | dist2 = dist2.to(device) 40 | idx1 = idx1.to(device) 41 | idx2 = idx2.to(device) 42 | torch.cuda.set_device(device) 43 | 44 | chamfer_2D.forward(xyz1, xyz2, dist1, dist2, idx1, idx2) 45 | ctx.save_for_backward(xyz1, xyz2, idx1, idx2) 46 | return dist1, dist2, idx1, idx2 47 | 48 | @staticmethod 49 | def backward(ctx, graddist1, graddist2, gradidx1, gradidx2): 50 | xyz1, xyz2, idx1, idx2 = ctx.saved_tensors 51 | graddist1 = graddist1.contiguous() 52 | graddist2 = graddist2.contiguous() 53 | device = graddist1.device 54 | 55 | gradxyz1 = torch.zeros(xyz1.size()) 56 | gradxyz2 = torch.zeros(xyz2.size()) 57 | 58 | gradxyz1 = gradxyz1.to(device) 59 | gradxyz2 = gradxyz2.to(device) 60 | chamfer_2D.backward( 61 | xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2 62 | ) 63 | return gradxyz1, gradxyz2 64 | 65 | 66 | class chamfer_2DDist(nn.Module): 67 | def __init__(self): 68 | super(chamfer_2DDist, self).__init__() 69 | 70 | def forward(self, input1, input2): 71 | input1 = input1.contiguous() 72 | input2 = input2.contiguous() 73 | return chamfer_2DFunction.apply(input1, input2) 74 | -------------------------------------------------------------------------------- /lib/external/ChamferDistancePytorch/chamfer2D/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | setup( 5 | name='chamfer_2D', 6 | ext_modules=[ 7 | CUDAExtension('chamfer_2D', [ 8 | "/".join(__file__.split('/')[:-1] + ['chamfer_cuda.cpp']), 9 | "/".join(__file__.split('/')[:-1] + ['chamfer2D.cu']), 10 | ]), 11 | ], 12 | cmdclass={ 13 | 'build_ext': BuildExtension 14 | }) -------------------------------------------------------------------------------- /lib/external/ChamferDistancePytorch/chamfer3D/chamfer3D.cu: -------------------------------------------------------------------------------- 1 | 2 | #include 3 | #include 4 | 5 | #include 6 | #include 7 | 8 | #include 9 | 10 | 11 | 12 | __global__ void NmDistanceKernel(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i){ 13 | const int batch=512; 14 | __shared__ float buf[batch*3]; 15 | for (int i=blockIdx.x;ibest){ 127 | result[(i*n+j)]=best; 128 | result_i[(i*n+j)]=best_i; 129 | } 130 | } 131 | __syncthreads(); 132 | } 133 | } 134 | } 135 | // int chamfer_cuda_forward(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream){ 136 | int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2){ 137 | 138 | const auto batch_size = xyz1.size(0); 139 | const auto n = xyz1.size(1); //num_points point cloud A 140 | const auto m = xyz2.size(1); //num_points point cloud B 141 | 142 | NmDistanceKernel<<>>(batch_size, n, xyz1.data(), m, xyz2.data(), dist1.data(), idx1.data()); 143 | NmDistanceKernel<<>>(batch_size, m, xyz2.data(), n, xyz1.data(), dist2.data(), idx2.data()); 144 | 145 | cudaError_t err = cudaGetLastError(); 146 | if (err != cudaSuccess) { 147 | printf("error in nnd updateOutput: %s\n", cudaGetErrorString(err)); 148 | //THError("aborting"); 149 | return 0; 150 | } 151 | return 1; 152 | 153 | 154 | } 155 | __global__ void NmDistanceGradKernel(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,float * grad_xyz1,float * grad_xyz2){ 156 | for (int i=blockIdx.x;i>>(batch_size,n,xyz1.data(),m,xyz2.data(),graddist1.data(),idx1.data(),gradxyz1.data(),gradxyz2.data()); 185 | NmDistanceGradKernel<<>>(batch_size,m,xyz2.data(),n,xyz1.data(),graddist2.data(),idx2.data(),gradxyz2.data(),gradxyz1.data()); 186 | 187 | cudaError_t err = cudaGetLastError(); 188 | if (err != cudaSuccess) { 189 | printf("error in nnd get grad: %s\n", cudaGetErrorString(err)); 190 | //THError("aborting"); 191 | return 0; 192 | } 193 | return 1; 194 | 195 | } 196 | 197 | -------------------------------------------------------------------------------- /lib/external/ChamferDistancePytorch/chamfer3D/chamfer_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | ///TMP 5 | //#include "common.h" 6 | /// NOT TMP 7 | 8 | 9 | int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2); 10 | 11 | 12 | int chamfer_cuda_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2); 13 | 14 | 15 | 16 | 17 | int chamfer_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2) { 18 | return chamfer_cuda_forward(xyz1, xyz2, dist1, dist2, idx1, idx2); 19 | } 20 | 21 | 22 | int chamfer_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, 23 | at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2) { 24 | 25 | return chamfer_cuda_backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2); 26 | } 27 | 28 | 29 | 30 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 31 | m.def("forward", &chamfer_forward, "chamfer forward (CUDA)"); 32 | m.def("backward", &chamfer_backward, "chamfer backward (CUDA)"); 33 | } -------------------------------------------------------------------------------- /lib/external/ChamferDistancePytorch/chamfer3D/dist_chamfer_3D.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.autograd import Function 3 | import torch 4 | import importlib 5 | import os 6 | chamfer_found = importlib.find_loader("chamfer_3D") is not None 7 | if not chamfer_found: 8 | ## Cool trick from https://github.com/chrdiller 9 | print("Jitting Chamfer 3D") 10 | 11 | from torch.utils.cpp_extension import load 12 | chamfer_3D = load(name="chamfer_3D", 13 | sources=[ 14 | "/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer_cuda.cpp"]), 15 | "/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer3D.cu"]), 16 | ]) 17 | print("Loaded JIT 3D CUDA chamfer distance") 18 | 19 | else: 20 | import chamfer_3D 21 | print("Loaded compiled 3D CUDA chamfer distance") 22 | 23 | 24 | # Chamfer's distance module @thibaultgroueix 25 | # GPU tensors only 26 | class chamfer_3DFunction(Function): 27 | @staticmethod 28 | def forward(ctx, xyz1, xyz2): 29 | batchsize, n, _ = xyz1.size() 30 | _, m, _ = xyz2.size() 31 | device = xyz1.device 32 | 33 | dist1 = torch.zeros(batchsize, n) 34 | dist2 = torch.zeros(batchsize, m) 35 | 36 | idx1 = torch.zeros(batchsize, n).type(torch.IntTensor) 37 | idx2 = torch.zeros(batchsize, m).type(torch.IntTensor) 38 | 39 | dist1 = dist1.to(device) 40 | dist2 = dist2.to(device) 41 | idx1 = idx1.to(device) 42 | idx2 = idx2.to(device) 43 | torch.cuda.set_device(device) 44 | 45 | chamfer_3D.forward(xyz1, xyz2, dist1, dist2, idx1, idx2) 46 | ctx.save_for_backward(xyz1, xyz2, idx1, idx2) 47 | return dist1, dist2, idx1, idx2 48 | 49 | @staticmethod 50 | def backward(ctx, graddist1, graddist2, gradidx1, gradidx2): 51 | xyz1, xyz2, idx1, idx2 = ctx.saved_tensors 52 | graddist1 = graddist1.contiguous() 53 | graddist2 = graddist2.contiguous() 54 | device = graddist1.device 55 | 56 | gradxyz1 = torch.zeros(xyz1.size()) 57 | gradxyz2 = torch.zeros(xyz2.size()) 58 | 59 | gradxyz1 = gradxyz1.to(device) 60 | gradxyz2 = gradxyz2.to(device) 61 | chamfer_3D.backward( 62 | xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2 63 | ) 64 | return gradxyz1, gradxyz2 65 | 66 | 67 | class chamfer_3DDist(nn.Module): 68 | def __init__(self): 69 | super(chamfer_3DDist, self).__init__() 70 | 71 | def forward(self, input1, input2): 72 | input1 = input1.contiguous() 73 | input2 = input2.contiguous() 74 | return chamfer_3DFunction.apply(input1, input2) 75 | 76 | -------------------------------------------------------------------------------- /lib/external/ChamferDistancePytorch/chamfer3D/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | setup( 5 | name='chamfer_3D', 6 | ext_modules=[ 7 | CUDAExtension('chamfer_3D', [ 8 | "/".join(__file__.split('/')[:-1] + ['chamfer_cuda.cpp']), 9 | "/".join(__file__.split('/')[:-1] + ['chamfer3D.cu']), 10 | ]), 11 | ], 12 | cmdclass={ 13 | 'build_ext': BuildExtension 14 | }) -------------------------------------------------------------------------------- /lib/external/ChamferDistancePytorch/chamfer5D/chamfer5D.cu: -------------------------------------------------------------------------------- 1 | 2 | #include 3 | #include 4 | 5 | #include 6 | #include 7 | 8 | #include 9 | 10 | 11 | 12 | __global__ void NmDistanceKernel(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i){ 13 | const int batch=2048; 14 | __shared__ float buf[batch*5]; 15 | for (int i=blockIdx.x;ibest){ 147 | result[(i*n+j)]=best; 148 | result_i[(i*n+j)]=best_i; 149 | } 150 | } 151 | __syncthreads(); 152 | } 153 | } 154 | } 155 | // int chamfer_cuda_forward(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream){ 156 | int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2){ 157 | 158 | const auto batch_size = xyz1.size(0); 159 | const auto n = xyz1.size(1); //num_points point cloud A 160 | const auto m = xyz2.size(1); //num_points point cloud B 161 | 162 | NmDistanceKernel<<>>(batch_size, n, xyz1.data(), m, xyz2.data(), dist1.data(), idx1.data()); 163 | NmDistanceKernel<<>>(batch_size, m, xyz2.data(), n, xyz1.data(), dist2.data(), idx2.data()); 164 | 165 | cudaError_t err = cudaGetLastError(); 166 | if (err != cudaSuccess) { 167 | printf("error in nnd updateOutput: %s\n", cudaGetErrorString(err)); 168 | //THError("aborting"); 169 | return 0; 170 | } 171 | return 1; 172 | 173 | 174 | } 175 | __global__ void NmDistanceGradKernel(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,float * grad_xyz1,float * grad_xyz2){ 176 | for (int i=blockIdx.x;i>>(batch_size,n,xyz1.data(),m,xyz2.data(),graddist1.data(),idx1.data(),gradxyz1.data(),gradxyz2.data()); 213 | NmDistanceGradKernel<<>>(batch_size,m,xyz2.data(),n,xyz1.data(),graddist2.data(),idx2.data(),gradxyz2.data(),gradxyz1.data()); 214 | 215 | cudaError_t err = cudaGetLastError(); 216 | if (err != cudaSuccess) { 217 | printf("error in nnd get grad: %s\n", cudaGetErrorString(err)); 218 | //THError("aborting"); 219 | return 0; 220 | } 221 | return 1; 222 | 223 | } 224 | -------------------------------------------------------------------------------- /lib/external/ChamferDistancePytorch/chamfer5D/chamfer_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | ///TMP 5 | //#include "common.h" 6 | /// NOT TMP 7 | 8 | 9 | int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2); 10 | 11 | 12 | int chamfer_cuda_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2); 13 | 14 | 15 | 16 | 17 | int chamfer_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2) { 18 | return chamfer_cuda_forward(xyz1, xyz2, dist1, dist2, idx1, idx2); 19 | } 20 | 21 | 22 | int chamfer_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, 23 | at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2) { 24 | 25 | return chamfer_cuda_backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2); 26 | } 27 | 28 | 29 | 30 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 31 | m.def("forward", &chamfer_forward, "chamfer forward (CUDA)"); 32 | m.def("backward", &chamfer_backward, "chamfer backward (CUDA)"); 33 | } -------------------------------------------------------------------------------- /lib/external/ChamferDistancePytorch/chamfer5D/dist_chamfer_5D.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.autograd import Function 3 | import torch 4 | import importlib 5 | import os 6 | 7 | chamfer_found = importlib.find_loader("chamfer_5D") is not None 8 | if not chamfer_found: 9 | ## Cool trick from https://github.com/chrdiller 10 | print("Jitting Chamfer 5D") 11 | 12 | from torch.utils.cpp_extension import load 13 | chamfer_5D = load(name="chamfer_5D", 14 | sources=[ 15 | "/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer_cuda.cpp"]), 16 | "/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer5D.cu"]), 17 | ]) 18 | print("Loaded JIT 5D CUDA chamfer distance") 19 | 20 | else: 21 | import chamfer_5D 22 | print("Loaded compiled 5D CUDA chamfer distance") 23 | 24 | 25 | # Chamfer's distance module @thibaultgroueix 26 | # GPU tensors only 27 | class chamfer_5DFunction(Function): 28 | @staticmethod 29 | def forward(ctx, xyz1, xyz2): 30 | batchsize, n, _ = xyz1.size() 31 | _, m, _ = xyz2.size() 32 | device = xyz1.device 33 | 34 | dist1 = torch.zeros(batchsize, n) 35 | dist2 = torch.zeros(batchsize, m) 36 | 37 | idx1 = torch.zeros(batchsize, n).type(torch.IntTensor) 38 | idx2 = torch.zeros(batchsize, m).type(torch.IntTensor) 39 | 40 | dist1 = dist1.to(device) 41 | dist2 = dist2.to(device) 42 | idx1 = idx1.to(device) 43 | idx2 = idx2.to(device) 44 | torch.cuda.set_device(device) 45 | 46 | chamfer_5D.forward(xyz1, xyz2, dist1, dist2, idx1, idx2) 47 | ctx.save_for_backward(xyz1, xyz2, idx1, idx2) 48 | return dist1, dist2, idx1, idx2 49 | 50 | @staticmethod 51 | def backward(ctx, graddist1, graddist2, gradidx1, gradidx2): 52 | xyz1, xyz2, idx1, idx2 = ctx.saved_tensors 53 | graddist1 = graddist1.contiguous() 54 | graddist2 = graddist2.contiguous() 55 | device = graddist1.device 56 | 57 | gradxyz1 = torch.zeros(xyz1.size()) 58 | gradxyz2 = torch.zeros(xyz2.size()) 59 | 60 | gradxyz1 = gradxyz1.to(device) 61 | gradxyz2 = gradxyz2.to(device) 62 | chamfer_5D.backward( 63 | xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2 64 | ) 65 | return gradxyz1, gradxyz2 66 | 67 | 68 | class chamfer_5DDist(nn.Module): 69 | def __init__(self): 70 | super(chamfer_5DDist, self).__init__() 71 | 72 | def forward(self, input1, input2): 73 | input1 = input1.contiguous() 74 | input2 = input2.contiguous() 75 | return chamfer_5DFunction.apply(input1, input2) 76 | -------------------------------------------------------------------------------- /lib/external/ChamferDistancePytorch/chamfer5D/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | setup( 5 | name='chamfer_5D', 6 | ext_modules=[ 7 | CUDAExtension('chamfer_5D', [ 8 | "/".join(__file__.split('/')[:-1] + ['chamfer_cuda.cpp']), 9 | "/".join(__file__.split('/')[:-1] + ['chamfer5D.cu']), 10 | ]), 11 | ], 12 | cmdclass={ 13 | 'build_ext': BuildExtension 14 | }) -------------------------------------------------------------------------------- /lib/external/ChamferDistancePytorch/chamfer_python.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def pairwise_dist(x, y): 5 | xx, yy, zz = torch.mm(x, x.t()), torch.mm(y, y.t()), torch.mm(x, y.t()) 6 | rx = xx.diag().unsqueeze(0).expand_as(xx) 7 | ry = yy.diag().unsqueeze(0).expand_as(yy) 8 | P = rx.t() + ry - 2 * zz 9 | return P 10 | 11 | 12 | def NN_loss(x, y, dim=0): 13 | dist = pairwise_dist(x, y) 14 | values, indices = dist.min(dim=dim) 15 | return values.mean() 16 | 17 | 18 | def distChamfer(a, b): 19 | """ 20 | :param a: Pointclouds Batch x nul_points x dim 21 | :param b: Pointclouds Batch x nul_points x dim 22 | :return: 23 | -closest point on b of points from a 24 | -closest point on a of points from b 25 | -idx of closest point on b of points from a 26 | -idx of closest point on a of points from b 27 | Works for pointcloud of any dimension 28 | """ 29 | x, y = a.double(), b.double() 30 | bs, num_points_x, points_dim = x.size() 31 | bs, num_points_y, points_dim = y.size() 32 | 33 | xx = torch.pow(x, 2).sum(2) 34 | yy = torch.pow(y, 2).sum(2) 35 | zz = torch.bmm(x, y.transpose(2, 1)) 36 | rx = xx.unsqueeze(1).expand(bs, num_points_y, num_points_x) # Diagonal elements xx 37 | ry = yy.unsqueeze(1).expand(bs, num_points_x, num_points_y) # Diagonal elements yy 38 | P = rx.transpose(2, 1) + ry - 2 * zz 39 | return torch.min(P, 2)[0].float(), torch.min(P, 1)[0].float(), torch.min(P, 2)[1].int(), torch.min(P, 1)[1].int() 40 | 41 | def distChamfer_raw(a, b): 42 | """ 43 | :param a: Pointclouds Batch x nul_points x dim 44 | :param b: Pointclouds Batch x nul_points x dim 45 | :return: 46 | -closest point on b of points from a 47 | -closest point on a of points from b 48 | -idx of closest point on b of points from a 49 | -idx of closest point on a of points from b 50 | Works for pointcloud of any dimension 51 | """ 52 | x, y = a.double(), b.double() 53 | bs, num_points_x, points_dim = x.size() 54 | bs, num_points_y, points_dim = y.size() 55 | 56 | xx = torch.pow(x, 2).sum(2) 57 | yy = torch.pow(y, 2).sum(2) 58 | zz = torch.bmm(x, y.transpose(2, 1)) 59 | rx = xx.unsqueeze(1).expand(bs, num_points_y, num_points_x) # Diagonal elements xx 60 | ry = yy.unsqueeze(1).expand(bs, num_points_x, num_points_y) # Diagonal elements yy 61 | P = rx.transpose(2, 1) + ry - 2 * zz 62 | return P 63 | 64 | 65 | def distChamfer_downsample(a, b, resolution=None, downsample_method='random',idx_a=None,idx_b=None): 66 | ### assume a b of batch size 1 67 | assert a.shape[0] == 1 68 | if resolution is None: 69 | return distChamfer_raw(a, b), idx_a, idx_b 70 | 71 | else: 72 | if downsample_method == 'random': 73 | original_res_a = a.shape[1] 74 | original_res_b = b.shape[1] 75 | # import pdb; pdb.set_trace() 76 | if idx_a is None: 77 | idx_a = torch.randperm(original_res_a)[:min(resolution,original_res_a)] 78 | idx_b = torch.randperm(original_res_b)[:min(resolution,original_res_b)] 79 | a_down = a[:,idx_a,:] 80 | b_down = b[:,idx_b,:] 81 | 82 | dist_mat = distChamfer_raw(a_down, b_down) 83 | 84 | return dist_mat, idx_a, idx_b 85 | 86 | -------------------------------------------------------------------------------- /lib/external/ChamferDistancePytorch/fscore.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def fscore(dist1, dist2, threshold=0.001): 4 | """ 5 | Calculates the F-score between two point clouds with the corresponding threshold value. 6 | :param dist1: Batch, N-Points 7 | :param dist2: Batch, N-Points 8 | :param th: float 9 | :return: fscore, precision, recall 10 | """ 11 | # NB : In this depo, dist1 and dist2 are squared pointcloud euclidean distances, so you should adapt the threshold accordingly. 12 | precision_1 = torch.mean((dist1 < threshold).float(), dim=1) 13 | precision_2 = torch.mean((dist2 < threshold).float(), dim=1) 14 | fscore = 2 * precision_1 * precision_2 / (precision_1 + precision_2) 15 | fscore[torch.isnan(fscore)] = 0 16 | return fscore, precision_1, precision_2 17 | 18 | -------------------------------------------------------------------------------- /lib/external/ChamferDistancePytorch/unit_test.py: -------------------------------------------------------------------------------- 1 | import torch, time 2 | import chamfer2D.dist_chamfer_2D 3 | import chamfer3D.dist_chamfer_3D 4 | import chamfer5D.dist_chamfer_5D 5 | import chamfer_python 6 | 7 | cham2D = chamfer2D.dist_chamfer_2D.chamfer_2DDist() 8 | cham3D = chamfer3D.dist_chamfer_3D.chamfer_3DDist() 9 | cham5D = chamfer5D.dist_chamfer_5D.chamfer_5DDist() 10 | 11 | from torch.autograd import Variable 12 | from fscore import fscore 13 | 14 | def test_chamfer(distChamfer, dim): 15 | points1 = torch.rand(4, 100, dim).cuda() 16 | points2 = torch.rand(4, 200, dim, requires_grad=True).cuda() 17 | dist1, dist2, idx1, idx2= distChamfer(points1, points2) 18 | 19 | loss = torch.sum(dist1) 20 | loss.backward() 21 | 22 | mydist1, mydist2, myidx1, myidx2 = chamfer_python.distChamfer(points1, points2) 23 | d1 = (dist1 - mydist1) ** 2 24 | d2 = (dist2 - mydist2) ** 2 25 | assert ( 26 | torch.mean(d1) + torch.mean(d2) < 0.00000001 27 | ), "chamfer cuda and chamfer normal are not giving the same results" 28 | 29 | xd1 = idx1 - myidx1 30 | xd2 = idx2 - myidx2 31 | assert ( 32 | torch.norm(xd1.float()) + torch.norm(xd2.float()) == 0 33 | ), "chamfer cuda and chamfer normal are not giving the same results" 34 | print(f"fscore :", fscore(dist1, dist2)) 35 | print("Unit test passed") 36 | 37 | 38 | def timings(distChamfer, dim): 39 | p1 = torch.rand(32, 2000, dim).cuda() 40 | p2 = torch.rand(32, 1000, dim).cuda() 41 | print("Timings : Start CUDA version") 42 | start = time.time() 43 | num_it = 100 44 | for i in range(num_it): 45 | points1 = Variable(p1, requires_grad=True) 46 | points2 = Variable(p2) 47 | mydist1, mydist2, idx1, idx2 = distChamfer(points1, points2) 48 | loss = torch.sum(mydist1) 49 | loss.backward() 50 | print(f"Ellapsed time forward backward is {(time.time() - start)/num_it} seconds.") 51 | 52 | 53 | print("Timings : Start Pythonic version") 54 | start = time.time() 55 | for i in range(num_it): 56 | points1 = Variable(p1, requires_grad=True) 57 | points2 = Variable(p2) 58 | mydist1, mydist2, idx1, idx2 = chamfer_python.distChamfer(points1, points2) 59 | loss = torch.sum(mydist1) 60 | loss.backward() 61 | print(f"Ellapsed time forward backward is {(time.time() - start)/num_it} seconds.") 62 | 63 | 64 | 65 | dims = [2,3,5] 66 | for i,cham in enumerate([cham2D, cham3D, cham5D]): 67 | print(f"testing Chamfer {dims[i]}D") 68 | test_chamfer(cham, dims[i]) 69 | timings(cham, dims[i]) 70 | -------------------------------------------------------------------------------- /lib/mesh_templates/wireframe_16rings.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junzhezhang/mesh-inversion/d6614726344f5a56c068df2750fefc593c4ca43d/lib/mesh_templates/wireframe_16rings.png -------------------------------------------------------------------------------- /lib/mesh_templates/wireframe_17rings.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junzhezhang/mesh-inversion/d6614726344f5a56c068df2750fefc593c4ca43d/lib/mesh_templates/wireframe_17rings.png -------------------------------------------------------------------------------- /lib/mesh_templates/wireframe_31rings.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junzhezhang/mesh-inversion/d6614726344f5a56c068df2750fefc593c4ca43d/lib/mesh_templates/wireframe_31rings.png -------------------------------------------------------------------------------- /lib/mesh_templates/wireframe_32rings.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junzhezhang/mesh-inversion/d6614726344f5a56c068df2750fefc593c4ca43d/lib/mesh_templates/wireframe_32rings.png -------------------------------------------------------------------------------- /lib/models/__init__.py: -------------------------------------------------------------------------------- 1 | """This package contains modules related to objective functions, optimizations, and network architectures. 2 | To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel. 3 | You need to implement the following five functions: 4 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). 5 | -- : unpack data from dataset and apply preprocessing. 6 | -- : produce intermediate results. 7 | -- : calculate loss, gradients, and update network weights. 8 | -- : (optionally) add model-specific options and set default options. 9 | In the function <__init__>, you need to define four lists: 10 | -- self.loss_names (str list): specify the training losses that you want to plot and save. 11 | -- self.model_names (str list): define networks used in our training. 12 | -- self.visual_names (str list): specify the images that you want to display and save. 13 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage. 14 | Now you can use the model class by specifying flag '--model dummy'. 15 | See our template model class 'template_model.py' for more details. 16 | """ 17 | 18 | import importlib 19 | from lib.models.cyclegan_base_model import BaseModel 20 | 21 | 22 | def find_model_using_name(model_name): 23 | """Import the module "models/[model_name]_model.py". 24 | In the file, the class called DatasetNameModel() will 25 | be instantiated. It has to be a subclass of BaseModel, 26 | and it is case-insensitive. 27 | """ 28 | if model_name == 'cycle_gan': 29 | model_filename = "models.cyclegan_networks" 30 | # model_filename = "models." + model_name + "_model" 31 | else: 32 | raise 33 | modellib = importlib.import_module(model_filename) 34 | model = None 35 | target_model_name = model_name.replace('_', '') + 'model' 36 | for name, cls in modellib.__dict__.items(): 37 | if name.lower() == target_model_name.lower() \ 38 | and issubclass(cls, BaseModel): 39 | model = cls 40 | 41 | if model is None: 42 | print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)) 43 | exit(0) 44 | 45 | return model 46 | 47 | 48 | def get_option_setter(model_name): 49 | """Return the static method of the model class.""" 50 | model_class = find_model_using_name(model_name) 51 | return model_class.modify_commandline_options 52 | 53 | 54 | def create_model(opt): 55 | """Create a model given the option. 56 | This function warps the class CustomDatasetDataLoader. 57 | This is the main interface between this package and 'train.py'/'test.py' 58 | Example: 59 | >>> from lib.models import create_model 60 | >>> model = create_model(opt) 61 | """ 62 | # model = find_model_using_name(opt.model) 63 | # instance = model(opt) 64 | # NOTE: simplified 65 | # model = find_model_using_name 66 | print("model [%s] was created" % type(instance).__name__) 67 | return instance -------------------------------------------------------------------------------- /lib/models/cmr_mesh_net.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torchvision 6 | from . import cmr_net_blocks as nb 7 | #------------- Modules ------------# 8 | #----------------------------------# 9 | class ResNetConv(nn.Module): 10 | def __init__(self, n_blocks=4): 11 | super(ResNetConv, self).__init__() 12 | self.resnet = torchvision.models.resnet18(pretrained=True) 13 | self.n_blocks = n_blocks 14 | 15 | def forward(self, x): 16 | n_blocks = self.n_blocks 17 | x = self.resnet.conv1(x) 18 | x = self.resnet.bn1(x) 19 | x = self.resnet.relu(x) 20 | x = self.resnet.maxpool(x) 21 | 22 | if n_blocks >= 1: 23 | x = self.resnet.layer1(x) 24 | if n_blocks >= 2: 25 | x = self.resnet.layer2(x) 26 | if n_blocks >= 3: 27 | x = self.resnet.layer3(x) 28 | if n_blocks >= 4: 29 | x = self.resnet.layer4(x) 30 | return x 31 | 32 | class Encoder(nn.Module): 33 | """ 34 | Current: 35 | Resnet with 4 blocks (x32 spatial dim reduction) 36 | Another conv with stride 2 (x64) 37 | This is sent to 2 fc layers with final output nz_feat. 38 | """ 39 | 40 | def __init__(self, input_shape, n_blocks=4, nz_feat=100, batch_norm=True): 41 | super(Encoder, self).__init__() 42 | self.resnet_conv = ResNetConv(n_blocks=4) 43 | self.enc_conv1 = nb.conv2d(batch_norm, 512, 256, stride=2, kernel_size=4) 44 | nc_input = 256 * (input_shape[0] // 64) * (input_shape[1] // 64) 45 | self.enc_fc = nb.fc_stack(nc_input, nz_feat, 2) 46 | 47 | nb.net_init(self.enc_conv1) 48 | 49 | def forward(self, img): 50 | resnet_feat = self.resnet_conv.forward(img) 51 | 52 | out_enc_conv1 = self.enc_conv1(resnet_feat) 53 | out_enc_conv1 = out_enc_conv1.view(img.size(0), -1) 54 | feat = self.enc_fc.forward(out_enc_conv1) 55 | 56 | return feat 57 | 58 | class QuatPredictor(nn.Module): 59 | def __init__(self, nz_feat, nz_rot=4, classify_rot=False): 60 | super(QuatPredictor, self).__init__() 61 | self.pred_layer = nn.Linear(nz_feat, nz_rot) 62 | self.classify_rot = classify_rot 63 | 64 | def forward(self, feat): 65 | quat = self.pred_layer.forward(feat) 66 | if self.classify_rot: 67 | quat = torch.nn.functional.log_softmax(quat) 68 | else: 69 | quat = torch.nn.functional.normalize(quat) 70 | return quat 71 | 72 | 73 | class ScalePredictor(nn.Module): 74 | def __init__(self, nz): 75 | super(ScalePredictor, self).__init__() 76 | self.pred_layer = nn.Linear(nz, 1) 77 | 78 | def forward(self, feat): 79 | scale = self.pred_layer.forward(feat) + 1 #biasing the scale to 1 80 | scale = torch.nn.functional.relu(scale) + 1e-12 81 | return scale 82 | 83 | 84 | class TransPredictor(nn.Module): 85 | """ 86 | Outputs [tx, ty] or [tx, ty, tz] 87 | """ 88 | 89 | def __init__(self, nz, orth=True): 90 | super(TransPredictor, self).__init__() 91 | if orth: 92 | self.pred_layer = nn.Linear(nz, 2) 93 | else: 94 | self.pred_layer = nn.Linear(nz, 3) 95 | 96 | def forward(self, feat): 97 | trans = self.pred_layer.forward(feat) 98 | return trans 99 | 100 | 101 | class CamPoseEstimator(nn.Module): 102 | def __init__(self, nz_feat=200, input_shape=(256,256), args=None): 103 | 104 | super(CamPoseEstimator, self).__init__() 105 | # self.model = mesh_net.MeshNet( 106 | # img_size, args, nz_feat=args.nz_feat, num_kps=args.num_kps, sfm_mean_shape=sfm_mean_shape) 107 | # class MeshNet(nn.Module): 108 | # def __init__(self, input_shape, args, nz_feat=100, num_kps=15, sfm_mean_shape=None): 109 | # self.code_predictor = CodePredictor(nz_feat=nz_feat, num_verts=self.num_output,args=args) 110 | self.encoder = Encoder(input_shape=input_shape, n_blocks=4, nz_feat=nz_feat) 111 | self.quat_predictor = QuatPredictor(nz_feat) 112 | 113 | self.scale_predictor = ScalePredictor(nz_feat) 114 | self.trans_predictor = TransPredictor(nz_feat) 115 | 116 | def forward(self,x): 117 | feat = self.encoder.forward(x) 118 | scale_pred = self.scale_predictor.forward(feat) 119 | quat_pred = self.quat_predictor.forward(feat) 120 | trans_pred = self.trans_predictor.forward(feat) 121 | 122 | return scale_pred, trans_pred, trans_pred 123 | -------------------------------------------------------------------------------- /lib/models/cmr_net_blocks.py: -------------------------------------------------------------------------------- 1 | ''' 2 | CNN building blocks. 3 | Taken from https://github.com/shubhtuls/factored3d/ 4 | ''' 5 | from __future__ import division 6 | from __future__ import print_function 7 | import torch 8 | import torch.nn as nn 9 | import math 10 | 11 | class Flatten(nn.Module): 12 | def forward(self, x): 13 | return x.view(x.size()[0], -1) 14 | 15 | class Unsqueeze(nn.Module): 16 | def __init__(self, dim): 17 | super(Unsqueeze, self).__init__() 18 | self.dim = dim 19 | 20 | def forward(self, x): 21 | return x.unsqueeze(self.dim) 22 | 23 | ## fc layers 24 | def fc(batch_norm, nc_inp, nc_out): 25 | if batch_norm: 26 | return nn.Sequential( 27 | nn.Linear(nc_inp, nc_out, bias=True), 28 | nn.BatchNorm1d(nc_out), 29 | nn.LeakyReLU(0.2,inplace=True) 30 | ) 31 | else: 32 | return nn.Sequential( 33 | nn.Linear(nc_inp, nc_out), 34 | nn.LeakyReLU(0.1,inplace=True) 35 | ) 36 | 37 | def fc_stack(nc_inp, nc_out, nlayers, use_bn=True): 38 | modules = [] 39 | for l in range(nlayers): 40 | modules.append(fc(use_bn, nc_inp, nc_out)) 41 | nc_inp = nc_out 42 | encoder = nn.Sequential(*modules) 43 | net_init(encoder) 44 | return encoder 45 | 46 | ## 2D convolution layers 47 | def conv2d(batch_norm, in_planes, out_planes, kernel_size=3, stride=1): 48 | if batch_norm: 49 | return nn.Sequential( 50 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=True), 51 | nn.BatchNorm2d(out_planes), 52 | nn.LeakyReLU(0.2,inplace=True) 53 | ) 54 | else: 55 | return nn.Sequential( 56 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=True), 57 | nn.LeakyReLU(0.2,inplace=True) 58 | ) 59 | 60 | 61 | def deconv2d(in_planes, out_planes): 62 | return nn.Sequential( 63 | nn.ConvTranspose2d(in_planes, out_planes, kernel_size=4, stride=2, padding=1, bias=True), 64 | nn.LeakyReLU(0.2,inplace=True) 65 | ) 66 | 67 | 68 | def upconv2d(in_planes, out_planes, mode='bilinear'): 69 | if mode == 'nearest': 70 | print('Using NN upsample!!') 71 | upconv = nn.Sequential( 72 | nn.Upsample(scale_factor=2, mode=mode, align_corners=True), 73 | nn.ReflectionPad2d(1), 74 | nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=1, padding=0), 75 | nn.LeakyReLU(0.2,inplace=True) 76 | ) 77 | return upconv 78 | 79 | 80 | def decoder2d(nlayers, nz_shape, nc_input, use_bn=True, nc_final=1, nc_min=8, nc_step=1, init_fc=True, use_deconv=False, upconv_mode='bilinear'): 81 | ''' Simple 3D encoder with nlayers. 82 | 83 | Args: 84 | nlayers: number of decoder layers 85 | nz_shape: number of bottleneck 86 | nc_input: number of channels to start upconvolution from 87 | use_bn: whether to use batch_norm 88 | nc_final: number of output channels 89 | nc_min: number of min channels 90 | nc_step: double number of channels every nc_step layers 91 | init_fc: initial features are not spatial, use an fc & unsqueezing to make them 3D 92 | ''' 93 | modules = [] 94 | if init_fc: 95 | modules.append(fc(use_bn, nz_shape, nc_input)) 96 | for d in range(3): 97 | modules.append(Unsqueeze(2)) 98 | nc_output = nc_input 99 | for nl in range(nlayers): 100 | if (nl % nc_step==0) and (nc_output//2 >= nc_min): 101 | nc_output = nc_output//2 102 | if use_deconv: 103 | print('Using deconv decoder!') 104 | modules.append(deconv2d(nc_input, nc_output)) 105 | nc_input = nc_output 106 | modules.append(conv2d(use_bn, nc_input, nc_output)) 107 | else: 108 | modules.append(upconv2d(nc_input, nc_output, mode=upconv_mode)) 109 | nc_input = nc_output 110 | modules.append(conv2d(use_bn, nc_input, nc_output)) 111 | 112 | modules.append(nn.Conv2d(nc_output, nc_final, kernel_size=3, stride=1, padding=1, bias=True)) 113 | decoder = nn.Sequential(*modules) 114 | net_init(decoder) 115 | return decoder 116 | 117 | 118 | ## 3D convolution layers 119 | def conv3d(batch_norm, in_planes, out_planes, kernel_size=3, stride=1): 120 | if batch_norm: 121 | return nn.Sequential( 122 | nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=True), 123 | nn.BatchNorm3d(out_planes), 124 | nn.LeakyReLU(0.2,inplace=True) 125 | ) 126 | else: 127 | return nn.Sequential( 128 | nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=True), 129 | nn.LeakyReLU(0.2,inplace=True) 130 | ) 131 | 132 | 133 | def deconv3d(batch_norm, in_planes, out_planes): 134 | if batch_norm: 135 | return nn.Sequential( 136 | nn.ConvTranspose3d(in_planes, out_planes, kernel_size=4, stride=2, padding=1, bias=True), 137 | nn.BatchNorm3d(out_planes), 138 | nn.LeakyReLU(0.2,inplace=True) 139 | ) 140 | else: 141 | return nn.Sequential( 142 | nn.ConvTranspose3d(in_planes, out_planes, kernel_size=4, stride=2, padding=1, bias=True), 143 | nn.LeakyReLU(0.2,inplace=True) 144 | ) 145 | 146 | 147 | ## 3D Network Modules 148 | def encoder3d(nlayers, use_bn=True, nc_input=1, nc_max=128, nc_l1=8, nc_step=1, nz_shape=20): 149 | ''' Simple 3D encoder with nlayers. 150 | 151 | Args: 152 | nlayers: number of encoder layers 153 | use_bn: whether to use batch_norm 154 | nc_input: number of input channels 155 | nc_max: number of max channels 156 | nc_l1: number of channels in layer 1 157 | nc_step: double number of channels every nc_step layers 158 | nz_shape: size of bottleneck layer 159 | ''' 160 | modules = [] 161 | nc_output = nc_l1 162 | for nl in range(nlayers): 163 | if (nl>=1) and (nl%nc_step==0) and (nc_output <= nc_max*2): 164 | nc_output *= 2 165 | 166 | modules.append(conv3d(use_bn, nc_input, nc_output, stride=1)) 167 | nc_input = nc_output 168 | modules.append(conv3d(use_bn, nc_input, nc_output, stride=1)) 169 | modules.append(torch.nn.MaxPool3d(kernel_size=2, stride=2)) 170 | 171 | modules.append(Flatten()) 172 | modules.append(fc_stack(nc_output, nz_shape, 2, use_bn=True)) 173 | encoder = nn.Sequential(*modules) 174 | net_init(encoder) 175 | return encoder, nc_output 176 | 177 | 178 | def decoder3d(nlayers, nz_shape, nc_input, use_bn=True, nc_final=1, nc_min=8, nc_step=1, init_fc=True): 179 | ''' Simple 3D encoder with nlayers. 180 | 181 | Args: 182 | nlayers: number of decoder layers 183 | nz_shape: number of bottleneck 184 | nc_input: number of channels to start upconvolution from 185 | use_bn: whether to use batch_norm 186 | nc_final: number of output channels 187 | nc_min: number of min channels 188 | nc_step: double number of channels every nc_step layers 189 | init_fc: initial features are not spatial, use an fc & unsqueezing to make them 3D 190 | ''' 191 | modules = [] 192 | if init_fc: 193 | modules.append(fc(use_bn, nz_shape, nc_input)) 194 | for d in range(3): 195 | modules.append(Unsqueeze(2)) 196 | nc_output = nc_input 197 | for nl in range(nlayers): 198 | if (nl%nc_step==0) and (nc_output//2 >= nc_min): 199 | nc_output = nc_output//2 200 | 201 | modules.append(deconv3d(use_bn, nc_input, nc_output)) 202 | nc_input = nc_output 203 | modules.append(conv3d(use_bn, nc_input, nc_output)) 204 | 205 | modules.append(nn.Conv3d(nc_output, nc_final, kernel_size=3, stride=1, padding=1, bias=True)) 206 | decoder = nn.Sequential(*modules) 207 | net_init(decoder) 208 | return decoder 209 | 210 | 211 | def net_init(net): 212 | for m in net.modules(): 213 | if isinstance(m, nn.Linear): 214 | #n = m.out_features 215 | #m.weight.data.normal_(0, 0.02 / n) #this modified initialization seems to work better, but it's very hacky 216 | #n = m.in_features 217 | #m.weight.data.normal_(0, math.sqrt(2. / n)) #xavier 218 | m.weight.data.normal_(0, 0.02) 219 | if m.bias is not None: 220 | m.bias.data.zero_() 221 | 222 | if isinstance(m, nn.Conv2d): #or isinstance(m, nn.ConvTranspose2d): 223 | #n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels 224 | #m.weight.data.normal_(0, math.sqrt(2. / n)) #this modified initialization seems to work better, but it's very hacky 225 | m.weight.data.normal_(0, 0.02) 226 | if m.bias is not None: 227 | m.bias.data.zero_() 228 | 229 | if isinstance(m, nn.ConvTranspose2d): 230 | # Initialize Deconv with bilinear weights. 231 | base_weights = bilinear_init(m.weight.data.size(-1)) 232 | base_weights = base_weights.unsqueeze(0).unsqueeze(0) 233 | m.weight.data = base_weights.repeat(m.weight.data.size(0), m.weight.data.size(1), 1, 1) 234 | if m.bias is not None: 235 | m.bias.data.zero_() 236 | 237 | if isinstance(m, nn.Conv3d) or isinstance(m, nn.ConvTranspose3d): 238 | #n = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.in_channels 239 | #m.weight.data.normal_(0, math.sqrt(2. / n)) 240 | m.weight.data.normal_(0, 0.02) 241 | if m.bias is not None: 242 | m.bias.data.zero_() 243 | 244 | elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm3d): 245 | m.weight.data.fill_(1) 246 | m.bias.data.zero_() 247 | 248 | 249 | def bilinear_init(kernel_size=4): 250 | # Following Caffe's BilinearUpsamplingFiller 251 | # https://github.com/BVLC/caffe/pull/2213/files 252 | import numpy as np 253 | width = kernel_size 254 | height = kernel_size 255 | f = int(np.ceil(width / 2.)) 256 | cc = (2 * f - 1 - f % 2) / (2.*f) 257 | weights = torch.zeros((height, width)) 258 | for y in range(height): 259 | for x in range(width): 260 | weights[y, x] = (1 - np.abs(x / f - cc)) * (1 - np.abs(y / f - cc)) 261 | 262 | return weights 263 | 264 | 265 | if __name__ == '__main__': 266 | decoder2d(5, None, 256, use_deconv=True, init_fc=False) 267 | bilinear_init() -------------------------------------------------------------------------------- /lib/models/reconstruction.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from lib.rendering.utils import circpad, symmetrize_texture, adjust_poles 6 | import numpy as np 7 | from torch.autograd import Variable 8 | 9 | class ResBlock(nn.Module): 10 | def __init__(self, ch_in, ch_out, pad_fn): 11 | super().__init__() 12 | self.conv1 = nn.Conv2d(ch_in, ch_in, 3, padding=(1, 0), bias=False) 13 | self.conv2 = nn.Conv2d(ch_in, ch_out, 3, padding=(1, 0), bias=False) 14 | self.bn1 = nn.BatchNorm2d(ch_in) 15 | self.bn2 = nn.BatchNorm2d(ch_out) 16 | self.relu = nn.ReLU(inplace=True) 17 | self.pad_fn = pad_fn 18 | if ch_in != ch_out: 19 | self.shortcut = nn.Conv2d(ch_in, ch_out, 1, bias=False) 20 | else: 21 | self.shortcut = lambda x: x 22 | 23 | def forward(self, x): 24 | shortcut = self.shortcut(x) 25 | x = self.relu(self.bn1(self.conv1(self.pad_fn(x, 1)))) 26 | x = self.relu(self.bn2(self.conv2(self.pad_fn(x, 1)))) 27 | return x + shortcut 28 | 29 | 30 | class ShapeRetriever(nn.Module): 31 | """ 32 | output classification 33 | """ 34 | def __init__(self, nz_feat, n_mean_shapes): 35 | super(ShapeRetriever, self).__init__() 36 | self.pred_layer = nn.Sequential( 37 | nn.Linear(nz_feat, n_mean_shapes), 38 | nn.Softmax()) 39 | 40 | def forward(self, feat): 41 | wt = self.pred_layer.forward(feat) 42 | # Make it B x num_verts x 3 43 | # delta_v = delta_v.view(delta_v.size(0), -1, 3) 44 | # print('shape: ( Mean = {}, Var = {} )'.format(delta_v.mean().data[0], delta_v.var().data[0])) 45 | return wt 46 | 47 | 48 | 49 | class ReconstructionNetwork(nn.Module): 50 | def __init__(self, symmetric=True, texture_res=64, mesh_res=32, interpolation_mode='nearest', \ 51 | use_multitpl=False,n_templates=0,use_kp=False): 52 | super().__init__() 53 | 54 | self.symmetric = symmetric 55 | 56 | if symmetric: 57 | self.pad = lambda x, amount: F.pad(x, (amount, amount, 0, 0), mode='replicate') 58 | else: 59 | self.pad = lambda x, amount: circpad(x, amount) 60 | 61 | self.relu = nn.ReLU(inplace=True) 62 | 63 | if interpolation_mode == 'nearest': 64 | self.up = lambda x: F.interpolate(x, scale_factor=2, mode='nearest') 65 | elif interpolation_mode == 'bilinear': 66 | self.up = lambda x: F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) 67 | else: 68 | raise 69 | 70 | assert mesh_res >= 32 71 | assert texture_res >= 64 72 | 73 | self.conv1e = nn.Conv2d(4, 64, 5, stride=2, padding=2, bias=False) # 128 -> 64 74 | self.bn1e = nn.BatchNorm2d(64) 75 | self.conv2e = nn.Conv2d(64, 128, 3, stride=2, padding=1, bias=False) # 64 > 32 76 | self.bn2e = nn.BatchNorm2d(128) 77 | self.conv3e = nn.Conv2d(128, 256, 3, stride=2, padding=1, bias=False) # 32 -> 16 78 | self.bn3e = nn.BatchNorm2d(256) 79 | self.conv4e = nn.Conv2d(256, 512, 3, stride=2, padding=1, bias=False) # 16 -> 8 80 | self.bn4e = nn.BatchNorm2d(512) 81 | 82 | bottleneck_dim = 256 83 | self.conv5e = nn.Conv2d(512, 64, 3, stride=2, padding=1, bias=False) # 8 -> 4 84 | self.bn5e = nn.BatchNorm2d(64) 85 | self.fc1e = nn.Linear(64*8*8, bottleneck_dim, bias=False) 86 | self.bnfc1e = nn.BatchNorm1d(bottleneck_dim) 87 | 88 | self.fc3e = nn.Linear(bottleneck_dim, 1024, bias=False) 89 | self.bnfc3e = nn.BatchNorm1d(1024) 90 | 91 | # Texture generation 92 | self.base_res_h = 4 93 | self.base_res_w = 2 if symmetric else 4 94 | 95 | self.fc1_tex = nn.Linear(1024, self.base_res_h*self.base_res_w*256) 96 | self.blk1 = ResBlock(256, 512, self.pad) # 4 -> 8 97 | self.blk2 = ResBlock(512, 256, self.pad) # 8 -> 16 98 | self.blk3 = ResBlock(256, 256, self.pad) # 16 -> 32 (k=1) 99 | 100 | assert texture_res in [64, 128, 256] 101 | self.texture_res = texture_res 102 | if texture_res >= 128: 103 | self.blk3b_tex = ResBlock(256, 256, self.pad) # k = 2 104 | if texture_res >= 256: 105 | self.blk3c_tex = ResBlock(256, 256, self.pad) # k = 4 106 | 107 | self.blk4_tex = ResBlock(256, 128, self.pad) # k*32 -> k*64 108 | self.blk5_tex = ResBlock(128, 64, self.pad) # k*64 -> k*64 (no upsampling) 109 | 110 | self.conv_tex = nn.Conv2d(64, 3, 5, padding=(2, 0)) 111 | 112 | # Mesh generation 113 | self.blk4_mesh = ResBlock(256, 64, self.pad) # 32 -> 32 (no upsampling) 114 | self.conv_mesh = nn.Conv2d(64, 3, 5, padding=(2, 0)) 115 | 116 | # Zero-initialize mesh output layer for stability (avoids self-intersections) 117 | self.conv_mesh.bias.data[:] = 0 118 | self.conv_mesh.weight.data[:] = 0 119 | 120 | self.use_multitpl = use_multitpl 121 | if self.use_multitpl: 122 | self.shape_wt = ShapeRetriever(nz_feat=256, n_mean_shapes=n_templates) 123 | self.use_kp = use_kp 124 | if self.use_kp: 125 | num_kps, num_verts = 15, 482 # NOTE only for birds 126 | vert2kp_init = torch.Tensor(np.ones((num_kps, num_verts)) / float(num_verts)) 127 | # vert2kp_init = torch.rand([num_kps,num_verts]) 128 | 129 | # Remember initial vert2kp (after softmax) 130 | self.vert2kp_init = torch.nn.functional.softmax(Variable(vert2kp_init.cuda(), requires_grad=False), dim=1) 131 | self.vert2kp = nn.Parameter(vert2kp_init) 132 | # self.register_parameter(name='vert2kp',param=nn.Parameter(vert2kp_init)) 133 | 134 | total_params = 0 135 | for param in self.parameters(): 136 | total_params += param.nelement() 137 | print('Model parameters: {:.2f}M'.format(total_params/1000000)) 138 | 139 | 140 | def forward(self, x): 141 | other_outputs = {} 142 | # Generate latent code 143 | x = self.relu(self.bn1e(self.conv1e(x))) 144 | x = self.relu(self.bn2e(self.conv2e(x))) 145 | x = self.relu(self.bn3e(self.conv3e(x))) 146 | x = self.relu(self.bn4e(self.conv4e(x))) 147 | x = self.relu(self.bn5e(self.conv5e(x))) 148 | 149 | x = x.view(x.shape[0], -1) # Flatten 150 | z = self.relu(self.bnfc1e(self.fc1e(x))) 151 | 152 | if self.use_multitpl: 153 | wt = self.shape_wt(z) # (B, k) 154 | other_outputs['wt'] = wt 155 | 156 | z = self.relu(self.bnfc3e(self.fc3e(z))) 157 | 158 | bb = self.fc1_tex(z).view(z.shape[0], -1, self.base_res_h, self.base_res_w) 159 | bb = self.up(self.blk1(bb)) 160 | bb = self.up(self.blk2(bb)) 161 | bb = self.up(self.blk3(bb)) 162 | bb_mesh = bb 163 | if self.texture_res >= 128: 164 | bb = self.up(self.blk3b_tex(bb)) 165 | if self.texture_res >= 256: 166 | bb = self.up(self.blk3c_tex(bb)) 167 | 168 | mesh_map = self.blk4_mesh(bb_mesh) 169 | mesh_map = self.conv_mesh(self.pad(self.relu(mesh_map), 2)) 170 | mesh_map = adjust_poles(mesh_map) 171 | 172 | tex = self.up(self.blk4_tex(bb)) 173 | tex = self.blk5_tex(tex) 174 | tex = self.conv_tex(self.pad(self.relu(tex), 2)).tanh_() 175 | 176 | if self.symmetric: 177 | tex = symmetrize_texture(tex) 178 | mesh_map = symmetrize_texture(mesh_map) 179 | 180 | 181 | if self.use_kp: 182 | vert2kp = torch.nn.functional.softmax(self.vert2kp, dim=1) # [15, 482] 183 | other_outputs['vert2kp'] = vert2kp 184 | 185 | return tex, mesh_map, other_outputs 186 | # # tex torch.Size([50, 3, 128, 128]) 187 | # # mesh_map: torch.Size([50, 3, 128, 128]) 188 | # return tex, mesh_map 189 | 190 | 191 | class DatasetParams(nn.Module): 192 | def __init__(self, args, dataset_size): 193 | super().__init__() 194 | # Dataset offsets 195 | self.dataset_size = dataset_size 196 | if args.optimize_deltas: 197 | self.ds_translation = nn.Parameter(torch.zeros(dataset_size, 2)) 198 | self.ds_scale = nn.Parameter(torch.zeros(dataset_size, 1)) 199 | if args.optimize_z0: 200 | self.ds_z0 = nn.Parameter(torch.ones(dataset_size, 1)) 201 | 202 | def forward(self, indices, mode): 203 | assert mode in ['deltas', 'z0'] 204 | if indices is not None: 205 | # Indices between N and 2N indicate that the image is mirrored (data augmentation) 206 | # Therefore, we flip the sign of the x translation 207 | x_sign = (1 - 2*(indices // self.dataset_size).float()).unsqueeze(-1) 208 | indices = indices % self.dataset_size 209 | else: 210 | x_sign = 1 211 | 212 | if mode == 'deltas': 213 | if indices is not None: 214 | translation_delta = self.ds_translation[indices] 215 | else: 216 | translation_delta = self.ds_translation.mean(dim=0, keepdim=True) 217 | 218 | translation_delta = torch.cat((translation_delta[:, :1] * x_sign, 219 | translation_delta[:, 1:2], 220 | torch.zeros_like(translation_delta[:, :1])), dim=1) 221 | if indices is not None: 222 | scale_delta = self.ds_scale[indices] 223 | else: 224 | scale_delta = self.ds_scale.mean(dim=0, keepdim=True) 225 | return translation_delta, scale_delta 226 | else: # z0 227 | if indices is not None: 228 | z0 = self.ds_z0[indices] 229 | else: 230 | z0 = self.ds_z0.mean(dim=0, keepdim=True) 231 | return 1 + torch.exp(z0) -------------------------------------------------------------------------------- /lib/rendering/cmr_geom_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utils related to geometry like projection,, 3 | """ 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import torch 9 | 10 | 11 | def sample_textures(texture_flow, images): 12 | """ 13 | texture_flow: B x F x T x T x 2 14 | (In normalized coordinate [-1, 1]) 15 | images: B x 3 x N x N 16 | 17 | output: B x F x T x T x 3 18 | """ 19 | # Reshape into B x F x T*T x 2 20 | T = texture_flow.size(-2) 21 | F = texture_flow.size(1) 22 | flow_grid = texture_flow.view(-1, F, T * T, 2) 23 | # B x 3 x F x T*T 24 | samples = torch.nn.functional.grid_sample(images, flow_grid, align_corners=True) 25 | # B x 3 x F x T x T 26 | samples = samples.view(-1, 3, F, T, T) 27 | # B x F x T x T x 3 28 | return samples.permute(0, 2, 3, 4, 1) 29 | 30 | def orthographic_proj(X, cam): 31 | """ 32 | X: B x N x 3 33 | cam: B x 7: [sc, tx, ty, quaternions] 34 | """ 35 | quat = cam[:, -4:] 36 | X_rot = quat_rotate(X, quat) 37 | 38 | scale = cam[:, 0].contiguous().view(-1, 1, 1) 39 | trans = cam[:, 1:3].contiguous().view(cam.size(0), 1, -1) 40 | 41 | return scale * X_rot[:, :, :2] + trans 42 | 43 | def orthographic_proj_withz(X, cam, offset_z=0.): 44 | """ 45 | X: B x N x 3 46 | cam: B x 7: [sc, tx, ty, quaternions] 47 | Orth preserving the z. 48 | """ 49 | quat = cam[:, -4:] 50 | X_rot = quat_rotate(X, quat) 51 | 52 | scale = cam[:, 0].contiguous().view(-1, 1, 1) 53 | trans = cam[:, 1:3].contiguous().view(cam.size(0), 1, -1) 54 | 55 | proj = scale * X_rot 56 | 57 | proj_xy = proj[:, :, :2] + trans 58 | proj_z = proj[:, :, 2, None] + offset_z 59 | 60 | return torch.cat((proj_xy, proj_z), 2) 61 | 62 | 63 | def cross_product(qa, qb): 64 | """Cross product of va by vb. 65 | 66 | Args: 67 | qa: B X N X 3 vectors 68 | qb: B X N X 3 vectors 69 | Returns: 70 | q_mult: B X N X 3 vectors 71 | """ 72 | qa_0 = qa[:, :, 0] 73 | qa_1 = qa[:, :, 1] 74 | qa_2 = qa[:, :, 2] 75 | 76 | qb_0 = qb[:, :, 0] 77 | qb_1 = qb[:, :, 1] 78 | qb_2 = qb[:, :, 2] 79 | 80 | # See https://en.wikipedia.org/wiki/Cross_product 81 | q_mult_0 = qa_1*qb_2 - qa_2*qb_1 82 | q_mult_1 = qa_2*qb_0 - qa_0*qb_2 83 | q_mult_2 = qa_0*qb_1 - qa_1*qb_0 84 | 85 | return torch.stack([q_mult_0, q_mult_1, q_mult_2], dim=-1) 86 | 87 | 88 | def hamilton_product(qa, qb): 89 | """Multiply qa by qb. 90 | 91 | Args: 92 | qa: B X N X 4 quaternions 93 | qb: B X N X 4 quaternions 94 | Returns: 95 | q_mult: B X N X 4 96 | """ 97 | qa_0 = qa[:, :, 0] 98 | qa_1 = qa[:, :, 1] 99 | qa_2 = qa[:, :, 2] 100 | qa_3 = qa[:, :, 3] 101 | 102 | qb_0 = qb[:, :, 0] 103 | qb_1 = qb[:, :, 1] 104 | qb_2 = qb[:, :, 2] 105 | qb_3 = qb[:, :, 3] 106 | 107 | # See https://en.wikipedia.org/wiki/Quaternion#Hamilton_product 108 | q_mult_0 = qa_0*qb_0 - qa_1*qb_1 - qa_2*qb_2 - qa_3*qb_3 109 | q_mult_1 = qa_0*qb_1 + qa_1*qb_0 + qa_2*qb_3 - qa_3*qb_2 110 | q_mult_2 = qa_0*qb_2 - qa_1*qb_3 + qa_2*qb_0 + qa_3*qb_1 111 | q_mult_3 = qa_0*qb_3 + qa_1*qb_2 - qa_2*qb_1 + qa_3*qb_0 112 | 113 | return torch.stack([q_mult_0, q_mult_1, q_mult_2, q_mult_3], dim=-1) 114 | 115 | 116 | def quat_rotate(X, q): 117 | """Rotate points by quaternions. 118 | 119 | Args: 120 | X: B X N X 3 points 121 | q: B X 4 quaternions 122 | 123 | Returns: 124 | X_rot: B X N X 3 (rotated points) 125 | """ 126 | # repeat q along 2nd dim 127 | ones_x = X[[0], :, :][:, :, [0]]*0 + 1 128 | q = torch.unsqueeze(q, 1)*ones_x 129 | 130 | q_conj = torch.cat([ q[:, :, [0]] , -1*q[:, :, 1:4] ], dim=-1) 131 | X = torch.cat([ X[:, :, [0]]*0, X ], dim=-1) 132 | 133 | X_rot = hamilton_product(q, hamilton_product(X, q_conj)) 134 | return X_rot[:, :, 1:4] 135 | -------------------------------------------------------------------------------- /lib/rendering/cmr_mesh.py: -------------------------------------------------------------------------------- 1 | """ 2 | Mesh stuff. 3 | """ 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import numpy as np 10 | from lib.rendering.cmr_meshzoo import iso_sphere 11 | 12 | def create_sphere(n_subdivide=3): 13 | # 3 makes 642 verts, 1280 faces, 14 | # 4 makes 2562 verts, 5120 faces 15 | verts, faces = iso_sphere(n_subdivide) 16 | return verts, faces 17 | 18 | 19 | def make_symmetric(verts, faces): 20 | """ 21 | Assumes that the input mesh {V,F} is perfectly symmetric 22 | Splits the mesh along the X-axis, and reorders the mesh s.t. 23 | (so this is reflection on Y-axis..) 24 | [indept verts, right (x>0) verts, left verts] 25 | 26 | v[:num_indept + num_sym] = A 27 | v[:-num_sym] = -A[num_indept:] 28 | """ 29 | left = verts[:, 0] < 0 30 | right = verts[:, 0] > 0 31 | center = verts[:, 0] == 0 32 | 33 | left_inds = np.where(left)[0] 34 | right_inds = np.where(right)[0] 35 | center_inds = np.where(center)[0] 36 | 37 | num_indept = len(center_inds) 38 | num_sym = len(left_inds) 39 | assert(len(left_inds) == len(right_inds)) 40 | 41 | # For each right verts, find the corresponding left verts. 42 | prop_left_inds = np.hstack([np.where(np.all(verts == np.array([-1, 1, 1]) * verts[ri], 1))[0] for ri in right_inds]) 43 | assert(prop_left_inds.shape[0] == num_sym) 44 | 45 | # Make sure right/left order are symmetric. 46 | for ind, (ri, li) in enumerate(zip(right_inds, prop_left_inds)): 47 | if np.any(verts[ri] != np.array([-1, 1, 1]) * verts[li]): 48 | print('bad! %d' % ind) 49 | import pdb; pdb.set_trace() 50 | 51 | new_order = np.hstack([center_inds, right_inds, prop_left_inds]) 52 | # verts i is now vert j 53 | ind_perm = np.hstack([np.where(new_order==i)[0] for i in range(verts.shape[0])]) 54 | 55 | new_verts = verts[new_order, :] 56 | new_faces0 = ind_perm[faces] 57 | 58 | new_faces, num_indept_faces, num_sym_faces = make_faces_symmetric(new_verts, new_faces0, num_indept, num_sym) 59 | 60 | return new_verts, new_faces, num_indept, num_sym, num_indept_faces, num_sym_faces 61 | 62 | def make_faces_symmetric(verts, faces, num_indept_verts, num_sym_verts): 63 | """ 64 | This reorders the faces, such that it has this order: 65 | F_indept - independent face ids 66 | F_right (x>0) 67 | F_left 68 | 69 | 1. For each face, identify whether it's independent or has a symmetric face. 70 | 71 | A face is independent, if v_i is an independent vertex and if the other two v_j, v_k are the symmetric pairs. 72 | Otherwise, there are two kinds of symmetric faces: 73 | - v_i is indept, v_j, v_k are not the symmetric paris) 74 | - all three have symmetric counter verts. 75 | 76 | Returns a new set of faces that is in the above order. 77 | Also, the symmetric face pairs are reordered so that the vertex order is the same. 78 | i.e. verts[f_id] and verts[f_id_sym] is in the same vertex order, except the x coord are flipped 79 | """ 80 | DRAW = False 81 | indept_faces = [] 82 | right_faces = [] 83 | left_faces = [] 84 | 85 | indept_verts = verts[:num_indept_verts] 86 | symmetric_verts = verts[num_indept_verts:] 87 | # These are symmetric pairs 88 | right_ids = np.arange(num_indept_verts, num_indept_verts+num_sym_verts) 89 | left_ids = np.arange(num_indept_verts+num_sym_verts, num_indept_verts+2*num_sym_verts) 90 | # Make this for easy lookup 91 | # Saves for each vert_id, the symmetric vert_ids 92 | v_dict = {} 93 | for r_id, l_id in zip(right_ids, left_ids): 94 | v_dict[r_id] = l_id 95 | v_dict[l_id] = r_id 96 | # Return itself for indepentnet. 97 | for ind in range(num_indept_verts): 98 | v_dict[ind] = ind 99 | 100 | # Saves faces that contain this verts 101 | verts2faces = [np.where((faces == v_id).any(axis=1))[0] for v_id in range(verts.shape[0])] 102 | done_face = np.zeros(faces.shape[0]) 103 | # Make faces symmetric: 104 | for f_id in range(faces.shape[0]): 105 | if done_face[f_id]: 106 | continue 107 | v_ids = sorted(faces[f_id]) 108 | # This is triangles x [x,y,z] 109 | vs = verts[v_ids] 110 | # Find the corresponding vs? 111 | v_sym_ids = sorted([v_dict[v_id] for v_id in v_ids]) 112 | 113 | # Check if it's independent 114 | if sorted(v_sym_ids) == sorted(v_ids): 115 | # Independent!! 116 | indept_faces.append(faces[f_id]) 117 | # indept_faces.append(f_id) 118 | done_face[f_id] = 1 119 | else: 120 | # Find the face with these verts. (so we can mark it done) 121 | possible_faces = np.hstack([verts2faces[v_id] for v_id in v_sym_ids]) 122 | possible_fids, counts = np.unique(possible_faces, return_counts=True) 123 | # The face id is the one that appears 3 times in this list. 124 | sym_fid = possible_fids[counts == 3][0] 125 | assert(sorted(v_sym_ids) == sorted(faces[sym_fid])) 126 | # Make sure that the order of these vertices are the same. 127 | # Go in the order of face: f_id 128 | face_here = faces[f_id] 129 | sym_face_here = [v_dict[v_id] for v_id in face_here] 130 | # Above is the same tri as faces[sym_fid], but vertices are in the order of faces[f_id] 131 | # Which one is right x > 0? 132 | # Only use unique verts in these faces to compute. 133 | unique_vids = np.array(v_ids) != np.array(v_sym_ids) 134 | if np.all(verts[face_here][unique_vids, 0] < verts[sym_face_here][unique_vids, 0]): 135 | # f_id is left 136 | left_faces.append(face_here) 137 | right_faces.append(sym_face_here) 138 | else: 139 | left_faces.append(sym_face_here) 140 | right_faces.append(face_here) 141 | done_face[f_id] = 1 142 | done_face[sym_fid] = 1 143 | # Draw 144 | # tri_sym = Mesh(verts[v_sym_ids], [[0, 1, 2]], vc='red') 145 | # mv.set_dynamic_meshes([mesh, tri, tri_sym]) 146 | 147 | assert(len(left_faces) + len(right_faces) + len(indept_faces) == faces.shape[0]) 148 | # Now concatenate them,, 149 | new_faces = np.vstack([indept_faces, right_faces, left_faces]) 150 | # Now sort each row of new_faces to make sure that bary centric coord will be same. 151 | num_indept_faces = len(indept_faces) 152 | num_sym_faces = len(right_faces) 153 | 154 | return new_faces, num_indept_faces, num_sym_faces 155 | 156 | 157 | def compute_edges2verts(verts, faces): 158 | """ 159 | Returns a list: [A, B, C, D] the 4 vertices for each edge. 160 | """ 161 | edge_dict = {} 162 | for face_id, (face) in enumerate(faces): 163 | for e1, e2, o_id in [(0, 1, 2), (0, 2, 1), (1, 2, 0)]: 164 | edge = tuple(sorted((face[e1], face[e2]))) 165 | other_v = face[o_id] 166 | if edge not in edge_dict.keys(): 167 | edge_dict[edge] = [other_v] 168 | else: 169 | if other_v not in edge_dict[edge]: 170 | edge_dict[edge].append(other_v) 171 | result = np.stack([np.hstack((edge, other_vs)) for edge, other_vs in edge_dict.items()]) 172 | return result 173 | 174 | def compute_vert2kp(verts, mean_shape): 175 | # verts: N x 3 176 | # mean_shape: 3 x K (K=15) 177 | # 178 | # computes vert2kp: K x N matrix by picking NN to each point in mean_shape. 179 | 180 | if mean_shape.shape[0] == 3: 181 | # Make it K x 3 182 | mean_shape = mean_shape.T 183 | num_kp = mean_shape.shape[1] 184 | 185 | nn_inds = [np.argmin(np.linalg.norm(verts - pt, axis=1)) for pt in mean_shape] 186 | 187 | dists = np.stack([np.linalg.norm(verts - verts[nn_ind], axis=1) for nn_ind in nn_inds]) 188 | vert2kp = -.5*(dists)/.01 189 | return vert2kp 190 | 191 | def get_spherical_coords(X): 192 | 193 | # X is N x 3 194 | rad = np.linalg.norm(X, axis=1) 195 | # Inclination 196 | theta = np.arccos(X[:, 2] / rad) 197 | # Azimuth 198 | phi = np.arctan2(X[:, 1], X[:, 0]) 199 | 200 | # Normalize both to be between [-1, 1] 201 | vv = (theta / np.pi) * 2 - 1 202 | uu = ((phi + np.pi) / (2*np.pi)) * 2 - 1 203 | # Return N x 2 204 | return np.stack([uu, vv],1) 205 | 206 | 207 | def compute_uvsampler(verts, faces, tex_size=2): 208 | """ 209 | For this mesh, pre-computes the UV coordinates for 210 | F x T x T points. 211 | Returns F x T x T x 2 212 | """ 213 | # verts: (642, 3); faces: (656, 3) 214 | 215 | alpha = np.arange(tex_size, dtype=np.float) / (tex_size-1) 216 | beta = np.arange(tex_size, dtype=np.float) / (tex_size-1) 217 | import itertools 218 | # Barycentric coordinate values 219 | coords = np.stack([p for p in itertools.product(*[alpha, beta])]) # (36, 2) 220 | vs = verts[faces] # (656, 3, 3) (F, 3 nodes, posistions) 221 | # Compute alpha, beta (this is the same order as NMR) 222 | v2 = vs[:, 2] # (656, 3) 223 | v0v2 = vs[:, 0] - vs[:, 2] # (656, 3) 224 | v1v2 = vs[:, 1] - vs[:, 2] # (656, 3) 225 | # F x 3 x T*2 226 | samples = np.dstack([v0v2, v1v2]).dot(coords.T) + v2.reshape(-1, 3, 1) 227 | # F x T*2 x 3 points on the sphere 228 | samples = np.transpose(samples, (0, 2, 1)) # (656, 36, 3) 229 | # Now convert these to uv. 230 | uv = get_spherical_coords(samples.reshape(-1, 3)) 231 | # uv = uv.reshape(-1, len(coords), 2) 232 | 233 | uv = uv.reshape(-1, tex_size, tex_size, 2) 234 | return uv 235 | 236 | 237 | def append_obj(mf_handle, vertices, faces): 238 | for vx in range(vertices.shape[0]): 239 | mf_handle.write('v {:f} {:f} {:f}\n'.format(vertices[vx, 0], vertices[vx, 1], vertices[vx, 2])) 240 | for fx in range(faces.shape[0]): 241 | mf_handle.write('f {:d} {:d} {:d}\n'.format(faces[fx, 0], faces[fx, 1], faces[fx, 2])) 242 | return 243 | -------------------------------------------------------------------------------- /lib/rendering/cmr_nmr_kaolin.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | 6 | import numpy as np 7 | import scipy.misc 8 | import tqdm 9 | import cv2 10 | 11 | import torch 12 | 13 | import rendering.cmr_geom_utils as geom_utils 14 | 15 | from kaolin.graphics.NeuralMeshRenderer import NeuralMeshRenderer 16 | 17 | class NeuralRenderer(torch.nn.Module): 18 | """ 19 | replace NeuralRenderer from nmr.py with the kaolin's 20 | """ 21 | def __init__(self, img_size=256,uv_sampler=None): 22 | super(NeuralRenderer, self).__init__() 23 | self.renderer = NeuralMeshRenderer(image_size=img_size, camera_mode='look_at',perspective=False,viewing_angle=30,light_intensity_ambient=0.8) 24 | # 30 degree is equivalent to self.renderer.eye = [0, 0, -2.732] 25 | 26 | self.offset_z = 5. 27 | self.proj_fn = geom_utils.orthographic_proj_withz 28 | print('NMR-kaolin initiated') 29 | 30 | def ambient_light_only(self): 31 | # Make light only ambient. 32 | self.renderer.light_intensity_ambient = 1 33 | self.renderer.light_intensity_directional = 0 34 | 35 | def set_bgcolor(self, color): 36 | self.renderer.background_color = color 37 | 38 | def project_points(self, verts, cams): 39 | proj = self.proj_fn(verts, cams) 40 | return proj[:, :, :2] 41 | 42 | def forward(self, vertices, faces, cams, textures=None): 43 | # faces: B, 1280, 3 44 | # vertices:[B, 642, 3] 45 | # texture: texture: [B, 1280, 6, 6, 6, 3] - should be RGB, with range [0,1] 46 | # if textures is not None: 47 | 48 | verts = self.proj_fn(vertices, cams, offset_z=self.offset_z) 49 | 50 | vs = verts.clone() 51 | vs[:, :, 1] *= -1 52 | fs = faces.clone() 53 | if textures is None: 54 | self.mask_only = True 55 | masks = self.renderer.render_silhouettes(vs,fs) 56 | return masks 57 | else: 58 | self.mask_only = False 59 | ts = textures.clone() 60 | # print('tx shape:',ts.shape) 61 | imgs = self.renderer.render(vs, fs, ts)[0] #only keep rgb, no alpha and depth 62 | # for i, img in enumerate(imgs): 63 | 64 | # img = img.permute([1,2,0]).detach().cpu().numpy() 65 | # 66 | # cv2.imwrite('./vis/nmr'+str(i)+'.jpg',img*255) 67 | # print('saved img') 68 | # print('!!!imgs:',imgs.shape) 69 | return imgs -------------------------------------------------------------------------------- /lib/rendering/fragment_shader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn 3 | 4 | from .utils import grid_sample_bilinear 5 | 6 | def texinterpolation(imtexcoord_bxhxwx2, texture_bx3xthxtw, filtering='bilinear'): 7 | 8 | imtexcoord_bxhxwx2 = imtexcoord_bxhxwx2 * 2 - 1 # [0, 1] to [-1, 1] 9 | imtexcoord_bxhxwx2 = imtexcoord_bxhxwx2 * torch.FloatTensor([1, -1]).to(imtexcoord_bxhxwx2.device) # Flip y 10 | 11 | if filtering == 'bilinear': 12 | # Ensures consistent behavior across different PyTorch versions 13 | texcolor = grid_sample_bilinear(texture_bx3xthxtw, imtexcoord_bxhxwx2) 14 | else: 15 | texcolor = torch.nn.functional.grid_sample(texture_bx3xthxtw, 16 | imtexcoord_bxhxwx2, 17 | mode=filtering) 18 | 19 | texcolor = texcolor.permute(0, 2, 3, 1) 20 | return texcolor 21 | 22 | def fragmentshader(imtexcoord_bxhxwx2, 23 | texture_bx3xthxtw, 24 | improb_bxhxwx1, 25 | filtering='bilinear', 26 | background_image=None): 27 | 28 | texcolor_bxhxwx3 = texinterpolation(imtexcoord_bxhxwx2, 29 | texture_bx3xthxtw, 30 | filtering=filtering) 31 | 32 | if background_image is None: 33 | color = texcolor_bxhxwx3 * improb_bxhxwx1 34 | else: 35 | color = torch.lerp(background_image, texcolor_bxhxwx3, improb_bxhxwx1) 36 | 37 | return color 38 | -------------------------------------------------------------------------------- /lib/rendering/monkey_patches.py: -------------------------------------------------------------------------------- 1 | # At the time of writing, Kaolin is not compatible with PyTorch >= 1.5, 2 | # so we monkey patch the methods we need to make it work with newer versions. 3 | 4 | # Adapted from: https://github.com/NVIDIAGameWorks/kaolin/blob/e7e513173bd4159ae45be6b3e156a3ad156a3eb9/kaolin/rep/Mesh.py#L586-L735 5 | 6 | import torch 7 | 8 | def compute_adjacency_info_patched(vertices: torch.Tensor, faces: torch.Tensor): 9 | """Build data structures to help speed up connectivity queries. Assumes 10 | a homogeneous mesh, i.e., each face has the same number of vertices. 11 | 12 | The outputs have the following format: AA, AA_count 13 | 14 | AA_count: ``[count_0, ..., count_n]`` 15 | 16 | with AA: 17 | 18 | .. code-block:: 19 | 20 | [[aa_{0,0}, ..., aa_{0,count_0} (, -1, ..., -1)], 21 | [aa_{1,0}, ..., aa_{1,count_1} (, -1, ..., -1)], 22 | ... 23 | [aa_{n,0}, ..., aa_{n,count_n} (, -1, ..., -1)]] 24 | """ 25 | 26 | device = vertices.device 27 | facesize = faces.shape[1] 28 | nb_vertices = vertices.shape[0] 29 | nb_faces = faces.shape[0] 30 | edges = torch.cat([faces[:,i:i+2] for i in range(facesize - 1)] + 31 | [faces[:,[-1,0]]], dim=0) 32 | # Sort the vertex of edges in increasing order 33 | edges = torch.sort(edges, dim=1)[0] 34 | # id of corresponding face in edges 35 | face_ids = torch.arange(nb_faces, device=device, dtype=torch.long).repeat(facesize) 36 | # remove multiple occurences and sort by the first vertex 37 | # the edge key / id is fixed from now as the first axis position 38 | # edges_ids will give the key of the edges on the original vector 39 | edges, edges_ids = torch.unique(edges, sorted=True, return_inverse=True, dim=0) 40 | nb_edges = edges.shape[0] 41 | 42 | # EDGE2EDGES 43 | _edges_ids = edges_ids.reshape(facesize, nb_faces) 44 | edges2edges = torch.cat([ 45 | torch.stack([_edges_ids[1:], _edges_ids[:-1]], dim=-1).reshape(-1, 2), 46 | torch.stack([_edges_ids[-1:], _edges_ids[:1]], dim=-1).reshape(-1, 2) 47 | ], dim=0) 48 | 49 | double_edges2edges = torch.cat([edges2edges, torch.flip(edges2edges, dims=(1,))], dim=0) 50 | double_edges2edges = torch.cat( 51 | [double_edges2edges, torch.arange(double_edges2edges.shape[0], device=device, dtype=torch.long).reshape(-1, 1)], dim=1) 52 | double_edges2edges = torch.unique(double_edges2edges, sorted=True, dim=0)[:,:2] 53 | idx_first = torch.where( 54 | torch.nn.functional.pad(double_edges2edges[1:,0] != double_edges2edges[:-1,0], 55 | (1, 0), value=1))[0] 56 | nb_edges_per_edge = idx_first[1:] - idx_first[:-1] 57 | offsets = torch.zeros(double_edges2edges.shape[0], device=device, dtype=torch.long) 58 | offsets[idx_first[1:]] = nb_edges_per_edge 59 | sub_idx = (torch.arange(double_edges2edges.shape[0], device=device,dtype=torch.long) - 60 | torch.cumsum(offsets, dim=0)) 61 | nb_edges_per_edge = torch.cat([nb_edges_per_edge, 62 | double_edges2edges.shape[0] - idx_first[-1:]], 63 | dim=0) 64 | max_sub_idx = torch.max(nb_edges_per_edge) 65 | ee = torch.full((nb_edges, max_sub_idx), device=device, dtype=torch.long, fill_value=-1) 66 | ee[double_edges2edges[:,0], sub_idx] = double_edges2edges[:,1] 67 | 68 | # EDGE2FACE 69 | sorted_edges_ids, order_edges_ids = torch.sort(edges_ids) 70 | sorted_faces_ids = face_ids[order_edges_ids] 71 | # indices of first occurences of each key 72 | idx_first = torch.where( 73 | torch.nn.functional.pad(sorted_edges_ids[1:] != sorted_edges_ids[:-1], 74 | (1,0), value=1))[0] 75 | nb_faces_per_edge = idx_first[1:] - idx_first[:-1] 76 | # compute sub_idx (2nd axis indices to store the faces) 77 | offsets = torch.zeros(sorted_edges_ids.shape[0], device=device, dtype=torch.long) 78 | offsets[idx_first[1:]] = nb_faces_per_edge 79 | sub_idx = (torch.arange(sorted_edges_ids.shape[0], device=device, dtype=torch.long) - 80 | torch.cumsum(offsets, dim=0)) 81 | # TODO(cfujitsang): potential way to compute sub_idx differently 82 | # to test with bigger model 83 | #sub_idx = torch.ones(sorted_edges_ids.shape[0], device=device, dtype=torch.long) 84 | #sub_idx[0] = 0 85 | #sub_idx[idx_first[1:]] = 1 - nb_faces_per_edge 86 | #sub_idx = torch.cumsum(sub_idx, dim=0) 87 | nb_faces_per_edge = torch.cat([nb_faces_per_edge, 88 | sorted_edges_ids.shape[0] - idx_first[-1:]], 89 | dim=0) 90 | max_sub_idx = torch.max(nb_faces_per_edge) 91 | ef = torch.full((nb_edges, max_sub_idx), device=device, dtype=torch.long, fill_value=-1) 92 | ef[sorted_edges_ids, sub_idx] = sorted_faces_ids 93 | # FACE2FACES 94 | nb_faces_per_face = torch.stack([nb_faces_per_edge[edges_ids[i*nb_faces:(i+1)*nb_faces]] 95 | for i in range(facesize)], dim=1).sum(dim=1) - facesize 96 | ff = torch.cat([ef[edges_ids[i*nb_faces:(i+1)*nb_faces]] for i in range(facesize)], dim=1) 97 | # remove self occurences 98 | ff[ff == torch.arange(nb_faces, device=device, dtype=torch.long).view(-1,1)] = -1 99 | ff = torch.sort(ff, dim=-1, descending=True)[0] 100 | to_del = (ff[:,1:] == ff[:,:-1]) & (ff[:,1:] != -1) 101 | ff[:,1:][to_del] = -1 102 | nb_faces_per_face = nb_faces_per_face - torch.sum(to_del, dim=1) 103 | max_sub_idx = torch.max(nb_faces_per_face) 104 | ff = torch.sort(ff, dim=-1, descending=True)[0][:,:max_sub_idx] 105 | 106 | # VERTEX2VERTICES and VERTEX2EDGES 107 | npy_edges = edges.cpu().numpy() 108 | edge2key = {tuple(npy_edges[i]): i for i in range(nb_edges)} 109 | #_edges and double_edges 2nd axis correspond to the triplet: 110 | # [left vertex, right vertex, edge key] 111 | _edges = torch.cat([edges, torch.arange(nb_edges, device=device).view(-1, 1)], 112 | dim=1) 113 | double_edges = torch.cat([_edges, _edges[:,[1,0,2]]], dim=0) 114 | double_edges = torch.unique(double_edges, sorted=True, dim=0) 115 | # TODO(cfujitsang): potential improvment, to test with bigger model: 116 | #double_edges0, order_double_edges = torch.sort(double_edges[0]) 117 | nb_double_edges = double_edges.shape[0] 118 | # indices of first occurences of each key 119 | idx_first = torch.where( 120 | torch.nn.functional.pad(double_edges[1:,0] != double_edges[:-1,0], 121 | (1,0), value=1))[0] 122 | nb_edges_per_vertex = idx_first[1:] - idx_first[:-1] 123 | # compute sub_idx (2nd axis indices to store the edges) 124 | offsets = torch.zeros(nb_double_edges, device=device, dtype=torch.long) 125 | offsets[idx_first[1:]] = nb_edges_per_vertex 126 | sub_idx = (torch.arange(nb_double_edges, device=device, dtype=torch.long) - 127 | torch.cumsum(offsets, dim=0)) 128 | nb_edges_per_vertex = torch.cat([nb_edges_per_vertex, 129 | nb_double_edges - idx_first[-1:]], dim=0) 130 | max_sub_idx = torch.max(nb_edges_per_vertex) 131 | vv = torch.full((nb_vertices, max_sub_idx), device=device, dtype=torch.long, fill_value=-1) 132 | vv[double_edges[:,0], sub_idx] = double_edges[:,1] 133 | ve = torch.full((nb_vertices, max_sub_idx), device=device, dtype=torch.long, fill_value=-1) 134 | ve[double_edges[:,0], sub_idx] = double_edges[:,2] 135 | 136 | # VERTEX2FACES 137 | vertex_ordered, order_vertex = torch.sort(faces.view(-1)) 138 | face_ids_in_vertex_order = order_vertex // facesize # This line has been patched 139 | # indices of first occurences of each id 140 | idx_first = torch.where( 141 | torch.nn.functional.pad(vertex_ordered[1:] != vertex_ordered[:-1], (1,0), value=1))[0] 142 | nb_faces_per_vertex = idx_first[1:] - idx_first[:-1] 143 | # compute sub_idx (2nd axis indices to store the faces) 144 | offsets = torch.zeros(vertex_ordered.shape[0], device=device, dtype=torch.long) 145 | offsets[idx_first[1:]] = nb_faces_per_vertex 146 | sub_idx = (torch.arange(vertex_ordered.shape[0], device=device, dtype=torch.long) - 147 | torch.cumsum(offsets, dim=0)) 148 | # TODO(cfujitsang): it seems that nb_faces_per_vertex == nb_edges_per_vertex ? 149 | nb_faces_per_vertex = torch.cat([nb_faces_per_vertex, 150 | vertex_ordered.shape[0] - idx_first[-1:]], dim=0) 151 | max_sub_idx = torch.max(nb_faces_per_vertex) 152 | vf = torch.full((nb_vertices, max_sub_idx), device=device, dtype=torch.long, fill_value=-1) 153 | vf[vertex_ordered, sub_idx] = face_ids_in_vertex_order 154 | 155 | return edge2key, edges, vv, nb_edges_per_vertex, ve, nb_edges_per_vertex, vf, \ 156 | nb_faces_per_vertex, ff, nb_faces_per_face, ee, nb_edges_per_edge, ef, nb_faces_per_edge -------------------------------------------------------------------------------- /lib/rendering/renderer.py: -------------------------------------------------------------------------------- 1 | from kaolin.graphics.dib_renderer.rasterizer import linear_rasterizer 2 | from kaolin.graphics.dib_renderer.utils import datanormalize 3 | 4 | from .fragment_shader import fragmentshader 5 | import sys 6 | import os 7 | sys.path.append(os.path.abspath('..')) 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | def ortho_projection(points_bxpx3, faces_fx3): 13 | 14 | xy_bxpx3 = points_bxpx3 # xyz [10, 482, 3] 15 | xy_bxpx2 = xy_bxpx3[:, :, :2] # xy 16 | 17 | # 1st, 2nd, 3rd point of the faces; concat 18 | pf0_bxfx3 = points_bxpx3[:, faces_fx3[:, 0], :] 19 | pf1_bxfx3 = points_bxpx3[:, faces_fx3[:, 1], :] 20 | pf2_bxfx3 = points_bxpx3[:, faces_fx3[:, 2], :] 21 | points3d_bxfx9 = torch.cat((pf0_bxfx3, pf1_bxfx3, pf2_bxfx3), dim=2) 22 | 23 | xy_f0 = xy_bxpx2[:, faces_fx3[:, 0], :] 24 | xy_f1 = xy_bxpx2[:, faces_fx3[:, 1], :] 25 | xy_f2 = xy_bxpx2[:, faces_fx3[:, 2], :] 26 | points2d_bxfx6 = torch.cat((xy_f0, xy_f1, xy_f2), dim=2) 27 | 28 | v01_bxfx3 = pf1_bxfx3 - pf0_bxfx3 29 | v02_bxfx3 = pf2_bxfx3 - pf0_bxfx3 30 | 31 | normal_bxfx3 = torch.cross(v01_bxfx3, v02_bxfx3, dim=2) 32 | 33 | return points3d_bxfx9, points2d_bxfx6, normal_bxfx3 34 | 35 | class Renderer(nn.Module): 36 | 37 | def __init__(self, height, width, filtering='bilinear'): 38 | super().__init__() 39 | 40 | self.height = height 41 | self.width = width 42 | self.filtering = filtering 43 | 44 | def forward(self, points, uv_bxpx2, texture_bx3xthxtw, ft_fx3=None, background_image=None, return_hardmask=False): 45 | # print('rendering...') 46 | 47 | ### half vertices ? 48 | points_bxpx3, faces_fx3 = points 49 | 50 | if ft_fx3 is None: 51 | ft_fx3 = faces_fx3 52 | 53 | points3d_bxfx9, points2d_bxfx6, normal_bxfx3 = ortho_projection(points_bxpx3, faces_fx3) 54 | 55 | # Detect front/back faces 56 | normalz_bxfx1 = normal_bxfx3[:, :, 2:3] 57 | 58 | # Ensure that normals are unit length 59 | normal1_bxfx3 = datanormalize(normal_bxfx3, axis=2) 60 | 61 | c0 = uv_bxpx2[:, ft_fx3[:, 0], :] 62 | c1 = uv_bxpx2[:, ft_fx3[:, 1], :] 63 | c2 = uv_bxpx2[:, ft_fx3[:, 2], :] 64 | mask = torch.ones_like(c0[:, :, :1]) 65 | uv_bxfx9 = torch.cat((c0, mask, c1, mask, c2, mask), dim=2) 66 | 67 | # print(points3d_bxfx9.shape,points2d_bxfx6.shape, normalz_bxfx1.shape,uv_bxfx9.shape) 68 | # torch.Size([16, 960, 9]) torch.Size([16, 960, 6]) torch.Size([16, 960, 1]) torch.Size([1, 960, 9]) 69 | # torch.Size([1, 960, 9]) torch.Size([1, 960, 6]) torch.Size([1, 960, 1]) torch.Size([1, 960, 9]) 70 | imfeat, improb_bxhxwx1 = linear_rasterizer( 71 | self.height, 72 | self.width, 73 | points3d_bxfx9, 74 | points2d_bxfx6, 75 | normalz_bxfx1, 76 | uv_bxfx9, 77 | ) 78 | 79 | imtexcoords = imfeat[:, :, :, :2] 80 | hardmask = imfeat[:, :, :, 2:3] 81 | 82 | imrender = fragmentshader(imtexcoords, texture_bx3xthxtw, hardmask, 83 | filtering=self.filtering, background_image=background_image) 84 | 85 | if return_hardmask: 86 | improb_bxhxwx1 = hardmask 87 | return imrender, improb_bxhxwx1, normal1_bxfx3 88 | -------------------------------------------------------------------------------- /lib/rendering/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from packaging import version 5 | 6 | def grid_sample_bilinear(input, grid): 7 | # PyTorch 1.3 introduced an API change (breaking change in version 1.4), therefore we check this explicitly 8 | # to make sure that the behavior is consistent across different versions 9 | if version.parse(torch.__version__) < version.parse('1.3'): 10 | return F.grid_sample(input, grid, mode='bilinear') 11 | else: 12 | return F.grid_sample(input, grid, mode='bilinear', align_corners=True) 13 | 14 | 15 | def symmetrize_texture(x): 16 | # Apply even symmetry along the x-axis (from length N to 2N) 17 | x_flip = torch.flip(x, (len(x.shape) - 1,)) 18 | return torch.cat((x_flip[:, :, :, x_flip.shape[3]//2:], x, x_flip[:, :, :, :x_flip.shape[3]//2]), dim=-1) 19 | 20 | 21 | def adjust_poles(tex): 22 | # Average top and bottom rows (corresponding to poles) -- for mesh only 23 | top = tex[:, :, :1].mean(dim=3, keepdim=True).expand(-1, -1, -1, tex.shape[3]) 24 | middle = tex[:, :, 1:-1] 25 | bottom = tex[:, :, -1:].mean(dim=3, keepdim=True).expand(-1, -1, -1, tex.shape[3]) 26 | return torch.cat((top, middle, bottom), dim=2) 27 | 28 | 29 | def circpad(x, amount=1): 30 | # Circular padding along x-axis (before a convolution) 31 | left = x[:, :, :, :amount] 32 | right = x[:, :, :, -amount:] 33 | return torch.cat((right, x, left), dim=3) 34 | 35 | 36 | def qrot(q, v): 37 | """ 38 | Quaternion-vector multiplication (rotation of a vector) 39 | """ 40 | assert q.shape[-1] == 4 41 | assert v.shape[-1] == 3 42 | 43 | qvec = q[:, 1:].unsqueeze(1).expand(-1, v.shape[1], -1) 44 | uv = torch.cross(qvec, v, dim=2) 45 | uuv = torch.cross(qvec, uv, dim=2) 46 | return v + 2 * (q[:, :1].unsqueeze(1) * uv + uuv) 47 | 48 | def qmul(q, r): 49 | """ 50 | Quaternion-quaternion multiplication 51 | """ 52 | assert q.shape[-1] == 4 53 | assert r.shape[-1] == 4 54 | 55 | original_shape = q.shape 56 | 57 | # Compute outer product 58 | terms = torch.bmm(r.view(-1, 4, 1), q.view(-1, 1, 4)) 59 | 60 | w = terms[:, 0, 0] - terms[:, 1, 1] - terms[:, 2, 2] - terms[:, 3, 3] 61 | x = terms[:, 0, 1] + terms[:, 1, 0] - terms[:, 2, 3] + terms[:, 3, 2] 62 | y = terms[:, 0, 2] + terms[:, 1, 3] + terms[:, 2, 0] - terms[:, 3, 1] 63 | z = terms[:, 0, 3] - terms[:, 1, 2] + terms[:, 2, 1] + terms[:, 3, 0] 64 | return torch.stack((w, x, y, z), dim=1).view(original_shape) -------------------------------------------------------------------------------- /lib/sync_batchnorm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .batchnorm import patch_sync_batchnorm, convert_model 13 | from .replicate import DataParallelWithCallback, patch_replication_callback 14 | -------------------------------------------------------------------------------- /lib/sync_batchnorm/batchnorm_reimpl.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # File : batchnorm_reimpl.py 4 | # Author : acgtyrant 5 | # Date : 11/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.init as init 14 | 15 | __all__ = ['BatchNorm2dReimpl'] 16 | 17 | 18 | class BatchNorm2dReimpl(nn.Module): 19 | """ 20 | A re-implementation of batch normalization, used for testing the numerical 21 | stability. 22 | 23 | Author: acgtyrant 24 | See also: 25 | https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14 26 | """ 27 | def __init__(self, num_features, eps=1e-5, momentum=0.1): 28 | super().__init__() 29 | 30 | self.num_features = num_features 31 | self.eps = eps 32 | self.momentum = momentum 33 | self.weight = nn.Parameter(torch.empty(num_features)) 34 | self.bias = nn.Parameter(torch.empty(num_features)) 35 | self.register_buffer('running_mean', torch.zeros(num_features)) 36 | self.register_buffer('running_var', torch.ones(num_features)) 37 | self.reset_parameters() 38 | 39 | def reset_running_stats(self): 40 | self.running_mean.zero_() 41 | self.running_var.fill_(1) 42 | 43 | def reset_parameters(self): 44 | self.reset_running_stats() 45 | init.uniform_(self.weight) 46 | init.zeros_(self.bias) 47 | 48 | def forward(self, input_): 49 | batchsize, channels, height, width = input_.size() 50 | numel = batchsize * height * width 51 | input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel) 52 | sum_ = input_.sum(1) 53 | sum_of_square = input_.pow(2).sum(1) 54 | mean = sum_ / numel 55 | sumvar = sum_of_square - sum_ * mean 56 | 57 | self.running_mean = ( 58 | (1 - self.momentum) * self.running_mean 59 | + self.momentum * mean.detach() 60 | ) 61 | unbias_var = sumvar / (numel - 1) 62 | self.running_var = ( 63 | (1 - self.momentum) * self.running_var 64 | + self.momentum * unbias_var.detach() 65 | ) 66 | 67 | bias_var = sumvar / numel 68 | inv_std = 1 / (bias_var + self.eps).pow(0.5) 69 | output = ( 70 | (input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) * 71 | self.weight.unsqueeze(1) + self.bias.unsqueeze(1)) 72 | 73 | return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous() 74 | 75 | -------------------------------------------------------------------------------- /lib/sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | 59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 62 | and passed to a registered callback. 63 | - After receiving the messages, the master device should gather the information and determine to message passed 64 | back to each slave devices. 65 | """ 66 | 67 | def __init__(self, master_callback): 68 | """ 69 | 70 | Args: 71 | master_callback: a callback to be invoked after having collected messages from slave devices. 72 | """ 73 | self._master_callback = master_callback 74 | self._queue = queue.Queue() 75 | self._registry = collections.OrderedDict() 76 | self._activated = False 77 | 78 | def __getstate__(self): 79 | return {'master_callback': self._master_callback} 80 | 81 | def __setstate__(self, state): 82 | self.__init__(state['master_callback']) 83 | 84 | def register_slave(self, identifier): 85 | """ 86 | Register an slave device. 87 | 88 | Args: 89 | identifier: an identifier, usually is the device id. 90 | 91 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 92 | 93 | """ 94 | if self._activated: 95 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 96 | self._activated = False 97 | self._registry.clear() 98 | future = FutureResult() 99 | self._registry[identifier] = _MasterRegistry(future) 100 | return SlavePipe(identifier, self._queue, future) 101 | 102 | def run_master(self, master_msg): 103 | """ 104 | Main entry for the master device in each forward pass. 105 | The messages were first collected from each devices (including the master device), and then 106 | an callback will be invoked to compute the message to be sent back to each devices 107 | (including the master device). 108 | 109 | Args: 110 | master_msg: the message that the master want to send to itself. This will be placed as the first 111 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 112 | 113 | Returns: the message to be sent back to the master device. 114 | 115 | """ 116 | self._activated = True 117 | 118 | intermediates = [(0, master_msg)] 119 | for i in range(self.nr_slaves): 120 | intermediates.append(self._queue.get()) 121 | 122 | results = self._master_callback(intermediates) 123 | assert results[0][0] == 0, 'The first result should belongs to the master.' 124 | 125 | for i, res in results: 126 | if i == 0: 127 | continue 128 | self._registry[i].result.put(res) 129 | 130 | for i in range(self.nr_slaves): 131 | assert self._queue.get() is True 132 | 133 | return results[0][1] 134 | 135 | @property 136 | def nr_slaves(self): 137 | return len(self._registry) 138 | -------------------------------------------------------------------------------- /lib/sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | 31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 32 | 33 | Note that, as all modules are isomorphism, we assign each sub-module with a context 34 | (shared among multiple copies of this module on different devices). 35 | Through this context, different copies can share some information. 36 | 37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 38 | of any slave copies. 39 | """ 40 | master_copy = modules[0] 41 | nr_modules = len(list(master_copy.modules())) 42 | ctxs = [CallbackContext() for _ in range(nr_modules)] 43 | 44 | for i, module in enumerate(modules): 45 | for j, m in enumerate(module.modules()): 46 | if hasattr(m, '__data_parallel_replicate__'): 47 | m.__data_parallel_replicate__(ctxs[j], i) 48 | 49 | 50 | class DataParallelWithCallback(DataParallel): 51 | """ 52 | Data Parallel with a replication callback. 53 | 54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 55 | original `replicate` function. 56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 57 | 58 | Examples: 59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 61 | # sync_bn.__data_parallel_replicate__ will be invoked. 62 | """ 63 | 64 | def replicate(self, module, device_ids): 65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 66 | execute_replication_callbacks(modules) 67 | return modules 68 | 69 | 70 | def patch_replication_callback(data_parallel): 71 | """ 72 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 73 | Useful when you have customized `DataParallel` implementation. 74 | 75 | Examples: 76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 78 | > patch_replication_callback(sync_bn) 79 | # this is equivalent to 80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 82 | """ 83 | 84 | assert isinstance(data_parallel, DataParallel) 85 | 86 | old_replicate = data_parallel.replicate 87 | 88 | @functools.wraps(old_replicate) 89 | def new_replicate(module, device_ids): 90 | modules = old_replicate(module, device_ids) 91 | execute_replication_callbacks(modules) 92 | return modules 93 | 94 | data_parallel.replicate = new_replicate 95 | -------------------------------------------------------------------------------- /lib/sync_batchnorm/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | import torch 13 | 14 | 15 | class TorchTestCase(unittest.TestCase): 16 | def assertTensorClose(self, x, y): 17 | adiff = float((x - y).abs().max()) 18 | if (y == 0).all(): 19 | rdiff = 'NaN' 20 | else: 21 | rdiff = float((adiff / y).abs().max()) 22 | 23 | message = ( 24 | 'Tensor close check failed\n' 25 | 'adiff={}\n' 26 | 'rdiff={}\n' 27 | ).format(adiff, rdiff) 28 | self.assertTrue(torch.allclose(x, y), message) 29 | 30 | -------------------------------------------------------------------------------- /lib/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junzhezhang/mesh-inversion/d6614726344f5a56c068df2750fefc593c4ca43d/lib/utils/__init__.py -------------------------------------------------------------------------------- /lib/utils/cam_pose.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.spatial.transform import Rotation 3 | 4 | def quaternion_apply(quaternion1, quaternion0): 5 | w0, x0, y0, z0 = quaternion0 6 | w1, x1, y1, z1 = quaternion1 7 | return np.array([-x1 * x0 - y1 * y0 - z1 * z0 + w1 * w0, 8 | x1 * w0 + y1 * z0 - z1 * y0 + w1 * x0, 9 | -x1 * z0 + y1 * w0 + z1 * x0 + w1 * y0, 10 | x1 * y0 - y1 * x0 + z1 * w0 + w1 * z0], dtype=np.float32) 11 | -------------------------------------------------------------------------------- /lib/utils/common_utils.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | import logging 4 | import torch 5 | import torchvision 6 | import numpy as np 7 | import datetime 8 | import time 9 | from functools import wraps 10 | from typing import Any, Callable 11 | import os 12 | 13 | 14 | def str2bool(v): 15 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 16 | return True 17 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 18 | return False 19 | else: 20 | raise argparse.ArgumentTypeError('Boolean value expected.') 21 | 22 | class LRScheduler(object): 23 | 24 | def __init__(self, optimizer, warm_up): 25 | super(LRScheduler, self).__init__() 26 | self.optimizer = optimizer 27 | self.warm_up = warm_up 28 | 29 | def update(self, iteration, learning_rate, num_group=1000, ratio=1): 30 | if iteration < self.warm_up: 31 | learning_rate *= iteration / self.warm_up 32 | for i, param_group in enumerate(self.optimizer.param_groups): 33 | param_group['lr'] = learning_rate * ratio**i 34 | 35 | 36 | def timeit(func: Callable[..., Any]) -> Callable[..., Any]: 37 | """Times a function, usually used as decorator""" 38 | # ref: http://zyxue.github.io/2017/09/21/python-timeit-decorator.html 39 | @wraps(func) 40 | def timed_func(*args: Any, **kwargs: Any) -> Any: 41 | """Returns the timed function""" 42 | start_time = time.time() 43 | result = func(*args, **kwargs) 44 | elapsed_time = datetime.timedelta(seconds=(time.time() - start_time)) 45 | print("time spent on %s: %s"%(func.__name__, elapsed_time)) 46 | return result 47 | 48 | return timed_func 49 | 50 | 51 | def to_grid_tex(x): 52 | with torch.no_grad(): 53 | return torchvision.utils.make_grid((x.data[:, :3]+1)/2, nrow=4) 54 | 55 | def to_grid_mesh(x): 56 | with torch.no_grad(): 57 | x = x.data[:, :3] 58 | minv = x.min(dim=3, keepdim=True)[0].min(dim=2, keepdim=True)[0] 59 | maxv = x.max(dim=3, keepdim=True)[0].max(dim=2, keepdim=True)[0] 60 | x = (x - minv)/(maxv-minv) 61 | return torchvision.utils.make_grid(x, nrow=4) 62 | 63 | 64 | def set_requires_grad(nets, requires_grad=False): 65 | # ref: https://github.com/lyndonzheng/F-LSeSim/blob/e092e62ed8a2f51f3661630e1522ec2549ec31d3/models/base_model.py#L229 66 | """Set requies_grad=Fasle for all the networks to avoid unnecessary computations 67 | Parameters: 68 | nets (network list) -- a list of networks 69 | requires_grad (bool) -- whether the networks require gradients or not 70 | """ 71 | if not isinstance(nets, list): 72 | nets = [nets] 73 | for net in nets: 74 | if net is not None: 75 | for param in net.parameters(): 76 | param.requires_grad = requires_grad -------------------------------------------------------------------------------- /lib/utils/fid.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import linalg 3 | from packaging import version 4 | import scipy 5 | 6 | from .inception import InceptionV3 7 | import torch.nn.functional as F 8 | 9 | def init_inception(): 10 | block_dim = 2048 11 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[block_dim] 12 | scipy_version = version.parse(scipy.__version__) 13 | if scipy_version > version.parse('1.3.3') and scipy_version < version.parse('1.5.0'): 14 | print('[Performance warning] If this step takes too long, downgrade SciPy to version 1.3.3 or update it to 1.5+, e.g. ' 15 | 'pip install scipy==1.5.2') 16 | inception_model = InceptionV3([block_idx]) 17 | return inception_model 18 | 19 | def forward_inception_batch(inception_model, images): 20 | pred = inception_model(images)[0] 21 | if pred.shape[2] != 1 or pred.shape[3] != 1: 22 | pred = F.adaptive_avg_pool2d(pred, output_size=(1, 1)) 23 | return pred.data.cpu().numpy().reshape(images.shape[0], -1) 24 | 25 | def calculate_stats(act): 26 | mu = np.mean(act, axis=0) 27 | sigma = np.cov(act, rowvar=False) 28 | return mu, sigma 29 | 30 | # Borrowed from https://github.com/mseitzer/pytorch-fid 31 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 32 | """Numpy implementation of the Frechet Distance. 33 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 34 | and X_2 ~ N(mu_2, C_2) is 35 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 36 | 37 | Stable version by Dougal J. Sutherland. 38 | 39 | Params: 40 | -- mu1 : Numpy array containing the activations of the pool_3 layer of the 41 | inception net ( like returned by the function 'get_predictions') 42 | for generated samples. 43 | -- mu2 : The sample mean over activations of the pool_3 layer, precalcualted 44 | on an representive data set. 45 | -- sigma1: The covariance matrix over activations of the pool_3 layer for 46 | generated samples. 47 | -- sigma2: The covariance matrix over activations of the pool_3 layer, 48 | precalcualted on an representive data set. 49 | 50 | Returns: 51 | -- : The Frechet Distance. 52 | """ 53 | 54 | mu1 = np.atleast_1d(mu1) 55 | mu2 = np.atleast_1d(mu2) 56 | 57 | sigma1 = np.atleast_2d(sigma1) 58 | sigma2 = np.atleast_2d(sigma2) 59 | 60 | assert mu1.shape == mu2.shape, "Training and test mean vectors have different lengths" 61 | assert sigma1.shape == sigma2.shape, "Training and test covariances have different dimensions" 62 | 63 | diff = mu1 - mu2 64 | 65 | # product might be almost singular 66 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 67 | if not np.isfinite(covmean).all(): 68 | msg = "fid calculation produces singular product; adding %s to diagonal of cov estimates" % eps 69 | warnings.warn(msg) 70 | offset = np.eye(sigma1.shape[0]) * eps 71 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 72 | 73 | # numerical error might give slight imaginary component 74 | if np.iscomplexobj(covmean): 75 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 76 | m = np.max(np.abs(covmean.imag)) 77 | raise ValueError("Imaginary component {}".format(m)) 78 | covmean = covmean.real 79 | 80 | tr_covmean = np.trace(covmean) 81 | 82 | return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean -------------------------------------------------------------------------------- /lib/utils/image.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | import cv2 5 | import numpy as np 6 | 7 | 8 | def resize_img(img, scale_factor): 9 | new_size = (np.round(np.array(img.shape[:2]) * scale_factor)).astype(int) 10 | new_img = cv2.resize(img, (new_size[1], new_size[0])) 11 | # This is scale factor of [height, width] i.e. [y, x] 12 | actual_factor = [new_size[0] / float(img.shape[0]), 13 | new_size[1] / float(img.shape[1])] 14 | return new_img, actual_factor 15 | 16 | 17 | def peturb_bbox(bbox, pf=0, jf=0): 18 | ''' 19 | Jitters and pads the input bbox. 20 | 21 | Args: 22 | bbox: Zero-indexed tight bbox. 23 | pf: padding fraction. 24 | jf: jittering fraction. 25 | Returns: 26 | pet_bbox: Jittered and padded box. Might have -ve or out-of-image coordinates 27 | ''' 28 | pet_bbox = [coord for coord in bbox] 29 | bwidth = bbox[2] - bbox[0] + 1 30 | bheight = bbox[3] - bbox[1] + 1 31 | 32 | pet_bbox[0] -= (pf*bwidth) + (1-2*np.random.random())*jf*bwidth 33 | pet_bbox[1] -= (pf*bheight) + (1-2*np.random.random())*jf*bheight 34 | pet_bbox[2] += (pf*bwidth) + (1-2*np.random.random())*jf*bwidth 35 | pet_bbox[3] += (pf*bheight) + (1-2*np.random.random())*jf*bheight 36 | 37 | return pet_bbox 38 | 39 | 40 | def square_bbox(bbox): 41 | ''' 42 | Converts a bbox to have a square shape by increasing size along non-max dimension. 43 | ''' 44 | sq_bbox = [int(round(coord)) for coord in bbox] 45 | bwidth = sq_bbox[2] - sq_bbox[0] + 1 46 | bheight = sq_bbox[3] - sq_bbox[1] + 1 47 | maxdim = float(max(bwidth, bheight)) 48 | 49 | dw_b_2 = int(round((maxdim-bwidth)/2.0)) 50 | dh_b_2 = int(round((maxdim-bheight)/2.0)) 51 | 52 | sq_bbox[0] -= dw_b_2 53 | sq_bbox[1] -= dh_b_2 54 | sq_bbox[2] = sq_bbox[0] + maxdim - 1 55 | sq_bbox[3] = sq_bbox[1] + maxdim - 1 56 | 57 | return sq_bbox 58 | 59 | 60 | def crop(img, bbox, bgval=0): 61 | ''' 62 | Crops a region from the image corresponding to the bbox. 63 | If some regions specified go outside the image boundaries, the pixel values are set to bgval. 64 | 65 | Args: 66 | img: image to crop 67 | bbox: bounding box to crop 68 | bgval: default background for regions outside image 69 | ''' 70 | bbox = [int(round(c)) for c in bbox] 71 | bwidth = bbox[2] - bbox[0] + 1 72 | bheight = bbox[3] - bbox[1] + 1 73 | 74 | im_shape = np.shape(img) 75 | im_h, im_w = im_shape[0], im_shape[1] 76 | 77 | nc = 1 if len(im_shape) < 3 else im_shape[2] 78 | 79 | img_out = np.ones((bheight, bwidth, nc))*bgval 80 | x_min_src = max(0, bbox[0]) 81 | x_max_src = min(im_w, bbox[2]+1) 82 | y_min_src = max(0, bbox[1]) 83 | y_max_src = min(im_h, bbox[3]+1) 84 | 85 | x_min_trg = x_min_src - bbox[0] 86 | x_max_trg = x_max_src - x_min_src + x_min_trg 87 | y_min_trg = y_min_src - bbox[1] 88 | y_max_trg = y_max_src - y_min_src + y_min_trg 89 | 90 | img_out[y_min_trg:y_max_trg, x_min_trg:x_max_trg, :] = img[y_min_src:y_max_src, x_min_src:x_max_src, :] 91 | return img_out 92 | 93 | 94 | def compute_dt(mask): 95 | """ 96 | Computes distance transform of mask. 97 | """ 98 | from scipy.ndimage import distance_transform_edt 99 | dist = distance_transform_edt(1-mask) / max(mask.shape) 100 | return dist 101 | 102 | def compute_dt_barrier(mask, k=50): 103 | """ 104 | Computes barrier distance transform of mask. 105 | """ 106 | from scipy.ndimage import distance_transform_edt 107 | dist_out = distance_transform_edt(1-mask) 108 | dist_in = distance_transform_edt(mask) 109 | 110 | dist_diff = (dist_out - dist_in) / max(mask.shape) 111 | 112 | dist = 1. / (1 + np.exp(k * -dist_diff)) 113 | return dist 114 | -------------------------------------------------------------------------------- /lib/utils/image_pool.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | 4 | 5 | class ImagePool(): 6 | """This class implements an image buffer that stores previously generated images. 7 | This buffer enables us to update discriminators using a history of generated images 8 | rather than the ones produced by the latest generators. 9 | """ 10 | 11 | def __init__(self, pool_size): 12 | """Initialize the ImagePool class 13 | Parameters: 14 | pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created 15 | """ 16 | self.pool_size = pool_size 17 | if self.pool_size > 0: # create an empty pool 18 | self.num_imgs = 0 19 | self.images = [] 20 | 21 | def query(self, images): 22 | """Return an image from the pool. 23 | Parameters: 24 | images: the latest generated images from the generator 25 | Returns images from the buffer. 26 | By 50/100, the buffer will return input images. 27 | By 50/100, the buffer will return images previously stored in the buffer, 28 | and insert the current images to the buffer. 29 | """ 30 | if self.pool_size == 0: # if the buffer size is 0, do nothing 31 | return images 32 | return_images = [] 33 | for image in images: 34 | image = torch.unsqueeze(image.data, 0) 35 | if self.num_imgs < self.pool_size: # if the buffer is not full; keep inserting current images to the buffer 36 | self.num_imgs = self.num_imgs + 1 37 | self.images.append(image) 38 | return_images.append(image) 39 | else: 40 | p = random.uniform(0, 1) 41 | if p > 0.5: # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer 42 | random_id = random.randint(0, self.pool_size - 1) # randint is inclusive 43 | tmp = self.images[random_id].clone() 44 | self.images[random_id] = image 45 | return_images.append(tmp) 46 | else: # by another 50% chance, the buffer will return the current image 47 | return_images.append(image) 48 | return_images = torch.cat(return_images, 0) # collect all the images and return 49 | return return_images -------------------------------------------------------------------------------- /lib/utils/inception.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from torchvision import models 4 | 5 | 6 | class InceptionV3(nn.Module): 7 | """Pretrained InceptionV3 network returning feature maps""" 8 | 9 | # Index of default block of inception to return, 10 | # corresponds to output of final average pooling 11 | DEFAULT_BLOCK_INDEX = 3 12 | 13 | # Maps feature dimensionality to their output blocks indices 14 | BLOCK_INDEX_BY_DIM = { 15 | 64: 0, # First max pooling features 16 | 192: 1, # Second max pooling featurs 17 | 768: 2, # Pre-aux classifier features 18 | 2048: 3 # Final average pooling features 19 | } 20 | 21 | def __init__(self, 22 | output_blocks=[DEFAULT_BLOCK_INDEX], 23 | resize_input=True, 24 | normalize_input=True, 25 | requires_grad=False): 26 | """Build pretrained InceptionV3 27 | 28 | Parameters 29 | ---------- 30 | output_blocks : list of int 31 | Indices of blocks to return features of. Possible values are: 32 | - 0: corresponds to output of first max pooling 33 | - 1: corresponds to output of second max pooling 34 | - 2: corresponds to output which is fed to aux classifier 35 | - 3: corresponds to output of final average pooling 36 | resize_input : bool 37 | If true, bilinearly resizes input to width and height 299 before 38 | feeding input to model. As the network without fully connected 39 | layers is fully convolutional, it should be able to handle inputs 40 | of arbitrary size, so resizing might not be strictly needed 41 | normalize_input : bool 42 | If true, scales the input from range (0, 1) to the range the 43 | pretrained Inception network expects, namely (-1, 1) 44 | requires_grad : bool 45 | If true, parameters of the model require gradient. Possibly useful 46 | for finetuning the network 47 | """ 48 | super(InceptionV3, self).__init__() 49 | 50 | self.resize_input = resize_input 51 | self.normalize_input = normalize_input 52 | self.output_blocks = sorted(output_blocks) 53 | self.last_needed_block = max(output_blocks) 54 | 55 | assert self.last_needed_block <= 3, \ 56 | 'Last possible output block index is 3' 57 | 58 | self.blocks = nn.ModuleList() 59 | 60 | inception = models.inception_v3(pretrained=True) 61 | 62 | # Block 0: input to maxpool1 63 | block0 = [ 64 | inception.Conv2d_1a_3x3, 65 | inception.Conv2d_2a_3x3, 66 | inception.Conv2d_2b_3x3, 67 | nn.MaxPool2d(kernel_size=3, stride=2) 68 | ] 69 | self.blocks.append(nn.Sequential(*block0)) 70 | 71 | # Block 1: maxpool1 to maxpool2 72 | if self.last_needed_block >= 1: 73 | block1 = [ 74 | inception.Conv2d_3b_1x1, 75 | inception.Conv2d_4a_3x3, 76 | nn.MaxPool2d(kernel_size=3, stride=2) 77 | ] 78 | self.blocks.append(nn.Sequential(*block1)) 79 | 80 | # Block 2: maxpool2 to aux classifier 81 | if self.last_needed_block >= 2: 82 | block2 = [ 83 | inception.Mixed_5b, 84 | inception.Mixed_5c, 85 | inception.Mixed_5d, 86 | inception.Mixed_6a, 87 | inception.Mixed_6b, 88 | inception.Mixed_6c, 89 | inception.Mixed_6d, 90 | inception.Mixed_6e, 91 | ] 92 | self.blocks.append(nn.Sequential(*block2)) 93 | 94 | # Block 3: aux classifier to final avgpool 95 | if self.last_needed_block >= 3: 96 | block3 = [ 97 | inception.Mixed_7a, 98 | inception.Mixed_7b, 99 | inception.Mixed_7c, 100 | nn.AdaptiveAvgPool2d(output_size=(1, 1)) 101 | ] 102 | self.blocks.append(nn.Sequential(*block3)) 103 | 104 | for param in self.parameters(): 105 | param.requires_grad = requires_grad 106 | 107 | def forward(self, inp): 108 | """Get Inception feature maps 109 | 110 | Parameters 111 | ---------- 112 | inp : torch.autograd.Variable 113 | Input tensor of shape Bx3xHxW. Values are expected to be in 114 | range (0, 1) 115 | 116 | Returns 117 | ------- 118 | List of torch.autograd.Variable, corresponding to the selected output 119 | block, sorted ascending by index 120 | """ 121 | outp = [] 122 | x = inp 123 | 124 | if self.resize_input: 125 | x = F.interpolate(x, 126 | size=(299, 299), 127 | mode='bilinear', 128 | align_corners=False) 129 | 130 | if self.normalize_input: 131 | x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1) 132 | 133 | for idx, block in enumerate(self.blocks): 134 | x = block(x) 135 | if idx in self.output_blocks: 136 | outp.append(x) 137 | 138 | if idx == self.last_needed_block: 139 | break 140 | 141 | return outp 142 | -------------------------------------------------------------------------------- /lib/utils/inversion_dist.py: -------------------------------------------------------------------------------- 1 | import math 2 | import multiprocessing as mp 3 | import os 4 | 5 | import torch 6 | import torch.distributed as dist 7 | from torch.nn import Module 8 | from torch.utils.data import Sampler 9 | 10 | 11 | class DistModule(Module): 12 | 13 | def __init__(self, module): 14 | super(DistModule, self).__init__() 15 | self.module = module 16 | broadcast_params(self.module) 17 | 18 | def forward(self, *inputs, **kwargs): 19 | return self.module(*inputs, **kwargs) 20 | 21 | def train(self, mode=True): 22 | super(DistModule, self).train(mode) 23 | self.module.train(mode) 24 | 25 | 26 | def average_gradients(model): 27 | """ average gradients """ 28 | for param in model.parameters(): 29 | if param.requires_grad and param.grad is not None: 30 | dist.all_reduce(param.grad.data) 31 | 32 | 33 | def broadcast_params(model): 34 | """ broadcast model parameters """ 35 | for p in model.state_dict().values(): 36 | dist.broadcast(p, 0) 37 | 38 | 39 | def average_params(model): 40 | """ broadcast model parameters """ 41 | worldsize = dist.get_world_size() 42 | for p in model.state_dict().values(): 43 | dist.all_reduce(p) 44 | p /= worldsize 45 | 46 | 47 | def dist_init(port): 48 | if mp.get_start_method(allow_none=True) != 'spawn': 49 | mp.set_start_method('spawn') 50 | proc_id = int(os.environ['SLURM_PROCID']) 51 | ntasks = int(os.environ['SLURM_NTASKS']) 52 | node_list = os.environ['SLURM_NODELIST'] 53 | num_gpus = torch.cuda.device_count() 54 | torch.cuda.set_device(proc_id % num_gpus) 55 | 56 | if '[' in node_list: 57 | beg = node_list.find('[') 58 | pos1 = node_list.find('-', beg) 59 | if pos1 < 0: 60 | pos1 = 1000 61 | pos2 = node_list.find(',', beg) 62 | if pos2 < 0: 63 | pos2 = 1000 64 | node_list = node_list[:min(pos1, pos2)].replace('[', '') 65 | addr = node_list[8:].replace('-', '.') 66 | # print(addr) 67 | 68 | os.environ['MASTER_PORT'] = port 69 | os.environ['MASTER_ADDR'] = addr 70 | os.environ['WORLD_SIZE'] = str(ntasks) 71 | os.environ['RANK'] = str(proc_id) 72 | dist.init_process_group(backend='nccl') 73 | 74 | rank = dist.get_rank() 75 | world_size = dist.get_world_size() 76 | return rank, world_size 77 | 78 | 79 | class DistributedSampler(Sampler): 80 | """Sampler that restricts data loading to a subset of the dataset. 81 | 82 | It is especially useful in conjunction with 83 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 84 | process can pass a DistributedSampler instance as a DataLoader sampler, 85 | and load a subset of the original dataset that is exclusive to it. 86 | 87 | .. note:: 88 | Dataset is assumed to be of constant size. 89 | 90 | Arguments: 91 | dataset: Dataset used for sampling. 92 | num_replicas (optional): Number of processes participating in 93 | distributed training. 94 | rank (optional): Rank of the current process within num_replicas. 95 | """ 96 | 97 | def __init__(self, dataset, num_replicas=None, rank=None): 98 | if num_replicas is None: 99 | if not dist.is_available(): 100 | raise RuntimeError( 101 | "Requires distributed package to be available") 102 | num_replicas = dist.get_world_size() 103 | if rank is None: 104 | if not dist.is_available(): 105 | raise RuntimeError( 106 | "Requires distributed package to be available") 107 | rank = dist.get_rank() 108 | self.dataset = dataset 109 | self.num_replicas = num_replicas 110 | self.rank = rank 111 | self.epoch = 0 112 | self.num_samples = int( 113 | math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 114 | self.total_size = self.num_samples * self.num_replicas 115 | 116 | def __iter__(self): 117 | # deterministically shuffle based on epoch 118 | indices = [i for i in range(len(self.dataset))] 119 | 120 | # add extra samples to make it evenly divisible 121 | indices += indices[:(self.total_size - len(indices))] 122 | assert len(indices) == self.total_size 123 | 124 | # subsample 125 | indices = indices[self.rank * self.num_samples:(self.rank + 1) * 126 | self.num_samples] 127 | assert len(indices) == self.num_samples 128 | 129 | return iter(indices) 130 | 131 | def __len__(self): 132 | return self.num_samples 133 | 134 | def set_epoch(self, epoch): 135 | self.epoch = epoch 136 | -------------------------------------------------------------------------------- /lib/utils/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import time 9 | from torch.autograd import Function 10 | 11 | import math 12 | import sys 13 | from numbers import Number 14 | from collections import Set, Mapping, deque 15 | import os 16 | 17 | 18 | def loss_flat(mesh, norms): 19 | """ 20 | Smoothness regularizer. 21 | Encourages neighboring faces to have similar normals (low cosine distance). 22 | """ 23 | loss = 0. 24 | for i in range(3): 25 | norm1 = norms 26 | norm2 = norms[:, mesh.ff[:, i]] 27 | cos = torch.sum(norm1 * norm2, dim=-1) 28 | loss += torch.mean((cos - 1) ** 2) 29 | loss *= (mesh.faces.shape[0]/2.) 30 | return loss 31 | 32 | # This class was borrowed from the pix2pix(HD) / SPADE repo, 33 | # and has been modified to add support for output masking and weighting 34 | class GANLoss(nn.Module): 35 | def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0, 36 | tensor=torch.FloatTensor, opt=None): 37 | super().__init__() 38 | self.real_label = target_real_label 39 | self.fake_label = target_fake_label 40 | self.real_label_tensor = None 41 | self.fake_label_tensor = None 42 | self.zero_tensor = None 43 | self.Tensor = tensor 44 | self.gan_mode = gan_mode 45 | self.opt = opt 46 | if gan_mode == 'ls': 47 | pass 48 | elif gan_mode == 'original': 49 | pass 50 | elif gan_mode == 'w': 51 | pass 52 | elif gan_mode == 'hinge': 53 | pass 54 | else: 55 | raise ValueError('Unexpected gan_mode {}'.format(gan_mode)) 56 | 57 | def get_target_tensor(self, input, target_is_real): 58 | if target_is_real: 59 | if self.real_label_tensor is None: 60 | self.real_label_tensor = self.Tensor(1).fill_(self.real_label) 61 | self.real_label_tensor.requires_grad_(False) 62 | return self.real_label_tensor.expand_as(input) 63 | else: 64 | if self.fake_label_tensor is None: 65 | self.fake_label_tensor = self.Tensor(1).fill_(self.fake_label) 66 | self.fake_label_tensor.requires_grad_(False) 67 | return self.fake_label_tensor.expand_as(input) 68 | 69 | def get_zero_tensor(self, input): 70 | if self.zero_tensor is None: 71 | self.zero_tensor = self.Tensor(1).fill_(0) 72 | self.zero_tensor.requires_grad_(False) 73 | return self.zero_tensor.expand_as(input) 74 | 75 | def mean(self, x, mask=None, weight=None, get_average=True): 76 | if weight is None: 77 | weight = 1 78 | 79 | if mask is None: 80 | if get_average: 81 | return torch.mean(x) * weight 82 | else: 83 | return ret * weight 84 | else: 85 | assert x.shape == mask.shape, (x.shape, mask.shape) 86 | ret = torch.sum(x * mask, dim=[1, 2, 3]) / torch.sum(mask, dim=[1, 2, 3]) 87 | if get_average: 88 | return torch.mean(ret) * weight 89 | else: 90 | return ret * weight 91 | 92 | def loss(self, input, target_is_real, for_discriminator=True, mask=None, weight=None, get_average=True): 93 | if self.gan_mode == 'original': # cross entropy loss 94 | target_tensor = self.get_target_tensor(input, target_is_real) 95 | loss = F.binary_cross_entropy_with_logits(input, target_tensor) 96 | return loss 97 | elif self.gan_mode == 'ls': 98 | target_tensor = self.get_target_tensor(input, target_is_real) 99 | return F.mse_loss(input, target_tensor) 100 | elif self.gan_mode == 'hinge': 101 | if for_discriminator: 102 | if target_is_real: 103 | minval = torch.min(input - 1, self.get_zero_tensor(input)) 104 | loss = -self.mean(minval, mask, weight, get_average=get_average) 105 | else: 106 | minval = torch.min(-input - 1, self.get_zero_tensor(input)) 107 | loss = -self.mean(minval, mask, weight, get_average=get_average) 108 | else: 109 | assert target_is_real, "The generator's hinge loss must be aiming for real" 110 | loss = -self.mean(input, mask, weight, get_average=get_average) 111 | return loss 112 | else: 113 | # wgan 114 | if target_is_real: 115 | return -input.mean() 116 | else: 117 | return input.mean() 118 | 119 | def __call__(self, input, target_is_real, for_discriminator=True, mask=None, weight=None,get_average=True): 120 | if isinstance(input, list): 121 | if mask is not None: 122 | assert isinstance(mask, list) 123 | assert len(input) == len(mask) 124 | loss = 0 125 | for idx, pred_i in enumerate(input): 126 | if isinstance(pred_i, list): 127 | pred_i = pred_i[-1] 128 | loss_tensor = self.loss(pred_i, target_is_real, for_discriminator, 129 | mask[idx] if mask is not None else None, 130 | weight[idx] if weight is not None else None, 131 | get_average=get_average) 132 | bs = 1 if len(loss_tensor.size()) == 0 else loss_tensor.size(0) 133 | new_loss = torch.mean(loss_tensor.view(bs, -1), dim=1) 134 | loss += new_loss 135 | if weight is None: 136 | return loss / len(input) 137 | else: 138 | return loss / sum(weight) 139 | else: 140 | return self.loss(input, target_is_real, for_discriminator, mask,get_average=get_average) 141 | 142 | 143 | 144 | 145 | 146 | -------------------------------------------------------------------------------- /lib/utils/mask_proj.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | import torch.nn.functional as F 4 | 5 | def mask2proj(mask_4d, threshold=0.9): 6 | """ 7 | convert from mask into coordinates 8 | input [1,1,299,299]: torch 9 | outout [1,N,2] 10 | NOTE: plt.scatter(yy,-xx), that's why swap x,y and make y * -1 after scale 11 | """ 12 | mask_2d = mask_4d[0,0] 13 | indices_2d = torch.where(mask_2d>threshold) 14 | indices = torch.stack([indices_2d[1],indices_2d[0]],-1) 15 | assert mask_4d.shape[2] == mask_4d.shape[3] 16 | scale = mask_4d.shape[3]/2.0 17 | coords = indices/scale -1 18 | coords[:,1]*=(-1) # indices from top to down (row 0 to row N), coords fron down to top [-1,1] 19 | return coords.unsqueeze(0) 20 | 21 | def get_vtx_color(mask_4d, img_4d, threshold=0.9): 22 | """ 23 | given image and mask 24 | img: (1,3,299,299) [-1,1] 25 | mask: [1,1,299,299]: torch 26 | output: 27 | vtx: [1,N,2] coords 28 | color: [1,N,3] 29 | """ 30 | mask_2d = mask_4d[0,0] 31 | indices_2d = torch.where(mask_2d>threshold) 32 | 33 | indices = torch.stack([indices_2d[1],indices_2d[0]],-1) 34 | assert mask_4d.shape[2] == mask_4d.shape[3] 35 | scale = mask_4d.shape[3]/2.0 36 | coords = indices/scale -1 37 | coords[:,1]*=(-1) 38 | 39 | color = img_4d[0,:,indices_2d[0],indices_2d[1]] 40 | return coords.unsqueeze(0), color.permute([1,0]).contiguous().unsqueeze(0) #,indices 41 | 42 | def grid_sample_from_vtx(vtx, color_map): 43 | """ 44 | grid sample from vtx 45 | the vtx can be form mask2proj() or get_vtx_color(), or projected from vtx_3d 46 | color_map can be target image, rendered image, or feature map of any size 47 | vtx: [B, N, 2] 48 | color_map: [B, C, H, W] 49 | """ 50 | vtx_copy = vtx.clone() 51 | vtx_copy[:,:,1] *= (-1) 52 | 53 | clr_sampled = F.grid_sample(color_map,vtx_copy.unsqueeze(2), align_corners=True).squeeze(-1).permute(0, 2, 1) 54 | 55 | return clr_sampled 56 | 57 | 58 | 59 | def farthest_point_sample(xyz, npoint): 60 | 61 | """ 62 | code borrowed from: http://www.programmersought.com/article/8737853003/#14_query_ball_point_93 63 | Input: 64 | xyz: pointcloud data, [B, N, C] 65 | npoint: number of samples 66 | Return: 67 | centroids: sampled pointcloud index, [B, npoint] 68 | """ 69 | device = xyz.device 70 | B, N, C = xyz.shape 71 | centroids = torch.zeros(B, npoint, dtype=torch.long).to(device) 72 | distance = torch.ones(B, N).to(device) * 1e10 73 | farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) 74 | batch_indices = torch.arange(B, dtype=torch.long).to(device) 75 | for i in range(npoint): 76 | # Update the i-th farthest point 77 | centroids[:, i] = farthest 78 | # Take the xyz coordinate of the farthest point 79 | centroid = xyz[batch_indices, farthest, :].view(B, 1, C) 80 | # Calculate the Euclidean distance from all points in the point set to this farthest point 81 | dist = torch.sum((xyz - centroid) ** 2, -1) 82 | # Update distances to record the minimum distance of each point in the sample from all existing sample points 83 | mask = dist < distance 84 | distance[mask] = dist[mask] 85 | # Find the farthest point from the updated distances matrix, and use it as the farthest point for the next iteration 86 | farthest = torch.max(distance, -1)[1] 87 | return centroids 88 | 89 | # def mask2proj(mask_tensor,threshold=0.9): 90 | # """ 91 | # assume mask of shape (1, 1, h, w) 92 | # return proj (1,N,2) of xy points of the N points 93 | # assume in the [0,1] range 94 | # """ 95 | # 96 | # # set(mask_tensor.detach().cpu().numpy().reshape(-1).tolist()) 97 | # # ans = mask > threshold # return boolen 98 | 99 | # idx_tuple = (mask_tensor > threshold).nonzero(as_tuple=True) 100 | # h_idx = idx_tuple[2].type(torch.float32) 101 | # w_idx = idx_tuple[3].type(torch.float32) 102 | 103 | # # NOTE: normalize to [0,1] 104 | # h_coords = h_idx/mask_tensor.shape[2] 105 | # w_coords = w_idx/mask_tensor.shape[3] 106 | 107 | # proj = torch.stack([h_coords,w_coords],-1).unsqueeze(0) 108 | # return proj 109 | 110 | # def mask2proj_loop(mask_tensor,threshold=0.9): 111 | # ### v1 112 | # # tic = time.time() 113 | # # coords = [] 114 | # # for i in range(mask_tensor.shape[2]): 115 | # # for j in range(mask_tensor.shape[3]): 116 | # # if mask_tensor[0,0,i,j] > threshold: 117 | # # coords.append([i,j]) 118 | # # toc = time.time() 119 | # # print('time spent in loop:',int(toc-tic)) 120 | 121 | # ### v2 122 | # tic = time.time() 123 | # coords = [] 124 | # for i in range(mask_tensor.shape[2]): 125 | # for j in range(mask_tensor.shape[3]): 126 | # if mask_tensor[0,0,i,j] > threshold: 127 | # coords.append([j,i]) 128 | # toc = time.time() 129 | # print('time spent in loop:',int(toc-tic)) 130 | # return coords 131 | 132 | 133 | # def visualize -------------------------------------------------------------------------------- /lib/utils/nn_modules.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torchvision.models as models 7 | 8 | # create a module to normalize input image so we can easily put it in a 9 | # nn.Sequential 10 | class Normalization(nn.Module): 11 | def __init__(self, mean, std, input_range='01'): 12 | super(Normalization, self).__init__() 13 | # .view the mean and std to make them [C x 1 x 1] so that they can 14 | # directly work with image Tensor of shape [B x C x H x W]. 15 | # B is batch size. C is number of channels. H is height and W is width. 16 | # self.mean = torch.tensor(mean).view(-1, 1, 1) 17 | # self.std = torch.tensor(std).view(-1, 1, 1) 18 | self.mean = mean.view(-1, 1, 1) 19 | self.std = std.view(-1, 1, 1) 20 | self.input_range = input_range 21 | 22 | def forward(self, img): 23 | if self.input_range == 'n11': 24 | img = img/2. + 0.5 25 | # normalize img 26 | return (img - self.mean) / self.std 27 | 28 | class FeaturesRes18(nn.Module): 29 | """ 30 | backbone: resnet18 31 | """ 32 | 33 | def __init__(self): 34 | super(FeaturesRes18, self).__init__() 35 | resnet18 = models.resnet18(pretrained=True) 36 | # resnet18.load_state_dict(torch.load('resnet18.pth')) 37 | modules = list(resnet18.children())[:-1] 38 | self.features = nn.Sequential(*modules) 39 | 40 | def forward(self, x): 41 | output = self.features(x) 42 | output = output.view(output.size()[0], -1) 43 | return output -------------------------------------------------------------------------------- /lib/utils/text_functions.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pickle 4 | from nltk.tokenize import RegexpTokenizer 5 | from collections import defaultdict 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 10 | import torch.nn.functional as F 11 | 12 | # The following functions have been adapted from AttnGAN: 13 | # https://github.com/taoxugit/AttnGAN/blob/0d000e652b407e976cb88fab299e8566f3de8a37/code/datasets.py 14 | 15 | class TextDataProcessorCUB: 16 | def __init__(self, data_dir, split='train', 17 | captions_per_image=10, 18 | words_num=18 # max amount of words per caption 19 | ): 20 | 21 | self.embeddings_num = captions_per_image 22 | self.words_num = words_num 23 | 24 | self.data = [] 25 | self.data_dir = data_dir 26 | 27 | split_dir = os.path.join(data_dir, split) 28 | self.split = split 29 | 30 | self.filenames, self.captions, self.ixtoword, \ 31 | self.wordtoix, self.n_words = self.load_text_data(data_dir, split) 32 | 33 | self.filenames_to_index = {} 34 | for fc, fn in enumerate(self.filenames): 35 | self.filenames_to_index[fn+".jpg"] = fc 36 | 37 | self.number_example = len(self.filenames) 38 | 39 | 40 | def load_captions(self, data_dir, filenames): 41 | all_captions = [] 42 | for i in range(len(filenames)): 43 | cap_path = '%s/text/%s.txt' % (data_dir, filenames[i]) 44 | with open(cap_path, "r") as f: 45 | captions = f.read().decode('utf8').split('\n') 46 | cnt = 0 47 | for cap in captions: 48 | if len(cap) == 0: 49 | continue 50 | cap = cap.replace("\ufffd\ufffd", " ") 51 | # picks out sequences of alphanumeric characters as tokens 52 | # and drops everything else 53 | tokenizer = RegexpTokenizer(r'\w+') 54 | tokens = tokenizer.tokenize(cap.lower()) 55 | # print('tokens', tokens) 56 | if len(tokens) == 0: 57 | print('cap', cap) 58 | continue 59 | 60 | tokens_new = [] 61 | for t in tokens: 62 | t = t.encode('ascii', 'ignore').decode('ascii') 63 | if len(t) > 0: 64 | tokens_new.append(t) 65 | all_captions.append(tokens_new) 66 | cnt += 1 67 | if cnt == self.embeddings_num: 68 | break 69 | if cnt < self.embeddings_num: 70 | print('ERROR: the captions for %s less than %d' 71 | % (filenames[i], cnt)) 72 | return all_captions 73 | 74 | def build_dictionary(self, train_captions, test_captions): 75 | word_counts = defaultdict(float) 76 | captions = train_captions + test_captions 77 | for sent in captions: 78 | for word in sent: 79 | word_counts[word] += 1 80 | 81 | vocab = [w for w in word_counts if word_counts[w] >= 0] 82 | 83 | ixtoword = {} 84 | ixtoword[0] = '' 85 | wordtoix = {} 86 | wordtoix[''] = 0 87 | ix = 1 88 | for w in vocab: 89 | wordtoix[w] = ix 90 | ixtoword[ix] = w 91 | ix += 1 92 | 93 | train_captions_new = [] 94 | for t in train_captions: 95 | rev = [] 96 | for w in t: 97 | if w in wordtoix: 98 | rev.append(wordtoix[w]) 99 | # rev.append(0) # do not need '' token 100 | train_captions_new.append(rev) 101 | 102 | test_captions_new = [] 103 | for t in test_captions: 104 | rev = [] 105 | for w in t: 106 | if w in wordtoix: 107 | rev.append(wordtoix[w]) 108 | # rev.append(0) # do not need '' token 109 | test_captions_new.append(rev) 110 | 111 | return [train_captions_new, test_captions_new, 112 | ixtoword, wordtoix, len(ixtoword)] 113 | 114 | def load_text_data(self, data_dir, split): 115 | 116 | data_dir = 'cache/cub/captions' 117 | 118 | filepath = os.path.join(data_dir, 'captions.pickle') 119 | train_names = self.load_filenames(data_dir, 'train') 120 | test_names = self.load_filenames(data_dir, 'test') 121 | if not os.path.isfile(filepath): 122 | train_captions = self.load_captions(data_dir, train_names) 123 | test_captions = self.load_captions(data_dir, test_names) 124 | 125 | train_captions, test_captions, ixtoword, wordtoix, n_words = \ 126 | self.build_dictionary(train_captions, test_captions) 127 | with open(filepath, 'wb') as f: 128 | pickle.dump([train_captions, test_captions, 129 | ixtoword, wordtoix], f, protocol=2) 130 | print('Save to: ', filepath) 131 | else: 132 | with open(filepath, 'rb') as f: 133 | x = pickle.load(f) 134 | train_captions, test_captions = x[0], x[1] 135 | ixtoword, wordtoix = x[2], x[3] 136 | del x 137 | n_words = len(ixtoword) 138 | print('Load from: ', filepath) 139 | 140 | #keep all captions and select based on fileid 141 | captions = train_captions+test_captions 142 | filenames = train_names+test_names 143 | return filenames, captions, ixtoword, wordtoix, n_words 144 | 145 | 146 | 147 | def load_filenames(self, data_dir, split): 148 | filepath = '%s/%s/filenames.pickle' % (data_dir, split) 149 | if os.path.isfile(filepath): 150 | with open(filepath, 'rb') as f: 151 | filenames = pickle.load(f) 152 | print('Load filenames from: %s (%d)' % (filepath, len(filenames))) 153 | else: 154 | filenames = [] 155 | return filenames 156 | 157 | def get_caption(self, sent_ix, words_num=None): 158 | if words_num is None: 159 | words_num = self.words_num 160 | 161 | # a list of indices for a sentence 162 | sent_caption = np.asarray(self.captions[sent_ix]).astype('int64') 163 | 164 | if (sent_caption == 0).sum() > 0: 165 | print('ERROR: do not need END (0) token', sent_caption) 166 | num_words = len(sent_caption) 167 | # pad with 0s (i.e., '') 168 | x = np.zeros((words_num,), dtype='int64') 169 | x_len = num_words 170 | if num_words <= words_num: 171 | x[:num_words] = sent_caption 172 | else: 173 | ix = list(np.arange(num_words)) # 1, 2, 3,..., maxNum 174 | np.random.shuffle(ix) 175 | ix = ix[:words_num] 176 | ix = np.sort(ix) 177 | x[:] = sent_caption[ix] 178 | x_len = words_num 179 | return x, x_len 180 | 181 | 182 | 183 | # LSTM text encoder adapted from AttnGAN: 184 | # https://github.com/taoxugit/AttnGAN/blob/0d000e652b407e976cb88fab299e8566f3de8a37/code/model.py#L75-L159 185 | class RNN_Encoder(nn.Module): 186 | def __init__(self, ntoken, words_num, ninput=300, drop_prob=0.5, 187 | nhidden=128, nlayers=1, bidirectional=True, rnn_type='LSTM'): 188 | super(RNN_Encoder, self).__init__() 189 | self.n_steps = words_num 190 | self.ntoken = ntoken # size of the dictionary 191 | self.ninput = ninput # size of each embedding vector 192 | self.drop_prob = drop_prob # probability of an element to be zeroed 193 | self.nlayers = nlayers # Number of recurrent layers 194 | self.bidirectional = bidirectional 195 | self.rnn_type = rnn_type #'LSTM' or 'GRU' 196 | 197 | if bidirectional: 198 | self.num_directions = 2 199 | else: 200 | self.num_directions = 1 201 | # number of features in the hidden state 202 | self.nhidden = nhidden // self.num_directions 203 | 204 | self.define_module() 205 | self.init_weights() 206 | 207 | def define_module(self): 208 | self.encoder = nn.Embedding(self.ntoken, self.ninput) 209 | self.drop = nn.Dropout(self.drop_prob) 210 | if self.rnn_type == 'LSTM': 211 | # dropout: If non-zero, introduces a dropout layer on 212 | # the outputs of each RNN layer except the last layer 213 | self.rnn = nn.LSTM(self.ninput, self.nhidden, 214 | self.nlayers, batch_first=True, 215 | dropout=self.drop_prob, 216 | bidirectional=self.bidirectional) 217 | elif self.rnn_type == 'GRU': 218 | self.rnn = nn.GRU(self.ninput, self.nhidden, 219 | self.nlayers, batch_first=True, 220 | dropout=self.drop_prob, 221 | bidirectional=self.bidirectional) 222 | else: 223 | raise NotImplementedError 224 | 225 | def init_weights(self): 226 | initrange = 0.1 227 | self.encoder.weight.data.uniform_(-initrange, initrange) 228 | 229 | def forward(self, captions, cap_lens): 230 | # input: torch.LongTensor of size batch x n_steps 231 | # --> emb: batch x n_steps x ninput 232 | emb = self.drop(self.encoder(captions)) 233 | # 234 | # Returns: a PackedSequence object 235 | cap_lens = cap_lens.data.tolist() 236 | total_length = captions.size(1) # max sequence length 237 | emb = pack_padded_sequence(emb, cap_lens, batch_first=True, enforce_sorted=False) 238 | # #hidden and memory (num_layers * num_directions, batch, hidden_size): 239 | # tensor containing the initial hidden state for each element in batch. 240 | # #output (batch, seq_len, hidden_size * num_directions) 241 | # #or a PackedSequence object: 242 | # tensor containing output features (h_t) from the last layer of RNN 243 | self.rnn.flatten_parameters() 244 | output, hidden = self.rnn(emb) 245 | # PackedSequence object 246 | # --> (batch, seq_len, hidden_size * num_directions) 247 | output = pad_packed_sequence(output, batch_first=True, total_length=total_length)[0] 248 | # output = self.drop(output) 249 | # --> batch x hidden_size*num_directions x seq_len 250 | words_emb = output.transpose(1, 2) 251 | # --> batch x num_directions*hidden_size 252 | if self.rnn_type == 'LSTM': 253 | sent_emb = hidden[0].transpose(0, 1).contiguous() 254 | else: 255 | sent_emb = hidden.transpose(0, 1).contiguous() 256 | sent_emb = sent_emb.view(-1, self.nhidden * self.num_directions) 257 | return words_emb, sent_emb -------------------------------------------------------------------------------- /lib/utils/tf_visualizer.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------- 2 | # Copyright (C) 2020 NVIDIA Corporation. All rights reserved. 3 | # Nvidia Source Code License-NC 4 | # Code written by Xueting Li. 5 | # ----------------------------------------------------------- 6 | import numpy as np 7 | import os 8 | import ntpath 9 | import time 10 | import termcolor 11 | 12 | # convert to colored strings 13 | def red(content): return termcolor.colored(str(content),"red",attrs=["bold"]) 14 | def green(content): return termcolor.colored(str(content),"green",attrs=["bold"]) 15 | def blue(content): return termcolor.colored(str(content),"blue",attrs=["bold"]) 16 | def cyan(content): return termcolor.colored(str(content),"cyan",attrs=["bold"]) 17 | def yellow(content): return termcolor.colored(str(content),"yellow",attrs=["bold"]) 18 | def magenta(content): return termcolor.colored(str(content),"magenta",attrs=["bold"]) 19 | 20 | class Visualizer(): 21 | def __init__(self, opt): 22 | # self.opt = opt 23 | self.log_name = os.path.join(opt.checkpoint_dir, opt.name, 'loss_log.txt') 24 | with open(self.log_name, "a") as log_file: 25 | now = time.strftime("%c") 26 | log_file.write('================ Training Loss (%s) ================\n' % now) 27 | 28 | # scalars: same format as |scalars| of plot_current_scalars 29 | def print_current_scalars(self, epoch, i, scalars): 30 | message = green('(epoch: %d, iters: %d) ' % (epoch, i)) 31 | for k, v in scalars.items(): 32 | if("lr" in k): 33 | message += '%s: %.6f ' % (k, v) 34 | else: 35 | message += '%s: %.3f ' % (k, v) 36 | 37 | print(message) 38 | with open(self.log_name, "a") as log_file: 39 | log_file.write('%s\n' % message) 40 | -------------------------------------------------------------------------------- /lib/utils/vgg_feat.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchvision 5 | import math 6 | 7 | 8 | def normalize_batch(batch): 9 | # normalize using imagenet mean and std 10 | mean = batch.new_tensor([0.485, 0.456, 0.406]).view(-1, 1, 1) 11 | std = batch.new_tensor([0.229, 0.224, 0.225]).view(-1, 1, 1) 12 | # shift to [0 1] before normalization 13 | batch = batch/2. + 0.5 14 | return (batch - mean) / std 15 | 16 | # VGG architecter, used for the perceptual-like loss using a pretrained VGG network 17 | class VGG19(torch.nn.Module): 18 | def __init__(self, requires_grad=False): 19 | super().__init__() 20 | 21 | vgg_pretrained_features = torchvision.models.vgg19(pretrained=True).features 22 | 23 | self.slice1 = torch.nn.Sequential() 24 | self.slice2 = torch.nn.Sequential() 25 | self.slice3 = torch.nn.Sequential() 26 | self.slice4 = torch.nn.Sequential() 27 | self.slice5 = torch.nn.Sequential() 28 | for x in range(2): 29 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 30 | for x in range(2, 7): 31 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 32 | for x in range(7, 12): 33 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 34 | for x in range(12, 21): 35 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 36 | for x in range(21, 30): 37 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 38 | if not requires_grad: 39 | for param in self.parameters(): 40 | param.requires_grad = False 41 | 42 | def forward(self, X, mask): 43 | B, C, H, W = X.shape 44 | X = normalize_batch(X[:, :3]) 45 | X = X * mask # apply mask after normalization 46 | h_relu1 = self.slice1(X) 47 | h_relu2 = self.slice2(h_relu1) 48 | h_relu3 = self.slice3(h_relu2) 49 | # h_relu4 = self.slice4(h_relu3) 50 | # h_relu5 = self.slice5(h_relu4) 51 | # out = [F.interpolate(h_relu3, [H//2, W//2], mode='bilinear'), 52 | # F.interpolate(h_relu4, [H//2, W//2], mode='bilinear')] 53 | # out = torch.cat(out, dim=1) 54 | return h_relu3 55 | -------------------------------------------------------------------------------- /run_evaluation.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.optim 5 | from torch.utils.data import Dataset 6 | from lib.arguments import Arguments 7 | 8 | from lib.utils.common_utils import * 9 | from lib.utils.inversion_dist import * 10 | 11 | from lib.mesh_inversion import MeshInversion 12 | from scipy.spatial.transform import Rotation 13 | from lib.utils.cam_pose import quaternion_apply 14 | from lib.utils.fid import calculate_stats, calculate_frechet_distance, init_inception, forward_inception_batch 15 | from PIL import Image 16 | 17 | import numpy as np 18 | import imageio 19 | import glob 20 | from tqdm import tqdm 21 | 22 | def imread(filename): 23 | """ 24 | Loads an image file into a (height, width, 3) uint8 ndarray. 25 | """ 26 | return np.asarray(Image.open(filename), dtype=np.uint8)[..., :3] 27 | 28 | class EvalDataset(Dataset): 29 | def __init__(self, path, eval_option, skip_angles=False): 30 | 31 | self.eval_option = eval_option 32 | 33 | if eval_option == 'FID_1': 34 | self.files = glob.glob(path+'/*.pth') 35 | else: 36 | self.files = glob.glob(path+'/*.png') 37 | 38 | if skip_angles: 39 | new_files = [itm for itm in self.files if ('_90' not in itm and '_270' not in itm)] 40 | self.files = new_files 41 | 42 | 43 | def __getitem__(self, idx): 44 | if self.eval_option == 'FID_1': 45 | this_itm = torch.load(self.files[idx]) 46 | clean_input = (this_itm['input_img'] / 2 + 0.5) * this_itm['input_mask'] 47 | clean_pred = (this_itm['pred_img'] / 2 + 0.5) * this_itm['pred_mask'] 48 | clean_input = (this_itm['input_img'] / 2 + 0.5) * this_itm['input_mask'] 49 | ret = { 50 | 'input': clean_input.squeeze(0), 51 | 'pred': clean_pred.squeeze(0), 52 | } 53 | 54 | return ret 55 | else: 56 | filename = self.files[idx] 57 | img = imread(str(filename)).astype(np.float32) 58 | 59 | img = img.transpose((2, 0, 1)) 60 | img /= 255 61 | 62 | img_t = torch.from_numpy(img).type(torch.FloatTensor) 63 | 64 | return img_t 65 | 66 | def __len__(self): 67 | return len(self.files) 68 | 69 | 70 | class Tester(object): 71 | def __init__(self, args): 72 | self.args = args 73 | 74 | 75 | def eval_iou(self): 76 | self.results_dir = os.path.join('./outputs/inversion_results',self.args.name) 77 | self.pathnames = sorted(glob.glob(self.results_dir+'/*.pth')) 78 | iou_ls = [] 79 | 80 | for i, pathname in enumerate(self.pathnames): 81 | this_unit = torch.load(pathname) 82 | input_mask = this_unit['input_mask'] 83 | pred_mask = this_unit['pred_mask'] 84 | 85 | overlap_mask = input_mask * pred_mask 86 | union_mask = input_mask + pred_mask - overlap_mask 87 | iou = overlap_mask.sum()/union_mask.sum() 88 | iou_ls.append(iou) 89 | 90 | ious = torch.stack(iou_ls) 91 | print('Mean IoU:{:4.3f}'.format(ious.mean().item())) 92 | 93 | def render_multiview(self, n_views=12): 94 | # init MeshInversion, which got GAN, mesh template, and renderer 95 | self.model = MeshInversion(self.args) 96 | 97 | results_dir = os.path.join('./outputs/inversion_results',self.args.name) 98 | pathnames = sorted(glob.glob(results_dir+'/*.pth')) 99 | 100 | rendering_dir = os.path.join(f'./outputs/multiview_renderings_{n_views}',self.args.name) 101 | os.makedirs(rendering_dir, exist_ok=True) 102 | 103 | for i, pathname in enumerate(pathnames): 104 | data = torch.load(pathname) 105 | idx = data['idx'] 106 | 107 | pred_tex = data['pred_tex'].cuda() 108 | pred_shape = data['pred_shape'].cuda() 109 | scale = torch.tensor([self.args.default_scale]).cuda() 110 | translation = data['translation'].cuda() # not in use 111 | 112 | for angle in range(0,360,self.args.angle_interval): 113 | if n_views == 10 and angle in [90, 270]: 114 | continue 115 | if self.args.canonical_pose: 116 | original_rot = Rotation.from_euler('xyz', [0, -90, 90-self.args.default_orientation], degrees=True) 117 | temp_quat = original_rot.as_quat().astype(np.float32) 118 | original_rotation = [temp_quat[-1],temp_quat[0],temp_quat[1],temp_quat[2]] 119 | else: 120 | raise 121 | rot = Rotation.from_euler('xyz', [0, angle, 0], degrees=True) 122 | rot_quat = rot.as_quat().astype(np.float32) 123 | quaternion = [rot_quat[-1],rot_quat[0],rot_quat[1],rot_quat[2]] 124 | quaternion = quaternion_apply(quaternion, original_rotation) 125 | img, mask, _ = self.model.render(pred_tex,pred_shape, attn_map=None, rotation=torch.tensor([quaternion]).cuda(), \ 126 | scale=scale, translation=translation, novel_view=True) 127 | 128 | img = img / 2 + 0.5 129 | img = img + (1-torch.cat([mask,mask,mask],1)) 130 | img = img.squeeze(0) 131 | img = (img.permute(1, 2, 0)*255).clamp(0, 255).cpu().byte().numpy() 132 | 133 | pathname = os.path.join(rendering_dir,f'{idx}_{angle:03d}.png') 134 | imageio.imwrite(pathname, img) 135 | 136 | if i%200 == 0: 137 | print(f'done {i} out of {len(pathnames)}') 138 | 139 | def eval_fid(self): 140 | 141 | inception_model = torch.nn.DataParallel(init_inception()).cuda().eval() 142 | 143 | if self.args.eval_option == 'FID_1': 144 | path = os.path.join('./outputs/inversion_results',self.args.name) 145 | data_set = EvalDataset(path=path, eval_option='FID_1',skip_angles=False) 146 | data_loader = torch.utils.data.DataLoader(data_set, batch_size=40, num_workers=8, \ 147 | pin_memory=True, drop_last=False, shuffle=False) 148 | emb_fake = [] 149 | emb_real = [] 150 | for i, data in enumerate(tqdm(data_loader)): 151 | pred_data = data['pred'] 152 | input_data = data['input'] 153 | emb_fake.append(forward_inception_batch(inception_model, pred_data.cuda())) 154 | emb_real.append(forward_inception_batch(inception_model, input_data.cuda())) 155 | 156 | emb_fake = np.concatenate(emb_fake, axis=0) 157 | emb_real = np.concatenate(emb_real, axis=0) 158 | m1, s1 = calculate_stats(emb_fake) 159 | m2, s2 = calculate_stats(emb_real) 160 | else: 161 | if self.args.eval_option == 'FID_12': 162 | path = os.path.join('./outputs/multiview_renderings_12',self.args.name) 163 | data_set = EvalDataset(path=path, eval_option='FID_12',skip_angles=False) 164 | else: 165 | # reuse the 12-view renderings with skip_angles=True 166 | path = os.path.join('./outputs/multiview_renderings_12',self.args.name) 167 | data_set = EvalDataset(path=path, eval_option='FID_10',skip_angles=True) 168 | 169 | data_loader = torch.utils.data.DataLoader(data_set, batch_size=40, num_workers=8, \ 170 | pin_memory=True, drop_last=False, shuffle=False) 171 | 172 | # load_gt_stats 173 | filepath = os.path.join('./datasets/cub', 'cache', 'precomputed_fid_299x299_testval.npz') 174 | stats = np.load(filepath) 175 | m2 = stats['stats_m'] 176 | s2 = stats['stats_s'] + np.triu(stats['stats_s'].T, 1) 177 | 178 | emb_fake = [] 179 | 180 | for i, data in enumerate(tqdm(data_loader)): 181 | emb_fake.append(forward_inception_batch(inception_model, data.cuda())) 182 | 183 | emb_fake = np.concatenate(emb_fake, axis=0) 184 | m1, s1 = calculate_stats(emb_fake) 185 | fid = calculate_frechet_distance(m1, s1, m2, s2) 186 | 187 | 188 | fid = calculate_frechet_distance(m1, s1, m2, s2) 189 | 190 | print('{}:{:.02f}'.format(self.args.eval_option,fid)) 191 | 192 | 193 | if __name__ == "__main__": 194 | args = Arguments(stage='evaluation').parser().parse_args() 195 | 196 | tester = Tester(args) 197 | if args.eval_option == 'IoU': 198 | tester.eval_iou() 199 | if args.eval_option == 'FID_1': 200 | tester.eval_fid() 201 | if args.eval_option == 'FID_12': 202 | tester.render_multiview(n_views=12) 203 | tester.eval_fid() 204 | if args.eval_option == 'FID_10': 205 | # reuse the 12-view renderings with skip_angles=True, can comment out if FID_12 called before 206 | # tester.render_multiview(n_views=12) 207 | tester.eval_fid() 208 | 209 | 210 | -------------------------------------------------------------------------------- /run_inversion.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.optim 5 | 6 | from lib.arguments import Arguments 7 | 8 | from lib.utils.common_utils import * 9 | from lib.utils.inversion_dist import * 10 | from lib.data import cub as cub_data 11 | 12 | from lib.mesh_inversion import MeshInversion 13 | 14 | 15 | 16 | class Trainer(object): 17 | def __init__(self, args): 18 | self.args = args 19 | self.data_module = cub_data 20 | self.model = MeshInversion(self.args) 21 | self.dataloader = self.data_module.data_loader(self.args, shuffle=self.args.shuffle) 22 | 23 | # dir of saved results, in .pth file for each instance 24 | os.makedirs("./outputs", exist_ok=True) 25 | os.makedirs("./outputs/inversion_results", exist_ok=True) 26 | os.makedirs(f"./outputs/inversion_results/{args.name}", exist_ok=True) 27 | 28 | 29 | def run(self): 30 | 31 | if self.args.use_pred_pose: 32 | cmr_dict_path = os.path.join(self.args.data_dir, 'cache','cmr_pred_cam.pth') 33 | cmr_dict = torch.load(cmr_dict_path) 34 | 35 | for i, data in enumerate(self.dataloader): 36 | 37 | idx = data['idx'][0].item() 38 | 39 | if self.args.use_pred_pose: 40 | # replace sfm pose with cmr predicted ones 41 | img_key, ext = os.path.splitext(os.path.basename(data['img_path'][0])) 42 | cmr_item = cmr_dict[img_key] 43 | # to avoid the outliers 44 | if cmr_item['pred_pose_overlay_iou'] > self.args.filter_noisy_pred_pose: 45 | cmr_pred_cam = cmr_item['pred_cam'].unsqueeze(0).type(torch.float32) 46 | data['sfm_pose'] = cmr_pred_cam 47 | 48 | self.model.set_target(idx, data, seq=i) 49 | self.model.init_z() 50 | self.model.run() 51 | 52 | print(f"{idx} completed.") 53 | 54 | 55 | if __name__ == "__main__": 56 | args = Arguments(stage='inversion').parser().parse_args() 57 | 58 | trainer = Trainer(args) 59 | trainer.run() --------------------------------------------------------------------------------