├── .gitignore
├── .gitmodules
├── Doc
└── images
│ ├── Teaser.png
│ ├── clothing_transfer.png
│ ├── data_teaser.png
│ ├── mesh.gif
│ └── teaser.gif
├── LICENSE
├── README.md
├── configs
├── data
│ ├── mpiis
│ │ ├── DSC_7151.yml
│ │ └── DSC_7157.yml
│ └── snapshot
│ │ ├── female-3-casual.yml
│ │ ├── female-4-casual.yml
│ │ ├── male-3-casual.yml
│ │ └── male-4-casual.yml
└── exp
│ ├── stage_0_nerf.yml
│ ├── stage_1_hybrid.yml
│ └── stage_1_hybrid_perceptual.yml
├── environment.yml
├── fetch_data.sh
├── lib
├── __init__.py
├── datasets
│ ├── build_datasets.py
│ └── scarf_data.py
├── models
│ ├── __init__.py
│ ├── embedding.py
│ ├── lbs.py
│ ├── nerf.py
│ ├── scarf.py
│ ├── siren.py
│ └── smplx.py
├── trainer.py
├── utils
│ ├── camera_util.py
│ ├── config.py
│ ├── lossfunc.py
│ ├── metric.py
│ ├── perceptual_loss.py
│ ├── rasterize_rendering.py
│ ├── rotation_converter.py
│ ├── util.py
│ └── volumetric_rendering.py
└── visualizer.py
├── main_demo.py
├── main_train.py
├── process_data
├── README.md
├── fetch_asset_data.sh
├── lists
│ └── video_list.txt
├── logs
│ └── generate_data.log
├── process_video.py
└── submodules
│ └── detector.py
├── requirements.txt
└── train.sh
/.gitignore:
--------------------------------------------------------------------------------
1 | # Compiled source #
2 | ###################
3 | *.o
4 | *.so
5 |
6 | # Packages #
7 | ############
8 | # it's better to unpack these files and commit the raw source
9 | # git has its own built in compression methods
10 | *.7z
11 | *.dmg
12 | *.gz
13 | *.iso
14 | *.jar
15 | *.rar
16 | *.tar
17 | *.zip
18 |
19 | # OS generated files #
20 | ######################
21 | .DS_Store
22 | .DS_Store?
23 | ._*
24 | .Spotlight-V100
25 | .Trashes
26 | ehthumbs.db
27 | Thumbs.db
28 | .vscode
29 |
30 | # 3D data #
31 | ############
32 | *.mat
33 | *.pkl
34 | *.obj
35 | *.dat
36 | *.npz
37 |
38 | # python file #
39 | ############
40 | *.pyc
41 | __pycache__
42 |
43 | *results*
44 | # *_vis.jpg
45 | ## internal use
46 | cluster_scripts
47 | internal
48 | *TestSamples*
49 | data
50 | exps
51 | wandb
52 | process_data/assets
53 | debug*
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "process_data/submodules/cloth-segmentation"]
2 | path = process_data/submodules/cloth-segmentation
3 | url = https://github.com/levindabhi/cloth-segmentation.git
4 | [submodule "process_data/submodules/RobustVideoMatting"]
5 | path = process_data/submodules/RobustVideoMatting
6 | url = https://github.com/PeterL1n/RobustVideoMatting.git
7 | [submodule "process_data/submodules/PIXIE"]
8 | path = process_data/submodules/PIXIE
9 | url = https://github.com/yfeng95/PIXIE.git
10 |
--------------------------------------------------------------------------------
/Doc/images/Teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yfeng95/SCARF/d8d89d27431120569380f16094d419df063c2fe4/Doc/images/Teaser.png
--------------------------------------------------------------------------------
/Doc/images/clothing_transfer.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yfeng95/SCARF/d8d89d27431120569380f16094d419df063c2fe4/Doc/images/clothing_transfer.png
--------------------------------------------------------------------------------
/Doc/images/data_teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yfeng95/SCARF/d8d89d27431120569380f16094d419df063c2fe4/Doc/images/data_teaser.png
--------------------------------------------------------------------------------
/Doc/images/mesh.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yfeng95/SCARF/d8d89d27431120569380f16094d419df063c2fe4/Doc/images/mesh.gif
--------------------------------------------------------------------------------
/Doc/images/teaser.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yfeng95/SCARF/d8d89d27431120569380f16094d419df063c2fe4/Doc/images/teaser.gif
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | License
2 |
3 | Software Copyright License for non-commercial scientific research purposes
4 | Please read carefully the following terms and conditions and any accompanying documentation before you download and/or use the ECON model, data and software, (the "Data & Software"), including 3D meshes, images, videos, textures, software, scripts, and animations. By downloading and/or using the Data & Software (including downloading, cloning, installing, and any other use of the corresponding github repository), you acknowledge that you have read these terms and conditions, understand them, and agree to be bound by them. If you do not agree with these terms and conditions, you must not download and/or use the Data & Software. Any infringement of the terms of this agreement will automatically terminate your rights under this License
5 |
6 | Ownership / Licensees
7 | The Software and the associated materials has been developed at the Max Planck Institute for Intelligent Systems (hereinafter "MPI"). Any copyright or patent right is owned by and proprietary material of the Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (hereinafter “MPG”; MPI and MPG hereinafter collectively “Max-Planck”) hereinafter the “Licensor”.
8 |
9 | License Grant
10 | Licensor grants you (Licensee) personally a single-user, non-exclusive, non-transferable, free of charge right:
11 |
12 | • To install the Model & Software on computers owned, leased or otherwise controlled by you and/or your organization;
13 | • To use the Model & Software for the sole purpose of performing peaceful non-commercial scientific research, non-commercial education, or non-commercial artistic projects;
14 | • To modify, adapt, translate or create derivative works based upon the Model & Software.
15 |
16 | Any other use, in particular any use for commercial, pornographic, military, or surveillance, purposes is prohibited. This includes, without limitation, incorporation in a commercial product, use in a commercial service, or production of other artifacts for commercial purposes. The Data & Software may not be used to create fake, libelous, misleading, or defamatory content of any kind excluding analyses in peer-reviewed scientific research. The Data & Software may not be reproduced, modified and/or made available in any form to any third party without Max-Planck’s prior written permission.
17 |
18 | The Data & Software may not be used for pornographic purposes or to generate pornographic material whether commercial or not. This license also prohibits the use of the Software to train methods/algorithms/neural networks/etc. for commercial, pornographic, military, surveillance, or defamatory use of any kind. By downloading the Data & Software, you agree not to reverse engineer it.
19 |
20 | No Distribution
21 | The Data & Software and the license herein granted shall not be copied, shared, distributed, re-sold, offered for re-sale, transferred or sub-licensed in whole or in part except that you may make one copy for archive purposes only.
22 |
23 | Disclaimer of Representations and Warranties
24 | You expressly acknowledge and agree that the Data & Software results from basic research, is provided “AS IS”, may contain errors, and that any use of the Data & Software is at your sole risk. LICENSOR MAKES NO REPRESENTATIONS OR WARRANTIES OF ANY KIND CONCERNING THE DATA & SOFTWARE, NEITHER EXPRESS NOR IMPLIED, AND THE ABSENCE OF ANY LEGAL OR ACTUAL DEFECTS, WHETHER DISCOVERABLE OR NOT. Specifically, and not to limit the foregoing, licensor makes no representations or warranties (i) regarding the merchantability or fitness for a particular purpose of the Data & Software, (ii) that the use of the Data & Software will not infringe any patents, copyrights or other intellectual property rights of a third party, and (iii) that the use of the Data & Software will not cause any damage of any kind to you or a third party.
25 |
26 | Limitation of Liability
27 | Because this Data & Software License Agreement qualifies as a donation, according to Section 521 of the German Civil Code (Bürgerliches Gesetzbuch – BGB) Licensor as a donor is liable for intent and gross negligence only. If the Licensor fraudulently conceals a legal or material defect, they are obliged to compensate the Licensee for the resulting damage.
28 | Licensor shall be liable for loss of data only up to the amount of typical recovery costs which would have arisen had proper and regular data backup measures been taken. For the avoidance of doubt Licensor shall be liable in accordance with the German Product Liability Act in the event of product liability. The foregoing applies also to Licensor’s legal representatives or assistants in performance. Any further liability shall be excluded.
29 | Patent claims generated through the usage of the Data & Software cannot be directed towards the copyright holders.
30 | The Data & Software is provided in the state of development the licensor defines. If modified or extended by Licensee, the Licensor makes no claims about the fitness of the Data & Software and is not responsible for any problems such modifications cause.
31 |
32 | No Maintenance Services
33 | You understand and agree that Licensor is under no obligation to provide either maintenance services, update services, notices of latent defects, or corrections of defects with regard to the Data & Software. Licensor nevertheless reserves the right to update, modify, or discontinue the Data & Software at any time.
34 |
35 | Defects of the Data & Software must be notified in writing to the Licensor with a comprehensible description of the error symptoms. The notification of the defect should enable the reproduction of the error. The Licensee is encouraged to communicate any use, results, modification or publication.
36 |
37 | Publications using the Model & Software
38 | You acknowledge that the Data & Software is a valuable scientific resource and agree to appropriately reference the following paper in any publication making use of the Data & Software.
39 |
40 | Citation:
41 | @inproceedings{Feng2022scarf,
42 | author = {Feng, Yao and Yang, Jinlong and Pollefeys, Marc and Black, Michael J. and Bolkart, Timo},
43 | title = {Capturing and Animation of Body and Clothing from Monocular Video},
44 | year = {2022},
45 | booktitle = {SIGGRAPH Asia 2022 Conference Papers},
46 | articleno = {45},
47 | numpages = {9},
48 | location = {Daegu, Republic of Korea},
49 | series = {SA '22}
50 | }
51 |
52 | Commercial licensing opportunities
53 | For commercial uses of the Model & Software, please send email to ps-license@tue.mpg.de
54 |
55 | This Agreement shall be governed by the laws of the Federal Republic of Germany except for the UN Sales Convention.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
SCARF: Capturing and Animation of Body and Clothing from Monocular Video
6 |
7 |
20 |
21 |

22 |
23 |
24 |
25 | This is the Pytorch implementation of SCARF. More details please check our [Project](https://yfeng95.github.io/scarf/) page.
26 |
27 | SCARF extracts a 3D clothed avatar from a monocular video.
28 | SCARF allows us to synthesize new views of the reconstructed avatar, and to animate the avatar with SMPL-X identity shape and pose control.
29 | The disentanglement of thebody and clothing further enables us to transfer clothing between subjects for virtual try-on applications.
30 |
31 | The key features:
32 | 1. animate the avatar by changing body poses (including hand articulation and facial expressions),
33 | 2. synthesize novel views of the avatar, and
34 | 3. transfer clothing between avatars for virtual try-on applications.
35 |
36 |
37 | ## Getting Started
38 | Clone the repo:
39 | ```bash
40 | git clone https://github.com/yfeng95/SCARF
41 | cd SCARF
42 | ```
43 | ### Requirements
44 | ```bash
45 | conda create -n scarf python=3.9
46 | conda activate scarf
47 | pip install -r requirements.txt
48 | ```
49 | If you have problems when installing [pytorch3d](https://github.com/facebookresearch/pytorch3d/blob/main/INSTALL.md), please follow their instructions.
50 | ### Download data
51 | ```
52 | bash fetch_data.sh
53 | ```
54 |
55 | ## Visualization
56 | * check training frames:
57 | ```bash
58 | python main_demo.py --vis_type capture --frame_id 0
59 | ```
60 | * **novel view** synthesis of given frame id:
61 | ```bash
62 | python main_demo.py --vis_type novel_view --frame_id 0
63 | ```
64 | * extract **mesh** and visualize
65 | ```bash
66 | python main_demo.py --vis_type extract_mesh --frame_id 0
67 | ```
68 | You can go to our [project](https://yfeng95.github.io/scarf/) page and play with the extracted meshes.
69 |
70 |
71 |
72 |
73 | * **animation**
74 | ```bash
75 | python main_demo.py --vis_type animate
76 | ```
77 |
78 | * **clothing transfer**
79 | ```bash
80 | # apply clothing from other model
81 | python main_demo.py --vis_type novel_view --clothing_model_path exps/snapshot/male-3-casual
82 | # transfer clothing to new body
83 | python main_demo.py --vis_type novel_view --body_model_path exps/snapshot/male-3-casual
84 | ```
85 |
86 |
87 |
88 |
89 | More data and trained models can be found [here](https://nextcloud.tuebingen.mpg.de/index.php/s/3SEwJmZcfY5LnnN), you can download and put them into `./exps`.
90 |
91 | ## Training
92 | * training with SCARF video example
93 | ```bash
94 | bash train.sh
95 | ```
96 | * training with other videos
97 | check [here](./process_data/README.md) to **prepare data with your own videos**, then change the data_cfg accordingly.
98 |
99 | ## TODO
100 | - [ ] add more processed data and trained models
101 | - [ ] code for refining the pose of trained models
102 | - [ ] with instant ngp
103 |
104 | ## Citation
105 | ```bibtex
106 | @inproceedings{Feng2022scarf,
107 | author = {Feng, Yao and Yang, Jinlong and Pollefeys, Marc and Black, Michael J. and Bolkart, Timo},
108 | title = {Capturing and Animation of Body and Clothing from Monocular Video},
109 | year = {2022},
110 | booktitle = {SIGGRAPH Asia 2022 Conference Papers},
111 | articleno = {45},
112 | numpages = {9},
113 | location = {Daegu, Republic of Korea},
114 | series = {SA '22}
115 | }
116 | ```
117 |
118 | ## Acknowledgments
119 | We thank [Sergey Prokudin](https://ps.is.mpg.de/people/sprokudin), [Weiyang Liu](https://wyliu.com), [Yuliang Xiu](https://xiuyuliang.cn/), [Songyou Peng](https://pengsongyou.github.io/), [Qianli Ma](https://qianlim.github.io/) for fruitful discussions, and PS members for proofreading. We also thank Betty
120 | Mohler, Tsvetelina Alexiadis, Claudia Gallatz, and Andres Camilo Mendoza Patino for their supports with data.
121 |
122 | Special thanks to [Boyi Jiang](https://scholar.google.com/citations?user=lTlZV8wAAAAJ&hl=zh-CN) and [Sida Peng](https://pengsida.net/) for sharing their data.
123 |
124 | Here are some great resources we benefit from:
125 | - [FasterRCNN](https://pytorch.org/vision/main/models/faster_rcnn.html) for detection
126 | - [RobustVideoMatting](https://github.com/PeterL1n/RobustVideoMatting) for background segmentation
127 | - [cloth-segmentation](https://github.com/levindabhi/cloth-segmentation) for clothing segmentation
128 | - [PIXIE](https://github.com/yfeng95/PIXIE) for SMPL-X parameters estimation
129 | - [smplx](https://github.com/vchoutas/smplx) for body models
130 | - [PyTorch3D](https://github.com/facebookresearch/pytorch3d) for Differential Rendering
131 |
132 | Some functions are based on other repositories, we acknowledge the origin individually in each file.
133 |
134 | ## License
135 |
136 | This code and model are available for non-commercial scientific research purposes as defined in the [LICENSE](LICENSE) file. By downloading and using the code and model you agree to the terms in the [LICENSE](LICENSE).
137 |
138 | ## Disclosure
139 | MJB has received research gift funds from Adobe, Intel, Nvidia, Meta/Facebook, and Amazon. MJB has financial interests in Amazon, Datagen Technologies, and Meshcapade GmbH. While MJB is a part-time employee of Meshcapade, his research was performed solely at, and funded solely by, the Max Planck Society.
140 | While TB is part-time employee of Amazon, this research was performed
141 | solely at, and funded solely by, MPI.
142 |
143 | ## Contact
144 | For more questions, please contact yao.feng@tue.mpg.de
145 | For commercial licensing, please contact ps-licensing@tue.mpg.de
--------------------------------------------------------------------------------
/configs/data/mpiis/DSC_7151.yml:
--------------------------------------------------------------------------------
1 | dataset:
2 | image_size: 512
3 | num_workers: 1
4 | white_bg: False
5 | type: 'scarf'
6 | path: 'exps/mpiis/DSC_7151'
7 | subjects: ['DSC_7151']
8 | train:
9 | frame_start: 0
10 | frame_end: 400
11 | frame_step: 2
12 | val:
13 | frame_start: 1
14 | frame_end: 400
15 | frame_step: 60
16 |
--------------------------------------------------------------------------------
/configs/data/mpiis/DSC_7157.yml:
--------------------------------------------------------------------------------
1 | dataset:
2 | image_size: 512
3 | num_workers: 1
4 | white_bg: False
5 | type: 'scarf'
6 | path: 'exps/mpiis/DSC_7157'
7 | subjects: ['DSC_7157']
8 | train:
9 | frame_start: 0
10 | frame_end: 400
11 | frame_step: 2
12 | val:
13 | frame_start: 1
14 | frame_end: 400
15 | frame_step: 60
16 |
--------------------------------------------------------------------------------
/configs/data/snapshot/female-3-casual.yml:
--------------------------------------------------------------------------------
1 | dataset:
2 | image_size: 512
3 | num_workers: 1
4 | white_bg: False
5 | subjects: ['female-3-casual']
6 | train:
7 | frame_start: 1
8 | frame_end: 446
9 | frame_step: 4
10 | val:
11 | frame_start: 47
12 | frame_end: 648
13 | frame_step: 50
14 | test:
15 | frame_start: 447
16 | frame_end: 648
17 | frame_step: 4
--------------------------------------------------------------------------------
/configs/data/snapshot/female-4-casual.yml:
--------------------------------------------------------------------------------
1 | dataset:
2 | image_size: 512
3 | num_workers: 1
4 | white_bg: False
5 | subjects: ['female-4-casual']
6 | train:
7 | frame_start: 1
8 | frame_end: 336
9 | frame_step: 4
10 | val:
11 | frame_start: 336
12 | frame_end: 524
13 | frame_step: 50
14 | test:
15 | frame_start: 336
16 | frame_end: 524
17 | frame_step: 4
--------------------------------------------------------------------------------
/configs/data/snapshot/male-3-casual.yml:
--------------------------------------------------------------------------------
1 | dataset:
2 | image_size: 512
3 | num_workers: 1
4 | white_bg: False
5 | type: 'scarf'
6 | path: '/home/yfeng/Data/Projects-data/DELTA/datasets/body'
7 | subjects: ['male-3-casual']
8 | train:
9 | frame_start: 1
10 | frame_end: 456
11 | frame_step: 4
12 | val:
13 | frame_start: 456
14 | frame_end: 676
15 | frame_step: 50
16 | test:
17 | frame_start: 457
18 | frame_end: 676
19 | frame_step: 4
--------------------------------------------------------------------------------
/configs/data/snapshot/male-4-casual.yml:
--------------------------------------------------------------------------------
1 | dataset:
2 | image_size: 512
3 | num_workers: 1
4 | white_bg: False
5 | subjects: ['male-4-casual']
6 | train:
7 | frame_start: 1
8 | frame_end: 660
9 | frame_step: 6
10 | val:
11 | frame_start: 661
12 | frame_end: 873
13 | frame_step: 50
14 | test:
15 | frame_start: 661
16 | frame_end: 873
17 | frame_step: 4
--------------------------------------------------------------------------------
/configs/exp/stage_0_nerf.yml:
--------------------------------------------------------------------------------
1 | # scarf model
2 | use_mesh: False
3 | use_nerf: True
4 | use_fine: True
5 | k_neigh: 6
6 | chunk: 4096
7 | opt_pose: True
8 | opt_cam: True
9 | use_highres_smplx: True
10 |
11 | # training
12 | train:
13 | batch_size: 1
14 | max_steps: 100000 # training longer for more details
15 | lr: 5e-4
16 | tex_lr: 5e-4
17 | geo_lr: 5e-4
18 | pose_lr: 5e-4
19 | precrop_iters: 400
20 | precrop_frac: 0.3
21 | log_steps: 20
22 | val_steps: 500
23 | checkpoint_steps: 1000
24 |
25 | dataset:
26 | image_size: 512
27 | num_workers: 1
28 | white_bg: False
29 | type: 'scarf'
30 |
31 | loss:
32 | w_rgb: 1.
33 | w_alpha: 0.1
34 | w_depth: 0.
35 | mesh_w_rgb: 1.
36 | mesh_w_mrf: 0. #0.0001
37 | skin_consistency_type: verts_all_mean
38 | mesh_skin_consistency: 0.001
39 | mesh_w_alpha: 0.05
40 | mesh_w_alpha_skin: 1.
41 | mesh_inside_mask: 1.
42 | reg_offset_w: 10.
43 | reg_lap_w: 2.0
44 | reg_edge_w: 1.
45 | reg_normal_w: 0.01
46 | nerf_reg_normal_w: 0.001
--------------------------------------------------------------------------------
/configs/exp/stage_1_hybrid.yml:
--------------------------------------------------------------------------------
1 | # scarf model
2 | use_fine: True
3 | use_mesh: True
4 | use_nerf: True
5 | lbs_map: True
6 | k_neigh: 6
7 | chunk: 4096
8 | opt_pose: True
9 | opt_cam: True
10 | tex_network: siren
11 | use_highres_smplx: True
12 | exclude_hand: False
13 | opt_mesh: True
14 | mesh_offset_scale: 0.04
15 |
16 | sample_patch_rays: True
17 | sample_patch_size: 48
18 |
19 | use_deformation: True
20 | deformation_dim: 3
21 | deformation_type: posed_verts
22 |
23 | # training
24 | train:
25 | batch_size: 1
26 | max_steps: 50000 ## normally, 10k is enough
27 | lr: 1e-4
28 | tex_lr: 1e-5
29 | geo_lr: 1e-5
30 | pose_lr: 1e-4
31 | # precrop_iters: 400
32 | precrop_frac: 0.3
33 | log_steps: 20
34 | val_steps: 500
35 | checkpoint_steps: 1000
36 |
37 | dataset:
38 | image_size: 512
39 | num_workers: 1
40 | white_bg: False
41 | type: 'scarf'
42 |
43 | loss:
44 | # nerf
45 | w_rgb: 1.
46 | w_patch_mrf: 0.0005
47 | w_alpha: 0.5
48 | w_depth: 0.
49 | # mesh
50 | mesh_w_rgb: 1.
51 | mesh_w_alpha: 0.001
52 | mesh_w_alpha_skin: 30.
53 | mesh_w_mrf: 0.0005
54 | reg_offset_w: 400.
55 | reg_offset_w_face: 200.
56 | reg_lap_w: 0. #130.0
57 | reg_edge_w: 500.0
58 | use_new_edge_loss: True
59 | skin_consistency_type: render_hand_mean
60 | mesh_skin_consistency: 0.01
61 | mesh_inside_mask: 40.
62 | nerf_reg_dxyz_w: 2.
63 |
64 |
--------------------------------------------------------------------------------
/configs/exp/stage_1_hybrid_perceptual.yml:
--------------------------------------------------------------------------------
1 | # scarf model
2 | use_fine: True
3 | use_mesh: True
4 | use_nerf: True
5 | lbs_map: True
6 | k_neigh: 6
7 | chunk: 4096
8 | opt_pose: True
9 | opt_cam: True
10 | tex_network: siren
11 | use_highres_smplx: True
12 | exclude_hand: False
13 | opt_mesh: True
14 | mesh_offset_scale: 0.04
15 |
16 | sample_patch_rays: True
17 | sample_patch_size: 48
18 |
19 | use_deformation: True
20 | deformation_dim: 3
21 | deformation_type: posed_verts
22 |
23 | # training
24 | train:
25 | batch_size: 1
26 | max_steps: 50000
27 | lr: 1e-4
28 | tex_lr: 1e-5
29 | geo_lr: 1e-5
30 | pose_lr: 1e-4
31 | # precrop_iters: 400
32 | precrop_frac: 0.3
33 | log_steps: 20
34 | val_steps: 500
35 | checkpoint_steps: 1000
36 |
37 | dataset:
38 | image_size: 512
39 | num_workers: 1
40 | white_bg: False
41 | type: 'scarf'
42 |
43 | loss:
44 | # nerf
45 | w_rgb: 1.
46 | w_patch_mrf: 0. #0005
47 | w_patch_perceptual: 0.04
48 | w_alpha: 0.5
49 | w_depth: 0.
50 | # mesh
51 | mesh_w_rgb: 1.
52 | mesh_w_alpha: 0.001
53 | mesh_w_alpha_skin: 30.
54 | mesh_w_mrf: 0. #0005
55 | mesh_w_perceptual: 0.04
56 | reg_offset_w: 400.
57 | reg_offset_w_face: 200.
58 | reg_lap_w: 0. #130.0
59 | reg_edge_w: 500.0
60 | use_new_edge_loss: True
61 | skin_consistency_type: render_hand_mean
62 | mesh_skin_consistency: 0.01
63 | mesh_inside_mask: 40.
64 | nerf_reg_dxyz_w: 2.
65 |
66 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: scarf
2 | channels:
3 | - defaults
4 | dependencies:
5 | - _libgcc_mutex=0.1=main
6 | - _openmp_mutex=5.1=1_gnu
7 | - ca-certificates=2022.10.11=h06a4308_0
8 | - certifi=2022.9.24=py39h06a4308_0
9 | - ld_impl_linux-64=2.38=h1181459_1
10 | - libffi=3.4.2=h6a678d5_6
11 | - libgcc-ng=11.2.0=h1234567_1
12 | - libgomp=11.2.0=h1234567_1
13 | - libstdcxx-ng=11.2.0=h1234567_1
14 | - ncurses=6.3=h5eee18b_3
15 | - openssl=1.1.1s=h7f8727e_0
16 | - pip=22.3.1=py39h06a4308_0
17 | - python=3.9.15=h7a1cb2a_2
18 | - readline=8.2=h5eee18b_0
19 | - setuptools=65.5.0=py39h06a4308_0
20 | - sqlite=3.40.0=h5082296_0
21 | - tk=8.6.12=h1ccaba5_0
22 | - tzdata=2022g=h04d1e81_0
23 | - wheel=0.37.1=pyhd3eb1b0_0
24 | - xz=5.2.8=h5eee18b_0
25 | - zlib=1.2.13=h5eee18b_0
26 | - pip:
27 | - asttokens==2.2.1
28 | - backcall==0.2.0
29 | - charset-normalizer==2.1.1
30 | - chumpy==0.70
31 | - click==8.1.3
32 | - contourpy==1.0.6
33 | - cycler==0.11.0
34 | - decorator==5.1.1
35 | - docker-pycreds==0.4.0
36 | - executing==1.2.0
37 | - fonttools==4.38.0
38 | - fvcore==0.1.5.post20221122
39 | - gitdb==4.0.10
40 | - gitpython==3.1.29
41 | - idna==3.4
42 | - imageio==2.22.4
43 | - iopath==0.1.10
44 | - ipdb==0.13.9
45 | - ipython==8.7.0
46 | - jedi==0.18.2
47 | - kiwisolver==1.4.4
48 | - kornia==0.6.0
49 | - loguru==0.6.0
50 | - lpips==0.1.4
51 | - matplotlib==3.6.2
52 | - matplotlib-inline==0.1.6
53 | - networkx==2.8.8
54 | - numpy==1.23.5
55 | - nvidia-cublas-cu11==11.10.3.66
56 | - nvidia-cuda-nvrtc-cu11==11.7.99
57 | - nvidia-cuda-runtime-cu11==11.7.99
58 | - nvidia-cudnn-cu11==8.5.0.96
59 | - opencv-python==4.6.0.66
60 | - packaging==22.0
61 | - parso==0.8.3
62 | - pathtools==0.1.2
63 | - pexpect==4.8.0
64 | - pickleshare==0.7.5
65 | - pillow==9.3.0
66 | - portalocker==2.6.0
67 | - promise==2.3
68 | - prompt-toolkit==3.0.36
69 | - protobuf==4.21.11
70 | - psutil==5.9.4
71 | - ptyprocess==0.7.0
72 | - pure-eval==0.2.2
73 | - pygments==2.13.0
74 | - pymcubes==0.1.2
75 | - pyparsing==3.0.9
76 | - python-dateutil==2.8.2
77 | - pywavelets==1.4.1
78 | - pyyaml==5.1.1
79 | - requests==2.28.1
80 | - scikit-image==0.19.3
81 | - scipy==1.9.3
82 | - sentry-sdk==1.11.1
83 | - setproctitle==1.3.2
84 | - shortuuid==1.0.11
85 | - six==1.16.0
86 | - smmap==5.0.0
87 | - stack-data==0.6.2
88 | - tabulate==0.9.0
89 | - termcolor==2.1.1
90 | - tifffile==2022.10.10
91 | - toml==0.10.2
92 | - torch==1.13.0
93 | - torchmetrics==0.11.0
94 | - torchvision==0.14.0
95 | - tqdm==4.64.1
96 | - traitlets==5.7.0
97 | - trimesh==3.17.1
98 | - typing-extensions==4.4.0
99 | - urllib3==1.26.13
100 | - vulture==2.6
101 | - wandb==0.13.6
102 | - wcwidth==0.2.5
103 | - yacs==0.1.8
--------------------------------------------------------------------------------
/fetch_data.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | mkdir -p ./data
3 |
4 | # SMPL-X 2020 (neutral SMPL-X model with the FLAME 2020 expression blendshapes)
5 | urle () { [[ "${1}" ]] || return 1; local LANG=C i x; for (( i = 0; i < ${#1}; i++ )); do x="${1:i:1}"; [[ "${x}" == [a-zA-Z0-9.~-] ]] && echo -n "${x}" || printf '%%%02X' "'${x}"; done; echo; }
6 | echo -e "\nYou need to register at https://smpl-x.is.tue.mpg.de"
7 | read -p "Username (SMPL-X):" username
8 | read -p "Password (SMPL-X):" password
9 | username=$(urle $username)
10 | password=$(urle $password)
11 | wget --post-data "username=$username&password=$password" 'https://download.is.tue.mpg.de/download.php?domain=smplx&sfile=SMPLX_NEUTRAL_2020.npz&resume=1' -O './data/SMPLX_NEUTRAL_2020.npz' --no-check-certificate --continue
12 |
13 | # scarf utilities
14 | echo -e "\nDownloading SCARF data..."
15 | wget https://owncloud.tuebingen.mpg.de/index.php/s/n58Fzbzz7Ei9x2W/download -O ./data/scarf_utilities.zip
16 | unzip ./data/scarf_utilities.zip -d ./data
17 | rm ./data/scarf_utilities.zip
18 |
19 | # download two examples
20 | echo -e "\nDownloading SCARF training data and trained avatars..."
21 | wget https://owncloud.tuebingen.mpg.de/index.php/s/geTtN4p5YTJaqPi/download -O scarf-exp-data-small.zip
22 | unzip ./scarf-exp-data-small.zip -d .
23 | rm ./scarf-exp-data-small.zip
24 | mv ./scarf-exp-data-small ./exps
25 |
--------------------------------------------------------------------------------
/lib/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yfeng95/SCARF/d8d89d27431120569380f16094d419df063c2fe4/lib/__init__.py
--------------------------------------------------------------------------------
/lib/datasets/build_datasets.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | def build_train(cfg, mode='train'):
4 | if cfg.type == 'scarf':
5 | from .scarf_data import NerfDataset
6 | return NerfDataset(cfg, mode=mode)
7 |
--------------------------------------------------------------------------------
/lib/datasets/scarf_data.py:
--------------------------------------------------------------------------------
1 | from skimage.transform import rescale, resize, downscale_local_mean
2 | from skimage.io import imread
3 | import cv2
4 | import pickle
5 | from tqdm import tqdm
6 | import numpy as np
7 | import torch
8 | import os
9 | from glob import glob
10 | from ..utils import rotation_converter
11 |
12 | class NerfDataset(torch.utils.data.Dataset):
13 | """SCARF Dataset"""
14 |
15 | def __init__(self, cfg, mode='train'):
16 | super().__init__()
17 | subject = cfg.subjects[0]
18 | imagepath_list = []
19 | self.dataset_path = os.path.join(cfg.path, subject)
20 | if not os.path.exists (self.dataset_path):
21 | print(f'{self.dataset_path} not exists, please check the data path')
22 | exit()
23 | imagepath_list = glob(os.path.join(self.dataset_path, 'image', f'{subject}_*.png'))
24 | root_dir = os.path.join(self.dataset_path, 'cache')
25 | os.makedirs(root_dir, exist_ok=True)
26 | self.pose_cache_path = os.path.join(root_dir, 'pose.pt')
27 | self.cam_cache_path = os.path.join(root_dir, 'cam.pt')
28 | self.exp_cache_path = os.path.join(root_dir, 'exp.pt')
29 | self.beta_cache_path = os.path.join(root_dir, 'beta.pt')
30 | self.subject_id = subject
31 |
32 | imagepath_list = sorted(imagepath_list)
33 | frame_start = getattr(cfg, mode).frame_start
34 | frame_end = getattr(cfg, mode).frame_end
35 | frame_step = getattr(cfg, mode).frame_step
36 | imagepath_list = imagepath_list[frame_start:min(len(imagepath_list), frame_end):frame_step]
37 |
38 | self.data = imagepath_list
39 | if cfg.n_images < 10:
40 | self.data = self.data[:cfg.n_images]
41 | assert len(self.data) > 0, f"Can't find data; make sure data path {self.dataset_path} is correct"
42 |
43 | self.image_size = cfg.image_size
44 | self.white_bg = cfg.white_bg
45 |
46 | def __len__(self):
47 | return len(self.data)
48 |
49 | def __getitem__(self, index):
50 | # load image
51 | imagepath = self.data[index]
52 | image = imread(imagepath) / 255.
53 | imagename = imagepath.split('/')[-1].split('.')[0]
54 | image = image[:, :, :3]
55 | frame_id = int(imagename.split('_f')[-1])
56 | frame_id = f'{frame_id:06d}'
57 |
58 | # load mask
59 | maskpath = os.path.join(self.dataset_path, 'matting', f'{imagename}.png')
60 | alpha_image = imread(maskpath) / 255.
61 | alpha_image = (alpha_image > 0.5).astype(np.float32)
62 | alpha_image = alpha_image[:, :, -1:]
63 | if self.white_bg:
64 | image = image[..., :3] * alpha_image + (1. - alpha_image)
65 | else:
66 | image = image[..., :3] * alpha_image
67 | # add alpha channel
68 | image = np.concatenate([image, alpha_image[:, :, :1]], axis=-1)
69 | image = resize(image, [self.image_size, self.image_size])
70 | image = torch.from_numpy(image.transpose(2, 0, 1)).float()
71 | mask = image[3:]
72 | image = image[:3]
73 |
74 | # load camera and pose
75 | frame_id = int(imagename.split('_f')[-1])
76 | name = self.subject_id
77 |
78 | # load pickle
79 | pkl_file = os.path.join(self.dataset_path, 'pixie', f'{imagename}_param.pkl')
80 | with open(pkl_file, 'rb') as f:
81 | codedict = pickle.load(f)
82 | param_dict = {}
83 | for key in codedict.keys():
84 | if isinstance(codedict[key], str):
85 | param_dict[key] = codedict[key]
86 | else:
87 | param_dict[key] = torch.from_numpy(codedict[key])
88 | beta = param_dict['shape'].squeeze()[:10]
89 | # full_pose = param_dict['full_pose'].squeeze()
90 | jaw_pose = torch.eye(3, dtype=torch.float32).unsqueeze(0) #param_dict['jaw_pose']
91 | eye_pose = torch.eye(3, dtype=torch.float32).unsqueeze(0).repeat(2,1,1)
92 | # hand_pose = torch.eye(3, dtype=torch.float32).unsqueeze(0).repeat(15,1,1)
93 | full_pose = torch.cat([param_dict['global_pose'], param_dict['body_pose'],
94 | jaw_pose, eye_pose,
95 | # hand_pose, hand_pose], dim=0)
96 | param_dict['left_hand_pose'], param_dict['right_hand_pose']], dim=0)
97 | cam = param_dict['body_cam'].squeeze()
98 | exp = torch.zeros_like(param_dict['exp'].squeeze()[:10])
99 | frame_id = f'{frame_id:06}'
100 | data = {
101 | 'idx': index,
102 | 'frame_id': frame_id,
103 | 'name': name,
104 | 'imagepath': imagepath,
105 | 'image': image,
106 | 'mask': mask,
107 | 'full_pose': full_pose,
108 | 'cam': cam,
109 | 'beta': beta,
110 | 'exp': exp
111 | }
112 |
113 | seg_image_path = os.path.join(self.dataset_path, 'cloth_segmentation', f"{imagename}.png")
114 | cloth_seg = imread(seg_image_path)/255.
115 | cloth_seg = resize(cloth_seg, [self.image_size, self.image_size])
116 | cloth_mask = torch.from_numpy(cloth_seg[:,:,:3].sum(-1))[None,...]
117 | cloth_mask = (cloth_mask > 0.1).float()
118 | cloth_mask = ((mask + cloth_mask) > 1.5).float()
119 | skin_mask = ((mask - cloth_mask) > 0).float()
120 | data['cloth_mask'] = cloth_mask
121 | data['skin_mask'] = skin_mask
122 |
123 | return data
--------------------------------------------------------------------------------
/lib/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yfeng95/SCARF/d8d89d27431120569380f16094d419df063c2fe4/lib/models/__init__.py
--------------------------------------------------------------------------------
/lib/models/embedding.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | class Embedding(nn.Module):
5 | def __init__(self, in_channels, N_freqs, logscale=True):
6 | """
7 | Defines a function that embeds x to (x, sin(2^k x), cos(2^k x), ...)
8 | in_channels: number of input channels (3 for both xyz and direction)
9 | """
10 | super(Embedding, self).__init__()
11 | self.N_freqs = N_freqs
12 | self.in_channels = in_channels
13 | self.funcs = [torch.sin, torch.cos]
14 | self.out_channels = in_channels*(len(self.funcs)*N_freqs+1)
15 |
16 | if logscale:
17 | self.freq_bands = 2**torch.linspace(0, N_freqs-1, N_freqs)
18 | else:
19 | self.freq_bands = torch.linspace(1, 2**(N_freqs-1), N_freqs)
20 |
21 | def forward(self, x):
22 | """
23 | Embeds x to (x, sin(2^k x), cos(2^k x), ...)
24 | Different from the paper, "x" is also in the output
25 | See https://github.com/bmild/nerf/issues/12
26 |
27 | Inputs:
28 | x: (B, self.in_channels)
29 |
30 | Outputs:
31 | out: (B, self.out_channels)
32 | """
33 | out = [x]
34 | for freq in self.freq_bands:
35 | for func in self.funcs:
36 | out += [func(freq*x)]
37 |
38 | return torch.cat(out, -1)
39 |
40 |
--------------------------------------------------------------------------------
/lib/models/lbs.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4 | # holder of all proprietary rights on this computer program.
5 | # You can only use this computer program if you have closed
6 | # a license agreement with MPG or you get the right to use the computer
7 | # program from someone who is authorized to grant you that right.
8 | # Any use of the computer program without a valid license is prohibited and
9 | # liable to prosecution.
10 | #
11 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
12 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
13 | # for Intelligent Systems. All rights reserved.
14 | #
15 | # Contact: ps-license@tuebingen.mpg.de
16 |
17 | from __future__ import absolute_import
18 | from __future__ import print_function
19 | from __future__ import division
20 |
21 | import numpy as np
22 | import os
23 | import yaml
24 | import torch
25 | import torch.nn.functional as F
26 | from torch import nn
27 |
28 | def rot_mat_to_euler(rot_mats):
29 | # Calculates rotation matrix to euler angles
30 | # Careful for extreme cases of eular angles like [0.0, pi, 0.0]
31 |
32 | sy = torch.sqrt(rot_mats[:, 0, 0] * rot_mats[:, 0, 0] +
33 | rot_mats[:, 1, 0] * rot_mats[:, 1, 0])
34 | return torch.atan2(-rot_mats[:, 2, 0], sy)
35 |
36 |
37 | def find_dynamic_lmk_idx_and_bcoords(vertices, pose, dynamic_lmk_faces_idx,
38 | dynamic_lmk_b_coords,
39 | head_kin_chain, dtype=torch.float32):
40 | ''' Compute the faces, barycentric coordinates for the dynamic landmarks
41 |
42 |
43 | To do so, we first compute the rotation of the neck around the y-axis
44 | and then use a pre-computed look-up table to find the faces and the
45 | barycentric coordinates that will be used.
46 |
47 | Special thanks to Soubhik Sanyal (soubhik.sanyal@tuebingen.mpg.de)
48 | for providing the original TensorFlow implementation and for the LUT.
49 |
50 | Parameters
51 | ----------
52 | vertices: torch.tensor BxVx3, dtype = torch.float32
53 | The tensor of input vertices
54 | pose: torch.tensor Bx(Jx3), dtype = torch.float32
55 | The current pose of the body model
56 | dynamic_lmk_faces_idx: torch.tensor L, dtype = torch.long
57 | The look-up table from neck rotation to faces
58 | dynamic_lmk_b_coords: torch.tensor Lx3, dtype = torch.float32
59 | The look-up table from neck rotation to barycentric coordinates
60 | head_kin_chain: list
61 | A python list that contains the indices of the joints that form the
62 | kinematic chain of the neck.
63 | dtype: torch.dtype, optional
64 |
65 | Returns
66 | -------
67 | dyn_lmk_faces_idx: torch.tensor, dtype = torch.long
68 | A tensor of size BxL that contains the indices of the faces that
69 | will be used to compute the current dynamic landmarks.
70 | dyn_lmk_b_coords: torch.tensor, dtype = torch.float32
71 | A tensor of size BxL that contains the indices of the faces that
72 | will be used to compute the current dynamic landmarks.
73 | '''
74 |
75 | batch_size = vertices.shape[0]
76 | pose = pose.detach()
77 | # aa_pose = torch.index_select(pose.view(batch_size, -1, 3), 1,
78 | # head_kin_chain)
79 | # rot_mats = batch_rodrigues(
80 | # aa_pose.view(-1, 3), dtype=dtype).view(batch_size, -1, 3, 3)
81 | rot_mats = torch.index_select(pose, 1, head_kin_chain)
82 |
83 | rel_rot_mat = torch.eye(3, device=vertices.device,
84 | dtype=dtype).unsqueeze_(dim=0)
85 | for idx in range(len(head_kin_chain)):
86 | # rel_rot_mat = torch.bmm(rot_mats[:, idx], rel_rot_mat)
87 | rel_rot_mat = torch.matmul(rot_mats[:, idx], rel_rot_mat)
88 |
89 | y_rot_angle = torch.round(
90 | torch.clamp(-rot_mat_to_euler(rel_rot_mat) * 180.0 / np.pi,
91 | max=39)).to(dtype=torch.long)
92 | # print(y_rot_angle[0])
93 | neg_mask = y_rot_angle.lt(0).to(dtype=torch.long)
94 | mask = y_rot_angle.lt(-39).to(dtype=torch.long)
95 | neg_vals = mask * 78 + (1 - mask) * (39 - y_rot_angle)
96 | y_rot_angle = (neg_mask * neg_vals +
97 | (1 - neg_mask) * y_rot_angle)
98 | # print(y_rot_angle[0])
99 |
100 | dyn_lmk_faces_idx = torch.index_select(dynamic_lmk_faces_idx,
101 | 0, y_rot_angle)
102 | dyn_lmk_b_coords = torch.index_select(dynamic_lmk_b_coords,
103 | 0, y_rot_angle)
104 |
105 | return dyn_lmk_faces_idx, dyn_lmk_b_coords
106 |
107 |
108 | def vertices2landmarks(vertices, faces, lmk_faces_idx, lmk_bary_coords):
109 | ''' Calculates landmarks by barycentric interpolation
110 |
111 | Parameters
112 | ----------
113 | vertices: torch.tensor BxVx3, dtype = torch.float32
114 | The tensor of input vertices
115 | faces: torch.tensor Fx3, dtype = torch.long
116 | The faces of the mesh
117 | lmk_faces_idx: torch.tensor L, dtype = torch.long
118 | The tensor with the indices of the faces used to calculate the
119 | landmarks.
120 | lmk_bary_coords: torch.tensor Lx3, dtype = torch.float32
121 | The tensor of barycentric coordinates that are used to interpolate
122 | the landmarks
123 |
124 | Returns
125 | -------
126 | landmarks: torch.tensor BxLx3, dtype = torch.float32
127 | The coordinates of the landmarks for each mesh in the batch
128 | '''
129 | # Extract the indices of the vertices for each face
130 | # BxLx3
131 | batch_size, num_verts = vertices.shape[:2]
132 | device = vertices.device
133 |
134 | lmk_faces = torch.index_select(faces, 0, lmk_faces_idx.view(-1)).view(
135 | batch_size, -1, 3)
136 |
137 | lmk_faces += torch.arange(
138 | batch_size, dtype=torch.long, device=device).view(-1, 1, 1) * num_verts
139 |
140 | lmk_vertices = vertices.view(-1, 3)[lmk_faces].view(
141 | batch_size, -1, 3, 3)
142 |
143 | landmarks = torch.einsum('blfi,blf->bli', [lmk_vertices, lmk_bary_coords])
144 | return landmarks
145 |
146 |
147 | def lbs(betas, pose, v_template, shapedirs, posedirs, J_regressor, parents,
148 | lbs_weights, pose2rot=True, dtype=torch.float32):
149 | ''' Performs Linear Blend Skinning with the given shape and pose parameters
150 |
151 | Parameters
152 | ----------
153 | betas : torch.tensor BxNB
154 | The tensor of shape parameters
155 | pose : torch.tensor Bx(J + 1) * 3
156 | The pose parameters in axis-angle format
157 | v_template torch.tensor BxVx3
158 | The template mesh that will be deformed
159 | shapedirs : torch.tensor 1xNB
160 | The tensor of PCA shape displacements
161 | posedirs : torch.tensor Px(V * 3)
162 | The pose PCA coefficients
163 | J_regressor : torch.tensor JxV
164 | The regressor array that is used to calculate the joints from
165 | the position of the vertices
166 | parents: torch.tensor J
167 | The array that describes the kinematic tree for the model
168 | lbs_weights: torch.tensor N x V x (J + 1)
169 | The linear blend skinning weights that represent how much the
170 | rotation matrix of each part affects each vertex
171 | pose2rot: bool, optional
172 | Flag on whether to convert the input pose tensor to rotation
173 | matrices. The default value is True. If False, then the pose tensor
174 | should already contain rotation matrices and have a size of
175 | Bx(J + 1)x9
176 | dtype: torch.dtype, optional
177 |
178 | Returns
179 | -------
180 | verts: torch.tensor BxVx3
181 | The vertices of the mesh after applying the shape and pose
182 | displacements.
183 | joints: torch.tensor BxJx3
184 | The joints of the model
185 | '''
186 |
187 | batch_size = max(betas.shape[0], pose.shape[0])
188 | device = betas.device
189 |
190 | # Add shape contribution
191 | shape_offsets = blend_shapes(betas, shapedirs)
192 | v_shaped = v_template + shape_offsets
193 |
194 | # Get the joints
195 | # NxJx3 array
196 | J = vertices2joints(J_regressor, v_shaped)
197 |
198 | # 3. Add pose blend shapes
199 | # N x J x 3 x 3
200 | ident = torch.eye(3, dtype=dtype, device=device)
201 | if pose2rot:
202 | rot_mats = batch_rodrigues(
203 | pose.view(-1, 3), dtype=dtype).view([batch_size, -1, 3, 3])
204 |
205 | pose_feature = (rot_mats[:, 1:, :, :] - ident).view([batch_size, -1])
206 | # (N x P) x (P, V * 3) -> N x V x 3
207 | pose_offsets = torch.matmul(pose_feature, posedirs) \
208 | .view(batch_size, -1, 3)
209 | else:
210 | pose_feature = pose[:, 1:].view(batch_size, -1, 3, 3) - ident
211 | rot_mats = pose.view(batch_size, -1, 3, 3)
212 |
213 | pose_offsets = torch.matmul(pose_feature.view(batch_size, -1),
214 | posedirs).view(batch_size, -1, 3)
215 |
216 | v_posed = pose_offsets + v_shaped
217 | # 4. Get the global joint location
218 | J_transformed, A = batch_rigid_transform(rot_mats, J, parents, dtype=dtype)
219 |
220 | # 5. Do skinning:
221 | # W is N x V x (J + 1)
222 | W = lbs_weights.unsqueeze(dim=0).expand([batch_size, -1, -1])
223 | # (N x V x (J + 1)) x (N x (J + 1) x 16)
224 | num_joints = J_regressor.shape[0]
225 | T = torch.matmul(W, A.view(batch_size, num_joints, 16)) \
226 | .view(batch_size, -1, 4, 4)
227 |
228 | homogen_coord = torch.ones([batch_size, v_posed.shape[1], 1],
229 | dtype=dtype, device=device)
230 | v_posed_homo = torch.cat([v_posed, homogen_coord], dim=2)
231 | v_homo = torch.matmul(T, torch.unsqueeze(v_posed_homo, dim=-1))
232 |
233 | verts = v_homo[:, :, :3, 0]
234 |
235 | return verts, J_transformed, A, T, shape_offsets, pose_offsets
236 |
237 |
238 |
239 | def vertices2joints(J_regressor, vertices):
240 | ''' Calculates the 3D joint locations from the vertices
241 |
242 | Parameters
243 | ----------
244 | J_regressor : torch.tensor JxV
245 | The regressor array that is used to calculate the joints from the
246 | position of the vertices
247 | vertices : torch.tensor BxVx3
248 | The tensor of mesh vertices
249 |
250 | Returns
251 | -------
252 | torch.tensor BxJx3
253 | The location of the joints
254 | '''
255 |
256 | return torch.einsum('bik,ji->bjk', [vertices, J_regressor])
257 |
258 |
259 | def blend_shapes(betas, shape_disps):
260 | ''' Calculates the per vertex displacement due to the blend shapes
261 |
262 |
263 | Parameters
264 | ----------
265 | betas : torch.tensor Bx(num_betas)
266 | Blend shape coefficients
267 | shape_disps: torch.tensor Vx3x(num_betas)
268 | Blend shapes
269 |
270 | Returns
271 | -------
272 | torch.tensor BxVx3
273 | The per-vertex displacement due to shape deformation
274 | '''
275 |
276 | # Displacement[b, m, k] = sum_{l} betas[b, l] * shape_disps[m, k, l]
277 | # i.e. Multiply each shape displacement by its corresponding beta and
278 | # then sum them.
279 | blend_shape = torch.einsum('bl,mkl->bmk', [betas, shape_disps])
280 | return blend_shape
281 |
282 |
283 | def batch_rodrigues(rot_vecs, epsilon=1e-8, dtype=torch.float32):
284 | ''' Calculates the rotation matrices for a batch of rotation vectors
285 | Parameters
286 | ----------
287 | rot_vecs: torch.tensor Nx3
288 | array of N axis-angle vectors
289 | Returns
290 | -------
291 | R: torch.tensor Nx3x3
292 | The rotation matrices for the given axis-angle parameters
293 | '''
294 |
295 | batch_size = rot_vecs.shape[0]
296 | device = rot_vecs.device
297 |
298 | angle = torch.norm(rot_vecs + 1e-8, dim=1, keepdim=True)
299 | rot_dir = rot_vecs / angle
300 |
301 | cos = torch.unsqueeze(torch.cos(angle), dim=1)
302 | sin = torch.unsqueeze(torch.sin(angle), dim=1)
303 |
304 | # Bx1 arrays
305 | rx, ry, rz = torch.split(rot_dir, 1, dim=1)
306 | K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device)
307 |
308 | zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device)
309 | K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1) \
310 | .view((batch_size, 3, 3))
311 |
312 | ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0)
313 | rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K)
314 | return rot_mat
315 |
316 |
317 | def transform_mat(R, t):
318 | ''' Creates a batch of transformation matrices
319 | Args:
320 | - R: Bx3x3 array of a batch of rotation matrices
321 | - t: Bx3x1 array of a batch of translation vectors
322 | Returns:
323 | - T: Bx4x4 Transformation matrix
324 | '''
325 | # No padding left or right, only add an extra row
326 | return torch.cat([F.pad(R, [0, 0, 0, 1]),
327 | F.pad(t, [0, 0, 0, 1], value=1)], dim=2)
328 |
329 |
330 | def batch_rigid_transform(rot_mats, joints, parents, dtype=torch.float32):
331 | """
332 | Applies a batch of rigid transformations to the joints
333 |
334 | Parameters
335 | ----------
336 | rot_mats : torch.tensor BxNx3x3
337 | Tensor of rotation matrices
338 | joints : torch.tensor BxNx3
339 | Locations of joints
340 | parents : torch.tensor BxN
341 | The kinematic tree of each object
342 | dtype : torch.dtype, optional:
343 | The data type of the created tensors, the default is torch.float32
344 |
345 | Returns
346 | -------
347 | posed_joints : torch.tensor BxNx3
348 | The locations of the joints after applying the pose rotations
349 | rel_transforms : torch.tensor BxNx4x4
350 | The relative (with respect to the root joint) rigid transformations
351 | for all the joints
352 | """
353 |
354 | joints = torch.unsqueeze(joints, dim=-1)
355 |
356 | rel_joints = joints.clone()
357 | rel_joints[:, 1:] -= joints[:, parents[1:]]
358 |
359 | transforms_mat = transform_mat(
360 | rot_mats.reshape(-1, 3, 3),
361 | rel_joints.reshape(-1, 3, 1)).reshape(-1, joints.shape[1], 4, 4)
362 |
363 | transform_chain = [transforms_mat[:, 0]]
364 | for i in range(1, parents.shape[0]):
365 | # Subtract the joint location at the rest pose
366 | # No need for rotation, since it's identity when at rest
367 | curr_res = torch.matmul(transform_chain[parents[i]],
368 | transforms_mat[:, i])
369 | transform_chain.append(curr_res)
370 |
371 | transforms = torch.stack(transform_chain, dim=1)
372 |
373 | # The last column of the transformations contains the posed joints
374 | posed_joints = transforms[:, :, :3, 3]
375 |
376 | # # The last column of the transformations contains the posed joints
377 | # posed_joints = transforms[:, :, :3, 3]
378 |
379 | joints_homogen = F.pad(joints, [0, 0, 0, 1])
380 |
381 | rel_transforms = transforms - F.pad(
382 | torch.matmul(transforms, joints_homogen), [3, 0, 0, 0, 0, 0, 0, 0])
383 |
384 | return posed_joints, rel_transforms
385 |
386 | class JointsFromVerticesSelector(nn.Module):
387 | def __init__(self, fname):
388 | ''' Selects extra joints from vertices
389 | '''
390 | super(JointsFromVerticesSelector, self).__init__()
391 |
392 | err_msg = (
393 | 'Either pass a filename or triangle face ids, names and'
394 | ' barycentrics')
395 | assert fname is not None or (
396 | face_ids is not None and bcs is not None and names is not None
397 | ), err_msg
398 | if fname is not None:
399 | fname = os.path.expanduser(os.path.expandvars(fname))
400 | with open(fname, 'r') as f:
401 | data = yaml.load(f)
402 | names = list(data.keys())
403 | bcs = []
404 | face_ids = []
405 | for name, d in data.items():
406 | face_ids.append(d['face'])
407 | bcs.append(d['bc'])
408 | bcs = np.array(bcs, dtype=np.float32)
409 | face_ids = np.array(face_ids, dtype=np.int32)
410 | assert len(bcs) == len(face_ids), (
411 | 'The number of barycentric coordinates must be equal to the faces'
412 | )
413 | assert len(names) == len(face_ids), (
414 | 'The number of names must be equal to the number of '
415 | )
416 |
417 | self.names = names
418 | self.register_buffer('bcs', torch.tensor(bcs, dtype=torch.float32))
419 | self.register_buffer(
420 | 'face_ids', torch.tensor(face_ids, dtype=torch.long))
421 |
422 | def extra_joint_names(self):
423 | ''' Returns the names of the extra joints
424 | '''
425 | return self.names
426 |
427 | def forward(self, vertices, faces):
428 | if len(self.face_ids) < 1:
429 | return []
430 | vertex_ids = faces[self.face_ids].reshape(-1)
431 | # Should be BxNx3x3
432 | triangles = torch.index_select(vertices, 1, vertex_ids).reshape(
433 | -1, len(self.bcs), 3, 3)
434 | return (triangles * self.bcs[None, :, :, None]).sum(dim=2)
435 |
436 | # def to_tensor(array, dtype=torch.float32):
437 | # if torch.is_tensor(array):
438 | # return array
439 | # else:
440 | # return torch.tensor(array, dtype=dtype)
441 |
442 | def to_tensor(array, dtype=torch.float32):
443 | if 'torch.tensor' not in str(type(array)):
444 | return torch.tensor(array, dtype=dtype)
445 |
446 | def to_np(array, dtype=np.float32):
447 | if 'scipy.sparse' in str(type(array)):
448 | array = array.todense()
449 | return np.array(array, dtype=dtype)
450 |
451 | class Struct(object):
452 | def __init__(self, **kwargs):
453 | for key, val in kwargs.items():
454 | setattr(self, key, val)
455 |
--------------------------------------------------------------------------------
/lib/models/nerf.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from .embedding import Embedding
6 |
7 | class DeRF(nn.Module):
8 | def __init__(self,
9 | D=6, W=128,
10 | freqs_xyz=10,
11 | deformation_dim=0,
12 | out_channels=3,
13 | skips=[4]):
14 | """
15 | D: number of layers for density (sigma) encoder
16 | W: number of hidden units in each layer
17 | in_channels: number of input channels for xyz (3+3*10*2=63 by default)
18 | skips: add skip connection in the Dth layer
19 | """
20 | super(DeRF, self).__init__()
21 | self.D = D
22 | self.W = W
23 | self.freqs_xyz = freqs_xyz
24 | self.deformation_dim = deformation_dim
25 | self.skips = skips
26 |
27 | self.in_channels = 3 + 3*freqs_xyz*2 + deformation_dim
28 | self.out_channels = out_channels
29 |
30 | self.encoding_xyz = Embedding(3, freqs_xyz)
31 |
32 | # xyz encoding layers
33 | for i in range(D):
34 | if i == 0:
35 | layer = nn.Linear(self.in_channels, W)
36 | elif i in skips:
37 | layer = nn.Linear(W+self.in_channels, W)
38 | else:
39 | layer = nn.Linear(W, W)
40 | layer = nn.Sequential(layer, nn.ReLU(True))
41 | setattr(self, f"xyz_encoding_{i+1}", layer)
42 |
43 | self.out = nn.Linear(W, self.out_channels)
44 |
45 | def forward(self, xyz, deformation_code=None):
46 | xyz_encoded = self.encoding_xyz(xyz)
47 |
48 | if self.deformation_dim > 0:
49 | xyz_encoded = torch.cat([xyz_encoded, deformation_code], -1)
50 |
51 | xyz_ = xyz_encoded
52 | for i in range(self.D):
53 | if i in self.skips:
54 | xyz_ = torch.cat([xyz_encoded, xyz_], -1)
55 | xyz_ = getattr(self, f"xyz_encoding_{i+1}")(xyz_)
56 | out = self.out(xyz_)
57 |
58 | return out
59 |
60 | class NeRF(nn.Module):
61 | def __init__(self,
62 | D=8, W=256,
63 | freqs_xyz=10, freqs_dir=4,
64 | use_view=True, use_normal=False,
65 | deformation_dim=0, appearance_dim=0,
66 | skips=[4], actvn_type='relu'):
67 | """
68 | D: number of layers for density (sigma) encoder
69 | W: number of hidden units in each layer
70 | in_channels_xyz: number of input channels for xyz (3+3*10*2=63 by default)
71 | in_channels_dir: number of input channels for direction (3+3*4*2=27 by default)
72 | skips: add skip connection in the Dth layer
73 | """
74 | super(NeRF, self).__init__()
75 | self.D = D
76 | self.W = W
77 | self.freqs_xyz = freqs_xyz
78 | self.freqs_dir = freqs_dir
79 | self.deformation_dim = deformation_dim
80 | self.appearance_dim = appearance_dim
81 | self.skips = skips
82 | self.use_view = use_view
83 | self.use_normal = use_normal
84 |
85 | self.encoding_xyz = Embedding(3, freqs_xyz)
86 | if self.use_view:
87 | self.encoding_dir = Embedding(3, freqs_dir)
88 |
89 | self.in_channels_xyz = 3 + 3*freqs_xyz*2 + deformation_dim
90 |
91 | self.in_channels_dir = appearance_dim
92 | if self.use_view:
93 | self.in_channels_dir += 3 + 3*freqs_dir*2
94 | if self.use_normal:
95 | self.in_channels_dir += 3
96 |
97 | if actvn_type == 'relu':
98 | actvn = nn.ReLU(inplace=True)
99 | elif actvn_type == 'leaky_relu':
100 | actvn = nn.LeakyReLU(0.2, inplace=True)
101 | elif actvn_type == 'softplus':
102 | actvn = nn.Softplus(beta=100)
103 | else:
104 | assert NotImplementedError
105 |
106 | # xyz encoding layers
107 | for i in range(D):
108 | if i == 0:
109 | layer = nn.Linear(self.in_channels_xyz, W)
110 | elif i in skips:
111 | layer = nn.Linear(W+self.in_channels_xyz, W)
112 | else:
113 | layer = nn.Linear(W, W)
114 | layer = nn.Sequential(layer, actvn)
115 | setattr(self, f"xyz_encoding_{i+1}", layer)
116 | self.xyz_encoding_final = nn.Linear(W, W)
117 |
118 | # direction encoding layers
119 | self.dir_encoding = nn.Sequential(
120 | nn.Linear(W+self.in_channels_dir, W//2),
121 | nn.ReLU(True))
122 |
123 | # output layers
124 | self.sigma = nn.Linear(W, 1)
125 | self.rgb = nn.Sequential(
126 | nn.Linear(W//2, 3),
127 | nn.Sigmoid())
128 |
129 | def forward(self, xyz, viewdir=None, deformation_code=None, appearance_code=None):
130 | """
131 | Inputs:
132 | x: (B, self.in_channels_xyz(+self.in_channels_dir))
133 | the embedded vector of position and direction
134 | Outputs:
135 | out: (B, 4), rgb and sigma
136 | """
137 | sigma, xyz_encoding_final = self.get_sigma(xyz, deformation_code=deformation_code)
138 |
139 | dir_encoding_input = xyz_encoding_final
140 |
141 | if self.use_view:
142 | viewdir_encoded = self.encoding_dir(viewdir)
143 | dir_encoding_input = torch.cat([dir_encoding_input, viewdir_encoded], -1)
144 | if self.use_normal:
145 | normal = self.get_normal(xyz, deformation_code=deformation_code)
146 | dir_encoding_input = torch.cat([dir_encoding_input, normal], -1)
147 | if self.appearance_dim > 0:
148 | dir_encoding_input = torch.cat([dir_encoding_input, appearance_code], -1)
149 |
150 | dir_encoding = self.dir_encoding(dir_encoding_input)
151 | rgb = self.rgb(dir_encoding)
152 |
153 | return rgb, sigma
154 |
155 | def get_sigma(self, xyz, deformation_code=None, only_sigma=False):
156 |
157 | xyz_encoded = self.encoding_xyz(xyz)
158 |
159 | if self.deformation_dim > 0:
160 | xyz_encoded = torch.cat([xyz_encoded, deformation_code], -1)
161 |
162 | xyz_ = xyz_encoded
163 | for i in range(self.D):
164 | if i in self.skips:
165 | xyz_ = torch.cat([xyz_encoded, xyz_], -1)
166 | xyz_ = getattr(self, f"xyz_encoding_{i+1}")(xyz_)
167 |
168 | sigma = self.sigma(xyz_)
169 |
170 | if only_sigma:
171 | return sigma
172 |
173 | xyz_encoding_final = self.xyz_encoding_final(xyz_)
174 |
175 | return sigma, xyz_encoding_final
176 |
177 | def get_normal(self, xyz, deformation_code=None, delta=0.02):
178 | with torch.set_grad_enabled(True):
179 | xyz.requires_grad_(True)
180 | sigma = self.get_sigma(xyz, deformation_code=deformation_code, only_sigma=True)
181 | alpha = 1 - torch.exp(-delta * torch.relu(sigma))
182 | normal = torch.autograd.grad(
183 | outputs=alpha,
184 | inputs=xyz,
185 | grad_outputs=torch.ones_like(alpha, requires_grad=False, device=alpha.device),
186 | create_graph=True,
187 | retain_graph=True,
188 | only_inputs=True)[0]
189 |
190 | return normal
191 |
192 |
193 | ####
194 |
195 | class MLP(nn.Module):
196 | def __init__(self,
197 | filter_channels,
198 | merge_layer=0,
199 | res_layers=[],
200 | norm='group',
201 | last_op=None):
202 | super(MLP, self).__init__()
203 |
204 | self.filters = nn.ModuleList()
205 | self.norms = nn.ModuleList()
206 | self.merge_layer = merge_layer if merge_layer > 0 else len(filter_channels) // 2
207 | self.res_layers = res_layers
208 | self.norm = norm
209 | self.last_op = last_op
210 |
211 | for l in range(0, len(filter_channels)-1):
212 | if l in self.res_layers:
213 | self.filters.append(nn.Conv1d(
214 | filter_channels[l] + filter_channels[0],
215 | filter_channels[l+1],
216 | 1))
217 | else:
218 | self.filters.append(nn.Conv1d(
219 | filter_channels[l],
220 | filter_channels[l+1],
221 | 1))
222 | if l != len(filter_channels)-2:
223 | if norm == 'group':
224 | self.norms.append(nn.GroupNorm(32, filter_channels[l+1]))
225 | elif norm == 'batch':
226 | self.norms.append(nn.BatchNorm1d(filter_channels[l+1]))
227 |
228 | ## init
229 | # for l in range(0, len(filter_channels)-1):
230 | # conv = self.filters[l]
231 | # conv.weight.data.fill_(0.00001)
232 | # conv.bias.data.fill_(0.0)
233 |
234 | def forward(self, feature):
235 | '''
236 | feature may include multiple view inputs
237 | args:
238 | feature: [B, C_in, N]
239 | return:
240 | [B, C_out, N] prediction
241 | '''
242 | y = feature
243 | tmpy = feature
244 | phi = None
245 | for i, f in enumerate(self.filters):
246 | y = f(
247 | y if i not in self.res_layers
248 | else torch.cat([y, tmpy], 1)
249 | )
250 | if i != len(self.filters)-1:
251 | if self.norm not in ['batch', 'group']:
252 | y = F.leaky_relu(y)
253 | else:
254 | y = F.leaky_relu(self.norms[i](y))
255 | if i == self.merge_layer:
256 | phi = y.clone()
257 |
258 | if self.last_op is not None:
259 | y = self.last_op(y)
260 |
261 | return y #, phi
262 |
263 |
264 | class GeoMLP(nn.Module):
265 | def __init__(self,
266 | filter_channels = [128, 256, 128],
267 | input_dim = 3, embedding_freqs = 10,
268 | output_dim = 3,
269 | cond_dim = 72,
270 | last_op=torch.tanh,
271 | scale=0.1):
272 | super(GeoMLP, self).__init__()
273 | self.input_dim = input_dim
274 | self.cond_dim = cond_dim
275 | self.embedding_freqs = embedding_freqs
276 | self.embedding_dim = input_dim*(2*embedding_freqs+1)
277 | # Embeddings
278 | self.embedding = Embedding(self.input_dim, embedding_freqs) # 10 is the default number
279 |
280 | # xyz encoding layers
281 | filter_channels = [self.embedding_dim + cond_dim] + filter_channels + [output_dim]
282 | self.mlp = MLP(filter_channels, last_op=last_op)
283 | self.scale = scale
284 |
285 | def forward(self, x, cond):
286 | """
287 | Encodes input (xyz+dir) to rgb+sigma (not ready to render yet).
288 | For rendering this ray, please see rendering.py
289 | Inputs:
290 | x: (B, self.in_channels_xyz(+self.in_channels_dir))
291 | the embedded vector of position and direction
292 | sigma_only: whether to infer sigma only. If True,
293 | x is of shape (B, self.in_channels_xyz)
294 | Outputs:
295 | if sigma_ony:
296 | sigma: (B, 1) sigma
297 | else:
298 | out: (B, 4), rgb and sigma
299 | """
300 | # x: [B, nv, 3]
301 | # cond: [B, n_theta]
302 | batch_size, nv, _ = x.shape
303 | pos_embedding = self.embedding(x.reshape(batch_size, -1)).reshape(batch_size, nv, -1)
304 | cond = cond[:,None,:].expand(-1, nv, -1)
305 | inputs = torch.cat([pos_embedding, cond], -1) #[B, nv, n_position+n_theta]
306 | inputs = inputs.permute(0,2,1)
307 | out = self.mlp(inputs).permute(0,2,1)*self.scale
308 | return out
309 |
310 |
--------------------------------------------------------------------------------
/lib/models/siren.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch.nn as nn
3 | import torch
4 | import math
5 | import torch.nn.functional as F
6 |
7 | class Sine(nn.Module):
8 | """Sine Activation Function."""
9 |
10 | def __init__(self):
11 | super().__init__()
12 | def forward(self, x):
13 | return torch.sin(30. * x)
14 |
15 | def sine_init(m):
16 | with torch.no_grad():
17 | if isinstance(m, nn.Linear):
18 | num_input = m.weight.size(-1)
19 | m.weight.uniform_(-np.sqrt(6 / num_input) / 30, np.sqrt(6 / num_input) / 30)
20 |
21 |
22 | def first_layer_sine_init(m):
23 | with torch.no_grad():
24 | if isinstance(m, nn.Linear):
25 | num_input = m.weight.size(-1)
26 | m.weight.uniform_(-1 / num_input, 1 / num_input)
27 |
28 |
29 | def film_sine_init(m):
30 | with torch.no_grad():
31 | if isinstance(m, nn.Linear):
32 | num_input = m.weight.size(-1)
33 | m.weight.uniform_(-np.sqrt(6 / num_input) / 30, np.sqrt(6 / num_input) / 30)
34 |
35 |
36 | def first_layer_film_sine_init(m):
37 | with torch.no_grad():
38 | if isinstance(m, nn.Linear):
39 | num_input = m.weight.size(-1)
40 | m.weight.uniform_(-1 / num_input, 1 / num_input)
41 |
42 |
43 | def kaiming_leaky_init(m):
44 | classname = m.__class__.__name__
45 | if classname.find('Linear') != -1:
46 | torch.nn.init.kaiming_normal_(m.weight, a=0.2, mode='fan_in', nonlinearity='leaky_relu')
47 |
48 | class CustomMappingNetwork(nn.Module):
49 | def __init__(self, z_dim, map_hidden_dim, map_output_dim):
50 | super().__init__()
51 |
52 |
53 |
54 | self.network = nn.Sequential(nn.Linear(z_dim, map_hidden_dim),
55 | nn.LeakyReLU(0.2, inplace=True),
56 |
57 | nn.Linear(map_hidden_dim, map_hidden_dim),
58 | nn.LeakyReLU(0.2, inplace=True),
59 |
60 | nn.Linear(map_hidden_dim, map_hidden_dim),
61 | nn.LeakyReLU(0.2, inplace=True),
62 |
63 | nn.Linear(map_hidden_dim, map_output_dim))
64 |
65 | self.network.apply(kaiming_leaky_init)
66 | with torch.no_grad():
67 | self.network[-1].weight *= 0.25
68 |
69 | def forward(self, z):
70 | frequencies_offsets = self.network(z)
71 | frequencies = frequencies_offsets[..., :frequencies_offsets.shape[-1]//2]
72 | phase_shifts = frequencies_offsets[..., frequencies_offsets.shape[-1]//2:]
73 |
74 | return frequencies, phase_shifts
75 |
76 |
77 | def frequency_init(freq):
78 | def init(m):
79 | with torch.no_grad():
80 | if isinstance(m, nn.Linear):
81 | num_input = m.weight.size(-1)
82 | m.weight.uniform_(-np.sqrt(6 / num_input) / freq, np.sqrt(6 / num_input) / freq)
83 | return init
84 |
85 | class FiLMLayer(nn.Module):
86 | def __init__(self, input_dim, hidden_dim):
87 | super().__init__()
88 | self.layer = nn.Linear(input_dim, hidden_dim)
89 |
90 | def forward(self, x, freq, phase_shift):
91 | x = self.layer(x)
92 | freq = freq.unsqueeze(1).expand_as(x)
93 | phase_shift = phase_shift.unsqueeze(1).expand_as(x)
94 | return torch.sin(freq * x + phase_shift)
95 |
96 |
97 | class TALLSIREN(nn.Module):
98 | """Primary SIREN architecture used in pi-GAN generators."""
99 |
100 | def __init__(self, input_dim=2, z_dim=100, hidden_dim=256, output_dim=1, device=None):
101 | super().__init__()
102 | self.device = device
103 | self.input_dim = input_dim
104 | self.z_dim = z_dim
105 | self.hidden_dim = hidden_dim
106 | self.output_dim = output_dim
107 |
108 | self.network = nn.ModuleList([
109 | FiLMLayer(input_dim, hidden_dim),
110 | FiLMLayer(hidden_dim, hidden_dim),
111 | FiLMLayer(hidden_dim, hidden_dim),
112 | FiLMLayer(hidden_dim, hidden_dim),
113 | FiLMLayer(hidden_dim, hidden_dim),
114 | FiLMLayer(hidden_dim, hidden_dim),
115 | FiLMLayer(hidden_dim, hidden_dim),
116 | FiLMLayer(hidden_dim, hidden_dim),
117 | ])
118 | self.final_layer = nn.Linear(hidden_dim, 1)
119 |
120 | self.color_layer_sine = FiLMLayer(hidden_dim + 3, hidden_dim)
121 | self.color_layer_linear = nn.Sequential(nn.Linear(hidden_dim, 3), nn.Sigmoid())
122 |
123 | self.mapping_network = CustomMappingNetwork(z_dim, 256, (len(self.network) + 1)*hidden_dim*2)
124 |
125 | self.network.apply(frequency_init(25))
126 | self.final_layer.apply(frequency_init(25))
127 | self.color_layer_sine.apply(frequency_init(25))
128 | self.color_layer_linear.apply(frequency_init(25))
129 | self.network[0].apply(first_layer_film_sine_init)
130 |
131 | def forward(self, input, z, ray_directions, **kwargs):
132 | frequencies, phase_shifts = self.mapping_network(z)
133 | return self.forward_with_frequencies_phase_shifts(input, frequencies, phase_shifts, ray_directions, **kwargs)
134 |
135 | def forward_with_frequencies_phase_shifts(self, input, frequencies, phase_shifts, ray_directions, **kwargs):
136 | frequencies = frequencies*15 + 30
137 |
138 | x = input
139 |
140 | for index, layer in enumerate(self.network):
141 | start = index * self.hidden_dim
142 | end = (index+1) * self.hidden_dim
143 | x = layer(x, frequencies[..., start:end], phase_shifts[..., start:end])
144 |
145 | sigma = self.final_layer(x)
146 | rbg = self.color_layer_sine(torch.cat([ray_directions, x], dim=-1), frequencies[..., -self.hidden_dim:], phase_shifts[..., -self.hidden_dim:])
147 | rbg = self.color_layer_linear(rbg)
148 |
149 | return torch.cat([rbg, sigma], dim=-1)
150 |
151 |
152 | class UniformBoxWarp(nn.Module):
153 | def __init__(self, sidelength):
154 | super().__init__()
155 | self.scale_factor = 2/sidelength
156 |
157 | def forward(self, coordinates):
158 | return coordinates * self.scale_factor
159 |
160 | class SPATIALSIRENBASELINE(nn.Module):
161 | """Same architecture as TALLSIREN but adds a UniformBoxWarp to map input points to -1, 1"""
162 |
163 | def __init__(self, input_dim=2, z_dim=100, hidden_dim=256, output_dim=1, device=None):
164 | super().__init__()
165 | self.device = device
166 | self.input_dim = input_dim
167 | self.z_dim = z_dim
168 | self.hidden_dim = hidden_dim
169 | self.output_dim = output_dim
170 |
171 | self.network = nn.ModuleList([
172 | FiLMLayer(3, hidden_dim),
173 | FiLMLayer(hidden_dim, hidden_dim),
174 | FiLMLayer(hidden_dim, hidden_dim),
175 | FiLMLayer(hidden_dim, hidden_dim),
176 | FiLMLayer(hidden_dim, hidden_dim),
177 | FiLMLayer(hidden_dim, hidden_dim),
178 | FiLMLayer(hidden_dim, hidden_dim),
179 | FiLMLayer(hidden_dim, hidden_dim),
180 | ])
181 | self.final_layer = nn.Linear(hidden_dim, 1)
182 |
183 | self.color_layer_sine = FiLMLayer(hidden_dim + 3, hidden_dim)
184 | self.color_layer_linear = nn.Sequential(nn.Linear(hidden_dim, 3))
185 |
186 | self.mapping_network = CustomMappingNetwork(z_dim, 256, (len(self.network) + 1)*hidden_dim*2)
187 |
188 | self.network.apply(frequency_init(25))
189 | self.final_layer.apply(frequency_init(25))
190 | self.color_layer_sine.apply(frequency_init(25))
191 | self.color_layer_linear.apply(frequency_init(25))
192 | self.network[0].apply(first_layer_film_sine_init)
193 |
194 | self.gridwarper = UniformBoxWarp(0.24) # Don't worry about this, it was added to ensure compatibility with another model. Shouldn't affect performance.
195 |
196 | def forward(self, input, z, ray_directions, **kwargs):
197 | frequencies, phase_shifts = self.mapping_network(z)
198 | return self.forward_with_frequencies_phase_shifts(input, frequencies, phase_shifts, ray_directions, **kwargs)
199 |
200 | def forward_with_frequencies_phase_shifts(self, input, frequencies, phase_shifts, ray_directions, **kwargs):
201 | frequencies = frequencies*15 + 30
202 |
203 | input = self.gridwarper(input)
204 | x = input
205 |
206 | for index, layer in enumerate(self.network):
207 | start = index * self.hidden_dim
208 | end = (index+1) * self.hidden_dim
209 | x = layer(x, frequencies[..., start:end], phase_shifts[..., start:end])
210 |
211 | sigma = self.final_layer(x)
212 | rbg = self.color_layer_sine(torch.cat([ray_directions, x], dim=-1), frequencies[..., -self.hidden_dim:], phase_shifts[..., -self.hidden_dim:])
213 | rbg = torch.sigmoid(self.color_layer_linear(rbg))
214 |
215 | return torch.cat([rbg, sigma], dim=-1)
216 |
217 |
218 |
219 | class UniformBoxWarp(nn.Module):
220 | def __init__(self, sidelength):
221 | super().__init__()
222 | self.scale_factor = 2/sidelength
223 |
224 | def forward(self, coordinates):
225 | return coordinates * self.scale_factor
226 |
227 |
228 | def sample_from_3dgrid(coordinates, grid):
229 | """
230 | Expects coordinates in shape (batch_size, num_points_per_batch, 3)
231 | Expects grid in shape (1, channels, H, W, D)
232 | (Also works if grid has batch size)
233 | Returns sampled features of shape (batch_size, num_points_per_batch, feature_channels)
234 | """
235 | coordinates = coordinates.float()
236 | grid = grid.float()
237 |
238 | batch_size, n_coords, n_dims = coordinates.shape
239 | sampled_features = torch.nn.functional.grid_sample(grid.expand(batch_size, -1, -1, -1, -1),
240 | coordinates.reshape(batch_size, 1, 1, -1, n_dims),
241 | mode='bilinear', padding_mode='zeros', align_corners=True)
242 | N, C, H, W, D = sampled_features.shape
243 | sampled_features = sampled_features.permute(0, 4, 3, 2, 1).reshape(N, H*W*D, C)
244 | return sampled_features
245 |
246 |
247 | def modified_first_sine_init(m):
248 | with torch.no_grad():
249 | # if hasattr(m, 'weight'):
250 | if isinstance(m, nn.Linear):
251 | num_input = 3
252 | m.weight.uniform_(-1 / num_input, 1 / num_input)
253 |
254 |
255 | class EmbeddingPiGAN128(nn.Module):
256 | """Smaller architecture that has an additional cube of embeddings. Often gives better fine details."""
257 |
258 | def __init__(self, input_dim=2, z_dim=100, hidden_dim=128, output_dim=1, device=None):
259 | super().__init__()
260 | self.device = device
261 | self.input_dim = input_dim
262 | self.z_dim = z_dim
263 | self.hidden_dim = hidden_dim
264 | self.output_dim = output_dim
265 |
266 | self.network = nn.ModuleList([
267 | FiLMLayer(32 + 3, hidden_dim),
268 | FiLMLayer(hidden_dim, hidden_dim),
269 | FiLMLayer(hidden_dim, hidden_dim),
270 | FiLMLayer(hidden_dim, hidden_dim),
271 | FiLMLayer(hidden_dim, hidden_dim),
272 | FiLMLayer(hidden_dim, hidden_dim),
273 | FiLMLayer(hidden_dim, hidden_dim),
274 | FiLMLayer(hidden_dim, hidden_dim),
275 | ])
276 | print(self.network)
277 |
278 | self.final_layer = nn.Linear(hidden_dim, 1)
279 |
280 | self.color_layer_sine = FiLMLayer(hidden_dim + 3, hidden_dim)
281 | self.color_layer_linear = nn.Sequential(nn.Linear(hidden_dim, 3))
282 |
283 | self.mapping_network = CustomMappingNetwork(z_dim, 256, (len(self.network) + 1)*hidden_dim*2)
284 |
285 | self.network.apply(frequency_init(25))
286 | self.final_layer.apply(frequency_init(25))
287 | self.color_layer_sine.apply(frequency_init(25))
288 | self.color_layer_linear.apply(frequency_init(25))
289 | self.network[0].apply(modified_first_sine_init)
290 |
291 | self.spatial_embeddings = nn.Parameter(torch.randn(1, 32, 96, 96, 96)*0.01)
292 |
293 | # !! Important !! Set this value to the expected side-length of your scene. e.g. for for faces, heads usually fit in
294 | # a box of side-length 0.24, since the camera has such a narrow FOV. For other scenes, with higher FOV, probably needs to be bigger.
295 | self.gridwarper = UniformBoxWarp(0.24)
296 |
297 | def forward(self, input, z, ray_directions, **kwargs):
298 | frequencies, phase_shifts = self.mapping_network(z)
299 | return self.forward_with_frequencies_phase_shifts(input, frequencies, phase_shifts, ray_directions, **kwargs)
300 |
301 | def forward_with_frequencies_phase_shifts(self, input, frequencies, phase_shifts, ray_directions, **kwargs):
302 | frequencies = frequencies*15 + 30
303 |
304 | input = self.gridwarper(input)
305 | shared_features = sample_from_3dgrid(input, self.spatial_embeddings)
306 | x = torch.cat([shared_features, input], -1)
307 |
308 | for index, layer in enumerate(self.network):
309 | start = index * self.hidden_dim
310 | end = (index+1) * self.hidden_dim
311 | x = layer(x, frequencies[..., start:end], phase_shifts[..., start:end])
312 |
313 | sigma = self.final_layer(x)
314 | rbg = self.color_layer_sine(torch.cat([ray_directions, x], dim=-1), frequencies[..., -self.hidden_dim:], phase_shifts[..., -self.hidden_dim:])
315 | rbg = torch.sigmoid(self.color_layer_linear(rbg))
316 |
317 | return torch.cat([rbg, sigma], dim=-1)
318 |
319 |
320 |
321 | class EmbeddingPiGAN256(EmbeddingPiGAN128):
322 | def __init__(self, *args, **kwargs):
323 | super().__init__(*args, **kwargs, hidden_dim=256)
324 | self.spatial_embeddings = nn.Parameter(torch.randn(1, 32, 64, 64, 64)*0.1)
325 |
326 |
327 |
328 |
329 |
330 | class GeoSIREN(nn.Module):
331 | """Primary SIREN architecture used in pi-GAN generators."""
332 |
333 | def __init__(self, input_dim=2, z_dim=100, hidden_dim=256, output_dim=1, device=None, last_op=None, scale=1.):
334 | super().__init__()
335 | self.device = device
336 | self.input_dim = input_dim
337 | self.z_dim = z_dim
338 | self.hidden_dim = hidden_dim
339 | self.output_dim = output_dim
340 | self.scale = scale
341 | self.last_op = last_op
342 |
343 | self.network = nn.ModuleList([
344 | FiLMLayer(input_dim, hidden_dim),
345 | FiLMLayer(hidden_dim, hidden_dim),
346 | FiLMLayer(hidden_dim, hidden_dim),
347 | FiLMLayer(hidden_dim, hidden_dim),
348 | FiLMLayer(hidden_dim, hidden_dim),
349 | FiLMLayer(hidden_dim, hidden_dim),
350 | FiLMLayer(hidden_dim, hidden_dim),
351 | FiLMLayer(hidden_dim, hidden_dim),
352 | ])
353 | self.final_layer = nn.Linear(hidden_dim, output_dim)
354 |
355 | self.color_layer_sine = FiLMLayer(hidden_dim + 3, hidden_dim)
356 | self.color_layer_linear = nn.Sequential(nn.Linear(hidden_dim, 3), nn.Sigmoid())
357 |
358 | self.mapping_network = CustomMappingNetwork(z_dim, 256, (len(self.network) + 1)*hidden_dim*2)
359 |
360 | self.network.apply(frequency_init(25))
361 | self.final_layer.apply(frequency_init(25))
362 | self.color_layer_sine.apply(frequency_init(25))
363 | self.color_layer_linear.apply(frequency_init(25))
364 | self.network[0].apply(first_layer_film_sine_init)
365 |
366 | def forward(self, input, z, **kwargs):
367 | frequencies, phase_shifts = self.mapping_network(z)
368 | return self.forward_with_frequencies_phase_shifts(input, frequencies, phase_shifts, **kwargs)
369 |
370 | def forward_with_frequencies_phase_shifts(self, input, frequencies, phase_shifts, **kwargs):
371 | frequencies = frequencies*15 + 30
372 |
373 | x = input
374 |
375 | for index, layer in enumerate(self.network):
376 | start = index * self.hidden_dim
377 | end = (index+1) * self.hidden_dim
378 | x = layer(x, frequencies[..., start:end], phase_shifts[..., start:end])
379 |
380 | sigma = self.final_layer(x)
381 |
382 | if self.last_op is not None:
383 | sigma = self.last_op(sigma)
384 | sigma = sigma*self.scale
385 |
386 | return sigma #torch.cat([sigma], dim=-1)
387 |
388 |
389 |
--------------------------------------------------------------------------------
/lib/models/smplx.py:
--------------------------------------------------------------------------------
1 | """
2 | original from https://github.com/vchoutas/smplx
3 | modified by Vassilis and Yao
4 | """
5 |
6 | import torch
7 | import torch.nn as nn
8 | import numpy as np
9 | import pickle
10 | import torch.nn.functional as F
11 | import os
12 | import yaml
13 |
14 | from .lbs import Struct, to_tensor, to_np, lbs, vertices2landmarks, JointsFromVerticesSelector, find_dynamic_lmk_idx_and_bcoords
15 |
16 | ## SMPLX
17 | J14_NAMES = [
18 | 'right_ankle',
19 | 'right_knee',
20 | 'right_hip',
21 | 'left_hip',
22 | 'left_knee',
23 | 'left_ankle',
24 | 'right_wrist',
25 | 'right_elbow',
26 | 'right_shoulder',
27 | 'left_shoulder',
28 | 'left_elbow',
29 | 'left_wrist',
30 | 'neck',
31 | 'head',
32 | ]
33 | SMPLX_names = ['pelvis', 'left_hip', 'right_hip', 'spine1', 'left_knee', 'right_knee', 'spine2', 'left_ankle', 'right_ankle', 'spine3', 'left_foot', 'right_foot', 'neck', 'left_collar', 'right_collar', 'head', 'left_shoulder', 'right_shoulder', 'left_elbow', 'right_elbow', 'left_wrist', 'right_wrist', 'jaw', 'left_eye_smplx', 'right_eye_smplx', 'left_index1', 'left_index2', 'left_index3', 'left_middle1', 'left_middle2', 'left_middle3', 'left_pinky1', 'left_pinky2', 'left_pinky3', 'left_ring1', 'left_ring2', 'left_ring3', 'left_thumb1', 'left_thumb2', 'left_thumb3', 'right_index1', 'right_index2', 'right_index3', 'right_middle1', 'right_middle2', 'right_middle3', 'right_pinky1', 'right_pinky2', 'right_pinky3', 'right_ring1', 'right_ring2', 'right_ring3', 'right_thumb1', 'right_thumb2', 'right_thumb3', 'right_eye_brow1', 'right_eye_brow2', 'right_eye_brow3', 'right_eye_brow4', 'right_eye_brow5', 'left_eye_brow5', 'left_eye_brow4', 'left_eye_brow3', 'left_eye_brow2', 'left_eye_brow1', 'nose1', 'nose2', 'nose3', 'nose4', 'right_nose_2', 'right_nose_1', 'nose_middle', 'left_nose_1', 'left_nose_2', 'right_eye1', 'right_eye2', 'right_eye3', 'right_eye4', 'right_eye5', 'right_eye6', 'left_eye4', 'left_eye3', 'left_eye2', 'left_eye1', 'left_eye6', 'left_eye5', 'right_mouth_1', 'right_mouth_2', 'right_mouth_3', 'mouth_top', 'left_mouth_3', 'left_mouth_2', 'left_mouth_1', 'left_mouth_5', 'left_mouth_4', 'mouth_bottom', 'right_mouth_4', 'right_mouth_5', 'right_lip_1', 'right_lip_2', 'lip_top', 'left_lip_2', 'left_lip_1', 'left_lip_3', 'lip_bottom', 'right_lip_3', 'right_contour_1', 'right_contour_2', 'right_contour_3', 'right_contour_4', 'right_contour_5', 'right_contour_6', 'right_contour_7', 'right_contour_8', 'contour_middle', 'left_contour_8', 'left_contour_7', 'left_contour_6', 'left_contour_5', 'left_contour_4', 'left_contour_3', 'left_contour_2', 'left_contour_1', 'head_top', 'left_big_toe', 'left_ear', 'left_eye', 'left_heel', 'left_index', 'left_middle', 'left_pinky', 'left_ring', 'left_small_toe', 'left_thumb', 'nose', 'right_big_toe', 'right_ear', 'right_eye', 'right_heel', 'right_index', 'right_middle', 'right_pinky', 'right_ring', 'right_small_toe', 'right_thumb']
34 | extra_names = ['head_top', 'left_big_toe', 'left_ear', 'left_eye', 'left_heel', 'left_index', 'left_middle', 'left_pinky', 'left_ring', 'left_small_toe', 'left_thumb', 'nose', 'right_big_toe', 'right_ear', 'right_eye', 'right_heel', 'right_index', 'right_middle', 'right_pinky', 'right_ring', 'right_small_toe', 'right_thumb']
35 | SMPLX_names += extra_names
36 |
37 | part_indices = {}
38 | part_indices['body'] = np.array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
39 | 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 123,
40 | 124, 125, 126, 127, 132, 134, 135, 136, 137, 138, 143])
41 | part_indices['torso'] = np.array([ 0, 1, 2, 3, 6, 9, 12, 13, 14, 15, 16, 17, 18,
42 | 19, 22, 23, 24, 55, 56, 57, 58, 59, 76, 77, 78, 79,
43 | 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92,
44 | 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105,
45 | 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118,
46 | 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131,
47 | 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144])
48 | part_indices['head'] = np.array([ 12, 15, 22, 23, 24, 55, 56, 57, 58, 59, 60, 61, 62,
49 | 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75,
50 | 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88,
51 | 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101,
52 | 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114,
53 | 115, 116, 117, 118, 119, 120, 121, 122, 123, 125, 126, 134, 136,
54 | 137])
55 | part_indices['face'] = np.array([ 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66,
56 | 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79,
57 | 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92,
58 | 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105,
59 | 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118,
60 | 119, 120, 121, 122])
61 | part_indices['upper'] = np.array([ 12, 13, 14, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66,
62 | 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79,
63 | 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92,
64 | 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105,
65 | 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118,
66 | 119, 120, 121, 122])
67 | part_indices['hand'] = np.array([ 20, 21, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
68 | 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48,
69 | 49, 50, 51, 52, 53, 54, 128, 129, 130, 131, 133, 139, 140,
70 | 141, 142, 144])
71 | part_indices['left_hand'] = np.array([ 20, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36,
72 | 37, 38, 39, 128, 129, 130, 131, 133])
73 | part_indices['right_hand'] = np.array([ 21, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51,
74 | 52, 53, 54, 139, 140, 141, 142, 144])
75 | # kinematic tree
76 | head_kin_chain = [15,12,9,6,3,0]
77 |
78 | #--smplx joints
79 | # 00 - Global
80 | # 01 - L_Thigh
81 | # 02 - R_Thigh
82 | # 03 - Spine
83 | # 04 - L_Calf
84 | # 05 - R_Calf
85 | # 06 - Spine1
86 | # 07 - L_Foot
87 | # 08 - R_Foot
88 | # 09 - Spine2
89 | # 10 - L_Toes
90 | # 11 - R_Toes
91 | # 12 - Neck
92 | # 13 - L_Shoulder
93 | # 14 - R_Shoulder
94 | # 15 - Head
95 | # 16 - L_UpperArm
96 | # 17 - R_UpperArm
97 | # 18 - L_ForeArm
98 | # 19 - R_ForeArm
99 | # 20 - L_Hand
100 | # 21 - R_Hand
101 | # 22 - Jaw
102 | # 23 - L_Eye
103 | # 24 - R_Eye
104 |
105 | class SMPLX(nn.Module):
106 | """
107 | Given smplx parameters, this class generates a differentiable SMPLX function
108 | which outputs a mesh and 3D joints
109 | """
110 | def __init__(self, config):
111 | super(SMPLX, self).__init__()
112 | print("creating the SMPLX Decoder")
113 | ss = np.load(config.smplx_model_path, allow_pickle=True)
114 | smplx_model = Struct(**ss)
115 |
116 | self.dtype = torch.float32
117 | self.register_buffer('faces_tensor', to_tensor(to_np(smplx_model.f, dtype=np.int64), dtype=torch.long))
118 | # The vertices of the template model
119 | self.register_buffer('v_template', to_tensor(to_np(smplx_model.v_template), dtype=self.dtype))
120 | # The shape components and expression
121 | # expression space is the same as FLAME
122 | shapedirs = to_tensor(to_np(smplx_model.shapedirs), dtype=self.dtype)
123 | shapedirs = torch.cat([shapedirs[:,:,:config.n_shape], shapedirs[:,:,300:300+config.n_exp]], 2)
124 | self.register_buffer('shapedirs', shapedirs)
125 | # The pose components
126 | num_pose_basis = smplx_model.posedirs.shape[-1]
127 | posedirs = np.reshape(smplx_model.posedirs, [-1, num_pose_basis]).T
128 | self.register_buffer('posedirs', to_tensor(to_np(posedirs), dtype=self.dtype))
129 | self.register_buffer('J_regressor', to_tensor(to_np(smplx_model.J_regressor), dtype=self.dtype))
130 | parents = to_tensor(to_np(smplx_model.kintree_table[0])).long(); parents[0] = -1
131 | self.register_buffer('parents', parents)
132 | self.register_buffer('lbs_weights', to_tensor(to_np(smplx_model.weights), dtype=self.dtype))
133 | # for face keypoints
134 | self.register_buffer('lmk_faces_idx', torch.tensor(smplx_model.lmk_faces_idx, dtype=torch.long))
135 | self.register_buffer('lmk_bary_coords', torch.tensor(smplx_model.lmk_bary_coords, dtype=self.dtype))
136 | self.register_buffer('dynamic_lmk_faces_idx', torch.tensor(smplx_model.dynamic_lmk_faces_idx, dtype=torch.long))
137 | self.register_buffer('dynamic_lmk_bary_coords', torch.tensor(smplx_model.dynamic_lmk_bary_coords, dtype=self.dtype))
138 | # pelvis to head, to calculate head yaw angle, then find the dynamic landmarks
139 | self.register_buffer('head_kin_chain', torch.tensor(head_kin_chain, dtype=torch.long))
140 |
141 | self.n_shape = config.n_shape
142 | self.n_pose = num_pose_basis
143 |
144 | #-- initialize parameters
145 | # shape and expression
146 | self.register_buffer('shape_params', nn.Parameter(torch.zeros([1, config.n_shape], dtype=self.dtype), requires_grad=False))
147 | self.register_buffer('expression_params', nn.Parameter(torch.zeros([1, config.n_exp], dtype=self.dtype), requires_grad=False))
148 | # pose: represented as rotation matrx [number of joints, 3, 3]
149 | self.register_buffer('global_pose', nn.Parameter(torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(1,1,1), requires_grad=False))
150 | self.register_buffer('head_pose', nn.Parameter(torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(1,1,1), requires_grad=False))
151 | self.register_buffer('neck_pose', nn.Parameter(torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(1,1,1), requires_grad=False))
152 | self.register_buffer('jaw_pose', nn.Parameter(torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(1,1,1), requires_grad=False))
153 | self.register_buffer('eye_pose', nn.Parameter(torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(2,1,1), requires_grad=False))
154 | self.register_buffer('body_pose', nn.Parameter(torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(21,1,1), requires_grad=False))
155 | self.register_buffer('left_hand_pose', nn.Parameter(torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(15,1,1), requires_grad=False))
156 | self.register_buffer('right_hand_pose', nn.Parameter(torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(15,1,1), requires_grad=False))
157 |
158 | if config.extra_joint_path:
159 | self.extra_joint_selector = JointsFromVerticesSelector(
160 | fname=config.extra_joint_path)
161 | self.use_joint_regressor = True
162 | self.keypoint_names = SMPLX_names
163 | if self.use_joint_regressor:
164 | with open(config.j14_regressor_path, 'rb') as f:
165 | j14_regressor = pickle.load(f, encoding='latin1')
166 | source = []
167 | target = []
168 | for idx, name in enumerate(self.keypoint_names):
169 | if name in J14_NAMES:
170 | source.append(idx)
171 | target.append(J14_NAMES.index(name))
172 | source = np.asarray(source)
173 | target = np.asarray(target)
174 | self.register_buffer('source_idxs', torch.from_numpy(source))
175 | self.register_buffer('target_idxs', torch.from_numpy(target))
176 | joint_regressor = torch.from_numpy(
177 | j14_regressor).to(dtype=torch.float32)
178 | self.register_buffer('extra_joint_regressor', joint_regressor)
179 | self.part_indices = part_indices
180 |
181 | def forward(self, shape_params=None, expression_params=None,
182 | global_pose=None, body_pose=None,
183 | jaw_pose=None, eye_pose=None,
184 | left_hand_pose=None, right_hand_pose=None, full_pose=None,
185 | offset=None, transl=None, return_T = False):
186 | """
187 | Args:
188 | shape_params: [N, number of shape parameters]
189 | expression_params: [N, number of expression parameters]
190 | global_pose: pelvis pose, [N, 1, 3, 3]
191 | body_pose: [N, 21, 3, 3]
192 | jaw_pose: [N, 1, 3, 3]
193 | eye_pose: [N, 2, 3, 3]
194 | left_hand_pose: [N, 15, 3, 3]
195 | right_hand_pose: [N, 15, 3, 3]
196 | Returns:
197 | vertices: [N, number of vertices, 3]
198 | landmarks: [N, number of landmarks (68 face keypoints), 3]
199 | joints: [N, number of smplx joints (145), 3]
200 | """
201 | if shape_params is None:
202 | batch_size = full_pose.shape[0]
203 | shape_params = self.shape_params.expand(batch_size, -1)
204 | else:
205 | batch_size = shape_params.shape[0]
206 | if expression_params is None:
207 | expression_params = self.expression_params.expand(batch_size, -1)
208 |
209 | if full_pose is None:
210 | if global_pose is None:
211 | global_pose = self.global_pose.unsqueeze(0).expand(batch_size, -1, -1, -1)
212 | if body_pose is None:
213 | body_pose = self.body_pose.unsqueeze(0).expand(batch_size, -1, -1, -1)
214 | if jaw_pose is None:
215 | jaw_pose = self.jaw_pose.unsqueeze(0).expand(batch_size, -1, -1, -1)
216 | if eye_pose is None:
217 | eye_pose = self.eye_pose.unsqueeze(0).expand(batch_size, -1, -1, -1)
218 | if left_hand_pose is None:
219 | left_hand_pose = self.left_hand_pose.unsqueeze(0).expand(batch_size, -1, -1, -1)
220 | if right_hand_pose is None:
221 | right_hand_pose = self.right_hand_pose.unsqueeze(0).expand(batch_size, -1, -1, -1)
222 | full_pose = torch.cat([global_pose, body_pose,
223 | jaw_pose, eye_pose,
224 | left_hand_pose, right_hand_pose], dim=1)
225 |
226 | shape_components = torch.cat([shape_params, expression_params], dim=1)
227 | if offset is not None:
228 | if len(offset.shape) == 2:
229 | template_vertices = (self.v_template+offset).unsqueeze(0).expand(batch_size, -1, -1)
230 | else:
231 | template_vertices = self.v_template.unsqueeze(0).expand(batch_size, -1, -1) + offset
232 | else:
233 | template_vertices = self.v_template.unsqueeze(0).expand(batch_size, -1, -1)
234 |
235 | # smplx
236 | vertices, joints, A, T, shape_offsets, pose_offsets = lbs(shape_components, full_pose, template_vertices,
237 | self.shapedirs, self.posedirs,
238 | self.J_regressor, self.parents,
239 | self.lbs_weights, dtype=self.dtype,
240 | pose2rot = False)
241 | # face dynamic landmarks
242 | lmk_faces_idx = self.lmk_faces_idx.unsqueeze(dim=0).expand(batch_size, -1)
243 | lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).expand(batch_size, -1, -1)
244 | dyn_lmk_faces_idx, dyn_lmk_bary_coords = (
245 | find_dynamic_lmk_idx_and_bcoords(
246 | vertices, full_pose,
247 | self.dynamic_lmk_faces_idx,
248 | self.dynamic_lmk_bary_coords,
249 | self.head_kin_chain)
250 | )
251 | lmk_faces_idx = torch.cat([lmk_faces_idx, dyn_lmk_faces_idx], 1)
252 | lmk_bary_coords = torch.cat([lmk_bary_coords, dyn_lmk_bary_coords], 1)
253 | landmarks = vertices2landmarks(vertices, self.faces_tensor,
254 | lmk_faces_idx,
255 | lmk_bary_coords)
256 |
257 | final_joint_set = [joints, landmarks]
258 | if hasattr(self, 'extra_joint_selector'):
259 | # Add any extra joints that might be needed
260 | extra_joints = self.extra_joint_selector(vertices, self.faces_tensor)
261 | final_joint_set.append(extra_joints)
262 | # Create the final joint set
263 | joints = torch.cat(final_joint_set, dim=1)
264 | if self.use_joint_regressor:
265 | reg_joints = torch.einsum(
266 | 'ji,bik->bjk', self.extra_joint_regressor, vertices)
267 | joints[:, self.source_idxs] = (
268 | joints[:, self.source_idxs].detach() * 0.0 +
269 | reg_joints[:, self.target_idxs] * 1.0
270 | )
271 | ### translate z.
272 | # original: -0.3 ~ 0.5
273 | # now: + 0.5
274 | # vertices[:,:,-1] = vertices[:,:,-1] + 1
275 | if transl is not None:
276 | joints = joints + transl.unsqueeze(dim=1)
277 | vertices = vertices + transl.unsqueeze(dim=1)
278 | if return_T:
279 | A[..., :3, 3] += transl.unsqueeze(dim=1)
280 | T[..., :3, 3] += transl.unsqueeze(dim=1)
281 |
282 | if return_T:
283 | return vertices, landmarks, joints, A, T, shape_offsets, pose_offsets
284 | else:
285 | return vertices, landmarks, joints
286 |
287 | def pose_abs2rel(self, global_pose, body_pose, abs_joint = 'head'):
288 | ''' change absolute pose to relative pose
289 | Basic knowledge for SMPLX kinematic tree:
290 | absolute pose = parent pose * relative pose
291 | Here, pose must be represented as rotation matrix (batch_sizexnx3x3)
292 | '''
293 | if abs_joint == 'head':
294 | # Pelvis -> Spine 1, 2, 3 -> Neck -> Head
295 | kin_chain = [15, 12, 9, 6, 3, 0]
296 | elif abs_joint == 'right_wrist':
297 | # Pelvis -> Spine 1, 2, 3 -> right Collar -> right shoulder
298 | # -> right elbow -> right wrist
299 | kin_chain = [21, 19, 17, 14, 9, 6, 3, 0]
300 | elif abs_joint == 'left_wrist':
301 | # Pelvis -> Spine 1, 2, 3 -> Left Collar -> Left shoulder
302 | # -> Left elbow -> Left wrist
303 | kin_chain = [20, 18, 16, 13, 9, 6, 3, 0]
304 | else:
305 | raise NotImplementedError(
306 | f'pose_abs2rel does not support: {abs_joint}')
307 |
308 | batch_size = global_pose.shape[0]
309 | dtype = global_pose.dtype
310 | device = global_pose.device
311 | full_pose = torch.cat([global_pose, body_pose], dim=1)
312 | rel_rot_mat = torch.eye(
313 | 3, device=device,
314 | dtype=dtype).unsqueeze_(dim=0).repeat(batch_size, 1, 1)
315 | for idx in kin_chain[1:]:
316 | rel_rot_mat = torch.bmm(full_pose[:, idx], rel_rot_mat)
317 |
318 | # This contains the absolute pose of the parent
319 | abs_parent_pose = rel_rot_mat.detach()
320 | # Let's assume that in the input this specific joint is predicted as an absolute value
321 | abs_joint_pose = body_pose[:, kin_chain[0] - 1]
322 | # abs_head = parents(abs_neck) * rel_head ==> rel_head = abs_neck.T * abs_head
323 | rel_joint_pose = torch.matmul(
324 | abs_parent_pose.reshape(-1, 3, 3).transpose(1, 2),
325 | abs_joint_pose.reshape(-1, 3, 3))
326 | # Replace the new relative pose
327 | body_pose[:, kin_chain[0] - 1, :, :] = rel_joint_pose
328 | return body_pose
329 |
--------------------------------------------------------------------------------
/lib/utils/camera_util.py:
--------------------------------------------------------------------------------
1 | ''' Projections
2 | Pinhole Camera Model
3 |
4 | Ref:
5 | http://web.stanford.edu/class/cs231a/lectures/lecture2_camera_models.pdf
6 | https://github.com/YadiraF/face3d/blob/master/face3d/mesh/transform.py
7 | https://github.com/daniilidis-group/neural_renderer/blob/master/neural_renderer
8 | https://www.scratchapixel.com/lessons/mathematics-physics-for-computer-graphics/lookat-function
9 | https://www.scratchapixel.com/lessons/3d-basic-rendering/3d-viewing-pinhole-camera/implementing-virtual-pinhole-camera
10 | https://kornia.readthedocs.io/en/v0.1.2/pinhole.html
11 |
12 | Pinhole Camera Model:
13 | 1. world space to camera space:
14 | Normally, the object is represented in the world reference system,
15 | need to first map it into camera system/coordinates/space
16 | P_camera = [R|t]P_world
17 | [R|t]: camera to world transformation matrix, defines the camera position and oritation
18 | Given: where the camera is in the world system. (extrinsic/external)
19 | represented by:
20 | look at camera: eye position, at, up direction
21 | 2. camera space to image space:
22 | Then project the obj into image plane
23 | P_image = K[I|0]P_camera
24 | K = [[fx, 0, cx],
25 | [ 0, fy, cy],
26 | [ 0, 0, 1]]
27 | Given: settings of te camera. (intrinsic/internal)
28 | represented by:
29 | focal lengh, image/film size
30 | or fov, near, far
31 | Angle of View: computed from the focal length and the film size parameters.
32 | perspective projection:
33 | x' = f*x/z; y' = f*y/z
34 |
35 | Finally:
36 | P_image = MXP_world
37 | M = K[R T]
38 | homo: 4x4
39 | '''
40 | import torch
41 | import numpy as np
42 | import torch.nn.functional as F
43 |
44 | # --------------------- 0. camera projection
45 | def transform(points, intrinsic, extrinsic):
46 | ''' Perspective projection
47 | Args:
48 | points: [bz, np, 3]
49 | intrinsic: [bz, 3, 3]
50 | extrinsic: [bz, 3, 4]
51 | '''
52 | points = torch.matmul(points, extrinsic[:,:3,:3].transpose(1,2)) + extrinsic[:,:3,3][:,None,:]
53 | points = torch.matmul(points, intrinsic.transpose(1,2))
54 |
55 | ## if homo
56 | # vertices_homo = camera_util.homogeneous(mesh['vertices'])
57 | # transformed_verts = torch.matmul(vertices_homo, extrinsic.transpose(1,2))
58 | return points
59 |
60 | def perspective_project(points, focal=None, image_hw=None, extrinsic=None, transl=None):
61 | ''' points from world space to ndc space (for pytorch3d rendering)
62 | TODO
63 | '''
64 | batch_size = points.shape[0]
65 | device = points.device
66 | dtype = points.dtype
67 | # non homo
68 | if points.shape[-1] == 3:
69 | if transl is not None:
70 | points = points + transl[:,None,:]
71 | if extrinsic is not None:
72 | points = torch.matmul(points, extrinsic[:,:3,:3].transpose(1,2)) + extrinsic[:,:3,3][:,None,:]
73 | if focal is not None:
74 | # import ipdb; ipdb.set_trace()
75 | if image_hw is not None:
76 | H, W = image_hw
77 | fx = 2*focal/H
78 | fy = 2*focal/W
79 | else:
80 | fx = fy = focal
81 |
82 | # 2/H is for normalization
83 | # intrinsic = torch.tensor(
84 | # [[fx, 0, 0],
85 | # [0, fy, 0],
86 | # [0, 0, 1]], device=device, dtype=dtype)[None,...].repeat(batch_size, 1, 1)
87 | intrinsic = torch.tensor(
88 | [[1, 0, 0],
89 | [0, 1, 0],
90 | [0, 0, 1]], device=device, dtype=dtype)[None,...].repeat(batch_size, 1, 1)
91 | intrinsic_xy = intrinsic[:,:2]*fx
92 | intrinsic_z = intrinsic[:,2:]
93 | intrinsic = torch.cat([intrinsic_xy, intrinsic_z], dim = 1)
94 | # if points.requires_grad:
95 | # import ipdb; ipdb.set_trace()
96 | points = torch.matmul(points, intrinsic.transpose(1,2))
97 | # perspective distortion
98 | # points[:,:,:2] = points[:,:,:2]/(points[:,:,[2]]+1e-5) # inplace, has problem for gradient
99 | z = points[:,:,[2]]
100 | xy = points[:,:,:2]/(z+1e-6)
101 | points = torch.cat([xy, z], dim=-1)
102 | return points
103 |
104 | def perspective_project_inv(points, focal=None, image_hw=None, extrinsic=None, transl=None):
105 | ''' points from world space to ndc space (for pytorch3d rendering)
106 | TODO
107 | '''
108 | batch_size = points.shape[0]
109 | device = points.device
110 | dtype = points.dtype
111 | # non homo
112 | if points.shape[-1] == 3:
113 | if focal is not None:
114 | # import ipdb; ipdb.set_trace()
115 | if image_hw is not None:
116 | H, W = image_hw
117 | fx = 2*focal/H
118 | fy = 2*focal/W
119 | else:
120 | fx = fy = focal
121 | # 2/H is for normalization
122 | intrinsic = torch.tensor(
123 | [[1, 0, 0],
124 | [0, 1, 0],
125 | [0, 0, 1]], device=device, dtype=dtype)[None,...].repeat(batch_size, 1, 1)
126 | intrinsic_xy = intrinsic[:,:2]*fx
127 | intrinsic_z = intrinsic[:,2:]
128 | intrinsic = torch.cat([intrinsic_xy, intrinsic_z], dim = 1)
129 |
130 | z = points[:,:,[2]]
131 | xy = points[:,:,:2]*(z+1e-5)
132 | points = torch.cat([xy, z], dim=-1)
133 |
134 | intrinsic = torch.inverse(intrinsic)
135 | points = torch.matmul(points, intrinsic.transpose(1,2))
136 | if transl is not None:
137 | points = points - transl[:,None,:]
138 | if extrinsic is not None:
139 | points = torch.matmul(points, extrinsic[:,:3,:3].transpose(1,2)) + extrinsic[:,:3,3][:,None,:]
140 | return points
141 |
142 |
143 | # TODO: homo
144 |
145 | # --------------------- 1. world space to camera space:
146 | def look_at(eye, at=[0, 0, 0], up=[0, 1, 0]):
147 | """
148 | "Look at" transformation of vertices.
149 | standard camera space:
150 | camera located at the origin.
151 | looking down negative z-axis.
152 | vertical vector is y-axis.
153 | Xcam = R(X - C)
154 | Homo: [[R, -RC],
155 | [0, 1]]
156 | Args:
157 | eye: [3,] the XYZ world space position of the camera.
158 | at: [3,] a position along the center of the camera's gaze.
159 | up: [3,] up direction
160 | Returns:
161 | extrinsic: R, t
162 | """
163 | device = eye.device
164 | # if list or tuple convert to numpy array
165 | if isinstance(at, list) or isinstance(at, tuple):
166 | at = torch.tensor(at, dtype=torch.float32, device=device)
167 | # if numpy array convert to tensor
168 | elif isinstance(at, np.ndarray):
169 | at = torch.from_numpy(at).to(device)
170 | elif torch.is_tensor(at):
171 | at.to(device)
172 |
173 | if isinstance(up, list) or isinstance(up, tuple):
174 | up = torch.tensor(up, dtype=torch.float32, device=device)
175 | elif isinstance(up, np.ndarray):
176 | up = torch.from_numpy(up).to(device)
177 | elif torch.is_tensor(up):
178 | up.to(device)
179 |
180 | if isinstance(eye, list) or isinstance(eye, tuple):
181 | eye = torch.tensor(eye, dtype=torch.float32, device=device)
182 | elif isinstance(eye, np.ndarray):
183 | eye = torch.from_numpy(eye).to(device)
184 | elif torch.is_tensor(eye):
185 | eye = eye.to(device)
186 |
187 | batch_size = eye.shape[0]
188 | if eye.ndimension() == 1:
189 | eye = eye[None, :].repeat(batch_size, 1)
190 | if at.ndimension() == 1:
191 | at = at[None, :].repeat(batch_size, 1)
192 | if up.ndimension() == 1:
193 | up = up[None, :].repeat(batch_size, 1)
194 |
195 | # create new axes
196 | # eps is chosen as 0.5 to match the chainer version
197 | z_axis = F.normalize(at - eye, eps=1e-5)
198 | x_axis = F.normalize(torch.cross(up, z_axis), eps=1e-5)
199 | y_axis = F.normalize(torch.cross(z_axis, x_axis), eps=1e-5)
200 | # create rotation matrix: [bs, 3, 3]
201 | r = torch.cat((x_axis[:, None, :], y_axis[:, None, :], z_axis[:, None, :]), dim=1)
202 |
203 | camera_R = r
204 | camera_t = eye
205 | # Note: x_new = R(x - t)
206 | return camera_R, camera_t
207 |
208 | def get_extrinsic(R, t, homo=False):
209 | batch_size = R.shape[0]
210 | device = R.device
211 | extrinsic = torch.eye(4, device=device).repeat(batch_size, 1, 1)
212 | extrinsic[:, :3, :3] = R
213 | extrinsic[:, :3, [-1]] = -torch.matmul(R, t[:,:,None])
214 | if homo:
215 | return extrinsic
216 | else:
217 | return extrinsic[:,:3,:]
218 |
219 | #------------------------- 2. camera space to image space (perspective projection):
220 | def get_intrinsic(focal, H, W, cx=0., cy=0., homo=False, batch_size=1., device='cuda:0'):
221 | '''
222 | given different control parameteres
223 | TODO: generate intrinsic matrix from other inputs
224 | P = np.array([[near/right, 0, 0, 0],
225 | [0, near/top, 0, 0],
226 | [0, 0, -(far+near)/(far-near), -2*far*near/(far-near)],
227 | [0, 0, -1, 0]])
228 | '''
229 | intrinsic = torch.eye(4, device=device).repeat(batch_size, 1, 1)
230 | f = focal
231 | K = torch.tensor(
232 | [[f/(H/2), 0, cx],
233 | [0, f/(W/2), cy],
234 | [0, 0, 1]], device=device, dtype=torch.float32)[None,...]
235 | intrinsic[:, :3, :3] = K
236 | if homo:
237 | return intrinsic
238 | else:
239 | return intrinsic[:,:3,:3]
240 |
241 |
242 | #------------------------- composite intrinsic and extrinsic into one matrix
243 | def compose_matrix(K, R, t):
244 | '''
245 | Args:
246 | K: [N, 3, 3]
247 | R: [N, 3, 3]
248 | t: [N, 3]
249 | Returns:
250 | P: [N, 4, 4]
251 | ## test if homo is the same as no homo:
252 | batch_size, nv, _ = trans_verts.shape
253 | trans_verts = torch.cat([trans_verts, torch.ones([batch_size, nv, 1], device=trans_verts.device)], dim=-1)
254 | trans_verts = torch.matmul(trans_verts, P.transpose(1,2))[:,:,:3]
255 | (tested, the same)
256 | # t = -Rt
257 | '''
258 | batch_size = K.shape[0]
259 | device = K.device
260 | intrinsic = torch.eye(4, device=device).repeat(batch_size, 1, 1)
261 | extrinsic = torch.eye(4, device=device).repeat(batch_size, 1, 1)
262 | intrinsic[:, :3, :3] = K
263 | extrinsic[:, :3, :3] = R
264 | # import ipdb; ipdb.set_trace()
265 | extrinsic[:, :3, [-1]] = -torch.matmul(R, t[:,:,None])
266 | P = torch.matmul(intrinsic, extrinsic)
267 | return P, intrinsic, extrinsic
268 |
269 | # def perspective_project(vertices, K, R, t, dist_coeffs, orig_size, eps=1e-9):
270 | # '''
271 | # Calculate projective transformation of vertices given a projection matrix
272 | # Input parameters:
273 | # K: batch_size * 3 * 3 intrinsic camera matrix
274 | # R, t: batch_size * 3 * 3, batch_size * 1 * 3 extrinsic calibration parameters
275 | # dist_coeffs: vector of distortion coefficients
276 | # orig_size: original size of image captured by the camera
277 | # Returns: For each point [X,Y,Z] in world coordinates [u,v,z] where u,v are the coordinates of the projection in
278 | # pixels and z is the depth
279 | # '''
280 | # # instead of P*x we compute x'*P'
281 | # vertices = torch.matmul(vertices, R.transpose(2,1)) + t
282 | # x, y, z = vertices[:, :, 0], vertices[:, :, 1], vertices[:, :, 2]
283 | # x_ = x / (z + eps)
284 | # y_ = y / (z + eps)
285 |
286 | # # Get distortion coefficients from vector
287 | # k1 = dist_coeffs[:, None, 0]
288 | # k2 = dist_coeffs[:, None, 1]
289 | # p1 = dist_coeffs[:, None, 2]
290 | # p2 = dist_coeffs[:, None, 3]
291 | # k3 = dist_coeffs[:, None, 4]
292 |
293 | # # we use x_ for x' and x__ for x'' etc.
294 | # r = torch.sqrt(x_ ** 2 + y_ ** 2)
295 | # x__ = x_*(1 + k1*(r**2) + k2*(r**4) + k3*(r**6)) + 2*p1*x_*y_ + p2*(r**2 + 2*x_**2)
296 | # y__ = y_*(1 + k1*(r**2) + k2*(r**4) + k3 *(r**6)) + p1*(r**2 + 2*y_**2) + 2*p2*x_*y_
297 | # vertices = torch.stack([x__, y__, torch.ones_like(z)], dim=-1)
298 | # vertices = torch.matmul(vertices, K.transpose(1,2))
299 | # u, v = vertices[:, :, 0], vertices[:, :, 1]
300 | # v = orig_size - v
301 | # # map u,v from [0, img_size] to [-1, 1] to use by the renderer
302 | # u = 2 * (u - orig_size / 2.) / orig_size
303 | # v = 2 * (v - orig_size / 2.) / orig_size
304 | # vertices = torch.stack([u, v, z], dim=-1)
305 | # return vertices
306 |
307 | def to_homo(points):
308 | '''
309 | points: [N, num of points, 2/3]
310 | '''
311 | batch_size, num, _ = points.shape
312 | points = torch.cat([points, torch.ones([batch_size, num, 1], device=points.device, dtype=points.dtype)], dim=-1)
313 | return points
314 |
315 | def homogeneous(points):
316 | """
317 | Concat 1 to each point
318 | :param points (..., 3)
319 | :return (..., 4)
320 | """
321 | return F.pad(points, (0, 1), "constant", 1.0)
322 |
323 |
324 | def batch_orth_proj(X, camera):
325 | ''' orthgraphic projection
326 | X: 3d vertices, [bz, n_point, 3]
327 | camera: scale and translation, [bz, 3], [scale, tx, ty]
328 | '''
329 | camera = camera.clone().view(-1, 1, 3)
330 | X_trans = X[:, :, :2] + camera[:, :, 1:]
331 | X_trans = torch.cat([X_trans, X[:,:,2:]], 2)
332 | shape = X_trans.shape
333 | Xn = (camera[:, :, 0:1] * X_trans)
334 | return Xn
--------------------------------------------------------------------------------
/lib/utils/config.py:
--------------------------------------------------------------------------------
1 | from yacs.config import CfgNode as CN
2 | import argparse
3 | import yaml
4 | import os
5 |
6 | cfg = CN()
7 |
8 | workdir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
9 |
10 | # project settings
11 | cfg.output_dir = ''
12 | cfg.group = 'test'
13 | cfg.exp_name = None
14 | cfg.device = 'cuda:0'
15 |
16 | # load models
17 | cfg.ckpt_path = None
18 | cfg.nerf_ckpt_path = 'exps/mpiis/DSC_7157/model.tar'
19 | cfg.mesh_ckpt_path = ''
20 | cfg.pose_ckpt_path = ''
21 |
22 | # ---------------------------------------------------------------------------- #
23 | # Options for SCARF model
24 | # ---------------------------------------------------------------------------- #
25 | cfg.depth_std = 0.
26 | cfg.opt_nerf_pose = True
27 | cfg.opt_beta = True
28 | cfg.opt_mesh = True
29 | cfg.sample_patch_rays = False
30 | cfg.sample_patch_size = 32
31 |
32 | # merf model
33 | cfg.mesh_offset_scale = 0.01
34 | cfg.exclude_hand = False
35 | cfg.use_perspective = False
36 | cfg.cam_mode = 'orth'
37 | cfg.use_mesh = True
38 | cfg.mesh_only_body = False
39 | cfg.use_highres_smplx = True
40 | cfg.freqs_tex = 10
41 | cfg.tex_detach_verts = False
42 | cfg.tex_network = 'siren' # siren
43 | cfg.use_nerf = True
44 | cfg.use_fine = True
45 | cfg.share_fine = False
46 | cfg.freqs_xyz = 10
47 | cfg.freqs_dir = 4
48 | cfg.use_view = False
49 | # skinning
50 | cfg.use_outer_mesh = False
51 | cfg.lbs_map = False
52 | cfg.k_neigh = 1
53 | cfg.weighted_neigh = False
54 | cfg.use_valid = False
55 | cfg.dis_threshold = 0.01
56 | ## deformation code
57 | cfg.use_deformation = False #'opt_code'
58 | cfg.deformation_type = '' #'opt_code'
59 | cfg.deformation_dim = 0
60 | cfg.latent_dim = 0
61 | cfg.pose_dim = 69
62 | ## appearance code
63 | cfg.use_appearance = False
64 | cfg.appearance_type = '' #'opt_code' # or pose
65 | cfg.appearance_dim = 0 # if pose, should be 3
66 | ## mesh tex code
67 | cfg.use_texcond = False
68 | cfg.texcond_type = '' #'opt_code' # or pose
69 | cfg.texcond_dim = 0 # if pose, should be 3
70 |
71 | # nerf rendering
72 | cfg.n_samples = 64
73 | cfg.n_importance = 32
74 | cfg.n_depth = 0
75 | cfg.chunk = 32*32 # chunk size to split the input to avoid OOM
76 | cfg.query_inside = False
77 | cfg.white_bkgd = True
78 |
79 | ####--------pose model
80 | cfg.opt_pose = False
81 | cfg.opt_exp = False
82 | cfg.opt_cam = False
83 | cfg.opt_appearance = False
84 | cfg.opt_focal = False
85 |
86 | # ---------------------------------------------------------------------------- #
87 | # Options for Training
88 | # ---------------------------------------------------------------------------- #
89 | cfg.train = CN()
90 | cfg.train.optimizer = 'adam'
91 | cfg.train.resume = True
92 | cfg.train.batch_size = 1
93 | cfg.train.max_epochs = 100000
94 | cfg.train.max_steps = 100*1000
95 | cfg.train.lr = 5e-4
96 | cfg.train.pose_lr = 5e-4
97 | cfg.train.tex_lr = 5e-4
98 | cfg.train.geo_lr = 5e-4
99 | cfg.train.decay_steps = [50]
100 | cfg.train.decay_gamma = 0.5
101 | cfg.train.precrop_iters = 400
102 | cfg.train.precrop_frac = 0.5
103 | cfg.train.coarse_iters = 0 # in the begining, only train coarse
104 | # logger
105 | cfg.train.log_dir = 'logs'
106 | cfg.train.log_steps = 10
107 | cfg.train.vis_dir = 'train_images'
108 | cfg.train.vis_steps = 200
109 | cfg.train.write_summary = True
110 | cfg.train.checkpoint_steps = 500
111 | cfg.train.val_steps = 200
112 | cfg.train.val_vis_dir = 'val_images'
113 | cfg.train.eval_steps = 5000
114 |
115 | # ---------------------------------------------------------------------------- #
116 | # Options for Dataset
117 | # ---------------------------------------------------------------------------- #
118 | cfg.dataset = CN()
119 | cfg.dataset.type = 'scarf'
120 | cfg.dataset.path = ''
121 | cfg.dataset.white_bg = True
122 | cfg.dataset.subjects = None
123 | cfg.dataset.n_subjects = 1
124 | cfg.dataset.image_size = 512
125 | cfg.dataset.num_workers = 4
126 | cfg.dataset.n_images = 1000000
127 | cfg.dataset.updated_params_path = ''
128 | cfg.dataset.load_gt_pose = True
129 | cfg.dataset.load_normal = False
130 | cfg.dataset.load_perspective = False
131 |
132 | # training setting
133 | cfg.dataset.train = CN()
134 | cfg.dataset.train.cam_list = []
135 | cfg.dataset.train.frame_start = 0
136 | cfg.dataset.train.frame_end = 10000
137 | cfg.dataset.train.frame_step = 4
138 | cfg.dataset.val = CN()
139 | cfg.dataset.val.cam_list = []
140 | cfg.dataset.val.frame_start = 400
141 | cfg.dataset.val.frame_end = 500
142 | cfg.dataset.val.frame_step = 4
143 | cfg.dataset.test = CN()
144 | cfg.dataset.test.cam_list = []
145 | cfg.dataset.test.frame_start = 400
146 | cfg.dataset.test.frame_end = 500
147 | cfg.dataset.test.frame_step = 4
148 |
149 | # ---------------------------------------------------------------------------- #
150 | # Options for losses
151 | # ---------------------------------------------------------------------------- #
152 | cfg.loss = CN()
153 | cfg.loss.w_rgb = 1.
154 | cfg.loss.w_patch_mrf = 0.#0005
155 | cfg.loss.w_patch_perceptual = 0.#0005
156 | cfg.loss.w_alpha = 0.
157 | cfg.loss.w_xyz = 1.
158 | cfg.loss.w_depth = 0.
159 | cfg.loss.w_depth_close = 0.
160 | cfg.loss.reg_depth = 0.
161 | cfg.loss.use_mse = False
162 | cfg.loss.mesh_w_rgb = 1.
163 | cfg.loss.mesh_w_normal = 0.1
164 | cfg.loss.mesh_w_mrf = 0.
165 | cfg.loss.mesh_w_perceptual = 0.
166 | cfg.loss.mesh_w_alpha = 0.
167 | cfg.loss.mesh_w_alpha_skin = 0.
168 | cfg.loss.skin_consistency_type = 'verts_all_mean'
169 | cfg.loss.mesh_skin_consistency = 0.001
170 | cfg.loss.mesh_inside_mask = 100.
171 | cfg.loss.nerf_hard = 0.
172 | cfg.loss.nerf_hard_scale = 1.
173 | cfg.loss.mesh_reg_wdecay = 0.
174 | # regs
175 | cfg.loss.geo_reg = True
176 | cfg.loss.reg_beta_l1 = 1e-4
177 | cfg.loss.reg_cam_l1 = 1e-4
178 | cfg.loss.reg_pose_l1 = 1e-4
179 | cfg.loss.reg_a_norm = 1e-4
180 | cfg.loss.reg_beta_temp = 1e-4
181 | cfg.loss.reg_cam_temp = 1e-4
182 | cfg.loss.reg_pose_temp = 1e-4
183 | cfg.loss.nerf_reg_dxyz_w = 1e-4
184 | ##
185 | cfg.loss.reg_lap_w = 1.0
186 | cfg.loss.reg_edge_w = 10.0
187 | cfg.loss.reg_normal_w = 0.01
188 | cfg.loss.reg_offset_w = 100.
189 | cfg.loss.reg_offset_w_face = 500.
190 | cfg.loss.reg_offset_w_body = 0.
191 | cfg.loss.use_new_edge_loss = False
192 | ## new
193 | cfg.loss.pose_reg = False
194 | cfg.loss.background_reg = False
195 | cfg.loss.nerf_reg_normal_w = 0. #0.01
196 |
197 | # ---------------------------------------------------------------------------- #
198 | # Options for Body model
199 | # ---------------------------------------------------------------------------- #
200 | cfg.data_dir = os.path.join(workdir, 'data')
201 | cfg.model = CN()
202 | cfg.model.highres_path = os.path.join(cfg.data_dir, 'subdiv_level_1')
203 | cfg.model.topology_path = os.path.join(cfg.data_dir, 'SMPL_X_template_FLAME_uv.obj')
204 | cfg.model.smplx_model_path = os.path.join(cfg.data_dir, 'SMPLX_NEUTRAL_2020.npz')
205 | cfg.model.extra_joint_path = os.path.join(cfg.data_dir, 'smplx_extra_joints.yaml')
206 | cfg.model.j14_regressor_path = os.path.join(cfg.data_dir, 'SMPLX_to_J14.pkl')
207 | cfg.model.mano_ids_path = os.path.join(cfg.data_dir, 'MANO_SMPLX_vertex_ids.pkl')
208 | cfg.model.flame_vertex_masks_path = os.path.join(cfg.data_dir, 'FLAME_masks.pkl')
209 | cfg.model.flame_ids_path = os.path.join(cfg.data_dir, 'SMPL-X__FLAME_vertex_ids.npy')
210 | cfg.model.n_shape = 10
211 | cfg.model.n_exp = 10
212 |
213 | def get_cfg_defaults():
214 | """Get a yacs CfgNode object with default values for my_project."""
215 | # Return a clone so that the defaults will not be altered
216 | # This is for the "local variable" use pattern
217 | return cfg.clone()
218 |
219 | def update_cfg(cfg, cfg_file):
220 | cfg.merge_from_file(cfg_file)
221 | return cfg.clone()
222 |
223 | def parse_args():
224 | parser = argparse.ArgumentParser()
225 | parser.add_argument('--cfg', type=str, default = os.path.join(os.path.join(root_dir, 'configs/nerf_pl'), 'test.yaml'), help='cfg file path', )
226 | parser.add_argument('--mode', type=str, default = 'train', help='mode: train, test')
227 | parser.add_argument('--random_beta', action="store_true", default = False, help='delete folders')
228 | parser.add_argument('--clean', action="store_true", default = False, help='delete folders')
229 | parser.add_argument('--debug', action="store_true", default = False, help='debug model')
230 |
231 | args = parser.parse_args()
232 | print(args, end='\n\n')
233 |
234 | cfg = get_cfg_defaults()
235 | if args.cfg is not None:
236 | cfg_file = args.cfg
237 | cfg = update_cfg(cfg, args.cfg)
238 | cfg.cfg_file = cfg_file
239 | cfg.mode = args.mode
240 | cfg.clean = args.clean
241 | cfg.debug = args.debug
242 | cfg.random_beta = args.random_beta
243 | return cfg
244 |
--------------------------------------------------------------------------------
/lib/utils/lossfunc.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torchvision
4 | import torchvision.models as models
5 | import torch.nn.functional as F
6 | from functools import reduce
7 | import scipy.sparse as sp
8 | import numpy as np
9 | from chumpy.utils import row, col
10 |
11 | def get_vert_connectivity(num_vertices, faces):
12 | """
13 | Returns a sparse matrix (of size #verts x #verts) where each nonzero
14 | element indicates a neighborhood relation. For example, if there is a
15 | nonzero element in position (15,12), that means vertex 15 is connected
16 | by an edge to vertex 12.
17 | Adapted from https://github.com/mattloper/opendr/
18 | """
19 |
20 | vpv = sp.csc_matrix((num_vertices,num_vertices))
21 |
22 | # for each column in the faces...
23 | for i in range(3):
24 | IS = faces[:,i]
25 | JS = faces[:,(i+1)%3]
26 | data = np.ones(len(IS))
27 | ij = np.vstack((row(IS.flatten()), row(JS.flatten())))
28 | mtx = sp.csc_matrix((data, ij), shape=vpv.shape)
29 | vpv = vpv + mtx + mtx.T
30 | return vpv
31 |
32 | def get_vertices_per_edge(num_vertices, faces):
33 | """
34 | Returns an Ex2 array of adjacencies between vertices, where
35 | each element in the array is a vertex index. Each edge is included
36 | only once. If output of get_faces_per_edge is provided, this is used to
37 | avoid call to get_vert_connectivity()
38 | Adapted from https://github.com/mattloper/opendr/
39 | """
40 |
41 | vc = sp.coo_matrix(get_vert_connectivity(num_vertices, faces))
42 | result = np.hstack((col(vc.row), col(vc.col)))
43 | result = result[result[:,0] < result[:,1]] # for uniqueness
44 | return result
45 |
46 | def relative_edge_loss(vertices1, vertices2, vertices_per_edge=None, faces=None, lossfunc=torch.nn.functional.mse_loss):
47 | """
48 | Given two meshes of the same topology, returns the relative edge differences.
49 |
50 | """
51 |
52 | if vertices_per_edge is None and faces is not None:
53 | vertices_per_edge = get_vertices_per_edge(len(vertices1), faces)
54 | elif vertices_per_edge is None and faces is None:
55 | raise ValueError("Either vertices_per_edge or faces must be specified")
56 |
57 | edges_for = lambda x: x[:, vertices_per_edge[:, 0], :] - x[:, vertices_per_edge[:, 1], :]
58 | return lossfunc(edges_for(vertices1), edges_for(vertices2))
59 |
60 |
61 | def relative_laplacian_loss(mesh1, mesh2, lossfunc=torch.nn.functional.mse_loss):
62 | L = mesh2.laplacian_packed()
63 | L1 = L.mm(mesh1.verts_packed())
64 | L2 = L.mm(mesh2.verts_packed())
65 | loss = lossfunc(L1, L2)
66 | return loss
67 |
68 | def mesh_laplacian(meshes, method: str = "uniform"):
69 | r"""
70 | Computes the laplacian smoothing objective for a batch of meshes.
71 | This function supports three variants of Laplacian smoothing,
72 | namely with uniform weights("uniform"), with cotangent weights ("cot"),
73 | and cotangent curvature ("cotcurv").For more details read [1, 2].
74 |
75 | Args:
76 | meshes: Meshes object with a batch of meshes.
77 | method: str specifying the method for the laplacian.
78 | Returns:
79 | loss: Average laplacian smoothing loss across the batch.
80 | Returns 0 if meshes contains no meshes or all empty meshes.
81 |
82 | Consider a mesh M = (V, F), with verts of shape Nx3 and faces of shape Mx3.
83 | The Laplacian matrix L is a NxN tensor such that LV gives a tensor of vectors:
84 | for a uniform Laplacian, LuV[i] points to the centroid of its neighboring
85 | vertices, a cotangent Laplacian LcV[i] is known to be an approximation of
86 | the surface normal, while the curvature variant LckV[i] scales the normals
87 | by the discrete mean curvature. For vertex i, assume S[i] is the set of
88 | neighboring vertices to i, a_ij and b_ij are the "outside" angles in the
89 | two triangles connecting vertex v_i and its neighboring vertex v_j
90 | for j in S[i], as seen in the diagram below.
91 |
92 | .. code-block:: python
93 |
94 | a_ij
95 | /\
96 | / \
97 | / \
98 | / \
99 | v_i /________\ v_j
100 | \ /
101 | \ /
102 | \ /
103 | \ /
104 | \/
105 | b_ij
106 |
107 | The definition of the Laplacian is LV[i] = sum_j w_ij (v_j - v_i)
108 | For the uniform variant, w_ij = 1 / |S[i]|
109 | For the cotangent variant,
110 | w_ij = (cot a_ij + cot b_ij) / (sum_k cot a_ik + cot b_ik)
111 | For the cotangent curvature, w_ij = (cot a_ij + cot b_ij) / (4 A[i])
112 | where A[i] is the sum of the areas of all triangles containing vertex v_i.
113 |
114 | There is a nice trigonometry identity to compute cotangents. Consider a triangle
115 | with side lengths A, B, C and angles a, b, c.
116 |
117 | .. code-block:: python
118 |
119 | c
120 | /|\
121 | / | \
122 | / | \
123 | B / H| \ A
124 | / | \
125 | / | \
126 | /a_____|_____b\
127 | C
128 |
129 | Then cot a = (B^2 + C^2 - A^2) / 4 * area
130 | We know that area = CH/2, and by the law of cosines we have
131 |
132 | A^2 = B^2 + C^2 - 2BC cos a => B^2 + C^2 - A^2 = 2BC cos a
133 |
134 | Putting these together, we get:
135 |
136 | B^2 + C^2 - A^2 2BC cos a
137 | _______________ = _________ = (B/H) cos a = cos a / sin a = cot a
138 | 4 * area 2CH
139 |
140 |
141 | [1] Desbrun et al, "Implicit fairing of irregular meshes using diffusion
142 | and curvature flow", SIGGRAPH 1999.
143 |
144 | [2] Nealan et al, "Laplacian Mesh Optimization", Graphite 2006.
145 | """
146 |
147 | if meshes.isempty():
148 | return torch.tensor(
149 | [0.0], dtype=torch.float32, device=meshes.device, requires_grad=True
150 | )
151 |
152 | N = len(meshes)
153 | verts_packed = meshes.verts_packed() # (sum(V_n), 3)
154 | faces_packed = meshes.faces_packed() # (sum(F_n), 3)
155 | num_verts_per_mesh = meshes.num_verts_per_mesh() # (N,)
156 | verts_packed_idx = meshes.verts_packed_to_mesh_idx() # (sum(V_n),)
157 | weights = num_verts_per_mesh.gather(0, verts_packed_idx) # (sum(V_n),)
158 | weights = 1.0 / weights.float()
159 |
160 | # We don't want to backprop through the computation of the Laplacian;
161 | # just treat it as a magic constant matrix that is used to transform
162 | # verts into normals
163 | with torch.no_grad():
164 | if method == "uniform":
165 | L = meshes.laplacian_packed()
166 | elif method in ["cot", "cotcurv"]:
167 | L, inv_areas = cot_laplacian(verts_packed, faces_packed)
168 | if method == "cot":
169 | norm_w = torch.sparse.sum(L, dim=1).to_dense().view(-1, 1)
170 | idx = norm_w > 0
171 | norm_w[idx] = 1.0 / norm_w[idx]
172 | else:
173 | L_sum = torch.sparse.sum(L, dim=1).to_dense().view(-1, 1)
174 | norm_w = 0.25 * inv_areas
175 | else:
176 | raise ValueError("Method should be one of {uniform, cot, cotcurv}")
177 |
178 | if method == "uniform":
179 | loss = L.mm(verts_packed)
180 | elif method == "cot":
181 | loss = L.mm(verts_packed) * norm_w - verts_packed
182 | elif method == "cotcurv":
183 | # pyre-fixme[61]: `norm_w` may not be initialized here.
184 | loss = (L.mm(verts_packed) - L_sum * verts_packed) * norm_w
185 | # import ipdb; ipdb.set_trace()
186 | # loss = loss.norm(dim=1)
187 |
188 | # loss = loss * weights
189 | # return loss.sum() / N
190 | return loss
191 |
192 |
193 |
194 | def huber(x, y, scaling=0.1):
195 | """
196 | A helper function for evaluating the smooth L1 (huber) loss
197 | between the rendered silhouettes and colors.
198 | """
199 | # import ipdb; ipdb.set_trace()
200 | diff_sq = (x - y) ** 2
201 | loss = ((1 + diff_sq / (scaling**2)).clamp(1e-4).sqrt() - 1) * float(scaling)
202 | # if mask is not None:
203 | # loss = loss.abs().sum()/mask.sum()
204 | # else:
205 | loss = loss.abs().mean()
206 | return loss
207 |
208 | class MSELoss(nn.Module):
209 | def __init__(self, w_alpha=0.):
210 | super(MSELoss, self).__init__()
211 | # self.loss = nn.MSELoss(reduction='mean')
212 | self.w_alpha = w_alpha
213 |
214 | def forward(self, inputs, targets):
215 | # import ipdb; ipdb.set_trace()
216 | # print(inputs['rgb_coarse'].min(), inputs['rgb_coarse'].max(), targets.min(), targets.max())
217 | # loss = self.loss(inputs['rgb_coarse'], targets[:,:3])
218 | loss = huber(inputs['rgb_coarse'], targets[:,:3])
219 |
220 | if 'rgb_fine' in inputs:
221 | loss += huber(inputs['rgb_fine'], targets[:,:3])
222 |
223 | # import ipdb; ipdb.set_trace()
224 | if self.w_alpha>0. and targets.shape[1]>3:
225 | weights = inputs['weights_coarse']
226 | pix_alpha = weights.sum(dim=1)
227 | # pix_alpha[pix_alpha>1] = 1
228 | # loss += self.loss(pix_alpha, targets[:,3])*self.w_alpha
229 | loss += huber(pix_alpha, targets[:,3])*self.w_alpha
230 | if 'weights_fine' in inputs:
231 | weight = inputs['weights_fine']
232 | pix_alpha = weights.sum(dim=1)
233 | # loss += self.loss(pix_alpha, targets[:,3])*self.w_alpha
234 | loss += huber(pix_alpha, targets[:,3])*self.w_alpha
235 |
236 | return loss
237 |
238 |
239 | loss_dict = {'mse': MSELoss}
240 |
241 |
242 | ### IDMRF loss
243 | class VGG19FeatLayer(nn.Module):
244 | def __init__(self):
245 | super(VGG19FeatLayer, self).__init__()
246 | self.vgg19 = models.vgg19(pretrained=True).features.eval().cuda()
247 | self.mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).cuda()
248 | self.std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).cuda()
249 |
250 | def forward(self, x):
251 | out = {}
252 | x = x - self.mean
253 | x = x/self.std
254 | ci = 1
255 | ri = 0
256 | for layer in self.vgg19.children():
257 | if isinstance(layer, nn.Conv2d):
258 | ri += 1
259 | name = 'conv{}_{}'.format(ci, ri)
260 | elif isinstance(layer, nn.ReLU):
261 | ri += 1
262 | name = 'relu{}_{}'.format(ci, ri)
263 | layer = nn.ReLU(inplace=False)
264 | elif isinstance(layer, nn.MaxPool2d):
265 | ri = 0
266 | name = 'pool_{}'.format(ci)
267 | ci += 1
268 | elif isinstance(layer, nn.BatchNorm2d):
269 | name = 'bn_{}'.format(ci)
270 | else:
271 | raise RuntimeError('Unrecognized layer: {}'.format(layer.__class__.__name__))
272 | x = layer(x)
273 | out[name] = x
274 | # print([x for x in out])
275 | return out
276 | class IDMRFLoss(nn.Module):
277 | def __init__(self, featlayer=VGG19FeatLayer):
278 | super(IDMRFLoss, self).__init__()
279 | self.featlayer = featlayer()
280 | self.feat_style_layers = {'relu3_2': 1.0, 'relu4_2': 1.0}
281 | self.feat_content_layers = {'relu4_2': 1.0}
282 | self.bias = 1.0
283 | self.nn_stretch_sigma = 0.5
284 | self.lambda_style = 1.0
285 | self.lambda_content = 1.0
286 |
287 | def sum_normalize(self, featmaps):
288 | reduce_sum = torch.sum(featmaps, dim=1, keepdim=True) + 1e-5
289 | return featmaps / reduce_sum
290 |
291 | def patch_extraction(self, featmaps):
292 | patch_size = 1
293 | patch_stride = 1
294 | patches_as_depth_vectors = featmaps.unfold(2, patch_size, patch_stride).unfold(3, patch_size, patch_stride)
295 | self.patches_OIHW = patches_as_depth_vectors.permute(0, 2, 3, 1, 4, 5)
296 | dims = self.patches_OIHW.size()
297 | self.patches_OIHW = self.patches_OIHW.view(-1, dims[3], dims[4], dims[5])
298 | return self.patches_OIHW
299 |
300 | def compute_relative_distances(self, cdist):
301 | epsilon = 1e-5
302 | div = torch.min(cdist, dim=1, keepdim=True)[0]
303 | relative_dist = cdist / (div + epsilon)
304 | return relative_dist
305 |
306 | def exp_norm_relative_dist(self, relative_dist):
307 | scaled_dist = relative_dist
308 | dist_before_norm = torch.exp((self.bias - scaled_dist)/self.nn_stretch_sigma)
309 | self.cs_NCHW = self.sum_normalize(dist_before_norm)
310 | return self.cs_NCHW
311 |
312 | def mrf_loss(self, gen, tar):
313 | meanT = torch.mean(tar, 1, keepdim=True)
314 | gen_feats, tar_feats = gen - meanT, tar - meanT
315 | gen_feats = gen_feats
316 | tar_feats = tar_feats
317 | gen_feats_norm = torch.norm(gen_feats, p=2, dim=1, keepdim=True)
318 | tar_feats_norm = torch.norm(tar_feats, p=2, dim=1, keepdim=True)
319 | gen_normalized = gen_feats / (gen_feats_norm + 1e-5)
320 | tar_normalized = tar_feats / (tar_feats_norm + 1e-5)
321 |
322 | cosine_dist_l = []
323 | BatchSize = tar.size(0)
324 |
325 | for i in range(BatchSize):
326 | tar_feat_i = tar_normalized[i:i+1, :, :, :]
327 | gen_feat_i = gen_normalized[i:i+1, :, :, :]
328 | patches_OIHW = self.patch_extraction(tar_feat_i)
329 | cosine_dist_i = F.conv2d(gen_feat_i, patches_OIHW)
330 | cosine_dist_l.append(cosine_dist_i)
331 | cosine_dist = torch.cat(cosine_dist_l, dim=0)
332 | cosine_dist_zero_2_one = - (cosine_dist - 1) / 2
333 | relative_dist = self.compute_relative_distances(cosine_dist_zero_2_one)
334 | rela_dist = self.exp_norm_relative_dist(relative_dist)
335 | dims_div_mrf = rela_dist.size()
336 | k_max_nc = torch.max(rela_dist.view(dims_div_mrf[0], dims_div_mrf[1], -1), dim=2)[0]
337 | div_mrf = torch.mean(k_max_nc, dim=1)
338 | div_mrf_sum = -torch.log(div_mrf)
339 | div_mrf_sum = torch.sum(div_mrf_sum)
340 | return div_mrf_sum
341 |
342 | def forward(self, gen, tar):
343 | ## gen: [bz,3,h,w] rgb [0,1]
344 | gen_vgg_feats = self.featlayer(gen)
345 | tar_vgg_feats = self.featlayer(tar)
346 | style_loss_list = [self.feat_style_layers[layer] * self.mrf_loss(gen_vgg_feats[layer], tar_vgg_feats[layer]) for layer in self.feat_style_layers]
347 | self.style_loss = reduce(lambda x, y: x+y, style_loss_list) * self.lambda_style
348 | content_loss_list = [self.feat_content_layers[layer] * self.mrf_loss(gen_vgg_feats[layer], tar_vgg_feats[layer]) for layer in self.feat_content_layers]
349 | self.content_loss = reduce(lambda x, y: x+y, content_loss_list) * self.lambda_content
350 | return self.style_loss + self.content_loss
351 |
352 | # loss = 0
353 | # for key in self.feat_style_layers.keys():
354 | # loss += torch.mean((gen_vgg_feats[key] - tar_vgg_feats[key])**2)
355 | # return loss
356 |
357 |
358 | class VGGPerceptualLoss(torch.nn.Module):
359 | def __init__(self, resize=True):
360 | super(VGGPerceptualLoss, self).__init__()
361 | blocks = []
362 | blocks.append(torchvision.models.vgg16(pretrained=True).features[:4].eval())
363 | blocks.append(torchvision.models.vgg16(pretrained=True).features[4:9].eval())
364 | blocks.append(torchvision.models.vgg16(pretrained=True).features[9:16].eval())
365 | blocks.append(torchvision.models.vgg16(pretrained=True).features[16:23].eval())
366 | for bl in blocks:
367 | for p in bl.parameters():
368 | p.requires_grad = False
369 | self.blocks = torch.nn.ModuleList(blocks)
370 | self.transform = torch.nn.functional.interpolate
371 | self.resize = resize
372 | self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
373 | self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
374 |
375 | def forward(self, input, target, feature_layers=[0, 1, 2, 3], style_layers=[]):
376 | if input.shape[1] != 3:
377 | input = input.repeat(1, 3, 1, 1)
378 | target = target.repeat(1, 3, 1, 1)
379 | input = (input-self.mean) / self.std
380 | target = (target-self.mean) / self.std
381 | if self.resize:
382 | input = self.transform(input, mode='bilinear', size=(224, 224), align_corners=False)
383 | target = self.transform(target, mode='bilinear', size=(224, 224), align_corners=False)
384 | loss = 0.0
385 | x = input
386 | y = target
387 | for i, block in enumerate(self.blocks):
388 | x = block(x)
389 | y = block(y)
390 | if i in feature_layers:
391 | loss += torch.nn.functional.l1_loss(x, y)
392 | if i in style_layers:
393 | act_x = x.reshape(x.shape[0], x.shape[1], -1)
394 | act_y = y.reshape(y.shape[0], y.shape[1], -1)
395 | gram_x = act_x @ act_x.permute(0, 2, 1)
396 | gram_y = act_y @ act_y.permute(0, 2, 1)
397 | loss += torch.nn.functional.l1_loss(gram_x, gram_y)
398 | return loss
--------------------------------------------------------------------------------
/lib/utils/metric.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from kornia.losses import ssim as dssim
3 |
4 | def mse(image_pred, image_gt, valid_mask=None, reduction='mean'):
5 | value = (image_pred-image_gt[:,:3])**2
6 | if valid_mask is not None:
7 | value = value[valid_mask]
8 | if reduction == 'mean':
9 | return torch.mean(value)
10 | return value
11 |
12 | def psnr(image_pred, image_gt, valid_mask=None, reduction='mean'):
13 | return -10*torch.log10(mse(image_pred, image_gt[:,:3], valid_mask, reduction))
14 |
15 | def ssim(image_pred, image_gt, reduction='mean'):
16 | """
17 | image_pred and image_gt: (1, 3, H, W)
18 | """
19 | dssim_ = dssim(image_pred, image_gt, 3, reduction) # dissimilarity in [0, 1]
20 | return 1-2*dssim_ # in [-1, 1]
--------------------------------------------------------------------------------
/lib/utils/perceptual_loss.py:
--------------------------------------------------------------------------------
1 | """
2 | Code heavily inspired by https://github.com/JustusThies/NeuralTexGen/blob/master/models/VGG_LOSS.py
3 | """
4 | import torch
5 | from torchvision import models
6 | from torchvision.transforms import Normalize
7 | from collections import namedtuple
8 | import sys
9 | from pathlib import Path
10 |
11 | sys.path.append(str((Path(__file__).parents[2]/"deps")))
12 | from InsightFace.recognition.arcface_torch.backbones import get_model
13 |
14 |
15 | class VGG16(torch.nn.Module):
16 | def __init__(self, requires_grad=False):
17 | super(VGG16, self).__init__()
18 | vgg_pretrained_features = models.vgg16(pretrained=True).features
19 | self.slice1 = torch.nn.Sequential()
20 | self.slice2 = torch.nn.Sequential()
21 | self.slice3 = torch.nn.Sequential()
22 | self.slice4 = torch.nn.Sequential()
23 | for x in range(4):
24 | self.slice1.add_module(str(x), vgg_pretrained_features[x])
25 | for x in range(4, 9):
26 | self.slice2.add_module(str(x), vgg_pretrained_features[x])
27 | for x in range(9, 16):
28 | self.slice3.add_module(str(x), vgg_pretrained_features[x])
29 | for x in range(16, 23):
30 | self.slice4.add_module(str(x), vgg_pretrained_features[x])
31 | if not requires_grad:
32 | for param in self.parameters():
33 | param.requires_grad = False
34 | self.normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
35 |
36 | def forward(self, X):
37 | """
38 | assuming rgb input of shape N x 3 x H x W normalized to -1 ... +1
39 | :param X:
40 | :return:
41 | """
42 | X = self.normalize(X * .5 + .5)
43 |
44 | h = self.slice1(X)
45 | h_relu1_2 = h
46 | h = self.slice2(h)
47 | h_relu2_2 = h
48 | h = self.slice3(h)
49 | h_relu3_3 = h
50 | h = self.slice4(h)
51 | h_relu4_3 = h
52 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3'])
53 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3)
54 | return out
55 |
56 |
57 | class ResNet18(torch.nn.Module):
58 | def __init__(self, weight_path):
59 | super(ResNet18, self).__init__()
60 | net = get_model("r18")
61 | net.load_state_dict(torch.load(weight_path))
62 | net.eval()
63 | self.conv1 = net.conv1
64 | self.bn1 = net.bn1
65 | self.prelu = net.prelu
66 | self.layer1 = net.layer1
67 | self.layer2 = net.layer2
68 | self.layer3 = net.layer3
69 | self.layer4 = net.layer4
70 |
71 | for p in self.parameters():
72 | p.requires_grad = False
73 |
74 | def forward(self, x):
75 | """
76 | assuming rgb input of shape N x 3 x H x W normalized to -1 ... +1
77 | :param X:
78 | :return:
79 | """
80 | self.eval()
81 | x = self.conv1(x)
82 | x = self.bn1(x)
83 | x = self.prelu(x)
84 | x = self.layer1(x)
85 | relu1_2 = x.clone()
86 | x = self.layer2(x)
87 | relu2_2 = x.clone()
88 | x = self.layer3(x)
89 | relu3_3 = x.clone()
90 | x = self.layer4(x)
91 | relu4_3 = x.clone()
92 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3'])
93 | out = vgg_outputs(relu1_2, relu2_2, relu3_3, relu4_3)
94 | return out
95 |
96 |
97 | def gram_matrix(y):
98 | (b, ch, h, w) = y.size()
99 | features = y.view(b, ch, w * h)
100 | features_t = features.transpose(1, 2)
101 | gram = features.bmm(features_t) / (ch * h * w)
102 | return gram
103 |
104 |
105 | class ResNetLOSS(torch.nn.Module):
106 | def __init__(self, criterion=torch.nn.L1Loss(reduction='mean')):
107 | super(ResNetLOSS, self).__init__()
108 | self.model = ResNet18("assets/InsightFace/backbone.pth")
109 | self.model.eval()
110 | self.criterion = criterion
111 | self.criterion.reduction = "mean"
112 |
113 | for p in self.parameters():
114 | p.requires_grad = False
115 |
116 | def forward(self, fake, target, content_weight=1.0, style_weight=1.0):
117 | """
118 | assumes input images normalize to -1 ... + 1 and of shape N x 3 x H x W
119 | :param fake:
120 | :param target:
121 | :param content_weight:
122 | :param style_weight:
123 | :return:
124 | """
125 | vgg_fake = self.model(fake)
126 | vgg_target = self.model(target)
127 |
128 | content_loss = self.criterion(vgg_target.relu2_2, vgg_fake.relu2_2)
129 |
130 | # gram_matrix
131 | gram_style = [gram_matrix(y) for y in vgg_target]
132 | style_loss = 0.0
133 | for ft_y, gm_s in zip(vgg_fake, gram_style):
134 | gm_y = gram_matrix(ft_y)
135 | style_loss += self.criterion(gm_y, gm_s)
136 |
137 | total_loss = content_weight * content_loss + style_weight * style_loss
138 | return total_loss
139 |
--------------------------------------------------------------------------------
/lib/utils/rasterize_rendering.py:
--------------------------------------------------------------------------------
1 | '''
2 | rasterization
3 | basic rasterize
4 | render shape (for visualization)
5 | render texture (need uv information)
6 | '''
7 |
8 | import torch
9 | from torch import nn
10 | import torch.nn.functional as F
11 |
12 | from pytorch3d.structures import Meshes
13 | from pytorch3d.renderer.mesh import rasterize_meshes
14 | from . import util
15 |
16 | def add_directionlight(normals, lights=None):
17 | '''
18 | normals: [bz, nv, 3]
19 | lights: [bz, nlight, 6]
20 | returns:
21 | shading: [bz, nv, 3]
22 | '''
23 | light_direction = lights[:,:,:3]; light_intensities = lights[:,:,3:]
24 | directions_to_lights = F.normalize(light_direction[:,:,None,:].expand(-1,-1,normals.shape[1],-1), dim=3)
25 | # normals_dot_lights = torch.clamp((normals[:,None,:,:]*directions_to_lights).sum(dim=3), 0., 1.)
26 | # normals_dot_lights = (normals[:,None,:,:]*directions_to_lights).sum(dim=3)
27 | normals_dot_lights = torch.clamp((normals[:,None,:,:]*directions_to_lights).sum(dim=3), 0., 1.)
28 | shading = normals_dot_lights[:,:,:,None]*light_intensities[:,:,None,:]
29 | return shading.mean(1)
30 |
31 |
32 | def pytorch3d_rasterize(vertices, faces, image_size, attributes=None,
33 | soft=False, blur_radius=0.0, sigma=1e-8, faces_per_pixel=1, gamma=1e-4,
34 | perspective_correct=False, clip_barycentric_coords=True, h=None, w=None):
35 | fixed_vertices = vertices.clone()
36 | fixed_vertices[...,:2] = -fixed_vertices[...,:2]
37 |
38 | if h is None and w is None:
39 | image_size = image_size
40 | else:
41 | image_size = [h, w]
42 | if h>w:
43 | fixed_vertices[..., 1] = fixed_vertices[..., 1]*h/w
44 | else:
45 | fixed_vertices[..., 0] = fixed_vertices[..., 0]*w/h
46 |
47 | meshes_screen = Meshes(verts=fixed_vertices.float(), faces=faces.long())
48 | # import ipdb; ipdb.set_trace()
49 | # pytorch3d rasterize
50 | pix_to_face, zbuf, bary_coords, dists = rasterize_meshes(
51 | meshes_screen,
52 | image_size=image_size,
53 | blur_radius=blur_radius,
54 | faces_per_pixel=faces_per_pixel,
55 | perspective_correct=perspective_correct,
56 | clip_barycentric_coords=clip_barycentric_coords,
57 | # max_faces_per_bin = faces.shape[1],
58 | bin_size = 0
59 | )
60 | # import ipdb; ipdb.set_trace()
61 | vismask = (pix_to_face > -1).float().squeeze(-1)
62 | depth = zbuf.squeeze(-1)
63 |
64 | if soft:
65 | from pytorch3d.renderer.blending import _sigmoid_alpha
66 | colors = torch.ones_like(bary_coords)
67 | N, H, W, K = pix_to_face.shape
68 | pixel_colors = torch.ones((N, H, W, 4), dtype=colors.dtype, device=colors.device)
69 | pixel_colors[..., :3] = colors[..., 0, :]
70 | alpha = _sigmoid_alpha(dists, pix_to_face, sigma)
71 | pixel_colors[..., 3] = alpha
72 | pixel_colors = pixel_colors.permute(0,3,1,2)
73 | return pixel_colors
74 |
75 | if attributes is None:
76 | return depth, vismask
77 | else:
78 | vismask = (pix_to_face > -1).float()
79 | D = attributes.shape[-1]
80 | attributes = attributes.clone(); attributes = attributes.view(attributes.shape[0]*attributes.shape[1], 3, attributes.shape[-1])
81 | N, H, W, K, _ = bary_coords.shape
82 | mask = pix_to_face == -1
83 | pix_to_face = pix_to_face.clone()
84 | pix_to_face[mask] = 0
85 | idx = pix_to_face.view(N * H * W * K, 1, 1).expand(N * H * W * K, 3, D)
86 | pixel_face_vals = attributes.gather(0, idx).view(N, H, W, K, 3, D)
87 | pixel_vals = (bary_coords[..., None] * pixel_face_vals).sum(dim=-2)
88 | pixel_vals[mask] = 0 # Replace masked values in output.
89 | pixel_vals = pixel_vals[:,:,:,0].permute(0,3,1,2)
90 | pixel_vals = torch.cat([pixel_vals, vismask[:,:,:,0][:,None,:,:]], dim=1)
91 | return pixel_vals
92 |
93 | # For visualization
94 | def render_shape(vertices, faces, image_size=None, background=None, lights=None, blur_radius=0., shift=True, colors=None, h=None, w=None):
95 | '''
96 | -- rendering shape with detail normal map
97 | '''
98 | batch_size = vertices.shape[0]
99 | transformed_vertices = vertices.clone()
100 | # set lighting
101 | # if lights is None:
102 | # light_positions = torch.tensor(
103 | # [
104 | # [-1,1,1],
105 | # [1,1,1],
106 | # [-1,-1,1],
107 | # [1,-1,1],
108 | # [0,0,1]
109 | # ]
110 | # )[None,:,:].expand(batch_size, -1, -1).float()
111 | # light_intensities = torch.ones_like(light_positions).float()*1.7
112 | # lights = torch.cat((light_positions, light_intensities), 2).to(vertices.device)
113 | if lights is None:
114 | light_positions = torch.tensor(
115 | [
116 | [-5, 5, -5],
117 | [5, 5, -5],
118 | [-5, -5, -5],
119 | [5, -5, -5],
120 | [0, 0, -5],
121 | ]
122 | )[None,:,:].expand(batch_size, -1, -1).float()
123 |
124 | light_intensities = torch.ones_like(light_positions).float()*1.7
125 | lights = torch.cat((light_positions, light_intensities), 2).to(vertices.device)
126 | if shift:
127 | transformed_vertices[:,:,2] = transformed_vertices[:,:,2] - transformed_vertices[:,:,2].min()
128 | transformed_vertices[:,:,2] = transformed_vertices[:,:,2]/transformed_vertices[:,:,2].max()*80 + 10
129 |
130 | # Attributes
131 | face_vertices = util.face_vertices(vertices, faces)
132 | normals = util.vertex_normals(vertices,faces); face_normals = util.face_vertices(normals, faces)
133 | transformed_normals = util.vertex_normals(transformed_vertices, faces); transformed_face_normals = util.face_vertices(transformed_normals, faces)
134 | if colors is None:
135 | face_colors = torch.ones_like(face_vertices)*180/255.
136 | else:
137 | face_colors = util.face_vertices(colors, faces)
138 |
139 | attributes = torch.cat([face_colors,
140 | transformed_face_normals.detach(),
141 | face_vertices.detach(),
142 | face_normals],
143 | -1)
144 | # rasterize
145 | rendering = pytorch3d_rasterize(transformed_vertices, faces, image_size=image_size, attributes=attributes, blur_radius=blur_radius, h=h, w=w)
146 |
147 | ####
148 | alpha_images = rendering[:, -1, :, :][:, None, :, :].detach()
149 |
150 | # albedo
151 | albedo_images = rendering[:, :3, :, :]
152 | # mask
153 | transformed_normal_map = rendering[:, 3:6, :, :].detach()
154 | pos_mask = (transformed_normal_map[:, 2:, :, :] < 0.15).float()
155 |
156 | # shading
157 | normal_images = rendering[:, 9:12, :, :].detach()
158 | vertice_images = rendering[:, 6:9, :, :].detach()
159 |
160 | shading = add_directionlight(normal_images.permute(0,2,3,1).reshape([batch_size, -1, 3]), lights)
161 | shading_images = shading.reshape([batch_size, albedo_images.shape[2], albedo_images.shape[3], 3]).permute(0,3,1,2).contiguous()
162 | shaded_images = albedo_images*shading_images
163 |
164 | # alpha_images = alpha_images*pos_mask
165 | if background is None:
166 | shape_images = shaded_images*alpha_images + torch.ones_like(shaded_images).to(vertices.device)*(1-alpha_images)
167 | else:
168 | shape_images = shaded_images *alpha_images + background*(1-alpha_images)
169 | return shape_images
170 |
171 |
172 | def render_texture(transformed_vertices, faces, image_size,
173 | vertices, albedos,
174 | uv_faces, uv_coords,
175 | lights=None, light_type='point'):
176 | '''
177 | -- Texture Rendering
178 | vertices: [batch_size, V, 3], vertices in world space, for calculating normals, then shading
179 | transformed_vertices: [batch_size, V, 3], range:normalized to [-1,1], projected vertices in image space (that is aligned to the iamge pixel), for rasterization
180 | albedos: [batch_size, 3, h, w], uv map
181 | lights:
182 | spherical homarnic: [N, 9(shcoeff), 3(rgb)]
183 | points/directional lighting: [N, n_lights, 6(xyzrgb)]
184 | light_type:
185 | point or directional
186 | '''
187 | batch_size = vertices.shape[0]
188 | face_uvcoords = util.face_vertices(uv_coords, uv_faces)
189 | ## rasterizer near 0 far 100. move mesh so minz larger than 0
190 | # import ipdb; ipdb.set_trace()
191 | # normalize to 0, 100
192 | # transformed_vertices[:,:,2] = transformed_vertices[:,:,2] + 10
193 | # transformed_vertices[:,:,2] = transformed_vertices[:,:,2] - transformed_vertices[:,:,2].min()
194 | # transformed_vertices[:,:,2] = transformed_vertices[:,:,2]/transformed_vertices[:,:,2].max()*80 + 10
195 |
196 | # import ipdb; ipdb.set_trace()
197 | # attributes
198 | face_vertices = util.face_vertices(vertices, faces)
199 | normals = util.vertex_normals(vertices, faces); face_normals = util.face_vertices(normals, faces)
200 | transformed_normals = util.vertex_normals(transformed_vertices, faces); transformed_face_normals = util.face_vertices(transformed_normals, self.faces.expand(batch_size, -1, -1))
201 |
202 | attributes = torch.cat([face_uvcoords,
203 | transformed_face_normals.detach(),
204 | face_vertices.detach(),
205 | face_normals],
206 | -1)
207 | # rasterize
208 | rendering = pytorch3d_rasterize(transformed_vertices, faces, attributes)
209 |
210 | ####
211 | # vis mask
212 | alpha_images = rendering[:, -1, :, :][:, None, :, :].detach()
213 |
214 | # albedo
215 | uvcoords_images = rendering[:, :3, :, :]; grid = (uvcoords_images).permute(0, 2, 3, 1)[:, :, :, :2]
216 | albedo_images = F.grid_sample(albedos, grid, align_corners=False)
217 |
218 | # visible mask for pixels with positive normal direction
219 | transformed_normal_map = rendering[:, 3:6, :, :].detach()
220 | pos_mask = (transformed_normal_map[:, 2:, :, :] < -0.05).float()
221 |
222 | # shading
223 | normal_images = rendering[:, 9:12, :, :]
224 | if lights is not None:
225 | if lights.shape[1] == 9:
226 | shading_images = self.add_SHlight(normal_images, lights)
227 | else:
228 | if light_type=='point':
229 | vertice_images = rendering[:, 6:9, :, :].detach()
230 | shading = self.add_pointlight(vertice_images.permute(0,2,3,1).reshape([batch_size, -1, 3]), normal_images.permute(0,2,3,1).reshape([batch_size, -1, 3]), lights)
231 | shading_images = shading.reshape([batch_size, albedo_images.shape[2], albedo_images.shape[3], 3]).permute(0,3,1,2)
232 | else:
233 | shading = self.add_directionlight(normal_images.permute(0,2,3,1).reshape([batch_size, -1, 3]), lights)
234 | shading_images = shading.reshape([batch_size, albedo_images.shape[2], albedo_images.shape[3], 3]).permute(0,3,1,2)
235 | images = albedo_images*shading_images
236 | else:
237 | images = albedo_images
238 | shading_images = images.detach()*0.
239 |
240 | outputs = {
241 | 'images': images*alpha_images,
242 | 'albedo_images': albedo_images*alpha_images,
243 | 'alpha_images': alpha_images,
244 | 'pos_mask': pos_mask,
245 | 'shading_images': shading_images,
246 | 'grid': grid,
247 | 'normals': normals,
248 | 'normal_images': normal_images*alpha_images,
249 | 'transformed_normals': transformed_normals,
250 | }
251 |
252 | return outputs
--------------------------------------------------------------------------------
/lib/utils/volumetric_rendering.py:
--------------------------------------------------------------------------------
1 | """
2 | Differentiable volumetric implementation used by pi-GAN generator.
3 | """
4 |
5 | import time
6 | from functools import partial
7 |
8 | import math
9 | import numpy as np
10 | import torch
11 | import torch.nn.functional as F
12 | import matplotlib.pyplot as plt
13 | import random
14 |
15 | import torch
16 |
17 | def transform_vectors(matrix: torch.Tensor, vectors4: torch.Tensor) -> torch.Tensor:
18 | """
19 | Left-multiplies MxM @ NxM. Returns NxM.
20 | """
21 | res = torch.matmul(vectors4, matrix.T)
22 | return res
23 |
24 | def normalize_vecs(vectors: torch.Tensor) -> torch.Tensor:
25 | """
26 | Normalize vector lengths.
27 | """
28 | return vectors / (torch.norm(vectors, dim=-1, keepdim=True))
29 |
30 | def torch_dot(x: torch.Tensor, y: torch.Tensor):
31 | """
32 | Dot product of two tensors.
33 | """
34 | return (x * y).sum(-1)
35 |
36 | def fancy_integration(rgb_sigma, z_vals, device, noise_std=0.5, last_back=False, white_back=False, clamp_mode='softplus', fill_mode=None):
37 | """Performs NeRF volumetric rendering."""
38 | rgbs = rgb_sigma[..., :3]
39 | sigmas = rgb_sigma[..., 3:]
40 |
41 | # import ipdb; ipdb.set_trace()
42 | deltas = z_vals[:, :, 1:] - z_vals[:, :, :-1]
43 | delta_inf = 1e10 * torch.ones_like(deltas[:, :, :1])
44 | deltas = torch.cat([deltas, delta_inf], -2)
45 |
46 | noise = torch.randn(sigmas.shape, device=device) * noise_std
47 |
48 | if clamp_mode == 'softplus':
49 | alphas = 1-torch.exp(-deltas * (F.softplus(sigmas + noise)))
50 | elif clamp_mode == 'relu':
51 | alphas = 1 - torch.exp(-deltas * (F.relu(sigmas + noise)))
52 | else:
53 | raise "Need to choose clamp mode"
54 |
55 | alphas_shifted = torch.cat([torch.ones_like(alphas[:, :, :1]), 1-alphas + 1e-10], -2)
56 | weights = alphas * torch.cumprod(alphas_shifted, -2)[:, :, :-1]
57 | weights_sum = weights.sum(2)
58 | # weights_sum = weights[:,:,:-1].sum(2)
59 |
60 | if last_back:
61 | weights[:, :, -1] += (1 - weights_sum)
62 |
63 | rgb_final = torch.sum(weights * rgbs, -2)
64 | depth_final = torch.sum(weights * z_vals, -2)
65 |
66 | # if white_back:
67 | # rgb_final = rgb_final + 1-weights_sum
68 |
69 | if fill_mode == 'debug':
70 | rgb_final[weights_sum.squeeze(-1) < 0.9] = torch.tensor([1., 0, 0], device=rgb_final.device)
71 | elif fill_mode == 'weight':
72 | rgb_final = weights_sum.expand_as(rgb_final)
73 |
74 | return rgb_final, depth_final, weights.squeeze(-1)
75 |
76 |
77 | def get_initial_rays_trig(n, num_steps, device, fov, resolution, ray_start, ray_end):
78 | """Returns sample points, z_vals, and ray directions in camera space."""
79 |
80 | W, H = resolution
81 | # Create full screen NDC (-1 to +1) coords [x, y, 0, 1].
82 | # Y is flipped to follow image memory layouts.
83 | x, y = torch.meshgrid(torch.linspace(-1, 1, W, device=device),
84 | torch.linspace(1, -1, H, device=device))
85 | x = x.T.flatten()
86 | y = y.T.flatten()
87 | z = -torch.ones_like(x, device=device) / np.tan((2 * math.pi * fov / 360)/2)
88 |
89 | rays_d_cam = normalize_vecs(torch.stack([x, y, z], -1))
90 | # rays_d_cam = torch.stack([x, y, z], -1)
91 |
92 |
93 | z_vals = torch.linspace(ray_start, ray_end, num_steps, device=device).reshape(1, num_steps, 1).repeat(W*H, 1, 1)
94 | points = rays_d_cam.unsqueeze(1).repeat(1, num_steps, 1) * z_vals
95 |
96 | points = torch.stack(n*[points])
97 | z_vals = torch.stack(n*[z_vals])
98 | rays_d_cam = torch.stack(n*[rays_d_cam]).to(device)
99 | # import ipdb; ipdb.set_trace()
100 | return points, z_vals, rays_d_cam
101 |
102 | def perturb_points(points, z_vals, ray_directions, device):
103 | distance_between_points = z_vals[:,:,1:2,:] - z_vals[:,:,0:1,:]
104 | offset = (torch.rand(z_vals.shape, device=device)-0.5) * distance_between_points
105 | # z_vals = z_vals + offset
106 | z_vals[:,:,1:-1,:] = z_vals[:,:,1:-1,:] + offset[:,:,1:-1,:]
107 | points[:,:,1:-1,:] = points[:,:,1:-1,:] + offset[:,:,1:-1,:] * ray_directions.unsqueeze(2)
108 | return points, z_vals
109 |
110 |
111 | def transform_sampled_points(points, z_vals, ray_directions, device, h_stddev=1, v_stddev=1, h_mean=math.pi * 0.5, v_mean=math.pi * 0.5, mode='normal'):
112 | """Samples a camera position and maps points in camera space to world space."""
113 |
114 | n, num_rays, num_steps, channels = points.shape
115 |
116 | # points, z_vals = perturb_points(points, z_vals, ray_directions, device)
117 |
118 |
119 | camera_origin, pitch, yaw = sample_camera_positions(n=points.shape[0], r=1, horizontal_stddev=h_stddev, vertical_stddev=v_stddev, horizontal_mean=h_mean, vertical_mean=v_mean, device=device, mode=mode)
120 | forward_vector = normalize_vecs(-camera_origin)
121 |
122 | cam2world_matrix = create_cam2world_matrix(forward_vector, camera_origin, device=device)
123 |
124 | points_homogeneous = torch.ones((points.shape[0], points.shape[1], points.shape[2], points.shape[3] + 1), device=device)
125 | points_homogeneous[:, :, :, :3] = points
126 |
127 | # should be n x 4 x 4 , n x r^2 x num_steps x 4
128 | transformed_points = torch.bmm(cam2world_matrix, points_homogeneous.reshape(n, -1, 4).permute(0,2,1)).permute(0, 2, 1).reshape(n, num_rays, num_steps, 4)
129 |
130 |
131 | transformed_ray_directions = torch.bmm(cam2world_matrix[..., :3, :3], ray_directions.reshape(n, -1, 3).permute(0,2,1)).permute(0, 2, 1).reshape(n, num_rays, 3)
132 |
133 | homogeneous_origins = torch.zeros((n, 4, num_rays), device=device)
134 | homogeneous_origins[:, 3, :] = 1
135 | transformed_ray_origins = torch.bmm(cam2world_matrix, homogeneous_origins).permute(0, 2, 1).reshape(n, num_rays, 4)[..., :3]
136 |
137 | # return transformed_points[..., :3], z_vals, transformed_ray_directions, transformed_ray_origins, pitch, yaw
138 | return points, z_vals, transformed_ray_directions, transformed_ray_origins, pitch, yaw
139 |
140 | def truncated_normal_(tensor, mean=0, std=1):
141 | size = tensor.shape
142 | tmp = tensor.new_empty(size + (4,)).normal_()
143 | valid = (tmp < 2) & (tmp > -2)
144 | ind = valid.max(-1, keepdim=True)[1]
145 | tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
146 | tensor.data.mul_(std).add_(mean)
147 | return tensor
148 |
149 | def sample_camera_positions(device, n=1, r=1, horizontal_stddev=1, vertical_stddev=1, horizontal_mean=math.pi*0.5, vertical_mean=math.pi*0.5, mode='normal'):
150 | """
151 | Samples n random locations along a sphere of radius r. Uses the specified distribution.
152 | Theta is yaw in radians (-pi, pi)
153 | Phi is pitch in radians (0, pi)
154 | """
155 |
156 | if mode == 'uniform':
157 | theta = (torch.rand((n, 1), device=device) - 0.5) * 2 * horizontal_stddev + horizontal_mean
158 | phi = (torch.rand((n, 1), device=device) - 0.5) * 2 * vertical_stddev + vertical_mean
159 |
160 | elif mode == 'normal' or mode == 'gaussian':
161 | theta = torch.randn((n, 1), device=device) * horizontal_stddev + horizontal_mean
162 | phi = torch.randn((n, 1), device=device) * vertical_stddev + vertical_mean
163 |
164 | elif mode == 'hybrid':
165 | if random.random() < 0.5:
166 | theta = (torch.rand((n, 1), device=device) - 0.5) * 2 * horizontal_stddev * 2 + horizontal_mean
167 | phi = (torch.rand((n, 1), device=device) - 0.5) * 2 * vertical_stddev * 2 + vertical_mean
168 | else:
169 | theta = torch.randn((n, 1), device=device) * horizontal_stddev + horizontal_mean
170 | phi = torch.randn((n, 1), device=device) * vertical_stddev + vertical_mean
171 |
172 | elif mode == 'truncated_gaussian':
173 | theta = truncated_normal_(torch.zeros((n, 1), device=device)) * horizontal_stddev + horizontal_mean
174 | phi = truncated_normal_(torch.zeros((n, 1), device=device)) * vertical_stddev + vertical_mean
175 |
176 | elif mode == 'spherical_uniform':
177 | theta = (torch.rand((n, 1), device=device) - .5) * 2 * horizontal_stddev + horizontal_mean
178 | v_stddev, v_mean = vertical_stddev / math.pi, vertical_mean / math.pi
179 | v = ((torch.rand((n,1), device=device) - .5) * 2 * v_stddev + v_mean)
180 | v = torch.clamp(v, 1e-5, 1 - 1e-5)
181 | phi = torch.arccos(1 - 2 * v)
182 |
183 | else:
184 | # Just use the mean.
185 | theta = torch.ones((n, 1), device=device, dtype=torch.float) * horizontal_mean
186 | phi = torch.ones((n, 1), device=device, dtype=torch.float) * vertical_mean
187 |
188 | phi = torch.clamp(phi, 1e-5, math.pi - 1e-5)
189 |
190 | output_points = torch.zeros((n, 3), device=device)
191 | output_points[:, 0:1] = r*torch.sin(phi) * torch.cos(theta)
192 | output_points[:, 2:3] = r*torch.sin(phi) * torch.sin(theta)
193 | output_points[:, 1:2] = r*torch.cos(phi)
194 |
195 | return output_points, phi, theta
196 |
197 | def create_cam2world_matrix(forward_vector, origin, device=None):
198 | """Takes in the direction the camera is pointing and the camera origin and returns a cam2world matrix."""
199 |
200 | forward_vector = normalize_vecs(forward_vector)
201 | up_vector = torch.tensor([0, 1, 0], dtype=torch.float, device=device).expand_as(forward_vector)
202 |
203 | left_vector = normalize_vecs(torch.cross(up_vector, forward_vector, dim=-1))
204 |
205 | up_vector = normalize_vecs(torch.cross(forward_vector, left_vector, dim=-1))
206 |
207 | rotation_matrix = torch.eye(4, device=device).unsqueeze(0).repeat(forward_vector.shape[0], 1, 1)
208 | rotation_matrix[:, :3, :3] = torch.stack((-left_vector, up_vector, -forward_vector), axis=-1)
209 |
210 | translation_matrix = torch.eye(4, device=device).unsqueeze(0).repeat(forward_vector.shape[0], 1, 1)
211 | translation_matrix[:, :3, 3] = origin
212 |
213 | cam2world = translation_matrix @ rotation_matrix
214 |
215 | return cam2world
216 |
217 |
218 | def create_world2cam_matrix(forward_vector, origin):
219 | """Takes in the direction the camera is pointing and the camera origin and returns a world2cam matrix."""
220 | cam2world = create_cam2world_matrix(forward_vector, origin, device=device)
221 | world2cam = torch.inverse(cam2world)
222 | return world2cam
223 |
224 |
225 | def sample_pdf(bins, weights, N_importance, det=False, eps=1e-5):
226 | """
227 | Sample @N_importance samples from @bins with distribution defined by @weights.
228 | Inputs:
229 | bins: (N_rays, N_samples_+1) where N_samples_ is "the number of coarse samples per ray - 2"
230 | weights: (N_rays, N_samples_)
231 | N_importance: the number of samples to draw from the distribution
232 | det: deterministic or not
233 | eps: a small number to prevent division by zero
234 | Outputs:
235 | samples: the sampled samples
236 | Source: https://github.com/kwea123/nerf_pl/blob/master/models/rendering.py
237 | """
238 | N_rays, N_samples_ = weights.shape
239 | weights = weights + eps # prevent division by zero (don't do inplace op!)
240 | pdf = weights / torch.sum(weights, -1, keepdim=True) # (N_rays, N_samples_)
241 | cdf = torch.cumsum(pdf, -1) # (N_rays, N_samples), cumulative distribution function
242 | cdf = torch.cat([torch.zeros_like(cdf[: ,:1]), cdf], -1) # (N_rays, N_samples_+1)
243 | # padded to 0~1 inclusive
244 | if det:
245 | u = torch.linspace(0, 1, N_importance, device=bins.device)
246 | u = u.expand(N_rays, N_importance)
247 | else:
248 | u = torch.rand(N_rays, N_importance, device=bins.device)
249 | u = u.contiguous()
250 |
251 | # inds = torch.searchsorted(cdf, u)
252 | inds = torch.searchsorted(cdf, u, right=True)
253 | below = torch.clamp_min(inds-1, 0)
254 | above = torch.clamp_max(inds, N_samples_)
255 |
256 | inds_sampled = torch.stack([below, above], -1).view(N_rays, 2*N_importance)
257 | cdf_g = torch.gather(cdf, 1, inds_sampled)
258 | cdf_g = cdf_g.view(N_rays, N_importance, 2)
259 | bins_g = torch.gather(bins, 1, inds_sampled).view(N_rays, N_importance, 2)
260 |
261 | denom = cdf_g[...,1]-cdf_g[...,0]
262 | denom[denom
2 |
3 |
4 |
5 |
SCARF: Capturing and Animation of Body and Clothing from Monocular Video
6 |
7 |
20 |
21 |

22 |
cropped image, subject segmentation, clothing segmentation, SMPL-X estimation
23 |
24 |
25 | This is the script to process video data for SCARF training.
26 |
27 | ## Getting Started
28 | ### Environment
29 | SCARF needs input image, subject mask, clothing mask, and inital SMPL-X estimation for training.
30 | Specificly, we use
31 | * [FasterRCNN](https://pytorch.org/vision/main/models/faster_rcnn.html) to detect the subject and crop image
32 | * [RobustVideoMatting](https://github.com/PeterL1n/RobustVideoMatting) to remove background
33 | * [cloth-segmentation](https://github.com/levindabhi/cloth-segmentation) to segment clothing
34 | * [PIXIE](https://github.com/yfeng95/PIXIE) to estimate SMPL-X parameters
35 |
36 | When using the processing script, it is necessary to agree to the terms of their licenses and properly cite them in your work.
37 |
38 | 1. Clone submodule repositories:
39 | ```
40 | git submodule update --init --recursive
41 | ```
42 | 2. Download their needed data:
43 | ```bash
44 | bash fetch_asset_data.sh
45 | ```
46 | If the script failed, please check their websites and download the models manually.
47 |
48 | ### process video data
49 | Put your data list into ./lists/subject_list.txt, it can be video path or image folders.
50 | Then run
51 | ```bash
52 | python process_video.py --crop --ignore_existing
53 | ```
54 | Processing time depends on the number of frames and the size of video, for mpiis-scarf video (with 400 frames and resolution 1028x1920), need around 12min.
55 |
56 |
57 | ## Video Data
58 | The script has been verified to work for datasets:
59 | ```
60 | a. mpiis-scarf (recorded video for this paper)
61 | b. People Snapshot Dataset (https://graphics.tu-bs.de/people-snapshot)
62 | c. SelfRecon dataset (https://jby1993.github.io/SelfRecon/)
63 | d. iPER dataset (https://svip-lab.github.io/dataset/iPER_dataset.html)
64 | ```
65 | To get the optimal results for your customized video, it is recommended to capture the video using similar settings as the datasets mentioned above.
66 |
67 | This means keeping the camera static,
68 | recording the subject with more views, and using uniform lighting. And better to have less than 1000 frames for training.
69 | For more information, please refer to the limitations section of SCARF.
--------------------------------------------------------------------------------
/process_data/fetch_asset_data.sh:
--------------------------------------------------------------------------------
1 | # trained model for RobustVideoMatting
2 | mkdir -p ./assets
3 | mkdir -p ./assets/RobustVideoMatting
4 | echo -e "Downloading RobustVideoMatting model..."
5 | wget https://github.com/PeterL1n/RobustVideoMatting/releases/download/v1.0.0/rvm_resnet50.pth -P ./assets/RobustVideoMatting
6 |
7 | # trained model for cloth-segmentation
8 | # if failed, please download the model from
9 | # https://drive.google.com/file/d/1mhF3yqd7R-Uje092eypktNl-RoZNuiCJ/view -- not valide anymore
10 | # use new link: https://drive.google.com/file/d/18_ekYXKLgd0H8XOE7jMJqQrlsoyhQU67/view?usp=drive_link
11 | mkdir -p ./assets/cloth-segmentation
12 | echo -e "Downloading cloth-segmentation model..."
13 | FILEID=18_ekYXKLgd0H8XOE7jMJqQrlsoyhQU67
14 | FILENAME=./assets/cloth-segmentation/cloth_segm_u2net_latest.pth
15 | wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id='${FILEID} -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=${FILEID}" -O $FILENAME && rm -rf /tmp/cookies.txt
16 |
17 | # trained model for PIXIE
18 | # if failed, please check https://github.com/yfeng95/PIXIE/blob/master/fetch_model.sh
19 | cd submodules/PIXIE
20 | echo -e "Downloading PIXIE data..."
21 | #!/bin/bash
22 | urle () { [[ "${1}" ]] || return 1; local LANG=C i x; for (( i = 0; i < ${#1}; i++ )); do x="${1:i:1}"; [[ "${x}" == [a-zA-Z0-9.~-] ]] && echo -n "${x}" || printf '%%%02X' "'${x}"; done; echo; }
23 |
24 | # SMPL-X 2020 (neutral SMPL-X model with the FLAME 2020 expression blendshapes)
25 | echo -e "\nYou need to register at https://smpl-x.is.tue.mpg.de"
26 | read -p "Username (SMPL-X):" username
27 | read -p "Password (SMPL-X):" password
28 | username=$(urle $username)
29 | password=$(urle $password)
30 | wget --post-data "username=$username&password=$password" 'https://download.is.tue.mpg.de/download.php?domain=smplx&sfile=SMPLX_NEUTRAL_2020.npz&resume=1' -O './data/SMPLX_NEUTRAL_2020.npz' --no-check-certificate --continue
31 |
32 | # PIXIE pretrained model and utilities
33 | echo -e "\nYou need to register at https://pixie.is.tue.mpg.de/"
34 | read -p "Username (PIXIE):" username
35 | read -p "Password (PIXIE):" password
36 | username=$(urle $username)
37 | password=$(urle $password)
38 | wget --post-data "username=$username&password=$password" 'https://download.is.tue.mpg.de/download.php?domain=pixie&sfile=pixie_model.tar&resume=1' -O './data/pixie_model.tar' --no-check-certificate --continue
39 | wget --post-data "username=$username&password=$password" 'https://download.is.tue.mpg.de/download.php?domain=pixie&sfile=utilities.zip&resume=1' -O './data/utilities.zip' --no-check-certificate --continue
40 |
41 | cd ./data
42 | unzip utilities.zip
--------------------------------------------------------------------------------
/process_data/lists/video_list.txt:
--------------------------------------------------------------------------------
1 | ../exps/mpiis/DSC_7157/DSC_7157.mp4
--------------------------------------------------------------------------------
/process_data/logs/generate_data.log:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yfeng95/SCARF/d8d89d27431120569380f16094d419df063c2fe4/process_data/logs/generate_data.log
--------------------------------------------------------------------------------
/process_data/process_video.py:
--------------------------------------------------------------------------------
1 | from ipaddress import ip_address
2 | import os, sys
3 | import argparse
4 | from tqdm import tqdm
5 | from pathlib import Path
6 | from loguru import logger
7 | from glob import glob
8 | from skimage.io import imread, imsave
9 | from skimage.transform import estimate_transform, warp, resize, rescale
10 | import shutil
11 | import torch
12 | import cv2
13 | import numpy as np
14 | from PIL import Image
15 |
16 | def get_palette(num_cls):
17 | """Returns the color map for visualizing the segmentation mask.
18 | Args:
19 | num_cls: Number of classes
20 | Returns:
21 | The color map
22 | """
23 | n = num_cls
24 | palette = [0] * (n * 3)
25 | for j in range(0, n):
26 | lab = j
27 | palette[j * 3 + 0] = 0
28 | palette[j * 3 + 1] = 0
29 | palette[j * 3 + 2] = 0
30 | i = 0
31 | while lab:
32 | palette[j * 3 + 0] |= ((lab >> 0) & 1) << (7 - i)
33 | palette[j * 3 + 1] |= ((lab >> 1) & 1) << (7 - i)
34 | palette[j * 3 + 2] |= ((lab >> 2) & 1) << (7 - i)
35 | i += 1
36 | lab >>= 3
37 | return palette
38 |
39 | def generate_frame(inputpath, savepath, subject_name=None, n_frames=2000, fps=30):
40 | ''' extract frames from video or copy frames from image folder
41 | '''
42 | os.makedirs(savepath, exist_ok=True)
43 | if subject_name is None:
44 | subject_name = Path(inputpath).stem
45 | ## video data
46 | if os.path.isfile(inputpath) and (os.path.splitext(inputpath)[-1] in ['.mp4', '.csv', '.MOV']):
47 | videopath = os.path.join(os.path.dirname(savepath), f'{subject_name}.mp4')
48 | logger.info(f'extract frames from video: {inputpath}..., then save to {videopath}')
49 | vidcap = cv2.VideoCapture(inputpath)
50 | count = 0
51 | success, image = vidcap.read()
52 | cv2.imwrite(os.path.join(savepath, f'{subject_name}_f{count:06d}.png'), image)
53 | h, w = image.shape[:2]
54 | print(h, w)
55 | # import imageio
56 | # savecap = imageio.get_writer(videopath, fps=fps)
57 | # savecap.append_data(image[:,:,::-1])
58 | while success:
59 | count += 1
60 | success,image = vidcap.read()
61 | if count > n_frames or image is None:
62 | break
63 | imagepath = os.path.join(savepath, f'{subject_name}_f{count:06d}.png')
64 | cv2.imwrite(imagepath, image) # save frame as JPEG png
65 | # savecap.append_data(image[:,:,::-1])
66 | logger.info(f'extracted {count} frames')
67 | elif os.path.isdir(inputpath):
68 | logger.info(f'copy frames from folder: {inputpath}...')
69 | imagepath_list = glob(inputpath + '/*.jpg') + glob(inputpath + '/*.png') + glob(inputpath + '/*.jpeg')
70 | imagepath_list = sorted(imagepath_list)
71 | for count, imagepath in enumerate(imagepath_list):
72 | shutil.copyfile(imagepath, os.path.join(savepath, f'{subject_name}_f{count:06d}.png'))
73 | print('frames are stored in {}'.format(savepath))
74 | else:
75 | logger.info(f'please check the input path: {inputpath}')
76 | logger.info(f'video frames are stored in {savepath}')
77 |
78 | def generate_image(inputpath, savepath, subject_name=None, crop=False, crop_each=False, image_size=512, scale_bbox=1.1, device='cuda:0'):
79 | ''' generate image from given frame path.
80 | '''
81 | logger.info(f'generae images, crop {crop}, image size {image_size}')
82 | os.makedirs(savepath, exist_ok=True)
83 | # load detection model
84 | from submodules.detector import FasterRCNN
85 | detector = FasterRCNN(device=device)
86 | if os.path.isdir(inputpath):
87 | imagepath_list = glob(inputpath + '/*.jpg') + glob(inputpath + '/*.png') + glob(inputpath + '/*.jpeg')
88 | imagepath_list = sorted(imagepath_list)
89 | # if crop, detect the bbox of the first image and use the bbox for all frames
90 | if crop:
91 | imagepath = imagepath_list[0]
92 | logger.info(f'detect first image {imagepath}')
93 | imagename = os.path.splitext(os.path.basename(imagepath))[0]
94 | image = imread(imagepath)[:,:,:3]/255.
95 | h, w, _ = image.shape
96 |
97 | image_tensor = torch.tensor(image.transpose(2,0,1), dtype=torch.float32)[None, ...]
98 | bbox = detector.run(image_tensor)
99 | left = bbox[0]; right = bbox[2]; top = bbox[1]; bottom = bbox[3]
100 | np.savetxt(os.path.join(Path(inputpath).parent, 'image_bbox.txt'), bbox)
101 |
102 | ## calculate warping function for image cropping
103 | old_size = max(right - left, bottom - top)
104 | center = np.array([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0])
105 | size = int(old_size*scale_bbox)
106 | src_pts = np.array([[center[0]-size/2, center[1]-size/2], [center[0] - size/2, center[1]+size/2], [center[0]+size/2, center[1]-size/2]])
107 | DST_PTS = np.array([[0,0], [0,image_size - 1], [image_size - 1, 0]])
108 | tform = estimate_transform('similarity', src_pts, DST_PTS)
109 |
110 | for count, imagepath in enumerate(tqdm(imagepath_list)):
111 | if crop:
112 | image = imread(imagepath)
113 | dst_image = warp(image, tform.inverse, output_shape=(image_size, image_size))
114 | dst_image = (dst_image*255).astype(np.uint8)
115 | imsave(os.path.join(savepath, f'{subject_name}_f{count:06d}.png'), dst_image)
116 | else:
117 | shutil.copyfile(imagepath, os.path.join(savepath, f'{subject_name}_f{count:06d}.png'))
118 | logger.info(f'images are stored in {savepath}')
119 |
120 | def generate_matting_rvm(inputpath, savepath, ckpt_path='assets/RobustVideoMatting/rvm_resnet50.pth', device='cuda:0'):
121 | sys.path.append('./submodules/RobustVideoMatting')
122 | from model import MattingNetwork
123 | EXTS = ['jpg', 'jpeg', 'png']
124 | segmentor = MattingNetwork(variant='resnet50').eval().to(device)
125 | segmentor.load_state_dict(torch.load(ckpt_path))
126 |
127 | images_folder = inputpath
128 | output_folder = savepath
129 | os.makedirs(output_folder, exist_ok=True)
130 |
131 | frame_IDs = os.listdir(images_folder)
132 | frame_IDs = [id.split('.')[0] for id in frame_IDs if id.split('.')[-1] in EXTS]
133 | frame_IDs.sort()
134 | frame_IDs = frame_IDs[:4][::-1] + frame_IDs
135 |
136 | rec = [None] * 4 # Initial recurrent
137 | downsample_ratio = 1.0 # Adjust based on your video.
138 |
139 | # bgr = torch.tensor([1, 1, 1.]).view(3, 1, 1).cuda()
140 | for i in tqdm(range(len(frame_IDs))):
141 | frame_ID = frame_IDs[i]
142 | img_path = os.path.join(images_folder, '{}.png'.format(frame_ID))
143 | try:
144 | img_masked_path = os.path.join(output_folder, '{}.png'.format(frame_ID))
145 | img = cv2.imread(img_path)
146 | src = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
147 | src = torch.from_numpy(src).float() / 255.
148 | src = src.permute(2, 0, 1).unsqueeze(0)
149 | with torch.no_grad():
150 | fgr, pha, *rec = segmentor(src.to(device), *rec, downsample_ratio) # Cycle the recurrent states.
151 | pha = pha.permute(0, 2, 3, 1).cpu().numpy().squeeze(0)
152 | # check the difference of current
153 | mask = (pha * 255).astype(np.uint8)
154 | img_masked = np.concatenate([img, mask], axis=-1)
155 | cv2.imwrite(img_masked_path, img_masked)
156 | except:
157 | os.remove(img_path)
158 | logger.info(f'matting failed for image {img_path}, delete it')
159 |
160 | sys.modules.pop('model')
161 |
162 | def generate_cloth_segmentation(inputpath, savepath, ckpt_path='assets/cloth-segmentation/cloth_segm_u2net_latest.pth', device='cuda:0', vis=False):
163 | logger.info(f'generate cloth segmentation for {inputpath}')
164 | os.makedirs(savepath, exist_ok=True)
165 | # load model
166 | sys.path.insert(0, './submodules/cloth-segmentation')
167 | import torch.nn.functional as F
168 | import torchvision.transforms as transforms
169 | from data.base_dataset import Normalize_image
170 | from utils.saving_utils import load_checkpoint_mgpu
171 | from networks import U2NET
172 |
173 | transforms_list = []
174 | transforms_list += [transforms.ToTensor()]
175 | transforms_list += [Normalize_image(0.5, 0.5)]
176 | transform_rgb = transforms.Compose(transforms_list)
177 |
178 | net = U2NET(in_ch=3, out_ch=4)
179 | net = load_checkpoint_mgpu(net, ckpt_path)
180 | net = net.to(device)
181 | net = net.eval()
182 |
183 | palette = get_palette(4)
184 |
185 | images_list = sorted(os.listdir(inputpath))
186 | pbar = tqdm(total=len(images_list))
187 | for image_name in tqdm(images_list):
188 | img = Image.open(os.path.join(inputpath, image_name)).convert("RGB")
189 | image_tensor = transform_rgb(img)
190 | image_tensor = torch.unsqueeze(image_tensor, 0)
191 |
192 | output_tensor = net(image_tensor.to(device))
193 | output_tensor = F.log_softmax(output_tensor[0], dim=1)
194 | output_tensor = torch.max(output_tensor, dim=1, keepdim=True)[1]
195 | output_tensor = torch.squeeze(output_tensor, dim=0)
196 | output_tensor = torch.squeeze(output_tensor, dim=0)
197 | output_arr = output_tensor.cpu().numpy()
198 |
199 | output_img = Image.fromarray(output_arr.astype("uint8"), mode="L")
200 | output_img.putpalette(palette)
201 | name = Path(image_name).stem
202 | output_img.save(os.path.join(savepath, f'{name}.png'))
203 | pbar.update(1)
204 | pbar.close()
205 |
206 | def generate_pixie(inputpath, savepath, ckpt_path='assets/face_normals/model.pth', device='cuda:0', image_size=512, vis=False):
207 | logger.info(f'generate pixie results')
208 | os.makedirs(savepath, exist_ok=True)
209 | # load model
210 | sys.path.insert(0, './submodules/PIXIE')
211 | from pixielib.pixie import PIXIE
212 | from pixielib.visualizer import Visualizer
213 | from pixielib.datasets.body_datasets import TestData
214 | from pixielib.utils import util
215 | from pixielib.utils.config import cfg as pixie_cfg
216 | from pixielib.utils.tensor_cropper import transform_points
217 | # run pixie
218 | testdata = TestData(inputpath, iscrop=False)
219 | pixie_cfg.model.use_tex = False
220 | pixie = PIXIE(config = pixie_cfg, device=device)
221 | visualizer = Visualizer(render_size=image_size, config = pixie_cfg, device=device, rasterizer_type='standard')
222 | testdata = TestData(inputpath, iscrop=False)
223 | for i, batch in enumerate(tqdm(testdata, dynamic_ncols=True)):
224 | batch['image'] = batch['image'].unsqueeze(0)
225 | batch['image_hd'] = batch['image_hd'].unsqueeze(0)
226 | name = batch['name']
227 | util.move_dict_to_device(batch, device)
228 | data = {
229 | 'body': batch
230 | }
231 | param_dict = pixie.encode(data, threthold=True, keep_local=True, copy_and_paste=False)
232 | codedict = param_dict['body']
233 | opdict = pixie.decode(codedict, param_type='body')
234 | util.save_pkl(os.path.join(savepath, f'{name}_param.pkl'), codedict)
235 | if vis:
236 | opdict['albedo'] = visualizer.tex_flame2smplx(opdict['albedo'])
237 | visdict = visualizer.render_results(opdict, data['body']['image_hd'], overlay=True, use_deca=False)
238 | cv2.imwrite(os.path.join(savepath, f'{name}_vis.jpg'), visualizer.visualize_grid(visdict, size=image_size))
239 |
240 | def process_video(subjectpath, savepath=None, vis=False, crop=False, crop_each=False, ignore_existing=False, n_frames=2000):
241 | if savepath is None:
242 | savepath = Path(subjectpath).parent
243 | subject_name = Path(subjectpath).stem
244 | savepath = os.path.join(savepath, subject_name)
245 | os.makedirs(savepath, exist_ok=True)
246 | logger.info(f'processing {subject_name}')
247 | # 0. copy frames from video or image folder
248 | if ignore_existing or not os.path.exists(os.path.join(savepath, 'frame')):
249 | generate_frame(subjectpath, os.path.join(savepath, 'frame'), n_frames=n_frames)
250 |
251 | # 1. crop image from frames, use fasterrcnn for detection
252 | if ignore_existing or not os.path.exists(os.path.join(savepath, 'image')):
253 | generate_image(os.path.join(savepath, 'frame'), os.path.join(savepath, 'image'), subject_name=subject_name,
254 | crop=crop, crop_each=crop_each, image_size=512, scale_bbox=1.1, device='cuda:0')
255 |
256 | # 2. video matting
257 | if ignore_existing or not os.path.exists(os.path.join(savepath, 'matting')):
258 | generate_matting_rvm(os.path.join(savepath, 'image'), os.path.join(savepath, 'matting'))
259 |
260 | # 3. cloth segmentation
261 | if ignore_existing or not os.path.exists(os.path.join(savepath, 'cloth_segmentation')):
262 | generate_cloth_segmentation(os.path.join(savepath, 'image'), os.path.join(savepath, 'cloth_segmentation'), vis=vis)
263 |
264 | # 4. smplx estimation using PIXIE (https://github.com/yfeng95/PIXIE)
265 | if ignore_existing or not os.path.exists(os.path.join(savepath, 'pixie')):
266 | generate_pixie(os.path.join(savepath, 'image'), os.path.join(savepath, 'pixie'), vis=vis)
267 | logger.info(f'finish {subject_name}')
268 |
269 | def main(args):
270 | logger.add(args.logpath)
271 |
272 | with open(args.list, 'r') as f:
273 | lines = f.readlines()
274 | subject_list = [s.strip() for s in lines]
275 | if args.subject_idx is not None:
276 | if args.subject_idx > len(subject_list):
277 | print('idx error!')
278 | else:
279 | subject_list = [subject_list[args.subject_idx]]
280 |
281 | for subjectpath in tqdm(subject_list):
282 | process_video(subjectpath, savepath=args.savepath, vis=args.vis, crop=args.crop, crop_each=args.crop_each, ignore_existing=args.ignore_existing,
283 | n_frames=args.n_frames)
284 |
285 | if __name__ == "__main__":
286 | parser = argparse.ArgumentParser(description='generate dataset from video or image folder')
287 | parser.add_argument('--list', default='lists/video_list.txt', type=str,
288 | help='path to the subject data, can be image folder or video')
289 | parser.add_argument('--logpath', default='logs/generate_data.log', type=str,
290 | help='path to save log')
291 | parser.add_argument('--savepath', default=None, type=str,
292 | help='path to save processed data, if not specified, then save to the same folder as the subject data')
293 | parser.add_argument('--subject_idx', default=None, type=int,
294 | help='specify subject idx, if None (default), then use all the subject data in the list')
295 | parser.add_argument("--image_size", default=512, type=int,
296 | help = 'image size')
297 | parser.add_argument("--crop", default=True, action="store_true",
298 | help='whether to crop image according to the subject detection bbox')
299 | parser.add_argument("--crop_each", default=False, action="store_true",
300 | help='TODO, whether to crop image according for each frame in the video')
301 | parser.add_argument("--vis", default=True, action="store_true",
302 | help='whether to visualize labels (lmk, iris, face parsing)')
303 | parser.add_argument("--ignore_existing", default=False, action="store_true",
304 | help='ignore existing data')
305 | parser.add_argument("--filter_data", default=False, action="store_true",
306 | help='check labels, if it is not good, then delete image')
307 | parser.add_argument("--n_frames", default=400, type=int,
308 | help='number of frames to be processed')
309 | args = parser.parse_args()
310 |
311 | main(args)
312 |
313 |
--------------------------------------------------------------------------------
/process_data/submodules/detector.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import argparse
4 | import cv2
5 | import scipy
6 | from skimage.io import imread, imsave
7 | from skimage.transform import estimate_transform, warp, resize, rescale
8 | from glob import glob
9 | import os
10 | from tqdm import tqdm
11 | import shutil
12 |
13 | '''
14 | For cropping body:
15 | 1. using bbox from objection detectors
16 | 2. calculate bbox from body joints regressor
17 | object detectors:
18 | know body object from label number
19 | https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
20 | label for peopel: 1
21 | COCO_INSTANCE_CATEGORY_NAMES = [
22 | '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
23 | 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
24 | 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
25 | 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
26 | 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
27 | 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
28 | 'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
29 | 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
30 | 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
31 | 'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
32 | 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
33 | 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
34 | ]
35 | '''
36 | import os
37 | import os.path as osp
38 | import sys
39 | import numpy as np
40 | import cv2
41 |
42 | import torch
43 | import torchvision.transforms as transforms
44 | # from PIL import Image
45 |
46 | class FasterRCNN(object):
47 | ''' detect body
48 | '''
49 | def __init__(self, device='cuda:0'):
50 | '''
51 | https://pytorch.org/docs/stable/torchvision/models.html#faster-r-cnn
52 | '''
53 | import torchvision
54 | self.model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
55 | self.model.to(device)
56 | self.model.eval()
57 | self.device = device
58 |
59 | @torch.no_grad()
60 | def run(self, input):
61 | '''
62 | input:
63 | The input to the model is expected to be a list of tensors,
64 | each of shape [C, H, W], one for each image, and should be in 0-1 range.
65 | Different images can have different sizes.
66 | return:
67 | detected box, [x1, y1, x2, y2]
68 | '''
69 | prediction = self.model(input.to(self.device))[0]
70 | inds = (prediction['labels']==1)*(prediction['scores']>0.5)
71 | if len(inds) < 1 or inds.sum()<0.5:
72 | return None
73 | else:
74 | # if inds.sum()<0.5:
75 | # inds = (prediction['labels']==1)*(prediction['scores']>0.05)
76 | bbox = prediction['boxes'][inds][0].cpu().numpy()
77 | return bbox
78 |
79 | @torch.no_grad()
80 | def run_multi(self, input):
81 | '''
82 | input:
83 | The input to the model is expected to be a list of tensors,
84 | each of shape [C, H, W], one for each image, and should be in 0-1 range.
85 | Different images can have different sizes.
86 | return:
87 | detected box, [x1, y1, x2, y2]
88 | '''
89 | prediction = self.model(input.to(self.device))[0]
90 | inds = (prediction['labels']==1)*(prediction['scores']>0.9)
91 | if len(inds) < 1:
92 | return None
93 | else:
94 | bbox = prediction['boxes'][inds].cpu().numpy()
95 | return bbox
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # to install all this requirements
2 | #`pip install -r requirements.txt`
3 | torch
4 | torchvision
5 |
6 | scikit-image
7 | opencv-python
8 | chumpy
9 | matplotlib
10 | loguru
11 | tqdm
12 | yacs
13 | wandb
14 | PyMCubes
15 | trimesh
16 | # for pixie
17 | kornia==0.4.0
18 | PyYAML==5.1.1
19 | # -- evaluation
20 | lpips
21 | torchmetrics
22 | ipdb
23 | # -- pytorch3d, if failed, go to its install page for more information
24 | # git+https://github.com/facebookresearch/pytorch3d.git
25 | # face-alignment
26 | # imageio
27 | # others
28 | # kornia
29 | # pytorch3d
30 | # trimesh
31 | # einops
32 | # pytorch_lightning
33 | # imageio-ffmpeg
34 | # PyMCubes
35 | # open3d
36 | # plyfile
37 | #
38 | # nerfacc
39 | # packaging
40 | # "git+https://github.com/facebookresearch/pytorch3d.git"
41 | # git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch
--------------------------------------------------------------------------------
/train.sh:
--------------------------------------------------------------------------------
1 | # train nerf
2 | python main_train.py --data_cfg configs/data/mpiis/DSC_7157.yml --exp_cfg configs/exp/stage_0_nerf.yml
3 |
4 | # train hybrid
5 | python main_train.py --data_cfg configs/data/mpiis/DSC_7157.yml --exp_cfg configs/exp/stage_1_hybrid_perceptual.yml --clean
6 | ## Note: this one uses perceptual loss rather than mrf loss as described in the paper.
7 | ## The reason is: mrf loss works for v100. But for a100, the loss will be NAN.
8 | ## if you want to use mrf loss, make sure you are using v100, then run:
9 | ## python main_train.py --data_cfg configs/data/mpiis/DSC_7157.yml --exp_cfg configs/exp/stage_1_hybrid.yml --clean
10 |
11 | #--- training male-3-casual example
12 | # python main_train.py --data_cfg configs/data/snapshot/male-3-casual.yml --exp_cfg configs/exp/stage_0_nerf.yml
13 |
--------------------------------------------------------------------------------