├── LICENSE ├── README.md ├── assets ├── append1_00.png ├── append2_00.png ├── closed_mano_faces.pkl └── pipeline_00.png ├── data ├── __pycache__ │ └── grab_test.cpython-38.pyc └── grab_test.py ├── eval.py ├── grab_data ├── grab_demo.npy ├── obj_meshes │ ├── airplane.ply │ ├── alarmclock.ply │ ├── apple.ply │ ├── banana.ply │ ├── binoculars.ply │ ├── body.ply │ ├── bowl.ply │ ├── camera.ply │ ├── coffeemug.ply │ ├── cubelarge.ply │ ├── cubemedium.ply │ ├── cubemiddle.ply │ ├── cubesmall.ply │ ├── cup.ply │ ├── cylinderlarge.ply │ ├── cylindermedium.ply │ ├── cylindersmall.ply │ ├── doorknob.ply │ ├── duck.ply │ ├── elephant.ply │ ├── eyeglasses.ply │ ├── flashlight.ply │ ├── flute.ply │ ├── fryingpan.ply │ ├── gamecontroller.ply │ ├── hammer.ply │ ├── hand.ply │ ├── headphones.ply │ ├── knife.ply │ ├── lightbulb.ply │ ├── mouse.ply │ ├── mug.ply │ ├── phone.ply │ ├── piggybank.ply │ ├── pyramidlarge.ply │ ├── pyramidmedium.ply │ ├── pyramidsmall.ply │ ├── rubberduck.ply │ ├── scissors.ply │ ├── spherelarge.ply │ ├── spheremedium.ply │ ├── spheresmall.ply │ ├── stamp.ply │ ├── stanfordbunny.ply │ ├── stapler.ply │ ├── table.ply │ ├── teapot.ply │ ├── toothbrush.ply │ ├── toothpaste.ply │ ├── toruslarge.ply │ ├── torusmedium.ply │ ├── torussmall.ply │ ├── train.ply │ ├── watch.ply │ ├── waterbottle.ply │ ├── wineglass.ply │ └── wristwatch.ply ├── test │ ├── frame_names.npz │ └── grabnet_test.npz ├── train │ ├── frame_names.npz │ └── grabnet_train.npz └── val │ ├── frame_names.npz │ └── grabnet_val.npz ├── mano ├── __pycache__ │ └── mano_models.cpython-38.pyc └── mano_models.py ├── models ├── __pycache__ │ ├── cond_diffusion_model.cpython-38.pyc │ ├── contact_conditional_module.cpython-38.pyc │ ├── contact_conditional_moudle.cpython-38.pyc │ ├── pointnet.cpython-38.pyc │ ├── semantic_conditional_module.cpython-38.pyc │ ├── semantic_conditional_moudle.cpython-38.pyc │ └── transformer_module.cpython-38.pyc ├── cond_diffusion_model.py ├── contact_conditional_module.py ├── mlp.py ├── pointnet.py ├── resnet.py ├── semantic_conditional_module.py └── transformer_module.py └── requirements.txt /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ClickDiff 2 | [![Paper](https://img.shields.io/badge/cs.CV-Paper-b31b1b?logo=arxiv&logoColor=red)](https://arxiv.org/abs/2407.19370) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/clickdiff-click-to-induce-semantic-contact/controllable-grasp-generation-on-grab)](https://paperswithcode.com/sota/controllable-grasp-generation-on-grab?p=clickdiff-click-to-induce-semantic-contact) 3 | 4 | Official code for "ClickDiff: Click to Induce Semantic Contact Map for Controllable Grasp Generation with Diffusion Models", ACM MM, Oral Paper, 2024. 5 | 6 | ![pipeline_00](assets/pipeline_00.png) 7 | 8 | ## Download dataset 9 | 10 | 1. **GRAB** dataset from [https://grab.is.tue.mpg.de/](https://grab.is.tue.mpg.de/) 11 | 2. **ARCTIC** dataset from [https://arctic.is.tue.mpg.de/](https://arctic.is.tue.mpg.de/) 12 | 13 | 14 | ## RUN 15 | 16 | Click [here](https://drive.google.com/drive/folders/1bnJjyJbSrf1978lCh80Zo8gaHdu8K_wp?usp=sharing) to download our weights and place them in the `checkpoint` directory. 17 | Generate results on the test set and save them to `exp/demo` 18 | 19 | ```python 20 | python eval.py 21 | ``` 22 | 23 | ![append1_00](assets/append1_00.png) 24 | 25 | 26 | ![append2_00](assets/append2_00.png) 27 | -------------------------------------------------------------------------------- /assets/append1_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/assets/append1_00.png -------------------------------------------------------------------------------- /assets/append2_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/assets/append2_00.png -------------------------------------------------------------------------------- /assets/closed_mano_faces.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/assets/closed_mano_faces.pkl -------------------------------------------------------------------------------- /assets/pipeline_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/assets/pipeline_00.png -------------------------------------------------------------------------------- /data/__pycache__/grab_test.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/data/__pycache__/grab_test.cpython-38.pyc -------------------------------------------------------------------------------- /data/grab_test.py: -------------------------------------------------------------------------------- 1 | import os.path as op 2 | import sys 3 | 4 | import numpy as np 5 | import torch 6 | from loguru import logger 7 | from torch.utils.data import Dataset 8 | from ipdb import set_trace as st 9 | 10 | class GrabDataset(Dataset): 11 | def __getitem__(self, index): 12 | idx = self.idxs[index] 13 | data = self.getitem(idx) 14 | return data 15 | 16 | def getitem(self, idx, load_rgb=True): 17 | 18 | data_input_all =self.data[idx,:] 19 | data_input_dict = {} 20 | data_input_dict['seq_len'] = data_input_all.shape[0] 21 | data_input_dict['motion'] = data_input_all.unsqueeze(0) 22 | data_input_dict['name'] = self.frame_names[idx] 23 | data_input_dict['targets'] ={} 24 | 25 | for key in self.targets: 26 | if key != 'v' and key != 'mask' and key != 'parts_ids': 27 | if len(self.targets[key]) == self.targets['v_sub'].shape[0]: 28 | data_input_dict['targets'][key] = self.targets[key][idx] 29 | else: 30 | data_input_dict['targets'][key] = self.targets[key] 31 | 32 | return data_input_dict 33 | 34 | 35 | def _load_data(self): 36 | data_p = op.join( 37 | f"grab_data/grab_demo.npy" 38 | ) 39 | 40 | logger.info(f"Loading {data_p}") 41 | data = np.load(data_p, allow_pickle=True).item() 42 | # st() 43 | self.data = data["data_dict"] 44 | self.idxs = data["imgnames"] 45 | self.frame_names = data["frame_names"] 46 | self.targets = data["targets"] 47 | 48 | 49 | def __init__(self, args='', split='', seq=None): 50 | self._load_data() 51 | logger.info( 52 | f"Dataset Loaded, num samples {len(self.idxs)}" 53 | ) 54 | 55 | def __len__(self): 56 | return len(self.idxs) 57 | 58 | 59 | 60 | if __name__ == "__main__": 61 | ds = GrabDataset(args='',split='') 62 | ds.__getitem__(0) 63 | ds.__getitem__(1) -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('.') 3 | sys.path.append('..') 4 | 5 | import argparse 6 | import os 7 | from pathlib import Path 8 | import numpy as np 9 | import os.path as op 10 | import pickle 11 | 12 | import torch 13 | from mano.mano_models import build_mano_aa 14 | 15 | from torch.utils import data 16 | from ipdb import set_trace as st 17 | import trimesh 18 | from models.semantic_conditional_module import get_moudle as moudle_1 19 | from models.contact_conditional_module import get_moudle as moudle_2 20 | from data.grab_test import GrabDataset 21 | 22 | 23 | def test(opt, device): 24 | diffusion_moudle_1 = moudle_1(opt) 25 | diffusion_moudle_2 = moudle_2(opt) 26 | idxs = 0 27 | 28 | diffusion_weight_path_1 = "checkpoint/semantic_conditional_module.pt" 29 | diffusion_weight_path_2 = "checkpoint/contact_conditional_module.pt" 30 | diffusion_moudle_1.load_weight_path(diffusion_weight_path_1) 31 | diffusion_moudle_2.load_weight_path(diffusion_weight_path_2) 32 | 33 | val_dataset = GrabDataset() 34 | val_dl = data.DataLoader(val_dataset, batch_size=3000, shuffle=False, pin_memory=True, num_workers=0) 35 | 36 | new_item = {} 37 | with torch.no_grad(): 38 | for i, item in enumerate(val_dl): 39 | test_input_data_dict= item #item["motion"].shape torch.Size([32, 2531]) 40 | one_motion = test_input_data_dict['motion'] 41 | 42 | res_list_1 = diffusion_moudle_1.full_body_gen_cond_head_pose_sliding_window(one_motion.to(device)) 43 | one_motion[:,:,6205:8253] = res_list_1 44 | all_res_list = diffusion_moudle_2.full_body_gen_cond_head_pose_sliding_window(one_motion.to(device)) 45 | 46 | preds_new = my_process_data(preds=all_res_list,namelist=test_input_data_dict['name']) 47 | with open("assets/closed_mano_faces.pkl", 'rb') as f: 48 | hand_face = pickle.load(f) 49 | 50 | hand_verts = preds_new["manov3d.r"] 51 | exp_name = 'demo' 52 | save_dir = f'exp/{exp_name}' 53 | 54 | aa_name = test_input_data_dict['name'] 55 | for i in range(len(hand_verts)): 56 | hand_mesh = trimesh.Trimesh(vertices=hand_verts[i], faces=hand_face) 57 | parts = aa_name[i].split('/') 58 | relevant_parts = [parts[2]] + parts[3].split('_') + [parts[-1].split('.')[0]] 59 | formatted_string = '_'.join(relevant_parts) 60 | formatted_string = exp_name+'_'+ formatted_string 61 | # st() 62 | hand_mesh.export(os.path.join(save_dir, f'{formatted_string}.obj'.format(i))) 63 | 64 | def my_process_data(preds,namelist): 65 | models = {'mano_r':build_mano_aa(is_rhand=True,flat_hand=True)} 66 | 67 | targets=dict() 68 | for i in range(len(namelist)): 69 | 70 | rot_r = preds[i][0][:3] 71 | pose_r = preds[i][0][3:48] 72 | trans_r = preds[i][0][48:51] 73 | betas_r = preds[i][0][51:61] 74 | 75 | pose_r = np.concatenate((rot_r.to('cpu'), pose_r.cpu()), axis=0) 76 | 77 | if i == 0: 78 | targets["mano.pose.r"] = torch.from_numpy(pose_r).float().unsqueeze(0).to('cpu') 79 | targets["mano.beta.r"] = np.expand_dims(betas_r.cpu().numpy(), axis=0) 80 | targets["mano.trans.r"] = np.expand_dims(trans_r.cpu().numpy(), axis=0) 81 | else: 82 | targets["mano.pose.r"] = torch.cat([targets["mano.pose.r"],torch.from_numpy(pose_r).float().unsqueeze(0).to('cpu')],dim=0) 83 | targets["mano.beta.r"] = np.concatenate([targets["mano.beta.r"],np.expand_dims(betas_r.cpu().numpy(), axis=0)],axis=0) 84 | targets["mano.trans.r"] = np.concatenate([targets["mano.trans.r"],np.expand_dims(trans_r.cpu().numpy(), axis=0)],axis=0) 85 | # st() 86 | 87 | gt_pose_r = targets["mano.pose.r"] 88 | gt_betas_r = targets["mano.beta.r"] 89 | gt_trans_r = targets["mano.trans.r"] 90 | 91 | temp_gt_out_r = models["mano_r"]( 92 | betas=torch.from_numpy(gt_betas_r), 93 | hand_pose=gt_pose_r[:, 3:], 94 | global_orient=gt_pose_r[:, :3], 95 | transl=torch.from_numpy(gt_trans_r), 96 | ) 97 | targets["manoj21.r"] = temp_gt_out_r.joints 98 | targets["manov3d.r"] = temp_gt_out_r.vertices 99 | 100 | return targets 101 | 102 | def parse_opt(): 103 | parser = argparse.ArgumentParser() 104 | parser.add_argument('--workers', type=int, default=0, help='the number of workers for data loading') 105 | parser.add_argument('--device', default='0', help='cuda device') 106 | parser.add_argument('--weight', default='latest') 107 | parser.add_argument("--gen_vis", action="store_true") 108 | # For AvatarPoser config 109 | parser.add_argument('--kinpoly_cfg', type=str, default="", help='Path to option JSON file.') 110 | # Diffusion model settings 111 | parser.add_argument('--diffusion_window', type=int, default=1, help='horizon') 112 | parser.add_argument('--diffusion_batch_size', type=int, default=200, help='batch size') 113 | parser.add_argument('--diffusion_learning_rate', type=float, default=1e-5, help='generator_learning_rate') 114 | 115 | parser.add_argument('--diffusion_n_dec_layers', type=int, default=4, help='the number of decoder layers') 116 | parser.add_argument('--diffusion_n_head', type=int, default=4, help='the number of heads in self-attention') 117 | parser.add_argument('--diffusion_d_k', type=int, default=256, help='the dimension of keys in transformer') 118 | parser.add_argument('--diffusion_d_v', type=int, default=256, help='the dimension of values in transformer') 119 | parser.add_argument('--diffusion_d_model', type=int, default=512, help='the dimension of intermediate representation in transformer') 120 | 121 | parser.add_argument('--diffusion_project', default='runs/test', help='project/name') 122 | parser.add_argument('--diffusion_exp_name', default='', help='save to project/name') 123 | 124 | # For data representation 125 | parser.add_argument("--canonicalize_init_head", action="store_true") 126 | parser.add_argument("--use_min_max", action="store_true") 127 | 128 | parser.add_argument('--data_root_folder', default='', help='') 129 | 130 | opt = parser.parse_args() 131 | return opt 132 | 133 | if __name__ == "__main__": 134 | opt = parse_opt() 135 | opt.diffusion_save_dir = str(Path(opt.diffusion_project) / opt.diffusion_exp_name) 136 | device = torch.device(f"cuda:{opt.device}" if torch.cuda.is_available() else "cpu") 137 | test(opt, device) -------------------------------------------------------------------------------- /grab_data/grab_demo.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/grab_data/grab_demo.npy -------------------------------------------------------------------------------- /grab_data/obj_meshes/airplane.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/grab_data/obj_meshes/airplane.ply -------------------------------------------------------------------------------- /grab_data/obj_meshes/alarmclock.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/grab_data/obj_meshes/alarmclock.ply -------------------------------------------------------------------------------- /grab_data/obj_meshes/apple.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/grab_data/obj_meshes/apple.ply -------------------------------------------------------------------------------- /grab_data/obj_meshes/banana.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/grab_data/obj_meshes/banana.ply -------------------------------------------------------------------------------- /grab_data/obj_meshes/binoculars.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/grab_data/obj_meshes/binoculars.ply -------------------------------------------------------------------------------- /grab_data/obj_meshes/body.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/grab_data/obj_meshes/body.ply -------------------------------------------------------------------------------- /grab_data/obj_meshes/bowl.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/grab_data/obj_meshes/bowl.ply -------------------------------------------------------------------------------- /grab_data/obj_meshes/camera.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/grab_data/obj_meshes/camera.ply -------------------------------------------------------------------------------- /grab_data/obj_meshes/coffeemug.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/grab_data/obj_meshes/coffeemug.ply -------------------------------------------------------------------------------- /grab_data/obj_meshes/cubelarge.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/grab_data/obj_meshes/cubelarge.ply -------------------------------------------------------------------------------- /grab_data/obj_meshes/cubemedium.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/grab_data/obj_meshes/cubemedium.ply -------------------------------------------------------------------------------- /grab_data/obj_meshes/cubemiddle.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/grab_data/obj_meshes/cubemiddle.ply -------------------------------------------------------------------------------- /grab_data/obj_meshes/cubesmall.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/grab_data/obj_meshes/cubesmall.ply -------------------------------------------------------------------------------- /grab_data/obj_meshes/cup.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/grab_data/obj_meshes/cup.ply -------------------------------------------------------------------------------- /grab_data/obj_meshes/cylinderlarge.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/grab_data/obj_meshes/cylinderlarge.ply -------------------------------------------------------------------------------- /grab_data/obj_meshes/cylindermedium.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/grab_data/obj_meshes/cylindermedium.ply -------------------------------------------------------------------------------- /grab_data/obj_meshes/cylindersmall.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/grab_data/obj_meshes/cylindersmall.ply -------------------------------------------------------------------------------- /grab_data/obj_meshes/doorknob.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/grab_data/obj_meshes/doorknob.ply -------------------------------------------------------------------------------- /grab_data/obj_meshes/duck.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/grab_data/obj_meshes/duck.ply -------------------------------------------------------------------------------- /grab_data/obj_meshes/elephant.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/grab_data/obj_meshes/elephant.ply -------------------------------------------------------------------------------- /grab_data/obj_meshes/eyeglasses.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/grab_data/obj_meshes/eyeglasses.ply -------------------------------------------------------------------------------- /grab_data/obj_meshes/flashlight.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/grab_data/obj_meshes/flashlight.ply -------------------------------------------------------------------------------- /grab_data/obj_meshes/flute.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/grab_data/obj_meshes/flute.ply -------------------------------------------------------------------------------- /grab_data/obj_meshes/fryingpan.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/grab_data/obj_meshes/fryingpan.ply -------------------------------------------------------------------------------- /grab_data/obj_meshes/gamecontroller.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/grab_data/obj_meshes/gamecontroller.ply -------------------------------------------------------------------------------- /grab_data/obj_meshes/hammer.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/grab_data/obj_meshes/hammer.ply -------------------------------------------------------------------------------- /grab_data/obj_meshes/hand.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/grab_data/obj_meshes/hand.ply -------------------------------------------------------------------------------- /grab_data/obj_meshes/headphones.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/grab_data/obj_meshes/headphones.ply -------------------------------------------------------------------------------- /grab_data/obj_meshes/knife.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/grab_data/obj_meshes/knife.ply -------------------------------------------------------------------------------- /grab_data/obj_meshes/lightbulb.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/grab_data/obj_meshes/lightbulb.ply -------------------------------------------------------------------------------- /grab_data/obj_meshes/mouse.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/grab_data/obj_meshes/mouse.ply -------------------------------------------------------------------------------- /grab_data/obj_meshes/mug.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/grab_data/obj_meshes/mug.ply -------------------------------------------------------------------------------- /grab_data/obj_meshes/phone.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/grab_data/obj_meshes/phone.ply -------------------------------------------------------------------------------- /grab_data/obj_meshes/piggybank.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/grab_data/obj_meshes/piggybank.ply -------------------------------------------------------------------------------- /grab_data/obj_meshes/pyramidlarge.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/grab_data/obj_meshes/pyramidlarge.ply -------------------------------------------------------------------------------- /grab_data/obj_meshes/pyramidmedium.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/grab_data/obj_meshes/pyramidmedium.ply -------------------------------------------------------------------------------- /grab_data/obj_meshes/pyramidsmall.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/grab_data/obj_meshes/pyramidsmall.ply -------------------------------------------------------------------------------- /grab_data/obj_meshes/rubberduck.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/grab_data/obj_meshes/rubberduck.ply -------------------------------------------------------------------------------- /grab_data/obj_meshes/scissors.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/grab_data/obj_meshes/scissors.ply -------------------------------------------------------------------------------- /grab_data/obj_meshes/spherelarge.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/grab_data/obj_meshes/spherelarge.ply -------------------------------------------------------------------------------- /grab_data/obj_meshes/spheremedium.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/grab_data/obj_meshes/spheremedium.ply -------------------------------------------------------------------------------- /grab_data/obj_meshes/spheresmall.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/grab_data/obj_meshes/spheresmall.ply -------------------------------------------------------------------------------- /grab_data/obj_meshes/stamp.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/grab_data/obj_meshes/stamp.ply -------------------------------------------------------------------------------- /grab_data/obj_meshes/stanfordbunny.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/grab_data/obj_meshes/stanfordbunny.ply -------------------------------------------------------------------------------- /grab_data/obj_meshes/stapler.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/grab_data/obj_meshes/stapler.ply -------------------------------------------------------------------------------- /grab_data/obj_meshes/table.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/grab_data/obj_meshes/table.ply -------------------------------------------------------------------------------- /grab_data/obj_meshes/teapot.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/grab_data/obj_meshes/teapot.ply -------------------------------------------------------------------------------- /grab_data/obj_meshes/toothbrush.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/grab_data/obj_meshes/toothbrush.ply -------------------------------------------------------------------------------- /grab_data/obj_meshes/toothpaste.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/grab_data/obj_meshes/toothpaste.ply -------------------------------------------------------------------------------- /grab_data/obj_meshes/toruslarge.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/grab_data/obj_meshes/toruslarge.ply -------------------------------------------------------------------------------- /grab_data/obj_meshes/torusmedium.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/grab_data/obj_meshes/torusmedium.ply -------------------------------------------------------------------------------- /grab_data/obj_meshes/torussmall.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/grab_data/obj_meshes/torussmall.ply -------------------------------------------------------------------------------- /grab_data/obj_meshes/train.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/grab_data/obj_meshes/train.ply -------------------------------------------------------------------------------- /grab_data/obj_meshes/watch.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/grab_data/obj_meshes/watch.ply -------------------------------------------------------------------------------- /grab_data/obj_meshes/waterbottle.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/grab_data/obj_meshes/waterbottle.ply -------------------------------------------------------------------------------- /grab_data/obj_meshes/wineglass.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/grab_data/obj_meshes/wineglass.ply -------------------------------------------------------------------------------- /grab_data/obj_meshes/wristwatch.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/grab_data/obj_meshes/wristwatch.ply -------------------------------------------------------------------------------- /grab_data/test/frame_names.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/grab_data/test/frame_names.npz -------------------------------------------------------------------------------- /grab_data/test/grabnet_test.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/grab_data/test/grabnet_test.npz -------------------------------------------------------------------------------- /grab_data/train/frame_names.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/grab_data/train/frame_names.npz -------------------------------------------------------------------------------- /grab_data/train/grabnet_train.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/grab_data/train/grabnet_train.npz -------------------------------------------------------------------------------- /grab_data/val/frame_names.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/grab_data/val/frame_names.npz -------------------------------------------------------------------------------- /grab_data/val/grabnet_val.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/grab_data/val/grabnet_val.npz -------------------------------------------------------------------------------- /mano/__pycache__/mano_models.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/mano/__pycache__/mano_models.cpython-38.pyc -------------------------------------------------------------------------------- /mano/mano_models.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | import numpy as np 4 | import torch 5 | from smplx import MANO 6 | from ipdb import set_trace as st 7 | 8 | 9 | MODEL_DIR = "/data-home/arctic-master/data/body_models/mano" 10 | SEAL_FACES_R = [ 11 | [120, 108, 778], 12 | [108, 79, 778], 13 | [79, 78, 778], 14 | [78, 121, 778], 15 | [121, 214, 778], 16 | [214, 215, 778], 17 | [215, 279, 778], 18 | [279, 239, 778], 19 | [239, 234, 778], 20 | [234, 92, 778], 21 | [92, 38, 778], 22 | [38, 122, 778], 23 | [122, 118, 778], 24 | [118, 117, 778], 25 | [117, 119, 778], 26 | [119, 120, 778], 27 | ] 28 | 29 | # vertex ids around the ring of the wrist 30 | CIRCLE_V_ID = np.array( 31 | [108, 79, 78, 121, 214, 215, 279, 239, 234, 92, 38, 122, 118, 117, 119, 120], 32 | dtype=np.int64, 33 | ) 34 | 35 | 36 | def seal_mano_mesh(v3d, faces, is_rhand): 37 | # v3d: B, 778, 3 38 | # faces: 1538, 3 39 | # output: v3d(B, 779, 3); faces (1554, 3) 40 | 41 | seal_faces = torch.LongTensor(np.array(SEAL_FACES_R)).to(faces.device) 42 | if not is_rhand: 43 | # left hand 44 | seal_faces = seal_faces[:, np.array([1, 0, 2])] # invert face normal 45 | centers = v3d[:, CIRCLE_V_ID].mean(dim=1)[:, None, :] 46 | sealed_vertices = torch.cat((v3d, centers), dim=1) 47 | faces = torch.cat((faces, seal_faces), dim=0) 48 | return sealed_vertices, faces 49 | 50 | 51 | MANO_MODEL_DIR = "/data-home/arctic-master/data/body_models/mano" 52 | 53 | def build_mano_aa(is_rhand, create_transl=False, flat_hand=False): 54 | return MANO( 55 | MODEL_DIR, 56 | create_transl=create_transl, 57 | use_pca=False, 58 | flat_hand_mean=flat_hand, 59 | is_rhand=is_rhand, 60 | ) 61 | 62 | 63 | def construct_layers(dev): 64 | mano_layers = { 65 | "right": build_mano_aa(True, create_transl=True, flat_hand=False), 66 | "left": build_mano_aa(False, create_transl=True, flat_hand=False) 67 | } 68 | for layer in mano_layers.values(): 69 | layer.to(dev) 70 | return mano_layers 71 | -------------------------------------------------------------------------------- /models/__pycache__/cond_diffusion_model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/models/__pycache__/cond_diffusion_model.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/contact_conditional_module.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/models/__pycache__/contact_conditional_module.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/contact_conditional_moudle.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/models/__pycache__/contact_conditional_moudle.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/pointnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/models/__pycache__/pointnet.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/semantic_conditional_module.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/models/__pycache__/semantic_conditional_module.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/semantic_conditional_moudle.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/models/__pycache__/semantic_conditional_moudle.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/transformer_module.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adventurer-w/ClickDiff/d476810e212945ed46c8c5025e5bfb5791267886/models/__pycache__/transformer_module.cpython-38.pyc -------------------------------------------------------------------------------- /models/cond_diffusion_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from models.transformer_module import Decoder 5 | 6 | import os 7 | import math 8 | from tqdm.auto import tqdm 9 | from einops import rearrange, reduce 10 | from einops.layers.torch import Rearrange 11 | from inspect import isfunction 12 | import torch.nn.functional as F 13 | from ipdb import set_trace as st 14 | 15 | def exists(x): 16 | return x is not None 17 | 18 | def default(val, d): 19 | if exists(val): 20 | return val 21 | return d() if isfunction(d) else d 22 | 23 | def extract(a, t, x_shape): 24 | b, *_ = t.shape 25 | out = a.gather(-1, t) 26 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 27 | 28 | def linear_beta_schedule(timesteps): 29 | scale = 1000 / timesteps 30 | beta_start = scale * 0.0001 31 | beta_end = scale * 0.02 32 | return torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64) 33 | 34 | def cosine_beta_schedule(timesteps, s = 0.008): 35 | """ 36 | cosine schedule 37 | as proposed in https://openreview.net/forum?id=-NEXDKk8gZ 38 | """ 39 | steps = timesteps + 1 40 | x = torch.linspace(0, timesteps, steps, dtype = torch.float64) 41 | alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2 42 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0] 43 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) 44 | return torch.clip(betas, 0, 0.999) 45 | 46 | 47 | class SinusoidalPosEmb(nn.Module): 48 | def __init__(self, dim): 49 | super().__init__() 50 | self.dim = dim 51 | 52 | def forward(self, x): 53 | device = x.device 54 | half_dim = self.dim // 2 55 | emb = math.log(10000) / (half_dim - 1) 56 | emb = torch.exp(torch.arange(half_dim, device=device) * -emb) 57 | emb = x[:, None] * emb[None, :] 58 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1) 59 | return emb 60 | 61 | class TransformerDiffusionModel(nn.Module): 62 | def __init__( 63 | self, 64 | d_feats, 65 | d_condfeat, 66 | d_model, 67 | n_dec_layers, 68 | n_head, 69 | d_k, 70 | d_v, 71 | max_timesteps, 72 | ): 73 | super().__init__() 74 | 75 | self.d_feats = d_feats 76 | self.d_condfeat= d_condfeat 77 | self.d_model = d_model 78 | self.n_head = n_head 79 | self.n_dec_layers = n_dec_layers 80 | self.d_k = d_k 81 | self.d_v = d_v 82 | self.max_timesteps = max_timesteps 83 | 84 | # Input: BS X D X T 85 | # Output: BS X T X D' 86 | self.motion_transformer = Decoder(d_feats=self.d_condfeat, d_model=self.d_model, \ 87 | n_layers=self.n_dec_layers, n_head=self.n_head, d_k=self.d_k, d_v=self.d_v, \ 88 | max_timesteps=self.max_timesteps, use_full_attention=True) 89 | 90 | self.linear_out = nn.Linear(self.d_model, self.d_feats) 91 | 92 | # For noise level t embedding 93 | dim = 64 94 | time_dim = dim * 4 95 | 96 | sinu_pos_emb = SinusoidalPosEmb(dim) 97 | fourier_dim = dim 98 | 99 | self.time_mlp = nn.Sequential( 100 | sinu_pos_emb, 101 | nn.Linear(fourier_dim, time_dim), 102 | nn.GELU(), 103 | nn.Linear(time_dim, d_model) 104 | ) 105 | 106 | def forward(self, src, noise_t, padding_mask=None): 107 | 108 | noise_t_embed = self.time_mlp(noise_t) # BS X d_model 109 | noise_t_embed = noise_t_embed[:, None, :] # BS X 1 X d_model 110 | 111 | bs = src.shape[0] 112 | num_steps = src.shape[1] + 1 113 | 114 | if padding_mask is None: 115 | # In training, no need for masking 116 | padding_mask = torch.ones(bs, 1, num_steps).to(src.device).bool() # BS X 1 X timesteps 117 | 118 | # Get position vec for position-wise embedding 119 | pos_vec = torch.arange(num_steps)+1 # timesteps 120 | pos_vec = pos_vec[None, None, :].to(src.device).repeat(bs, 1, 1) # BS X 1 X timesteps 121 | 122 | data_input = src.transpose(1, 2).detach() # BS X D X T 123 | feat_pred, _ = self.motion_transformer(data_input, padding_mask, pos_vec, obj_embedding=noise_t_embed) 124 | 125 | output = self.linear_out(feat_pred[:, 1:]) # BS X T X D 126 | 127 | return output # predicted noise, the same size as the input 128 | 129 | class CondGaussianDiffusion(nn.Module): 130 | def __init__( 131 | self, 132 | d_feats, 133 | d_condfeat, 134 | d_model, 135 | n_head, 136 | n_dec_layers, 137 | d_k, 138 | d_v, 139 | max_timesteps, 140 | out_dim, 141 | timesteps = 1000, 142 | loss_type = 'l1', 143 | objective = 'pred_noise', 144 | beta_schedule = 'cosine', 145 | p2_loss_weight_gamma = 0., # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time - 1. is recommended 146 | p2_loss_weight_k = 1, 147 | batch_size=None, 148 | ): 149 | super().__init__() 150 | 151 | self.denoise_fn = TransformerDiffusionModel(d_feats=d_feats, d_condfeat=d_condfeat,d_model=d_model, n_head=n_head, \ 152 | d_k=d_k, d_v=d_v, n_dec_layers=n_dec_layers, max_timesteps=max_timesteps) 153 | # Input condition and noisy motion, noise level t, predict gt motion 154 | 155 | self.objective = objective 156 | 157 | self.seq_len = max_timesteps - 1 158 | self.out_dim = out_dim 159 | self.d_feats = d_feats 160 | 161 | if beta_schedule == 'linear': 162 | betas = linear_beta_schedule(timesteps) 163 | elif beta_schedule == 'cosine': 164 | betas = cosine_beta_schedule(timesteps) 165 | else: 166 | raise ValueError(f'unknown beta schedule {beta_schedule}') 167 | 168 | alphas = 1. - betas 169 | alphas_cumprod = torch.cumprod(alphas, axis=0) 170 | alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.) 171 | 172 | timesteps, = betas.shape 173 | self.num_timesteps = int(timesteps) 174 | self.loss_type = loss_type 175 | 176 | # helper function to register buffer from float64 to float32 177 | register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32)) 178 | 179 | register_buffer('betas', betas) 180 | register_buffer('alphas_cumprod', alphas_cumprod) 181 | register_buffer('alphas_cumprod_prev', alphas_cumprod_prev) 182 | 183 | # calculations for diffusion q(x_t | x_{t-1}) and others 184 | register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod)) 185 | register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod)) 186 | register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod)) 187 | register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod)) 188 | register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1)) 189 | 190 | # calculations for posterior q(x_{t-1} | x_t, x_0) 191 | posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) 192 | 193 | # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) 194 | register_buffer('posterior_variance', posterior_variance) 195 | 196 | # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain 197 | register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20))) 198 | register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)) 199 | register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)) 200 | 201 | # calculate p2 reweighting 202 | register_buffer('p2_loss_weight', (p2_loss_weight_k + alphas_cumprod / (1 - alphas_cumprod)) ** -p2_loss_weight_gamma) 203 | 204 | def predict_start_from_noise(self, x_t, t, noise): 205 | return ( 206 | extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - 207 | extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise 208 | ) 209 | 210 | def q_posterior(self, x_start, x_t, t): 211 | posterior_mean = ( 212 | extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + 213 | extract(self.posterior_mean_coef2, t, x_t.shape) * x_t 214 | ) 215 | posterior_variance = extract(self.posterior_variance, t, x_t.shape) 216 | posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) 217 | return posterior_mean, posterior_variance, posterior_log_variance_clipped 218 | 219 | def p_mean_variance(self, x, t, x_cond, clip_denoised, padding_mask=None): 220 | x_cond = x_cond[:,:,self.d_feats:] 221 | x_all = torch.cat((x, x_cond), dim=-1) 222 | 223 | model_output = self.denoise_fn(x_all, t, padding_mask=padding_mask) 224 | 225 | if self.objective == 'pred_noise': 226 | x_start = self.predict_start_from_noise(x, t=t, noise=model_output) 227 | elif self.objective == 'pred_x0': 228 | x_start = model_output 229 | else: 230 | raise ValueError(f'unknown objective {self.objective}') 231 | 232 | 233 | model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_start, x_t=x, t=t) 234 | return model_mean, posterior_variance, posterior_log_variance 235 | 236 | @torch.no_grad() 237 | def p_sample(self, x, t, x_cond, clip_denoised=True, padding_mask=None): 238 | 239 | b, *_, device = *x.shape, x.device 240 | model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, x_cond=x_cond, \ 241 | clip_denoised=clip_denoised, padding_mask=padding_mask) 242 | noise = torch.randn_like(x) 243 | 244 | nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) 245 | return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise 246 | 247 | @torch.no_grad() 248 | def p_sample_loop(self, shape, x_start, hand, cond_mask, padding_mask=None): 249 | device = self.betas.device 250 | 251 | b = shape[0] 252 | x = torch.randn(hand.shape, device=device) 253 | 254 | x_cond = x_start * (1. - cond_mask) + \ 255 | cond_mask * torch.randn_like(x_start).to(x_start.device) 256 | 257 | for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps): 258 | x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), x_cond, padding_mask=padding_mask) 259 | 260 | return x # BS X T X D 261 | 262 | 263 | @torch.no_grad() 264 | def p_sample_loop_sliding_window_w_canonical(self, ds,x_denoise,obj, shape, cond_mask): 265 | 266 | 267 | device = self.betas.device 268 | 269 | b = shape[0] 270 | # assert b == 1 271 | 272 | x_all = torch.randn(x_denoise.shape, device=device) 273 | 274 | curr_x = x_all 275 | 276 | curr_x_start = torch.zeros(shape[0], shape[1], shape[2]).to(device) 277 | # st() 278 | curr_x_start[:, :, (curr_x_start.shape[2]-obj.shape[2]):] = obj # BS X T X 6 279 | 280 | curr_cond_mask = cond_mask[:, :,:] # BS X T X D 281 | curr_x_cond = curr_x_start * (1. - curr_cond_mask) + curr_cond_mask * torch.randn_like(curr_x_start).to(curr_x_start.device) 282 | 283 | for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps): 284 | curr_x = self.p_sample(curr_x, torch.full((b,), i, device=device, dtype=torch.long), curr_x_cond) 285 | 286 | return curr_x 287 | # T X 22 X 3, T X 3 288 | 289 | 290 | @torch.no_grad() 291 | def sample(self, x_start, hand, cond_mask=None, padding_mask=None): 292 | # naive conditional sampling by replacing the noisy prediction with input target data. 293 | self.denoise_fn.eval() 294 | sample_res = self.p_sample_loop(x_start.shape, \ 295 | x_start, hand, cond_mask) 296 | # BS X T X D 297 | self.denoise_fn.train() 298 | return sample_res 299 | 300 | @torch.no_grad() 301 | def sample_sliding_window(self, x_start, cond_mask): 302 | # If the sequence is longer than trained max window, divide 303 | self.denoise_fn.eval() 304 | sample_res = self.p_sample_loop_sliding_window(x_start.shape, \ 305 | x_start, cond_mask) 306 | # BS X T X D 307 | self.denoise_fn.train() 308 | return sample_res 309 | 310 | @torch.no_grad() 311 | def sample_sliding_window_w_canonical(self, ds, x_denoise, obj, x_start, cond_mask): 312 | 313 | self.denoise_fn.eval() 314 | sample_res = self.p_sample_loop_sliding_window_w_canonical(ds = ds, x_denoise = x_denoise, obj = obj,shape = x_start.shape,cond_mask = cond_mask) 315 | # BS X T X D 316 | self.denoise_fn.train() 317 | return sample_res 318 | 319 | def q_sample(self, x_start, t, noise=None): 320 | noise = default(noise, lambda: torch.randn_like(x_start)) 321 | 322 | return ( 323 | extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + 324 | extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise 325 | ) 326 | 327 | @property 328 | def loss_fn(self): 329 | if self.loss_type == 'l1': 330 | return F.l1_loss 331 | elif self.loss_type == 'l2': 332 | return F.mse_loss 333 | else: 334 | raise ValueError(f'invalid loss type {self.loss_type}') 335 | 336 | def p_losses(self, x_both, x_start, cond_mask, t, noise=None, padding_mask=None): 337 | # x_start: BS X T X D 338 | # cond_mask: BS X T X D, missing regions are 1, head pose conditioned regions are 0. 339 | noise = default(noise, lambda: torch.randn_like(x_start)) 340 | 341 | x = self.q_sample(x_start=x_start, t=t, noise=noise) # noisy motion in noise level t. 342 | 343 | noisy_x_start = x_both.clone() 344 | masked_x_input = x 345 | x_cond = noisy_x_start * (1. - cond_mask) + cond_mask * torch.randn_like(noisy_x_start).to(noisy_x_start.device) 346 | 347 | x_cond = x_cond[:,:,self.d_feats:] 348 | 349 | x_all = torch.cat((masked_x_input, x_cond), dim=-1) 350 | 351 | model_out = self.denoise_fn(x_all, t, padding_mask) 352 | 353 | if self.objective == 'pred_noise': 354 | target = noise 355 | elif self.objective == 'pred_x0': 356 | target = x_start 357 | else: 358 | raise ValueError(f'unknown objective {self.objective}') 359 | 360 | # Predicting both head pose and other joints' pose. 361 | if padding_mask is not None: 362 | loss = self.loss_fn(model_out, target, reduction = 'none') * padding_mask[:, 0, 1:][:, :, None] 363 | else: 364 | loss = self.loss_fn(model_out, target, reduction = 'none') # BS X T X D 365 | 366 | loss = reduce(loss, 'b ... -> b (...)', 'mean') 367 | 368 | loss = loss * extract(self.p2_loss_weight, t, loss.shape) 369 | 370 | return loss.mean() 371 | 372 | 373 | def forward(self, x_both,x_start, cond_mask, padding_mask=None): 374 | bs = x_start.shape[0] 375 | t = torch.randint(0, self.num_timesteps, (bs,), device=x_start.device).long() 376 | curr_loss = self.p_losses(x_both, x_start, cond_mask, t, padding_mask=padding_mask) 377 | 378 | return curr_loss -------------------------------------------------------------------------------- /models/contact_conditional_module.py: -------------------------------------------------------------------------------- 1 | # 全数据集 heatmap到mano 2 | 3 | import sys 4 | import argparse 5 | import os 6 | from pathlib import Path 7 | import yaml 8 | 9 | import wandb 10 | import torch 11 | from torch.optim import Adam 12 | from torch.cuda.amp import autocast, GradScaler 13 | from torch.utils import data 14 | import pytorch3d.transforms as transforms 15 | from ema_pytorch import EMA 16 | 17 | 18 | from models.pointnet import PointNetfeat2 #as PointNetfeat #as pointnet 19 | import torch.nn.functional as F 20 | from ipdb import set_trace as st 21 | import torch.nn as nn 22 | from data.grab_test import GrabDataset 23 | from models.cond_diffusion_model import CondGaussianDiffusion 24 | 25 | def cycle(dl): 26 | while True: 27 | for data in dl: 28 | yield data 29 | 30 | class Trainer(object): 31 | def __init__( 32 | self, 33 | opt, 34 | diffusion_model, 35 | pointnet_model, 36 | *, 37 | ema_decay = 0.995, 38 | train_batch_size = 20480, # bs=1 39 | train_lr = 1e-5, 40 | train_num_steps = 500000, 41 | gradient_accumulate_every = 2, 42 | amp = False, 43 | step_start_ema = 2000, 44 | ema_update_every = 10, 45 | save_and_sample_every = 1000, 46 | results_folder = './results', 47 | use_wandb=True, 48 | run_demo=False, 49 | ): 50 | super().__init__() 51 | self.use_wandb = use_wandb 52 | if self.use_wandb: 53 | # Loggers 54 | wandb.init(config=opt, project=opt.wandb_pj_name, entity=opt.entity, name=opt.exp_name, dir=opt.save_dir) 55 | 56 | self.model = diffusion_model 57 | self.ema = EMA(diffusion_model, beta=ema_decay, update_every=ema_update_every) 58 | self.pointnet = pointnet_model 59 | #self.pointnet = PointNetfeat(3,64,64,1024,global_feat=True) 60 | self.ema2 = EMA(pointnet_model, beta=ema_decay, update_every=ema_update_every) 61 | self.step_start_ema = step_start_ema 62 | self.save_and_sample_every = save_and_sample_every 63 | self.batch_size = train_batch_size 64 | self.gradient_accumulate_every = gradient_accumulate_every 65 | self.train_num_steps = train_num_steps 66 | 67 | 68 | self.optimizer = Adam([{'params': pointnet_model.parameters(), 'lr': 1e-4}, 69 | {'params': diffusion_model.parameters(), 'lr': train_lr}]) 70 | #self.optimizer = Adam(diffusion_model.parameters(), lr=train_lr) 71 | self.step = 0 72 | self.amp = amp 73 | self.scaler = GradScaler(enabled=amp) 74 | 75 | self.results_folder = results_folder 76 | 77 | self.vis_folder = results_folder.replace("weights", "vis_res") 78 | 79 | self.opt = opt 80 | # self.mse_loss = nn.MSELoss(reduction="none") 81 | self.ds = GrabDataset() 82 | 83 | self.window = opt.window 84 | self.mse_loss = nn.MSELoss(reduction="none") 85 | 86 | def prep_dataloader(self, args,seq=None): 87 | # Define dataset 88 | if seq is not None: 89 | split = args.run_on 90 | train_dataset = GrabDataset() 91 | val_dataset = GrabDataset() 92 | 93 | self.ds = train_dataset 94 | self.val_ds = val_dataset 95 | self.dl = cycle(data.DataLoader(self.ds, batch_size=self.batch_size, shuffle=True, pin_memory=True, num_workers=0)) 96 | self.val_dl = cycle(data.DataLoader(self.val_ds, batch_size=self.batch_size, shuffle=False, pin_memory=True, num_workers=0)) 97 | 98 | def save(self, milestone): 99 | data = { 100 | 'step': self.step, 101 | 'model': self.model.state_dict(), 102 | 'pointnet': self.pointnet.state_dict(), 103 | 'ema': self.ema.state_dict(), 104 | 'ema2': self.ema2.state_dict(), 105 | 'scaler': self.scaler.state_dict() # 梯度缩放器的状态字典 106 | } 107 | torch.save(data, os.path.join(self.results_folder, 'model-'+str(milestone)+'.pt')) 108 | print("save sucess") 109 | 110 | def load(self, milestone): 111 | data = torch.load(os.path.join(self.results_folder, 'model-'+str(milestone)+'.pt')) 112 | self.step = data['step'] 113 | self.model.load_state_dict(data['model'], strict=False) 114 | self.ema.load_state_dict(data['ema'], strict=False) 115 | self.pointnet.load_state_dict(data['pointnet'], strict=False) 116 | self.ema2.load_state_dict(data['ema2'], strict=False) 117 | self.scaler.load_state_dict(data['scaler']) 118 | 119 | def load_weight_path(self, weight_path): 120 | data = torch.load(weight_path) 121 | self.step = data['step'] 122 | self.model.load_state_dict(data['model'], strict=False) 123 | self.ema.load_state_dict(data['ema'], strict=False) 124 | self.pointnet.load_state_dict(data['pointnet'], strict=False) 125 | self.ema2.load_state_dict(data['ema2'], strict=False) 126 | self.scaler.load_state_dict(data['scaler']) 127 | self.scaler.load_state_dict(data['scaler']) 128 | 129 | def prep_head_condition_mask(self, data, joint_idx=15): 130 | mask = torch.ones_like(data).to(data.device) 131 | mask[:,:,61:] = torch.zeros(data.shape[0], data.shape[1], 3072).to(data.device) 132 | # st() 133 | return mask 134 | 135 | def full_body_gen_cond_head_pose_sliding_window(self, input_data): 136 | self.ema.ema_model.eval() 137 | self.ema2.ema_model.eval() 138 | 139 | with torch.no_grad(): 140 | obj = input_data[:,:,61:6205].view([input_data.shape[0],input_data.shape[1],2048,3]).cuda() 141 | feature_in = obj.transpose(3, 2).cuda() 142 | 143 | val_feature= self.ema2.ema_model(feature_in) 144 | val_data_in = torch.cat([input_data[:,:,:61],val_feature.cuda(),input_data[:,:,6205:8253]],dim=-1).cuda() 145 | 146 | mano = input_data[:,:,:61] 147 | denoise_data = torch.zeros(mano.shape[0], mano.shape[1], mano.shape[2]).to(input_data.device) 148 | 149 | data = torch.zeros(val_data_in.shape[0], val_data_in.shape[1], val_data_in.shape[2]).to(input_data.device) 150 | val_data_in = val_data_in[:,:,61:] 151 | cond_mask = self.prep_head_condition_mask(data) 152 | 153 | pred_x = self.ema.ema_model.sample_sliding_window_w_canonical(self.ds, x_denoise=denoise_data,\ 154 | obj=val_data_in, x_start=data, cond_mask=cond_mask) 155 | 156 | return pred_x 157 | 158 | 159 | def get_moudle(opt, run_demo=False): 160 | opt.window = opt.diffusion_window 161 | 162 | opt.diffusion_save_dir = os.path.join(opt.diffusion_project, opt.diffusion_exp_name) 163 | device = torch.device(f"cuda" if torch.cuda.is_available() else "cpu") 164 | 165 | # Prepare Directories 166 | save_dir = Path(opt.diffusion_save_dir) 167 | wdir = save_dir / 'weights' 168 | 169 | # Define model 170 | repr_dim = 61 171 | repr_dim_cond=3133 172 | 173 | transformer_diffusion = CondGaussianDiffusion(d_feats=repr_dim, d_condfeat=repr_dim_cond, d_model=opt.diffusion_d_model, \ 174 | n_dec_layers=opt.diffusion_n_dec_layers, n_head=opt.diffusion_n_head, \ 175 | d_k=opt.diffusion_d_k, d_v=opt.diffusion_d_v, \ 176 | max_timesteps=opt.diffusion_window+1, out_dim=repr_dim, timesteps=1000, objective="pred_x0", \ 177 | batch_size=opt.diffusion_batch_size) 178 | 179 | transformer_diffusion.to(device) 180 | pointnet_model = PointNetfeat2(3,64,64,1024,global_feat=True) 181 | #pointnet_model = pointnet(global_feat=True, feature_transform=True, channel=3) 182 | pointnet_model.to(device) 183 | 184 | 185 | trainer = Trainer( 186 | opt, 187 | transformer_diffusion, 188 | pointnet_model, 189 | train_batch_size=opt.diffusion_batch_size, # 32 190 | train_lr=opt.diffusion_learning_rate, # 1e-4 191 | train_num_steps=8000000, # 700000, total training steps 192 | gradient_accumulate_every=2, # gradient accumulation steps 193 | ema_decay=0.995, # exponential moving average decay 194 | amp=True, # turn on mixed precision 195 | results_folder=str(wdir), 196 | use_wandb=False, 197 | run_demo=run_demo, 198 | ) 199 | 200 | return trainer -------------------------------------------------------------------------------- /models/mlp.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | class MLP(nn.Module): 5 | def __init__(self, input_dim, hidden_dims=(128, 128), activation='tanh', is_dropout=False): 6 | super().__init__() 7 | if activation == 'tanh': 8 | self.activation = torch.tanh 9 | elif activation == 'relu': 10 | self.activation = torch.relu 11 | elif activation == 'sigmoid': 12 | self.activation = torch.sigmoid 13 | 14 | self.out_dim = hidden_dims[-1] 15 | self.affine_layers = nn.ModuleList() 16 | last_dim = input_dim 17 | for idx, nh in enumerate(hidden_dims): 18 | self.affine_layers.append(nn.Linear(last_dim, nh)) 19 | if idx == 0 and is_dropout: 20 | self.affine_layers.append(nn.Dropout(p=0.5)) 21 | last_dim = nh 22 | 23 | def forward(self, x): 24 | for affine in self.affine_layers: 25 | x = self.activation(affine(x)) 26 | return x 27 | -------------------------------------------------------------------------------- /models/pointnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.nn.parallel 8 | import torch.utils.data 9 | from torch.autograd import Variable 10 | from ipdb import set_trace as st 11 | """ 12 | Source: https://github.com/fxia22/pointnet.pytorch/blob/f0c2430b0b1529e3f76fb5d6cd6ca14be763d975/pointnet/model.py 13 | """ 14 | 15 | class PointNetfeat(nn.Module): 16 | def __init__(self, input_dim, shallow_dim, mid_dim, out_dim, global_feat=False): 17 | super(PointNetfeat, self).__init__() 18 | self.shallow_layer = nn.Sequential( 19 | nn.Conv1d(input_dim, shallow_dim, 1), nn.BatchNorm1d(shallow_dim) 20 | ) 21 | 22 | self.base_layer = nn.Sequential( 23 | nn.Conv1d(shallow_dim, mid_dim, 1), 24 | nn.BatchNorm1d(mid_dim), 25 | nn.ReLU(), 26 | nn.Conv1d(mid_dim, out_dim, 1), 27 | nn.BatchNorm1d(out_dim), 28 | ) 29 | 30 | self.global_feat = global_feat 31 | self.out_dim = out_dim 32 | 33 | def forward(self, x): 34 | 35 | x= x.squeeze(1) 36 | n_pts = x.size()[2] 37 | x = self.shallow_layer(x) 38 | pointfeat = x 39 | 40 | x = self.base_layer(x) 41 | x = torch.max(x, 2, keepdim=True)[0] 42 | x = x.view(-1, self.out_dim) 43 | 44 | trans_feat = None 45 | trans = None 46 | x = x.unsqueeze(1) 47 | if self.global_feat: 48 | return x#, trans, trans_feat 49 | 50 | 51 | class PointNetfeat2(nn.Module): 52 | def __init__(self, input_dim, shallow_dim, mid_dim, out_dim, global_feat=False): 53 | super(PointNetfeat2, self).__init__() 54 | # self.shallow_layer = nn.Sequential( 55 | # nn.Conv1d(input_dim, shallow_dim, 1), nn.BatchNorm1d(shallow_dim) 56 | # ) 57 | 58 | self.base_layer = nn.Sequential( 59 | nn.Conv1d(input_dim, shallow_dim, 1), 60 | nn.BatchNorm1d(shallow_dim), 61 | nn.ReLU(), 62 | nn.Conv1d(shallow_dim, out_dim, 1), 63 | nn.BatchNorm1d(out_dim), 64 | ) 65 | 66 | self.global_feat = global_feat 67 | self.out_dim = out_dim 68 | 69 | def forward(self, x): 70 | 71 | x= x.squeeze(1) 72 | n_pts = x.size()[2] 73 | # x = self.shallow_layer(x) 74 | # pointfeat = x 75 | 76 | x = self.base_layer(x) 77 | x = torch.max(x, 2, keepdim=True)[0] 78 | x = x.view(-1, self.out_dim) 79 | 80 | trans_feat = None 81 | trans = None 82 | x = x.unsqueeze(1) 83 | if self.global_feat: 84 | return x#, trans, trans_feat 85 | # else: 86 | # x = x.view(-1, self.out_dim, 1).repeat(1, 1, n_pts) 87 | # return torch.cat([x, pointfeat], 1), trans, trans_feat 88 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torchvision import models 4 | 5 | class ResNet(nn.Module): 6 | def __init__(self, out_dim, fix_params=False, running_stats=False, pretrained = False): 7 | super().__init__() 8 | self.out_dim = out_dim 9 | self.resnet = models.resnet18(pretrained = pretrained ) 10 | if fix_params: 11 | for param in self.resnet.parameters(): 12 | param.requires_grad = False 13 | 14 | self.resnet.fc = nn.Linear(self.resnet.fc.in_features, out_dim) 15 | self.bn_stats(running_stats) 16 | 17 | def forward(self, x): 18 | return self.resnet(x) 19 | 20 | def bn_stats(self, track_running_stats): 21 | for m in self.modules(): 22 | if type(m) == nn.BatchNorm2d: 23 | m.track_running_stats = track_running_stats 24 | 25 | class FeatureExtractor(nn.Module): 26 | def __init__(self): 27 | super(FeatureExtractor, self).__init__() 28 | self.cnn_fdim = 512 29 | self.cnn = ResNet(self.cnn_fdim, running_stats=False, pretrained=True) 30 | # If freeze the CNN params 31 | for param in self.cnn.parameters(): 32 | param.requires_grad = False 33 | 34 | def to(self, device): 35 | self.device = device 36 | super().to(device) 37 | return self 38 | 39 | def forward(self, data): 40 | # pose: 69 dim body pose 41 | batch_size, seq_len, _, _, _ = data['of'].shape # 42 | 43 | of_data = data['of'] # B X T X 224 X 224 X 2 44 | of_data = torch.cat((of_data, torch.zeros(of_data.shape[:-1] + (1,), device=of_data.device)), dim=-1) 45 | h, w = 224, 224 46 | c = 3 47 | of_data = of_data.reshape(-1, h, w, c).permute(0, 3, 1, 2) # B X T X 3 X 224 X 224 48 | input_features = self.cnn(of_data).reshape(batch_size, seq_len, self.cnn_fdim) # B X T X D 49 | 50 | return input_features 51 | 52 | if __name__ == '__main__': 53 | net = ResNet(128) 54 | input = ones(1, 3, 224, 224) 55 | out = net(input) 56 | print(out.shape) 57 | -------------------------------------------------------------------------------- /models/semantic_conditional_module.py: -------------------------------------------------------------------------------- 1 | # 全数据集 heatmap到mano 2 | 3 | import sys 4 | import argparse 5 | import os 6 | from pathlib import Path 7 | import yaml 8 | 9 | import wandb 10 | import torch 11 | from torch.optim import Adam 12 | from torch.cuda.amp import GradScaler 13 | from torch.utils import data 14 | import pytorch3d.transforms as transforms 15 | from ema_pytorch import EMA 16 | 17 | from models.pointnet import PointNetfeat2 #as PointNetfeat #as pointnet 18 | import torch.nn.functional as F 19 | from ipdb import set_trace as st 20 | import torch.nn as nn 21 | from data.grab_test import GrabDataset 22 | from models.cond_diffusion_model import CondGaussianDiffusion 23 | 24 | 25 | def cycle(dl): 26 | while True: 27 | for data in dl: 28 | yield data 29 | 30 | class Trainer(object): 31 | def __init__( 32 | self, 33 | opt, 34 | diffusion_model, 35 | pointnet_model, 36 | *, 37 | ema_decay = 0.995, 38 | train_batch_size = 20480, 39 | train_lr = 1e-5, 40 | train_num_steps = 1000000, 41 | gradient_accumulate_every = 2, 42 | amp = False, 43 | step_start_ema = 2000, 44 | ema_update_every = 10, 45 | save_and_sample_every = 5000, 46 | results_folder = './results', 47 | use_wandb=True, 48 | run_demo=False, 49 | ): 50 | super().__init__() 51 | self.use_wandb = use_wandb 52 | if self.use_wandb: 53 | # Loggers 54 | wandb.init(config=opt, project=opt.wandb_pj_name, entity=opt.entity, name=opt.exp_name, dir=opt.save_dir) 55 | 56 | self.model = diffusion_model 57 | self.ema = EMA(diffusion_model, beta=ema_decay, update_every=ema_update_every) 58 | self.pointnet = pointnet_model 59 | #self.pointnet = PointNetfeat(3,64,64,1024,global_feat=True) 60 | self.ema2 = EMA(pointnet_model, beta=ema_decay, update_every=ema_update_every) 61 | self.step_start_ema = step_start_ema 62 | self.save_and_sample_every = save_and_sample_every 63 | self.batch_size = train_batch_size 64 | self.gradient_accumulate_every = gradient_accumulate_every 65 | self.train_num_steps = train_num_steps 66 | 67 | self.optimizer = Adam([{'params': pointnet_model.parameters(), 'lr': 1e-4}, 68 | {'params': diffusion_model.parameters(), 'lr': train_lr}]) 69 | #self.optimizer = Adam(diffusion_model.parameters(), lr=train_lr) 70 | self.step = 0 71 | self.amp = amp 72 | self.scaler = GradScaler(enabled=amp) 73 | 74 | self.results_folder = results_folder 75 | 76 | self.vis_folder = results_folder.replace("weights", "vis_res") 77 | 78 | self.opt = opt 79 | # self.mse_loss = nn.MSELoss(reduction="none") 80 | 81 | self.ds = GrabDataset() 82 | 83 | self.window = opt.window 84 | self.mse_loss = nn.MSELoss(reduction="none") 85 | 86 | def prep_dataloader(self, args,seq=None): 87 | # Define dataset 88 | if seq is not None: 89 | split = args.run_on 90 | train_dataset = GrabDataset() 91 | val_dataset = GrabDataset() 92 | 93 | self.ds = train_dataset 94 | self.val_ds = val_dataset 95 | self.dl = cycle(data.DataLoader(self.ds, batch_size=self.batch_size, shuffle=True, pin_memory=True, num_workers=0)) 96 | self.val_dl = cycle(data.DataLoader(self.val_ds, batch_size=self.batch_size, shuffle=False, pin_memory=True, num_workers=0)) 97 | 98 | def save(self, milestone): 99 | data = { 100 | 'step': self.step, 101 | 'model': self.model.state_dict(), 102 | 'pointnet': self.pointnet.state_dict(), 103 | 'ema': self.ema.state_dict(), 104 | 'ema2': self.ema2.state_dict(), 105 | 'scaler': self.scaler.state_dict() 106 | } 107 | torch.save(data, os.path.join(self.results_folder, 'model-'+str(milestone)+'.pt')) 108 | print("save sucess") 109 | 110 | def load(self, milestone): 111 | data = torch.load(os.path.join(self.results_folder, 'model-'+str(milestone)+'.pt')) 112 | self.step = data['step'] 113 | self.model.load_state_dict(data['model'], strict=False) 114 | self.ema.load_state_dict(data['ema'], strict=False) 115 | self.pointnet.load_state_dict(data['pointnet'], strict=False) 116 | self.ema2.load_state_dict(data['ema2'], strict=False) 117 | self.scaler.load_state_dict(data['scaler']) 118 | 119 | def load_weight_path(self, weight_path): 120 | data = torch.load(weight_path) 121 | self.step = data['step'] 122 | self.model.load_state_dict(data['model'], strict=False) 123 | self.ema.load_state_dict(data['ema'], strict=False) 124 | self.pointnet.load_state_dict(data['pointnet'], strict=False) 125 | self.ema2.load_state_dict(data['ema2'], strict=False) 126 | self.scaler.load_state_dict(data['scaler']) 127 | self.scaler.load_state_dict(data['scaler']) 128 | 129 | 130 | def prep_head_condition_mask(self, data, joint_idx=15): 131 | mask = torch.ones_like(data).to(data.device) 132 | mask[:,:,2048:] = torch.zeros(data.shape[0], data.shape[1], 11264).to(data.device) 133 | return mask 134 | 135 | 136 | def full_body_gen_cond_head_pose_sliding_window(self, input_data): 137 | self.ema.ema_model.eval() 138 | self.ema2.ema_model.eval() 139 | 140 | with torch.no_grad(): 141 | obj = input_data[:,:,61:6205].view([input_data.shape[0],input_data.shape[1],2048,3]).cuda() 142 | feature_in = obj.transpose(3, 2).cuda() 143 | 144 | val_feature= self.ema2.ema_model(feature_in) 145 | val_data_in = torch.cat([input_data[:,:,6205:8253],val_feature.cuda(),input_data[:,:,8253:18493]],dim=-1).cuda() 146 | 147 | mano = input_data[:,:,6205:8253] 148 | denoise_data = torch.zeros(mano.shape[0], mano.shape[1], mano.shape[2]).to(input_data.device) 149 | 150 | data = torch.zeros(val_data_in.shape[0], val_data_in.shape[1], val_data_in.shape[2]).to(input_data.device) 151 | val_data_in = val_data_in[:,:,2048:] 152 | cond_mask = self.prep_head_condition_mask(data) 153 | 154 | pred_x = self.ema.ema_model.sample_sliding_window_w_canonical(self.ds, x_denoise=denoise_data,\ 155 | obj=val_data_in, x_start=data, cond_mask=cond_mask) 156 | 157 | return pred_x 158 | 159 | 160 | def get_moudle(opt, run_demo=False): 161 | opt.window = opt.diffusion_window 162 | opt.diffusion_save_dir = os.path.join(opt.diffusion_project, opt.diffusion_exp_name) 163 | device = torch.device(f"cuda" if torch.cuda.is_available() else "cpu") 164 | 165 | # Prepare Directories 166 | save_dir = Path(opt.diffusion_save_dir) 167 | wdir = save_dir / 'weights' 168 | 169 | # Define model 170 | repr_dim = 2048 171 | repr_dim_cond=13312 172 | 173 | transformer_diffusion = CondGaussianDiffusion(d_feats=repr_dim, d_condfeat=repr_dim_cond, d_model=opt.diffusion_d_model, \ 174 | n_dec_layers=opt.diffusion_n_dec_layers, n_head=opt.diffusion_n_head, \ 175 | d_k=opt.diffusion_d_k, d_v=opt.diffusion_d_v, \ 176 | max_timesteps=opt.diffusion_window+1, out_dim=repr_dim, timesteps=1000, objective="pred_x0", \ 177 | batch_size=opt.diffusion_batch_size) 178 | 179 | transformer_diffusion.to(device) 180 | pointnet_model = PointNetfeat2(3,64,64,1024,global_feat=True) 181 | #pointnet_model = pointnet(global_feat=True, feature_transform=True, channel=3) 182 | pointnet_model.to(device) 183 | 184 | trainer = Trainer( 185 | opt, 186 | transformer_diffusion, 187 | pointnet_model, 188 | train_batch_size=opt.diffusion_batch_size, 189 | train_lr=opt.diffusion_learning_rate, 190 | train_num_steps=8000000, 191 | gradient_accumulate_every=2, 192 | ema_decay=0.995, 193 | amp=True, 194 | results_folder=str(wdir), 195 | use_wandb=False, 196 | run_demo=run_demo, 197 | ) 198 | return trainer 199 | -------------------------------------------------------------------------------- /models/transformer_module.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): 7 | ''' Sinusoid position encoding table ''' 8 | 9 | def cal_angle(position, hid_idx): 10 | return position / np.power(10000, 2 * (hid_idx // 2) / d_hid) 11 | 12 | def get_posi_angle_vec(position): 13 | return [cal_angle(position, hid_j) for hid_j in range(d_hid)] 14 | 15 | sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)]) 16 | 17 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i 18 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 19 | 20 | if padding_idx is not None: 21 | # zero vector for padding dimension 22 | sinusoid_table[padding_idx] = 0. 23 | 24 | return torch.FloatTensor(sinusoid_table) 25 | 26 | def get_subsequent_mask(seq): 27 | ''' For masking out the subsequent info. ''' 28 | sz_b, len_s = seq.size() 29 | subsequent_mask = torch.triu( 30 | torch.ones((len_s, len_s), device=seq.device, dtype=torch.bool), diagonal=1) 31 | subsequent_mask = subsequent_mask.unsqueeze(0).expand(sz_b, -1, -1) # b x ls x ls 32 | 33 | return subsequent_mask 34 | 35 | 36 | class MultiHeadAttention(nn.Module): 37 | def __init__(self, n_head, d_model, d_k, d_v): 38 | super(MultiHeadAttention, self).__init__() 39 | 40 | self.n_head = n_head 41 | self.d_model = d_model 42 | self.d_k = d_k 43 | self.d_v = d_v 44 | 45 | self.w_q = nn.Linear(d_model, n_head*d_k) 46 | self.w_k = nn.Linear(d_model, n_head*d_k) 47 | self.w_v = nn.Linear(d_model, n_head*d_v) 48 | nn.init.normal_(self.w_q.weight, mean=0, std=np.sqrt(2.0/(d_model+d_k))) 49 | nn.init.normal_(self.w_k.weight, mean=0, std=np.sqrt(2.0/(d_model+d_k))) 50 | nn.init.normal_(self.w_v.weight, mean=0, std=np.sqrt(2.0/(d_model+d_v))) 51 | 52 | self.temperature = np.power(d_k, 0.5) 53 | self.attn_dropout = nn.Dropout(0.1) 54 | 55 | self.fc = nn.Linear(n_head*d_v, d_model) 56 | nn.init.xavier_normal_(self.fc.weight) 57 | self.layer_norm = nn.LayerNorm(d_model) 58 | 59 | self.dropout = nn.Dropout(0.1) 60 | 61 | def forward(self, q, k, v, mask=None): 62 | # q: BS X T X D, k: BS X T X D, v: BS X T X D, mask: BS X T X T 63 | bs, n_q, _ = q.shape 64 | bs, n_k, _ = k.shape 65 | bs, n_v, _ = v.shape 66 | 67 | assert n_k == n_v 68 | 69 | residual = q 70 | 71 | q = self.w_q(q).view(bs, n_q, self.n_head, self.d_k).permute(2, 0, 1, 3).contiguous().view(-1, n_q, self.d_k) 72 | k = self.w_k(k).view(bs, n_k, self.n_head, self.d_k).permute(2, 0, 1, 3).contiguous().view(-1, n_k, self.d_k) 73 | v = self.w_v(v).view(bs, n_v, self.n_head, self.d_v).permute(2, 0, 1, 3).contiguous().view(-1, n_v, self.d_v) 74 | 75 | attn = torch.bmm(q, k.transpose(1, 2)) # (n_head*bs) X n_q X n_k 76 | attn = attn / self.temperature 77 | 78 | if mask is not None: 79 | mask = mask.repeat(self.n_head, 1, 1) # (n_head*bs) x n_q x n_k 80 | attn = attn.masked_fill(mask, -np.inf) 81 | 82 | attn = F.softmax(attn, dim=2) # (n_head*bs) X n_q X n_k 83 | 84 | attn = self.attn_dropout(attn) 85 | output = torch.bmm(attn, v) # (n_head*bs) X n_q X d_v 86 | 87 | output = output.view(self.n_head, bs, n_q, self.d_v) 88 | output = output.permute(1, 2, 0, 3).contiguous().view(bs, n_q, -1) 89 | # BS X n_q X (n_head*D) 90 | 91 | # output = self.fc(output) # BS X n_q X D 92 | output = self.dropout(self.fc(output)) # BS X n_q X D 93 | output = self.layer_norm(output + residual) # BS X n_q X D 94 | 95 | return output, attn 96 | 97 | 98 | class PositionwiseFeedForward(nn.Module): 99 | def __init__(self, d_in, d_hid): 100 | super(PositionwiseFeedForward, self).__init__() 101 | 102 | self.w_1 = nn.Conv1d(d_in, d_hid, 1) 103 | self.w_2 = nn.Conv1d(d_hid, d_in, 1) 104 | self.layer_norm = nn.LayerNorm(d_in) 105 | self.dropout = nn.Dropout(0.1) 106 | 107 | def forward(self, x): 108 | # x: BS X N X D 109 | residual = x 110 | output = x.transpose(1, 2) # BS X D X N 111 | output = self.w_2(F.relu(self.w_1(output))) # BS X D X N 112 | output = output.transpose(1, 2) # BS X N X D 113 | output = self.dropout(output) 114 | output = self.layer_norm(output + residual) # BS X N X D 115 | 116 | return output 117 | 118 | 119 | class DecoderLayer(nn.Module): 120 | def __init__(self, d_model, n_head, d_k, d_v): 121 | super(DecoderLayer, self).__init__() 122 | 123 | self.self_attn = MultiHeadAttention(n_head, d_model, d_k, d_v) 124 | self.pos_ffn = PositionwiseFeedForward(d_model, d_model) 125 | 126 | def forward(self, decoder_input, self_attn_time_mask, self_attn_padding_mask): 127 | # decode_input: BS X T X D 128 | # time_mask: BS X T X T (padding postion are ones) 129 | # padding_mask: BS X T (padding position are zeros, diff usage from above) 130 | bs, dec_len, dec_hidden = decoder_input.shape 131 | 132 | decoder_out, dec_self_attn = self.self_attn(decoder_input, decoder_input, decoder_input, \ 133 | mask=self_attn_time_mask) 134 | # BS X T X D, BS X T X T 135 | decoder_out *= self_attn_padding_mask.unsqueeze(-1).float() 136 | # BS X T X D 137 | 138 | decoder_out = self.pos_ffn(decoder_out) # BS X T X D 139 | decoder_out *= self_attn_padding_mask.unsqueeze(-1).float() 140 | 141 | return decoder_out, dec_self_attn 142 | # BS X T X D, BS X T X T 143 | 144 | 145 | class CrossDecoderLayer(nn.Module): 146 | def __init__(self, d_model, n_head, d_k, d_v): 147 | super(CrossDecoderLayer, self).__init__() 148 | 149 | self.self_attn = MultiHeadAttention(n_head, d_model, d_k, d_v) 150 | self.pos_ffn = PositionwiseFeedForward(d_model, d_model) 151 | 152 | def forward(self, decoder_input, k_input, self_attn_time_mask, self_attn_padding_mask): 153 | # decode_input: BS X T X D 154 | # k_input: BS X K X D 155 | # time_mask: BS X T X T (padding postion are ones) 156 | # padding_mask: BS X T (padding position are zeros, diff usage from above) 157 | bs, dec_len, dec_hidden = decoder_input.shape 158 | 159 | decoder_out, dec_self_attn = self.self_attn(decoder_input, k_input, k_input, \ 160 | mask=self_attn_time_mask) 161 | # BS X T X D, BS X T X T 162 | decoder_out *= self_attn_padding_mask.unsqueeze(-1).float() 163 | # BS X T X D 164 | 165 | decoder_out = self.pos_ffn(decoder_out) # BS X T X D 166 | decoder_out *= self_attn_padding_mask.unsqueeze(-1).float() 167 | 168 | return decoder_out, dec_self_attn 169 | # BS X T X D, BS X T X T 170 | 171 | 172 | class Decoder(nn.Module): 173 | def __init__( 174 | self, 175 | d_feats, d_model, 176 | n_layers, n_head, d_k, d_v, max_timesteps, use_full_attention=False): 177 | super(Decoder, self).__init__() 178 | 179 | self.start_conv = nn.Conv1d(d_feats, d_model, 1) # (input: 17*3) 180 | self.position_vec = nn.Embedding.from_pretrained( 181 | get_sinusoid_encoding_table(max_timesteps+1, d_model, padding_idx=0), 182 | freeze=True) 183 | self.layer_stack = nn.ModuleList([DecoderLayer(d_model, n_head, d_k, d_v) 184 | for _ in range(n_layers)]) 185 | 186 | self.use_full_attention = use_full_attention 187 | 188 | def forward(self, decoder_input, padding_mask, decoder_pos_vec, obj_embedding=None): 189 | # decoder_input: BS X D X T 190 | # padding_mask: BS X 1 X T 191 | # decoder_pos_vec: BS X 1 X T 192 | # obj_embedding: BS X 1 X D 193 | 194 | dec_self_attn_list = [] 195 | 196 | padding_mask = padding_mask.squeeze(1) # BS X T 197 | decoder_pos_vec = decoder_pos_vec.squeeze(1) # BS X T 198 | 199 | input_embedding = self.start_conv(decoder_input) # BS X D X T 200 | input_embedding = input_embedding.transpose(1, 2) # BS X T X D 201 | if obj_embedding is not None: 202 | new_input_embedding = torch.cat((obj_embedding, input_embedding), dim=1) # BS X (T+1) X D 203 | else: 204 | new_input_embedding = input_embedding 205 | 206 | # self.position_vec = self.position_vec.cuda() 207 | pos_embedding = self.position_vec(decoder_pos_vec) # BS X T X D 208 | 209 | # Time mask is same for all blocks, while padding mask differ according to the position of block 210 | if self.use_full_attention: 211 | time_mask = None 212 | else: 213 | time_mask = get_subsequent_mask(decoder_pos_vec) 214 | # BS X T X T (Prev steps are 0, later 1) 215 | 216 | dec_output = new_input_embedding + pos_embedding # BS X T X D 217 | for dec_layer in self.layer_stack: 218 | dec_output, dec_self_attn = dec_layer( 219 | dec_output, # BS X T X D 220 | self_attn_time_mask=time_mask, # BS X T X T 221 | self_attn_padding_mask=padding_mask) # BS X T 222 | 223 | dec_self_attn_list += [dec_self_attn] 224 | 225 | return dec_output, dec_self_attn_list 226 | # BS X T X D, list 227 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | scikit-image 2 | scenepic==1.0.8 3 | torchgeometry==0.1.2 4 | einops==0.4.1 5 | smplx==0.1.28 6 | pyopengl==3.1.0 7 | gym==0.17.2 8 | numpy==1.21.5 9 | mujoco-py==2.1.2.14 10 | joblib 11 | imageio==2.19.3 12 | imageio-ffmpeg==0.4.7 13 | opencv-python 14 | trimesh 15 | scipy 16 | scikit-learn 17 | wandb==0.12.21 18 | ema-pytorch==0.0.10 19 | 20 | --------------------------------------------------------------------------------