├── .gitignore ├── .gitmodules ├── Doc └── images │ ├── Teaser.png │ ├── clothing_transfer.png │ ├── data_teaser.png │ ├── mesh.gif │ └── teaser.gif ├── LICENSE ├── README.md ├── configs ├── data │ ├── mpiis │ │ ├── DSC_7151.yml │ │ └── DSC_7157.yml │ └── snapshot │ │ ├── female-3-casual.yml │ │ ├── female-4-casual.yml │ │ ├── male-3-casual.yml │ │ └── male-4-casual.yml └── exp │ ├── stage_0_nerf.yml │ ├── stage_1_hybrid.yml │ └── stage_1_hybrid_perceptual.yml ├── environment.yml ├── fetch_data.sh ├── lib ├── __init__.py ├── datasets │ ├── build_datasets.py │ └── scarf_data.py ├── models │ ├── __init__.py │ ├── embedding.py │ ├── lbs.py │ ├── nerf.py │ ├── scarf.py │ ├── siren.py │ └── smplx.py ├── trainer.py ├── utils │ ├── camera_util.py │ ├── config.py │ ├── lossfunc.py │ ├── metric.py │ ├── perceptual_loss.py │ ├── rasterize_rendering.py │ ├── rotation_converter.py │ ├── util.py │ └── volumetric_rendering.py └── visualizer.py ├── main_demo.py ├── main_train.py ├── process_data ├── README.md ├── fetch_asset_data.sh ├── lists │ └── video_list.txt ├── logs │ └── generate_data.log ├── process_video.py └── submodules │ └── detector.py ├── requirements.txt └── train.sh /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled source # 2 | ################### 3 | *.o 4 | *.so 5 | 6 | # Packages # 7 | ############ 8 | # it's better to unpack these files and commit the raw source 9 | # git has its own built in compression methods 10 | *.7z 11 | *.dmg 12 | *.gz 13 | *.iso 14 | *.jar 15 | *.rar 16 | *.tar 17 | *.zip 18 | 19 | # OS generated files # 20 | ###################### 21 | .DS_Store 22 | .DS_Store? 23 | ._* 24 | .Spotlight-V100 25 | .Trashes 26 | ehthumbs.db 27 | Thumbs.db 28 | .vscode 29 | 30 | # 3D data # 31 | ############ 32 | *.mat 33 | *.pkl 34 | *.obj 35 | *.dat 36 | *.npz 37 | 38 | # python file # 39 | ############ 40 | *.pyc 41 | __pycache__ 42 | 43 | *results* 44 | # *_vis.jpg 45 | ## internal use 46 | cluster_scripts 47 | internal 48 | *TestSamples* 49 | data 50 | exps 51 | wandb 52 | process_data/assets 53 | debug* -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "process_data/submodules/cloth-segmentation"] 2 | path = process_data/submodules/cloth-segmentation 3 | url = https://github.com/levindabhi/cloth-segmentation.git 4 | [submodule "process_data/submodules/RobustVideoMatting"] 5 | path = process_data/submodules/RobustVideoMatting 6 | url = https://github.com/PeterL1n/RobustVideoMatting.git 7 | [submodule "process_data/submodules/PIXIE"] 8 | path = process_data/submodules/PIXIE 9 | url = https://github.com/yfeng95/PIXIE.git 10 | -------------------------------------------------------------------------------- /Doc/images/Teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yfeng95/SCARF/d8d89d27431120569380f16094d419df063c2fe4/Doc/images/Teaser.png -------------------------------------------------------------------------------- /Doc/images/clothing_transfer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yfeng95/SCARF/d8d89d27431120569380f16094d419df063c2fe4/Doc/images/clothing_transfer.png -------------------------------------------------------------------------------- /Doc/images/data_teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yfeng95/SCARF/d8d89d27431120569380f16094d419df063c2fe4/Doc/images/data_teaser.png -------------------------------------------------------------------------------- /Doc/images/mesh.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yfeng95/SCARF/d8d89d27431120569380f16094d419df063c2fe4/Doc/images/mesh.gif -------------------------------------------------------------------------------- /Doc/images/teaser.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yfeng95/SCARF/d8d89d27431120569380f16094d419df063c2fe4/Doc/images/teaser.gif -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | License 2 | 3 | Software Copyright License for non-commercial scientific research purposes 4 | Please read carefully the following terms and conditions and any accompanying documentation before you download and/or use the ECON model, data and software, (the "Data & Software"), including 3D meshes, images, videos, textures, software, scripts, and animations. By downloading and/or using the Data & Software (including downloading, cloning, installing, and any other use of the corresponding github repository), you acknowledge that you have read these terms and conditions, understand them, and agree to be bound by them. If you do not agree with these terms and conditions, you must not download and/or use the Data & Software. Any infringement of the terms of this agreement will automatically terminate your rights under this License 5 | 6 | Ownership / Licensees 7 | The Software and the associated materials has been developed at the Max Planck Institute for Intelligent Systems (hereinafter "MPI"). Any copyright or patent right is owned by and proprietary material of the Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (hereinafter “MPG”; MPI and MPG hereinafter collectively “Max-Planck”) hereinafter the “Licensor”. 8 | 9 | License Grant 10 | Licensor grants you (Licensee) personally a single-user, non-exclusive, non-transferable, free of charge right: 11 | 12 | • To install the Model & Software on computers owned, leased or otherwise controlled by you and/or your organization; 13 | • To use the Model & Software for the sole purpose of performing peaceful non-commercial scientific research, non-commercial education, or non-commercial artistic projects; 14 | • To modify, adapt, translate or create derivative works based upon the Model & Software. 15 | 16 | Any other use, in particular any use for commercial, pornographic, military, or surveillance, purposes is prohibited. This includes, without limitation, incorporation in a commercial product, use in a commercial service, or production of other artifacts for commercial purposes. The Data & Software may not be used to create fake, libelous, misleading, or defamatory content of any kind excluding analyses in peer-reviewed scientific research. The Data & Software may not be reproduced, modified and/or made available in any form to any third party without Max-Planck’s prior written permission. 17 | 18 | The Data & Software may not be used for pornographic purposes or to generate pornographic material whether commercial or not. This license also prohibits the use of the Software to train methods/algorithms/neural networks/etc. for commercial, pornographic, military, surveillance, or defamatory use of any kind. By downloading the Data & Software, you agree not to reverse engineer it. 19 | 20 | No Distribution 21 | The Data & Software and the license herein granted shall not be copied, shared, distributed, re-sold, offered for re-sale, transferred or sub-licensed in whole or in part except that you may make one copy for archive purposes only. 22 | 23 | Disclaimer of Representations and Warranties 24 | You expressly acknowledge and agree that the Data & Software results from basic research, is provided “AS IS”, may contain errors, and that any use of the Data & Software is at your sole risk. LICENSOR MAKES NO REPRESENTATIONS OR WARRANTIES OF ANY KIND CONCERNING THE DATA & SOFTWARE, NEITHER EXPRESS NOR IMPLIED, AND THE ABSENCE OF ANY LEGAL OR ACTUAL DEFECTS, WHETHER DISCOVERABLE OR NOT. Specifically, and not to limit the foregoing, licensor makes no representations or warranties (i) regarding the merchantability or fitness for a particular purpose of the Data & Software, (ii) that the use of the Data & Software will not infringe any patents, copyrights or other intellectual property rights of a third party, and (iii) that the use of the Data & Software will not cause any damage of any kind to you or a third party. 25 | 26 | Limitation of Liability 27 | Because this Data & Software License Agreement qualifies as a donation, according to Section 521 of the German Civil Code (Bürgerliches Gesetzbuch – BGB) Licensor as a donor is liable for intent and gross negligence only. If the Licensor fraudulently conceals a legal or material defect, they are obliged to compensate the Licensee for the resulting damage. 28 | Licensor shall be liable for loss of data only up to the amount of typical recovery costs which would have arisen had proper and regular data backup measures been taken. For the avoidance of doubt Licensor shall be liable in accordance with the German Product Liability Act in the event of product liability. The foregoing applies also to Licensor’s legal representatives or assistants in performance. Any further liability shall be excluded. 29 | Patent claims generated through the usage of the Data & Software cannot be directed towards the copyright holders. 30 | The Data & Software is provided in the state of development the licensor defines. If modified or extended by Licensee, the Licensor makes no claims about the fitness of the Data & Software and is not responsible for any problems such modifications cause. 31 | 32 | No Maintenance Services 33 | You understand and agree that Licensor is under no obligation to provide either maintenance services, update services, notices of latent defects, or corrections of defects with regard to the Data & Software. Licensor nevertheless reserves the right to update, modify, or discontinue the Data & Software at any time. 34 | 35 | Defects of the Data & Software must be notified in writing to the Licensor with a comprehensible description of the error symptoms. The notification of the defect should enable the reproduction of the error. The Licensee is encouraged to communicate any use, results, modification or publication. 36 | 37 | Publications using the Model & Software 38 | You acknowledge that the Data & Software is a valuable scientific resource and agree to appropriately reference the following paper in any publication making use of the Data & Software. 39 | 40 | Citation: 41 | @inproceedings{Feng2022scarf, 42 | author = {Feng, Yao and Yang, Jinlong and Pollefeys, Marc and Black, Michael J. and Bolkart, Timo}, 43 | title = {Capturing and Animation of Body and Clothing from Monocular Video}, 44 | year = {2022}, 45 | booktitle = {SIGGRAPH Asia 2022 Conference Papers}, 46 | articleno = {45}, 47 | numpages = {9}, 48 | location = {Daegu, Republic of Korea}, 49 | series = {SA '22} 50 | } 51 | 52 | Commercial licensing opportunities 53 | For commercial uses of the Model & Software, please send email to ps-license@tue.mpg.de 54 | 55 | This Agreement shall be governed by the laws of the Federal Republic of Germany except for the UN Sales Convention. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 |

4 | 5 |

SCARF: Capturing and Animation of Body and Clothing from Monocular Video 6 |

7 | 20 |
21 | teaser 22 |
23 |

24 | 25 | This is the Pytorch implementation of SCARF. More details please check our [Project](https://yfeng95.github.io/scarf/) page. 26 | 27 | SCARF extracts a 3D clothed avatar from a monocular video. 28 | SCARF allows us to synthesize new views of the reconstructed avatar, and to animate the avatar with SMPL-X identity shape and pose control. 29 | The disentanglement of thebody and clothing further enables us to transfer clothing between subjects for virtual try-on applications. 30 | 31 | The key features: 32 | 1. animate the avatar by changing body poses (including hand articulation and facial expressions), 33 | 2. synthesize novel views of the avatar, and 34 | 3. transfer clothing between avatars for virtual try-on applications. 35 | 36 | 37 | ## Getting Started 38 | Clone the repo: 39 | ```bash 40 | git clone https://github.com/yfeng95/SCARF 41 | cd SCARF 42 | ``` 43 | ### Requirements 44 | ```bash 45 | conda create -n scarf python=3.9 46 | conda activate scarf 47 | pip install -r requirements.txt 48 | ``` 49 | If you have problems when installing [pytorch3d](https://github.com/facebookresearch/pytorch3d/blob/main/INSTALL.md), please follow their instructions. 50 | ### Download data 51 | ``` 52 | bash fetch_data.sh 53 | ``` 54 | 55 | ## Visualization 56 | * check training frames: 57 | ```bash 58 | python main_demo.py --vis_type capture --frame_id 0 59 | ``` 60 | * **novel view** synthesis of given frame id: 61 | ```bash 62 | python main_demo.py --vis_type novel_view --frame_id 0 63 | ``` 64 | * extract **mesh** and visualize 65 | ```bash 66 | python main_demo.py --vis_type extract_mesh --frame_id 0 67 | ``` 68 | You can go to our [project](https://yfeng95.github.io/scarf/) page and play with the extracted meshes. 69 |

70 | 71 |

72 | 73 | * **animation** 74 | ```bash 75 | python main_demo.py --vis_type animate 76 | ``` 77 | 78 | * **clothing transfer** 79 | ```bash 80 | # apply clothing from other model 81 | python main_demo.py --vis_type novel_view --clothing_model_path exps/snapshot/male-3-casual 82 | # transfer clothing to new body 83 | python main_demo.py --vis_type novel_view --body_model_path exps/snapshot/male-3-casual 84 | ``` 85 |

86 | 87 |

88 | 89 | More data and trained models can be found [here](https://nextcloud.tuebingen.mpg.de/index.php/s/3SEwJmZcfY5LnnN), you can download and put them into `./exps`. 90 | 91 | ## Training 92 | * training with SCARF video example 93 | ```bash 94 | bash train.sh 95 | ``` 96 | * training with other videos 97 | check [here](./process_data/README.md) to **prepare data with your own videos**, then change the data_cfg accordingly. 98 | 99 | ## TODO 100 | - [ ] add more processed data and trained models 101 | - [ ] code for refining the pose of trained models 102 | - [ ] with instant ngp 103 | 104 | ## Citation 105 | ```bibtex 106 | @inproceedings{Feng2022scarf, 107 | author = {Feng, Yao and Yang, Jinlong and Pollefeys, Marc and Black, Michael J. and Bolkart, Timo}, 108 | title = {Capturing and Animation of Body and Clothing from Monocular Video}, 109 | year = {2022}, 110 | booktitle = {SIGGRAPH Asia 2022 Conference Papers}, 111 | articleno = {45}, 112 | numpages = {9}, 113 | location = {Daegu, Republic of Korea}, 114 | series = {SA '22} 115 | } 116 | ``` 117 | 118 | ## Acknowledgments 119 | We thank [Sergey Prokudin](https://ps.is.mpg.de/people/sprokudin), [Weiyang Liu](https://wyliu.com), [Yuliang Xiu](https://xiuyuliang.cn/), [Songyou Peng](https://pengsongyou.github.io/), [Qianli Ma](https://qianlim.github.io/) for fruitful discussions, and PS members for proofreading. We also thank Betty 120 | Mohler, Tsvetelina Alexiadis, Claudia Gallatz, and Andres Camilo Mendoza Patino for their supports with data. 121 | 122 | Special thanks to [Boyi Jiang](https://scholar.google.com/citations?user=lTlZV8wAAAAJ&hl=zh-CN) and [Sida Peng](https://pengsida.net/) for sharing their data. 123 | 124 | Here are some great resources we benefit from: 125 | - [FasterRCNN](https://pytorch.org/vision/main/models/faster_rcnn.html) for detection 126 | - [RobustVideoMatting](https://github.com/PeterL1n/RobustVideoMatting) for background segmentation 127 | - [cloth-segmentation](https://github.com/levindabhi/cloth-segmentation) for clothing segmentation 128 | - [PIXIE](https://github.com/yfeng95/PIXIE) for SMPL-X parameters estimation 129 | - [smplx](https://github.com/vchoutas/smplx) for body models 130 | - [PyTorch3D](https://github.com/facebookresearch/pytorch3d) for Differential Rendering 131 | 132 | Some functions are based on other repositories, we acknowledge the origin individually in each file. 133 | 134 | ## License 135 | 136 | This code and model are available for non-commercial scientific research purposes as defined in the [LICENSE](LICENSE) file. By downloading and using the code and model you agree to the terms in the [LICENSE](LICENSE). 137 | 138 | ## Disclosure 139 | MJB has received research gift funds from Adobe, Intel, Nvidia, Meta/Facebook, and Amazon. MJB has financial interests in Amazon, Datagen Technologies, and Meshcapade GmbH. While MJB is a part-time employee of Meshcapade, his research was performed solely at, and funded solely by, the Max Planck Society. 140 | While TB is part-time employee of Amazon, this research was performed 141 | solely at, and funded solely by, MPI. 142 | 143 | ## Contact 144 | For more questions, please contact yao.feng@tue.mpg.de 145 | For commercial licensing, please contact ps-licensing@tue.mpg.de -------------------------------------------------------------------------------- /configs/data/mpiis/DSC_7151.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | image_size: 512 3 | num_workers: 1 4 | white_bg: False 5 | type: 'scarf' 6 | path: 'exps/mpiis/DSC_7151' 7 | subjects: ['DSC_7151'] 8 | train: 9 | frame_start: 0 10 | frame_end: 400 11 | frame_step: 2 12 | val: 13 | frame_start: 1 14 | frame_end: 400 15 | frame_step: 60 16 | -------------------------------------------------------------------------------- /configs/data/mpiis/DSC_7157.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | image_size: 512 3 | num_workers: 1 4 | white_bg: False 5 | type: 'scarf' 6 | path: 'exps/mpiis/DSC_7157' 7 | subjects: ['DSC_7157'] 8 | train: 9 | frame_start: 0 10 | frame_end: 400 11 | frame_step: 2 12 | val: 13 | frame_start: 1 14 | frame_end: 400 15 | frame_step: 60 16 | -------------------------------------------------------------------------------- /configs/data/snapshot/female-3-casual.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | image_size: 512 3 | num_workers: 1 4 | white_bg: False 5 | subjects: ['female-3-casual'] 6 | train: 7 | frame_start: 1 8 | frame_end: 446 9 | frame_step: 4 10 | val: 11 | frame_start: 47 12 | frame_end: 648 13 | frame_step: 50 14 | test: 15 | frame_start: 447 16 | frame_end: 648 17 | frame_step: 4 -------------------------------------------------------------------------------- /configs/data/snapshot/female-4-casual.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | image_size: 512 3 | num_workers: 1 4 | white_bg: False 5 | subjects: ['female-4-casual'] 6 | train: 7 | frame_start: 1 8 | frame_end: 336 9 | frame_step: 4 10 | val: 11 | frame_start: 336 12 | frame_end: 524 13 | frame_step: 50 14 | test: 15 | frame_start: 336 16 | frame_end: 524 17 | frame_step: 4 -------------------------------------------------------------------------------- /configs/data/snapshot/male-3-casual.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | image_size: 512 3 | num_workers: 1 4 | white_bg: False 5 | type: 'scarf' 6 | path: '/home/yfeng/Data/Projects-data/DELTA/datasets/body' 7 | subjects: ['male-3-casual'] 8 | train: 9 | frame_start: 1 10 | frame_end: 456 11 | frame_step: 4 12 | val: 13 | frame_start: 456 14 | frame_end: 676 15 | frame_step: 50 16 | test: 17 | frame_start: 457 18 | frame_end: 676 19 | frame_step: 4 -------------------------------------------------------------------------------- /configs/data/snapshot/male-4-casual.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | image_size: 512 3 | num_workers: 1 4 | white_bg: False 5 | subjects: ['male-4-casual'] 6 | train: 7 | frame_start: 1 8 | frame_end: 660 9 | frame_step: 6 10 | val: 11 | frame_start: 661 12 | frame_end: 873 13 | frame_step: 50 14 | test: 15 | frame_start: 661 16 | frame_end: 873 17 | frame_step: 4 -------------------------------------------------------------------------------- /configs/exp/stage_0_nerf.yml: -------------------------------------------------------------------------------- 1 | # scarf model 2 | use_mesh: False 3 | use_nerf: True 4 | use_fine: True 5 | k_neigh: 6 6 | chunk: 4096 7 | opt_pose: True 8 | opt_cam: True 9 | use_highres_smplx: True 10 | 11 | # training 12 | train: 13 | batch_size: 1 14 | max_steps: 100000 # training longer for more details 15 | lr: 5e-4 16 | tex_lr: 5e-4 17 | geo_lr: 5e-4 18 | pose_lr: 5e-4 19 | precrop_iters: 400 20 | precrop_frac: 0.3 21 | log_steps: 20 22 | val_steps: 500 23 | checkpoint_steps: 1000 24 | 25 | dataset: 26 | image_size: 512 27 | num_workers: 1 28 | white_bg: False 29 | type: 'scarf' 30 | 31 | loss: 32 | w_rgb: 1. 33 | w_alpha: 0.1 34 | w_depth: 0. 35 | mesh_w_rgb: 1. 36 | mesh_w_mrf: 0. #0.0001 37 | skin_consistency_type: verts_all_mean 38 | mesh_skin_consistency: 0.001 39 | mesh_w_alpha: 0.05 40 | mesh_w_alpha_skin: 1. 41 | mesh_inside_mask: 1. 42 | reg_offset_w: 10. 43 | reg_lap_w: 2.0 44 | reg_edge_w: 1. 45 | reg_normal_w: 0.01 46 | nerf_reg_normal_w: 0.001 -------------------------------------------------------------------------------- /configs/exp/stage_1_hybrid.yml: -------------------------------------------------------------------------------- 1 | # scarf model 2 | use_fine: True 3 | use_mesh: True 4 | use_nerf: True 5 | lbs_map: True 6 | k_neigh: 6 7 | chunk: 4096 8 | opt_pose: True 9 | opt_cam: True 10 | tex_network: siren 11 | use_highres_smplx: True 12 | exclude_hand: False 13 | opt_mesh: True 14 | mesh_offset_scale: 0.04 15 | 16 | sample_patch_rays: True 17 | sample_patch_size: 48 18 | 19 | use_deformation: True 20 | deformation_dim: 3 21 | deformation_type: posed_verts 22 | 23 | # training 24 | train: 25 | batch_size: 1 26 | max_steps: 50000 ## normally, 10k is enough 27 | lr: 1e-4 28 | tex_lr: 1e-5 29 | geo_lr: 1e-5 30 | pose_lr: 1e-4 31 | # precrop_iters: 400 32 | precrop_frac: 0.3 33 | log_steps: 20 34 | val_steps: 500 35 | checkpoint_steps: 1000 36 | 37 | dataset: 38 | image_size: 512 39 | num_workers: 1 40 | white_bg: False 41 | type: 'scarf' 42 | 43 | loss: 44 | # nerf 45 | w_rgb: 1. 46 | w_patch_mrf: 0.0005 47 | w_alpha: 0.5 48 | w_depth: 0. 49 | # mesh 50 | mesh_w_rgb: 1. 51 | mesh_w_alpha: 0.001 52 | mesh_w_alpha_skin: 30. 53 | mesh_w_mrf: 0.0005 54 | reg_offset_w: 400. 55 | reg_offset_w_face: 200. 56 | reg_lap_w: 0. #130.0 57 | reg_edge_w: 500.0 58 | use_new_edge_loss: True 59 | skin_consistency_type: render_hand_mean 60 | mesh_skin_consistency: 0.01 61 | mesh_inside_mask: 40. 62 | nerf_reg_dxyz_w: 2. 63 | 64 | -------------------------------------------------------------------------------- /configs/exp/stage_1_hybrid_perceptual.yml: -------------------------------------------------------------------------------- 1 | # scarf model 2 | use_fine: True 3 | use_mesh: True 4 | use_nerf: True 5 | lbs_map: True 6 | k_neigh: 6 7 | chunk: 4096 8 | opt_pose: True 9 | opt_cam: True 10 | tex_network: siren 11 | use_highres_smplx: True 12 | exclude_hand: False 13 | opt_mesh: True 14 | mesh_offset_scale: 0.04 15 | 16 | sample_patch_rays: True 17 | sample_patch_size: 48 18 | 19 | use_deformation: True 20 | deformation_dim: 3 21 | deformation_type: posed_verts 22 | 23 | # training 24 | train: 25 | batch_size: 1 26 | max_steps: 50000 27 | lr: 1e-4 28 | tex_lr: 1e-5 29 | geo_lr: 1e-5 30 | pose_lr: 1e-4 31 | # precrop_iters: 400 32 | precrop_frac: 0.3 33 | log_steps: 20 34 | val_steps: 500 35 | checkpoint_steps: 1000 36 | 37 | dataset: 38 | image_size: 512 39 | num_workers: 1 40 | white_bg: False 41 | type: 'scarf' 42 | 43 | loss: 44 | # nerf 45 | w_rgb: 1. 46 | w_patch_mrf: 0. #0005 47 | w_patch_perceptual: 0.04 48 | w_alpha: 0.5 49 | w_depth: 0. 50 | # mesh 51 | mesh_w_rgb: 1. 52 | mesh_w_alpha: 0.001 53 | mesh_w_alpha_skin: 30. 54 | mesh_w_mrf: 0. #0005 55 | mesh_w_perceptual: 0.04 56 | reg_offset_w: 400. 57 | reg_offset_w_face: 200. 58 | reg_lap_w: 0. #130.0 59 | reg_edge_w: 500.0 60 | use_new_edge_loss: True 61 | skin_consistency_type: render_hand_mean 62 | mesh_skin_consistency: 0.01 63 | mesh_inside_mask: 40. 64 | nerf_reg_dxyz_w: 2. 65 | 66 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: scarf 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1=main 6 | - _openmp_mutex=5.1=1_gnu 7 | - ca-certificates=2022.10.11=h06a4308_0 8 | - certifi=2022.9.24=py39h06a4308_0 9 | - ld_impl_linux-64=2.38=h1181459_1 10 | - libffi=3.4.2=h6a678d5_6 11 | - libgcc-ng=11.2.0=h1234567_1 12 | - libgomp=11.2.0=h1234567_1 13 | - libstdcxx-ng=11.2.0=h1234567_1 14 | - ncurses=6.3=h5eee18b_3 15 | - openssl=1.1.1s=h7f8727e_0 16 | - pip=22.3.1=py39h06a4308_0 17 | - python=3.9.15=h7a1cb2a_2 18 | - readline=8.2=h5eee18b_0 19 | - setuptools=65.5.0=py39h06a4308_0 20 | - sqlite=3.40.0=h5082296_0 21 | - tk=8.6.12=h1ccaba5_0 22 | - tzdata=2022g=h04d1e81_0 23 | - wheel=0.37.1=pyhd3eb1b0_0 24 | - xz=5.2.8=h5eee18b_0 25 | - zlib=1.2.13=h5eee18b_0 26 | - pip: 27 | - asttokens==2.2.1 28 | - backcall==0.2.0 29 | - charset-normalizer==2.1.1 30 | - chumpy==0.70 31 | - click==8.1.3 32 | - contourpy==1.0.6 33 | - cycler==0.11.0 34 | - decorator==5.1.1 35 | - docker-pycreds==0.4.0 36 | - executing==1.2.0 37 | - fonttools==4.38.0 38 | - fvcore==0.1.5.post20221122 39 | - gitdb==4.0.10 40 | - gitpython==3.1.29 41 | - idna==3.4 42 | - imageio==2.22.4 43 | - iopath==0.1.10 44 | - ipdb==0.13.9 45 | - ipython==8.7.0 46 | - jedi==0.18.2 47 | - kiwisolver==1.4.4 48 | - kornia==0.6.0 49 | - loguru==0.6.0 50 | - lpips==0.1.4 51 | - matplotlib==3.6.2 52 | - matplotlib-inline==0.1.6 53 | - networkx==2.8.8 54 | - numpy==1.23.5 55 | - nvidia-cublas-cu11==11.10.3.66 56 | - nvidia-cuda-nvrtc-cu11==11.7.99 57 | - nvidia-cuda-runtime-cu11==11.7.99 58 | - nvidia-cudnn-cu11==8.5.0.96 59 | - opencv-python==4.6.0.66 60 | - packaging==22.0 61 | - parso==0.8.3 62 | - pathtools==0.1.2 63 | - pexpect==4.8.0 64 | - pickleshare==0.7.5 65 | - pillow==9.3.0 66 | - portalocker==2.6.0 67 | - promise==2.3 68 | - prompt-toolkit==3.0.36 69 | - protobuf==4.21.11 70 | - psutil==5.9.4 71 | - ptyprocess==0.7.0 72 | - pure-eval==0.2.2 73 | - pygments==2.13.0 74 | - pymcubes==0.1.2 75 | - pyparsing==3.0.9 76 | - python-dateutil==2.8.2 77 | - pywavelets==1.4.1 78 | - pyyaml==5.1.1 79 | - requests==2.28.1 80 | - scikit-image==0.19.3 81 | - scipy==1.9.3 82 | - sentry-sdk==1.11.1 83 | - setproctitle==1.3.2 84 | - shortuuid==1.0.11 85 | - six==1.16.0 86 | - smmap==5.0.0 87 | - stack-data==0.6.2 88 | - tabulate==0.9.0 89 | - termcolor==2.1.1 90 | - tifffile==2022.10.10 91 | - toml==0.10.2 92 | - torch==1.13.0 93 | - torchmetrics==0.11.0 94 | - torchvision==0.14.0 95 | - tqdm==4.64.1 96 | - traitlets==5.7.0 97 | - trimesh==3.17.1 98 | - typing-extensions==4.4.0 99 | - urllib3==1.26.13 100 | - vulture==2.6 101 | - wandb==0.13.6 102 | - wcwidth==0.2.5 103 | - yacs==0.1.8 -------------------------------------------------------------------------------- /fetch_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | mkdir -p ./data 3 | 4 | # SMPL-X 2020 (neutral SMPL-X model with the FLAME 2020 expression blendshapes) 5 | urle () { [[ "${1}" ]] || return 1; local LANG=C i x; for (( i = 0; i < ${#1}; i++ )); do x="${1:i:1}"; [[ "${x}" == [a-zA-Z0-9.~-] ]] && echo -n "${x}" || printf '%%%02X' "'${x}"; done; echo; } 6 | echo -e "\nYou need to register at https://smpl-x.is.tue.mpg.de" 7 | read -p "Username (SMPL-X):" username 8 | read -p "Password (SMPL-X):" password 9 | username=$(urle $username) 10 | password=$(urle $password) 11 | wget --post-data "username=$username&password=$password" 'https://download.is.tue.mpg.de/download.php?domain=smplx&sfile=SMPLX_NEUTRAL_2020.npz&resume=1' -O './data/SMPLX_NEUTRAL_2020.npz' --no-check-certificate --continue 12 | 13 | # scarf utilities 14 | echo -e "\nDownloading SCARF data..." 15 | wget https://owncloud.tuebingen.mpg.de/index.php/s/n58Fzbzz7Ei9x2W/download -O ./data/scarf_utilities.zip 16 | unzip ./data/scarf_utilities.zip -d ./data 17 | rm ./data/scarf_utilities.zip 18 | 19 | # download two examples 20 | echo -e "\nDownloading SCARF training data and trained avatars..." 21 | wget https://owncloud.tuebingen.mpg.de/index.php/s/geTtN4p5YTJaqPi/download -O scarf-exp-data-small.zip 22 | unzip ./scarf-exp-data-small.zip -d . 23 | rm ./scarf-exp-data-small.zip 24 | mv ./scarf-exp-data-small ./exps 25 | -------------------------------------------------------------------------------- /lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yfeng95/SCARF/d8d89d27431120569380f16094d419df063c2fe4/lib/__init__.py -------------------------------------------------------------------------------- /lib/datasets/build_datasets.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | def build_train(cfg, mode='train'): 4 | if cfg.type == 'scarf': 5 | from .scarf_data import NerfDataset 6 | return NerfDataset(cfg, mode=mode) 7 | -------------------------------------------------------------------------------- /lib/datasets/scarf_data.py: -------------------------------------------------------------------------------- 1 | from skimage.transform import rescale, resize, downscale_local_mean 2 | from skimage.io import imread 3 | import cv2 4 | import pickle 5 | from tqdm import tqdm 6 | import numpy as np 7 | import torch 8 | import os 9 | from glob import glob 10 | from ..utils import rotation_converter 11 | 12 | class NerfDataset(torch.utils.data.Dataset): 13 | """SCARF Dataset""" 14 | 15 | def __init__(self, cfg, mode='train'): 16 | super().__init__() 17 | subject = cfg.subjects[0] 18 | imagepath_list = [] 19 | self.dataset_path = os.path.join(cfg.path, subject) 20 | if not os.path.exists (self.dataset_path): 21 | print(f'{self.dataset_path} not exists, please check the data path') 22 | exit() 23 | imagepath_list = glob(os.path.join(self.dataset_path, 'image', f'{subject}_*.png')) 24 | root_dir = os.path.join(self.dataset_path, 'cache') 25 | os.makedirs(root_dir, exist_ok=True) 26 | self.pose_cache_path = os.path.join(root_dir, 'pose.pt') 27 | self.cam_cache_path = os.path.join(root_dir, 'cam.pt') 28 | self.exp_cache_path = os.path.join(root_dir, 'exp.pt') 29 | self.beta_cache_path = os.path.join(root_dir, 'beta.pt') 30 | self.subject_id = subject 31 | 32 | imagepath_list = sorted(imagepath_list) 33 | frame_start = getattr(cfg, mode).frame_start 34 | frame_end = getattr(cfg, mode).frame_end 35 | frame_step = getattr(cfg, mode).frame_step 36 | imagepath_list = imagepath_list[frame_start:min(len(imagepath_list), frame_end):frame_step] 37 | 38 | self.data = imagepath_list 39 | if cfg.n_images < 10: 40 | self.data = self.data[:cfg.n_images] 41 | assert len(self.data) > 0, f"Can't find data; make sure data path {self.dataset_path} is correct" 42 | 43 | self.image_size = cfg.image_size 44 | self.white_bg = cfg.white_bg 45 | 46 | def __len__(self): 47 | return len(self.data) 48 | 49 | def __getitem__(self, index): 50 | # load image 51 | imagepath = self.data[index] 52 | image = imread(imagepath) / 255. 53 | imagename = imagepath.split('/')[-1].split('.')[0] 54 | image = image[:, :, :3] 55 | frame_id = int(imagename.split('_f')[-1]) 56 | frame_id = f'{frame_id:06d}' 57 | 58 | # load mask 59 | maskpath = os.path.join(self.dataset_path, 'matting', f'{imagename}.png') 60 | alpha_image = imread(maskpath) / 255. 61 | alpha_image = (alpha_image > 0.5).astype(np.float32) 62 | alpha_image = alpha_image[:, :, -1:] 63 | if self.white_bg: 64 | image = image[..., :3] * alpha_image + (1. - alpha_image) 65 | else: 66 | image = image[..., :3] * alpha_image 67 | # add alpha channel 68 | image = np.concatenate([image, alpha_image[:, :, :1]], axis=-1) 69 | image = resize(image, [self.image_size, self.image_size]) 70 | image = torch.from_numpy(image.transpose(2, 0, 1)).float() 71 | mask = image[3:] 72 | image = image[:3] 73 | 74 | # load camera and pose 75 | frame_id = int(imagename.split('_f')[-1]) 76 | name = self.subject_id 77 | 78 | # load pickle 79 | pkl_file = os.path.join(self.dataset_path, 'pixie', f'{imagename}_param.pkl') 80 | with open(pkl_file, 'rb') as f: 81 | codedict = pickle.load(f) 82 | param_dict = {} 83 | for key in codedict.keys(): 84 | if isinstance(codedict[key], str): 85 | param_dict[key] = codedict[key] 86 | else: 87 | param_dict[key] = torch.from_numpy(codedict[key]) 88 | beta = param_dict['shape'].squeeze()[:10] 89 | # full_pose = param_dict['full_pose'].squeeze() 90 | jaw_pose = torch.eye(3, dtype=torch.float32).unsqueeze(0) #param_dict['jaw_pose'] 91 | eye_pose = torch.eye(3, dtype=torch.float32).unsqueeze(0).repeat(2,1,1) 92 | # hand_pose = torch.eye(3, dtype=torch.float32).unsqueeze(0).repeat(15,1,1) 93 | full_pose = torch.cat([param_dict['global_pose'], param_dict['body_pose'], 94 | jaw_pose, eye_pose, 95 | # hand_pose, hand_pose], dim=0) 96 | param_dict['left_hand_pose'], param_dict['right_hand_pose']], dim=0) 97 | cam = param_dict['body_cam'].squeeze() 98 | exp = torch.zeros_like(param_dict['exp'].squeeze()[:10]) 99 | frame_id = f'{frame_id:06}' 100 | data = { 101 | 'idx': index, 102 | 'frame_id': frame_id, 103 | 'name': name, 104 | 'imagepath': imagepath, 105 | 'image': image, 106 | 'mask': mask, 107 | 'full_pose': full_pose, 108 | 'cam': cam, 109 | 'beta': beta, 110 | 'exp': exp 111 | } 112 | 113 | seg_image_path = os.path.join(self.dataset_path, 'cloth_segmentation', f"{imagename}.png") 114 | cloth_seg = imread(seg_image_path)/255. 115 | cloth_seg = resize(cloth_seg, [self.image_size, self.image_size]) 116 | cloth_mask = torch.from_numpy(cloth_seg[:,:,:3].sum(-1))[None,...] 117 | cloth_mask = (cloth_mask > 0.1).float() 118 | cloth_mask = ((mask + cloth_mask) > 1.5).float() 119 | skin_mask = ((mask - cloth_mask) > 0).float() 120 | data['cloth_mask'] = cloth_mask 121 | data['skin_mask'] = skin_mask 122 | 123 | return data -------------------------------------------------------------------------------- /lib/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yfeng95/SCARF/d8d89d27431120569380f16094d419df063c2fe4/lib/models/__init__.py -------------------------------------------------------------------------------- /lib/models/embedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | class Embedding(nn.Module): 5 | def __init__(self, in_channels, N_freqs, logscale=True): 6 | """ 7 | Defines a function that embeds x to (x, sin(2^k x), cos(2^k x), ...) 8 | in_channels: number of input channels (3 for both xyz and direction) 9 | """ 10 | super(Embedding, self).__init__() 11 | self.N_freqs = N_freqs 12 | self.in_channels = in_channels 13 | self.funcs = [torch.sin, torch.cos] 14 | self.out_channels = in_channels*(len(self.funcs)*N_freqs+1) 15 | 16 | if logscale: 17 | self.freq_bands = 2**torch.linspace(0, N_freqs-1, N_freqs) 18 | else: 19 | self.freq_bands = torch.linspace(1, 2**(N_freqs-1), N_freqs) 20 | 21 | def forward(self, x): 22 | """ 23 | Embeds x to (x, sin(2^k x), cos(2^k x), ...) 24 | Different from the paper, "x" is also in the output 25 | See https://github.com/bmild/nerf/issues/12 26 | 27 | Inputs: 28 | x: (B, self.in_channels) 29 | 30 | Outputs: 31 | out: (B, self.out_channels) 32 | """ 33 | out = [x] 34 | for freq in self.freq_bands: 35 | for func in self.funcs: 36 | out += [func(freq*x)] 37 | 38 | return torch.cat(out, -1) 39 | 40 | -------------------------------------------------------------------------------- /lib/models/lbs.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is 4 | # holder of all proprietary rights on this computer program. 5 | # You can only use this computer program if you have closed 6 | # a license agreement with MPG or you get the right to use the computer 7 | # program from someone who is authorized to grant you that right. 8 | # Any use of the computer program without a valid license is prohibited and 9 | # liable to prosecution. 10 | # 11 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung 12 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute 13 | # for Intelligent Systems. All rights reserved. 14 | # 15 | # Contact: ps-license@tuebingen.mpg.de 16 | 17 | from __future__ import absolute_import 18 | from __future__ import print_function 19 | from __future__ import division 20 | 21 | import numpy as np 22 | import os 23 | import yaml 24 | import torch 25 | import torch.nn.functional as F 26 | from torch import nn 27 | 28 | def rot_mat_to_euler(rot_mats): 29 | # Calculates rotation matrix to euler angles 30 | # Careful for extreme cases of eular angles like [0.0, pi, 0.0] 31 | 32 | sy = torch.sqrt(rot_mats[:, 0, 0] * rot_mats[:, 0, 0] + 33 | rot_mats[:, 1, 0] * rot_mats[:, 1, 0]) 34 | return torch.atan2(-rot_mats[:, 2, 0], sy) 35 | 36 | 37 | def find_dynamic_lmk_idx_and_bcoords(vertices, pose, dynamic_lmk_faces_idx, 38 | dynamic_lmk_b_coords, 39 | head_kin_chain, dtype=torch.float32): 40 | ''' Compute the faces, barycentric coordinates for the dynamic landmarks 41 | 42 | 43 | To do so, we first compute the rotation of the neck around the y-axis 44 | and then use a pre-computed look-up table to find the faces and the 45 | barycentric coordinates that will be used. 46 | 47 | Special thanks to Soubhik Sanyal (soubhik.sanyal@tuebingen.mpg.de) 48 | for providing the original TensorFlow implementation and for the LUT. 49 | 50 | Parameters 51 | ---------- 52 | vertices: torch.tensor BxVx3, dtype = torch.float32 53 | The tensor of input vertices 54 | pose: torch.tensor Bx(Jx3), dtype = torch.float32 55 | The current pose of the body model 56 | dynamic_lmk_faces_idx: torch.tensor L, dtype = torch.long 57 | The look-up table from neck rotation to faces 58 | dynamic_lmk_b_coords: torch.tensor Lx3, dtype = torch.float32 59 | The look-up table from neck rotation to barycentric coordinates 60 | head_kin_chain: list 61 | A python list that contains the indices of the joints that form the 62 | kinematic chain of the neck. 63 | dtype: torch.dtype, optional 64 | 65 | Returns 66 | ------- 67 | dyn_lmk_faces_idx: torch.tensor, dtype = torch.long 68 | A tensor of size BxL that contains the indices of the faces that 69 | will be used to compute the current dynamic landmarks. 70 | dyn_lmk_b_coords: torch.tensor, dtype = torch.float32 71 | A tensor of size BxL that contains the indices of the faces that 72 | will be used to compute the current dynamic landmarks. 73 | ''' 74 | 75 | batch_size = vertices.shape[0] 76 | pose = pose.detach() 77 | # aa_pose = torch.index_select(pose.view(batch_size, -1, 3), 1, 78 | # head_kin_chain) 79 | # rot_mats = batch_rodrigues( 80 | # aa_pose.view(-1, 3), dtype=dtype).view(batch_size, -1, 3, 3) 81 | rot_mats = torch.index_select(pose, 1, head_kin_chain) 82 | 83 | rel_rot_mat = torch.eye(3, device=vertices.device, 84 | dtype=dtype).unsqueeze_(dim=0) 85 | for idx in range(len(head_kin_chain)): 86 | # rel_rot_mat = torch.bmm(rot_mats[:, idx], rel_rot_mat) 87 | rel_rot_mat = torch.matmul(rot_mats[:, idx], rel_rot_mat) 88 | 89 | y_rot_angle = torch.round( 90 | torch.clamp(-rot_mat_to_euler(rel_rot_mat) * 180.0 / np.pi, 91 | max=39)).to(dtype=torch.long) 92 | # print(y_rot_angle[0]) 93 | neg_mask = y_rot_angle.lt(0).to(dtype=torch.long) 94 | mask = y_rot_angle.lt(-39).to(dtype=torch.long) 95 | neg_vals = mask * 78 + (1 - mask) * (39 - y_rot_angle) 96 | y_rot_angle = (neg_mask * neg_vals + 97 | (1 - neg_mask) * y_rot_angle) 98 | # print(y_rot_angle[0]) 99 | 100 | dyn_lmk_faces_idx = torch.index_select(dynamic_lmk_faces_idx, 101 | 0, y_rot_angle) 102 | dyn_lmk_b_coords = torch.index_select(dynamic_lmk_b_coords, 103 | 0, y_rot_angle) 104 | 105 | return dyn_lmk_faces_idx, dyn_lmk_b_coords 106 | 107 | 108 | def vertices2landmarks(vertices, faces, lmk_faces_idx, lmk_bary_coords): 109 | ''' Calculates landmarks by barycentric interpolation 110 | 111 | Parameters 112 | ---------- 113 | vertices: torch.tensor BxVx3, dtype = torch.float32 114 | The tensor of input vertices 115 | faces: torch.tensor Fx3, dtype = torch.long 116 | The faces of the mesh 117 | lmk_faces_idx: torch.tensor L, dtype = torch.long 118 | The tensor with the indices of the faces used to calculate the 119 | landmarks. 120 | lmk_bary_coords: torch.tensor Lx3, dtype = torch.float32 121 | The tensor of barycentric coordinates that are used to interpolate 122 | the landmarks 123 | 124 | Returns 125 | ------- 126 | landmarks: torch.tensor BxLx3, dtype = torch.float32 127 | The coordinates of the landmarks for each mesh in the batch 128 | ''' 129 | # Extract the indices of the vertices for each face 130 | # BxLx3 131 | batch_size, num_verts = vertices.shape[:2] 132 | device = vertices.device 133 | 134 | lmk_faces = torch.index_select(faces, 0, lmk_faces_idx.view(-1)).view( 135 | batch_size, -1, 3) 136 | 137 | lmk_faces += torch.arange( 138 | batch_size, dtype=torch.long, device=device).view(-1, 1, 1) * num_verts 139 | 140 | lmk_vertices = vertices.view(-1, 3)[lmk_faces].view( 141 | batch_size, -1, 3, 3) 142 | 143 | landmarks = torch.einsum('blfi,blf->bli', [lmk_vertices, lmk_bary_coords]) 144 | return landmarks 145 | 146 | 147 | def lbs(betas, pose, v_template, shapedirs, posedirs, J_regressor, parents, 148 | lbs_weights, pose2rot=True, dtype=torch.float32): 149 | ''' Performs Linear Blend Skinning with the given shape and pose parameters 150 | 151 | Parameters 152 | ---------- 153 | betas : torch.tensor BxNB 154 | The tensor of shape parameters 155 | pose : torch.tensor Bx(J + 1) * 3 156 | The pose parameters in axis-angle format 157 | v_template torch.tensor BxVx3 158 | The template mesh that will be deformed 159 | shapedirs : torch.tensor 1xNB 160 | The tensor of PCA shape displacements 161 | posedirs : torch.tensor Px(V * 3) 162 | The pose PCA coefficients 163 | J_regressor : torch.tensor JxV 164 | The regressor array that is used to calculate the joints from 165 | the position of the vertices 166 | parents: torch.tensor J 167 | The array that describes the kinematic tree for the model 168 | lbs_weights: torch.tensor N x V x (J + 1) 169 | The linear blend skinning weights that represent how much the 170 | rotation matrix of each part affects each vertex 171 | pose2rot: bool, optional 172 | Flag on whether to convert the input pose tensor to rotation 173 | matrices. The default value is True. If False, then the pose tensor 174 | should already contain rotation matrices and have a size of 175 | Bx(J + 1)x9 176 | dtype: torch.dtype, optional 177 | 178 | Returns 179 | ------- 180 | verts: torch.tensor BxVx3 181 | The vertices of the mesh after applying the shape and pose 182 | displacements. 183 | joints: torch.tensor BxJx3 184 | The joints of the model 185 | ''' 186 | 187 | batch_size = max(betas.shape[0], pose.shape[0]) 188 | device = betas.device 189 | 190 | # Add shape contribution 191 | shape_offsets = blend_shapes(betas, shapedirs) 192 | v_shaped = v_template + shape_offsets 193 | 194 | # Get the joints 195 | # NxJx3 array 196 | J = vertices2joints(J_regressor, v_shaped) 197 | 198 | # 3. Add pose blend shapes 199 | # N x J x 3 x 3 200 | ident = torch.eye(3, dtype=dtype, device=device) 201 | if pose2rot: 202 | rot_mats = batch_rodrigues( 203 | pose.view(-1, 3), dtype=dtype).view([batch_size, -1, 3, 3]) 204 | 205 | pose_feature = (rot_mats[:, 1:, :, :] - ident).view([batch_size, -1]) 206 | # (N x P) x (P, V * 3) -> N x V x 3 207 | pose_offsets = torch.matmul(pose_feature, posedirs) \ 208 | .view(batch_size, -1, 3) 209 | else: 210 | pose_feature = pose[:, 1:].view(batch_size, -1, 3, 3) - ident 211 | rot_mats = pose.view(batch_size, -1, 3, 3) 212 | 213 | pose_offsets = torch.matmul(pose_feature.view(batch_size, -1), 214 | posedirs).view(batch_size, -1, 3) 215 | 216 | v_posed = pose_offsets + v_shaped 217 | # 4. Get the global joint location 218 | J_transformed, A = batch_rigid_transform(rot_mats, J, parents, dtype=dtype) 219 | 220 | # 5. Do skinning: 221 | # W is N x V x (J + 1) 222 | W = lbs_weights.unsqueeze(dim=0).expand([batch_size, -1, -1]) 223 | # (N x V x (J + 1)) x (N x (J + 1) x 16) 224 | num_joints = J_regressor.shape[0] 225 | T = torch.matmul(W, A.view(batch_size, num_joints, 16)) \ 226 | .view(batch_size, -1, 4, 4) 227 | 228 | homogen_coord = torch.ones([batch_size, v_posed.shape[1], 1], 229 | dtype=dtype, device=device) 230 | v_posed_homo = torch.cat([v_posed, homogen_coord], dim=2) 231 | v_homo = torch.matmul(T, torch.unsqueeze(v_posed_homo, dim=-1)) 232 | 233 | verts = v_homo[:, :, :3, 0] 234 | 235 | return verts, J_transformed, A, T, shape_offsets, pose_offsets 236 | 237 | 238 | 239 | def vertices2joints(J_regressor, vertices): 240 | ''' Calculates the 3D joint locations from the vertices 241 | 242 | Parameters 243 | ---------- 244 | J_regressor : torch.tensor JxV 245 | The regressor array that is used to calculate the joints from the 246 | position of the vertices 247 | vertices : torch.tensor BxVx3 248 | The tensor of mesh vertices 249 | 250 | Returns 251 | ------- 252 | torch.tensor BxJx3 253 | The location of the joints 254 | ''' 255 | 256 | return torch.einsum('bik,ji->bjk', [vertices, J_regressor]) 257 | 258 | 259 | def blend_shapes(betas, shape_disps): 260 | ''' Calculates the per vertex displacement due to the blend shapes 261 | 262 | 263 | Parameters 264 | ---------- 265 | betas : torch.tensor Bx(num_betas) 266 | Blend shape coefficients 267 | shape_disps: torch.tensor Vx3x(num_betas) 268 | Blend shapes 269 | 270 | Returns 271 | ------- 272 | torch.tensor BxVx3 273 | The per-vertex displacement due to shape deformation 274 | ''' 275 | 276 | # Displacement[b, m, k] = sum_{l} betas[b, l] * shape_disps[m, k, l] 277 | # i.e. Multiply each shape displacement by its corresponding beta and 278 | # then sum them. 279 | blend_shape = torch.einsum('bl,mkl->bmk', [betas, shape_disps]) 280 | return blend_shape 281 | 282 | 283 | def batch_rodrigues(rot_vecs, epsilon=1e-8, dtype=torch.float32): 284 | ''' Calculates the rotation matrices for a batch of rotation vectors 285 | Parameters 286 | ---------- 287 | rot_vecs: torch.tensor Nx3 288 | array of N axis-angle vectors 289 | Returns 290 | ------- 291 | R: torch.tensor Nx3x3 292 | The rotation matrices for the given axis-angle parameters 293 | ''' 294 | 295 | batch_size = rot_vecs.shape[0] 296 | device = rot_vecs.device 297 | 298 | angle = torch.norm(rot_vecs + 1e-8, dim=1, keepdim=True) 299 | rot_dir = rot_vecs / angle 300 | 301 | cos = torch.unsqueeze(torch.cos(angle), dim=1) 302 | sin = torch.unsqueeze(torch.sin(angle), dim=1) 303 | 304 | # Bx1 arrays 305 | rx, ry, rz = torch.split(rot_dir, 1, dim=1) 306 | K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device) 307 | 308 | zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device) 309 | K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1) \ 310 | .view((batch_size, 3, 3)) 311 | 312 | ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0) 313 | rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K) 314 | return rot_mat 315 | 316 | 317 | def transform_mat(R, t): 318 | ''' Creates a batch of transformation matrices 319 | Args: 320 | - R: Bx3x3 array of a batch of rotation matrices 321 | - t: Bx3x1 array of a batch of translation vectors 322 | Returns: 323 | - T: Bx4x4 Transformation matrix 324 | ''' 325 | # No padding left or right, only add an extra row 326 | return torch.cat([F.pad(R, [0, 0, 0, 1]), 327 | F.pad(t, [0, 0, 0, 1], value=1)], dim=2) 328 | 329 | 330 | def batch_rigid_transform(rot_mats, joints, parents, dtype=torch.float32): 331 | """ 332 | Applies a batch of rigid transformations to the joints 333 | 334 | Parameters 335 | ---------- 336 | rot_mats : torch.tensor BxNx3x3 337 | Tensor of rotation matrices 338 | joints : torch.tensor BxNx3 339 | Locations of joints 340 | parents : torch.tensor BxN 341 | The kinematic tree of each object 342 | dtype : torch.dtype, optional: 343 | The data type of the created tensors, the default is torch.float32 344 | 345 | Returns 346 | ------- 347 | posed_joints : torch.tensor BxNx3 348 | The locations of the joints after applying the pose rotations 349 | rel_transforms : torch.tensor BxNx4x4 350 | The relative (with respect to the root joint) rigid transformations 351 | for all the joints 352 | """ 353 | 354 | joints = torch.unsqueeze(joints, dim=-1) 355 | 356 | rel_joints = joints.clone() 357 | rel_joints[:, 1:] -= joints[:, parents[1:]] 358 | 359 | transforms_mat = transform_mat( 360 | rot_mats.reshape(-1, 3, 3), 361 | rel_joints.reshape(-1, 3, 1)).reshape(-1, joints.shape[1], 4, 4) 362 | 363 | transform_chain = [transforms_mat[:, 0]] 364 | for i in range(1, parents.shape[0]): 365 | # Subtract the joint location at the rest pose 366 | # No need for rotation, since it's identity when at rest 367 | curr_res = torch.matmul(transform_chain[parents[i]], 368 | transforms_mat[:, i]) 369 | transform_chain.append(curr_res) 370 | 371 | transforms = torch.stack(transform_chain, dim=1) 372 | 373 | # The last column of the transformations contains the posed joints 374 | posed_joints = transforms[:, :, :3, 3] 375 | 376 | # # The last column of the transformations contains the posed joints 377 | # posed_joints = transforms[:, :, :3, 3] 378 | 379 | joints_homogen = F.pad(joints, [0, 0, 0, 1]) 380 | 381 | rel_transforms = transforms - F.pad( 382 | torch.matmul(transforms, joints_homogen), [3, 0, 0, 0, 0, 0, 0, 0]) 383 | 384 | return posed_joints, rel_transforms 385 | 386 | class JointsFromVerticesSelector(nn.Module): 387 | def __init__(self, fname): 388 | ''' Selects extra joints from vertices 389 | ''' 390 | super(JointsFromVerticesSelector, self).__init__() 391 | 392 | err_msg = ( 393 | 'Either pass a filename or triangle face ids, names and' 394 | ' barycentrics') 395 | assert fname is not None or ( 396 | face_ids is not None and bcs is not None and names is not None 397 | ), err_msg 398 | if fname is not None: 399 | fname = os.path.expanduser(os.path.expandvars(fname)) 400 | with open(fname, 'r') as f: 401 | data = yaml.load(f) 402 | names = list(data.keys()) 403 | bcs = [] 404 | face_ids = [] 405 | for name, d in data.items(): 406 | face_ids.append(d['face']) 407 | bcs.append(d['bc']) 408 | bcs = np.array(bcs, dtype=np.float32) 409 | face_ids = np.array(face_ids, dtype=np.int32) 410 | assert len(bcs) == len(face_ids), ( 411 | 'The number of barycentric coordinates must be equal to the faces' 412 | ) 413 | assert len(names) == len(face_ids), ( 414 | 'The number of names must be equal to the number of ' 415 | ) 416 | 417 | self.names = names 418 | self.register_buffer('bcs', torch.tensor(bcs, dtype=torch.float32)) 419 | self.register_buffer( 420 | 'face_ids', torch.tensor(face_ids, dtype=torch.long)) 421 | 422 | def extra_joint_names(self): 423 | ''' Returns the names of the extra joints 424 | ''' 425 | return self.names 426 | 427 | def forward(self, vertices, faces): 428 | if len(self.face_ids) < 1: 429 | return [] 430 | vertex_ids = faces[self.face_ids].reshape(-1) 431 | # Should be BxNx3x3 432 | triangles = torch.index_select(vertices, 1, vertex_ids).reshape( 433 | -1, len(self.bcs), 3, 3) 434 | return (triangles * self.bcs[None, :, :, None]).sum(dim=2) 435 | 436 | # def to_tensor(array, dtype=torch.float32): 437 | # if torch.is_tensor(array): 438 | # return array 439 | # else: 440 | # return torch.tensor(array, dtype=dtype) 441 | 442 | def to_tensor(array, dtype=torch.float32): 443 | if 'torch.tensor' not in str(type(array)): 444 | return torch.tensor(array, dtype=dtype) 445 | 446 | def to_np(array, dtype=np.float32): 447 | if 'scipy.sparse' in str(type(array)): 448 | array = array.todense() 449 | return np.array(array, dtype=dtype) 450 | 451 | class Struct(object): 452 | def __init__(self, **kwargs): 453 | for key, val in kwargs.items(): 454 | setattr(self, key, val) 455 | -------------------------------------------------------------------------------- /lib/models/nerf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .embedding import Embedding 6 | 7 | class DeRF(nn.Module): 8 | def __init__(self, 9 | D=6, W=128, 10 | freqs_xyz=10, 11 | deformation_dim=0, 12 | out_channels=3, 13 | skips=[4]): 14 | """ 15 | D: number of layers for density (sigma) encoder 16 | W: number of hidden units in each layer 17 | in_channels: number of input channels for xyz (3+3*10*2=63 by default) 18 | skips: add skip connection in the Dth layer 19 | """ 20 | super(DeRF, self).__init__() 21 | self.D = D 22 | self.W = W 23 | self.freqs_xyz = freqs_xyz 24 | self.deformation_dim = deformation_dim 25 | self.skips = skips 26 | 27 | self.in_channels = 3 + 3*freqs_xyz*2 + deformation_dim 28 | self.out_channels = out_channels 29 | 30 | self.encoding_xyz = Embedding(3, freqs_xyz) 31 | 32 | # xyz encoding layers 33 | for i in range(D): 34 | if i == 0: 35 | layer = nn.Linear(self.in_channels, W) 36 | elif i in skips: 37 | layer = nn.Linear(W+self.in_channels, W) 38 | else: 39 | layer = nn.Linear(W, W) 40 | layer = nn.Sequential(layer, nn.ReLU(True)) 41 | setattr(self, f"xyz_encoding_{i+1}", layer) 42 | 43 | self.out = nn.Linear(W, self.out_channels) 44 | 45 | def forward(self, xyz, deformation_code=None): 46 | xyz_encoded = self.encoding_xyz(xyz) 47 | 48 | if self.deformation_dim > 0: 49 | xyz_encoded = torch.cat([xyz_encoded, deformation_code], -1) 50 | 51 | xyz_ = xyz_encoded 52 | for i in range(self.D): 53 | if i in self.skips: 54 | xyz_ = torch.cat([xyz_encoded, xyz_], -1) 55 | xyz_ = getattr(self, f"xyz_encoding_{i+1}")(xyz_) 56 | out = self.out(xyz_) 57 | 58 | return out 59 | 60 | class NeRF(nn.Module): 61 | def __init__(self, 62 | D=8, W=256, 63 | freqs_xyz=10, freqs_dir=4, 64 | use_view=True, use_normal=False, 65 | deformation_dim=0, appearance_dim=0, 66 | skips=[4], actvn_type='relu'): 67 | """ 68 | D: number of layers for density (sigma) encoder 69 | W: number of hidden units in each layer 70 | in_channels_xyz: number of input channels for xyz (3+3*10*2=63 by default) 71 | in_channels_dir: number of input channels for direction (3+3*4*2=27 by default) 72 | skips: add skip connection in the Dth layer 73 | """ 74 | super(NeRF, self).__init__() 75 | self.D = D 76 | self.W = W 77 | self.freqs_xyz = freqs_xyz 78 | self.freqs_dir = freqs_dir 79 | self.deformation_dim = deformation_dim 80 | self.appearance_dim = appearance_dim 81 | self.skips = skips 82 | self.use_view = use_view 83 | self.use_normal = use_normal 84 | 85 | self.encoding_xyz = Embedding(3, freqs_xyz) 86 | if self.use_view: 87 | self.encoding_dir = Embedding(3, freqs_dir) 88 | 89 | self.in_channels_xyz = 3 + 3*freqs_xyz*2 + deformation_dim 90 | 91 | self.in_channels_dir = appearance_dim 92 | if self.use_view: 93 | self.in_channels_dir += 3 + 3*freqs_dir*2 94 | if self.use_normal: 95 | self.in_channels_dir += 3 96 | 97 | if actvn_type == 'relu': 98 | actvn = nn.ReLU(inplace=True) 99 | elif actvn_type == 'leaky_relu': 100 | actvn = nn.LeakyReLU(0.2, inplace=True) 101 | elif actvn_type == 'softplus': 102 | actvn = nn.Softplus(beta=100) 103 | else: 104 | assert NotImplementedError 105 | 106 | # xyz encoding layers 107 | for i in range(D): 108 | if i == 0: 109 | layer = nn.Linear(self.in_channels_xyz, W) 110 | elif i in skips: 111 | layer = nn.Linear(W+self.in_channels_xyz, W) 112 | else: 113 | layer = nn.Linear(W, W) 114 | layer = nn.Sequential(layer, actvn) 115 | setattr(self, f"xyz_encoding_{i+1}", layer) 116 | self.xyz_encoding_final = nn.Linear(W, W) 117 | 118 | # direction encoding layers 119 | self.dir_encoding = nn.Sequential( 120 | nn.Linear(W+self.in_channels_dir, W//2), 121 | nn.ReLU(True)) 122 | 123 | # output layers 124 | self.sigma = nn.Linear(W, 1) 125 | self.rgb = nn.Sequential( 126 | nn.Linear(W//2, 3), 127 | nn.Sigmoid()) 128 | 129 | def forward(self, xyz, viewdir=None, deformation_code=None, appearance_code=None): 130 | """ 131 | Inputs: 132 | x: (B, self.in_channels_xyz(+self.in_channels_dir)) 133 | the embedded vector of position and direction 134 | Outputs: 135 | out: (B, 4), rgb and sigma 136 | """ 137 | sigma, xyz_encoding_final = self.get_sigma(xyz, deformation_code=deformation_code) 138 | 139 | dir_encoding_input = xyz_encoding_final 140 | 141 | if self.use_view: 142 | viewdir_encoded = self.encoding_dir(viewdir) 143 | dir_encoding_input = torch.cat([dir_encoding_input, viewdir_encoded], -1) 144 | if self.use_normal: 145 | normal = self.get_normal(xyz, deformation_code=deformation_code) 146 | dir_encoding_input = torch.cat([dir_encoding_input, normal], -1) 147 | if self.appearance_dim > 0: 148 | dir_encoding_input = torch.cat([dir_encoding_input, appearance_code], -1) 149 | 150 | dir_encoding = self.dir_encoding(dir_encoding_input) 151 | rgb = self.rgb(dir_encoding) 152 | 153 | return rgb, sigma 154 | 155 | def get_sigma(self, xyz, deformation_code=None, only_sigma=False): 156 | 157 | xyz_encoded = self.encoding_xyz(xyz) 158 | 159 | if self.deformation_dim > 0: 160 | xyz_encoded = torch.cat([xyz_encoded, deformation_code], -1) 161 | 162 | xyz_ = xyz_encoded 163 | for i in range(self.D): 164 | if i in self.skips: 165 | xyz_ = torch.cat([xyz_encoded, xyz_], -1) 166 | xyz_ = getattr(self, f"xyz_encoding_{i+1}")(xyz_) 167 | 168 | sigma = self.sigma(xyz_) 169 | 170 | if only_sigma: 171 | return sigma 172 | 173 | xyz_encoding_final = self.xyz_encoding_final(xyz_) 174 | 175 | return sigma, xyz_encoding_final 176 | 177 | def get_normal(self, xyz, deformation_code=None, delta=0.02): 178 | with torch.set_grad_enabled(True): 179 | xyz.requires_grad_(True) 180 | sigma = self.get_sigma(xyz, deformation_code=deformation_code, only_sigma=True) 181 | alpha = 1 - torch.exp(-delta * torch.relu(sigma)) 182 | normal = torch.autograd.grad( 183 | outputs=alpha, 184 | inputs=xyz, 185 | grad_outputs=torch.ones_like(alpha, requires_grad=False, device=alpha.device), 186 | create_graph=True, 187 | retain_graph=True, 188 | only_inputs=True)[0] 189 | 190 | return normal 191 | 192 | 193 | #### 194 | 195 | class MLP(nn.Module): 196 | def __init__(self, 197 | filter_channels, 198 | merge_layer=0, 199 | res_layers=[], 200 | norm='group', 201 | last_op=None): 202 | super(MLP, self).__init__() 203 | 204 | self.filters = nn.ModuleList() 205 | self.norms = nn.ModuleList() 206 | self.merge_layer = merge_layer if merge_layer > 0 else len(filter_channels) // 2 207 | self.res_layers = res_layers 208 | self.norm = norm 209 | self.last_op = last_op 210 | 211 | for l in range(0, len(filter_channels)-1): 212 | if l in self.res_layers: 213 | self.filters.append(nn.Conv1d( 214 | filter_channels[l] + filter_channels[0], 215 | filter_channels[l+1], 216 | 1)) 217 | else: 218 | self.filters.append(nn.Conv1d( 219 | filter_channels[l], 220 | filter_channels[l+1], 221 | 1)) 222 | if l != len(filter_channels)-2: 223 | if norm == 'group': 224 | self.norms.append(nn.GroupNorm(32, filter_channels[l+1])) 225 | elif norm == 'batch': 226 | self.norms.append(nn.BatchNorm1d(filter_channels[l+1])) 227 | 228 | ## init 229 | # for l in range(0, len(filter_channels)-1): 230 | # conv = self.filters[l] 231 | # conv.weight.data.fill_(0.00001) 232 | # conv.bias.data.fill_(0.0) 233 | 234 | def forward(self, feature): 235 | ''' 236 | feature may include multiple view inputs 237 | args: 238 | feature: [B, C_in, N] 239 | return: 240 | [B, C_out, N] prediction 241 | ''' 242 | y = feature 243 | tmpy = feature 244 | phi = None 245 | for i, f in enumerate(self.filters): 246 | y = f( 247 | y if i not in self.res_layers 248 | else torch.cat([y, tmpy], 1) 249 | ) 250 | if i != len(self.filters)-1: 251 | if self.norm not in ['batch', 'group']: 252 | y = F.leaky_relu(y) 253 | else: 254 | y = F.leaky_relu(self.norms[i](y)) 255 | if i == self.merge_layer: 256 | phi = y.clone() 257 | 258 | if self.last_op is not None: 259 | y = self.last_op(y) 260 | 261 | return y #, phi 262 | 263 | 264 | class GeoMLP(nn.Module): 265 | def __init__(self, 266 | filter_channels = [128, 256, 128], 267 | input_dim = 3, embedding_freqs = 10, 268 | output_dim = 3, 269 | cond_dim = 72, 270 | last_op=torch.tanh, 271 | scale=0.1): 272 | super(GeoMLP, self).__init__() 273 | self.input_dim = input_dim 274 | self.cond_dim = cond_dim 275 | self.embedding_freqs = embedding_freqs 276 | self.embedding_dim = input_dim*(2*embedding_freqs+1) 277 | # Embeddings 278 | self.embedding = Embedding(self.input_dim, embedding_freqs) # 10 is the default number 279 | 280 | # xyz encoding layers 281 | filter_channels = [self.embedding_dim + cond_dim] + filter_channels + [output_dim] 282 | self.mlp = MLP(filter_channels, last_op=last_op) 283 | self.scale = scale 284 | 285 | def forward(self, x, cond): 286 | """ 287 | Encodes input (xyz+dir) to rgb+sigma (not ready to render yet). 288 | For rendering this ray, please see rendering.py 289 | Inputs: 290 | x: (B, self.in_channels_xyz(+self.in_channels_dir)) 291 | the embedded vector of position and direction 292 | sigma_only: whether to infer sigma only. If True, 293 | x is of shape (B, self.in_channels_xyz) 294 | Outputs: 295 | if sigma_ony: 296 | sigma: (B, 1) sigma 297 | else: 298 | out: (B, 4), rgb and sigma 299 | """ 300 | # x: [B, nv, 3] 301 | # cond: [B, n_theta] 302 | batch_size, nv, _ = x.shape 303 | pos_embedding = self.embedding(x.reshape(batch_size, -1)).reshape(batch_size, nv, -1) 304 | cond = cond[:,None,:].expand(-1, nv, -1) 305 | inputs = torch.cat([pos_embedding, cond], -1) #[B, nv, n_position+n_theta] 306 | inputs = inputs.permute(0,2,1) 307 | out = self.mlp(inputs).permute(0,2,1)*self.scale 308 | return out 309 | 310 | -------------------------------------------------------------------------------- /lib/models/siren.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.nn as nn 3 | import torch 4 | import math 5 | import torch.nn.functional as F 6 | 7 | class Sine(nn.Module): 8 | """Sine Activation Function.""" 9 | 10 | def __init__(self): 11 | super().__init__() 12 | def forward(self, x): 13 | return torch.sin(30. * x) 14 | 15 | def sine_init(m): 16 | with torch.no_grad(): 17 | if isinstance(m, nn.Linear): 18 | num_input = m.weight.size(-1) 19 | m.weight.uniform_(-np.sqrt(6 / num_input) / 30, np.sqrt(6 / num_input) / 30) 20 | 21 | 22 | def first_layer_sine_init(m): 23 | with torch.no_grad(): 24 | if isinstance(m, nn.Linear): 25 | num_input = m.weight.size(-1) 26 | m.weight.uniform_(-1 / num_input, 1 / num_input) 27 | 28 | 29 | def film_sine_init(m): 30 | with torch.no_grad(): 31 | if isinstance(m, nn.Linear): 32 | num_input = m.weight.size(-1) 33 | m.weight.uniform_(-np.sqrt(6 / num_input) / 30, np.sqrt(6 / num_input) / 30) 34 | 35 | 36 | def first_layer_film_sine_init(m): 37 | with torch.no_grad(): 38 | if isinstance(m, nn.Linear): 39 | num_input = m.weight.size(-1) 40 | m.weight.uniform_(-1 / num_input, 1 / num_input) 41 | 42 | 43 | def kaiming_leaky_init(m): 44 | classname = m.__class__.__name__ 45 | if classname.find('Linear') != -1: 46 | torch.nn.init.kaiming_normal_(m.weight, a=0.2, mode='fan_in', nonlinearity='leaky_relu') 47 | 48 | class CustomMappingNetwork(nn.Module): 49 | def __init__(self, z_dim, map_hidden_dim, map_output_dim): 50 | super().__init__() 51 | 52 | 53 | 54 | self.network = nn.Sequential(nn.Linear(z_dim, map_hidden_dim), 55 | nn.LeakyReLU(0.2, inplace=True), 56 | 57 | nn.Linear(map_hidden_dim, map_hidden_dim), 58 | nn.LeakyReLU(0.2, inplace=True), 59 | 60 | nn.Linear(map_hidden_dim, map_hidden_dim), 61 | nn.LeakyReLU(0.2, inplace=True), 62 | 63 | nn.Linear(map_hidden_dim, map_output_dim)) 64 | 65 | self.network.apply(kaiming_leaky_init) 66 | with torch.no_grad(): 67 | self.network[-1].weight *= 0.25 68 | 69 | def forward(self, z): 70 | frequencies_offsets = self.network(z) 71 | frequencies = frequencies_offsets[..., :frequencies_offsets.shape[-1]//2] 72 | phase_shifts = frequencies_offsets[..., frequencies_offsets.shape[-1]//2:] 73 | 74 | return frequencies, phase_shifts 75 | 76 | 77 | def frequency_init(freq): 78 | def init(m): 79 | with torch.no_grad(): 80 | if isinstance(m, nn.Linear): 81 | num_input = m.weight.size(-1) 82 | m.weight.uniform_(-np.sqrt(6 / num_input) / freq, np.sqrt(6 / num_input) / freq) 83 | return init 84 | 85 | class FiLMLayer(nn.Module): 86 | def __init__(self, input_dim, hidden_dim): 87 | super().__init__() 88 | self.layer = nn.Linear(input_dim, hidden_dim) 89 | 90 | def forward(self, x, freq, phase_shift): 91 | x = self.layer(x) 92 | freq = freq.unsqueeze(1).expand_as(x) 93 | phase_shift = phase_shift.unsqueeze(1).expand_as(x) 94 | return torch.sin(freq * x + phase_shift) 95 | 96 | 97 | class TALLSIREN(nn.Module): 98 | """Primary SIREN architecture used in pi-GAN generators.""" 99 | 100 | def __init__(self, input_dim=2, z_dim=100, hidden_dim=256, output_dim=1, device=None): 101 | super().__init__() 102 | self.device = device 103 | self.input_dim = input_dim 104 | self.z_dim = z_dim 105 | self.hidden_dim = hidden_dim 106 | self.output_dim = output_dim 107 | 108 | self.network = nn.ModuleList([ 109 | FiLMLayer(input_dim, hidden_dim), 110 | FiLMLayer(hidden_dim, hidden_dim), 111 | FiLMLayer(hidden_dim, hidden_dim), 112 | FiLMLayer(hidden_dim, hidden_dim), 113 | FiLMLayer(hidden_dim, hidden_dim), 114 | FiLMLayer(hidden_dim, hidden_dim), 115 | FiLMLayer(hidden_dim, hidden_dim), 116 | FiLMLayer(hidden_dim, hidden_dim), 117 | ]) 118 | self.final_layer = nn.Linear(hidden_dim, 1) 119 | 120 | self.color_layer_sine = FiLMLayer(hidden_dim + 3, hidden_dim) 121 | self.color_layer_linear = nn.Sequential(nn.Linear(hidden_dim, 3), nn.Sigmoid()) 122 | 123 | self.mapping_network = CustomMappingNetwork(z_dim, 256, (len(self.network) + 1)*hidden_dim*2) 124 | 125 | self.network.apply(frequency_init(25)) 126 | self.final_layer.apply(frequency_init(25)) 127 | self.color_layer_sine.apply(frequency_init(25)) 128 | self.color_layer_linear.apply(frequency_init(25)) 129 | self.network[0].apply(first_layer_film_sine_init) 130 | 131 | def forward(self, input, z, ray_directions, **kwargs): 132 | frequencies, phase_shifts = self.mapping_network(z) 133 | return self.forward_with_frequencies_phase_shifts(input, frequencies, phase_shifts, ray_directions, **kwargs) 134 | 135 | def forward_with_frequencies_phase_shifts(self, input, frequencies, phase_shifts, ray_directions, **kwargs): 136 | frequencies = frequencies*15 + 30 137 | 138 | x = input 139 | 140 | for index, layer in enumerate(self.network): 141 | start = index * self.hidden_dim 142 | end = (index+1) * self.hidden_dim 143 | x = layer(x, frequencies[..., start:end], phase_shifts[..., start:end]) 144 | 145 | sigma = self.final_layer(x) 146 | rbg = self.color_layer_sine(torch.cat([ray_directions, x], dim=-1), frequencies[..., -self.hidden_dim:], phase_shifts[..., -self.hidden_dim:]) 147 | rbg = self.color_layer_linear(rbg) 148 | 149 | return torch.cat([rbg, sigma], dim=-1) 150 | 151 | 152 | class UniformBoxWarp(nn.Module): 153 | def __init__(self, sidelength): 154 | super().__init__() 155 | self.scale_factor = 2/sidelength 156 | 157 | def forward(self, coordinates): 158 | return coordinates * self.scale_factor 159 | 160 | class SPATIALSIRENBASELINE(nn.Module): 161 | """Same architecture as TALLSIREN but adds a UniformBoxWarp to map input points to -1, 1""" 162 | 163 | def __init__(self, input_dim=2, z_dim=100, hidden_dim=256, output_dim=1, device=None): 164 | super().__init__() 165 | self.device = device 166 | self.input_dim = input_dim 167 | self.z_dim = z_dim 168 | self.hidden_dim = hidden_dim 169 | self.output_dim = output_dim 170 | 171 | self.network = nn.ModuleList([ 172 | FiLMLayer(3, hidden_dim), 173 | FiLMLayer(hidden_dim, hidden_dim), 174 | FiLMLayer(hidden_dim, hidden_dim), 175 | FiLMLayer(hidden_dim, hidden_dim), 176 | FiLMLayer(hidden_dim, hidden_dim), 177 | FiLMLayer(hidden_dim, hidden_dim), 178 | FiLMLayer(hidden_dim, hidden_dim), 179 | FiLMLayer(hidden_dim, hidden_dim), 180 | ]) 181 | self.final_layer = nn.Linear(hidden_dim, 1) 182 | 183 | self.color_layer_sine = FiLMLayer(hidden_dim + 3, hidden_dim) 184 | self.color_layer_linear = nn.Sequential(nn.Linear(hidden_dim, 3)) 185 | 186 | self.mapping_network = CustomMappingNetwork(z_dim, 256, (len(self.network) + 1)*hidden_dim*2) 187 | 188 | self.network.apply(frequency_init(25)) 189 | self.final_layer.apply(frequency_init(25)) 190 | self.color_layer_sine.apply(frequency_init(25)) 191 | self.color_layer_linear.apply(frequency_init(25)) 192 | self.network[0].apply(first_layer_film_sine_init) 193 | 194 | self.gridwarper = UniformBoxWarp(0.24) # Don't worry about this, it was added to ensure compatibility with another model. Shouldn't affect performance. 195 | 196 | def forward(self, input, z, ray_directions, **kwargs): 197 | frequencies, phase_shifts = self.mapping_network(z) 198 | return self.forward_with_frequencies_phase_shifts(input, frequencies, phase_shifts, ray_directions, **kwargs) 199 | 200 | def forward_with_frequencies_phase_shifts(self, input, frequencies, phase_shifts, ray_directions, **kwargs): 201 | frequencies = frequencies*15 + 30 202 | 203 | input = self.gridwarper(input) 204 | x = input 205 | 206 | for index, layer in enumerate(self.network): 207 | start = index * self.hidden_dim 208 | end = (index+1) * self.hidden_dim 209 | x = layer(x, frequencies[..., start:end], phase_shifts[..., start:end]) 210 | 211 | sigma = self.final_layer(x) 212 | rbg = self.color_layer_sine(torch.cat([ray_directions, x], dim=-1), frequencies[..., -self.hidden_dim:], phase_shifts[..., -self.hidden_dim:]) 213 | rbg = torch.sigmoid(self.color_layer_linear(rbg)) 214 | 215 | return torch.cat([rbg, sigma], dim=-1) 216 | 217 | 218 | 219 | class UniformBoxWarp(nn.Module): 220 | def __init__(self, sidelength): 221 | super().__init__() 222 | self.scale_factor = 2/sidelength 223 | 224 | def forward(self, coordinates): 225 | return coordinates * self.scale_factor 226 | 227 | 228 | def sample_from_3dgrid(coordinates, grid): 229 | """ 230 | Expects coordinates in shape (batch_size, num_points_per_batch, 3) 231 | Expects grid in shape (1, channels, H, W, D) 232 | (Also works if grid has batch size) 233 | Returns sampled features of shape (batch_size, num_points_per_batch, feature_channels) 234 | """ 235 | coordinates = coordinates.float() 236 | grid = grid.float() 237 | 238 | batch_size, n_coords, n_dims = coordinates.shape 239 | sampled_features = torch.nn.functional.grid_sample(grid.expand(batch_size, -1, -1, -1, -1), 240 | coordinates.reshape(batch_size, 1, 1, -1, n_dims), 241 | mode='bilinear', padding_mode='zeros', align_corners=True) 242 | N, C, H, W, D = sampled_features.shape 243 | sampled_features = sampled_features.permute(0, 4, 3, 2, 1).reshape(N, H*W*D, C) 244 | return sampled_features 245 | 246 | 247 | def modified_first_sine_init(m): 248 | with torch.no_grad(): 249 | # if hasattr(m, 'weight'): 250 | if isinstance(m, nn.Linear): 251 | num_input = 3 252 | m.weight.uniform_(-1 / num_input, 1 / num_input) 253 | 254 | 255 | class EmbeddingPiGAN128(nn.Module): 256 | """Smaller architecture that has an additional cube of embeddings. Often gives better fine details.""" 257 | 258 | def __init__(self, input_dim=2, z_dim=100, hidden_dim=128, output_dim=1, device=None): 259 | super().__init__() 260 | self.device = device 261 | self.input_dim = input_dim 262 | self.z_dim = z_dim 263 | self.hidden_dim = hidden_dim 264 | self.output_dim = output_dim 265 | 266 | self.network = nn.ModuleList([ 267 | FiLMLayer(32 + 3, hidden_dim), 268 | FiLMLayer(hidden_dim, hidden_dim), 269 | FiLMLayer(hidden_dim, hidden_dim), 270 | FiLMLayer(hidden_dim, hidden_dim), 271 | FiLMLayer(hidden_dim, hidden_dim), 272 | FiLMLayer(hidden_dim, hidden_dim), 273 | FiLMLayer(hidden_dim, hidden_dim), 274 | FiLMLayer(hidden_dim, hidden_dim), 275 | ]) 276 | print(self.network) 277 | 278 | self.final_layer = nn.Linear(hidden_dim, 1) 279 | 280 | self.color_layer_sine = FiLMLayer(hidden_dim + 3, hidden_dim) 281 | self.color_layer_linear = nn.Sequential(nn.Linear(hidden_dim, 3)) 282 | 283 | self.mapping_network = CustomMappingNetwork(z_dim, 256, (len(self.network) + 1)*hidden_dim*2) 284 | 285 | self.network.apply(frequency_init(25)) 286 | self.final_layer.apply(frequency_init(25)) 287 | self.color_layer_sine.apply(frequency_init(25)) 288 | self.color_layer_linear.apply(frequency_init(25)) 289 | self.network[0].apply(modified_first_sine_init) 290 | 291 | self.spatial_embeddings = nn.Parameter(torch.randn(1, 32, 96, 96, 96)*0.01) 292 | 293 | # !! Important !! Set this value to the expected side-length of your scene. e.g. for for faces, heads usually fit in 294 | # a box of side-length 0.24, since the camera has such a narrow FOV. For other scenes, with higher FOV, probably needs to be bigger. 295 | self.gridwarper = UniformBoxWarp(0.24) 296 | 297 | def forward(self, input, z, ray_directions, **kwargs): 298 | frequencies, phase_shifts = self.mapping_network(z) 299 | return self.forward_with_frequencies_phase_shifts(input, frequencies, phase_shifts, ray_directions, **kwargs) 300 | 301 | def forward_with_frequencies_phase_shifts(self, input, frequencies, phase_shifts, ray_directions, **kwargs): 302 | frequencies = frequencies*15 + 30 303 | 304 | input = self.gridwarper(input) 305 | shared_features = sample_from_3dgrid(input, self.spatial_embeddings) 306 | x = torch.cat([shared_features, input], -1) 307 | 308 | for index, layer in enumerate(self.network): 309 | start = index * self.hidden_dim 310 | end = (index+1) * self.hidden_dim 311 | x = layer(x, frequencies[..., start:end], phase_shifts[..., start:end]) 312 | 313 | sigma = self.final_layer(x) 314 | rbg = self.color_layer_sine(torch.cat([ray_directions, x], dim=-1), frequencies[..., -self.hidden_dim:], phase_shifts[..., -self.hidden_dim:]) 315 | rbg = torch.sigmoid(self.color_layer_linear(rbg)) 316 | 317 | return torch.cat([rbg, sigma], dim=-1) 318 | 319 | 320 | 321 | class EmbeddingPiGAN256(EmbeddingPiGAN128): 322 | def __init__(self, *args, **kwargs): 323 | super().__init__(*args, **kwargs, hidden_dim=256) 324 | self.spatial_embeddings = nn.Parameter(torch.randn(1, 32, 64, 64, 64)*0.1) 325 | 326 | 327 | 328 | 329 | 330 | class GeoSIREN(nn.Module): 331 | """Primary SIREN architecture used in pi-GAN generators.""" 332 | 333 | def __init__(self, input_dim=2, z_dim=100, hidden_dim=256, output_dim=1, device=None, last_op=None, scale=1.): 334 | super().__init__() 335 | self.device = device 336 | self.input_dim = input_dim 337 | self.z_dim = z_dim 338 | self.hidden_dim = hidden_dim 339 | self.output_dim = output_dim 340 | self.scale = scale 341 | self.last_op = last_op 342 | 343 | self.network = nn.ModuleList([ 344 | FiLMLayer(input_dim, hidden_dim), 345 | FiLMLayer(hidden_dim, hidden_dim), 346 | FiLMLayer(hidden_dim, hidden_dim), 347 | FiLMLayer(hidden_dim, hidden_dim), 348 | FiLMLayer(hidden_dim, hidden_dim), 349 | FiLMLayer(hidden_dim, hidden_dim), 350 | FiLMLayer(hidden_dim, hidden_dim), 351 | FiLMLayer(hidden_dim, hidden_dim), 352 | ]) 353 | self.final_layer = nn.Linear(hidden_dim, output_dim) 354 | 355 | self.color_layer_sine = FiLMLayer(hidden_dim + 3, hidden_dim) 356 | self.color_layer_linear = nn.Sequential(nn.Linear(hidden_dim, 3), nn.Sigmoid()) 357 | 358 | self.mapping_network = CustomMappingNetwork(z_dim, 256, (len(self.network) + 1)*hidden_dim*2) 359 | 360 | self.network.apply(frequency_init(25)) 361 | self.final_layer.apply(frequency_init(25)) 362 | self.color_layer_sine.apply(frequency_init(25)) 363 | self.color_layer_linear.apply(frequency_init(25)) 364 | self.network[0].apply(first_layer_film_sine_init) 365 | 366 | def forward(self, input, z, **kwargs): 367 | frequencies, phase_shifts = self.mapping_network(z) 368 | return self.forward_with_frequencies_phase_shifts(input, frequencies, phase_shifts, **kwargs) 369 | 370 | def forward_with_frequencies_phase_shifts(self, input, frequencies, phase_shifts, **kwargs): 371 | frequencies = frequencies*15 + 30 372 | 373 | x = input 374 | 375 | for index, layer in enumerate(self.network): 376 | start = index * self.hidden_dim 377 | end = (index+1) * self.hidden_dim 378 | x = layer(x, frequencies[..., start:end], phase_shifts[..., start:end]) 379 | 380 | sigma = self.final_layer(x) 381 | 382 | if self.last_op is not None: 383 | sigma = self.last_op(sigma) 384 | sigma = sigma*self.scale 385 | 386 | return sigma #torch.cat([sigma], dim=-1) 387 | 388 | 389 | -------------------------------------------------------------------------------- /lib/models/smplx.py: -------------------------------------------------------------------------------- 1 | """ 2 | original from https://github.com/vchoutas/smplx 3 | modified by Vassilis and Yao 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import numpy as np 9 | import pickle 10 | import torch.nn.functional as F 11 | import os 12 | import yaml 13 | 14 | from .lbs import Struct, to_tensor, to_np, lbs, vertices2landmarks, JointsFromVerticesSelector, find_dynamic_lmk_idx_and_bcoords 15 | 16 | ## SMPLX 17 | J14_NAMES = [ 18 | 'right_ankle', 19 | 'right_knee', 20 | 'right_hip', 21 | 'left_hip', 22 | 'left_knee', 23 | 'left_ankle', 24 | 'right_wrist', 25 | 'right_elbow', 26 | 'right_shoulder', 27 | 'left_shoulder', 28 | 'left_elbow', 29 | 'left_wrist', 30 | 'neck', 31 | 'head', 32 | ] 33 | SMPLX_names = ['pelvis', 'left_hip', 'right_hip', 'spine1', 'left_knee', 'right_knee', 'spine2', 'left_ankle', 'right_ankle', 'spine3', 'left_foot', 'right_foot', 'neck', 'left_collar', 'right_collar', 'head', 'left_shoulder', 'right_shoulder', 'left_elbow', 'right_elbow', 'left_wrist', 'right_wrist', 'jaw', 'left_eye_smplx', 'right_eye_smplx', 'left_index1', 'left_index2', 'left_index3', 'left_middle1', 'left_middle2', 'left_middle3', 'left_pinky1', 'left_pinky2', 'left_pinky3', 'left_ring1', 'left_ring2', 'left_ring3', 'left_thumb1', 'left_thumb2', 'left_thumb3', 'right_index1', 'right_index2', 'right_index3', 'right_middle1', 'right_middle2', 'right_middle3', 'right_pinky1', 'right_pinky2', 'right_pinky3', 'right_ring1', 'right_ring2', 'right_ring3', 'right_thumb1', 'right_thumb2', 'right_thumb3', 'right_eye_brow1', 'right_eye_brow2', 'right_eye_brow3', 'right_eye_brow4', 'right_eye_brow5', 'left_eye_brow5', 'left_eye_brow4', 'left_eye_brow3', 'left_eye_brow2', 'left_eye_brow1', 'nose1', 'nose2', 'nose3', 'nose4', 'right_nose_2', 'right_nose_1', 'nose_middle', 'left_nose_1', 'left_nose_2', 'right_eye1', 'right_eye2', 'right_eye3', 'right_eye4', 'right_eye5', 'right_eye6', 'left_eye4', 'left_eye3', 'left_eye2', 'left_eye1', 'left_eye6', 'left_eye5', 'right_mouth_1', 'right_mouth_2', 'right_mouth_3', 'mouth_top', 'left_mouth_3', 'left_mouth_2', 'left_mouth_1', 'left_mouth_5', 'left_mouth_4', 'mouth_bottom', 'right_mouth_4', 'right_mouth_5', 'right_lip_1', 'right_lip_2', 'lip_top', 'left_lip_2', 'left_lip_1', 'left_lip_3', 'lip_bottom', 'right_lip_3', 'right_contour_1', 'right_contour_2', 'right_contour_3', 'right_contour_4', 'right_contour_5', 'right_contour_6', 'right_contour_7', 'right_contour_8', 'contour_middle', 'left_contour_8', 'left_contour_7', 'left_contour_6', 'left_contour_5', 'left_contour_4', 'left_contour_3', 'left_contour_2', 'left_contour_1', 'head_top', 'left_big_toe', 'left_ear', 'left_eye', 'left_heel', 'left_index', 'left_middle', 'left_pinky', 'left_ring', 'left_small_toe', 'left_thumb', 'nose', 'right_big_toe', 'right_ear', 'right_eye', 'right_heel', 'right_index', 'right_middle', 'right_pinky', 'right_ring', 'right_small_toe', 'right_thumb'] 34 | extra_names = ['head_top', 'left_big_toe', 'left_ear', 'left_eye', 'left_heel', 'left_index', 'left_middle', 'left_pinky', 'left_ring', 'left_small_toe', 'left_thumb', 'nose', 'right_big_toe', 'right_ear', 'right_eye', 'right_heel', 'right_index', 'right_middle', 'right_pinky', 'right_ring', 'right_small_toe', 'right_thumb'] 35 | SMPLX_names += extra_names 36 | 37 | part_indices = {} 38 | part_indices['body'] = np.array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 39 | 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 123, 40 | 124, 125, 126, 127, 132, 134, 135, 136, 137, 138, 143]) 41 | part_indices['torso'] = np.array([ 0, 1, 2, 3, 6, 9, 12, 13, 14, 15, 16, 17, 18, 42 | 19, 22, 23, 24, 55, 56, 57, 58, 59, 76, 77, 78, 79, 43 | 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 44 | 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 45 | 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 46 | 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 47 | 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144]) 48 | part_indices['head'] = np.array([ 12, 15, 22, 23, 24, 55, 56, 57, 58, 59, 60, 61, 62, 49 | 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 50 | 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 51 | 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 52 | 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 53 | 115, 116, 117, 118, 119, 120, 121, 122, 123, 125, 126, 134, 136, 54 | 137]) 55 | part_indices['face'] = np.array([ 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 56 | 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 57 | 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 58 | 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 59 | 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 60 | 119, 120, 121, 122]) 61 | part_indices['upper'] = np.array([ 12, 13, 14, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 62 | 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 63 | 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 64 | 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 65 | 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 66 | 119, 120, 121, 122]) 67 | part_indices['hand'] = np.array([ 20, 21, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 68 | 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 69 | 49, 50, 51, 52, 53, 54, 128, 129, 130, 131, 133, 139, 140, 70 | 141, 142, 144]) 71 | part_indices['left_hand'] = np.array([ 20, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 72 | 37, 38, 39, 128, 129, 130, 131, 133]) 73 | part_indices['right_hand'] = np.array([ 21, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 74 | 52, 53, 54, 139, 140, 141, 142, 144]) 75 | # kinematic tree 76 | head_kin_chain = [15,12,9,6,3,0] 77 | 78 | #--smplx joints 79 | # 00 - Global 80 | # 01 - L_Thigh 81 | # 02 - R_Thigh 82 | # 03 - Spine 83 | # 04 - L_Calf 84 | # 05 - R_Calf 85 | # 06 - Spine1 86 | # 07 - L_Foot 87 | # 08 - R_Foot 88 | # 09 - Spine2 89 | # 10 - L_Toes 90 | # 11 - R_Toes 91 | # 12 - Neck 92 | # 13 - L_Shoulder 93 | # 14 - R_Shoulder 94 | # 15 - Head 95 | # 16 - L_UpperArm 96 | # 17 - R_UpperArm 97 | # 18 - L_ForeArm 98 | # 19 - R_ForeArm 99 | # 20 - L_Hand 100 | # 21 - R_Hand 101 | # 22 - Jaw 102 | # 23 - L_Eye 103 | # 24 - R_Eye 104 | 105 | class SMPLX(nn.Module): 106 | """ 107 | Given smplx parameters, this class generates a differentiable SMPLX function 108 | which outputs a mesh and 3D joints 109 | """ 110 | def __init__(self, config): 111 | super(SMPLX, self).__init__() 112 | print("creating the SMPLX Decoder") 113 | ss = np.load(config.smplx_model_path, allow_pickle=True) 114 | smplx_model = Struct(**ss) 115 | 116 | self.dtype = torch.float32 117 | self.register_buffer('faces_tensor', to_tensor(to_np(smplx_model.f, dtype=np.int64), dtype=torch.long)) 118 | # The vertices of the template model 119 | self.register_buffer('v_template', to_tensor(to_np(smplx_model.v_template), dtype=self.dtype)) 120 | # The shape components and expression 121 | # expression space is the same as FLAME 122 | shapedirs = to_tensor(to_np(smplx_model.shapedirs), dtype=self.dtype) 123 | shapedirs = torch.cat([shapedirs[:,:,:config.n_shape], shapedirs[:,:,300:300+config.n_exp]], 2) 124 | self.register_buffer('shapedirs', shapedirs) 125 | # The pose components 126 | num_pose_basis = smplx_model.posedirs.shape[-1] 127 | posedirs = np.reshape(smplx_model.posedirs, [-1, num_pose_basis]).T 128 | self.register_buffer('posedirs', to_tensor(to_np(posedirs), dtype=self.dtype)) 129 | self.register_buffer('J_regressor', to_tensor(to_np(smplx_model.J_regressor), dtype=self.dtype)) 130 | parents = to_tensor(to_np(smplx_model.kintree_table[0])).long(); parents[0] = -1 131 | self.register_buffer('parents', parents) 132 | self.register_buffer('lbs_weights', to_tensor(to_np(smplx_model.weights), dtype=self.dtype)) 133 | # for face keypoints 134 | self.register_buffer('lmk_faces_idx', torch.tensor(smplx_model.lmk_faces_idx, dtype=torch.long)) 135 | self.register_buffer('lmk_bary_coords', torch.tensor(smplx_model.lmk_bary_coords, dtype=self.dtype)) 136 | self.register_buffer('dynamic_lmk_faces_idx', torch.tensor(smplx_model.dynamic_lmk_faces_idx, dtype=torch.long)) 137 | self.register_buffer('dynamic_lmk_bary_coords', torch.tensor(smplx_model.dynamic_lmk_bary_coords, dtype=self.dtype)) 138 | # pelvis to head, to calculate head yaw angle, then find the dynamic landmarks 139 | self.register_buffer('head_kin_chain', torch.tensor(head_kin_chain, dtype=torch.long)) 140 | 141 | self.n_shape = config.n_shape 142 | self.n_pose = num_pose_basis 143 | 144 | #-- initialize parameters 145 | # shape and expression 146 | self.register_buffer('shape_params', nn.Parameter(torch.zeros([1, config.n_shape], dtype=self.dtype), requires_grad=False)) 147 | self.register_buffer('expression_params', nn.Parameter(torch.zeros([1, config.n_exp], dtype=self.dtype), requires_grad=False)) 148 | # pose: represented as rotation matrx [number of joints, 3, 3] 149 | self.register_buffer('global_pose', nn.Parameter(torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(1,1,1), requires_grad=False)) 150 | self.register_buffer('head_pose', nn.Parameter(torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(1,1,1), requires_grad=False)) 151 | self.register_buffer('neck_pose', nn.Parameter(torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(1,1,1), requires_grad=False)) 152 | self.register_buffer('jaw_pose', nn.Parameter(torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(1,1,1), requires_grad=False)) 153 | self.register_buffer('eye_pose', nn.Parameter(torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(2,1,1), requires_grad=False)) 154 | self.register_buffer('body_pose', nn.Parameter(torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(21,1,1), requires_grad=False)) 155 | self.register_buffer('left_hand_pose', nn.Parameter(torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(15,1,1), requires_grad=False)) 156 | self.register_buffer('right_hand_pose', nn.Parameter(torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(15,1,1), requires_grad=False)) 157 | 158 | if config.extra_joint_path: 159 | self.extra_joint_selector = JointsFromVerticesSelector( 160 | fname=config.extra_joint_path) 161 | self.use_joint_regressor = True 162 | self.keypoint_names = SMPLX_names 163 | if self.use_joint_regressor: 164 | with open(config.j14_regressor_path, 'rb') as f: 165 | j14_regressor = pickle.load(f, encoding='latin1') 166 | source = [] 167 | target = [] 168 | for idx, name in enumerate(self.keypoint_names): 169 | if name in J14_NAMES: 170 | source.append(idx) 171 | target.append(J14_NAMES.index(name)) 172 | source = np.asarray(source) 173 | target = np.asarray(target) 174 | self.register_buffer('source_idxs', torch.from_numpy(source)) 175 | self.register_buffer('target_idxs', torch.from_numpy(target)) 176 | joint_regressor = torch.from_numpy( 177 | j14_regressor).to(dtype=torch.float32) 178 | self.register_buffer('extra_joint_regressor', joint_regressor) 179 | self.part_indices = part_indices 180 | 181 | def forward(self, shape_params=None, expression_params=None, 182 | global_pose=None, body_pose=None, 183 | jaw_pose=None, eye_pose=None, 184 | left_hand_pose=None, right_hand_pose=None, full_pose=None, 185 | offset=None, transl=None, return_T = False): 186 | """ 187 | Args: 188 | shape_params: [N, number of shape parameters] 189 | expression_params: [N, number of expression parameters] 190 | global_pose: pelvis pose, [N, 1, 3, 3] 191 | body_pose: [N, 21, 3, 3] 192 | jaw_pose: [N, 1, 3, 3] 193 | eye_pose: [N, 2, 3, 3] 194 | left_hand_pose: [N, 15, 3, 3] 195 | right_hand_pose: [N, 15, 3, 3] 196 | Returns: 197 | vertices: [N, number of vertices, 3] 198 | landmarks: [N, number of landmarks (68 face keypoints), 3] 199 | joints: [N, number of smplx joints (145), 3] 200 | """ 201 | if shape_params is None: 202 | batch_size = full_pose.shape[0] 203 | shape_params = self.shape_params.expand(batch_size, -1) 204 | else: 205 | batch_size = shape_params.shape[0] 206 | if expression_params is None: 207 | expression_params = self.expression_params.expand(batch_size, -1) 208 | 209 | if full_pose is None: 210 | if global_pose is None: 211 | global_pose = self.global_pose.unsqueeze(0).expand(batch_size, -1, -1, -1) 212 | if body_pose is None: 213 | body_pose = self.body_pose.unsqueeze(0).expand(batch_size, -1, -1, -1) 214 | if jaw_pose is None: 215 | jaw_pose = self.jaw_pose.unsqueeze(0).expand(batch_size, -1, -1, -1) 216 | if eye_pose is None: 217 | eye_pose = self.eye_pose.unsqueeze(0).expand(batch_size, -1, -1, -1) 218 | if left_hand_pose is None: 219 | left_hand_pose = self.left_hand_pose.unsqueeze(0).expand(batch_size, -1, -1, -1) 220 | if right_hand_pose is None: 221 | right_hand_pose = self.right_hand_pose.unsqueeze(0).expand(batch_size, -1, -1, -1) 222 | full_pose = torch.cat([global_pose, body_pose, 223 | jaw_pose, eye_pose, 224 | left_hand_pose, right_hand_pose], dim=1) 225 | 226 | shape_components = torch.cat([shape_params, expression_params], dim=1) 227 | if offset is not None: 228 | if len(offset.shape) == 2: 229 | template_vertices = (self.v_template+offset).unsqueeze(0).expand(batch_size, -1, -1) 230 | else: 231 | template_vertices = self.v_template.unsqueeze(0).expand(batch_size, -1, -1) + offset 232 | else: 233 | template_vertices = self.v_template.unsqueeze(0).expand(batch_size, -1, -1) 234 | 235 | # smplx 236 | vertices, joints, A, T, shape_offsets, pose_offsets = lbs(shape_components, full_pose, template_vertices, 237 | self.shapedirs, self.posedirs, 238 | self.J_regressor, self.parents, 239 | self.lbs_weights, dtype=self.dtype, 240 | pose2rot = False) 241 | # face dynamic landmarks 242 | lmk_faces_idx = self.lmk_faces_idx.unsqueeze(dim=0).expand(batch_size, -1) 243 | lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).expand(batch_size, -1, -1) 244 | dyn_lmk_faces_idx, dyn_lmk_bary_coords = ( 245 | find_dynamic_lmk_idx_and_bcoords( 246 | vertices, full_pose, 247 | self.dynamic_lmk_faces_idx, 248 | self.dynamic_lmk_bary_coords, 249 | self.head_kin_chain) 250 | ) 251 | lmk_faces_idx = torch.cat([lmk_faces_idx, dyn_lmk_faces_idx], 1) 252 | lmk_bary_coords = torch.cat([lmk_bary_coords, dyn_lmk_bary_coords], 1) 253 | landmarks = vertices2landmarks(vertices, self.faces_tensor, 254 | lmk_faces_idx, 255 | lmk_bary_coords) 256 | 257 | final_joint_set = [joints, landmarks] 258 | if hasattr(self, 'extra_joint_selector'): 259 | # Add any extra joints that might be needed 260 | extra_joints = self.extra_joint_selector(vertices, self.faces_tensor) 261 | final_joint_set.append(extra_joints) 262 | # Create the final joint set 263 | joints = torch.cat(final_joint_set, dim=1) 264 | if self.use_joint_regressor: 265 | reg_joints = torch.einsum( 266 | 'ji,bik->bjk', self.extra_joint_regressor, vertices) 267 | joints[:, self.source_idxs] = ( 268 | joints[:, self.source_idxs].detach() * 0.0 + 269 | reg_joints[:, self.target_idxs] * 1.0 270 | ) 271 | ### translate z. 272 | # original: -0.3 ~ 0.5 273 | # now: + 0.5 274 | # vertices[:,:,-1] = vertices[:,:,-1] + 1 275 | if transl is not None: 276 | joints = joints + transl.unsqueeze(dim=1) 277 | vertices = vertices + transl.unsqueeze(dim=1) 278 | if return_T: 279 | A[..., :3, 3] += transl.unsqueeze(dim=1) 280 | T[..., :3, 3] += transl.unsqueeze(dim=1) 281 | 282 | if return_T: 283 | return vertices, landmarks, joints, A, T, shape_offsets, pose_offsets 284 | else: 285 | return vertices, landmarks, joints 286 | 287 | def pose_abs2rel(self, global_pose, body_pose, abs_joint = 'head'): 288 | ''' change absolute pose to relative pose 289 | Basic knowledge for SMPLX kinematic tree: 290 | absolute pose = parent pose * relative pose 291 | Here, pose must be represented as rotation matrix (batch_sizexnx3x3) 292 | ''' 293 | if abs_joint == 'head': 294 | # Pelvis -> Spine 1, 2, 3 -> Neck -> Head 295 | kin_chain = [15, 12, 9, 6, 3, 0] 296 | elif abs_joint == 'right_wrist': 297 | # Pelvis -> Spine 1, 2, 3 -> right Collar -> right shoulder 298 | # -> right elbow -> right wrist 299 | kin_chain = [21, 19, 17, 14, 9, 6, 3, 0] 300 | elif abs_joint == 'left_wrist': 301 | # Pelvis -> Spine 1, 2, 3 -> Left Collar -> Left shoulder 302 | # -> Left elbow -> Left wrist 303 | kin_chain = [20, 18, 16, 13, 9, 6, 3, 0] 304 | else: 305 | raise NotImplementedError( 306 | f'pose_abs2rel does not support: {abs_joint}') 307 | 308 | batch_size = global_pose.shape[0] 309 | dtype = global_pose.dtype 310 | device = global_pose.device 311 | full_pose = torch.cat([global_pose, body_pose], dim=1) 312 | rel_rot_mat = torch.eye( 313 | 3, device=device, 314 | dtype=dtype).unsqueeze_(dim=0).repeat(batch_size, 1, 1) 315 | for idx in kin_chain[1:]: 316 | rel_rot_mat = torch.bmm(full_pose[:, idx], rel_rot_mat) 317 | 318 | # This contains the absolute pose of the parent 319 | abs_parent_pose = rel_rot_mat.detach() 320 | # Let's assume that in the input this specific joint is predicted as an absolute value 321 | abs_joint_pose = body_pose[:, kin_chain[0] - 1] 322 | # abs_head = parents(abs_neck) * rel_head ==> rel_head = abs_neck.T * abs_head 323 | rel_joint_pose = torch.matmul( 324 | abs_parent_pose.reshape(-1, 3, 3).transpose(1, 2), 325 | abs_joint_pose.reshape(-1, 3, 3)) 326 | # Replace the new relative pose 327 | body_pose[:, kin_chain[0] - 1, :, :] = rel_joint_pose 328 | return body_pose 329 | -------------------------------------------------------------------------------- /lib/utils/camera_util.py: -------------------------------------------------------------------------------- 1 | ''' Projections 2 | Pinhole Camera Model 3 | 4 | Ref: 5 | http://web.stanford.edu/class/cs231a/lectures/lecture2_camera_models.pdf 6 | https://github.com/YadiraF/face3d/blob/master/face3d/mesh/transform.py 7 | https://github.com/daniilidis-group/neural_renderer/blob/master/neural_renderer 8 | https://www.scratchapixel.com/lessons/mathematics-physics-for-computer-graphics/lookat-function 9 | https://www.scratchapixel.com/lessons/3d-basic-rendering/3d-viewing-pinhole-camera/implementing-virtual-pinhole-camera 10 | https://kornia.readthedocs.io/en/v0.1.2/pinhole.html 11 | 12 | Pinhole Camera Model: 13 | 1. world space to camera space: 14 | Normally, the object is represented in the world reference system, 15 | need to first map it into camera system/coordinates/space 16 | P_camera = [R|t]P_world 17 | [R|t]: camera to world transformation matrix, defines the camera position and oritation 18 | Given: where the camera is in the world system. (extrinsic/external) 19 | represented by: 20 | look at camera: eye position, at, up direction 21 | 2. camera space to image space: 22 | Then project the obj into image plane 23 | P_image = K[I|0]P_camera 24 | K = [[fx, 0, cx], 25 | [ 0, fy, cy], 26 | [ 0, 0, 1]] 27 | Given: settings of te camera. (intrinsic/internal) 28 | represented by: 29 | focal lengh, image/film size 30 | or fov, near, far 31 | Angle of View: computed from the focal length and the film size parameters. 32 | perspective projection: 33 | x' = f*x/z; y' = f*y/z 34 | 35 | Finally: 36 | P_image = MXP_world 37 | M = K[R T] 38 | homo: 4x4 39 | ''' 40 | import torch 41 | import numpy as np 42 | import torch.nn.functional as F 43 | 44 | # --------------------- 0. camera projection 45 | def transform(points, intrinsic, extrinsic): 46 | ''' Perspective projection 47 | Args: 48 | points: [bz, np, 3] 49 | intrinsic: [bz, 3, 3] 50 | extrinsic: [bz, 3, 4] 51 | ''' 52 | points = torch.matmul(points, extrinsic[:,:3,:3].transpose(1,2)) + extrinsic[:,:3,3][:,None,:] 53 | points = torch.matmul(points, intrinsic.transpose(1,2)) 54 | 55 | ## if homo 56 | # vertices_homo = camera_util.homogeneous(mesh['vertices']) 57 | # transformed_verts = torch.matmul(vertices_homo, extrinsic.transpose(1,2)) 58 | return points 59 | 60 | def perspective_project(points, focal=None, image_hw=None, extrinsic=None, transl=None): 61 | ''' points from world space to ndc space (for pytorch3d rendering) 62 | TODO 63 | ''' 64 | batch_size = points.shape[0] 65 | device = points.device 66 | dtype = points.dtype 67 | # non homo 68 | if points.shape[-1] == 3: 69 | if transl is not None: 70 | points = points + transl[:,None,:] 71 | if extrinsic is not None: 72 | points = torch.matmul(points, extrinsic[:,:3,:3].transpose(1,2)) + extrinsic[:,:3,3][:,None,:] 73 | if focal is not None: 74 | # import ipdb; ipdb.set_trace() 75 | if image_hw is not None: 76 | H, W = image_hw 77 | fx = 2*focal/H 78 | fy = 2*focal/W 79 | else: 80 | fx = fy = focal 81 | 82 | # 2/H is for normalization 83 | # intrinsic = torch.tensor( 84 | # [[fx, 0, 0], 85 | # [0, fy, 0], 86 | # [0, 0, 1]], device=device, dtype=dtype)[None,...].repeat(batch_size, 1, 1) 87 | intrinsic = torch.tensor( 88 | [[1, 0, 0], 89 | [0, 1, 0], 90 | [0, 0, 1]], device=device, dtype=dtype)[None,...].repeat(batch_size, 1, 1) 91 | intrinsic_xy = intrinsic[:,:2]*fx 92 | intrinsic_z = intrinsic[:,2:] 93 | intrinsic = torch.cat([intrinsic_xy, intrinsic_z], dim = 1) 94 | # if points.requires_grad: 95 | # import ipdb; ipdb.set_trace() 96 | points = torch.matmul(points, intrinsic.transpose(1,2)) 97 | # perspective distortion 98 | # points[:,:,:2] = points[:,:,:2]/(points[:,:,[2]]+1e-5) # inplace, has problem for gradient 99 | z = points[:,:,[2]] 100 | xy = points[:,:,:2]/(z+1e-6) 101 | points = torch.cat([xy, z], dim=-1) 102 | return points 103 | 104 | def perspective_project_inv(points, focal=None, image_hw=None, extrinsic=None, transl=None): 105 | ''' points from world space to ndc space (for pytorch3d rendering) 106 | TODO 107 | ''' 108 | batch_size = points.shape[0] 109 | device = points.device 110 | dtype = points.dtype 111 | # non homo 112 | if points.shape[-1] == 3: 113 | if focal is not None: 114 | # import ipdb; ipdb.set_trace() 115 | if image_hw is not None: 116 | H, W = image_hw 117 | fx = 2*focal/H 118 | fy = 2*focal/W 119 | else: 120 | fx = fy = focal 121 | # 2/H is for normalization 122 | intrinsic = torch.tensor( 123 | [[1, 0, 0], 124 | [0, 1, 0], 125 | [0, 0, 1]], device=device, dtype=dtype)[None,...].repeat(batch_size, 1, 1) 126 | intrinsic_xy = intrinsic[:,:2]*fx 127 | intrinsic_z = intrinsic[:,2:] 128 | intrinsic = torch.cat([intrinsic_xy, intrinsic_z], dim = 1) 129 | 130 | z = points[:,:,[2]] 131 | xy = points[:,:,:2]*(z+1e-5) 132 | points = torch.cat([xy, z], dim=-1) 133 | 134 | intrinsic = torch.inverse(intrinsic) 135 | points = torch.matmul(points, intrinsic.transpose(1,2)) 136 | if transl is not None: 137 | points = points - transl[:,None,:] 138 | if extrinsic is not None: 139 | points = torch.matmul(points, extrinsic[:,:3,:3].transpose(1,2)) + extrinsic[:,:3,3][:,None,:] 140 | return points 141 | 142 | 143 | # TODO: homo 144 | 145 | # --------------------- 1. world space to camera space: 146 | def look_at(eye, at=[0, 0, 0], up=[0, 1, 0]): 147 | """ 148 | "Look at" transformation of vertices. 149 | standard camera space: 150 | camera located at the origin. 151 | looking down negative z-axis. 152 | vertical vector is y-axis. 153 | Xcam = R(X - C) 154 | Homo: [[R, -RC], 155 | [0, 1]] 156 | Args: 157 | eye: [3,] the XYZ world space position of the camera. 158 | at: [3,] a position along the center of the camera's gaze. 159 | up: [3,] up direction 160 | Returns: 161 | extrinsic: R, t 162 | """ 163 | device = eye.device 164 | # if list or tuple convert to numpy array 165 | if isinstance(at, list) or isinstance(at, tuple): 166 | at = torch.tensor(at, dtype=torch.float32, device=device) 167 | # if numpy array convert to tensor 168 | elif isinstance(at, np.ndarray): 169 | at = torch.from_numpy(at).to(device) 170 | elif torch.is_tensor(at): 171 | at.to(device) 172 | 173 | if isinstance(up, list) or isinstance(up, tuple): 174 | up = torch.tensor(up, dtype=torch.float32, device=device) 175 | elif isinstance(up, np.ndarray): 176 | up = torch.from_numpy(up).to(device) 177 | elif torch.is_tensor(up): 178 | up.to(device) 179 | 180 | if isinstance(eye, list) or isinstance(eye, tuple): 181 | eye = torch.tensor(eye, dtype=torch.float32, device=device) 182 | elif isinstance(eye, np.ndarray): 183 | eye = torch.from_numpy(eye).to(device) 184 | elif torch.is_tensor(eye): 185 | eye = eye.to(device) 186 | 187 | batch_size = eye.shape[0] 188 | if eye.ndimension() == 1: 189 | eye = eye[None, :].repeat(batch_size, 1) 190 | if at.ndimension() == 1: 191 | at = at[None, :].repeat(batch_size, 1) 192 | if up.ndimension() == 1: 193 | up = up[None, :].repeat(batch_size, 1) 194 | 195 | # create new axes 196 | # eps is chosen as 0.5 to match the chainer version 197 | z_axis = F.normalize(at - eye, eps=1e-5) 198 | x_axis = F.normalize(torch.cross(up, z_axis), eps=1e-5) 199 | y_axis = F.normalize(torch.cross(z_axis, x_axis), eps=1e-5) 200 | # create rotation matrix: [bs, 3, 3] 201 | r = torch.cat((x_axis[:, None, :], y_axis[:, None, :], z_axis[:, None, :]), dim=1) 202 | 203 | camera_R = r 204 | camera_t = eye 205 | # Note: x_new = R(x - t) 206 | return camera_R, camera_t 207 | 208 | def get_extrinsic(R, t, homo=False): 209 | batch_size = R.shape[0] 210 | device = R.device 211 | extrinsic = torch.eye(4, device=device).repeat(batch_size, 1, 1) 212 | extrinsic[:, :3, :3] = R 213 | extrinsic[:, :3, [-1]] = -torch.matmul(R, t[:,:,None]) 214 | if homo: 215 | return extrinsic 216 | else: 217 | return extrinsic[:,:3,:] 218 | 219 | #------------------------- 2. camera space to image space (perspective projection): 220 | def get_intrinsic(focal, H, W, cx=0., cy=0., homo=False, batch_size=1., device='cuda:0'): 221 | ''' 222 | given different control parameteres 223 | TODO: generate intrinsic matrix from other inputs 224 | P = np.array([[near/right, 0, 0, 0], 225 | [0, near/top, 0, 0], 226 | [0, 0, -(far+near)/(far-near), -2*far*near/(far-near)], 227 | [0, 0, -1, 0]]) 228 | ''' 229 | intrinsic = torch.eye(4, device=device).repeat(batch_size, 1, 1) 230 | f = focal 231 | K = torch.tensor( 232 | [[f/(H/2), 0, cx], 233 | [0, f/(W/2), cy], 234 | [0, 0, 1]], device=device, dtype=torch.float32)[None,...] 235 | intrinsic[:, :3, :3] = K 236 | if homo: 237 | return intrinsic 238 | else: 239 | return intrinsic[:,:3,:3] 240 | 241 | 242 | #------------------------- composite intrinsic and extrinsic into one matrix 243 | def compose_matrix(K, R, t): 244 | ''' 245 | Args: 246 | K: [N, 3, 3] 247 | R: [N, 3, 3] 248 | t: [N, 3] 249 | Returns: 250 | P: [N, 4, 4] 251 | ## test if homo is the same as no homo: 252 | batch_size, nv, _ = trans_verts.shape 253 | trans_verts = torch.cat([trans_verts, torch.ones([batch_size, nv, 1], device=trans_verts.device)], dim=-1) 254 | trans_verts = torch.matmul(trans_verts, P.transpose(1,2))[:,:,:3] 255 | (tested, the same) 256 | # t = -Rt 257 | ''' 258 | batch_size = K.shape[0] 259 | device = K.device 260 | intrinsic = torch.eye(4, device=device).repeat(batch_size, 1, 1) 261 | extrinsic = torch.eye(4, device=device).repeat(batch_size, 1, 1) 262 | intrinsic[:, :3, :3] = K 263 | extrinsic[:, :3, :3] = R 264 | # import ipdb; ipdb.set_trace() 265 | extrinsic[:, :3, [-1]] = -torch.matmul(R, t[:,:,None]) 266 | P = torch.matmul(intrinsic, extrinsic) 267 | return P, intrinsic, extrinsic 268 | 269 | # def perspective_project(vertices, K, R, t, dist_coeffs, orig_size, eps=1e-9): 270 | # ''' 271 | # Calculate projective transformation of vertices given a projection matrix 272 | # Input parameters: 273 | # K: batch_size * 3 * 3 intrinsic camera matrix 274 | # R, t: batch_size * 3 * 3, batch_size * 1 * 3 extrinsic calibration parameters 275 | # dist_coeffs: vector of distortion coefficients 276 | # orig_size: original size of image captured by the camera 277 | # Returns: For each point [X,Y,Z] in world coordinates [u,v,z] where u,v are the coordinates of the projection in 278 | # pixels and z is the depth 279 | # ''' 280 | # # instead of P*x we compute x'*P' 281 | # vertices = torch.matmul(vertices, R.transpose(2,1)) + t 282 | # x, y, z = vertices[:, :, 0], vertices[:, :, 1], vertices[:, :, 2] 283 | # x_ = x / (z + eps) 284 | # y_ = y / (z + eps) 285 | 286 | # # Get distortion coefficients from vector 287 | # k1 = dist_coeffs[:, None, 0] 288 | # k2 = dist_coeffs[:, None, 1] 289 | # p1 = dist_coeffs[:, None, 2] 290 | # p2 = dist_coeffs[:, None, 3] 291 | # k3 = dist_coeffs[:, None, 4] 292 | 293 | # # we use x_ for x' and x__ for x'' etc. 294 | # r = torch.sqrt(x_ ** 2 + y_ ** 2) 295 | # x__ = x_*(1 + k1*(r**2) + k2*(r**4) + k3*(r**6)) + 2*p1*x_*y_ + p2*(r**2 + 2*x_**2) 296 | # y__ = y_*(1 + k1*(r**2) + k2*(r**4) + k3 *(r**6)) + p1*(r**2 + 2*y_**2) + 2*p2*x_*y_ 297 | # vertices = torch.stack([x__, y__, torch.ones_like(z)], dim=-1) 298 | # vertices = torch.matmul(vertices, K.transpose(1,2)) 299 | # u, v = vertices[:, :, 0], vertices[:, :, 1] 300 | # v = orig_size - v 301 | # # map u,v from [0, img_size] to [-1, 1] to use by the renderer 302 | # u = 2 * (u - orig_size / 2.) / orig_size 303 | # v = 2 * (v - orig_size / 2.) / orig_size 304 | # vertices = torch.stack([u, v, z], dim=-1) 305 | # return vertices 306 | 307 | def to_homo(points): 308 | ''' 309 | points: [N, num of points, 2/3] 310 | ''' 311 | batch_size, num, _ = points.shape 312 | points = torch.cat([points, torch.ones([batch_size, num, 1], device=points.device, dtype=points.dtype)], dim=-1) 313 | return points 314 | 315 | def homogeneous(points): 316 | """ 317 | Concat 1 to each point 318 | :param points (..., 3) 319 | :return (..., 4) 320 | """ 321 | return F.pad(points, (0, 1), "constant", 1.0) 322 | 323 | 324 | def batch_orth_proj(X, camera): 325 | ''' orthgraphic projection 326 | X: 3d vertices, [bz, n_point, 3] 327 | camera: scale and translation, [bz, 3], [scale, tx, ty] 328 | ''' 329 | camera = camera.clone().view(-1, 1, 3) 330 | X_trans = X[:, :, :2] + camera[:, :, 1:] 331 | X_trans = torch.cat([X_trans, X[:,:,2:]], 2) 332 | shape = X_trans.shape 333 | Xn = (camera[:, :, 0:1] * X_trans) 334 | return Xn -------------------------------------------------------------------------------- /lib/utils/config.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | import argparse 3 | import yaml 4 | import os 5 | 6 | cfg = CN() 7 | 8 | workdir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')) 9 | 10 | # project settings 11 | cfg.output_dir = '' 12 | cfg.group = 'test' 13 | cfg.exp_name = None 14 | cfg.device = 'cuda:0' 15 | 16 | # load models 17 | cfg.ckpt_path = None 18 | cfg.nerf_ckpt_path = 'exps/mpiis/DSC_7157/model.tar' 19 | cfg.mesh_ckpt_path = '' 20 | cfg.pose_ckpt_path = '' 21 | 22 | # ---------------------------------------------------------------------------- # 23 | # Options for SCARF model 24 | # ---------------------------------------------------------------------------- # 25 | cfg.depth_std = 0. 26 | cfg.opt_nerf_pose = True 27 | cfg.opt_beta = True 28 | cfg.opt_mesh = True 29 | cfg.sample_patch_rays = False 30 | cfg.sample_patch_size = 32 31 | 32 | # merf model 33 | cfg.mesh_offset_scale = 0.01 34 | cfg.exclude_hand = False 35 | cfg.use_perspective = False 36 | cfg.cam_mode = 'orth' 37 | cfg.use_mesh = True 38 | cfg.mesh_only_body = False 39 | cfg.use_highres_smplx = True 40 | cfg.freqs_tex = 10 41 | cfg.tex_detach_verts = False 42 | cfg.tex_network = 'siren' # siren 43 | cfg.use_nerf = True 44 | cfg.use_fine = True 45 | cfg.share_fine = False 46 | cfg.freqs_xyz = 10 47 | cfg.freqs_dir = 4 48 | cfg.use_view = False 49 | # skinning 50 | cfg.use_outer_mesh = False 51 | cfg.lbs_map = False 52 | cfg.k_neigh = 1 53 | cfg.weighted_neigh = False 54 | cfg.use_valid = False 55 | cfg.dis_threshold = 0.01 56 | ## deformation code 57 | cfg.use_deformation = False #'opt_code' 58 | cfg.deformation_type = '' #'opt_code' 59 | cfg.deformation_dim = 0 60 | cfg.latent_dim = 0 61 | cfg.pose_dim = 69 62 | ## appearance code 63 | cfg.use_appearance = False 64 | cfg.appearance_type = '' #'opt_code' # or pose 65 | cfg.appearance_dim = 0 # if pose, should be 3 66 | ## mesh tex code 67 | cfg.use_texcond = False 68 | cfg.texcond_type = '' #'opt_code' # or pose 69 | cfg.texcond_dim = 0 # if pose, should be 3 70 | 71 | # nerf rendering 72 | cfg.n_samples = 64 73 | cfg.n_importance = 32 74 | cfg.n_depth = 0 75 | cfg.chunk = 32*32 # chunk size to split the input to avoid OOM 76 | cfg.query_inside = False 77 | cfg.white_bkgd = True 78 | 79 | ####--------pose model 80 | cfg.opt_pose = False 81 | cfg.opt_exp = False 82 | cfg.opt_cam = False 83 | cfg.opt_appearance = False 84 | cfg.opt_focal = False 85 | 86 | # ---------------------------------------------------------------------------- # 87 | # Options for Training 88 | # ---------------------------------------------------------------------------- # 89 | cfg.train = CN() 90 | cfg.train.optimizer = 'adam' 91 | cfg.train.resume = True 92 | cfg.train.batch_size = 1 93 | cfg.train.max_epochs = 100000 94 | cfg.train.max_steps = 100*1000 95 | cfg.train.lr = 5e-4 96 | cfg.train.pose_lr = 5e-4 97 | cfg.train.tex_lr = 5e-4 98 | cfg.train.geo_lr = 5e-4 99 | cfg.train.decay_steps = [50] 100 | cfg.train.decay_gamma = 0.5 101 | cfg.train.precrop_iters = 400 102 | cfg.train.precrop_frac = 0.5 103 | cfg.train.coarse_iters = 0 # in the begining, only train coarse 104 | # logger 105 | cfg.train.log_dir = 'logs' 106 | cfg.train.log_steps = 10 107 | cfg.train.vis_dir = 'train_images' 108 | cfg.train.vis_steps = 200 109 | cfg.train.write_summary = True 110 | cfg.train.checkpoint_steps = 500 111 | cfg.train.val_steps = 200 112 | cfg.train.val_vis_dir = 'val_images' 113 | cfg.train.eval_steps = 5000 114 | 115 | # ---------------------------------------------------------------------------- # 116 | # Options for Dataset 117 | # ---------------------------------------------------------------------------- # 118 | cfg.dataset = CN() 119 | cfg.dataset.type = 'scarf' 120 | cfg.dataset.path = '' 121 | cfg.dataset.white_bg = True 122 | cfg.dataset.subjects = None 123 | cfg.dataset.n_subjects = 1 124 | cfg.dataset.image_size = 512 125 | cfg.dataset.num_workers = 4 126 | cfg.dataset.n_images = 1000000 127 | cfg.dataset.updated_params_path = '' 128 | cfg.dataset.load_gt_pose = True 129 | cfg.dataset.load_normal = False 130 | cfg.dataset.load_perspective = False 131 | 132 | # training setting 133 | cfg.dataset.train = CN() 134 | cfg.dataset.train.cam_list = [] 135 | cfg.dataset.train.frame_start = 0 136 | cfg.dataset.train.frame_end = 10000 137 | cfg.dataset.train.frame_step = 4 138 | cfg.dataset.val = CN() 139 | cfg.dataset.val.cam_list = [] 140 | cfg.dataset.val.frame_start = 400 141 | cfg.dataset.val.frame_end = 500 142 | cfg.dataset.val.frame_step = 4 143 | cfg.dataset.test = CN() 144 | cfg.dataset.test.cam_list = [] 145 | cfg.dataset.test.frame_start = 400 146 | cfg.dataset.test.frame_end = 500 147 | cfg.dataset.test.frame_step = 4 148 | 149 | # ---------------------------------------------------------------------------- # 150 | # Options for losses 151 | # ---------------------------------------------------------------------------- # 152 | cfg.loss = CN() 153 | cfg.loss.w_rgb = 1. 154 | cfg.loss.w_patch_mrf = 0.#0005 155 | cfg.loss.w_patch_perceptual = 0.#0005 156 | cfg.loss.w_alpha = 0. 157 | cfg.loss.w_xyz = 1. 158 | cfg.loss.w_depth = 0. 159 | cfg.loss.w_depth_close = 0. 160 | cfg.loss.reg_depth = 0. 161 | cfg.loss.use_mse = False 162 | cfg.loss.mesh_w_rgb = 1. 163 | cfg.loss.mesh_w_normal = 0.1 164 | cfg.loss.mesh_w_mrf = 0. 165 | cfg.loss.mesh_w_perceptual = 0. 166 | cfg.loss.mesh_w_alpha = 0. 167 | cfg.loss.mesh_w_alpha_skin = 0. 168 | cfg.loss.skin_consistency_type = 'verts_all_mean' 169 | cfg.loss.mesh_skin_consistency = 0.001 170 | cfg.loss.mesh_inside_mask = 100. 171 | cfg.loss.nerf_hard = 0. 172 | cfg.loss.nerf_hard_scale = 1. 173 | cfg.loss.mesh_reg_wdecay = 0. 174 | # regs 175 | cfg.loss.geo_reg = True 176 | cfg.loss.reg_beta_l1 = 1e-4 177 | cfg.loss.reg_cam_l1 = 1e-4 178 | cfg.loss.reg_pose_l1 = 1e-4 179 | cfg.loss.reg_a_norm = 1e-4 180 | cfg.loss.reg_beta_temp = 1e-4 181 | cfg.loss.reg_cam_temp = 1e-4 182 | cfg.loss.reg_pose_temp = 1e-4 183 | cfg.loss.nerf_reg_dxyz_w = 1e-4 184 | ## 185 | cfg.loss.reg_lap_w = 1.0 186 | cfg.loss.reg_edge_w = 10.0 187 | cfg.loss.reg_normal_w = 0.01 188 | cfg.loss.reg_offset_w = 100. 189 | cfg.loss.reg_offset_w_face = 500. 190 | cfg.loss.reg_offset_w_body = 0. 191 | cfg.loss.use_new_edge_loss = False 192 | ## new 193 | cfg.loss.pose_reg = False 194 | cfg.loss.background_reg = False 195 | cfg.loss.nerf_reg_normal_w = 0. #0.01 196 | 197 | # ---------------------------------------------------------------------------- # 198 | # Options for Body model 199 | # ---------------------------------------------------------------------------- # 200 | cfg.data_dir = os.path.join(workdir, 'data') 201 | cfg.model = CN() 202 | cfg.model.highres_path = os.path.join(cfg.data_dir, 'subdiv_level_1') 203 | cfg.model.topology_path = os.path.join(cfg.data_dir, 'SMPL_X_template_FLAME_uv.obj') 204 | cfg.model.smplx_model_path = os.path.join(cfg.data_dir, 'SMPLX_NEUTRAL_2020.npz') 205 | cfg.model.extra_joint_path = os.path.join(cfg.data_dir, 'smplx_extra_joints.yaml') 206 | cfg.model.j14_regressor_path = os.path.join(cfg.data_dir, 'SMPLX_to_J14.pkl') 207 | cfg.model.mano_ids_path = os.path.join(cfg.data_dir, 'MANO_SMPLX_vertex_ids.pkl') 208 | cfg.model.flame_vertex_masks_path = os.path.join(cfg.data_dir, 'FLAME_masks.pkl') 209 | cfg.model.flame_ids_path = os.path.join(cfg.data_dir, 'SMPL-X__FLAME_vertex_ids.npy') 210 | cfg.model.n_shape = 10 211 | cfg.model.n_exp = 10 212 | 213 | def get_cfg_defaults(): 214 | """Get a yacs CfgNode object with default values for my_project.""" 215 | # Return a clone so that the defaults will not be altered 216 | # This is for the "local variable" use pattern 217 | return cfg.clone() 218 | 219 | def update_cfg(cfg, cfg_file): 220 | cfg.merge_from_file(cfg_file) 221 | return cfg.clone() 222 | 223 | def parse_args(): 224 | parser = argparse.ArgumentParser() 225 | parser.add_argument('--cfg', type=str, default = os.path.join(os.path.join(root_dir, 'configs/nerf_pl'), 'test.yaml'), help='cfg file path', ) 226 | parser.add_argument('--mode', type=str, default = 'train', help='mode: train, test') 227 | parser.add_argument('--random_beta', action="store_true", default = False, help='delete folders') 228 | parser.add_argument('--clean', action="store_true", default = False, help='delete folders') 229 | parser.add_argument('--debug', action="store_true", default = False, help='debug model') 230 | 231 | args = parser.parse_args() 232 | print(args, end='\n\n') 233 | 234 | cfg = get_cfg_defaults() 235 | if args.cfg is not None: 236 | cfg_file = args.cfg 237 | cfg = update_cfg(cfg, args.cfg) 238 | cfg.cfg_file = cfg_file 239 | cfg.mode = args.mode 240 | cfg.clean = args.clean 241 | cfg.debug = args.debug 242 | cfg.random_beta = args.random_beta 243 | return cfg 244 | -------------------------------------------------------------------------------- /lib/utils/lossfunc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torchvision 4 | import torchvision.models as models 5 | import torch.nn.functional as F 6 | from functools import reduce 7 | import scipy.sparse as sp 8 | import numpy as np 9 | from chumpy.utils import row, col 10 | 11 | def get_vert_connectivity(num_vertices, faces): 12 | """ 13 | Returns a sparse matrix (of size #verts x #verts) where each nonzero 14 | element indicates a neighborhood relation. For example, if there is a 15 | nonzero element in position (15,12), that means vertex 15 is connected 16 | by an edge to vertex 12. 17 | Adapted from https://github.com/mattloper/opendr/ 18 | """ 19 | 20 | vpv = sp.csc_matrix((num_vertices,num_vertices)) 21 | 22 | # for each column in the faces... 23 | for i in range(3): 24 | IS = faces[:,i] 25 | JS = faces[:,(i+1)%3] 26 | data = np.ones(len(IS)) 27 | ij = np.vstack((row(IS.flatten()), row(JS.flatten()))) 28 | mtx = sp.csc_matrix((data, ij), shape=vpv.shape) 29 | vpv = vpv + mtx + mtx.T 30 | return vpv 31 | 32 | def get_vertices_per_edge(num_vertices, faces): 33 | """ 34 | Returns an Ex2 array of adjacencies between vertices, where 35 | each element in the array is a vertex index. Each edge is included 36 | only once. If output of get_faces_per_edge is provided, this is used to 37 | avoid call to get_vert_connectivity() 38 | Adapted from https://github.com/mattloper/opendr/ 39 | """ 40 | 41 | vc = sp.coo_matrix(get_vert_connectivity(num_vertices, faces)) 42 | result = np.hstack((col(vc.row), col(vc.col))) 43 | result = result[result[:,0] < result[:,1]] # for uniqueness 44 | return result 45 | 46 | def relative_edge_loss(vertices1, vertices2, vertices_per_edge=None, faces=None, lossfunc=torch.nn.functional.mse_loss): 47 | """ 48 | Given two meshes of the same topology, returns the relative edge differences. 49 | 50 | """ 51 | 52 | if vertices_per_edge is None and faces is not None: 53 | vertices_per_edge = get_vertices_per_edge(len(vertices1), faces) 54 | elif vertices_per_edge is None and faces is None: 55 | raise ValueError("Either vertices_per_edge or faces must be specified") 56 | 57 | edges_for = lambda x: x[:, vertices_per_edge[:, 0], :] - x[:, vertices_per_edge[:, 1], :] 58 | return lossfunc(edges_for(vertices1), edges_for(vertices2)) 59 | 60 | 61 | def relative_laplacian_loss(mesh1, mesh2, lossfunc=torch.nn.functional.mse_loss): 62 | L = mesh2.laplacian_packed() 63 | L1 = L.mm(mesh1.verts_packed()) 64 | L2 = L.mm(mesh2.verts_packed()) 65 | loss = lossfunc(L1, L2) 66 | return loss 67 | 68 | def mesh_laplacian(meshes, method: str = "uniform"): 69 | r""" 70 | Computes the laplacian smoothing objective for a batch of meshes. 71 | This function supports three variants of Laplacian smoothing, 72 | namely with uniform weights("uniform"), with cotangent weights ("cot"), 73 | and cotangent curvature ("cotcurv").For more details read [1, 2]. 74 | 75 | Args: 76 | meshes: Meshes object with a batch of meshes. 77 | method: str specifying the method for the laplacian. 78 | Returns: 79 | loss: Average laplacian smoothing loss across the batch. 80 | Returns 0 if meshes contains no meshes or all empty meshes. 81 | 82 | Consider a mesh M = (V, F), with verts of shape Nx3 and faces of shape Mx3. 83 | The Laplacian matrix L is a NxN tensor such that LV gives a tensor of vectors: 84 | for a uniform Laplacian, LuV[i] points to the centroid of its neighboring 85 | vertices, a cotangent Laplacian LcV[i] is known to be an approximation of 86 | the surface normal, while the curvature variant LckV[i] scales the normals 87 | by the discrete mean curvature. For vertex i, assume S[i] is the set of 88 | neighboring vertices to i, a_ij and b_ij are the "outside" angles in the 89 | two triangles connecting vertex v_i and its neighboring vertex v_j 90 | for j in S[i], as seen in the diagram below. 91 | 92 | .. code-block:: python 93 | 94 | a_ij 95 | /\ 96 | / \ 97 | / \ 98 | / \ 99 | v_i /________\ v_j 100 | \ / 101 | \ / 102 | \ / 103 | \ / 104 | \/ 105 | b_ij 106 | 107 | The definition of the Laplacian is LV[i] = sum_j w_ij (v_j - v_i) 108 | For the uniform variant, w_ij = 1 / |S[i]| 109 | For the cotangent variant, 110 | w_ij = (cot a_ij + cot b_ij) / (sum_k cot a_ik + cot b_ik) 111 | For the cotangent curvature, w_ij = (cot a_ij + cot b_ij) / (4 A[i]) 112 | where A[i] is the sum of the areas of all triangles containing vertex v_i. 113 | 114 | There is a nice trigonometry identity to compute cotangents. Consider a triangle 115 | with side lengths A, B, C and angles a, b, c. 116 | 117 | .. code-block:: python 118 | 119 | c 120 | /|\ 121 | / | \ 122 | / | \ 123 | B / H| \ A 124 | / | \ 125 | / | \ 126 | /a_____|_____b\ 127 | C 128 | 129 | Then cot a = (B^2 + C^2 - A^2) / 4 * area 130 | We know that area = CH/2, and by the law of cosines we have 131 | 132 | A^2 = B^2 + C^2 - 2BC cos a => B^2 + C^2 - A^2 = 2BC cos a 133 | 134 | Putting these together, we get: 135 | 136 | B^2 + C^2 - A^2 2BC cos a 137 | _______________ = _________ = (B/H) cos a = cos a / sin a = cot a 138 | 4 * area 2CH 139 | 140 | 141 | [1] Desbrun et al, "Implicit fairing of irregular meshes using diffusion 142 | and curvature flow", SIGGRAPH 1999. 143 | 144 | [2] Nealan et al, "Laplacian Mesh Optimization", Graphite 2006. 145 | """ 146 | 147 | if meshes.isempty(): 148 | return torch.tensor( 149 | [0.0], dtype=torch.float32, device=meshes.device, requires_grad=True 150 | ) 151 | 152 | N = len(meshes) 153 | verts_packed = meshes.verts_packed() # (sum(V_n), 3) 154 | faces_packed = meshes.faces_packed() # (sum(F_n), 3) 155 | num_verts_per_mesh = meshes.num_verts_per_mesh() # (N,) 156 | verts_packed_idx = meshes.verts_packed_to_mesh_idx() # (sum(V_n),) 157 | weights = num_verts_per_mesh.gather(0, verts_packed_idx) # (sum(V_n),) 158 | weights = 1.0 / weights.float() 159 | 160 | # We don't want to backprop through the computation of the Laplacian; 161 | # just treat it as a magic constant matrix that is used to transform 162 | # verts into normals 163 | with torch.no_grad(): 164 | if method == "uniform": 165 | L = meshes.laplacian_packed() 166 | elif method in ["cot", "cotcurv"]: 167 | L, inv_areas = cot_laplacian(verts_packed, faces_packed) 168 | if method == "cot": 169 | norm_w = torch.sparse.sum(L, dim=1).to_dense().view(-1, 1) 170 | idx = norm_w > 0 171 | norm_w[idx] = 1.0 / norm_w[idx] 172 | else: 173 | L_sum = torch.sparse.sum(L, dim=1).to_dense().view(-1, 1) 174 | norm_w = 0.25 * inv_areas 175 | else: 176 | raise ValueError("Method should be one of {uniform, cot, cotcurv}") 177 | 178 | if method == "uniform": 179 | loss = L.mm(verts_packed) 180 | elif method == "cot": 181 | loss = L.mm(verts_packed) * norm_w - verts_packed 182 | elif method == "cotcurv": 183 | # pyre-fixme[61]: `norm_w` may not be initialized here. 184 | loss = (L.mm(verts_packed) - L_sum * verts_packed) * norm_w 185 | # import ipdb; ipdb.set_trace() 186 | # loss = loss.norm(dim=1) 187 | 188 | # loss = loss * weights 189 | # return loss.sum() / N 190 | return loss 191 | 192 | 193 | 194 | def huber(x, y, scaling=0.1): 195 | """ 196 | A helper function for evaluating the smooth L1 (huber) loss 197 | between the rendered silhouettes and colors. 198 | """ 199 | # import ipdb; ipdb.set_trace() 200 | diff_sq = (x - y) ** 2 201 | loss = ((1 + diff_sq / (scaling**2)).clamp(1e-4).sqrt() - 1) * float(scaling) 202 | # if mask is not None: 203 | # loss = loss.abs().sum()/mask.sum() 204 | # else: 205 | loss = loss.abs().mean() 206 | return loss 207 | 208 | class MSELoss(nn.Module): 209 | def __init__(self, w_alpha=0.): 210 | super(MSELoss, self).__init__() 211 | # self.loss = nn.MSELoss(reduction='mean') 212 | self.w_alpha = w_alpha 213 | 214 | def forward(self, inputs, targets): 215 | # import ipdb; ipdb.set_trace() 216 | # print(inputs['rgb_coarse'].min(), inputs['rgb_coarse'].max(), targets.min(), targets.max()) 217 | # loss = self.loss(inputs['rgb_coarse'], targets[:,:3]) 218 | loss = huber(inputs['rgb_coarse'], targets[:,:3]) 219 | 220 | if 'rgb_fine' in inputs: 221 | loss += huber(inputs['rgb_fine'], targets[:,:3]) 222 | 223 | # import ipdb; ipdb.set_trace() 224 | if self.w_alpha>0. and targets.shape[1]>3: 225 | weights = inputs['weights_coarse'] 226 | pix_alpha = weights.sum(dim=1) 227 | # pix_alpha[pix_alpha>1] = 1 228 | # loss += self.loss(pix_alpha, targets[:,3])*self.w_alpha 229 | loss += huber(pix_alpha, targets[:,3])*self.w_alpha 230 | if 'weights_fine' in inputs: 231 | weight = inputs['weights_fine'] 232 | pix_alpha = weights.sum(dim=1) 233 | # loss += self.loss(pix_alpha, targets[:,3])*self.w_alpha 234 | loss += huber(pix_alpha, targets[:,3])*self.w_alpha 235 | 236 | return loss 237 | 238 | 239 | loss_dict = {'mse': MSELoss} 240 | 241 | 242 | ### IDMRF loss 243 | class VGG19FeatLayer(nn.Module): 244 | def __init__(self): 245 | super(VGG19FeatLayer, self).__init__() 246 | self.vgg19 = models.vgg19(pretrained=True).features.eval().cuda() 247 | self.mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).cuda() 248 | self.std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).cuda() 249 | 250 | def forward(self, x): 251 | out = {} 252 | x = x - self.mean 253 | x = x/self.std 254 | ci = 1 255 | ri = 0 256 | for layer in self.vgg19.children(): 257 | if isinstance(layer, nn.Conv2d): 258 | ri += 1 259 | name = 'conv{}_{}'.format(ci, ri) 260 | elif isinstance(layer, nn.ReLU): 261 | ri += 1 262 | name = 'relu{}_{}'.format(ci, ri) 263 | layer = nn.ReLU(inplace=False) 264 | elif isinstance(layer, nn.MaxPool2d): 265 | ri = 0 266 | name = 'pool_{}'.format(ci) 267 | ci += 1 268 | elif isinstance(layer, nn.BatchNorm2d): 269 | name = 'bn_{}'.format(ci) 270 | else: 271 | raise RuntimeError('Unrecognized layer: {}'.format(layer.__class__.__name__)) 272 | x = layer(x) 273 | out[name] = x 274 | # print([x for x in out]) 275 | return out 276 | class IDMRFLoss(nn.Module): 277 | def __init__(self, featlayer=VGG19FeatLayer): 278 | super(IDMRFLoss, self).__init__() 279 | self.featlayer = featlayer() 280 | self.feat_style_layers = {'relu3_2': 1.0, 'relu4_2': 1.0} 281 | self.feat_content_layers = {'relu4_2': 1.0} 282 | self.bias = 1.0 283 | self.nn_stretch_sigma = 0.5 284 | self.lambda_style = 1.0 285 | self.lambda_content = 1.0 286 | 287 | def sum_normalize(self, featmaps): 288 | reduce_sum = torch.sum(featmaps, dim=1, keepdim=True) + 1e-5 289 | return featmaps / reduce_sum 290 | 291 | def patch_extraction(self, featmaps): 292 | patch_size = 1 293 | patch_stride = 1 294 | patches_as_depth_vectors = featmaps.unfold(2, patch_size, patch_stride).unfold(3, patch_size, patch_stride) 295 | self.patches_OIHW = patches_as_depth_vectors.permute(0, 2, 3, 1, 4, 5) 296 | dims = self.patches_OIHW.size() 297 | self.patches_OIHW = self.patches_OIHW.view(-1, dims[3], dims[4], dims[5]) 298 | return self.patches_OIHW 299 | 300 | def compute_relative_distances(self, cdist): 301 | epsilon = 1e-5 302 | div = torch.min(cdist, dim=1, keepdim=True)[0] 303 | relative_dist = cdist / (div + epsilon) 304 | return relative_dist 305 | 306 | def exp_norm_relative_dist(self, relative_dist): 307 | scaled_dist = relative_dist 308 | dist_before_norm = torch.exp((self.bias - scaled_dist)/self.nn_stretch_sigma) 309 | self.cs_NCHW = self.sum_normalize(dist_before_norm) 310 | return self.cs_NCHW 311 | 312 | def mrf_loss(self, gen, tar): 313 | meanT = torch.mean(tar, 1, keepdim=True) 314 | gen_feats, tar_feats = gen - meanT, tar - meanT 315 | gen_feats = gen_feats 316 | tar_feats = tar_feats 317 | gen_feats_norm = torch.norm(gen_feats, p=2, dim=1, keepdim=True) 318 | tar_feats_norm = torch.norm(tar_feats, p=2, dim=1, keepdim=True) 319 | gen_normalized = gen_feats / (gen_feats_norm + 1e-5) 320 | tar_normalized = tar_feats / (tar_feats_norm + 1e-5) 321 | 322 | cosine_dist_l = [] 323 | BatchSize = tar.size(0) 324 | 325 | for i in range(BatchSize): 326 | tar_feat_i = tar_normalized[i:i+1, :, :, :] 327 | gen_feat_i = gen_normalized[i:i+1, :, :, :] 328 | patches_OIHW = self.patch_extraction(tar_feat_i) 329 | cosine_dist_i = F.conv2d(gen_feat_i, patches_OIHW) 330 | cosine_dist_l.append(cosine_dist_i) 331 | cosine_dist = torch.cat(cosine_dist_l, dim=0) 332 | cosine_dist_zero_2_one = - (cosine_dist - 1) / 2 333 | relative_dist = self.compute_relative_distances(cosine_dist_zero_2_one) 334 | rela_dist = self.exp_norm_relative_dist(relative_dist) 335 | dims_div_mrf = rela_dist.size() 336 | k_max_nc = torch.max(rela_dist.view(dims_div_mrf[0], dims_div_mrf[1], -1), dim=2)[0] 337 | div_mrf = torch.mean(k_max_nc, dim=1) 338 | div_mrf_sum = -torch.log(div_mrf) 339 | div_mrf_sum = torch.sum(div_mrf_sum) 340 | return div_mrf_sum 341 | 342 | def forward(self, gen, tar): 343 | ## gen: [bz,3,h,w] rgb [0,1] 344 | gen_vgg_feats = self.featlayer(gen) 345 | tar_vgg_feats = self.featlayer(tar) 346 | style_loss_list = [self.feat_style_layers[layer] * self.mrf_loss(gen_vgg_feats[layer], tar_vgg_feats[layer]) for layer in self.feat_style_layers] 347 | self.style_loss = reduce(lambda x, y: x+y, style_loss_list) * self.lambda_style 348 | content_loss_list = [self.feat_content_layers[layer] * self.mrf_loss(gen_vgg_feats[layer], tar_vgg_feats[layer]) for layer in self.feat_content_layers] 349 | self.content_loss = reduce(lambda x, y: x+y, content_loss_list) * self.lambda_content 350 | return self.style_loss + self.content_loss 351 | 352 | # loss = 0 353 | # for key in self.feat_style_layers.keys(): 354 | # loss += torch.mean((gen_vgg_feats[key] - tar_vgg_feats[key])**2) 355 | # return loss 356 | 357 | 358 | class VGGPerceptualLoss(torch.nn.Module): 359 | def __init__(self, resize=True): 360 | super(VGGPerceptualLoss, self).__init__() 361 | blocks = [] 362 | blocks.append(torchvision.models.vgg16(pretrained=True).features[:4].eval()) 363 | blocks.append(torchvision.models.vgg16(pretrained=True).features[4:9].eval()) 364 | blocks.append(torchvision.models.vgg16(pretrained=True).features[9:16].eval()) 365 | blocks.append(torchvision.models.vgg16(pretrained=True).features[16:23].eval()) 366 | for bl in blocks: 367 | for p in bl.parameters(): 368 | p.requires_grad = False 369 | self.blocks = torch.nn.ModuleList(blocks) 370 | self.transform = torch.nn.functional.interpolate 371 | self.resize = resize 372 | self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) 373 | self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) 374 | 375 | def forward(self, input, target, feature_layers=[0, 1, 2, 3], style_layers=[]): 376 | if input.shape[1] != 3: 377 | input = input.repeat(1, 3, 1, 1) 378 | target = target.repeat(1, 3, 1, 1) 379 | input = (input-self.mean) / self.std 380 | target = (target-self.mean) / self.std 381 | if self.resize: 382 | input = self.transform(input, mode='bilinear', size=(224, 224), align_corners=False) 383 | target = self.transform(target, mode='bilinear', size=(224, 224), align_corners=False) 384 | loss = 0.0 385 | x = input 386 | y = target 387 | for i, block in enumerate(self.blocks): 388 | x = block(x) 389 | y = block(y) 390 | if i in feature_layers: 391 | loss += torch.nn.functional.l1_loss(x, y) 392 | if i in style_layers: 393 | act_x = x.reshape(x.shape[0], x.shape[1], -1) 394 | act_y = y.reshape(y.shape[0], y.shape[1], -1) 395 | gram_x = act_x @ act_x.permute(0, 2, 1) 396 | gram_y = act_y @ act_y.permute(0, 2, 1) 397 | loss += torch.nn.functional.l1_loss(gram_x, gram_y) 398 | return loss -------------------------------------------------------------------------------- /lib/utils/metric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from kornia.losses import ssim as dssim 3 | 4 | def mse(image_pred, image_gt, valid_mask=None, reduction='mean'): 5 | value = (image_pred-image_gt[:,:3])**2 6 | if valid_mask is not None: 7 | value = value[valid_mask] 8 | if reduction == 'mean': 9 | return torch.mean(value) 10 | return value 11 | 12 | def psnr(image_pred, image_gt, valid_mask=None, reduction='mean'): 13 | return -10*torch.log10(mse(image_pred, image_gt[:,:3], valid_mask, reduction)) 14 | 15 | def ssim(image_pred, image_gt, reduction='mean'): 16 | """ 17 | image_pred and image_gt: (1, 3, H, W) 18 | """ 19 | dssim_ = dssim(image_pred, image_gt, 3, reduction) # dissimilarity in [0, 1] 20 | return 1-2*dssim_ # in [-1, 1] -------------------------------------------------------------------------------- /lib/utils/perceptual_loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code heavily inspired by https://github.com/JustusThies/NeuralTexGen/blob/master/models/VGG_LOSS.py 3 | """ 4 | import torch 5 | from torchvision import models 6 | from torchvision.transforms import Normalize 7 | from collections import namedtuple 8 | import sys 9 | from pathlib import Path 10 | 11 | sys.path.append(str((Path(__file__).parents[2]/"deps"))) 12 | from InsightFace.recognition.arcface_torch.backbones import get_model 13 | 14 | 15 | class VGG16(torch.nn.Module): 16 | def __init__(self, requires_grad=False): 17 | super(VGG16, self).__init__() 18 | vgg_pretrained_features = models.vgg16(pretrained=True).features 19 | self.slice1 = torch.nn.Sequential() 20 | self.slice2 = torch.nn.Sequential() 21 | self.slice3 = torch.nn.Sequential() 22 | self.slice4 = torch.nn.Sequential() 23 | for x in range(4): 24 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 25 | for x in range(4, 9): 26 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 27 | for x in range(9, 16): 28 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 29 | for x in range(16, 23): 30 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 31 | if not requires_grad: 32 | for param in self.parameters(): 33 | param.requires_grad = False 34 | self.normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 35 | 36 | def forward(self, X): 37 | """ 38 | assuming rgb input of shape N x 3 x H x W normalized to -1 ... +1 39 | :param X: 40 | :return: 41 | """ 42 | X = self.normalize(X * .5 + .5) 43 | 44 | h = self.slice1(X) 45 | h_relu1_2 = h 46 | h = self.slice2(h) 47 | h_relu2_2 = h 48 | h = self.slice3(h) 49 | h_relu3_3 = h 50 | h = self.slice4(h) 51 | h_relu4_3 = h 52 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3']) 53 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3) 54 | return out 55 | 56 | 57 | class ResNet18(torch.nn.Module): 58 | def __init__(self, weight_path): 59 | super(ResNet18, self).__init__() 60 | net = get_model("r18") 61 | net.load_state_dict(torch.load(weight_path)) 62 | net.eval() 63 | self.conv1 = net.conv1 64 | self.bn1 = net.bn1 65 | self.prelu = net.prelu 66 | self.layer1 = net.layer1 67 | self.layer2 = net.layer2 68 | self.layer3 = net.layer3 69 | self.layer4 = net.layer4 70 | 71 | for p in self.parameters(): 72 | p.requires_grad = False 73 | 74 | def forward(self, x): 75 | """ 76 | assuming rgb input of shape N x 3 x H x W normalized to -1 ... +1 77 | :param X: 78 | :return: 79 | """ 80 | self.eval() 81 | x = self.conv1(x) 82 | x = self.bn1(x) 83 | x = self.prelu(x) 84 | x = self.layer1(x) 85 | relu1_2 = x.clone() 86 | x = self.layer2(x) 87 | relu2_2 = x.clone() 88 | x = self.layer3(x) 89 | relu3_3 = x.clone() 90 | x = self.layer4(x) 91 | relu4_3 = x.clone() 92 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3']) 93 | out = vgg_outputs(relu1_2, relu2_2, relu3_3, relu4_3) 94 | return out 95 | 96 | 97 | def gram_matrix(y): 98 | (b, ch, h, w) = y.size() 99 | features = y.view(b, ch, w * h) 100 | features_t = features.transpose(1, 2) 101 | gram = features.bmm(features_t) / (ch * h * w) 102 | return gram 103 | 104 | 105 | class ResNetLOSS(torch.nn.Module): 106 | def __init__(self, criterion=torch.nn.L1Loss(reduction='mean')): 107 | super(ResNetLOSS, self).__init__() 108 | self.model = ResNet18("assets/InsightFace/backbone.pth") 109 | self.model.eval() 110 | self.criterion = criterion 111 | self.criterion.reduction = "mean" 112 | 113 | for p in self.parameters(): 114 | p.requires_grad = False 115 | 116 | def forward(self, fake, target, content_weight=1.0, style_weight=1.0): 117 | """ 118 | assumes input images normalize to -1 ... + 1 and of shape N x 3 x H x W 119 | :param fake: 120 | :param target: 121 | :param content_weight: 122 | :param style_weight: 123 | :return: 124 | """ 125 | vgg_fake = self.model(fake) 126 | vgg_target = self.model(target) 127 | 128 | content_loss = self.criterion(vgg_target.relu2_2, vgg_fake.relu2_2) 129 | 130 | # gram_matrix 131 | gram_style = [gram_matrix(y) for y in vgg_target] 132 | style_loss = 0.0 133 | for ft_y, gm_s in zip(vgg_fake, gram_style): 134 | gm_y = gram_matrix(ft_y) 135 | style_loss += self.criterion(gm_y, gm_s) 136 | 137 | total_loss = content_weight * content_loss + style_weight * style_loss 138 | return total_loss 139 | -------------------------------------------------------------------------------- /lib/utils/rasterize_rendering.py: -------------------------------------------------------------------------------- 1 | ''' 2 | rasterization 3 | basic rasterize 4 | render shape (for visualization) 5 | render texture (need uv information) 6 | ''' 7 | 8 | import torch 9 | from torch import nn 10 | import torch.nn.functional as F 11 | 12 | from pytorch3d.structures import Meshes 13 | from pytorch3d.renderer.mesh import rasterize_meshes 14 | from . import util 15 | 16 | def add_directionlight(normals, lights=None): 17 | ''' 18 | normals: [bz, nv, 3] 19 | lights: [bz, nlight, 6] 20 | returns: 21 | shading: [bz, nv, 3] 22 | ''' 23 | light_direction = lights[:,:,:3]; light_intensities = lights[:,:,3:] 24 | directions_to_lights = F.normalize(light_direction[:,:,None,:].expand(-1,-1,normals.shape[1],-1), dim=3) 25 | # normals_dot_lights = torch.clamp((normals[:,None,:,:]*directions_to_lights).sum(dim=3), 0., 1.) 26 | # normals_dot_lights = (normals[:,None,:,:]*directions_to_lights).sum(dim=3) 27 | normals_dot_lights = torch.clamp((normals[:,None,:,:]*directions_to_lights).sum(dim=3), 0., 1.) 28 | shading = normals_dot_lights[:,:,:,None]*light_intensities[:,:,None,:] 29 | return shading.mean(1) 30 | 31 | 32 | def pytorch3d_rasterize(vertices, faces, image_size, attributes=None, 33 | soft=False, blur_radius=0.0, sigma=1e-8, faces_per_pixel=1, gamma=1e-4, 34 | perspective_correct=False, clip_barycentric_coords=True, h=None, w=None): 35 | fixed_vertices = vertices.clone() 36 | fixed_vertices[...,:2] = -fixed_vertices[...,:2] 37 | 38 | if h is None and w is None: 39 | image_size = image_size 40 | else: 41 | image_size = [h, w] 42 | if h>w: 43 | fixed_vertices[..., 1] = fixed_vertices[..., 1]*h/w 44 | else: 45 | fixed_vertices[..., 0] = fixed_vertices[..., 0]*w/h 46 | 47 | meshes_screen = Meshes(verts=fixed_vertices.float(), faces=faces.long()) 48 | # import ipdb; ipdb.set_trace() 49 | # pytorch3d rasterize 50 | pix_to_face, zbuf, bary_coords, dists = rasterize_meshes( 51 | meshes_screen, 52 | image_size=image_size, 53 | blur_radius=blur_radius, 54 | faces_per_pixel=faces_per_pixel, 55 | perspective_correct=perspective_correct, 56 | clip_barycentric_coords=clip_barycentric_coords, 57 | # max_faces_per_bin = faces.shape[1], 58 | bin_size = 0 59 | ) 60 | # import ipdb; ipdb.set_trace() 61 | vismask = (pix_to_face > -1).float().squeeze(-1) 62 | depth = zbuf.squeeze(-1) 63 | 64 | if soft: 65 | from pytorch3d.renderer.blending import _sigmoid_alpha 66 | colors = torch.ones_like(bary_coords) 67 | N, H, W, K = pix_to_face.shape 68 | pixel_colors = torch.ones((N, H, W, 4), dtype=colors.dtype, device=colors.device) 69 | pixel_colors[..., :3] = colors[..., 0, :] 70 | alpha = _sigmoid_alpha(dists, pix_to_face, sigma) 71 | pixel_colors[..., 3] = alpha 72 | pixel_colors = pixel_colors.permute(0,3,1,2) 73 | return pixel_colors 74 | 75 | if attributes is None: 76 | return depth, vismask 77 | else: 78 | vismask = (pix_to_face > -1).float() 79 | D = attributes.shape[-1] 80 | attributes = attributes.clone(); attributes = attributes.view(attributes.shape[0]*attributes.shape[1], 3, attributes.shape[-1]) 81 | N, H, W, K, _ = bary_coords.shape 82 | mask = pix_to_face == -1 83 | pix_to_face = pix_to_face.clone() 84 | pix_to_face[mask] = 0 85 | idx = pix_to_face.view(N * H * W * K, 1, 1).expand(N * H * W * K, 3, D) 86 | pixel_face_vals = attributes.gather(0, idx).view(N, H, W, K, 3, D) 87 | pixel_vals = (bary_coords[..., None] * pixel_face_vals).sum(dim=-2) 88 | pixel_vals[mask] = 0 # Replace masked values in output. 89 | pixel_vals = pixel_vals[:,:,:,0].permute(0,3,1,2) 90 | pixel_vals = torch.cat([pixel_vals, vismask[:,:,:,0][:,None,:,:]], dim=1) 91 | return pixel_vals 92 | 93 | # For visualization 94 | def render_shape(vertices, faces, image_size=None, background=None, lights=None, blur_radius=0., shift=True, colors=None, h=None, w=None): 95 | ''' 96 | -- rendering shape with detail normal map 97 | ''' 98 | batch_size = vertices.shape[0] 99 | transformed_vertices = vertices.clone() 100 | # set lighting 101 | # if lights is None: 102 | # light_positions = torch.tensor( 103 | # [ 104 | # [-1,1,1], 105 | # [1,1,1], 106 | # [-1,-1,1], 107 | # [1,-1,1], 108 | # [0,0,1] 109 | # ] 110 | # )[None,:,:].expand(batch_size, -1, -1).float() 111 | # light_intensities = torch.ones_like(light_positions).float()*1.7 112 | # lights = torch.cat((light_positions, light_intensities), 2).to(vertices.device) 113 | if lights is None: 114 | light_positions = torch.tensor( 115 | [ 116 | [-5, 5, -5], 117 | [5, 5, -5], 118 | [-5, -5, -5], 119 | [5, -5, -5], 120 | [0, 0, -5], 121 | ] 122 | )[None,:,:].expand(batch_size, -1, -1).float() 123 | 124 | light_intensities = torch.ones_like(light_positions).float()*1.7 125 | lights = torch.cat((light_positions, light_intensities), 2).to(vertices.device) 126 | if shift: 127 | transformed_vertices[:,:,2] = transformed_vertices[:,:,2] - transformed_vertices[:,:,2].min() 128 | transformed_vertices[:,:,2] = transformed_vertices[:,:,2]/transformed_vertices[:,:,2].max()*80 + 10 129 | 130 | # Attributes 131 | face_vertices = util.face_vertices(vertices, faces) 132 | normals = util.vertex_normals(vertices,faces); face_normals = util.face_vertices(normals, faces) 133 | transformed_normals = util.vertex_normals(transformed_vertices, faces); transformed_face_normals = util.face_vertices(transformed_normals, faces) 134 | if colors is None: 135 | face_colors = torch.ones_like(face_vertices)*180/255. 136 | else: 137 | face_colors = util.face_vertices(colors, faces) 138 | 139 | attributes = torch.cat([face_colors, 140 | transformed_face_normals.detach(), 141 | face_vertices.detach(), 142 | face_normals], 143 | -1) 144 | # rasterize 145 | rendering = pytorch3d_rasterize(transformed_vertices, faces, image_size=image_size, attributes=attributes, blur_radius=blur_radius, h=h, w=w) 146 | 147 | #### 148 | alpha_images = rendering[:, -1, :, :][:, None, :, :].detach() 149 | 150 | # albedo 151 | albedo_images = rendering[:, :3, :, :] 152 | # mask 153 | transformed_normal_map = rendering[:, 3:6, :, :].detach() 154 | pos_mask = (transformed_normal_map[:, 2:, :, :] < 0.15).float() 155 | 156 | # shading 157 | normal_images = rendering[:, 9:12, :, :].detach() 158 | vertice_images = rendering[:, 6:9, :, :].detach() 159 | 160 | shading = add_directionlight(normal_images.permute(0,2,3,1).reshape([batch_size, -1, 3]), lights) 161 | shading_images = shading.reshape([batch_size, albedo_images.shape[2], albedo_images.shape[3], 3]).permute(0,3,1,2).contiguous() 162 | shaded_images = albedo_images*shading_images 163 | 164 | # alpha_images = alpha_images*pos_mask 165 | if background is None: 166 | shape_images = shaded_images*alpha_images + torch.ones_like(shaded_images).to(vertices.device)*(1-alpha_images) 167 | else: 168 | shape_images = shaded_images *alpha_images + background*(1-alpha_images) 169 | return shape_images 170 | 171 | 172 | def render_texture(transformed_vertices, faces, image_size, 173 | vertices, albedos, 174 | uv_faces, uv_coords, 175 | lights=None, light_type='point'): 176 | ''' 177 | -- Texture Rendering 178 | vertices: [batch_size, V, 3], vertices in world space, for calculating normals, then shading 179 | transformed_vertices: [batch_size, V, 3], range:normalized to [-1,1], projected vertices in image space (that is aligned to the iamge pixel), for rasterization 180 | albedos: [batch_size, 3, h, w], uv map 181 | lights: 182 | spherical homarnic: [N, 9(shcoeff), 3(rgb)] 183 | points/directional lighting: [N, n_lights, 6(xyzrgb)] 184 | light_type: 185 | point or directional 186 | ''' 187 | batch_size = vertices.shape[0] 188 | face_uvcoords = util.face_vertices(uv_coords, uv_faces) 189 | ## rasterizer near 0 far 100. move mesh so minz larger than 0 190 | # import ipdb; ipdb.set_trace() 191 | # normalize to 0, 100 192 | # transformed_vertices[:,:,2] = transformed_vertices[:,:,2] + 10 193 | # transformed_vertices[:,:,2] = transformed_vertices[:,:,2] - transformed_vertices[:,:,2].min() 194 | # transformed_vertices[:,:,2] = transformed_vertices[:,:,2]/transformed_vertices[:,:,2].max()*80 + 10 195 | 196 | # import ipdb; ipdb.set_trace() 197 | # attributes 198 | face_vertices = util.face_vertices(vertices, faces) 199 | normals = util.vertex_normals(vertices, faces); face_normals = util.face_vertices(normals, faces) 200 | transformed_normals = util.vertex_normals(transformed_vertices, faces); transformed_face_normals = util.face_vertices(transformed_normals, self.faces.expand(batch_size, -1, -1)) 201 | 202 | attributes = torch.cat([face_uvcoords, 203 | transformed_face_normals.detach(), 204 | face_vertices.detach(), 205 | face_normals], 206 | -1) 207 | # rasterize 208 | rendering = pytorch3d_rasterize(transformed_vertices, faces, attributes) 209 | 210 | #### 211 | # vis mask 212 | alpha_images = rendering[:, -1, :, :][:, None, :, :].detach() 213 | 214 | # albedo 215 | uvcoords_images = rendering[:, :3, :, :]; grid = (uvcoords_images).permute(0, 2, 3, 1)[:, :, :, :2] 216 | albedo_images = F.grid_sample(albedos, grid, align_corners=False) 217 | 218 | # visible mask for pixels with positive normal direction 219 | transformed_normal_map = rendering[:, 3:6, :, :].detach() 220 | pos_mask = (transformed_normal_map[:, 2:, :, :] < -0.05).float() 221 | 222 | # shading 223 | normal_images = rendering[:, 9:12, :, :] 224 | if lights is not None: 225 | if lights.shape[1] == 9: 226 | shading_images = self.add_SHlight(normal_images, lights) 227 | else: 228 | if light_type=='point': 229 | vertice_images = rendering[:, 6:9, :, :].detach() 230 | shading = self.add_pointlight(vertice_images.permute(0,2,3,1).reshape([batch_size, -1, 3]), normal_images.permute(0,2,3,1).reshape([batch_size, -1, 3]), lights) 231 | shading_images = shading.reshape([batch_size, albedo_images.shape[2], albedo_images.shape[3], 3]).permute(0,3,1,2) 232 | else: 233 | shading = self.add_directionlight(normal_images.permute(0,2,3,1).reshape([batch_size, -1, 3]), lights) 234 | shading_images = shading.reshape([batch_size, albedo_images.shape[2], albedo_images.shape[3], 3]).permute(0,3,1,2) 235 | images = albedo_images*shading_images 236 | else: 237 | images = albedo_images 238 | shading_images = images.detach()*0. 239 | 240 | outputs = { 241 | 'images': images*alpha_images, 242 | 'albedo_images': albedo_images*alpha_images, 243 | 'alpha_images': alpha_images, 244 | 'pos_mask': pos_mask, 245 | 'shading_images': shading_images, 246 | 'grid': grid, 247 | 'normals': normals, 248 | 'normal_images': normal_images*alpha_images, 249 | 'transformed_normals': transformed_normals, 250 | } 251 | 252 | return outputs -------------------------------------------------------------------------------- /lib/utils/volumetric_rendering.py: -------------------------------------------------------------------------------- 1 | """ 2 | Differentiable volumetric implementation used by pi-GAN generator. 3 | """ 4 | 5 | import time 6 | from functools import partial 7 | 8 | import math 9 | import numpy as np 10 | import torch 11 | import torch.nn.functional as F 12 | import matplotlib.pyplot as plt 13 | import random 14 | 15 | import torch 16 | 17 | def transform_vectors(matrix: torch.Tensor, vectors4: torch.Tensor) -> torch.Tensor: 18 | """ 19 | Left-multiplies MxM @ NxM. Returns NxM. 20 | """ 21 | res = torch.matmul(vectors4, matrix.T) 22 | return res 23 | 24 | def normalize_vecs(vectors: torch.Tensor) -> torch.Tensor: 25 | """ 26 | Normalize vector lengths. 27 | """ 28 | return vectors / (torch.norm(vectors, dim=-1, keepdim=True)) 29 | 30 | def torch_dot(x: torch.Tensor, y: torch.Tensor): 31 | """ 32 | Dot product of two tensors. 33 | """ 34 | return (x * y).sum(-1) 35 | 36 | def fancy_integration(rgb_sigma, z_vals, device, noise_std=0.5, last_back=False, white_back=False, clamp_mode='softplus', fill_mode=None): 37 | """Performs NeRF volumetric rendering.""" 38 | rgbs = rgb_sigma[..., :3] 39 | sigmas = rgb_sigma[..., 3:] 40 | 41 | # import ipdb; ipdb.set_trace() 42 | deltas = z_vals[:, :, 1:] - z_vals[:, :, :-1] 43 | delta_inf = 1e10 * torch.ones_like(deltas[:, :, :1]) 44 | deltas = torch.cat([deltas, delta_inf], -2) 45 | 46 | noise = torch.randn(sigmas.shape, device=device) * noise_std 47 | 48 | if clamp_mode == 'softplus': 49 | alphas = 1-torch.exp(-deltas * (F.softplus(sigmas + noise))) 50 | elif clamp_mode == 'relu': 51 | alphas = 1 - torch.exp(-deltas * (F.relu(sigmas + noise))) 52 | else: 53 | raise "Need to choose clamp mode" 54 | 55 | alphas_shifted = torch.cat([torch.ones_like(alphas[:, :, :1]), 1-alphas + 1e-10], -2) 56 | weights = alphas * torch.cumprod(alphas_shifted, -2)[:, :, :-1] 57 | weights_sum = weights.sum(2) 58 | # weights_sum = weights[:,:,:-1].sum(2) 59 | 60 | if last_back: 61 | weights[:, :, -1] += (1 - weights_sum) 62 | 63 | rgb_final = torch.sum(weights * rgbs, -2) 64 | depth_final = torch.sum(weights * z_vals, -2) 65 | 66 | # if white_back: 67 | # rgb_final = rgb_final + 1-weights_sum 68 | 69 | if fill_mode == 'debug': 70 | rgb_final[weights_sum.squeeze(-1) < 0.9] = torch.tensor([1., 0, 0], device=rgb_final.device) 71 | elif fill_mode == 'weight': 72 | rgb_final = weights_sum.expand_as(rgb_final) 73 | 74 | return rgb_final, depth_final, weights.squeeze(-1) 75 | 76 | 77 | def get_initial_rays_trig(n, num_steps, device, fov, resolution, ray_start, ray_end): 78 | """Returns sample points, z_vals, and ray directions in camera space.""" 79 | 80 | W, H = resolution 81 | # Create full screen NDC (-1 to +1) coords [x, y, 0, 1]. 82 | # Y is flipped to follow image memory layouts. 83 | x, y = torch.meshgrid(torch.linspace(-1, 1, W, device=device), 84 | torch.linspace(1, -1, H, device=device)) 85 | x = x.T.flatten() 86 | y = y.T.flatten() 87 | z = -torch.ones_like(x, device=device) / np.tan((2 * math.pi * fov / 360)/2) 88 | 89 | rays_d_cam = normalize_vecs(torch.stack([x, y, z], -1)) 90 | # rays_d_cam = torch.stack([x, y, z], -1) 91 | 92 | 93 | z_vals = torch.linspace(ray_start, ray_end, num_steps, device=device).reshape(1, num_steps, 1).repeat(W*H, 1, 1) 94 | points = rays_d_cam.unsqueeze(1).repeat(1, num_steps, 1) * z_vals 95 | 96 | points = torch.stack(n*[points]) 97 | z_vals = torch.stack(n*[z_vals]) 98 | rays_d_cam = torch.stack(n*[rays_d_cam]).to(device) 99 | # import ipdb; ipdb.set_trace() 100 | return points, z_vals, rays_d_cam 101 | 102 | def perturb_points(points, z_vals, ray_directions, device): 103 | distance_between_points = z_vals[:,:,1:2,:] - z_vals[:,:,0:1,:] 104 | offset = (torch.rand(z_vals.shape, device=device)-0.5) * distance_between_points 105 | # z_vals = z_vals + offset 106 | z_vals[:,:,1:-1,:] = z_vals[:,:,1:-1,:] + offset[:,:,1:-1,:] 107 | points[:,:,1:-1,:] = points[:,:,1:-1,:] + offset[:,:,1:-1,:] * ray_directions.unsqueeze(2) 108 | return points, z_vals 109 | 110 | 111 | def transform_sampled_points(points, z_vals, ray_directions, device, h_stddev=1, v_stddev=1, h_mean=math.pi * 0.5, v_mean=math.pi * 0.5, mode='normal'): 112 | """Samples a camera position and maps points in camera space to world space.""" 113 | 114 | n, num_rays, num_steps, channels = points.shape 115 | 116 | # points, z_vals = perturb_points(points, z_vals, ray_directions, device) 117 | 118 | 119 | camera_origin, pitch, yaw = sample_camera_positions(n=points.shape[0], r=1, horizontal_stddev=h_stddev, vertical_stddev=v_stddev, horizontal_mean=h_mean, vertical_mean=v_mean, device=device, mode=mode) 120 | forward_vector = normalize_vecs(-camera_origin) 121 | 122 | cam2world_matrix = create_cam2world_matrix(forward_vector, camera_origin, device=device) 123 | 124 | points_homogeneous = torch.ones((points.shape[0], points.shape[1], points.shape[2], points.shape[3] + 1), device=device) 125 | points_homogeneous[:, :, :, :3] = points 126 | 127 | # should be n x 4 x 4 , n x r^2 x num_steps x 4 128 | transformed_points = torch.bmm(cam2world_matrix, points_homogeneous.reshape(n, -1, 4).permute(0,2,1)).permute(0, 2, 1).reshape(n, num_rays, num_steps, 4) 129 | 130 | 131 | transformed_ray_directions = torch.bmm(cam2world_matrix[..., :3, :3], ray_directions.reshape(n, -1, 3).permute(0,2,1)).permute(0, 2, 1).reshape(n, num_rays, 3) 132 | 133 | homogeneous_origins = torch.zeros((n, 4, num_rays), device=device) 134 | homogeneous_origins[:, 3, :] = 1 135 | transformed_ray_origins = torch.bmm(cam2world_matrix, homogeneous_origins).permute(0, 2, 1).reshape(n, num_rays, 4)[..., :3] 136 | 137 | # return transformed_points[..., :3], z_vals, transformed_ray_directions, transformed_ray_origins, pitch, yaw 138 | return points, z_vals, transformed_ray_directions, transformed_ray_origins, pitch, yaw 139 | 140 | def truncated_normal_(tensor, mean=0, std=1): 141 | size = tensor.shape 142 | tmp = tensor.new_empty(size + (4,)).normal_() 143 | valid = (tmp < 2) & (tmp > -2) 144 | ind = valid.max(-1, keepdim=True)[1] 145 | tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1)) 146 | tensor.data.mul_(std).add_(mean) 147 | return tensor 148 | 149 | def sample_camera_positions(device, n=1, r=1, horizontal_stddev=1, vertical_stddev=1, horizontal_mean=math.pi*0.5, vertical_mean=math.pi*0.5, mode='normal'): 150 | """ 151 | Samples n random locations along a sphere of radius r. Uses the specified distribution. 152 | Theta is yaw in radians (-pi, pi) 153 | Phi is pitch in radians (0, pi) 154 | """ 155 | 156 | if mode == 'uniform': 157 | theta = (torch.rand((n, 1), device=device) - 0.5) * 2 * horizontal_stddev + horizontal_mean 158 | phi = (torch.rand((n, 1), device=device) - 0.5) * 2 * vertical_stddev + vertical_mean 159 | 160 | elif mode == 'normal' or mode == 'gaussian': 161 | theta = torch.randn((n, 1), device=device) * horizontal_stddev + horizontal_mean 162 | phi = torch.randn((n, 1), device=device) * vertical_stddev + vertical_mean 163 | 164 | elif mode == 'hybrid': 165 | if random.random() < 0.5: 166 | theta = (torch.rand((n, 1), device=device) - 0.5) * 2 * horizontal_stddev * 2 + horizontal_mean 167 | phi = (torch.rand((n, 1), device=device) - 0.5) * 2 * vertical_stddev * 2 + vertical_mean 168 | else: 169 | theta = torch.randn((n, 1), device=device) * horizontal_stddev + horizontal_mean 170 | phi = torch.randn((n, 1), device=device) * vertical_stddev + vertical_mean 171 | 172 | elif mode == 'truncated_gaussian': 173 | theta = truncated_normal_(torch.zeros((n, 1), device=device)) * horizontal_stddev + horizontal_mean 174 | phi = truncated_normal_(torch.zeros((n, 1), device=device)) * vertical_stddev + vertical_mean 175 | 176 | elif mode == 'spherical_uniform': 177 | theta = (torch.rand((n, 1), device=device) - .5) * 2 * horizontal_stddev + horizontal_mean 178 | v_stddev, v_mean = vertical_stddev / math.pi, vertical_mean / math.pi 179 | v = ((torch.rand((n,1), device=device) - .5) * 2 * v_stddev + v_mean) 180 | v = torch.clamp(v, 1e-5, 1 - 1e-5) 181 | phi = torch.arccos(1 - 2 * v) 182 | 183 | else: 184 | # Just use the mean. 185 | theta = torch.ones((n, 1), device=device, dtype=torch.float) * horizontal_mean 186 | phi = torch.ones((n, 1), device=device, dtype=torch.float) * vertical_mean 187 | 188 | phi = torch.clamp(phi, 1e-5, math.pi - 1e-5) 189 | 190 | output_points = torch.zeros((n, 3), device=device) 191 | output_points[:, 0:1] = r*torch.sin(phi) * torch.cos(theta) 192 | output_points[:, 2:3] = r*torch.sin(phi) * torch.sin(theta) 193 | output_points[:, 1:2] = r*torch.cos(phi) 194 | 195 | return output_points, phi, theta 196 | 197 | def create_cam2world_matrix(forward_vector, origin, device=None): 198 | """Takes in the direction the camera is pointing and the camera origin and returns a cam2world matrix.""" 199 | 200 | forward_vector = normalize_vecs(forward_vector) 201 | up_vector = torch.tensor([0, 1, 0], dtype=torch.float, device=device).expand_as(forward_vector) 202 | 203 | left_vector = normalize_vecs(torch.cross(up_vector, forward_vector, dim=-1)) 204 | 205 | up_vector = normalize_vecs(torch.cross(forward_vector, left_vector, dim=-1)) 206 | 207 | rotation_matrix = torch.eye(4, device=device).unsqueeze(0).repeat(forward_vector.shape[0], 1, 1) 208 | rotation_matrix[:, :3, :3] = torch.stack((-left_vector, up_vector, -forward_vector), axis=-1) 209 | 210 | translation_matrix = torch.eye(4, device=device).unsqueeze(0).repeat(forward_vector.shape[0], 1, 1) 211 | translation_matrix[:, :3, 3] = origin 212 | 213 | cam2world = translation_matrix @ rotation_matrix 214 | 215 | return cam2world 216 | 217 | 218 | def create_world2cam_matrix(forward_vector, origin): 219 | """Takes in the direction the camera is pointing and the camera origin and returns a world2cam matrix.""" 220 | cam2world = create_cam2world_matrix(forward_vector, origin, device=device) 221 | world2cam = torch.inverse(cam2world) 222 | return world2cam 223 | 224 | 225 | def sample_pdf(bins, weights, N_importance, det=False, eps=1e-5): 226 | """ 227 | Sample @N_importance samples from @bins with distribution defined by @weights. 228 | Inputs: 229 | bins: (N_rays, N_samples_+1) where N_samples_ is "the number of coarse samples per ray - 2" 230 | weights: (N_rays, N_samples_) 231 | N_importance: the number of samples to draw from the distribution 232 | det: deterministic or not 233 | eps: a small number to prevent division by zero 234 | Outputs: 235 | samples: the sampled samples 236 | Source: https://github.com/kwea123/nerf_pl/blob/master/models/rendering.py 237 | """ 238 | N_rays, N_samples_ = weights.shape 239 | weights = weights + eps # prevent division by zero (don't do inplace op!) 240 | pdf = weights / torch.sum(weights, -1, keepdim=True) # (N_rays, N_samples_) 241 | cdf = torch.cumsum(pdf, -1) # (N_rays, N_samples), cumulative distribution function 242 | cdf = torch.cat([torch.zeros_like(cdf[: ,:1]), cdf], -1) # (N_rays, N_samples_+1) 243 | # padded to 0~1 inclusive 244 | if det: 245 | u = torch.linspace(0, 1, N_importance, device=bins.device) 246 | u = u.expand(N_rays, N_importance) 247 | else: 248 | u = torch.rand(N_rays, N_importance, device=bins.device) 249 | u = u.contiguous() 250 | 251 | # inds = torch.searchsorted(cdf, u) 252 | inds = torch.searchsorted(cdf, u, right=True) 253 | below = torch.clamp_min(inds-1, 0) 254 | above = torch.clamp_max(inds, N_samples_) 255 | 256 | inds_sampled = torch.stack([below, above], -1).view(N_rays, 2*N_importance) 257 | cdf_g = torch.gather(cdf, 1, inds_sampled) 258 | cdf_g = cdf_g.view(N_rays, N_importance, 2) 259 | bins_g = torch.gather(bins, 1, inds_sampled).view(N_rays, N_importance, 2) 260 | 261 | denom = cdf_g[...,1]-cdf_g[...,0] 262 | denom[denom 2 | 3 |

4 | 5 |

SCARF: Capturing and Animation of Body and Clothing from Monocular Video 6 |

7 | 20 |
21 | visualize data 22 |
cropped image, subject segmentation, clothing segmentation, SMPL-X estimation
23 |
24 |

25 | This is the script to process video data for SCARF training. 26 | 27 | ## Getting Started 28 | ### Environment 29 | SCARF needs input image, subject mask, clothing mask, and inital SMPL-X estimation for training. 30 | Specificly, we use 31 | * [FasterRCNN](https://pytorch.org/vision/main/models/faster_rcnn.html) to detect the subject and crop image 32 | * [RobustVideoMatting](https://github.com/PeterL1n/RobustVideoMatting) to remove background 33 | * [cloth-segmentation](https://github.com/levindabhi/cloth-segmentation) to segment clothing 34 | * [PIXIE](https://github.com/yfeng95/PIXIE) to estimate SMPL-X parameters 35 | 36 | When using the processing script, it is necessary to agree to the terms of their licenses and properly cite them in your work. 37 | 38 | 1. Clone submodule repositories: 39 | ``` 40 | git submodule update --init --recursive 41 | ``` 42 | 2. Download their needed data: 43 | ```bash 44 | bash fetch_asset_data.sh 45 | ``` 46 | If the script failed, please check their websites and download the models manually. 47 | 48 | ### process video data 49 | Put your data list into ./lists/subject_list.txt, it can be video path or image folders. 50 | Then run 51 | ```bash 52 | python process_video.py --crop --ignore_existing 53 | ``` 54 | Processing time depends on the number of frames and the size of video, for mpiis-scarf video (with 400 frames and resolution 1028x1920), need around 12min. 55 | 56 | 57 | ## Video Data 58 | The script has been verified to work for datasets: 59 | ``` 60 | a. mpiis-scarf (recorded video for this paper) 61 | b. People Snapshot Dataset (https://graphics.tu-bs.de/people-snapshot) 62 | c. SelfRecon dataset (https://jby1993.github.io/SelfRecon/) 63 | d. iPER dataset (https://svip-lab.github.io/dataset/iPER_dataset.html) 64 | ``` 65 | To get the optimal results for your customized video, it is recommended to capture the video using similar settings as the datasets mentioned above. 66 | 67 | This means keeping the camera static, 68 | recording the subject with more views, and using uniform lighting. And better to have less than 1000 frames for training. 69 | For more information, please refer to the limitations section of SCARF. -------------------------------------------------------------------------------- /process_data/fetch_asset_data.sh: -------------------------------------------------------------------------------- 1 | # trained model for RobustVideoMatting 2 | mkdir -p ./assets 3 | mkdir -p ./assets/RobustVideoMatting 4 | echo -e "Downloading RobustVideoMatting model..." 5 | wget https://github.com/PeterL1n/RobustVideoMatting/releases/download/v1.0.0/rvm_resnet50.pth -P ./assets/RobustVideoMatting 6 | 7 | # trained model for cloth-segmentation 8 | # if failed, please download the model from 9 | # https://drive.google.com/file/d/1mhF3yqd7R-Uje092eypktNl-RoZNuiCJ/view -- not valide anymore 10 | # use new link: https://drive.google.com/file/d/18_ekYXKLgd0H8XOE7jMJqQrlsoyhQU67/view?usp=drive_link 11 | mkdir -p ./assets/cloth-segmentation 12 | echo -e "Downloading cloth-segmentation model..." 13 | FILEID=18_ekYXKLgd0H8XOE7jMJqQrlsoyhQU67 14 | FILENAME=./assets/cloth-segmentation/cloth_segm_u2net_latest.pth 15 | wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id='${FILEID} -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=${FILEID}" -O $FILENAME && rm -rf /tmp/cookies.txt 16 | 17 | # trained model for PIXIE 18 | # if failed, please check https://github.com/yfeng95/PIXIE/blob/master/fetch_model.sh 19 | cd submodules/PIXIE 20 | echo -e "Downloading PIXIE data..." 21 | #!/bin/bash 22 | urle () { [[ "${1}" ]] || return 1; local LANG=C i x; for (( i = 0; i < ${#1}; i++ )); do x="${1:i:1}"; [[ "${x}" == [a-zA-Z0-9.~-] ]] && echo -n "${x}" || printf '%%%02X' "'${x}"; done; echo; } 23 | 24 | # SMPL-X 2020 (neutral SMPL-X model with the FLAME 2020 expression blendshapes) 25 | echo -e "\nYou need to register at https://smpl-x.is.tue.mpg.de" 26 | read -p "Username (SMPL-X):" username 27 | read -p "Password (SMPL-X):" password 28 | username=$(urle $username) 29 | password=$(urle $password) 30 | wget --post-data "username=$username&password=$password" 'https://download.is.tue.mpg.de/download.php?domain=smplx&sfile=SMPLX_NEUTRAL_2020.npz&resume=1' -O './data/SMPLX_NEUTRAL_2020.npz' --no-check-certificate --continue 31 | 32 | # PIXIE pretrained model and utilities 33 | echo -e "\nYou need to register at https://pixie.is.tue.mpg.de/" 34 | read -p "Username (PIXIE):" username 35 | read -p "Password (PIXIE):" password 36 | username=$(urle $username) 37 | password=$(urle $password) 38 | wget --post-data "username=$username&password=$password" 'https://download.is.tue.mpg.de/download.php?domain=pixie&sfile=pixie_model.tar&resume=1' -O './data/pixie_model.tar' --no-check-certificate --continue 39 | wget --post-data "username=$username&password=$password" 'https://download.is.tue.mpg.de/download.php?domain=pixie&sfile=utilities.zip&resume=1' -O './data/utilities.zip' --no-check-certificate --continue 40 | 41 | cd ./data 42 | unzip utilities.zip -------------------------------------------------------------------------------- /process_data/lists/video_list.txt: -------------------------------------------------------------------------------- 1 | ../exps/mpiis/DSC_7157/DSC_7157.mp4 -------------------------------------------------------------------------------- /process_data/logs/generate_data.log: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yfeng95/SCARF/d8d89d27431120569380f16094d419df063c2fe4/process_data/logs/generate_data.log -------------------------------------------------------------------------------- /process_data/process_video.py: -------------------------------------------------------------------------------- 1 | from ipaddress import ip_address 2 | import os, sys 3 | import argparse 4 | from tqdm import tqdm 5 | from pathlib import Path 6 | from loguru import logger 7 | from glob import glob 8 | from skimage.io import imread, imsave 9 | from skimage.transform import estimate_transform, warp, resize, rescale 10 | import shutil 11 | import torch 12 | import cv2 13 | import numpy as np 14 | from PIL import Image 15 | 16 | def get_palette(num_cls): 17 | """Returns the color map for visualizing the segmentation mask. 18 | Args: 19 | num_cls: Number of classes 20 | Returns: 21 | The color map 22 | """ 23 | n = num_cls 24 | palette = [0] * (n * 3) 25 | for j in range(0, n): 26 | lab = j 27 | palette[j * 3 + 0] = 0 28 | palette[j * 3 + 1] = 0 29 | palette[j * 3 + 2] = 0 30 | i = 0 31 | while lab: 32 | palette[j * 3 + 0] |= ((lab >> 0) & 1) << (7 - i) 33 | palette[j * 3 + 1] |= ((lab >> 1) & 1) << (7 - i) 34 | palette[j * 3 + 2] |= ((lab >> 2) & 1) << (7 - i) 35 | i += 1 36 | lab >>= 3 37 | return palette 38 | 39 | def generate_frame(inputpath, savepath, subject_name=None, n_frames=2000, fps=30): 40 | ''' extract frames from video or copy frames from image folder 41 | ''' 42 | os.makedirs(savepath, exist_ok=True) 43 | if subject_name is None: 44 | subject_name = Path(inputpath).stem 45 | ## video data 46 | if os.path.isfile(inputpath) and (os.path.splitext(inputpath)[-1] in ['.mp4', '.csv', '.MOV']): 47 | videopath = os.path.join(os.path.dirname(savepath), f'{subject_name}.mp4') 48 | logger.info(f'extract frames from video: {inputpath}..., then save to {videopath}') 49 | vidcap = cv2.VideoCapture(inputpath) 50 | count = 0 51 | success, image = vidcap.read() 52 | cv2.imwrite(os.path.join(savepath, f'{subject_name}_f{count:06d}.png'), image) 53 | h, w = image.shape[:2] 54 | print(h, w) 55 | # import imageio 56 | # savecap = imageio.get_writer(videopath, fps=fps) 57 | # savecap.append_data(image[:,:,::-1]) 58 | while success: 59 | count += 1 60 | success,image = vidcap.read() 61 | if count > n_frames or image is None: 62 | break 63 | imagepath = os.path.join(savepath, f'{subject_name}_f{count:06d}.png') 64 | cv2.imwrite(imagepath, image) # save frame as JPEG png 65 | # savecap.append_data(image[:,:,::-1]) 66 | logger.info(f'extracted {count} frames') 67 | elif os.path.isdir(inputpath): 68 | logger.info(f'copy frames from folder: {inputpath}...') 69 | imagepath_list = glob(inputpath + '/*.jpg') + glob(inputpath + '/*.png') + glob(inputpath + '/*.jpeg') 70 | imagepath_list = sorted(imagepath_list) 71 | for count, imagepath in enumerate(imagepath_list): 72 | shutil.copyfile(imagepath, os.path.join(savepath, f'{subject_name}_f{count:06d}.png')) 73 | print('frames are stored in {}'.format(savepath)) 74 | else: 75 | logger.info(f'please check the input path: {inputpath}') 76 | logger.info(f'video frames are stored in {savepath}') 77 | 78 | def generate_image(inputpath, savepath, subject_name=None, crop=False, crop_each=False, image_size=512, scale_bbox=1.1, device='cuda:0'): 79 | ''' generate image from given frame path. 80 | ''' 81 | logger.info(f'generae images, crop {crop}, image size {image_size}') 82 | os.makedirs(savepath, exist_ok=True) 83 | # load detection model 84 | from submodules.detector import FasterRCNN 85 | detector = FasterRCNN(device=device) 86 | if os.path.isdir(inputpath): 87 | imagepath_list = glob(inputpath + '/*.jpg') + glob(inputpath + '/*.png') + glob(inputpath + '/*.jpeg') 88 | imagepath_list = sorted(imagepath_list) 89 | # if crop, detect the bbox of the first image and use the bbox for all frames 90 | if crop: 91 | imagepath = imagepath_list[0] 92 | logger.info(f'detect first image {imagepath}') 93 | imagename = os.path.splitext(os.path.basename(imagepath))[0] 94 | image = imread(imagepath)[:,:,:3]/255. 95 | h, w, _ = image.shape 96 | 97 | image_tensor = torch.tensor(image.transpose(2,0,1), dtype=torch.float32)[None, ...] 98 | bbox = detector.run(image_tensor) 99 | left = bbox[0]; right = bbox[2]; top = bbox[1]; bottom = bbox[3] 100 | np.savetxt(os.path.join(Path(inputpath).parent, 'image_bbox.txt'), bbox) 101 | 102 | ## calculate warping function for image cropping 103 | old_size = max(right - left, bottom - top) 104 | center = np.array([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0]) 105 | size = int(old_size*scale_bbox) 106 | src_pts = np.array([[center[0]-size/2, center[1]-size/2], [center[0] - size/2, center[1]+size/2], [center[0]+size/2, center[1]-size/2]]) 107 | DST_PTS = np.array([[0,0], [0,image_size - 1], [image_size - 1, 0]]) 108 | tform = estimate_transform('similarity', src_pts, DST_PTS) 109 | 110 | for count, imagepath in enumerate(tqdm(imagepath_list)): 111 | if crop: 112 | image = imread(imagepath) 113 | dst_image = warp(image, tform.inverse, output_shape=(image_size, image_size)) 114 | dst_image = (dst_image*255).astype(np.uint8) 115 | imsave(os.path.join(savepath, f'{subject_name}_f{count:06d}.png'), dst_image) 116 | else: 117 | shutil.copyfile(imagepath, os.path.join(savepath, f'{subject_name}_f{count:06d}.png')) 118 | logger.info(f'images are stored in {savepath}') 119 | 120 | def generate_matting_rvm(inputpath, savepath, ckpt_path='assets/RobustVideoMatting/rvm_resnet50.pth', device='cuda:0'): 121 | sys.path.append('./submodules/RobustVideoMatting') 122 | from model import MattingNetwork 123 | EXTS = ['jpg', 'jpeg', 'png'] 124 | segmentor = MattingNetwork(variant='resnet50').eval().to(device) 125 | segmentor.load_state_dict(torch.load(ckpt_path)) 126 | 127 | images_folder = inputpath 128 | output_folder = savepath 129 | os.makedirs(output_folder, exist_ok=True) 130 | 131 | frame_IDs = os.listdir(images_folder) 132 | frame_IDs = [id.split('.')[0] for id in frame_IDs if id.split('.')[-1] in EXTS] 133 | frame_IDs.sort() 134 | frame_IDs = frame_IDs[:4][::-1] + frame_IDs 135 | 136 | rec = [None] * 4 # Initial recurrent 137 | downsample_ratio = 1.0 # Adjust based on your video. 138 | 139 | # bgr = torch.tensor([1, 1, 1.]).view(3, 1, 1).cuda() 140 | for i in tqdm(range(len(frame_IDs))): 141 | frame_ID = frame_IDs[i] 142 | img_path = os.path.join(images_folder, '{}.png'.format(frame_ID)) 143 | try: 144 | img_masked_path = os.path.join(output_folder, '{}.png'.format(frame_ID)) 145 | img = cv2.imread(img_path) 146 | src = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 147 | src = torch.from_numpy(src).float() / 255. 148 | src = src.permute(2, 0, 1).unsqueeze(0) 149 | with torch.no_grad(): 150 | fgr, pha, *rec = segmentor(src.to(device), *rec, downsample_ratio) # Cycle the recurrent states. 151 | pha = pha.permute(0, 2, 3, 1).cpu().numpy().squeeze(0) 152 | # check the difference of current 153 | mask = (pha * 255).astype(np.uint8) 154 | img_masked = np.concatenate([img, mask], axis=-1) 155 | cv2.imwrite(img_masked_path, img_masked) 156 | except: 157 | os.remove(img_path) 158 | logger.info(f'matting failed for image {img_path}, delete it') 159 | 160 | sys.modules.pop('model') 161 | 162 | def generate_cloth_segmentation(inputpath, savepath, ckpt_path='assets/cloth-segmentation/cloth_segm_u2net_latest.pth', device='cuda:0', vis=False): 163 | logger.info(f'generate cloth segmentation for {inputpath}') 164 | os.makedirs(savepath, exist_ok=True) 165 | # load model 166 | sys.path.insert(0, './submodules/cloth-segmentation') 167 | import torch.nn.functional as F 168 | import torchvision.transforms as transforms 169 | from data.base_dataset import Normalize_image 170 | from utils.saving_utils import load_checkpoint_mgpu 171 | from networks import U2NET 172 | 173 | transforms_list = [] 174 | transforms_list += [transforms.ToTensor()] 175 | transforms_list += [Normalize_image(0.5, 0.5)] 176 | transform_rgb = transforms.Compose(transforms_list) 177 | 178 | net = U2NET(in_ch=3, out_ch=4) 179 | net = load_checkpoint_mgpu(net, ckpt_path) 180 | net = net.to(device) 181 | net = net.eval() 182 | 183 | palette = get_palette(4) 184 | 185 | images_list = sorted(os.listdir(inputpath)) 186 | pbar = tqdm(total=len(images_list)) 187 | for image_name in tqdm(images_list): 188 | img = Image.open(os.path.join(inputpath, image_name)).convert("RGB") 189 | image_tensor = transform_rgb(img) 190 | image_tensor = torch.unsqueeze(image_tensor, 0) 191 | 192 | output_tensor = net(image_tensor.to(device)) 193 | output_tensor = F.log_softmax(output_tensor[0], dim=1) 194 | output_tensor = torch.max(output_tensor, dim=1, keepdim=True)[1] 195 | output_tensor = torch.squeeze(output_tensor, dim=0) 196 | output_tensor = torch.squeeze(output_tensor, dim=0) 197 | output_arr = output_tensor.cpu().numpy() 198 | 199 | output_img = Image.fromarray(output_arr.astype("uint8"), mode="L") 200 | output_img.putpalette(palette) 201 | name = Path(image_name).stem 202 | output_img.save(os.path.join(savepath, f'{name}.png')) 203 | pbar.update(1) 204 | pbar.close() 205 | 206 | def generate_pixie(inputpath, savepath, ckpt_path='assets/face_normals/model.pth', device='cuda:0', image_size=512, vis=False): 207 | logger.info(f'generate pixie results') 208 | os.makedirs(savepath, exist_ok=True) 209 | # load model 210 | sys.path.insert(0, './submodules/PIXIE') 211 | from pixielib.pixie import PIXIE 212 | from pixielib.visualizer import Visualizer 213 | from pixielib.datasets.body_datasets import TestData 214 | from pixielib.utils import util 215 | from pixielib.utils.config import cfg as pixie_cfg 216 | from pixielib.utils.tensor_cropper import transform_points 217 | # run pixie 218 | testdata = TestData(inputpath, iscrop=False) 219 | pixie_cfg.model.use_tex = False 220 | pixie = PIXIE(config = pixie_cfg, device=device) 221 | visualizer = Visualizer(render_size=image_size, config = pixie_cfg, device=device, rasterizer_type='standard') 222 | testdata = TestData(inputpath, iscrop=False) 223 | for i, batch in enumerate(tqdm(testdata, dynamic_ncols=True)): 224 | batch['image'] = batch['image'].unsqueeze(0) 225 | batch['image_hd'] = batch['image_hd'].unsqueeze(0) 226 | name = batch['name'] 227 | util.move_dict_to_device(batch, device) 228 | data = { 229 | 'body': batch 230 | } 231 | param_dict = pixie.encode(data, threthold=True, keep_local=True, copy_and_paste=False) 232 | codedict = param_dict['body'] 233 | opdict = pixie.decode(codedict, param_type='body') 234 | util.save_pkl(os.path.join(savepath, f'{name}_param.pkl'), codedict) 235 | if vis: 236 | opdict['albedo'] = visualizer.tex_flame2smplx(opdict['albedo']) 237 | visdict = visualizer.render_results(opdict, data['body']['image_hd'], overlay=True, use_deca=False) 238 | cv2.imwrite(os.path.join(savepath, f'{name}_vis.jpg'), visualizer.visualize_grid(visdict, size=image_size)) 239 | 240 | def process_video(subjectpath, savepath=None, vis=False, crop=False, crop_each=False, ignore_existing=False, n_frames=2000): 241 | if savepath is None: 242 | savepath = Path(subjectpath).parent 243 | subject_name = Path(subjectpath).stem 244 | savepath = os.path.join(savepath, subject_name) 245 | os.makedirs(savepath, exist_ok=True) 246 | logger.info(f'processing {subject_name}') 247 | # 0. copy frames from video or image folder 248 | if ignore_existing or not os.path.exists(os.path.join(savepath, 'frame')): 249 | generate_frame(subjectpath, os.path.join(savepath, 'frame'), n_frames=n_frames) 250 | 251 | # 1. crop image from frames, use fasterrcnn for detection 252 | if ignore_existing or not os.path.exists(os.path.join(savepath, 'image')): 253 | generate_image(os.path.join(savepath, 'frame'), os.path.join(savepath, 'image'), subject_name=subject_name, 254 | crop=crop, crop_each=crop_each, image_size=512, scale_bbox=1.1, device='cuda:0') 255 | 256 | # 2. video matting 257 | if ignore_existing or not os.path.exists(os.path.join(savepath, 'matting')): 258 | generate_matting_rvm(os.path.join(savepath, 'image'), os.path.join(savepath, 'matting')) 259 | 260 | # 3. cloth segmentation 261 | if ignore_existing or not os.path.exists(os.path.join(savepath, 'cloth_segmentation')): 262 | generate_cloth_segmentation(os.path.join(savepath, 'image'), os.path.join(savepath, 'cloth_segmentation'), vis=vis) 263 | 264 | # 4. smplx estimation using PIXIE (https://github.com/yfeng95/PIXIE) 265 | if ignore_existing or not os.path.exists(os.path.join(savepath, 'pixie')): 266 | generate_pixie(os.path.join(savepath, 'image'), os.path.join(savepath, 'pixie'), vis=vis) 267 | logger.info(f'finish {subject_name}') 268 | 269 | def main(args): 270 | logger.add(args.logpath) 271 | 272 | with open(args.list, 'r') as f: 273 | lines = f.readlines() 274 | subject_list = [s.strip() for s in lines] 275 | if args.subject_idx is not None: 276 | if args.subject_idx > len(subject_list): 277 | print('idx error!') 278 | else: 279 | subject_list = [subject_list[args.subject_idx]] 280 | 281 | for subjectpath in tqdm(subject_list): 282 | process_video(subjectpath, savepath=args.savepath, vis=args.vis, crop=args.crop, crop_each=args.crop_each, ignore_existing=args.ignore_existing, 283 | n_frames=args.n_frames) 284 | 285 | if __name__ == "__main__": 286 | parser = argparse.ArgumentParser(description='generate dataset from video or image folder') 287 | parser.add_argument('--list', default='lists/video_list.txt', type=str, 288 | help='path to the subject data, can be image folder or video') 289 | parser.add_argument('--logpath', default='logs/generate_data.log', type=str, 290 | help='path to save log') 291 | parser.add_argument('--savepath', default=None, type=str, 292 | help='path to save processed data, if not specified, then save to the same folder as the subject data') 293 | parser.add_argument('--subject_idx', default=None, type=int, 294 | help='specify subject idx, if None (default), then use all the subject data in the list') 295 | parser.add_argument("--image_size", default=512, type=int, 296 | help = 'image size') 297 | parser.add_argument("--crop", default=True, action="store_true", 298 | help='whether to crop image according to the subject detection bbox') 299 | parser.add_argument("--crop_each", default=False, action="store_true", 300 | help='TODO, whether to crop image according for each frame in the video') 301 | parser.add_argument("--vis", default=True, action="store_true", 302 | help='whether to visualize labels (lmk, iris, face parsing)') 303 | parser.add_argument("--ignore_existing", default=False, action="store_true", 304 | help='ignore existing data') 305 | parser.add_argument("--filter_data", default=False, action="store_true", 306 | help='check labels, if it is not good, then delete image') 307 | parser.add_argument("--n_frames", default=400, type=int, 308 | help='number of frames to be processed') 309 | args = parser.parse_args() 310 | 311 | main(args) 312 | 313 | -------------------------------------------------------------------------------- /process_data/submodules/detector.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import argparse 4 | import cv2 5 | import scipy 6 | from skimage.io import imread, imsave 7 | from skimage.transform import estimate_transform, warp, resize, rescale 8 | from glob import glob 9 | import os 10 | from tqdm import tqdm 11 | import shutil 12 | 13 | ''' 14 | For cropping body: 15 | 1. using bbox from objection detectors 16 | 2. calculate bbox from body joints regressor 17 | object detectors: 18 | know body object from label number 19 | https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/ 20 | label for peopel: 1 21 | COCO_INSTANCE_CATEGORY_NAMES = [ 22 | '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 23 | 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign', 24 | 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 25 | 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A', 26 | 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 27 | 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 28 | 'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 29 | 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 30 | 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 31 | 'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 32 | 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book', 33 | 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush' 34 | ] 35 | ''' 36 | import os 37 | import os.path as osp 38 | import sys 39 | import numpy as np 40 | import cv2 41 | 42 | import torch 43 | import torchvision.transforms as transforms 44 | # from PIL import Image 45 | 46 | class FasterRCNN(object): 47 | ''' detect body 48 | ''' 49 | def __init__(self, device='cuda:0'): 50 | ''' 51 | https://pytorch.org/docs/stable/torchvision/models.html#faster-r-cnn 52 | ''' 53 | import torchvision 54 | self.model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True) 55 | self.model.to(device) 56 | self.model.eval() 57 | self.device = device 58 | 59 | @torch.no_grad() 60 | def run(self, input): 61 | ''' 62 | input: 63 | The input to the model is expected to be a list of tensors, 64 | each of shape [C, H, W], one for each image, and should be in 0-1 range. 65 | Different images can have different sizes. 66 | return: 67 | detected box, [x1, y1, x2, y2] 68 | ''' 69 | prediction = self.model(input.to(self.device))[0] 70 | inds = (prediction['labels']==1)*(prediction['scores']>0.5) 71 | if len(inds) < 1 or inds.sum()<0.5: 72 | return None 73 | else: 74 | # if inds.sum()<0.5: 75 | # inds = (prediction['labels']==1)*(prediction['scores']>0.05) 76 | bbox = prediction['boxes'][inds][0].cpu().numpy() 77 | return bbox 78 | 79 | @torch.no_grad() 80 | def run_multi(self, input): 81 | ''' 82 | input: 83 | The input to the model is expected to be a list of tensors, 84 | each of shape [C, H, W], one for each image, and should be in 0-1 range. 85 | Different images can have different sizes. 86 | return: 87 | detected box, [x1, y1, x2, y2] 88 | ''' 89 | prediction = self.model(input.to(self.device))[0] 90 | inds = (prediction['labels']==1)*(prediction['scores']>0.9) 91 | if len(inds) < 1: 92 | return None 93 | else: 94 | bbox = prediction['boxes'][inds].cpu().numpy() 95 | return bbox -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # to install all this requirements 2 | #`pip install -r requirements.txt` 3 | torch 4 | torchvision 5 | 6 | scikit-image 7 | opencv-python 8 | chumpy 9 | matplotlib 10 | loguru 11 | tqdm 12 | yacs 13 | wandb 14 | PyMCubes 15 | trimesh 16 | # for pixie 17 | kornia==0.4.0 18 | PyYAML==5.1.1 19 | # -- evaluation 20 | lpips 21 | torchmetrics 22 | ipdb 23 | # -- pytorch3d, if failed, go to its install page for more information 24 | # git+https://github.com/facebookresearch/pytorch3d.git 25 | # face-alignment 26 | # imageio 27 | # others 28 | # kornia 29 | # pytorch3d 30 | # trimesh 31 | # einops 32 | # pytorch_lightning 33 | # imageio-ffmpeg 34 | # PyMCubes 35 | # open3d 36 | # plyfile 37 | # 38 | # nerfacc 39 | # packaging 40 | # "git+https://github.com/facebookresearch/pytorch3d.git" 41 | # git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | # train nerf 2 | python main_train.py --data_cfg configs/data/mpiis/DSC_7157.yml --exp_cfg configs/exp/stage_0_nerf.yml 3 | 4 | # train hybrid 5 | python main_train.py --data_cfg configs/data/mpiis/DSC_7157.yml --exp_cfg configs/exp/stage_1_hybrid_perceptual.yml --clean 6 | ## Note: this one uses perceptual loss rather than mrf loss as described in the paper. 7 | ## The reason is: mrf loss works for v100. But for a100, the loss will be NAN. 8 | ## if you want to use mrf loss, make sure you are using v100, then run: 9 | ## python main_train.py --data_cfg configs/data/mpiis/DSC_7157.yml --exp_cfg configs/exp/stage_1_hybrid.yml --clean 10 | 11 | #--- training male-3-casual example 12 | # python main_train.py --data_cfg configs/data/snapshot/male-3-casual.yml --exp_cfg configs/exp/stage_0_nerf.yml 13 | --------------------------------------------------------------------------------