├── .gitignore ├── LICENSE ├── README.md ├── apps ├── demo.py ├── eval_interhand.py └── train.py ├── core ├── Loss.py ├── gcn_trainer.py ├── loader.py ├── test_utils.py └── vis_train.py ├── dataset ├── dataset_utils.py ├── heatmap.py ├── inference.py └── interhand.py ├── demo ├── 1.jpg ├── 2.jpg └── 3.jpg ├── models ├── decoder.py ├── encoder.py ├── manolayer.py ├── model.py ├── model_attn │ ├── DualGraph.py │ ├── __init__.py │ ├── gcn.py │ ├── img_attn.py │ ├── inter_attn.py │ └── self_attn.py └── model_zoo │ ├── __init__.py │ ├── coarsening.py │ ├── fc.py │ ├── graph_utils.py │ └── hrnet.py └── utils ├── DataProvider.py ├── config.py ├── defaults.yaml ├── lr_sc.py ├── tb_utils.py ├── utils.py └── vis_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | misc/* 2 | output/* 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | pip-wheel-metadata/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # IntagHand 2 | 3 | This repository contains a pytorch implementation of "[Interacting Attention Graph for Single Image Two-Hand Reconstruction](http://www.liuyebin.com/IntagHand/Intaghand.html)". 4 | 5 | Mengcheng Li, [Liang An](https://anl13.github.io), [Hongwen Zhang](https://hongwenzhang.github.io), Lianpeng Wu, Feng Chen, [Tao Yu](http://ytrock.com/), [Yebin Liu](http://www.liuyebin.com/) 6 | 7 | Tsinghua University & Hisense Inc. 8 | 9 | CVPR 2022 (Oral) 10 | 11 | 12 | 13 | **2023.02.02 Update: add an example of training code** 14 | 15 | 16 | 17 | ![pic](http://www.liuyebin.com/IntagHand/assets/results2.png) 18 | 19 | ## Requirements 20 | 21 | - Tested with python3.7 on Ubuntu 16.04, CUDA 10.2. 22 | 23 | ### packages 24 | 25 | - pytorch (tested on 1.10.0+cu102) 26 | 27 | - torchvision (tested on 0.11.0+cu102) 28 | 29 | - pytorch3d (tested on 0.6.1) 30 | 31 | - numpy 32 | 33 | - OpenCV 34 | 35 | - tqdm 36 | 37 | - yacs >= 0.1.8 38 | 39 | ### Pre-trained model and data 40 | 41 | - Download necessary assets (including the pre-trained models) from [misc.tar.gz](https://github.com/Dw1010/IntagHand/releases/download/v0.0/misc.tar.gz) and unzip it. 42 | - Register and download [MANO](https://mano.is.tue.mpg.de/) data. Put `MANO_LEFT.pkl` and `MANO_RIGHT.pkl` in `misc/mano` 43 | 44 | After collecting the above necessary files, the directory structure of `./misc` is expected as follows: 45 | 46 | ``` 47 | ./misc 48 | ├── mano 49 | │ └── MANO_LEFT.pkl 50 | │ └── MANO_RIGHT.pkl 51 | ├── model 52 | │ └── config.yaml 53 | │ └── interhand.pth 54 | │ └── wild_demo.pth 55 | ├── graph_left.pkl 56 | ├── graph_right.pkl 57 | ├── upsample.pkl 58 | ├── v_color.pkl 59 | 60 | ``` 61 | 62 | ## DEMO 63 | 64 | 1. Real-time demo : 65 | 66 | ``` 67 | python apps/demo.py --live_demo 68 | ``` 69 | 2. Single-image reconstruction : 70 | 71 | ``` 72 | python apps/demo.py --img_path demo/ --save_path demo/ 73 | ``` 74 | Results will be stored in folder `./demo` 75 | 76 | **Noted**: We don't operate hand detection, so hands are expected to be roughly at the center of image and take approximately 70-90% of the image area. 77 | 78 | ## Training 79 | 80 | 1. Download [InterHand2.6M](https://mks0601.github.io/InterHand2.6M/) dataset and unzip it. (**Noted**: we used the `v1.0_5fps` version and `H+M` subset for training and evaluating) 81 | 82 | 2. Process the dataset by : 83 | ``` 84 | python dataset/interhand.py --data_path PATH_OF_INTERHAND2.6M --save_path ./interhand2.6m/ 85 | ``` 86 | Replace `PATH_OF_INTERHAND2.6M` with your own store path of [InterHand2.6M](https://mks0601.github.io/InterHand2.6M/) dataset. 87 | 88 | 3. Try the training code: 89 | ``` 90 | python apps/train.py utils/defaults.yaml 91 | ``` 92 | 93 | The output model and TensorBoard log file would be store in `./output`. 94 | If you have multiple GPUs on your device, set `--gpu` to use them. For example, use: 95 | 96 | ``` 97 | python apps/train.py utils/defaults.yaml --gpu 0,1,2,3 98 | ``` 99 | to train model on 4 GPUs. 100 | 101 | 4. We highly recommend you to try different loss weight and fine-turn the model with lower learning rate to get better result. The training configuration can be modified in `utils/defaults.yaml`. 102 | 103 | ## Evaluation 104 | 105 | 1. Download [InterHand2.6M](https://mks0601.github.io/InterHand2.6M/) dataset and unzip it. (**Noted**: we used the `v1.0_5fps` version and `H+M` subset for training and evaluating) 106 | 107 | 2. Process the dataset by : 108 | ``` 109 | python dataset/interhand.py --data_path PATH_OF_INTERHAND2.6M --save_path ./interhand2.6m/ 110 | ``` 111 | Replace `PATH_OF_INTERHAND2.6M` with your own store path of [InterHand2.6M](https://mks0601.github.io/InterHand2.6M/) dataset. 112 | 113 | 3. Run evaluation: 114 | ``` 115 | python apps/eval_interhand.py --data_path ./interhand2.6m/ 116 | ``` 117 | 118 | You would get following output : 119 | 120 | ``` 121 | joint mean error: 122 | left: 8.93425289541483 mm, right: 8.663229644298553 mm 123 | all: 8.798741269856691 mm 124 | vert mean error: 125 | left: 9.173248894512653 mm, right: 8.890160359442234 mm 126 | all: 9.031704626977444 mm 127 | ``` 128 | 129 | 130 | ## Acknowledgement 131 | 132 | The pytorch implementation of MANO is based on [manopth](https://github.com/hassony2/manopth). The GCN network is based on [hand-graph-cnn](https://github.com/3d-hand-shape/hand-graph-cnn). The heatmap generation and inference is based on [DarkPose](https://github.com/ilovepose/DarkPose). We thank the authors for their great job! 133 | 134 | ## Citation 135 | 136 | If you find the code useful in your research, please consider citing the paper. 137 | 138 | ``` 139 | @inproceedings{Li2022intaghand, 140 | title={Interacting Attention Graph for Single Image Two-Hand Reconstruction}, 141 | author={Li, Mengcheng and An, Liang and Zhang, Hongwen and Wu, Lianpeng and Chen, Feng and Yu, Tao and Liu, Yebin}, 142 | booktitle={IEEE/CVF Conf. on Computer Vision and Pattern Recognition (CVPR)}, 143 | month=jun, 144 | year={2022}, 145 | } 146 | ``` 147 | 148 | -------------------------------------------------------------------------------- /apps/demo.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import cv2 as cv 4 | import glob 5 | import os 6 | import argparse 7 | 8 | 9 | import sys 10 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 11 | 12 | from models.model import load_model 13 | from utils.config import load_cfg 14 | from utils.utils import get_mano_path, imgUtils 15 | from dataset.dataset_utils import IMG_SIZE 16 | from core.test_utils import InterRender 17 | 18 | 19 | def cut_img(img, bbox): 20 | cut = img[max(int(bbox[2]), 0):min(int(bbox[3]), img.shape[0]), 21 | max(int(bbox[0]), 0):min(int(bbox[1]), img.shape[1])] 22 | cut = cv.copyMakeBorder(cut, 23 | max(int(-bbox[2]), 0), 24 | max(int(bbox[3] - img.shape[0]), 0), 25 | max(int(-bbox[0]), 0), 26 | max(int(bbox[1] - img.shape[1]), 0), 27 | borderType=cv.BORDER_CONSTANT, 28 | value=(0, 0, 0)) 29 | return cut 30 | 31 | 32 | if __name__ == '__main__': 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument("--cfg", type=str, default='misc/model/config.yaml') 35 | parser.add_argument("--model", type=str, default='misc/model/wild_demo.pth') 36 | parser.add_argument("--live_demo", action='store_true') 37 | parser.add_argument("--img_path", type=str, default='demo/') 38 | parser.add_argument("--save_path", type=str, default='demo/') 39 | parser.add_argument("--render_size", type=int, default=256) 40 | opt = parser.parse_args() 41 | 42 | model = InterRender(cfg_path=opt.cfg, 43 | model_path=opt.model, 44 | render_size=opt.render_size) 45 | 46 | if not opt.live_demo: 47 | img_path_list = glob.glob(os.path.join(opt.img_path, '*.jpg')) + glob.glob(os.path.join(opt.img_path, '*.png')) 48 | for img_path in img_path_list: 49 | img_name = os.path.basename(img_path) 50 | if img_name.find('output.jpg') != -1: 51 | continue 52 | img_name = img_name[:img_name.find('.')] 53 | img = cv.imread(img_path) 54 | params = model.run_model(img) 55 | img_overlap = model.render(params, bg_img=img) 56 | cv.imwrite(os.path.join(opt.save_path, img_name + '_output.jpg'), img_overlap) 57 | else: 58 | video_reader = cv.VideoCapture(0) 59 | fourcc = cv.VideoWriter_fourcc('M', 'J', 'P', 'G') 60 | video_reader.set(cv.CAP_PROP_FOURCC, fourcc) 61 | 62 | smooth = False 63 | params_last = None 64 | params_last_v = None 65 | params_v = None 66 | params_a = None 67 | 68 | fIdx = 0 69 | with torch.no_grad(): 70 | while True: 71 | fIdx = fIdx + 1 72 | _, img = video_reader.read() 73 | if img is None: 74 | exit() 75 | w = min(img.shape[1], img.shape[0]) / 2 * 0.6 76 | left = int(img.shape[1] / 2 - w) 77 | top = int(img.shape[0] / 2 - w) 78 | size = int(2 * w) 79 | bbox = [left, left + size, top, top + size] 80 | bbox = np.array(bbox).astype(np.int32) 81 | crop_img = img[bbox[2]:bbox[3], bbox[0]:bbox[1]] 82 | 83 | params = model.run_model(crop_img) 84 | if smooth and params_last is not None and params_v is not None and params_a is not None: 85 | for k in params.keys(): 86 | if isinstance(params[k], torch.Tensor): 87 | pred = params_last[k] + params_v[k] + 0.5 * params_a[k] 88 | params[k] = (0.7 * params[k] + 0.3 * pred) 89 | 90 | img_out = model.render(params, bg_img=crop_img) 91 | img[bbox[2]:bbox[3], bbox[0]:bbox[1]] = cv.resize(img_out, (size, size)) 92 | cv.line(img, (int(bbox[0]), int(bbox[2])), (int(bbox[0]), int(bbox[3])), (0, 0, 255), 2) 93 | cv.line(img, (int(bbox[1]), int(bbox[2])), (int(bbox[1]), int(bbox[3])), (0, 0, 255), 2) 94 | cv.line(img, (int(bbox[0]), int(bbox[2])), (int(bbox[1]), int(bbox[2])), (0, 0, 255), 2) 95 | cv.line(img, (int(bbox[0]), int(bbox[3])), (int(bbox[1]), int(bbox[3])), (0, 0, 255), 2) 96 | cv.imshow('cap', img) 97 | 98 | if params_last is not None: 99 | params_v = {} 100 | for k in params.keys(): 101 | if isinstance(params[k], torch.Tensor): 102 | params_v[k] = (params[k] - params_last[k]) 103 | if params_last_v is not None and params_v is not None: 104 | params_a = {} 105 | for k in params.keys(): 106 | if isinstance(params[k], torch.Tensor): 107 | params_a[k] = (params_v[k] - params_last_v[k]) 108 | params_last = params 109 | params_last_v = params_v 110 | 111 | key = cv.waitKey(1) 112 | 113 | if key == 27: 114 | exit() 115 | -------------------------------------------------------------------------------- /apps/eval_interhand.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import cv2 as cv 3 | import torch 4 | import numpy as np 5 | from tqdm import tqdm 6 | import torchvision.transforms as transforms 7 | from torch.utils.data import Dataset 8 | from torch.utils.data import DataLoader 9 | 10 | import sys 11 | import os 12 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 13 | 14 | from models.model import load_model 15 | from models.manolayer import ManoLayer 16 | from utils.config import load_cfg 17 | from utils.vis_utils import mano_two_hands_renderer 18 | from utils.utils import get_mano_path 19 | from dataset.dataset_utils import IMG_SIZE, cut_img 20 | from dataset.interhand import fix_shape, InterHand_dataset 21 | 22 | 23 | class Jr(): 24 | def __init__(self, J_regressor, 25 | device='cuda'): 26 | self.device = device 27 | self.process_J_regressor(J_regressor) 28 | 29 | def process_J_regressor(self, J_regressor): 30 | J_regressor = J_regressor.clone().detach() 31 | tip_regressor = torch.zeros_like(J_regressor[:5]) 32 | tip_regressor[0, 745] = 1.0 33 | tip_regressor[1, 317] = 1.0 34 | tip_regressor[2, 444] = 1.0 35 | tip_regressor[3, 556] = 1.0 36 | tip_regressor[4, 673] = 1.0 37 | J_regressor = torch.cat([J_regressor, tip_regressor], dim=0) 38 | new_order = [0, 13, 14, 15, 16, 39 | 1, 2, 3, 17, 40 | 4, 5, 6, 18, 41 | 10, 11, 12, 19, 42 | 7, 8, 9, 20] 43 | self.J_regressor = J_regressor[new_order].contiguous().to(self.device) 44 | 45 | def __call__(self, v): 46 | return torch.matmul(self.J_regressor, v) 47 | 48 | 49 | class handDataset(Dataset): 50 | def __init__(self, dataset): 51 | self.dataset = dataset 52 | self.normalize_img = transforms.Normalize(mean=[0.485, 0.456, 0.406], 53 | std=[0.229, 0.224, 0.225]) 54 | 55 | def __len__(self): 56 | return len(self.dataset) 57 | 58 | def __getitem__(self, idx): 59 | img, mask, dense, hand_dict = self.dataset[idx] 60 | img = cv.resize(img, (IMG_SIZE, IMG_SIZE)) 61 | imgTensor = torch.tensor(cv.cvtColor(img, cv.COLOR_BGR2RGB), dtype=torch.float32) / 255 62 | imgTensor = imgTensor.permute(2, 0, 1) 63 | imgTensor = self.normalize_img(imgTensor) 64 | 65 | maskTensor = torch.tensor(mask, dtype=torch.float32) / 255 66 | 67 | joints_left_gt = torch.from_numpy(hand_dict['left']['joints3d']).float() 68 | verts_left_gt = torch.from_numpy(hand_dict['left']['verts3d']).float() 69 | joints_right_gt = torch.from_numpy(hand_dict['right']['joints3d']).float() 70 | verts_right_gt = torch.from_numpy(hand_dict['right']['verts3d']).float() 71 | 72 | return imgTensor, maskTensor, joints_left_gt, verts_left_gt, joints_right_gt, verts_right_gt 73 | 74 | 75 | if __name__ == '__main__': 76 | parser = argparse.ArgumentParser() 77 | parser.add_argument("--cfg", type=str, default='misc/model/config.yaml') 78 | parser.add_argument("--model", type=str, default='misc/model/interhand.pth') 79 | parser.add_argument("--data_path", type=str) 80 | parser.add_argument("--bs", type=int, default=32) 81 | opt = parser.parse_args() 82 | 83 | opt.map = False 84 | 85 | network = load_model(opt.cfg) 86 | 87 | state = torch.load(opt.model, map_location='cpu') 88 | try: 89 | network.load_state_dict(state) 90 | except: 91 | state2 = {} 92 | for k, v in state.items(): 93 | state2[k[7:]] = v 94 | network.load_state_dict(state2) 95 | 96 | network.eval() 97 | network.cuda() 98 | 99 | mano_path = get_mano_path() 100 | mano_layer = {'left': ManoLayer(mano_path['left'], center_idx=None), 101 | 'right': ManoLayer(mano_path['right'], center_idx=None)} 102 | fix_shape(mano_layer) 103 | J_regressor = {'left': Jr(mano_layer['left'].J_regressor), 104 | 'right': Jr(mano_layer['right'].J_regressor)} 105 | 106 | faces_left = mano_layer['left'].get_faces() 107 | faces_right = mano_layer['right'].get_faces() 108 | 109 | dataset = handDataset(InterHand_dataset(opt.data_path, split='test')) 110 | dataloader = DataLoader(dataset, batch_size=opt.bs, shuffle=False, 111 | num_workers=4, drop_last=False, pin_memory=True) 112 | 113 | joints_loss = {'left': [], 'right': []} 114 | verts_loss = {'left': [], 'right': []} 115 | 116 | with torch.no_grad(): 117 | for data in tqdm(dataloader): 118 | 119 | imgTensors = data[0].cuda() 120 | joints_left_gt = data[2].cuda() 121 | verts_left_gt = data[3].cuda() 122 | joints_right_gt = data[4].cuda() 123 | verts_right_gt = data[5].cuda() 124 | 125 | joints_left_gt = J_regressor['left'](verts_left_gt) 126 | joints_right_gt = J_regressor['right'](verts_right_gt) 127 | 128 | root_left_gt = joints_left_gt[:, 9:10] 129 | root_right_gt = joints_right_gt[:, 9:10] 130 | length_left_gt = torch.linalg.norm(joints_left_gt[:, 9] - joints_left_gt[:, 0], dim=-1) 131 | length_right_gt = torch.linalg.norm(joints_right_gt[:, 9] - joints_right_gt[:, 0], dim=-1) 132 | joints_left_gt = joints_left_gt - root_left_gt 133 | verts_left_gt = verts_left_gt - root_left_gt 134 | joints_right_gt = joints_right_gt - root_right_gt 135 | verts_right_gt = verts_right_gt - root_right_gt 136 | 137 | result, paramsDict, handDictList, otherInfo = network(imgTensors) 138 | 139 | verts_left_pred = result['verts3d']['left'] 140 | verts_right_pred = result['verts3d']['right'] 141 | joints_left_pred = J_regressor['left'](verts_left_pred) 142 | joints_right_pred = J_regressor['right'](verts_right_pred) 143 | 144 | root_left_pred = joints_left_pred[:, 9:10] 145 | root_right_pred = joints_right_pred[:, 9:10] 146 | length_left_pred = torch.linalg.norm(joints_left_pred[:, 9] - joints_left_pred[:, 0], dim=-1) 147 | length_right_pred = torch.linalg.norm(joints_right_pred[:, 9] - joints_right_pred[:, 0], dim=-1) 148 | scale_left = (length_left_gt / length_left_pred).unsqueeze(-1).unsqueeze(-1) 149 | scale_right = (length_right_gt / length_right_pred).unsqueeze(-1).unsqueeze(-1) 150 | 151 | joints_left_pred = (joints_left_pred - root_left_pred) * scale_left 152 | verts_left_pred = (verts_left_pred - root_left_pred) * scale_left 153 | joints_right_pred = (joints_right_pred - root_right_pred) * scale_right 154 | verts_right_pred = (verts_right_pred - root_right_pred) * scale_right 155 | 156 | joint_left_loss = torch.linalg.norm((joints_left_pred - joints_left_gt), ord=2, dim=-1) 157 | joint_left_loss = joint_left_loss.detach().cpu().numpy() 158 | joints_loss['left'].append(joint_left_loss) 159 | 160 | joint_right_loss = torch.linalg.norm((joints_right_pred - joints_right_gt), ord=2, dim=-1) 161 | joint_right_loss = joint_right_loss.detach().cpu().numpy() 162 | joints_loss['right'].append(joint_right_loss) 163 | 164 | vert_left_loss = torch.linalg.norm((verts_left_pred - verts_left_gt), ord=2, dim=-1) 165 | vert_left_loss = vert_left_loss.detach().cpu().numpy() 166 | verts_loss['left'].append(vert_left_loss) 167 | 168 | vert_right_loss = torch.linalg.norm((verts_right_pred - verts_right_gt), ord=2, dim=-1) 169 | vert_right_loss = vert_right_loss.detach().cpu().numpy() 170 | verts_loss['right'].append(vert_right_loss) 171 | 172 | joints_loss['left'] = np.concatenate(joints_loss['left'], axis=0) 173 | joints_loss['right'] = np.concatenate(joints_loss['right'], axis=0) 174 | verts_loss['left'] = np.concatenate(verts_loss['left'], axis=0) 175 | verts_loss['right'] = np.concatenate(verts_loss['right'], axis=0) 176 | 177 | joints_mean_loss_left = joints_loss['left'].mean() * 1000 178 | joints_mean_loss_right = joints_loss['right'].mean() * 1000 179 | verts_mean_loss_left = verts_loss['left'].mean() * 1000 180 | verts_mean_loss_right = verts_loss['right'].mean() * 1000 181 | 182 | print('joint mean error:') 183 | print(' left: {} mm, right: {} mm'.format(joints_mean_loss_left, joints_mean_loss_right)) 184 | print(' all: {} mm'.format((joints_mean_loss_left + joints_mean_loss_right) / 2)) 185 | print('vert mean error:') 186 | print(' left: {} mm, right: {} mm'.format(verts_mean_loss_left, verts_mean_loss_right)) 187 | print(' all: {} mm'.format((verts_mean_loss_left + verts_mean_loss_right) / 2)) 188 | -------------------------------------------------------------------------------- /apps/train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 4 | 5 | import torch 6 | import json 7 | import cv2 as cv 8 | import numpy as np 9 | from tqdm import tqdm 10 | import pickle 11 | import argparse 12 | import yacs 13 | import random 14 | import torch.distributed as dist 15 | import torch.multiprocessing as mp 16 | 17 | 18 | from models.manolayer import ManoLayer 19 | 20 | from utils.config import load_cfg 21 | from core.gcn_trainer import train_gcn 22 | 23 | 24 | if __name__ == '__main__': 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument("cfg", type=str) 27 | parser.add_argument('--gpu', type=str, default='0') 28 | opt = parser.parse_args() 29 | 30 | os.environ['CUDA_VISIBLE_DEVICES'] = str(opt.gpu) 31 | print("Work on GPU: ", os.environ['CUDA_VISIBLE_DEVICES']) 32 | 33 | gpu_list = opt.gpu.split(',') 34 | num_gpus = len(gpu_list) 35 | dist_training = (num_gpus > 1) 36 | 37 | cfg = load_cfg(opt.cfg) 38 | 39 | if not os.path.isdir(cfg.SAVE.SAVE_DIR): 40 | os.makedirs(cfg.SAVE.SAVE_DIR, exist_ok=True) 41 | if not os.path.isdir(cfg.TB.SAVE_DIR): 42 | os.makedirs(cfg.TB.SAVE_DIR, exist_ok=True) 43 | with open(os.path.join(cfg.SAVE.SAVE_DIR, 'config.yaml'), 'w') as file: 44 | file.write(cfg.dump()) 45 | 46 | if not dist_training: 47 | train_gcn(cfg=cfg) 48 | else: 49 | mp.spawn(train_gcn, 50 | args=(num_gpus, cfg, True), 51 | nprocs=num_gpus, 52 | join=True) 53 | -------------------------------------------------------------------------------- /core/Loss.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 4 | 5 | import numpy as np 6 | import pickle 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from utils.utils import get_upsample_path 12 | 13 | MANO_PARENT = [-1, 0, 1, 2, 3, 14 | 0, 5, 6, 7, 15 | 0, 9, 10, 11, 16 | 0, 13, 14, 15, 17 | 0, 17, 18, 19] 18 | 19 | 20 | class GraphLoss(): 21 | def __init__(self, J_regressor, faces, level=4, 22 | device='cuda'): 23 | # loss function 24 | self.L1Loss = nn.L1Loss() 25 | self.L2Loss = nn.MSELoss() 26 | self.smoothL1Loss = nn.SmoothL1Loss(beta=0.05) 27 | 28 | self.device = device 29 | 30 | self.level = level + 1 31 | self.process_J_regressor(J_regressor) 32 | self.faces = torch.from_numpy(faces.astype(np.int64)).to(self.device) 33 | 34 | with open(get_upsample_path(), 'rb') as file: 35 | upsample_weight = pickle.load(file) 36 | self.upsample_weight = torch.from_numpy(upsample_weight).to(self.device) 37 | 38 | def process_J_regressor(self, J_regressor): 39 | J_regressor = J_regressor.clone().detach() 40 | tip_regressor = torch.zeros_like(J_regressor[:5]) 41 | tip_regressor[0, 745] = 1.0 42 | tip_regressor[1, 317] = 1.0 43 | tip_regressor[2, 444] = 1.0 44 | tip_regressor[3, 556] = 1.0 45 | tip_regressor[4, 673] = 1.0 46 | J_regressor = torch.cat([J_regressor, tip_regressor], dim=0) 47 | new_order = [0, 48 | 13, 14, 15, 16, 49 | 1, 2, 3, 17, 50 | 4, 5, 6, 18, 51 | 10, 11, 12, 19, 52 | 7, 8, 9, 20] 53 | self.J_regressor = J_regressor[new_order].contiguous().to(self.device) 54 | 55 | def mesh_downsample(self, feat, p=2): 56 | # feat: bs x N x f 57 | feat = feat.permute(0, 2, 1).contiguous() # x = bs x f x N 58 | feat = nn.AvgPool1d(p)(feat) # bs x f x N/p 59 | feat = feat.permute(0, 2, 1).contiguous() # x = bs x N/p x f 60 | return feat 61 | 62 | def mesh_upsample(self, x, p=2): 63 | x = x.permute(0, 2, 1).contiguous() # x = B x F x V 64 | x = nn.Upsample(scale_factor=p)(x) # B x F x (V*p) 65 | x = x.permute(0, 2, 1).contiguous() # x = B x (V*p) x F 66 | return x 67 | 68 | def norm_loss(self, verts_pred, verts_gt): 69 | edge_gt = verts_gt[:, self.faces] 70 | edge_gt = torch.stack([edge_gt[:, :, 0] - edge_gt[:, :, 1], 71 | edge_gt[:, :, 1] - edge_gt[:, :, 2], 72 | edge_gt[:, :, 2] - edge_gt[:, :, 0], 73 | ], dim=2) # B x F x 3 x 3 74 | edge_pred = verts_pred[:, self.faces] 75 | edge_pred = torch.stack([edge_pred[:, :, 0] - edge_pred[:, :, 1], 76 | edge_pred[:, :, 1] - edge_pred[:, :, 2], 77 | edge_pred[:, :, 2] - edge_pred[:, :, 0], 78 | ], dim=2) # B x F x 3 x 3 79 | 80 | # norm loss 81 | face_norm_gt = torch.cross(edge_gt[:, :, 0], edge_gt[:, :, 1], dim=-1) 82 | face_norm_gt = F.normalize(face_norm_gt, dim=-1) 83 | face_norm_gt = face_norm_gt.unsqueeze(2) # B x F x 1 x 3 84 | edge_pred_normed = F.normalize(edge_pred, dim=-1) 85 | temp = torch.sum(edge_pred_normed * face_norm_gt, dim=-1) # B x F x 3 86 | return self.L1Loss(temp, torch.zeros_like(temp)) 87 | 88 | def edge_loss(self, verts_pred, verts_gt): 89 | edge_gt = verts_gt[:, self.faces] 90 | edge_gt = torch.stack([edge_gt[:, :, 0] - edge_gt[:, :, 1], 91 | edge_gt[:, :, 1] - edge_gt[:, :, 2], 92 | edge_gt[:, :, 2] - edge_gt[:, :, 0], 93 | ], dim=2) # B x F x 3 x 3 94 | edge_pred = verts_pred[:, self.faces] 95 | edge_pred = torch.stack([edge_pred[:, :, 0] - edge_pred[:, :, 1], 96 | edge_pred[:, :, 1] - edge_pred[:, :, 2], 97 | edge_pred[:, :, 2] - edge_pred[:, :, 0], 98 | ], dim=2) # B x F x 3 x 3 99 | edge_length_gt = torch.linalg.norm(edge_gt, dim=-1) # B x F x 3 100 | edge_length_pred = torch.linalg.norm(edge_pred, dim=-1) # B x F x 3 101 | edge_length_loss = self.L1Loss(edge_length_pred, edge_length_gt) 102 | return edge_length_loss 103 | 104 | def calc_mano_loss(self, v3d_pred, v2d_pred, v3d_gt, v2d_gt, img_size): 105 | J_r_pred = torch.matmul(self.J_regressor, v3d_pred) 106 | J_r_gt = torch.matmul(self.J_regressor, v3d_gt) 107 | 108 | loss_dict = {} 109 | loss_dict['vert2d_loss'] = self.L2Loss((v2d_pred / img_size * 2 - 1), 110 | (v2d_gt / img_size * 2 - 1)) 111 | loss_dict['vert3d_loss'] = self.L1Loss(v3d_pred, v3d_gt) 112 | loss_dict['joint_loss'] = self.L1Loss(J_r_pred, J_r_gt) 113 | loss_dict['norm_loss'] = self.norm_loss(v3d_pred, v3d_gt) 114 | loss_dict['edge_loss'] = self.edge_loss(v3d_pred, v3d_gt) 115 | return loss_dict 116 | 117 | def upsample_weight_loss(self, w): 118 | x = w - self.upsample_weight 119 | return self.L1Loss(x, torch.zeros_like(x)) 120 | 121 | def rel_loss(self, v1, v2, v1_gt, v2_gt): 122 | rel_gt = v1.unsqueeze(1) - v2.unsqueeze(2) 123 | rel_gt = torch.linalg.norm(rel_gt, dim=-1) # bs x V x V 124 | rel_pred = v1_gt.unsqueeze(1) - v2_gt.unsqueeze(2) 125 | rel_pred = torch.linalg.norm(rel_pred, dim=-1) # bs x 21 x 21 126 | return self.L1Loss(rel_gt, rel_pred) 127 | 128 | def calc_loss(self, converter, 129 | v3d_gt, v2d_gt, 130 | v3d_pred, v2d_pred, 131 | v3dList, v2dList, 132 | img_size): 133 | assert self.faces.device == v3d_gt.device 134 | assert self.faces.device == v3d_pred.device 135 | mano_loss_dict = self.calc_mano_loss(v3d_pred, v2d_pred, v3d_gt, v2d_gt, img_size) 136 | 137 | v3dList_gt = [] 138 | v2dList_gt = [] 139 | v3d_gcn = converter.vert_to_GCN(v3d_gt) 140 | v2d_gcn = converter.vert_to_GCN(v2d_gt) 141 | 142 | for i in range(self.level): 143 | v3dList_gt.append(v3d_gcn) 144 | v2dList_gt.append(v2d_gcn) 145 | v3d_gcn = self.mesh_downsample(v3d_gcn) 146 | v2d_gcn = self.mesh_downsample(v2d_gcn) 147 | 148 | v3dList_gt.reverse() 149 | v2dList_gt.reverse() 150 | 151 | coarsen_loss_dict = {} 152 | coarsen_loss_dict['v3d_loss'] = [] 153 | coarsen_loss_dict['v2d_loss'] = [] 154 | for i in range(len(v2dList)): 155 | for j in range(len(v3dList_gt)): 156 | if v3dList[i].shape[1] == v3dList_gt[j].shape[1]: 157 | break 158 | 159 | coarsen_loss_dict['v3d_loss'].append(self.L1Loss(v3dList[i], 160 | v3dList_gt[j])) 161 | coarsen_loss_dict['v2d_loss'].append(self.L2Loss((v2dList[i] / img_size * 2 - 1), 162 | (v2dList_gt[j] / img_size * 2 - 1))) 163 | 164 | return mano_loss_dict, coarsen_loss_dict 165 | 166 | def range_loss(self, label, Min, Max): 167 | l1 = self._zero_norm_loss(torch.clamp(Min - label, min=0.)) 168 | l2 = self._zero_norm_loss(torch.clamp(label - Max, min=0.)) 169 | return l1 + l2 170 | 171 | def _one_norm_loss(self, p): 172 | return self.L1Loss(p, torch.ones_like(p)) 173 | 174 | def _zero_norm_loss(self, p): 175 | return self.L1Loss(p, torch.zeros_like(p)) 176 | 177 | 178 | 179 | 180 | def calc_aux_loss(cfg, hand_loss, 181 | dataDict, 182 | mask, dense, hms): 183 | loss_dict = {} 184 | total_loss = 0 185 | if 'mask' in dataDict: 186 | loss_dict['mask_loss'] = hand_loss.smoothL1Loss(dataDict['mask'], mask) 187 | total_loss = total_loss + loss_dict['mask_loss'] * cfg.LOSS_WEIGHT.AUX.MASK 188 | if 'dense' in dataDict: 189 | loss_l = hand_loss.smoothL1Loss(dataDict['dense'][:, :3] * mask[:, :1], dense * mask[:, :1]) 190 | loss_r = hand_loss.smoothL1Loss(dataDict['dense'][:, 3:] * mask[:, 1:], dense * mask[:, 1:]) 191 | loss_dict['dense_loss'] = (loss_l + loss_r) / 2 192 | total_loss = total_loss + loss_dict['dense_loss'] * cfg.LOSS_WEIGHT.AUX.DENSEPOSE 193 | if 'hms' in dataDict: 194 | loss_dict['hms_loss'] = hand_loss.L2Loss(dataDict['hms'], hms) 195 | total_loss = total_loss + loss_dict['hms_loss'] * cfg.LOSS_WEIGHT.AUX.HMS 196 | if total_loss > 0: 197 | loss_dict['total_loss'] = total_loss 198 | return loss_dict 199 | 200 | 201 | def calc_loss_GCN(cfg, epoch, 202 | graph_loss_left, graph_loss_right, 203 | converter_left, converter_right, 204 | result, paramsDict, handDictList, otherInfo, 205 | mask, dense, hms, 206 | v2d_l, j2d_l, v2d_r, j2d_r, 207 | v3d_l, j3d_l, v3d_r, j3d_r, 208 | root_rel, img_size, 209 | upsample_weight=None): 210 | 211 | aux_lost_dict = calc_aux_loss(cfg, graph_loss_left, 212 | otherInfo, 213 | mask, dense, hms) 214 | 215 | v3d_r = v3d_r + root_rel.unsqueeze(1) 216 | j3d_r = j3d_r + root_rel.unsqueeze(1) 217 | 218 | v2dList = [] 219 | v3dList = [] 220 | for i in range(len(handDictList)): 221 | v2dList.append(handDictList[i]['verts2d']['left']) 222 | v3dList.append(handDictList[i]['verts3d']['left']) 223 | mano_loss_dict_left, coarsen_loss_dict_left \ 224 | = graph_loss_left.calc_loss(converter_left, 225 | v3d_l, v2d_l, 226 | result['verts3d']['left'], result['verts2d']['left'], 227 | v3dList, v2dList, 228 | img_size) 229 | 230 | v2dList = [] 231 | v3dList = [] 232 | for i in range(len(handDictList)): 233 | v2dList.append(handDictList[i]['verts2d']['right']) 234 | v3dList.append(handDictList[i]['verts3d']['right']) 235 | mano_loss_dict_right, coarsen_loss_dict_right \ 236 | = graph_loss_right.calc_loss(converter_right, 237 | v3d_r, v2d_r, 238 | result['verts3d']['right'], result['verts2d']['right'], 239 | v3dList, v2dList, 240 | img_size) 241 | 242 | mano_loss_dict = {} 243 | for k in mano_loss_dict_left.keys(): 244 | mano_loss_dict[k] = (mano_loss_dict_left[k] + mano_loss_dict_right[k]) / 2 245 | 246 | coarsen_loss_dict = {} 247 | for k in coarsen_loss_dict_left.keys(): 248 | coarsen_loss_dict[k] = [] 249 | for i in range(len(coarsen_loss_dict_left[k])): 250 | coarsen_loss_dict[k].append((coarsen_loss_dict_left[k][i] + coarsen_loss_dict_right[k][i]) / 2) 251 | 252 | cfg = cfg.LOSS_WEIGHT 253 | alpha = 0 if epoch < cfg.GRAPH.NORM.NORM_EPOCH else 1 254 | 255 | if upsample_weight is not None: 256 | mano_loss_dict['upsample_norm_loss'] = graph_loss_left.upsample_weight_loss(upsample_weight) 257 | else: 258 | mano_loss_dict['upsample_norm_loss'] = torch.zeros_like(mano_loss_dict['vert3d_loss']) 259 | 260 | mano_loss = 0 \ 261 | + cfg.DATA.LABEL_3D * mano_loss_dict['vert3d_loss'] \ 262 | + cfg.DATA.LABEL_2D * mano_loss_dict['vert2d_loss'] \ 263 | + cfg.DATA.LABEL_3D * mano_loss_dict['joint_loss'] \ 264 | + cfg.GRAPH.NORM.NORMAL * mano_loss_dict['norm_loss'] \ 265 | + alpha * cfg.GRAPH.NORM.EDGE * mano_loss_dict['edge_loss'] 266 | 267 | coarsen_loss = 0 268 | for i in range(len(coarsen_loss_dict['v3d_loss'])): 269 | coarsen_loss = coarsen_loss \ 270 | + cfg.DATA.LABEL_3D * coarsen_loss_dict['v3d_loss'][i] \ 271 | + cfg.DATA.LABEL_2D * coarsen_loss_dict['v2d_loss'][i] 272 | 273 | 274 | total_loss = mano_loss + coarsen_loss + cfg.NORM.UPSAMPLE * mano_loss_dict['upsample_norm_loss'] 275 | 276 | if 'total_loss' in aux_lost_dict: 277 | total_loss = total_loss + aux_lost_dict['total_loss'] 278 | 279 | return total_loss, aux_lost_dict, mano_loss_dict, coarsen_loss_dict 280 | -------------------------------------------------------------------------------- /core/gcn_trainer.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | from tkinter.messagebox import NO 4 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 5 | 6 | import torch 7 | import json 8 | import cv2 as cv 9 | import numpy as np 10 | from torch.utils.tensorboard import SummaryWriter 11 | from tqdm import tqdm 12 | import pickle 13 | import random 14 | from torch.utils.data import DataLoader 15 | from torch.utils.data.distributed import DistributedSampler 16 | import torch.distributed as dist 17 | from torch.nn.parallel import DistributedDataParallel as DDP 18 | from torch.distributed.optim import ZeroRedundancyOptimizer 19 | 20 | 21 | from models.model import load_model 22 | 23 | from utils.tb_utils import tbUtils 24 | from utils.lr_sc import StepLR_withWarmUp 25 | from utils.DataProvider import DataProvider 26 | from utils.vis_utils import mano_two_hands_renderer 27 | from utils.utils import get_mano_path 28 | 29 | from core.loader import handDataset 30 | from core.Loss import GraphLoss, calc_loss_GCN 31 | from core.vis_train import tb_vis_train_gcn 32 | from dataset.dataset_utils import IMG_SIZE, BLUR_KERNEL 33 | from dataset.inference import get_final_preds2 34 | 35 | 36 | def freeze_model(model): 37 | for (name, params) in model.named_parameters(): 38 | params.requires_grad = False 39 | 40 | 41 | def train_gcn(rank=0, world_size=1, cfg=None, dist_training=False): 42 | if dist_training: 43 | os.environ['MASTER_ADDR'] = 'localhost' 44 | os.environ['MASTER_PORT'] = str(cfg.TRAIN.DIST_PORT) 45 | print("Init distributed training on local rank {}".format(rank)) 46 | torch.cuda.set_device(rank) 47 | dist.init_process_group("nccl", rank=rank, world_size=world_size) 48 | 49 | torch.manual_seed(cfg.SEED) 50 | torch.cuda.manual_seed(cfg.SEED) 51 | random.seed(cfg.SEED) 52 | np.random.seed(cfg.SEED) 53 | 54 | mano_path = get_mano_path() 55 | 56 | # ------------------------------------------------- 57 | # | 1. load model/optimizer/scheduler/tensorboard | 58 | # ------------------------------------------------- 59 | # load network 60 | network = load_model(cfg) 61 | network.to(rank) 62 | 63 | if cfg.MODEL.freeze_upsample: 64 | freeze_model(network.decoder.unsample_layer) 65 | 66 | converter = {} 67 | for hand_type in ['left', 'right']: 68 | converter[hand_type] = network.decoder.converter[hand_type] 69 | 70 | if dist_training: 71 | network = DDP( 72 | network, device_ids=[rank], 73 | output_device=rank, 74 | find_unused_parameters=True, 75 | ) 76 | # print('local rank {}: init model, done'.format(rank)) 77 | 78 | # load optimizer 79 | optim_params = list(filter(lambda p: p.requires_grad, network.parameters())) 80 | if cfg.TRAIN.OPTIM == 'adam': 81 | if dist_training: 82 | optimizer = ZeroRedundancyOptimizer( 83 | optim_params, 84 | optimizer_class=torch.optim.Adam, 85 | lr=cfg.TRAIN.LR 86 | ) 87 | else: 88 | optimizer = torch.optim.Adam(optim_params, lr=cfg.TRAIN.LR) 89 | elif cfg.TRAIN.OPTIM == 'rms': 90 | if dist_training: 91 | optimizer = ZeroRedundancyOptimizer( 92 | optim_params, 93 | optimizer_class=torch.optim.RMSprop, 94 | lr=cfg.TRAIN.LR 95 | ) 96 | else: 97 | optimizer = torch.optim.RMSprop(optim_params, lr=cfg.TRAIN.LR) 98 | else: 99 | raise ValueError('wrong optimizer type') 100 | # print('local rank {}: init optimizer, done'.format(rank)) 101 | 102 | # load learning rate scheduler 103 | lr_scheduler = StepLR_withWarmUp(optimizer, 104 | last_epoch=-1 if cfg.TRAIN.current_epoch == 0 else cfg.TRAIN.current_epoch, 105 | init_lr=1e-3 * cfg.TRAIN.LR, 106 | warm_up_epoch=cfg.TRAIN.warm_up, 107 | gamma=cfg.TRAIN.lr_decay_gamma, 108 | step_size=cfg.TRAIN.lr_decay_step, 109 | min_thres=0.05) 110 | # print('local rank {}: init lr_scheduler, done'.format(rank)) 111 | 112 | if rank == 0: 113 | # tensorboard 114 | writer = SummaryWriter(cfg.TB.SAVE_DIR) 115 | renderer = mano_two_hands_renderer(img_size=IMG_SIZE, device='cuda:{}'.format(rank)) 116 | 117 | # -------------------------- 118 | # | 2. load dataset & Loss | 119 | # -------------------------- 120 | aux_lambda = 2**(6 - len(cfg.MODEL.DECONV_DIMS)) 121 | trainDataset = handDataset(mano_path=mano_path, 122 | interPath=cfg.DATASET.INTERHAND_PATH, 123 | theta=[-cfg.DATA_AUGMENT.THETA, cfg.DATA_AUGMENT.THETA], 124 | scale=[1 - cfg.DATA_AUGMENT.SCALE, 1 + cfg.DATA_AUGMENT.SCALE], 125 | uv=[-cfg.DATA_AUGMENT.UV, cfg.DATA_AUGMENT.UV], 126 | aux_size=IMG_SIZE // aux_lambda) 127 | # print('local rank {}: init dataset, done'.format(rank)) 128 | 129 | provider_train = DataProvider(dataset=trainDataset, batch_size=cfg.TRAIN.BATCH_SIZE, 130 | num_workers=4, dist=dist_training) 131 | train_batch_per_epoch = provider_train.batch_per_epoch 132 | # print('local rank {}: init data loader, done'.format(rank)) 133 | 134 | Loss = {} 135 | faces = {} 136 | for hand_type in ['left', 'right']: 137 | with open(mano_path[hand_type], 'rb') as file: 138 | manoData = pickle.load(file, encoding='latin1') 139 | J_regressor = manoData['J_regressor'].tocoo(copy=False) 140 | location = [] 141 | data = [] 142 | for i in range(J_regressor.data.shape[0]): 143 | location.append([J_regressor.row[i], J_regressor.col[i]]) 144 | data.append(J_regressor.data[i]) 145 | i = torch.LongTensor(location) 146 | v = torch.FloatTensor(data) 147 | J_regressor = torch.sparse.FloatTensor(i.t(), v, torch.Size([16, 778])).to_dense() 148 | Loss[hand_type] = GraphLoss(J_regressor, manoData['f'], 149 | level=4, 150 | device=rank) 151 | # device='cuda:{}'.format(rank)) 152 | faces[hand_type] = manoData['f'] 153 | 154 | # print('local rank {}: init training loss, done'.format(rank)) 155 | 156 | # ------------ 157 | # | 3. train | 158 | # ------------ 159 | # print('local rank {}: strat training'.format(rank)) 160 | for epoch in range(cfg.TRAIN.current_epoch, cfg.TRAIN.EPOCHS): 161 | network.train() 162 | train_bar = range(train_batch_per_epoch) 163 | if rank == 0: 164 | train_bar = tqdm(train_bar) 165 | for bIdx in train_bar: 166 | total_idx = epoch * train_batch_per_epoch + bIdx 167 | 168 | # ------------ 169 | # | training | 170 | # ------------ 171 | label_list = provider_train.next() 172 | label_list_out = [] 173 | for label in label_list: 174 | if label is not None: 175 | label_list_out.append(label.to(rank)) 176 | [ori_img, 177 | imgTensors, mask, dense, hms, 178 | v2d_l, j2d_l, v2d_r, j2d_r, 179 | v3d_l, j3d_l, v3d_r, j3d_r, 180 | root_rel] = label_list_out 181 | result, paramsDict, handDictList, otherInfo = network(imgTensors) 182 | 183 | if cfg.MODEL.freeze_upsample: 184 | upsample_weight = None 185 | else: 186 | if dist_training: 187 | upsample_weight = network.module.decoder.get_upsample_weight() 188 | else: 189 | upsample_weight = network.decoder.get_upsample_weight() 190 | 191 | loss, aux_lost_dict, mano_loss_dict, coarsen_loss_dict = \ 192 | calc_loss_GCN(cfg, epoch, 193 | Loss['left'], Loss['right'], 194 | converter['left'], converter['right'], 195 | result, paramsDict, handDictList, otherInfo, 196 | mask, dense, hms, 197 | v2d_l, j2d_l, v2d_r, j2d_r, 198 | v3d_l, j3d_l, v3d_r, j3d_r, 199 | root_rel, img_size=imgTensors.shape[-1], 200 | upsample_weight=upsample_weight) 201 | 202 | optimizer.zero_grad() 203 | loss.backward() 204 | optimizer.step() 205 | 206 | # --------------- 207 | # | tensorboard | 208 | # --------------- 209 | if rank == 0: 210 | writer.add_scalar('learning_rate', lr_scheduler.get_lr()[0], total_idx) 211 | writer.add_scalar('train/total_loss', loss.item(), total_idx) 212 | for k, v in mano_loss_dict.items(): 213 | if k != 'total_loss': 214 | writer.add_scalar('train/mano_{}'.format(k), v.item(), total_idx) 215 | for k, v in aux_lost_dict.items(): 216 | if k != 'total_loss': 217 | writer.add_scalar('train/aux_{}'.format(k), v.item(), total_idx) 218 | for k, v in coarsen_loss_dict.items(): 219 | if k != 'total_loss': 220 | for t in range(len(v)): 221 | writer.add_scalar('train/coarsen_{}_{}'.format(k, t), v[t].item(), total_idx) 222 | if (total_idx + 1) % cfg.TB.SHOW_GAP == 0: 223 | tb_vis_train_gcn(cfg, writer, total_idx, renderer, v2d_l, v2d_r, 224 | ori_img, mask, dense, 225 | result, paramsDict, handDictList, otherInfo) 226 | 227 | tbUtils.draw_MANO_joints(writer, 'hms/l_gt', total_idx, ori_img[0], j2d_l[0]) 228 | handJ2d_pred, _ = get_final_preds2(otherInfo['hms'][:, :21].detach().cpu().numpy(), BLUR_KERNEL) 229 | handJ2d_pred = torch.from_numpy(handJ2d_pred) * aux_lambda 230 | tbUtils.draw_MANO_joints(writer, 'hms/l_pred', total_idx, ori_img[0], handJ2d_pred[0]) 231 | 232 | tbUtils.draw_MANO_joints(writer, 'hms/r_gt', total_idx, ori_img[0], j2d_r[0]) 233 | handJ2d_pred, _ = get_final_preds2(otherInfo['hms'][:, 21:].detach().cpu().numpy(), BLUR_KERNEL) 234 | handJ2d_pred = torch.from_numpy(handJ2d_pred) * aux_lambda 235 | tbUtils.draw_MANO_joints(writer, 'hms/r_pred', total_idx, ori_img[0], handJ2d_pred[0]) 236 | 237 | # -------- 238 | # | tqdm | 239 | # -------- 240 | train_bar.set_description('train, epoch:{}'.format(epoch)) 241 | train_bar.set_postfix(totalLoss=loss.item()) 242 | 243 | lr_scheduler.step() 244 | if (epoch + 1) % cfg.SAVE.SAVE_GAP == 0: 245 | if rank == 0: # save checkpoint in main process 246 | torch.save(network.state_dict(), os.path.join(cfg.SAVE.SAVE_DIR, str(epoch + 1) + '.pth')) 247 | 248 | if dist_training: 249 | dist.barrier() 250 | dist.destroy_process_group() 251 | -------------------------------------------------------------------------------- /core/loader.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 4 | 5 | from dataset.dataset_utils import BONE_LENGTH 6 | from dataset.interhand import InterHand_dataset 7 | from utils.utils import imgUtils, get_mano_path 8 | from models.manolayer import ManoLayer 9 | 10 | 11 | import random 12 | import torch 13 | import cv2 as cv 14 | import pickle 15 | import numpy as np 16 | from torch.utils.data import Dataset 17 | import torchvision.transforms as transforms 18 | 19 | 20 | class handDataset(Dataset): 21 | """mix different hand datasets""" 22 | 23 | def __init__(self, mano_path=None, 24 | interPath=None, 25 | theta=[-90, 90], scale=[0.75, 1.25], uv=[-10, 10], 26 | flip=True, 27 | train=True, 28 | aux_size=64, 29 | bone_length=BONE_LENGTH, 30 | noise=0.0): 31 | if mano_path is None: 32 | mano_path = get_mano_path() 33 | self.dataset = {} 34 | self.dataName = [] 35 | self.sizeList = [] 36 | self.theta = theta 37 | self.scale = scale 38 | self.uv = uv 39 | self.noise = noise 40 | self.flip = flip 41 | 42 | self.normalize_img = transforms.Normalize(mean=[0.485, 0.456, 0.406], 43 | std=[0.229, 0.224, 0.225]) 44 | 45 | self.train = train 46 | self.aux_size = aux_size 47 | self.bone_length = bone_length 48 | 49 | if interPath is not None and os.path.exists(str(interPath)): 50 | if self.train: 51 | split = 'train' 52 | else: 53 | split = 'val' 54 | self.dataset['inter'] = InterHand_dataset(interPath, split) 55 | self.dataName.append('inter') 56 | self.sizeList.append(len(self.dataset['inter'])) 57 | print('load interhand2.6m dataset, size: {}'.format(len(self.dataset['inter']))) 58 | 59 | self.size = 0 60 | for s in self.sizeList: 61 | self.size += s 62 | 63 | for i in range(1, len(self.sizeList)): 64 | self.sizeList[i] += self.sizeList[i - 1] 65 | 66 | def __len__(self): 67 | return self.size 68 | 69 | def augm_params(self): 70 | theta = random.random() * (self.theta[1] - self.theta[0]) + self.theta[0] 71 | scale = random.random() * (self.scale[1] - self.scale[0]) + self.scale[0] 72 | u = random.random() * (self.uv[1] - self.uv[0]) + self.uv[0] 73 | v = random.random() * (self.uv[1] - self.uv[0]) + self.uv[0] 74 | flip = random.random() > 0.5 if self.flip else False 75 | return theta, scale, u, v, flip 76 | 77 | def process_data(self, img, mask, dense, hand_dict): 78 | label2d_list = [hand_dict['left']['verts2d'], 79 | hand_dict['left']['joints2d'], 80 | hand_dict['right']['verts2d'], 81 | hand_dict['right']['joints2d']] 82 | label3d_list = [hand_dict['left']['verts3d'], 83 | hand_dict['left']['joints3d'], 84 | hand_dict['right']['verts3d'], 85 | hand_dict['right']['joints3d']] 86 | 87 | if self.train: 88 | # random sacle and translation 89 | hms_left = hand_dict['left']['hms'] 90 | hms_right = hand_dict['right']['hms'] 91 | 92 | theta, scale, u, v, flip = self.augm_params() 93 | imgList, label2d_list, label3d_list, _ \ 94 | = imgUtils.data_augmentation(theta, scale, u, v, 95 | img_list=[img, mask, dense] + hms_left + hms_right, 96 | label2d_list=label2d_list, 97 | label3d_list=label3d_list, 98 | img_size=img.shape[0]) 99 | img = imgList[0] 100 | mask = imgList[1] 101 | dense_map = imgList[2] 102 | hms = imgList[3:] 103 | 104 | # add img noise 105 | img = imgUtils.add_noise(img.astype(np.float32), 106 | noise=self.noise, 107 | scale=255.0, 108 | alpha=0.3, beta=0.05).astype(np.uint8) 109 | else: 110 | flip = False 111 | 112 | if flip: 113 | img = cv.flip(img, 1) 114 | mask = cv.flip(mask, 1) 115 | dense_map = cv.flip(dense_map, 1) 116 | for i in range(len(hms)): 117 | hms[i] = cv.flip(hms[i], 1) 118 | 119 | # to torch tensor 120 | dense_map = cv.resize(dense_map, (self.aux_size, self.aux_size)) 121 | dense_map = torch.tensor(dense_map, dtype=torch.float32) / 255 122 | dense_map = dense_map.permute(2, 0, 1) 123 | 124 | mask = cv.resize(mask, (self.aux_size, self.aux_size)) 125 | ret, mask = cv.threshold(mask, 127, 255, cv.THRESH_BINARY) 126 | mask = mask.astype(np.float) / 255 127 | mask = mask[..., 1:] 128 | if flip: 129 | mask = mask[..., [1, 0]] 130 | mask = torch.tensor(mask, dtype=torch.float32) 131 | mask = mask.permute(2, 0, 1) 132 | 133 | for i in range(len(hms)): 134 | hms[i] = cv.resize(hms[i], (self.aux_size, self.aux_size)) 135 | hms = np.concatenate(hms, axis=-1) 136 | if flip: 137 | idx = [i + 21 for i in range(21)] + [i for i in range(21)] 138 | hms = hms[..., idx] 139 | hms = torch.tensor(hms, dtype=torch.float32) / 255 140 | hms = hms.permute(2, 0, 1) 141 | 142 | ori_img = torch.tensor(img, dtype=torch.float32) / 255 143 | ori_img = ori_img.permute(2, 0, 1) 144 | imgTensor = torch.tensor(cv.cvtColor(img, cv.COLOR_BGR2RGB), dtype=torch.float32) / 255 145 | imgTensor = imgTensor.permute(2, 0, 1) 146 | imgTensor = self.normalize_img(imgTensor) 147 | 148 | root_left = label3d_list[1][9] 149 | root_right = label3d_list[3][9] 150 | root_rel = root_right - root_left 151 | label3d_list[0] = label3d_list[0] - root_left 152 | label3d_list[1] = label3d_list[1] - root_left 153 | label3d_list[2] = label3d_list[2] - root_right 154 | label3d_list[3] = label3d_list[3] - root_right 155 | 156 | if self.bone_length is not None: 157 | length = np.linalg.norm(label3d_list[1][9] - label3d_list[1][0]) \ 158 | + np.linalg.norm(label3d_list[3][9] - label3d_list[3][0]) 159 | length = length / 2 160 | scale = self.bone_length / length 161 | root_rel = root_rel * scale 162 | for i in range(4): 163 | label3d_list[i] = label3d_list[i] * scale 164 | 165 | root_rel = torch.tensor(root_rel, dtype=torch.float32) 166 | for i in range(4): 167 | label2d_list[i] = torch.tensor(label2d_list[i], dtype=torch.float32) 168 | label3d_list[i] = torch.tensor(label3d_list[i], dtype=torch.float32) 169 | 170 | if flip: 171 | root_rel[1:] = -root_rel[1:] 172 | for i in range(4): 173 | label2d_list[i][:, 0] = img.shape[0] - label2d_list[i][:, 0] 174 | label3d_list[i][:, 0] = -label3d_list[i][:, 0] 175 | 176 | [v2d_r, j2d_r, v2d_l, j2d_l] = label2d_list 177 | [v3d_r, j3d_r, v3d_l, j3d_l] = label3d_list 178 | else: 179 | [v2d_l, j2d_l, v2d_r, j2d_r] = label2d_list 180 | [v3d_l, j3d_l, v3d_r, j3d_r] = label3d_list 181 | 182 | return ori_img, \ 183 | imgTensor, mask, dense_map, hms, \ 184 | v2d_l, j2d_l, v2d_r, j2d_r,\ 185 | v3d_l, j3d_l, v3d_r, j3d_r, \ 186 | root_rel 187 | 188 | def __getitem__(self, idx): 189 | # for i in range(len(self.sizeList)): 190 | # if idx < self.sizeList[i]: 191 | # idx2 = idx - (0 if i == 0 else self.sizeList[i - 1]) 192 | # name = self.dataName[i] 193 | # if name == 'inter': 194 | # img, mask, dense, hand_dict = self.dataset[name][idx2] 195 | # break 196 | 197 | img, mask, dense, hand_dict = self.dataset['inter'][idx] 198 | 199 | return self.process_data(img, mask, dense, hand_dict) 200 | -------------------------------------------------------------------------------- /core/test_utils.py: -------------------------------------------------------------------------------- 1 | import cv2 as cv 2 | import torch 3 | import numpy as np 4 | import torchvision.transforms as transforms 5 | import math 6 | 7 | import sys 8 | import os 9 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 10 | 11 | 12 | from models.model import load_model 13 | from utils.config import load_cfg 14 | from utils.utils import imgUtils 15 | from utils.vis_utils import mano_two_hands_renderer 16 | from dataset.dataset_utils import IMG_SIZE 17 | 18 | 19 | class InterRender(): 20 | def __init__(self, 21 | cfg_path, 22 | model_path, 23 | input_size=IMG_SIZE, 24 | render_size=512): 25 | self.input_size = input_size 26 | self.render_size = render_size 27 | self.renderer = mano_two_hands_renderer(img_size=render_size, device='cuda') 28 | self.left_faces = self.renderer.mano['left'].get_faces() 29 | self.right_faces = self.renderer.mano['right'].get_faces() 30 | 31 | self.model = load_model(cfg_path) 32 | state = torch.load(model_path, map_location='cpu') 33 | try: 34 | self.model.load_state_dict(state) 35 | except: 36 | state2 = {} 37 | for k, v in state.items(): 38 | state2[k[7:]] = v 39 | self.model.load_state_dict(state2) 40 | self.model.eval() 41 | self.model.cuda() 42 | 43 | self.img_processor = transforms.Normalize(mean=[0.485, 0.456, 0.406], 44 | std=[0.229, 0.224, 0.225]) 45 | 46 | def process_img(self, img): 47 | img = imgUtils.pad2squre(img) 48 | img = cv.resize(img, (self.input_size, self.input_size)) 49 | imgTensor = torch.tensor(cv.cvtColor(img, cv.COLOR_BGR2RGB), dtype=torch.float32) / 255 50 | imgTensor = imgTensor.permute(2, 0, 1) 51 | imgTensor = self.img_processor(imgTensor).cuda().unsqueeze(0) 52 | return imgTensor 53 | 54 | @staticmethod 55 | def save_obj(path, verts, faces, color=None): 56 | with open(path, 'w') as file: 57 | for i in range(verts.shape[0]): 58 | if color is None: 59 | file.write('v {} {} {}\n'.format(verts[i, 0], verts[i, 1], verts[i, 2])) 60 | else: 61 | file.write('v {} {} {} {} {} {}\n'.format(verts[i, 0], verts[i, 1], verts[i, 2], 62 | color[0], color[1], color[2])) 63 | for i in range(faces.shape[0]): 64 | file.write('f {} {} {}\n'.format(faces[i, 0] + 1, faces[i, 1] + 1, faces[i, 2] + 1)) 65 | 66 | @torch.no_grad() 67 | def run_model(self, img): 68 | imgTensor = self.process_img(img) 69 | result, paramsDict, handDictList, otherInfo = self.model(imgTensor) 70 | 71 | params = {} 72 | params['scale_left'] = paramsDict['scale']['left'] 73 | params['trans2d_left'] = paramsDict['trans2d']['left'] 74 | params['scale_right'] = paramsDict['scale']['right'] 75 | params['trans2d_right'] = paramsDict['trans2d']['right'] 76 | params['v3d_left'] = result['verts3d']['left'] 77 | params['v3d_right'] = result['verts3d']['right'] 78 | params['otherInfo'] = otherInfo 79 | return params 80 | 81 | @torch.no_grad() 82 | def render(self, params, bg_img=None): 83 | img_out, mask_out = self.renderer.render_rgb_orth(scale_left=params['scale_left'], 84 | trans2d_left=params['trans2d_left'], 85 | scale_right=params['scale_right'], 86 | trans2d_right=params['trans2d_right'], 87 | v3d_left=params['v3d_left'], 88 | v3d_right=params['v3d_right']) 89 | img_out = img_out[0].detach().cpu().numpy() * 255 90 | mask_out = mask_out[0].detach().cpu().numpy()[..., np.newaxis] 91 | 92 | if bg_img is None: 93 | bg_img = np.ones_like(img_out) * 255 94 | else: 95 | bg_img = imgUtils.pad2squre(bg_img) 96 | bg_img = cv.resize(bg_img, (self.render_size, self.render_size)) 97 | 98 | img_out = img_out * mask_out + bg_img * (1 - mask_out) 99 | img_out = img_out.astype(np.uint8) 100 | return img_out 101 | 102 | @torch.no_grad() 103 | def render_other_view(self, params, theta=60): 104 | c = (torch.mean(params['v3d_left'], axis=1) + torch.mean(params['v3d_right'], axis=1)).unsqueeze(1) / 2 105 | v3d_left = params['v3d_left'] - c 106 | v3d_right = params['v3d_right'] - c 107 | 108 | theta = 3.14159 / 180 * theta 109 | R = [[math.cos(theta), 0, math.sin(theta)], 110 | [0, 1, 0], 111 | [-math.sin(theta), 0, math.cos(theta)]] 112 | R = torch.tensor(R).float().cuda() 113 | 114 | v3d_left = torch.matmul(v3d_left, R) 115 | v3d_right = torch.matmul(v3d_right, R) 116 | 117 | img_out, mask_out = self.renderer.render_rgb_orth(scale_left=torch.ones((1,)).float().cuda() * 3, 118 | scale_right=torch.ones((1,)).float().cuda() * 3, 119 | trans2d_left=torch.zeros((1, 2)).float().cuda(), 120 | trans2d_right=torch.zeros((1, 2)).float().cuda(), 121 | v3d_left=v3d_left, 122 | v3d_right=v3d_right) 123 | img_out = img_out[0].detach().cpu().numpy() * 255 124 | img = np.ones_like(img_out) * 255 125 | mask_out = mask_out[0].detach().cpu().numpy()[..., np.newaxis] 126 | img_out = img_out * mask_out + img * (1 - mask_out) 127 | img_out = img_out.astype(np.uint8) 128 | 129 | return img_out 130 | -------------------------------------------------------------------------------- /core/vis_train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 4 | 5 | import torch 6 | import numpy as np 7 | import random 8 | import cv2 as cv 9 | from utils.tb_utils import tbUtils 10 | 11 | 12 | @torch.no_grad() 13 | def tb_vis_train(cfg, writer, idx, renderer, verts2d, 14 | imgTensors, mask, dense, 15 | paramsDictList, handDictList, otherInfo): 16 | img_size = imgTensors.shape[-1] 17 | tbUtils.add_image(writer, '0_input', idx, 18 | imgTensors[0, [2, 1, 0]], dataformats='CHW') 19 | 20 | if 'mask' in otherInfo: 21 | tbUtils.add_image(writer, '1_img_aux/mask_gt', idx, 22 | mask[0], dataformats='HW', clamp=True) 23 | tbUtils.add_image(writer, '1_img_aux/mask_pred', idx, 24 | otherInfo['mask'][0], dataformats='HW', clamp=True) 25 | if 'dense' in otherInfo: 26 | tbUtils.add_image(writer, '1_img_aux/dense_gt', idx, 27 | dense[0], dataformats='CHW', clamp=True) 28 | tbUtils.add_image(writer, '1_img_aux/dense_pred', idx, 29 | otherInfo['dense'][0] * mask[0].unsqueeze(0), dataformats='CHW', clamp=True) 30 | 31 | tbUtils.draw_verts(writer, '2_mano/vert_gt', idx, 32 | imgTensors[0], verts2d[0], 33 | color=(0, 0, 255), 34 | parent=None) 35 | 36 | for itIdx in range(len(paramsDictList)): 37 | img, mask = renderer.render_rgb(scale=paramsDictList[itIdx]['scale'][:1], 38 | trans2d=paramsDictList[itIdx]['trans2d'][:1], 39 | v3d=handDictList[itIdx]['verts3d'][:1]) 40 | img = img[0] * mask[0].unsqueeze(-1) + imgTensors[0].permute(1, 2, 0) * (1 - mask[0].unsqueeze(-1)) 41 | tbUtils.add_image(writer, '2_mano/vert_out_{}'.format(itIdx), idx, 42 | img[..., [2, 1, 0]], dataformats='HWC') 43 | 44 | if 'attnList' in otherInfo: 45 | v_idx = random.randint(0, otherInfo['v2dList'][itIdx].shape[1] - 1) 46 | for itIdx in range(3): 47 | v2d = otherInfo['v2dList'][itIdx][0, v_idx].detach().cpu().numpy() 48 | attn = torch.sum(otherInfo['attnList'][itIdx][0], dim=0) 49 | attn = attn[v_idx].detach().cpu().numpy() 50 | attn = attn / attn.max() 51 | attn = cv.resize(attn, (img_size, img_size)) 52 | img = torch.clamp(imgTensors[0], 0, 1) * 255 53 | img = img.detach().cpu().numpy() 54 | img = img.transpose(1, 2, 0) 55 | temp = attn[..., np.newaxis] * img 56 | temp = temp.copy().astype(np.uint8) 57 | cv.circle(temp, (int(v2d[0]), int(v2d[1])), 2, (0, 0, 255), -1) 58 | temp = torch.from_numpy(temp).float() / 255 59 | tbUtils.add_image(writer, '3_attn/{}'.format(itIdx).format(itIdx), idx, 60 | temp[..., [2, 1, 0]], dataformats='HWC') 61 | 62 | 63 | @torch.no_grad() 64 | def tb_vis_train_gcn(cfg, writer, idx, renderer, verts2d_left, verts2d_right, 65 | imgTensors, mask, dense, 66 | result, paramsDict, handDictList, otherInfo): 67 | img_size = imgTensors.shape[-1] 68 | tbUtils.add_image(writer, '0_input', idx, 69 | imgTensors[0, [2, 1, 0]], dataformats='CHW') 70 | 71 | if 'mask' in otherInfo: 72 | tbUtils.add_image(writer, '1_img_aux/mask_gt', idx, 73 | mask[0, 0] * 0.5 + mask[0, 1], dataformats='HW', clamp=True) 74 | tbUtils.add_image(writer, '1_img_aux/mask_pred', idx, 75 | otherInfo['mask'][0, 0] * 0.5 + otherInfo['mask'][0, 1], dataformats='HW', clamp=True) 76 | if 'dense' in otherInfo: 77 | tbUtils.add_image(writer, '1_img_aux/dense_gt', idx, 78 | dense[0], dataformats='CHW', clamp=True) 79 | tbUtils.add_image(writer, '1_img_aux/dense_pred', idx, 80 | otherInfo['dense'][0, :3] * mask[0, :1] + otherInfo['dense'][0, 3:] * mask[0, 1:], 81 | dataformats='CHW', clamp=True) 82 | 83 | tbUtils.draw_verts(writer, '2_mano/vert_gt', idx, 84 | imgTensors[0], [verts2d_left[0], verts2d_right[0]], 85 | color=[(0, 0, 255), (255, 0, 0)]) 86 | 87 | img, mask = renderer.render_rgb_orth(scale_left=paramsDict['scale']['left'][:1], 88 | scale_right=paramsDict['scale']['right'][:1], 89 | trans2d_left=paramsDict['trans2d']['left'][:1], 90 | trans2d_right=paramsDict['trans2d']['right'][:1], 91 | v3d_left=result['verts3d']['left'][:1], 92 | v3d_right=result['verts3d']['right'][:1]) 93 | img = img[0] * mask[0].unsqueeze(-1) + imgTensors[0].permute(1, 2, 0) * (1 - mask[0].unsqueeze(-1)) 94 | tbUtils.add_image(writer, '2_mano/vert_out_result', idx, 95 | img[..., [2, 1, 0]], dataformats='HWC') 96 | 97 | for itIdx in range(len(handDictList)): 98 | img, mask = renderer.render_rgb_orth(scale_left=paramsDict['scale']['left'][:1], 99 | scale_right=paramsDict['scale']['right'][:1], 100 | trans2d_left=paramsDict['trans2d']['left'][:1], 101 | trans2d_right=paramsDict['trans2d']['right'][:1], 102 | v3d_left=otherInfo['verts3d_MANO_list']['left'][itIdx][:1], 103 | v3d_right=otherInfo['verts3d_MANO_list']['right'][itIdx][:1] 104 | ) 105 | img = img[0] * mask[0].unsqueeze(-1) + imgTensors[0].permute(1, 2, 0) * (1 - mask[0].unsqueeze(-1)) 106 | tbUtils.add_image(writer, '2_mano/vert_out_{}'.format(itIdx), idx, 107 | img[..., [2, 1, 0]], dataformats='HWC') 108 | -------------------------------------------------------------------------------- /dataset/dataset_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 as cv 3 | 4 | IMG_SIZE = 256 5 | HAND_BBOX_RATIO = 0.8 6 | HEATMAP_SIZE = 64 7 | HEATMAP_SIGMA = 2 8 | BLUR_KERNEL = 5 9 | BONE_LENGTH = 0.095 10 | 11 | 12 | def cut_img(img_list, label2d_list, camera=None, radio=0.7, img_size=256): 13 | Min = [] 14 | Max = [] 15 | for label2d in label2d_list: 16 | Min.append(np.min(label2d, axis=0)) 17 | Max.append(np.max(label2d, axis=0)) 18 | Min = np.min(np.array(Min), axis=0) 19 | Max = np.max(np.array(Max), axis=0) 20 | 21 | mid = (Min + Max) / 2 22 | L = np.max(Max - Min) / 2 / radio 23 | M = img_size / 2 / L * np.array([[1, 0, L - mid[0]], 24 | [0, 1, L - mid[1]]]) 25 | 26 | img_list_out = [] 27 | for img in img_list: 28 | img_list_out.append(cv.warpAffine(img, M, dsize=(img_size, img_size))) 29 | 30 | label2d_list_out = [] 31 | for label2d in label2d_list: 32 | x = np.concatenate([label2d, np.ones_like(label2d[:, :1])], axis=-1) 33 | x = x @ M.T 34 | label2d_list_out.append(x) 35 | 36 | if camera is not None: 37 | camera[0, 0] = camera[0, 0] * M[0, 0] 38 | camera[1, 1] = camera[1, 1] * M[1, 1] 39 | camera[0, 2] = camera[0, 2] * M[0, 0] + M[0, 2] 40 | camera[1, 2] = camera[1, 2] * M[1, 1] + M[1, 2] 41 | 42 | return img_list_out, label2d_list_out, camera 43 | -------------------------------------------------------------------------------- /dataset/heatmap.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def build_hm(x, y, sigma=4, res=64): 5 | xl = np.arange(0, res, 1, float)[np.newaxis, :] 6 | yl = np.arange(0, res, 1, float)[:, np.newaxis] 7 | hm = np.exp(- ((xl - x) ** 2 + (yl - y) ** 2) / (2 * sigma ** 2)) 8 | return hm 9 | 10 | 11 | class HeatmapGenerator(): 12 | def __init__(self, output_res=128, sigma=-1): 13 | self.output_res = output_res 14 | if sigma < 0: 15 | sigma = self.output_res / 32 16 | self.sigma = sigma 17 | 18 | def __call__(self, joints, scale=1): 19 | if joints.ndim == 2: 20 | joints = joints[np.newaxis, ...] 21 | if joints.shape[-1] == 2: 22 | joints = np.concatenate([joints, np.ones_like(joints[..., :1])], -1) 23 | 24 | # input : joints bs x N x 3 25 | bs = joints.shape[0] 26 | num_joints = joints.shape[1] 27 | hms = np.zeros((bs, num_joints, self.output_res, self.output_res), 28 | dtype=np.float32) 29 | sigma = self.sigma * scale 30 | 31 | for bsIdx in range(bs): 32 | for idx, pt in enumerate(joints[bsIdx]): 33 | if pt[2] > 0: 34 | x, y = pt[0], pt[1] 35 | if x < 0 or y < 0 or \ 36 | x >= self.output_res or y >= self.output_res: 37 | continue 38 | hms[bsIdx, idx] = build_hm(x, y, sigma, self.output_res) 39 | return hms 40 | -------------------------------------------------------------------------------- /dataset/inference.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Written by Bin Xiao (Bin.Xiao@microsoft.com) 5 | # Modified by Hanbin Dai (daihanbin.ac@gmail.com) and Feng Zhang (zhangfengwcy@gmail.com) 6 | # ------------------------------------------------------------------------------ 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | 12 | import math 13 | 14 | import numpy as np 15 | import cv2 16 | 17 | # from utils.transforms import transform_preds 18 | 19 | 20 | def get_max_preds(batch_heatmaps): 21 | ''' 22 | get predictions from score maps 23 | heatmaps: numpy.ndarray([batch_size, num_joints, height, width]) 24 | ''' 25 | assert isinstance(batch_heatmaps, np.ndarray), \ 26 | 'batch_heatmaps should be numpy.ndarray' 27 | assert batch_heatmaps.ndim == 4, 'batch_images should be 4-ndim' 28 | 29 | batch_size = batch_heatmaps.shape[0] 30 | num_joints = batch_heatmaps.shape[1] 31 | width = batch_heatmaps.shape[3] 32 | heatmaps_reshaped = batch_heatmaps.reshape((batch_size, num_joints, -1)) 33 | idx = np.argmax(heatmaps_reshaped, 2) 34 | maxvals = np.amax(heatmaps_reshaped, 2) 35 | 36 | maxvals = maxvals.reshape((batch_size, num_joints, 1)) 37 | idx = idx.reshape((batch_size, num_joints, 1)) 38 | 39 | preds = np.tile(idx, (1, 1, 2)).astype(np.float32) 40 | 41 | preds[:, :, 0] = (preds[:, :, 0]) % width 42 | preds[:, :, 1] = np.floor((preds[:, :, 1]) / width) 43 | 44 | pred_mask = np.tile(np.greater(maxvals, 0.0), (1, 1, 2)) 45 | pred_mask = pred_mask.astype(np.float32) 46 | 47 | preds *= pred_mask 48 | return preds, maxvals 49 | 50 | 51 | def taylor(hm, coord): 52 | heatmap_height = hm.shape[0] 53 | heatmap_width = hm.shape[1] 54 | px = int(coord[0]) 55 | py = int(coord[1]) 56 | if 1 < px < heatmap_width - 2 and 1 < py < heatmap_height - 2: 57 | dx = 0.5 * (hm[py][px + 1] - hm[py][px - 1]) 58 | dy = 0.5 * (hm[py + 1][px] - hm[py - 1][px]) 59 | dxx = 0.25 * (hm[py][px + 2] - 2 * hm[py][px] + hm[py][px - 2]) 60 | dxy = 0.25 * (hm[py + 1][px + 1] - hm[py - 1][px + 1] - hm[py + 1][px - 1] + hm[py - 1][px - 1]) 61 | dyy = 0.25 * (hm[py + 2 * 1][px] - 2 * hm[py][px] + hm[py - 2 * 1][px]) 62 | derivative = np.matrix([[dx], [dy]]) 63 | hessian = np.matrix([[dxx, dxy], [dxy, dyy]]) 64 | if dxx * dyy - dxy ** 2 != 0: 65 | hessianinv = hessian.I 66 | offset = -hessianinv * derivative 67 | offset = np.squeeze(np.array(offset.T), axis=0) 68 | coord += offset 69 | return coord 70 | 71 | 72 | def gaussian_blur(hm, kernel): 73 | border = (kernel - 1) // 2 74 | batch_size = hm.shape[0] 75 | num_joints = hm.shape[1] 76 | height = hm.shape[2] 77 | width = hm.shape[3] 78 | for i in range(batch_size): 79 | for j in range(num_joints): 80 | origin_max = np.max(hm[i, j]) 81 | dr = np.zeros((height + 2 * border, width + 2 * border)) 82 | dr[border: -border, border: -border] = hm[i, j].copy() 83 | dr = cv2.GaussianBlur(dr, (kernel, kernel), 0) 84 | hm[i, j] = dr[border: -border, border: -border].copy() 85 | hm[i, j] *= origin_max / np.max(hm[i, j]) 86 | return hm 87 | 88 | 89 | def get_final_preds(config, hm, center, scale): 90 | coords, maxvals = get_max_preds(hm) 91 | heatmap_height = hm.shape[2] 92 | heatmap_width = hm.shape[3] 93 | 94 | # post-processing 95 | hm = gaussian_blur(hm, config.TEST.BLUR_KERNEL) 96 | hm = np.maximum(hm, 1e-10) 97 | hm = np.log(hm) 98 | for n in range(coords.shape[0]): 99 | for p in range(coords.shape[1]): 100 | coords[n, p] = taylor(hm[n][p], coords[n][p]) 101 | 102 | preds = coords.copy() 103 | 104 | # Transform back 105 | for i in range(coords.shape[0]): 106 | preds[i] = transform_preds( 107 | coords[i], center[i], scale[i], [heatmap_width, heatmap_height] 108 | ) 109 | 110 | return preds, maxvals 111 | 112 | 113 | def get_final_preds2(hm, kernel): 114 | coords, maxvals = get_max_preds(hm) 115 | 116 | # post-processing 117 | if kernel > 1: 118 | hm = gaussian_blur(hm, kernel=kernel) 119 | hm = np.maximum(hm, 1e-10) 120 | hm = np.log(hm) 121 | for n in range(coords.shape[0]): 122 | for p in range(coords.shape[1]): 123 | coords[n, p] = taylor(hm[n][p], coords[n][p]) 124 | 125 | preds = coords.copy() 126 | 127 | return preds, maxvals 128 | -------------------------------------------------------------------------------- /dataset/interhand.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os.path as osp 3 | from tqdm import tqdm 4 | import cv2 as cv 5 | import numpy as np 6 | import torch 7 | import pickle 8 | from glob import glob 9 | 10 | import os 11 | import sys 12 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 13 | 14 | 15 | from models.manolayer import ManoLayer, rodrigues_batch 16 | from dataset.dataset_utils import IMG_SIZE, HAND_BBOX_RATIO, HEATMAP_SIGMA, HEATMAP_SIZE, cut_img 17 | from dataset.heatmap import HeatmapGenerator 18 | from utils.vis_utils import mano_two_hands_renderer 19 | from utils.utils import get_mano_path 20 | 21 | 22 | def fix_shape(mano_layer): 23 | if torch.sum(torch.abs(mano_layer['left'].shapedirs[:, 0, :] - mano_layer['right'].shapedirs[:, 0, :])) < 1: 24 | print('Fix shapedirs bug of MANO') 25 | mano_layer['left'].shapedirs[:, 0, :] *= -1 26 | 27 | 28 | class InterHandLoader(): 29 | def __init__(self, data_path, split='train', mano_path=None): 30 | assert split in ['train', 'test', 'val'] 31 | 32 | self.root_path = data_path 33 | self.img_root_path = os.path.join(self.root_path, 'images') 34 | self.annot_root_path = os.path.join(self.root_path, 'annotations') 35 | 36 | self.mano_layer = {'right': ManoLayer(mano_path['right'], center_idx=None), 37 | 'left': ManoLayer(mano_path['left'], center_idx=None)} 38 | fix_shape(self.mano_layer) 39 | 40 | self.split = split 41 | 42 | with open(osp.join(self.annot_root_path, self.split, 43 | 'InterHand2.6M_' + self.split + '_data.json')) as f: 44 | self.data_info = json.load(f) 45 | with open(osp.join(self.annot_root_path, self.split, 46 | 'InterHand2.6M_' + self.split + '_camera.json')) as f: 47 | self.cam_params = json.load(f) 48 | with open(osp.join(self.annot_root_path, self.split, 49 | 'InterHand2.6M_' + self.split + '_joint_3d.json')) as f: 50 | self.joints = json.load(f) 51 | with open(osp.join(self.annot_root_path, self.split, 52 | 'InterHand2.6M_' + self.split + '_MANO_NeuralAnnot.json')) as f: 53 | self.mano_params = json.load(f) 54 | 55 | self.data_size = len(self.data_info['images']) 56 | 57 | def __len__(self): 58 | return self.data_size 59 | 60 | def show_data(self, idx): 61 | for k in self.data_info['images'][idx].keys(): 62 | print(k, self.data_info['images'][idx][k]) 63 | for k in self.data_info['annotations'][idx].keys(): 64 | print(k, self.data_info['annotations'][idx][k]) 65 | 66 | def load_camera(self, idx): 67 | img_info = self.data_info['images'][idx] 68 | capture_idx = img_info['capture'] 69 | cam_idx = img_info['camera'] 70 | 71 | capture_idx = str(capture_idx) 72 | cam_idx = str(cam_idx) 73 | cam_param = self.cam_params[str(capture_idx)] 74 | cam_t = np.array(cam_param['campos'][cam_idx], dtype=np.float32).reshape(3) 75 | cam_R = np.array(cam_param['camrot'][cam_idx], dtype=np.float32).reshape(3, 3) 76 | cam_t = -np.dot(cam_R, cam_t.reshape(3, 1)).reshape(3) / 1000 # -Rt -> t 77 | 78 | # add camera intrinsics 79 | focal = np.array(cam_param['focal'][cam_idx], dtype=np.float32).reshape(2) 80 | princpt = np.array(cam_param['princpt'][cam_idx], dtype=np.float32).reshape(2) 81 | cameraIn = np.array([[focal[0], 0, princpt[0]], 82 | [0, focal[1], princpt[1]], 83 | [0, 0, 1]]) 84 | return cam_R, cam_t, cameraIn 85 | 86 | def load_mano(self, idx): 87 | img_info = self.data_info['images'][idx] 88 | capture_idx = img_info['capture'] 89 | frame_idx = img_info['frame_idx'] 90 | 91 | capture_idx = str(capture_idx) 92 | frame_idx = str(frame_idx) 93 | mano_dict = {} 94 | coord_dict = {} 95 | for hand_type in ['left', 'right']: 96 | try: 97 | mano_param = self.mano_params[capture_idx][frame_idx][hand_type] 98 | mano_pose = torch.FloatTensor(mano_param['pose']).view(-1, 3) 99 | root_pose = mano_pose[0].view(1, 3) 100 | hand_pose = mano_pose[1:, :].view(1, -1) 101 | # hand_pose = hand_pose.view(1, -1, 3) 102 | mano = self.mano_layer[hand_type] 103 | mean_pose = mano.hands_mean 104 | hand_pose = mano.axis2pca(hand_pose + mean_pose) 105 | shape = torch.FloatTensor(mano_param['shape']).view(1, -1) 106 | trans = torch.FloatTensor(mano_param['trans']).view(1, 3) 107 | root_pose = rodrigues_batch(root_pose) 108 | 109 | handV, handJ = self.mano_layer[hand_type](root_pose, hand_pose, shape, trans=trans) 110 | mano_dict[hand_type] = {'R': root_pose.numpy(), 'pose': hand_pose.numpy(), 'shape': shape.numpy(), 'trans': trans.numpy()} 111 | coord_dict[hand_type] = {'verts': handV, 'joints': handJ} 112 | except: 113 | mano_dict[hand_type] = None 114 | coord_dict[hand_type] = None 115 | 116 | return mano_dict, coord_dict 117 | 118 | def load_img(self, idx): 119 | img_info = self.data_info['images'][idx] 120 | img = cv.imread(osp.join(self.img_root_path, self.split, img_info['file_name'])) 121 | return img 122 | 123 | 124 | def cut_inter_img(loader, save_path, split): 125 | os.makedirs(osp.join(save_path, split, 'img'), exist_ok=True) 126 | os.makedirs(osp.join(save_path, split, 'anno'), exist_ok=True) 127 | 128 | idx = 0 129 | for i in tqdm(range(len(loader))): 130 | annotation = loader.data_info['annotations'][i] 131 | images_info = loader.data_info['images'][i] 132 | hand_type = annotation['hand_type'] 133 | hand_type_valid = annotation['hand_type_valid'] 134 | 135 | if hand_type == 'interacting' and hand_type_valid: 136 | mano_dict, coord_dict = loader.load_mano(i) 137 | if coord_dict['left'] is not None and coord_dict['right'] is not None: 138 | left = coord_dict['left']['verts'][0].detach().numpy() 139 | right = coord_dict['right']['verts'][0].detach().numpy() 140 | dist = np.linalg.norm(left - right, ord=2, axis=-1).min() 141 | if dist < 9999999: 142 | img = loader.load_img(i) 143 | if img.mean() < 10: 144 | continue 145 | 146 | cam_R, cam_t, cameraIn = loader.load_camera(i) 147 | left = left @ cam_R.T + cam_t 148 | left2d = left @ cameraIn.T 149 | left2d = left2d[:, :2] / left2d[:, 2:] 150 | right = right @ cam_R.T + cam_t 151 | right2d = right @ cameraIn.T 152 | right2d = right2d[:, :2] / right2d[:, 2:] 153 | 154 | [img], _, cameraIn = \ 155 | cut_img([img], [left2d, right2d], camera=cameraIn, radio=HAND_BBOX_RATIO, img_size=IMG_SIZE) 156 | cv.imwrite(osp.join(save_path, split, 'img', '{}.jpg'.format(idx)), img) 157 | 158 | data_info = {} 159 | data_info['inter_idx'] = idx 160 | data_info['image'] = images_info 161 | data_info['annotation'] = annotation 162 | data_info['mano_params'] = mano_dict 163 | data_info['camera'] = {'R': cam_R, 't': cam_t, 'camera': cameraIn} 164 | with open(osp.join(save_path, split, 'anno', '{}.pkl'.format(idx)), 'wb') as file: 165 | pickle.dump(data_info, file) 166 | 167 | idx = idx + 1 168 | 169 | 170 | def select_data(DATA_PATH, save_path, split): 171 | loader = InterHandLoader(DATA_PATH, split=split, mano_path=get_mano_path()) 172 | cut_inter_img(loader, save_path, split) 173 | 174 | 175 | def render_data(save_path, split): 176 | mano_path = get_mano_path() 177 | os.makedirs(osp.join(save_path, split, 'mask'), exist_ok=True) 178 | os.makedirs(osp.join(save_path, split, 'dense'), exist_ok=True) 179 | os.makedirs(osp.join(save_path, split, 'hms'), exist_ok=True) 180 | 181 | size = len(glob(osp.join(save_path, split, 'anno', '*.pkl'))) 182 | mano_layer = {'right': ManoLayer(mano_path['right'], center_idx=None), 183 | 'left': ManoLayer(mano_path['left'], center_idx=None)} 184 | fix_shape(mano_layer) 185 | renderer = mano_two_hands_renderer(img_size=IMG_SIZE, device='cuda') 186 | hmg = HeatmapGenerator(HEATMAP_SIZE, HEATMAP_SIGMA) 187 | 188 | for idx in tqdm(range(size)): 189 | with open(osp.join(save_path, split, 'anno', '{}.pkl'.format(idx)), 'rb') as file: 190 | data = pickle.load(file) 191 | 192 | R = data['camera']['R'] 193 | T = data['camera']['t'] 194 | camera = data['camera']['camera'] 195 | 196 | verts = [] 197 | for hand_type in ['left', 'right']: 198 | params = data['mano_params'][hand_type] 199 | handV, handJ = mano_layer[hand_type](torch.from_numpy(params['R']).float(), 200 | torch.from_numpy(params['pose']).float(), 201 | torch.from_numpy(params['shape']).float(), 202 | trans=torch.from_numpy(params['trans']).float()) 203 | handV = handV[0].numpy() 204 | handJ = handJ[0].numpy() 205 | handV = handV @ R.T + T 206 | handJ = handJ @ R.T + T 207 | 208 | handV2d = handV @ camera.T 209 | handV2d = handV2d[:, :2] / handV2d[:, 2:] 210 | handJ2d = handJ @ camera.T 211 | handJ2d = handJ2d[:, :2] / handJ2d[:, 2:] 212 | 213 | verts.append(torch.from_numpy(handV).float().cuda().unsqueeze(0)) 214 | hms = np.split(hmg(handJ2d * HEATMAP_SIZE / IMG_SIZE)[0], 7) # 21 x h x w 215 | for hIdx in range(len(hms)): 216 | cv.imwrite(os.path.join(save_path, split, 'hms', '{}_{}_{}.jpg'.format(idx, hIdx, hand_type)), 217 | hms[hIdx].transpose(1, 2, 0) * 255) 218 | 219 | img_mask = renderer.render_mask(cameras=torch.from_numpy(camera).float().cuda().unsqueeze(0), 220 | v3d_left=verts[0], v3d_right=verts[1]) 221 | img_mask = img_mask.detach().cpu().numpy()[0] * 255 222 | cv.imwrite(osp.join(save_path, split, 'mask', '{}.jpg'.format(idx)), img_mask) 223 | 224 | img_dense, _ = renderer.render_densepose(cameras=torch.from_numpy(camera).float().cuda().unsqueeze(0), 225 | v3d_left=verts[0], v3d_right=verts[1]) 226 | img_dense = img_dense.detach().cpu().numpy()[0] * 255 227 | cv.imwrite(osp.join(save_path, split, 'dense', '{}.jpg'.format(idx)), img_dense) 228 | 229 | 230 | class InterHand_dataset(): 231 | def __init__(self, data_path, split): 232 | assert split in ['train', 'test', 'val'] 233 | self.split = split 234 | mano_path = get_mano_path() 235 | self.mano_layer = {'right': ManoLayer(mano_path['right'], center_idx=None), 236 | 'left': ManoLayer(mano_path['left'], center_idx=None)} 237 | fix_shape(self.mano_layer) 238 | 239 | self.data_path = data_path 240 | self.size = len(glob(osp.join(data_path, split, 'anno', '*.pkl'))) 241 | 242 | def __len__(self): 243 | return self.size 244 | 245 | def __getitem__(self, idx): 246 | img = cv.imread(osp.join(self.data_path, self.split, 'img', '{}.jpg'.format(idx))) 247 | mask = cv.imread(osp.join(self.data_path, self.split, 'mask', '{}.jpg'.format(idx))) 248 | dense = cv.imread(osp.join(self.data_path, self.split, 'dense', '{}.jpg'.format(idx))) 249 | 250 | with open(os.path.join(self.data_path, self.split, 'anno', '{}.pkl'.format(idx)), 'rb') as file: 251 | data = pickle.load(file) 252 | 253 | R = data['camera']['R'] 254 | T = data['camera']['t'] 255 | camera = data['camera']['camera'] 256 | 257 | hand_dict = {} 258 | for hand_type in ['left', 'right']: 259 | hms = [] 260 | for hIdx in range(7): 261 | hm = cv.imread(os.path.join(self.data_path, self.split, 'hms', '{}_{}_{}.jpg'.format(idx, hIdx, hand_type))) 262 | hm = cv.resize(hm, (img.shape[1], img.shape[0])) 263 | hms.append(hm) 264 | 265 | params = data['mano_params'][hand_type] 266 | handV, handJ = self.mano_layer[hand_type](torch.from_numpy(params['R']).float(), 267 | torch.from_numpy(params['pose']).float(), 268 | torch.from_numpy(params['shape']).float(), 269 | trans=torch.from_numpy(params['trans']).float()) 270 | handV = handV[0].numpy() 271 | handJ = handJ[0].numpy() 272 | handV = handV @ R.T + T 273 | handJ = handJ @ R.T + T 274 | 275 | handV2d = handV @ camera.T 276 | handV2d = handV2d[:, :2] / handV2d[:, 2:] 277 | handJ2d = handJ @ camera.T 278 | handJ2d = handJ2d[:, :2] / handJ2d[:, 2:] 279 | 280 | hand_dict[hand_type] = {'hms': hms, 281 | 'verts3d': handV, 'joints3d': handJ, 282 | 'verts2d': handV2d, 'joints2d': handJ2d, 283 | 'R': R @ params['R'][0], 284 | 'pose': params['pose'][0], 285 | 'shape': params['shape'][0], 286 | 'camera': camera 287 | } 288 | 289 | return img, mask, dense, hand_dict 290 | 291 | 292 | if __name__ == '__main__': 293 | import argparse 294 | 295 | parser = argparse.ArgumentParser() 296 | parser.add_argument("--data_path", type=str) 297 | parser.add_argument("--save_path", type=str) 298 | opt = parser.parse_args() 299 | 300 | for split in ['train', 'test', 'val']: 301 | select_data(opt.data_path, opt.save_path, split=split) 302 | 303 | for split in ['train', 'test', 'val']: 304 | render_data(opt.save_path, split) 305 | -------------------------------------------------------------------------------- /demo/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dw1010/IntagHand/5bb0775b176a6a72162f1750c0226165fbcaa2eb/demo/1.jpg -------------------------------------------------------------------------------- /demo/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dw1010/IntagHand/5bb0775b176a6a72162f1750c0226165fbcaa2eb/demo/2.jpg -------------------------------------------------------------------------------- /demo/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dw1010/IntagHand/5bb0775b176a6a72162f1750c0226165fbcaa2eb/demo/3.jpg -------------------------------------------------------------------------------- /models/decoder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 4 | 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import pickle 10 | import numpy as np 11 | 12 | from dataset.dataset_utils import IMG_SIZE, BONE_LENGTH 13 | from utils.utils import projection_batch, get_dense_color_path, get_graph_dict_path, get_upsample_path 14 | from models.model_zoo import GCN_vert_convert, graph_upsample, graph_avg_pool 15 | from models.model_attn import DualGraph 16 | 17 | 18 | def weights_init(layer): 19 | classname = layer.__class__.__name__ 20 | # print(classname) 21 | if classname.find('Conv2d') != -1: 22 | nn.init.xavier_uniform_(layer.weight.data) 23 | elif classname.find('Linear') != -1: 24 | nn.init.xavier_uniform_(layer.weight.data) 25 | if layer.bias is not None: 26 | nn.init.constant_(layer.bias.data, 0.0) 27 | 28 | 29 | class decoder(nn.Module): 30 | def __init__(self, 31 | global_feature_dim=2048, 32 | f_in_Dim=[256, 256, 256, 256], 33 | f_out_Dim=[128, 64, 32], 34 | gcn_in_dim=[256, 128, 128], 35 | gcn_out_dim=[128, 128, 64], 36 | graph_k=2, 37 | graph_layer_num=4, 38 | left_graph_dict={}, 39 | right_graph_dict={}, 40 | vertex_num=778, 41 | dense_coor=None, 42 | num_attn_heads=4, 43 | upsample_weight=None, 44 | dropout=0.05): 45 | super(decoder, self).__init__() 46 | assert len(f_in_Dim) == 4 47 | f_in_Dim = f_in_Dim[:-1] 48 | assert len(gcn_in_dim) == 3 49 | for i in range(len(gcn_out_dim) - 1): 50 | assert gcn_out_dim[i] == gcn_in_dim[i + 1] 51 | 52 | graph_dict = {'left': left_graph_dict, 'right': right_graph_dict} 53 | graph_dict['left']['coarsen_graphs_L'].reverse() 54 | graph_dict['right']['coarsen_graphs_L'].reverse() 55 | graph_L = {} 56 | for hand_type in ['left', 'right']: 57 | graph_L[hand_type] = graph_dict[hand_type]['coarsen_graphs_L'] 58 | 59 | self.vNum_in = graph_L['left'][0].shape[0] 60 | self.vNum_out = graph_L['left'][2].shape[0] 61 | self.vNum_all = graph_L['left'][-1].shape[0] 62 | self.vNum_mano = vertex_num 63 | self.gf_dim = global_feature_dim 64 | self.gcn_in_dim = gcn_in_dim 65 | self.gcn_out_dim = gcn_out_dim 66 | 67 | if dense_coor is not None: 68 | dense_coor = torch.from_numpy(dense_coor).float() 69 | self.register_buffer('dense_coor', dense_coor) 70 | 71 | self.converter = {} 72 | for hand_type in ['left', 'right']: 73 | self.converter[hand_type] = GCN_vert_convert(vertex_num=self.vNum_mano, 74 | graph_perm_reverse=graph_dict[hand_type]['graph_perm_reverse'], 75 | graph_perm=graph_dict[hand_type]['graph_perm']) 76 | 77 | self.dual_gcn = DualGraph(verts_in_dim=self.gcn_in_dim, 78 | verts_out_dim=self.gcn_out_dim, 79 | graph_L_Left=graph_L['left'][:3], 80 | graph_L_Right=graph_L['right'][:3], 81 | graph_k=[graph_k, graph_k, graph_k], 82 | graph_layer_num=[graph_layer_num, graph_layer_num, graph_layer_num], 83 | img_size=[8, 16, 32], 84 | img_f_dim=f_in_Dim, 85 | grid_size=[8, 8, 8], 86 | grid_f_dim=f_out_Dim, 87 | n_heads=num_attn_heads, 88 | dropout=dropout) 89 | 90 | self.gf_layer_left = nn.Sequential(*(nn.Linear(self.gf_dim, self.gcn_in_dim[0] - 3), 91 | nn.LayerNorm(self.gcn_in_dim[0] - 3, eps=1e-6))) 92 | self.gf_layer_right = nn.Sequential(*(nn.Linear(self.gf_dim, self.gcn_in_dim[0] - 3), 93 | nn.LayerNorm(self.gcn_in_dim[0] - 3, eps=1e-6))) 94 | self.unsample_layer = nn.Linear(self.vNum_out, self.vNum_mano, bias=False) 95 | 96 | self.coord_head = nn.Linear(self.gcn_out_dim[-1], 3) 97 | self.avg_head = nn.Linear(self.vNum_out, 1) 98 | self.params_head = nn.Linear(self.gcn_out_dim[-1], 3) 99 | 100 | weights_init(self.gf_layer_left) 101 | weights_init(self.gf_layer_right) 102 | weights_init(self.coord_head) 103 | weights_init(self.avg_head) 104 | weights_init(self.params_head) 105 | 106 | if upsample_weight is not None: 107 | state = {'weight': upsample_weight.to(self.unsample_layer.weight.data.device)} 108 | self.unsample_layer.load_state_dict(state) 109 | else: 110 | weights_init(self.unsample_layer) 111 | 112 | def get_upsample_weight(self): 113 | return self.unsample_layer.weight.data 114 | 115 | def get_converter(self): 116 | return self.converter 117 | 118 | def get_hand_pe(self, bs, num=None): 119 | if num is None: 120 | num = self.vNum_in 121 | dense_coor = self.dense_coor.repeat(bs, 1, 1) * 2 - 1 122 | pel = self.converter['left'].vert_to_GCN(dense_coor) 123 | pel = graph_avg_pool(pel, p=pel.shape[1] // num) 124 | per = self.converter['right'].vert_to_GCN(dense_coor) 125 | per = graph_avg_pool(per, p=per.shape[1] // num) 126 | return pel, per 127 | 128 | def forward(self, x, fmaps): 129 | assert x.shape[1] == self.gf_dim 130 | fmaps = fmaps[:-1] 131 | bs = x.shape[0] 132 | 133 | pel, per = self.get_hand_pe(bs, num=self.vNum_in) 134 | Lf = torch.cat([self.gf_layer_left(x).unsqueeze(1).repeat(1, self.vNum_in, 1), pel], dim=-1) 135 | Rf = torch.cat([self.gf_layer_right(x).unsqueeze(1).repeat(1, self.vNum_in, 1), per], dim=-1) 136 | 137 | Lf, Rf = self.dual_gcn(Lf, Rf, fmaps) 138 | 139 | scale = {} 140 | trans2d = {} 141 | temp = self.avg_head(Lf.transpose(-1, -2))[..., 0] 142 | temp = self.params_head(temp) 143 | scale['left'] = temp[:, 0] 144 | trans2d['left'] = temp[:, 1:] 145 | temp = self.avg_head(Rf.transpose(-1, -2))[..., 0] 146 | temp = self.params_head(temp) 147 | scale['right'] = temp[:, 0] 148 | trans2d['right'] = temp[:, 1:] 149 | 150 | handDictList = [] 151 | 152 | paramsDict = {'scale': scale, 'trans2d': trans2d} 153 | verts3d = {'left': self.coord_head(Lf), 'right': self.coord_head(Rf)} 154 | verts2d = {} 155 | result = {'verts3d': {}, 'verts2d': {}} 156 | for hand_type in ['left', 'right']: 157 | verts2d[hand_type] = projection_batch(scale[hand_type], trans2d[hand_type], verts3d[hand_type], img_size=IMG_SIZE) 158 | result['verts3d'][hand_type] = self.unsample_layer(verts3d[hand_type].transpose(1, 2)).transpose(1, 2) 159 | result['verts2d'][hand_type] = projection_batch(scale[hand_type], trans2d[hand_type], result['verts3d'][hand_type], img_size=IMG_SIZE) 160 | handDictList.append({'verts3d': verts3d, 'verts2d': verts2d}) 161 | 162 | otherInfo = {} 163 | otherInfo['verts3d_MANO_list'] = {'left': [], 'right': []} 164 | otherInfo['verts2d_MANO_list'] = {'left': [], 'right': []} 165 | for i in range(len(handDictList)): 166 | for hand_type in ['left', 'right']: 167 | v = handDictList[i]['verts3d'][hand_type] 168 | v = graph_upsample(v, p=self.vNum_all // v.shape[1]) 169 | otherInfo['verts3d_MANO_list'][hand_type].append(self.converter[hand_type].GCN_to_vert(v)) 170 | v = handDictList[i]['verts2d'][hand_type] 171 | v = graph_upsample(v, p=self.vNum_all // v.shape[1]) 172 | otherInfo['verts2d_MANO_list'][hand_type].append(self.converter[hand_type].GCN_to_vert(v)) 173 | 174 | return result, paramsDict, handDictList, otherInfo 175 | 176 | 177 | def load_decoder(cfg, encoder_info): 178 | graph_path = get_graph_dict_path() 179 | with open(graph_path['left'], 'rb') as file: 180 | left_graph_dict = pickle.load(file) 181 | with open(graph_path['right'], 'rb') as file: 182 | right_graph_dict = pickle.load(file) 183 | 184 | dense_path = get_dense_color_path() 185 | with open(dense_path, 'rb') as file: 186 | dense_coor = pickle.load(file) 187 | 188 | upsample_path = get_upsample_path() 189 | with open(upsample_path, 'rb') as file: 190 | upsample_weight = pickle.load(file) 191 | upsample_weight = torch.from_numpy(upsample_weight).float() 192 | 193 | model = decoder( 194 | global_feature_dim=encoder_info['global_feature_dim'], 195 | f_in_Dim=encoder_info['fmaps_dim'], 196 | f_out_Dim=cfg.MODEL.IMG_DIMS, 197 | gcn_in_dim=cfg.MODEL.GCN_IN_DIM, 198 | gcn_out_dim=cfg.MODEL.GCN_OUT_DIM, 199 | graph_k=cfg.MODEL.graph_k, 200 | graph_layer_num=cfg.MODEL.graph_layer_num, 201 | vertex_num=778, 202 | dense_coor=dense_coor, 203 | left_graph_dict=left_graph_dict, 204 | right_graph_dict=right_graph_dict, 205 | num_attn_heads=4, 206 | upsample_weight=upsample_weight, 207 | dropout=cfg.TRAIN.dropout 208 | ) 209 | 210 | return model 211 | -------------------------------------------------------------------------------- /models/encoder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import pickle 9 | import numpy as np 10 | 11 | from dataset.dataset_utils import IMG_SIZE 12 | from utils.utils import projection_batch 13 | from models.manolayer import ManoLayer 14 | from models.model_zoo import get_hrnet, conv1x1, conv3x3, deconv3x3, weights_init, GCN_vert_convert, build_fc_layer, Bottleneck 15 | 16 | from utils.config import load_cfg 17 | 18 | from torchvision.models import resnet18, resnet34, resnet50, resnet101, resnet152 19 | 20 | 21 | class ResNetSimple_decoder(nn.Module): 22 | def __init__(self, expansion=4, 23 | fDim=[256, 256, 256, 256], direction=['flat', 'up', 'up', 'up'], 24 | out_dim=3): 25 | super(ResNetSimple_decoder, self).__init__() 26 | self.models = nn.ModuleList() 27 | fDim = [512 * expansion] + fDim 28 | for i in range(len(direction)): 29 | kernel_size = 1 if direction[i] == 'flat' else 3 30 | self.models.append(self.make_layer(fDim[i], fDim[i + 1], direction[i], kernel_size=kernel_size)) 31 | 32 | self.final_layer = nn.Conv2d( 33 | in_channels=fDim[-1], 34 | out_channels=out_dim, 35 | kernel_size=1, 36 | stride=1, 37 | padding=0 38 | ) 39 | 40 | def make_layer(self, in_dim, out_dim, 41 | direction, kernel_size=3, relu=True, bn=True): 42 | assert direction in ['flat', 'up'] 43 | assert kernel_size in [1, 3] 44 | if kernel_size == 3: 45 | padding = 1 46 | elif kernel_size == 1: 47 | padding = 0 48 | 49 | layers = [] 50 | if direction == 'up': 51 | layers.append(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)) 52 | layers.append(nn.Conv2d(in_dim, out_dim, kernel_size=kernel_size, stride=1, padding=padding, bias=False)) 53 | layers.append(nn.ReLU(inplace=True)) 54 | layers.append(nn.BatchNorm2d(out_dim)) 55 | 56 | return nn.Sequential(*layers) 57 | 58 | def forward(self, x): 59 | fmaps = [] 60 | for i in range(len(self.models)): 61 | x = self.models[i](x) 62 | fmaps.append(x) 63 | x = self.final_layer(x) 64 | return x, fmaps 65 | 66 | 67 | class ResNetSimple(nn.Module): 68 | def __init__(self, model_type='resnet50', 69 | pretrained=False, 70 | fmapDim=[256, 256, 256, 256], 71 | handNum=2, 72 | heatmapDim=21): 73 | super(ResNetSimple, self).__init__() 74 | assert model_type in ['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152'] 75 | if model_type == 'resnet18': 76 | self.resnet = resnet18(pretrained=pretrained) 77 | self.expansion = 1 78 | elif model_type == 'resnet34': 79 | self.resnet = resnet34(pretrained=pretrained) 80 | self.expansion = 1 81 | elif model_type == 'resnet50': 82 | self.resnet = resnet50(pretrained=pretrained) 83 | self.expansion = 4 84 | elif model_type == 'resnet101': 85 | self.resnet = resnet101(pretrained=pretrained) 86 | self.expansion = 4 87 | elif model_type == 'resnet152': 88 | self.resnet = resnet152(pretrained=pretrained) 89 | self.expansion = 4 90 | 91 | self.hms_decoder = ResNetSimple_decoder(expansion=self.expansion, 92 | fDim=fmapDim, 93 | direction=['flat', 'up', 'up', 'up'], 94 | out_dim=heatmapDim * handNum) 95 | for m in self.hms_decoder.modules(): 96 | weights_init(m) 97 | 98 | self.dp_decoder = ResNetSimple_decoder(expansion=self.expansion, 99 | fDim=fmapDim, 100 | direction=['flat', 'up', 'up', 'up'], 101 | out_dim=handNum + 3 * handNum) 102 | self.handNum = handNum 103 | 104 | for m in self.dp_decoder.modules(): 105 | weights_init(m) 106 | 107 | def forward(self, x): 108 | x = self.resnet.conv1(x) 109 | x = self.resnet.bn1(x) 110 | x = self.resnet.relu(x) 111 | x = self.resnet.maxpool(x) 112 | 113 | x4 = self.resnet.layer1(x) 114 | x3 = self.resnet.layer2(x4) 115 | x2 = self.resnet.layer3(x3) 116 | x1 = self.resnet.layer4(x2) 117 | 118 | img_fmaps = [x1, x2, x3, x4] 119 | 120 | hms, hms_fmaps = self.hms_decoder(x1) 121 | out, dp_fmaps = self.dp_decoder(x1) 122 | mask = out[:, :self.handNum] 123 | dp = out[:, self.handNum:] 124 | 125 | return hms, mask, dp, \ 126 | img_fmaps, hms_fmaps, dp_fmaps 127 | 128 | 129 | class resnet_mid(nn.Module): 130 | def __init__(self, 131 | model_type='resnet50', 132 | in_fmapDim=[256, 256, 256, 256], 133 | out_fmapDim=[256, 256, 256, 256]): 134 | super(resnet_mid, self).__init__() 135 | assert model_type in ['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152'] 136 | if model_type == 'resnet18' or model_type == 'resnet34': 137 | self.expansion = 1 138 | elif model_type == 'resnet50' or model_type == 'resnet101' or model_type == 'resnet152': 139 | self.expansion = 4 140 | 141 | self.img_fmaps_dim = [512 * self.expansion, 256 * self.expansion, 142 | 128 * self.expansion, 64 * self.expansion] 143 | self.dp_fmaps_dim = in_fmapDim 144 | self.hms_fmaps_dim = in_fmapDim 145 | 146 | self.convs = nn.ModuleList() 147 | for i in range(len(out_fmapDim)): 148 | inDim = self.dp_fmaps_dim[i] + self.hms_fmaps_dim[i] 149 | if i > 0: 150 | inDim = inDim + self.img_fmaps_dim[i] 151 | self.convs.append(conv1x1(inDim, out_fmapDim[i])) 152 | 153 | self.output_layer = nn.Sequential( 154 | nn.AdaptiveAvgPool2d(1), 155 | nn.Flatten(start_dim=1), 156 | ) 157 | 158 | self.global_feature_dim = 512 * self.expansion 159 | self.fmaps_dim = out_fmapDim 160 | 161 | def get_info(self): 162 | return {'global_feature_dim': self.global_feature_dim, 163 | 'fmaps_dim': self.fmaps_dim} 164 | 165 | def forward(self, img_fmaps, hms_fmaps, dp_fmaps): 166 | global_feature = self.output_layer(img_fmaps[0]) 167 | fmaps = [] 168 | for i in range(len(self.convs)): 169 | x = torch.cat((hms_fmaps[i], dp_fmaps[i]), dim=1) 170 | if i > 0: 171 | x = torch.cat((x, img_fmaps[i]), dim=1) 172 | fmaps.append(self.convs[i](x)) 173 | return global_feature, fmaps 174 | 175 | 176 | class HRnet_encoder(nn.Module): 177 | def __init__(self, model_type, pretrained='', handNum=2, heatmapDim=21): 178 | super(HRnet_encoder, self).__init__() 179 | name = 'w' + model_type[model_type.find('hrnet') + 5:] 180 | assert name in ['w18', 'w18_small_v1', 'w18_small_v2', 'w30', 'w32', 'w40', 'w44', 'w48', 'w64'] 181 | 182 | self.hrnet = get_hrnet(name=name, 183 | in_channels=3, 184 | head_type='none', 185 | pretrained='') 186 | 187 | if os.path.isfile(pretrained): 188 | print('load pretrained params: {}'.format(pretrained)) 189 | pretrained_dict = torch.load(pretrained) 190 | model_dict = self.hrnet.state_dict() 191 | pretrained_dict = {k: v for k, v in pretrained_dict.items() 192 | if k in model_dict.keys() and k.find('classifier') == -1} 193 | model_dict.update(pretrained_dict) 194 | self.hrnet.load_state_dict(model_dict) 195 | 196 | self.fmaps_dim = list(self.hrnet.stage4_cfg['NUM_CHANNELS']) 197 | self.fmaps_dim.reverse() 198 | 199 | self.hms_decoder = self.mask_decoder(outDim=heatmapDim * handNum) 200 | for m in self.hms_decoder.modules(): 201 | weights_init(m) 202 | 203 | self.dp_decoder = self.mask_decoder(outDim=1 + 3 * handNum) 204 | for m in self.dp_decoder.modules(): 205 | weights_init(m) 206 | 207 | def mask_decoder(self, outDim=3): 208 | last_inp_channels = 0 209 | for temp in self.fmaps_dim: 210 | last_inp_channels += temp 211 | 212 | return nn.Sequential( 213 | nn.Conv2d( 214 | in_channels=last_inp_channels, out_channels=last_inp_channels, 215 | kernel_size=1, stride=1, padding=0), 216 | nn.BatchNorm2d(last_inp_channels), 217 | nn.ReLU(inplace=True), 218 | nn.Conv2d( 219 | in_channels=last_inp_channels, out_channels=outDim, 220 | kernel_size=1, stride=1, padding=0) 221 | ) 222 | 223 | def forward(self, img): 224 | ylist = self.hrnet(img) 225 | 226 | # Upsampling 227 | x0_h, x0_w = ylist[0].size(2), ylist[0].size(3) 228 | x1 = F.interpolate(ylist[1], size=(x0_h, x0_w), mode='bilinear', align_corners=True) 229 | x2 = F.interpolate(ylist[2], size=(x0_h, x0_w), mode='bilinear', align_corners=True) 230 | x3 = F.interpolate(ylist[3], size=(x0_h, x0_w), mode='bilinear', align_corners=True) 231 | x = torch.cat([ylist[0], x1, x2, x3], 1) 232 | 233 | hms = self.hms_decoder(x) 234 | out = self.dp_decoder(x) 235 | mask = out[:, 0] 236 | dp = out[:, 1:] 237 | 238 | ylist.reverse() 239 | return hms, mask, dp, \ 240 | ylist, None, None 241 | 242 | 243 | class hrnet_mid(nn.Module): 244 | def __init__(self, 245 | model_type, 246 | in_fmapDim=[256, 256, 256, 256], 247 | out_fmapDim=[256, 256, 256, 256]): 248 | super(hrnet_mid, self).__init__() 249 | name = 'w' + model_type[model_type.find('hrnet') + 5:] 250 | assert name in ['w18', 'w18_small_v1', 'w18_small_v2', 'w30', 'w32', 'w40', 'w44', 'w48', 'w64'] 251 | 252 | self.convs = nn.ModuleList() 253 | for i in range(len(out_fmapDim)): 254 | self.convs.append(conv1x1(in_fmapDim[i], out_fmapDim[i])) 255 | 256 | self.global_feature_dim = 2048 257 | self.fmaps_dim = out_fmapDim 258 | 259 | in_fmapDim.reverse() 260 | self.incre_modules, self.downsamp_modules, \ 261 | self.final_layer = self._make_head(in_fmapDim) 262 | 263 | def get_info(self): 264 | return {'global_feature_dim': self.global_feature_dim, 265 | 'fmaps_dim': self.fmaps_dim} 266 | 267 | def _make_layer(self, block, inplanes, planes, blocks, stride=1): 268 | downsample = None 269 | if stride != 1 or inplanes != planes * block.expansion: 270 | downsample = nn.Sequential( 271 | nn.Conv2d(inplanes, planes * block.expansion, 272 | kernel_size=1, stride=stride, bias=False), 273 | nn.BatchNorm2d(planes * block.expansion, momentum=0.1), 274 | ) 275 | 276 | layers = [] 277 | layers.append(block(inplanes, planes, stride, downsample)) 278 | inplanes = planes * block.expansion 279 | for i in range(1, blocks): 280 | layers.append(block(inplanes, planes)) 281 | 282 | return nn.Sequential(*layers) 283 | 284 | def _make_head(self, pre_stage_channels): 285 | head_block = Bottleneck 286 | head_channels = [32, 64, 128, 256] 287 | 288 | # Increasing the #channels on each resolution 289 | # from C, 2C, 4C, 8C to 128, 256, 512, 1024 290 | incre_modules = [] 291 | for i, channels in enumerate(pre_stage_channels): 292 | incre_module = self._make_layer(head_block, 293 | channels, 294 | head_channels[i], 295 | 1, 296 | stride=1) 297 | incre_modules.append(incre_module) 298 | incre_modules = nn.ModuleList(incre_modules) 299 | 300 | # downsampling modules 301 | downsamp_modules = [] 302 | for i in range(len(pre_stage_channels) - 1): 303 | in_channels = head_channels[i] * head_block.expansion 304 | out_channels = head_channels[i + 1] * head_block.expansion 305 | 306 | downsamp_module = nn.Sequential( 307 | nn.Conv2d(in_channels=in_channels, 308 | out_channels=out_channels, 309 | kernel_size=3, 310 | stride=2, 311 | padding=1), 312 | nn.BatchNorm2d(out_channels, momentum=0.1), 313 | nn.ReLU(inplace=True) 314 | ) 315 | 316 | downsamp_modules.append(downsamp_module) 317 | downsamp_modules = nn.ModuleList(downsamp_modules) 318 | 319 | final_layer = nn.Sequential( 320 | nn.Conv2d( 321 | in_channels=head_channels[3] * head_block.expansion, 322 | out_channels=2048, 323 | kernel_size=1, 324 | stride=1, 325 | padding=0 326 | ), 327 | nn.BatchNorm2d(2048, momentum=0.1), 328 | nn.ReLU(inplace=True) 329 | ) 330 | 331 | return incre_modules, downsamp_modules, final_layer 332 | 333 | def forward(self, img_fmaps, hms_fmaps=None, dp_fmaps=None): 334 | fmaps = [] 335 | for i in range(len(self.convs)): 336 | fmaps.append(self.convs[i](img_fmaps[i])) 337 | 338 | img_fmaps.reverse() 339 | y = self.incre_modules[0](img_fmaps[0]) 340 | for i in range(len(self.downsamp_modules)): 341 | y = self.incre_modules[i + 1](img_fmaps[i + 1]) + \ 342 | self.downsamp_modules[i](y) 343 | 344 | y = self.final_layer(y) 345 | 346 | if torch._C._get_tracing_state(): 347 | y = y.flatten(start_dim=2).mean(dim=2) 348 | else: 349 | y = F.avg_pool2d(y, kernel_size=y.size() 350 | [2:]).view(y.size(0), -1) 351 | 352 | return y, fmaps 353 | 354 | 355 | def load_encoder(cfg): 356 | if cfg.MODEL.ENCODER_TYPE.find('resnet') != -1: 357 | encoder = ResNetSimple(model_type=cfg.MODEL.ENCODER_TYPE, 358 | pretrained=True, 359 | fmapDim=[128, 128, 128, 128], 360 | handNum=2, 361 | heatmapDim=21) 362 | mid_model = resnet_mid(model_type=cfg.MODEL.ENCODER_TYPE, 363 | in_fmapDim=[128, 128, 128, 128], 364 | out_fmapDim=cfg.MODEL.DECONV_DIMS) 365 | if cfg.MODEL.ENCODER_TYPE.find('hrnet') != -1: 366 | encoder = HRnet_encoder(model_type=cfg.MODEL.ENCODER_TYPE, 367 | pretrained=cfg.MODEL.ENCODER_PRETRAIN_PATH, 368 | handNum=2, 369 | heatmapDim=21) 370 | mid_model = hrnet_mid(model_type=cfg.MODEL.ENCODER_TYPE, 371 | in_fmapDim=encoder.fmaps_dim, 372 | out_fmapDim=cfg.MODEL.DECONV_DIMS) 373 | 374 | return encoder, mid_model 375 | -------------------------------------------------------------------------------- /models/manolayer.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy as np 3 | import torch 4 | from torch.nn import Module 5 | 6 | 7 | def convert_mano_pkl(loadPath, savePath): 8 | # in original MANO pkl file, 'shapedirs' component is a chumpy object, convert it to a numpy array 9 | manoData = pickle.load(open(loadPath, 'rb'), encoding='latin1') 10 | output = {} 11 | manoData['shapedirs'].r 12 | for (k, v) in manoData.items(): 13 | if k == 'shapedirs': 14 | output['shapedirs'] = v.r 15 | else: 16 | output[k] = v 17 | pickle.dump(output, open(savePath, 'wb')) 18 | 19 | 20 | def vec2mat(vec): 21 | # vec: bs * 6 22 | # output: bs * 3 * 3 23 | x = vec[:, 0:3] 24 | y = vec[:, 3:6] 25 | x = x / (torch.norm(x, p=2, dim=1, keepdim=True) + 1e-8) 26 | y = y - torch.sum(x * y, dim=1, keepdim=True) * x 27 | y = y / (torch.norm(y, p=2, dim=1, keepdim=True) + 1e-8) 28 | z = torch.cross(x, y) 29 | return torch.stack([x, y, z], dim=2) 30 | 31 | 32 | def rodrigues_batch(axis): 33 | # axis : bs * 3 34 | # return: bs * 3 * 3 35 | bs = axis.shape[0] 36 | Imat = torch.eye(3, dtype=axis.dtype, device=axis.device).repeat(bs, 1, 1) # bs * 3 * 3 37 | angle = torch.norm(axis, p=2, dim=1, keepdim=True) + 1e-8 # bs * 1 38 | axes = axis / angle # bs * 3 39 | sin = torch.sin(angle).unsqueeze(2) # bs * 1 * 1 40 | cos = torch.cos(angle).unsqueeze(2) # bs * 1 * 1 41 | L = torch.zeros((bs, 3, 3), dtype=axis.dtype, device=axis.device) 42 | L[:, 2, 1] = axes[:, 0] 43 | L[:, 1, 2] = -axes[:, 0] 44 | L[:, 0, 2] = axes[:, 1] 45 | L[:, 2, 0] = -axes[:, 1] 46 | L[:, 1, 0] = axes[:, 2] 47 | L[:, 0, 1] = -axes[:, 2] 48 | return Imat + sin * L + (1 - cos) * L.bmm(L) 49 | 50 | 51 | def get_trans(old_z, new_z): 52 | # z: bs x 3 53 | x = torch.cross(old_z, new_z) 54 | x = x / torch.norm(x, dim=1, keepdim=True) 55 | old_y = torch.cross(old_z, x) 56 | new_y = torch.cross(new_z, x) 57 | old_frame = torch.stack((x, old_y, old_z), axis=2) 58 | new_frame = torch.stack((x, new_y, new_z), axis=2) 59 | trans = torch.matmul(new_frame, old_frame.permute(0, 2, 1)) 60 | return trans 61 | 62 | 63 | def build_mano_frame(skelBatch): 64 | # skelBatch: bs x 21 x 3 65 | bs = skelBatch.shape[0] 66 | mano_son = [2, 3, 17, 5, 6, 18, 8, 9, 20, 11, 12, 19, 14, 15, 16] # 15 67 | mano_parent = [-1, 0, 1, 2, 0, 4, 5, 0, 7, 8, 0, 10, 11, 0, 13, 14] # 16 68 | palm_idx = [13, 1, 4, 10, 7] 69 | mano_order = [0, 5, 6, 7, 9, 10, 11, 17, 18, 19, 13, 14, 15, 1, 2, 3, 4, 8, 12, 16, 20] # 21 70 | 71 | skel = skelBatch[:, mano_order] 72 | z = skel[:, mano_son] - skel[:, 1:16] # bs x 15 x 3 73 | z = z / torch.norm(z, dim=2, keepdim=True) 74 | z = torch.cat((torch.zeros_like(z[:, 0:1]), z), axis=1) # bs x 16 x 3 75 | x = torch.zeros_like(z) # bs x 16 x 3 76 | x[:, :, 1] = 1.0 77 | y = torch.zeros_like(z) # bs x 16 x 3 78 | 79 | palm = skel[:, palm_idx] - skel[:, 0:1] # bs x 5 x 3 80 | n = torch.cross(palm[:, :-1], palm[:, 1:]) # bs x 4 x 3 81 | n = n / torch.norm(n, dim=2, keepdim=True) 82 | palm_x = torch.zeros((bs, 5, 3), dtype=n.dtype, device=n.device) 83 | palm_x[:, :-1] = palm_x[:, :-1] + n 84 | palm_x[:, 1:] = palm_x[:, 1:] + n 85 | palm_x = palm_x / torch.norm(palm_x, dim=2, keepdim=True) 86 | x[:, palm_idx] = palm_x 87 | 88 | y[:, palm_idx] = torch.cross(z[:, palm_idx], x[:, palm_idx]) 89 | y[:, palm_idx] = y[:, palm_idx] / torch.norm(y[:, palm_idx], dim=2, keepdim=True) 90 | x[:, palm_idx] = torch.cross(y[:, palm_idx], z[:, palm_idx]) 91 | frame = torch.stack((x, y, z), axis=3) # bs x 15 x 3 x 3 92 | for i in range(1, 16): 93 | if i in palm_idx: 94 | continue 95 | trans = get_trans(z[:, mano_parent[i]], z[:, i]) 96 | frame[:, i] = torch.matmul(trans, frame[:, mano_parent[i]]) 97 | return frame[:, 1:] 98 | 99 | 100 | class ManoLayer(Module): 101 | def __init__(self, manoPath, center_idx=9, use_pca=True, new_skel=False): 102 | super(ManoLayer, self).__init__() 103 | 104 | self.center_idx = center_idx 105 | self.use_pca = use_pca 106 | self.new_skel = new_skel 107 | 108 | manoData = pickle.load(open(manoPath, 'rb'), encoding='latin1') 109 | 110 | self.new_order = [0, 111 | 13, 14, 15, 16, 112 | 1, 2, 3, 17, 113 | 4, 5, 6, 18, 114 | 10, 11, 12, 19, 115 | 7, 8, 9, 20] 116 | 117 | # 45 * 45: PCA mat 118 | self.register_buffer('hands_components', torch.from_numpy(manoData['hands_components'].astype(np.float32))) 119 | hands_components_inv = torch.inverse(self.hands_components) 120 | self.register_buffer('hands_components_inv', hands_components_inv) 121 | # 16 * 778, J_regressor is a scipy csc matrix 122 | J_regressor = manoData['J_regressor'].tocoo(copy=False) 123 | location = [] 124 | data = [] 125 | for i in range(J_regressor.data.shape[0]): 126 | location.append([J_regressor.row[i], J_regressor.col[i]]) 127 | data.append(J_regressor.data[i]) 128 | i = torch.LongTensor(location) 129 | v = torch.FloatTensor(data) 130 | self.register_buffer('J_regressor', torch.sparse.FloatTensor(i.t(), v, torch.Size([16, 778])).to_dense(), 131 | persistent=False) 132 | # 16 * 3 133 | self.register_buffer('J_zero', torch.from_numpy(manoData['J'].astype(np.float32)), persistent=False) 134 | # 778 * 16 135 | self.register_buffer('weights', torch.from_numpy(manoData['weights'].astype(np.float32)), persistent=False) 136 | # (778, 3, 135) 137 | self.register_buffer('posedirs', torch.from_numpy(manoData['posedirs'].astype(np.float32)), persistent=False) 138 | # (778, 3) 139 | self.register_buffer('v_template', torch.from_numpy(manoData['v_template'].astype(np.float32)), persistent=False) 140 | # (778, 3, 10) shapedirs is 141 | if isinstance(manoData['shapedirs'], np.ndarray): 142 | self.register_buffer('shapedirs', torch.Tensor(manoData['shapedirs']).float(), persistent=False) 143 | else: 144 | self.register_buffer('shapedirs', torch.Tensor(manoData['shapedirs'].r.copy()).float(), persistent=False) 145 | # 45 146 | self.register_buffer('hands_mean', torch.from_numpy(manoData['hands_mean'].astype(np.float32)), persistent=False) 147 | 148 | self.faces = manoData['f'] # 1538 * 3: faces 149 | 150 | self.parent = [-1, ] 151 | for i in range(1, 16): 152 | self.parent.append(manoData['kintree_table'][0, i]) 153 | 154 | def get_faces(self): 155 | return self.faces 156 | 157 | def train(self, mode=True): 158 | self.is_train = mode 159 | 160 | def eval(self): 161 | self.train(False) 162 | 163 | def pca2axis(self, pca): 164 | rotation_axis = pca.mm(self.hands_components[:pca.shape[1]]) # bs * 45 165 | rotation_axis = rotation_axis + self.hands_mean 166 | return rotation_axis # bs * 45 167 | 168 | def pca2Rmat(self, pca): 169 | return self.axis2Rmat(self.pca2axis(pca)) 170 | 171 | def axis2Rmat(self, axis): 172 | # axis: bs x 45 173 | rotation_mat = rodrigues_batch(axis.view(-1, 3)) 174 | rotation_mat = rotation_mat.view(-1, 15, 3, 3) 175 | return rotation_mat 176 | 177 | def axis2pca(self, axis): 178 | # axis: bs x 45 179 | pca = axis - self.hands_mean 180 | pca = pca.mm(self.hands_components_inv) 181 | return pca 182 | 183 | def Rmat2pca(self, R): 184 | # R: bs x 15 x 3 x 3 185 | return self.axis2pca(self.Rmat2axis(R)) 186 | 187 | def Rmat2axis(self, R): 188 | # R: bs x 3 x 3 189 | R = R.view(-1, 3, 3) 190 | temp = (R - R.permute(0, 2, 1)) / 2 191 | L = temp[:, [2, 0, 1], [1, 2, 0]] # bs x 3 192 | sin = torch.norm(L, dim=1, keepdim=False) # bs 193 | L = L / (sin.unsqueeze(-1) + 1e-8) 194 | 195 | temp = (R + R.permute(0, 2, 1)) / 2 196 | temp = temp - torch.eye((3), dtype=R.dtype, device=R.device) 197 | temp2 = torch.matmul(L.unsqueeze(-1), L.unsqueeze(1)) 198 | temp2 = temp2 - torch.eye((3), dtype=R.dtype, device=R.device) 199 | temp = temp[:, 0, 0] + temp[:, 1, 1] + temp[:, 2, 2] 200 | temp2 = temp2[:, 0, 0] + temp2[:, 1, 1] + temp2[:, 2, 2] 201 | cos = 1 - temp / (temp2 + 1e-8) # bs 202 | 203 | sin = torch.clamp(sin, min=-1 + 1e-7, max=1 - 1e-7) 204 | theta = torch.asin(sin) 205 | 206 | # prevent in-place operation 207 | theta2 = torch.zeros_like(theta) 208 | theta2[:] = theta 209 | idx1 = (cos < 0) & (sin > 0) 210 | idx2 = (cos < 0) & (sin < 0) 211 | theta2[idx1] = 3.14159 - theta[idx1] 212 | theta2[idx2] = -3.14159 - theta[idx2] 213 | axis = theta2.unsqueeze(-1) * L 214 | 215 | return axis.view(-1, 45) 216 | 217 | def get_local_frame(self, shape): 218 | # output: frame[..., [0,1,2]] = [splay, bend, twist] 219 | # get local joint frame at zero pose 220 | with torch.no_grad(): 221 | shapeBlendShape = torch.matmul(self.shapedirs, shape.permute(1, 0)).permute(2, 0, 1) 222 | v_shaped = self.v_template + shapeBlendShape # bs * 778 * 3 223 | j_tpose = torch.matmul(self.J_regressor, v_shaped) # bs * 16 * 3 224 | j_tpose_21 = torch.cat((j_tpose, v_shaped[:, [744, 320, 444, 555, 672]]), axis=1) 225 | j_tpose_21 = j_tpose_21[:, self.new_order] 226 | frame = build_mano_frame(j_tpose_21) 227 | return frame # bs x 15 x 3 x 3 228 | 229 | @staticmethod 230 | def buildSE3_batch(R, t): 231 | # R: bs * 3 * 3 232 | # t: bs * 3 * 1 233 | # return: bs * 4 * 4 234 | bs = R.shape[0] 235 | pad = torch.zeros((bs, 1, 4), dtype=R.dtype, device=R.device) 236 | pad[:, 0, 3] = 1.0 237 | temp = torch.cat([R, t], 2) # bs * 3 * 4 238 | return torch.cat([temp, pad], 1) 239 | 240 | @staticmethod 241 | def SE3_apply(SE3, v): 242 | # SE3: bs * 4 * 4 243 | # v: bs * 3 244 | # return: bs * 3 245 | bs = v.shape[0] 246 | pad = torch.ones((bs, 1), dtype=v.dtype, device=v.device) 247 | temp = torch.cat([v, pad], 1).unsqueeze(2) # bs * 4 * 1 248 | return SE3.bmm(temp)[:, :3, 0] 249 | 250 | def forward(self, root_rotation, pose, shape, trans=None, scale=None): 251 | # input 252 | # root_rotation : bs * 3 * 3 253 | # pose : bs * ncomps or bs * 15 * 3 * 3 254 | # shape : bs * 10 255 | # trans : bs * 3 or None 256 | # scale : bs or None 257 | bs = root_rotation.shape[0] 258 | 259 | if self.use_pca: 260 | rotation_mat = self.pca2Rmat(pose) 261 | else: 262 | rotation_mat = pose 263 | 264 | shapeBlendShape = torch.matmul(self.shapedirs, shape.permute(1, 0)).permute(2, 0, 1) 265 | v_shaped = self.v_template + shapeBlendShape # bs * 778 * 3 266 | 267 | j_tpose = torch.matmul(self.J_regressor, v_shaped) # bs * 16 * 3 268 | 269 | Imat = torch.eye(3, dtype=rotation_mat.dtype, device=rotation_mat.device).repeat(bs, 15, 1, 1) 270 | pose_shape = rotation_mat.view(bs, -1) - Imat.view(bs, -1) # bs * 135 271 | poseBlendShape = torch.matmul(self.posedirs, pose_shape.permute(1, 0)).permute(2, 0, 1) 272 | v_tpose = v_shaped + poseBlendShape # bs * 778 * 3 273 | 274 | SE3_j = [] 275 | R = root_rotation 276 | t = (torch.eye(3, dtype=pose.dtype, device=pose.device).repeat(bs, 1, 1) - R).bmm(j_tpose[:, 0].unsqueeze(2)) 277 | SE3_j.append(self.buildSE3_batch(R, t)) 278 | for i in range(1, 16): 279 | R = rotation_mat[:, i - 1] 280 | t = (torch.eye(3, dtype=pose.dtype, device=pose.device).repeat(bs, 1, 1) - R).bmm(j_tpose[:, i].unsqueeze(2)) 281 | SE3_local = self.buildSE3_batch(R, t) 282 | SE3_j.append(torch.matmul(SE3_j[self.parent[i]], SE3_local)) 283 | SE3_j = torch.stack(SE3_j, dim=1) # bs * 16 * 4 * 4 284 | 285 | j_withoutTips = [] 286 | j_withoutTips.append(j_tpose[:, 0]) 287 | for i in range(1, 16): 288 | j_withoutTips.append(self.SE3_apply(SE3_j[:, self.parent[i]], j_tpose[:, i])) 289 | 290 | # there is no boardcast matmul for sparse matrix for now (pytorch 1.6.0) 291 | SE3_v = torch.matmul(self.weights, SE3_j.view(bs, 16, 16)).view(bs, -1, 4, 4) # bs * 778 * 4 * 4 292 | 293 | v_output = SE3_v[:, :, :3, :3].matmul(v_tpose.unsqueeze(3)) + SE3_v[:, :, :3, 3:4] 294 | v_output = v_output[:, :, :, 0] # bs * 778 * 3 295 | 296 | jList = j_withoutTips + [v_output[:, 745], v_output[:, 317], v_output[:, 444], v_output[:, 556], v_output[:, 673]] 297 | 298 | j_output = torch.stack(jList, dim=1) 299 | j_output = j_output[:, self.new_order] 300 | 301 | if self.center_idx is not None: 302 | center = j_output[:, self.center_idx:(self.center_idx + 1)] 303 | v_output = v_output - center 304 | j_output = j_output - center 305 | 306 | if scale is not None: 307 | scale = scale.unsqueeze(1).unsqueeze(2) # bs * 1 * 1 308 | v_output = v_output * scale 309 | j_output = j_output * scale 310 | 311 | if trans is not None: 312 | trans = trans.unsqueeze(1) # bs * 1 * 3 313 | v_output = v_output + trans 314 | j_output = j_output + trans 315 | 316 | if self.new_skel: 317 | j_output[:, 5] = (v_output[:, 63] + v_output[:, 144]) / 2 318 | j_output[:, 9] = (v_output[:, 271] + v_output[:, 220]) / 2 319 | j_output[:, 13] = (v_output[:, 148] + v_output[:, 290]) / 2 320 | j_output[:, 17] = (v_output[:, 770] + v_output[:, 83]) / 2 321 | 322 | return v_output, j_output 323 | 324 | 325 | if __name__ == '__main__': 326 | convert_mano_pkl('models/MANO_RIGHT.pkl', 'MANO_RIGHT.pkl') 327 | convert_mano_pkl('models/MANO_LEFT.pkl', 'MANO_LEFT.pkl') 328 | 329 | mano = ManoLayer(manoPath='models/MANO_RIGHT.pkl', center_idx=9, use_pca=True) 330 | pose = torch.rand((10, 30)) 331 | shape = torch.rand((10, 10)) 332 | rotation = torch.rand((10, 3)) 333 | root_rotation = rodrigues_batch(rotation) 334 | trans = torch.rand((10, 3)) 335 | scale = torch.rand((10)) 336 | v, j = mano(root_rotation=root_rotation, 337 | pose=pose, 338 | shape=shape, 339 | trans=trans, 340 | scale=scale) 341 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import pickle 9 | import numpy as np 10 | 11 | from dataset.dataset_utils import IMG_SIZE 12 | from models.encoder import load_encoder 13 | from models.decoder import load_decoder 14 | 15 | from utils.config import load_cfg 16 | 17 | 18 | class HandNET_GCN(nn.Module): 19 | def __init__(self, encoder, mid_model, decoder): 20 | super(HandNET_GCN, self).__init__() 21 | self.encoder = encoder 22 | self.mid_model = mid_model 23 | self.decoder = decoder 24 | 25 | def forward(self, img): 26 | hms, mask, dp, img_fmaps, hms_fmaps, dp_fmaps = self.encoder(img) 27 | global_feature, fmaps = self.mid_model(img_fmaps, hms_fmaps, dp_fmaps) 28 | result, paramsDict, handDictList, otherInfo = self.decoder(global_feature, fmaps) 29 | 30 | if hms is not None: 31 | otherInfo['hms'] = hms 32 | if mask is not None: 33 | otherInfo['mask'] = mask 34 | if dp is not None: 35 | otherInfo['dense'] = dp 36 | 37 | return result, paramsDict, handDictList, otherInfo 38 | 39 | 40 | def load_model(cfg): 41 | if isinstance(cfg, str): 42 | cfg = load_cfg(cfg) 43 | encoder, mid_model = load_encoder(cfg) 44 | decoder = load_decoder(cfg, mid_model.get_info()) 45 | model = HandNET_GCN(encoder, mid_model, decoder) 46 | 47 | abspath = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) 48 | path = os.path.join(abspath, str(cfg.MODEL_PARAM.MODEL_PRETRAIN_PATH)) 49 | if os.path.exists(path): 50 | state = torch.load(path, map_location='cpu') 51 | print('load model params from {}'.format(path)) 52 | try: 53 | model.load_state_dict(state) 54 | except: 55 | state2 = {} 56 | for k, v in state.items(): 57 | state2[k[7:]] = v 58 | model.load_state_dict(state2) 59 | 60 | return model 61 | -------------------------------------------------------------------------------- /models/model_attn/DualGraph.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | from .gcn import GraphLayer 7 | from .img_attn import img_ex 8 | from .inter_attn import inter_attn 9 | 10 | 11 | def graph_upsample(x, p): 12 | if p > 1: 13 | x = x.permute(0, 2, 1).contiguous() # x = B x F x V 14 | x = nn.Upsample(scale_factor=p)(x) # B x F x (V*p) 15 | x = x.permute(0, 2, 1).contiguous() # x = B x (V*p) x F 16 | return x 17 | else: 18 | return x 19 | 20 | 21 | class DualGraphLayer(nn.Module): 22 | def __init__(self, 23 | verts_in_dim=256, 24 | verts_out_dim=256, 25 | graph_L_Left=None, 26 | graph_L_Right=None, 27 | graph_k=2, 28 | graph_layer_num=4, 29 | img_size=64, 30 | img_f_dim=256, 31 | grid_size=8, 32 | grid_f_dim=128, 33 | n_heads=4, 34 | dropout=0.01): 35 | super().__init__() 36 | self.verts_num = graph_L_Left.shape[0] 37 | self.verts_in_dim = verts_in_dim 38 | self.img_size = img_size 39 | self.img_f_dim = img_f_dim 40 | 41 | self.position_embeddings = nn.Embedding(self.verts_num, self.verts_in_dim) 42 | 43 | self.graph_left = GraphLayer(verts_in_dim, verts_out_dim, 44 | graph_L_Left, graph_k, graph_layer_num, 45 | dropout) 46 | self.graph_right = GraphLayer(verts_in_dim, verts_out_dim, 47 | graph_L_Right, graph_k, graph_layer_num, 48 | dropout) 49 | 50 | self.img_ex_left = img_ex(img_size, img_f_dim, 51 | grid_size, grid_f_dim, 52 | verts_out_dim, 53 | n_heads=n_heads, 54 | dropout=dropout) 55 | self.img_ex_right = img_ex(img_size, img_f_dim, 56 | grid_size, grid_f_dim, 57 | verts_out_dim, 58 | n_heads=n_heads, 59 | dropout=dropout) 60 | self.attn = inter_attn(verts_out_dim, n_heads=n_heads, dropout=dropout) 61 | 62 | def forward(self, Lf, Rf, img_f): 63 | BS1, V, f = Lf.shape 64 | assert V == self.verts_num 65 | assert f == self.verts_in_dim 66 | BS2, V, f = Rf.shape 67 | assert V == self.verts_num 68 | assert f == self.verts_in_dim 69 | BS3, C, H, W = img_f.shape 70 | assert C == self.img_f_dim 71 | assert H == self.img_size 72 | assert W == self.img_size 73 | assert BS1 == BS2 74 | assert BS2 == BS3 75 | BS = BS1 76 | 77 | position_ids = torch.arange(self.verts_num, dtype=torch.long, device=Lf.device) 78 | position_ids = position_ids.unsqueeze(0).repeat(BS, 1) 79 | position_embeddings = self.position_embeddings(position_ids) 80 | Lf = Lf + position_embeddings 81 | Rf = Rf + position_embeddings 82 | 83 | Lf = self.graph_left(Lf) 84 | Rf = self.graph_right(Rf) 85 | 86 | Lf = self.img_ex_left(img_f, Lf) 87 | Rf = self.img_ex_right(img_f, Rf) 88 | 89 | Lf, Rf = self.attn(Lf, Rf) 90 | 91 | return Lf, Rf 92 | 93 | 94 | class DualGraph(nn.Module): 95 | def __init__(self, 96 | verts_in_dim=[512, 256, 128], 97 | verts_out_dim=[256, 128, 64], 98 | graph_L_Left=None, 99 | graph_L_Right=None, 100 | graph_k=[2, 2, 2], 101 | graph_layer_num=[4, 4, 4], 102 | img_size=[16, 32, 64], 103 | img_f_dim=[256, 256, 256], 104 | grid_size=[8, 8, 16], 105 | grid_f_dim=[256, 128, 64], 106 | n_heads=4, 107 | dropout=0.01): 108 | super().__init__() 109 | for i in range(len(verts_in_dim) - 1): 110 | assert verts_out_dim[i] == verts_in_dim[i + 1] 111 | for i in range(len(verts_in_dim) - 1): 112 | assert graph_L_Left[i + 1].shape[0] == 2 * graph_L_Left[i].shape[0] 113 | assert graph_L_Right[i + 1].shape[0] == 2 * graph_L_Right[i].shape[0] 114 | 115 | self.layers = nn.ModuleList() 116 | for i in range(len(verts_in_dim)): 117 | self.layers.append(DualGraphLayer(verts_in_dim=verts_in_dim[i], 118 | verts_out_dim=verts_out_dim[i], 119 | graph_L_Left=graph_L_Left[i], 120 | graph_L_Right=graph_L_Right[i], 121 | graph_k=graph_k[i], 122 | graph_layer_num=graph_layer_num[i], 123 | img_size=img_size[i], 124 | img_f_dim=img_f_dim[i], 125 | grid_size=grid_size[i], 126 | grid_f_dim=grid_f_dim[i], 127 | n_heads=n_heads, 128 | dropout=dropout)) 129 | 130 | def forward(self, Lf, Rf, img_f_list): 131 | assert len(img_f_list) == len(self.layers) 132 | for i in range(len(self.layers)): 133 | Lf, Rf = self.layers[i](Lf, Rf, img_f_list[i]) 134 | 135 | if i != len(self.layers) - 1: 136 | Lf = graph_upsample(Lf, 2) 137 | Rf = graph_upsample(Rf, 2) 138 | 139 | return Lf, Rf 140 | -------------------------------------------------------------------------------- /models/model_attn/__init__.py: -------------------------------------------------------------------------------- 1 | from .DualGraph import DualGraph 2 | 3 | __all__ = ['DualGraph'] -------------------------------------------------------------------------------- /models/model_attn/gcn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | 7 | def weights_init(layer): 8 | classname = layer.__class__.__name__ 9 | # print(classname) 10 | if classname.find('Conv2d') != -1: 11 | nn.init.xavier_uniform_(layer.weight.data) 12 | elif classname.find('Linear') != -1: 13 | nn.init.xavier_uniform_(layer.weight.data) 14 | if layer.bias is not None: 15 | nn.init.constant_(layer.bias.data, 0.0) 16 | 17 | 18 | def sparse_python_to_torch(sp_python): 19 | L = sp_python.tocoo() 20 | indices = np.column_stack((L.row, L.col)).T 21 | indices = indices.astype(np.int64) 22 | indices = torch.from_numpy(indices) 23 | indices = indices.type(torch.LongTensor) 24 | L_data = L.data.astype(np.float32) 25 | L_data = torch.from_numpy(L_data) 26 | L_data = L_data.type(torch.FloatTensor) 27 | L = torch.sparse.FloatTensor(indices, L_data, torch.Size(L.shape)) 28 | # if torch.cuda.is_available(): 29 | # L = L.cuda() 30 | 31 | return L 32 | 33 | 34 | def graph_conv_cheby(x, cl, L, K=3): 35 | # parameters 36 | # B = batch size 37 | # V = nb vertices 38 | # Fin = nb input features 39 | # Fout = nb output features 40 | # K = Chebyshev order & support size 41 | B, V, Fin = x.size() 42 | B, V, Fin = int(B), int(V), int(Fin) 43 | 44 | # transform to Chebyshev basis 45 | x0 = x.permute(1, 2, 0).contiguous() # V x Fin x B 46 | x0 = x0.view([V, Fin * B]) # V x Fin*B 47 | x = x0.unsqueeze(0) # 1 x V x Fin*B 48 | 49 | def concat(x, x_): 50 | x_ = x_.unsqueeze(0) # 1 x V x Fin*B 51 | return torch.cat((x, x_), 0) # K x V x Fin*B 52 | 53 | if K > 1: 54 | x1 = torch.mm(L, x0) # V x Fin*B 55 | x = torch.cat((x, x1.unsqueeze(0)), 0) # 2 x V x Fin*B 56 | for k in range(2, K): 57 | x2 = 2 * torch.mm(L, x1) - x0 58 | x = torch.cat((x, x2.unsqueeze(0)), 0) # M x Fin*B 59 | x0, x1 = x1, x2 60 | 61 | x = x.view([K, V, Fin, B]) # K x V x Fin x B 62 | x = x.permute(3, 1, 2, 0).contiguous() # B x V x Fin x K 63 | x = x.view([B * V, Fin * K]) # B*V x Fin*K 64 | 65 | # Compose linearly Fin features to get Fout features 66 | x = cl(x) # B*V x Fout 67 | x = x.view([B, V, -1]) # B x V x Fout 68 | 69 | return x 70 | 71 | 72 | class GCN_ResBlock(nn.Module): 73 | # x______________conv + norm (optianal)_____________ x ____activate 74 | # \____conv____activate____norm____conv____norm____/ 75 | def __init__(self, in_dim, out_dim, mid_dim, 76 | graph_L, graph_k, 77 | drop_out=0.01): 78 | super(GCN_ResBlock, self).__init__() 79 | if isinstance(graph_L, np.ndarray): 80 | self.register_buffer('graph_L', 81 | torch.from_numpy(graph_L).float(), 82 | persistent=False) 83 | else: 84 | self.register_buffer('graph_L', 85 | sparse_python_to_torch(graph_L).to_dense(), 86 | persistent=False) 87 | 88 | self.graph_k = graph_k 89 | self.in_dim = in_dim 90 | 91 | self.norm1 = nn.LayerNorm(in_dim, eps=1e-6) 92 | self.fc1 = nn.Linear(in_dim * graph_k, mid_dim) 93 | self.norm2 = nn.LayerNorm(out_dim, eps=1e-6) 94 | self.fc2 = nn.Linear(mid_dim * graph_k, out_dim) 95 | self.dropout = nn.Dropout(drop_out) 96 | self.shortcut = nn.Linear(in_dim, out_dim) 97 | self.norm3 = nn.LayerNorm(out_dim, eps=1e-6) 98 | 99 | def forward(self, x): 100 | # x : B x V x f 101 | assert x.shape[-1] == self.in_dim 102 | 103 | x1 = F.relu(self.norm1(x)) 104 | x1 = graph_conv_cheby(x, self.fc1, self.graph_L, K=self.graph_k) 105 | x1 = F.relu(self.norm2(x1)) 106 | x1 = graph_conv_cheby(x1, self.fc2, self.graph_L, K=self.graph_k) 107 | x1 = self.dropout(x1) 108 | x2 = self.shortcut(x) 109 | 110 | return self.norm3(x1 + x2) 111 | 112 | 113 | class GraphLayer(nn.Module): 114 | def __init__(self, 115 | in_dim=256, 116 | out_dim=256, 117 | graph_L=None, 118 | graph_k=2, 119 | graph_layer_num=3, 120 | drop_out=0.01): 121 | super().__init__() 122 | assert graph_k > 1 123 | 124 | self.GCN_blocks = nn.ModuleList() 125 | self.GCN_blocks.append(GCN_ResBlock(in_dim, out_dim, out_dim, graph_L, graph_k, drop_out)) 126 | for i in range(graph_layer_num - 1): 127 | self.GCN_blocks.append(GCN_ResBlock(out_dim, out_dim, out_dim, graph_L, graph_k, drop_out)) 128 | 129 | for m in self.modules(): 130 | weights_init(m) 131 | 132 | def forward(self, verts_f): 133 | for i in range(len(self.GCN_blocks)): 134 | verts_f = self.GCN_blocks[i](verts_f) 135 | if i != (len(self.GCN_blocks) - 1): 136 | verts_f = F.relu(verts_f) 137 | 138 | return verts_f 139 | -------------------------------------------------------------------------------- /models/model_attn/img_attn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .self_attn import SelfAttn 6 | 7 | 8 | def weights_init(layer): 9 | classname = layer.__class__.__name__ 10 | # print(classname) 11 | if classname.find('Conv2d') != -1: 12 | nn.init.xavier_uniform_(layer.weight.data) 13 | elif classname.find('Linear') != -1: 14 | nn.init.xavier_uniform_(layer.weight.data) 15 | if layer.bias is not None: 16 | nn.init.constant_(layer.bias.data, 0.0) 17 | 18 | 19 | class MLP_res_block(nn.Module): 20 | def __init__(self, in_dim, hid_dim, dropout=0.1): 21 | super().__init__() 22 | self.layer_norm = nn.LayerNorm(in_dim, eps=1e-6) 23 | self.fc1 = nn.Linear(in_dim, hid_dim) 24 | self.fc2 = nn.Linear(hid_dim, in_dim) 25 | 26 | self.dropout1 = nn.Dropout(dropout) 27 | self.dropout2 = nn.Dropout(dropout) 28 | 29 | def _ff_block(self, x): 30 | x = self.fc2(self.dropout1(F.relu(self.fc1(x)))) 31 | return self.dropout2(x) 32 | 33 | def forward(self, x): 34 | x = x + self._ff_block(self.layer_norm(x)) 35 | return x 36 | 37 | 38 | class img_feat_to_grid(nn.Module): 39 | def __init__(self, img_size, img_f_dim, grid_size, grid_f_dim, n_heads=4, dropout=0.01): 40 | super().__init__() 41 | self.img_f_dim = img_f_dim 42 | self.img_size = img_size 43 | self.grid_f_dim = grid_f_dim 44 | self.grid_size = grid_size 45 | self.position_embeddings = nn.Embedding(grid_size * grid_size, grid_f_dim) 46 | 47 | patch_size = img_size // grid_size 48 | self.proj = nn.Conv2d(img_f_dim, grid_f_dim, kernel_size=patch_size, stride=patch_size) 49 | self.self_attn = SelfAttn(grid_f_dim, n_heads=n_heads, hid_dim=grid_f_dim, dropout=dropout) 50 | 51 | def forward(self, img): 52 | bs = img.shape[0] 53 | assert img.shape[1] == self.img_f_dim 54 | assert img.shape[2] == self.img_size 55 | assert img.shape[3] == self.img_size 56 | 57 | position_ids = torch.arange(self.grid_size * self.grid_size, dtype=torch.long, device=img.device) 58 | position_ids = position_ids.unsqueeze(0).repeat(bs, 1) 59 | position_embeddings = self.position_embeddings(position_ids) 60 | 61 | grid_feat = F.relu(self.proj(img)) 62 | grid_feat = grid_feat.view(bs, self.grid_f_dim, -1).transpose(-1, -2) 63 | grid_feat = grid_feat + position_embeddings 64 | 65 | grid_feat = self.self_attn(grid_feat) 66 | 67 | return grid_feat 68 | 69 | 70 | class img_attn(nn.Module): 71 | def __init__(self, verts_f_dim, img_f_dim, n_heads=4, d_q=None, d_v=None, dropout=0.1): 72 | super().__init__() 73 | self.img_f_dim = img_f_dim 74 | self.verts_f_dim = verts_f_dim 75 | 76 | self.fc = nn.Linear(img_f_dim, verts_f_dim) 77 | self.Attn = SelfAttn(verts_f_dim, n_heads=n_heads, hid_dim=verts_f_dim, dropout=dropout) 78 | 79 | def forward(self, verts_f, img_f): 80 | assert verts_f.shape[2] == self.verts_f_dim 81 | assert img_f.shape[2] == self.img_f_dim 82 | assert verts_f.shape[0] == img_f.shape[0] 83 | V = verts_f.shape[1] 84 | 85 | img_f = self.fc(img_f) 86 | 87 | x = torch.cat([verts_f, img_f], dim=1) 88 | x = self.Attn(x) 89 | 90 | verts_f = x[:, :V] 91 | 92 | return verts_f 93 | 94 | 95 | class img_ex(nn.Module): 96 | def __init__(self, img_size, img_f_dim, 97 | grid_size, grid_f_dim, 98 | verts_f_dim, 99 | n_heads=4, 100 | dropout=0.01): 101 | super().__init__() 102 | self.verts_f_dim = verts_f_dim 103 | self.encoder = img_feat_to_grid(img_size, img_f_dim, grid_size, grid_f_dim, n_heads, dropout) 104 | self.attn = img_attn(verts_f_dim, grid_f_dim, n_heads=n_heads, dropout=dropout) 105 | 106 | for m in self.modules(): 107 | weights_init(m) 108 | 109 | def forward(self, img, verts_f): 110 | assert verts_f.shape[2] == self.verts_f_dim 111 | grid_feat = self.encoder(img) 112 | verts_f = self.attn(verts_f, grid_feat) 113 | return verts_f 114 | -------------------------------------------------------------------------------- /models/model_attn/inter_attn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .self_attn import SelfAttn 6 | 7 | 8 | def weights_init(layer): 9 | classname = layer.__class__.__name__ 10 | # print(classname) 11 | if classname.find('Conv2d') != -1: 12 | nn.init.xavier_uniform_(layer.weight.data) 13 | elif classname.find('Linear') != -1: 14 | nn.init.xavier_uniform_(layer.weight.data) 15 | if layer.bias is not None: 16 | nn.init.constant_(layer.bias.data, 0.0) 17 | 18 | 19 | class MLP_res_block(nn.Module): 20 | def __init__(self, in_dim, hid_dim, dropout=0.1): 21 | super().__init__() 22 | self.layer_norm = nn.LayerNorm(in_dim, eps=1e-6) 23 | self.fc1 = nn.Linear(in_dim, hid_dim) 24 | self.fc2 = nn.Linear(hid_dim, in_dim) 25 | 26 | self.dropout1 = nn.Dropout(dropout) 27 | self.dropout2 = nn.Dropout(dropout) 28 | 29 | def _ff_block(self, x): 30 | x = self.fc2(self.dropout1(F.relu(self.fc1(x)))) 31 | return self.dropout2(x) 32 | 33 | def forward(self, x): 34 | x = x + self._ff_block(self.layer_norm(x)) 35 | return x 36 | 37 | 38 | class inter_attn(nn.Module): 39 | def __init__(self, f_dim, n_heads=4, d_q=None, d_v=None, dropout=0.1): 40 | super().__init__() 41 | 42 | self.L_self_attn_layer = SelfAttn(f_dim, n_heads=n_heads, hid_dim=f_dim, dropout=dropout) 43 | self.R_self_attn_layer = SelfAttn(f_dim, n_heads=n_heads, hid_dim=f_dim, dropout=dropout) 44 | self.build_inter_attn(f_dim, n_heads, d_q, d_v, dropout) 45 | 46 | for m in self.modules(): 47 | weights_init(m) 48 | 49 | def build_inter_attn(self, f_dim, n_heads=4, d_q=None, d_v=None, dropout=0.1): 50 | if d_q is None: 51 | d_q = f_dim // n_heads 52 | if d_v is None: 53 | d_v = f_dim // n_heads 54 | 55 | self.n_heads = n_heads 56 | self.d_q = d_q 57 | self.d_v = d_v 58 | self.norm = d_q ** 0.5 59 | self.f_dim = f_dim 60 | 61 | self.dropout1 = nn.Dropout(dropout) 62 | self.dropout2 = nn.Dropout(dropout) 63 | self.w_qs = nn.Linear(f_dim, n_heads * d_q) 64 | self.w_ks = nn.Linear(f_dim, n_heads * d_q) 65 | self.w_vs = nn.Linear(f_dim, n_heads * d_v) 66 | self.fc = nn.Linear(n_heads * d_v, f_dim) 67 | 68 | self.layer_norm1 = nn.LayerNorm(f_dim, eps=1e-6) 69 | self.layer_norm2 = nn.LayerNorm(f_dim, eps=1e-6) 70 | self.ffL = MLP_res_block(f_dim, f_dim, dropout) 71 | self.ffR = MLP_res_block(f_dim, f_dim, dropout) 72 | 73 | def inter_attn(self, Lf, Rf, mask_L2R=None, mask_R2L=None): 74 | BS, V, fdim = Lf.shape 75 | assert fdim == self.f_dim 76 | BS, V, fdim = Rf.shape 77 | assert fdim == self.f_dim 78 | 79 | Lf2 = self.layer_norm1(Lf) 80 | Rf2 = self.layer_norm2(Rf) 81 | 82 | Lq = self.w_qs(Lf2).view(BS, V, self.n_heads, self.d_q).transpose(1, 2) # BS x h x V x q 83 | Lk = self.w_ks(Lf2).view(BS, V, self.n_heads, self.d_q).transpose(1, 2) # BS x h x V x q 84 | Lv = self.w_vs(Lf2).view(BS, V, self.n_heads, self.d_v).transpose(1, 2) # BS x h x V x v 85 | 86 | Rq = self.w_qs(Rf2).view(BS, V, self.n_heads, self.d_q).transpose(1, 2) # BS x h x V x q 87 | Rk = self.w_ks(Rf2).view(BS, V, self.n_heads, self.d_q).transpose(1, 2) # BS x h x V x q 88 | Rv = self.w_vs(Rf2).view(BS, V, self.n_heads, self.d_v).transpose(1, 2) # BS x h x V x v 89 | 90 | attn_R2L = torch.matmul(Lq, Rk.transpose(-1, -2)) / self.norm # bs, h, V, V 91 | attn_L2R = torch.matmul(Rq, Lk.transpose(-1, -2)) / self.norm # bs, h, V, V 92 | 93 | if mask_L2R is not None: 94 | attn_L2R = attn_L2R.masked_fill(mask_L2R == 0, -1e9) 95 | if mask_R2L is not None: 96 | attn_R2L = attn_R2L.masked_fill(mask_R2L == 0, -1e9) 97 | 98 | attn_R2L = F.softmax(attn_R2L, dim=-1) # bs, h, V, V 99 | attn_L2R = F.softmax(attn_L2R, dim=-1) # bs, h, V, V 100 | 101 | attn_R2L = self.dropout1(attn_R2L) 102 | attn_L2R = self.dropout1(attn_L2R) 103 | 104 | feat_L2R = torch.matmul(attn_L2R, Lv).transpose(1, 2).contiguous().view(BS, V, -1) 105 | feat_R2L = torch.matmul(attn_R2L, Rv).transpose(1, 2).contiguous().view(BS, V, -1) 106 | 107 | feat_L2R = self.dropout2(self.fc(feat_L2R)) 108 | feat_R2L = self.dropout2(self.fc(feat_R2L)) 109 | 110 | Lf = self.ffL(Lf + feat_R2L) 111 | Rf = self.ffR(Rf + feat_L2R) 112 | 113 | return Lf, Rf 114 | 115 | def forward(self, Lf, Rf, mask_L2R=None, mask_R2L=None): 116 | BS, V, fdim = Lf.shape 117 | assert fdim == self.f_dim 118 | BS, V, fdim = Rf.shape 119 | assert fdim == self.f_dim 120 | 121 | Lf = self.L_self_attn_layer(Lf) 122 | Rf = self.R_self_attn_layer(Rf) 123 | Lf, Rf = self.inter_attn(Lf, Rf, mask_L2R, mask_R2L) 124 | 125 | return Lf, Rf 126 | -------------------------------------------------------------------------------- /models/model_attn/self_attn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def weights_init(layer): 7 | classname = layer.__class__.__name__ 8 | # print(classname) 9 | if classname.find('Conv2d') != -1: 10 | nn.init.xavier_uniform_(layer.weight.data) 11 | elif classname.find('Linear') != -1: 12 | nn.init.xavier_uniform_(layer.weight.data) 13 | if layer.bias is not None: 14 | nn.init.constant_(layer.bias.data, 0.0) 15 | 16 | 17 | class MLP_res_block(nn.Module): 18 | def __init__(self, in_dim, hid_dim, dropout=0.1): 19 | super().__init__() 20 | self.layer_norm = nn.LayerNorm(in_dim, eps=1e-6) 21 | self.fc1 = nn.Linear(in_dim, hid_dim) 22 | self.fc2 = nn.Linear(hid_dim, in_dim) 23 | 24 | self.dropout1 = nn.Dropout(dropout) 25 | self.dropout2 = nn.Dropout(dropout) 26 | 27 | def _ff_block(self, x): 28 | x = self.fc2(self.dropout1(F.relu(self.fc1(x)))) 29 | return self.dropout2(x) 30 | 31 | def forward(self, x): 32 | x = x + self._ff_block(self.layer_norm(x)) 33 | return x 34 | 35 | 36 | class SelfAttn(nn.Module): 37 | def __init__(self, f_dim, hid_dim=None, n_heads=4, d_q=None, d_v=None, dropout=0.1): 38 | super().__init__() 39 | if d_q is None: 40 | d_q = f_dim // n_heads 41 | if d_v is None: 42 | d_v = f_dim // n_heads 43 | if hid_dim is None: 44 | hid_dim = f_dim 45 | 46 | self.n_heads = n_heads 47 | self.d_q = d_q 48 | self.d_v = d_v 49 | self.norm = d_q ** 0.5 50 | self.f_dim = f_dim 51 | 52 | self.dropout1 = nn.Dropout(dropout) 53 | self.dropout2 = nn.Dropout(dropout) 54 | self.w_qs = nn.Linear(f_dim, n_heads * d_q) 55 | self.w_ks = nn.Linear(f_dim, n_heads * d_q) 56 | self.w_vs = nn.Linear(f_dim, n_heads * d_v) 57 | 58 | self.layer_norm = nn.LayerNorm(f_dim, eps=1e-6) 59 | self.fc = nn.Linear(n_heads * d_v, f_dim) 60 | 61 | self.ff = MLP_res_block(f_dim, hid_dim, dropout) 62 | 63 | def self_attn(self, x): 64 | BS, V, f = x.shape 65 | 66 | q = self.w_qs(x).view(BS, -1, self.n_heads, self.d_q).transpose(1, 2) # BS x h x V x q 67 | k = self.w_ks(x).view(BS, -1, self.n_heads, self.d_q).transpose(1, 2) # BS x h x V x q 68 | v = self.w_vs(x).view(BS, -1, self.n_heads, self.d_v).transpose(1, 2) # BS x h x V x v 69 | 70 | attn = torch.matmul(q, k.transpose(-1, -2)) / self.norm # bs, h, V, V 71 | attn = F.softmax(attn, dim=-1) # bs, h, V, V 72 | attn = self.dropout1(attn) 73 | 74 | out = torch.matmul(attn, v).transpose(1, 2).contiguous().view(BS, V, -1) 75 | out = self.dropout2(self.fc(out)) 76 | return out 77 | 78 | def forward(self, x): 79 | BS, V, f = x.shape 80 | assert f == self.f_dim 81 | 82 | x = x + self.self_attn(self.layer_norm(x)) 83 | x = self.ff(x) 84 | 85 | return x 86 | -------------------------------------------------------------------------------- /models/model_zoo/__init__.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from .fc import build_fc_layer 4 | from .hrnet import get_hrnet, Bottleneck 5 | from .coarsening import build_graph 6 | from .graph_utils import graph_upsample, graph_avg_pool 7 | 8 | __all__ = ['build_fc_layer', 'get_hrnet', 'Bottleneck', 9 | 'build_graph', 'GCN_vert_convert', 'graph_upsample', 'graph_avg_pool', 10 | 'weights_init', 'conv1x1', 'conv3x3', 'deconv3x3'] 11 | 12 | 13 | class noop(nn.Module): 14 | def forward(self, x): 15 | return x 16 | 17 | 18 | def build_activate_layer(actType): 19 | if actType == 'relu': 20 | return nn.ReLU(inplace=True) 21 | elif actType == 'lrelu': 22 | return nn.LeakyReLU(0.1, inplace=True) 23 | elif actType == 'elu': 24 | return nn.ELU(inplace=True) 25 | elif actType == 'sigmoid': 26 | return nn.Sigmoid() 27 | elif actType == 'tanh': 28 | return nn.Tanh() 29 | elif actType == 'noop': 30 | return noop() 31 | else: 32 | raise RuntimeError('no such activate layer!') 33 | 34 | 35 | def weights_init(layer): 36 | classname = layer.__class__.__name__ 37 | # print(classname) 38 | if classname.find('Conv2d') != -1: 39 | nn.init.kaiming_normal_(layer.weight.data) 40 | elif classname.find('Linear') != -1: 41 | nn.init.kaiming_normal_(layer.weight.data) 42 | if layer.bias is not None: 43 | nn.init.constant_(layer.bias.data, 0.0) 44 | 45 | 46 | class Flatten(nn.Module): 47 | def forward(self, x): 48 | return x.view(x.size(0), -1) 49 | 50 | 51 | class unFlatten(nn.Module): 52 | def forward(self, x): 53 | return x.view(x.size(0), -1, 1, 1) 54 | 55 | 56 | def conv1x1(in_channels, out_channels, stride=1, bn_init_zero=False, actFun='relu'): 57 | bn = nn.BatchNorm2d(out_channels) 58 | nn.init.constant_(bn.weight, 0. if bn_init_zero else 1.) 59 | layers = [nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0, bias=False), 60 | build_activate_layer(actFun), 61 | bn] 62 | return nn.Sequential(*layers) 63 | 64 | 65 | def conv3x3(in_channels, out_channels, stride=1, bn_init_zero=False, actFun='relu'): 66 | bn = nn.BatchNorm2d(out_channels) 67 | nn.init.constant_(bn.weight, 0. if bn_init_zero else 1.) 68 | layers = [nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False), 69 | build_activate_layer(actFun), 70 | bn] 71 | return nn.Sequential(*layers) 72 | 73 | 74 | def deconv3x3(in_channels, out_channels, stride=1, bn_init_zero=False, actFun='relu'): 75 | bn = nn.BatchNorm2d(out_channels) 76 | nn.init.constant_(bn.weight, 0. if bn_init_zero else 1.) 77 | return nn.Sequential( 78 | nn.Upsample(scale_factor=stride, mode='bilinear', align_corners=True), 79 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False), 80 | build_activate_layer(actFun), 81 | bn 82 | ) 83 | 84 | 85 | class GCN_vert_convert(): 86 | def __init__(self, vertex_num=1, graph_perm_reverse=[0], graph_perm=[0]): 87 | self.graph_perm_reverse = graph_perm_reverse[:vertex_num] 88 | self.graph_perm = graph_perm 89 | 90 | def vert_to_GCN(self, x): 91 | # x: B x v x f 92 | return x[:, self.graph_perm] 93 | 94 | def GCN_to_vert(self, x): 95 | # x: B x v x f 96 | return x[:, self.graph_perm_reverse] 97 | -------------------------------------------------------------------------------- /models/model_zoo/coarsening.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.sparse 3 | from scipy.sparse.linalg import eigsh 4 | import torch 5 | 6 | 7 | # forked from https://github.com/3d-hand-shape/hand-graph-cnn 8 | 9 | 10 | def laplacian(W, normalized=True): 11 | """Return graph Laplacian""" 12 | 13 | # Degree matrix. 14 | d = W.sum(axis=0) 15 | 16 | # Laplacian matrix. 17 | if not normalized: 18 | D = scipy.sparse.diags(d.A.squeeze(), 0) 19 | L = D - W 20 | else: 21 | d += np.spacing(np.array(0, W.dtype)) 22 | d = 1 / np.sqrt(d) 23 | D = scipy.sparse.diags(d.A.squeeze(), 0) 24 | Imat = scipy.sparse.identity(d.size, dtype=W.dtype) 25 | L = Imat - D * W * D 26 | 27 | assert np.abs(L - L.T).mean() < 1e-9 28 | assert type(L) is scipy.sparse.csr.csr_matrix 29 | return L 30 | 31 | 32 | def rescale_L(L, lmax=2): 33 | """Rescale Laplacian eigenvalues to [-1,1]""" 34 | M, M = L.shape 35 | Imat = scipy.sparse.identity(M, format='csr', dtype=L.dtype) 36 | L /= lmax * 2 # L = 2.0*L / lmax 37 | L -= Imat 38 | return L 39 | 40 | 41 | def lmax_L(L): 42 | """Compute largest Laplacian eigenvalue""" 43 | return eigsh(L, k=1, which='LM', return_eigenvectors=False)[0] 44 | 45 | 46 | # graph coarsening with Heavy Edge Matching 47 | # forked from https://github.com/xbresson/spectral_graph_convnets 48 | def coarsen(A, levels): 49 | graphs, parents = HEM(A, levels) 50 | perms = compute_perm(parents) 51 | 52 | adjacencies = [] 53 | laplacians = [] 54 | for i, A in enumerate(graphs): 55 | M, M = A.shape 56 | 57 | if i < levels: 58 | A = perm_adjacency(A, perms[i]) 59 | 60 | A = A.tocsr() 61 | A.eliminate_zeros() 62 | adjacencies.append(A) 63 | # Mnew, Mnew = A.shape 64 | # print('Layer {0}: M_{0} = |V| = {1} nodes ({2} added), |E| = {3} edges'.format(i, Mnew, Mnew - M, A.nnz // 2)) 65 | 66 | L = laplacian(A, normalized=True) 67 | laplacians.append(L) 68 | 69 | return adjacencies, laplacians, perms[0] if len(perms) > 0 else None 70 | 71 | 72 | def HEM(W, levels, rid=None): 73 | """ 74 | Coarsen a graph multiple times using the Heavy Edge Matching (HEM). 75 | 76 | Input 77 | W: symmetric sparse weight (adjacency) matrix 78 | levels: the number of coarsened graphs 79 | 80 | Output 81 | graph[0]: original graph of size N_1 82 | graph[2]: coarser graph of size N_2 < N_1 83 | graph[levels]: coarsest graph of Size N_levels < ... < N_2 < N_1 84 | parents[i] is a vector of size N_i with entries ranging from 1 to N_{i+1} 85 | which indicate the parents in the coarser graph[i+1] 86 | nd_sz{i} is a vector of size N_i that contains the size of the supernode in the graph{i} 87 | 88 | Note 89 | if "graph" is a list of length k, then "parents" will be a list of length k-1 90 | """ 91 | 92 | N, N = W.shape 93 | 94 | if rid is None: 95 | rid = np.random.permutation(range(N)) 96 | 97 | ss = np.array(W.sum(axis=0)).squeeze() 98 | rid = np.argsort(ss) 99 | 100 | parents = [] 101 | degree = W.sum(axis=0) - W.diagonal() 102 | graphs = [] 103 | graphs.append(W) 104 | 105 | # print('Heavy Edge Matching coarsening with Xavier version') 106 | 107 | for _ in range(levels): 108 | 109 | # CHOOSE THE WEIGHTS FOR THE PAIRING 110 | # weights = ones(N,1) # metis weights 111 | weights = degree # graclus weights 112 | # weights = supernode_size # other possibility 113 | weights = np.array(weights).squeeze() 114 | 115 | # PAIR THE VERTICES AND CONSTRUCT THE ROOT VECTOR 116 | idx_row, idx_col, val = scipy.sparse.find(W) 117 | cc = idx_row 118 | rr = idx_col 119 | vv = val 120 | 121 | # TO BE SPEEDUP 122 | if not (list(cc) == list(np.sort(cc))): 123 | tmp = cc 124 | cc = rr 125 | rr = tmp 126 | 127 | cluster_id = HEM_one_level(cc, rr, vv, rid, weights) # cc is ordered 128 | parents.append(cluster_id) 129 | 130 | # COMPUTE THE EDGES WEIGHTS FOR THE NEW GRAPH 131 | nrr = cluster_id[rr] 132 | ncc = cluster_id[cc] 133 | nvv = vv 134 | Nnew = cluster_id.max() + 1 135 | # CSR is more appropriate: row,val pairs appear multiple times 136 | W = scipy.sparse.csr_matrix((nvv, (nrr, ncc)), shape=(Nnew, Nnew)) 137 | W.eliminate_zeros() 138 | 139 | # Add new graph to the list of all coarsened graphs 140 | graphs.append(W) 141 | N, N = W.shape 142 | 143 | # COMPUTE THE DEGREE (OMIT OR NOT SELF LOOPS) 144 | degree = W.sum(axis=0) 145 | # degree = W.sum(axis=0) - W.diagonal() 146 | 147 | # CHOOSE THE ORDER IN WHICH VERTICES WILL BE VISTED AT THE NEXT PASS 148 | # [~, rid]=sort(ss); # arthur strategy 149 | # [~, rid]=sort(supernode_size); # thomas strategy 150 | # rid=randperm(N); # metis/graclus strategy 151 | ss = np.array(W.sum(axis=0)).squeeze() 152 | rid = np.argsort(ss) 153 | 154 | return graphs, parents 155 | 156 | 157 | # Coarsen a graph given by rr,cc,vv. rr is assumed to be ordered 158 | def HEM_one_level(rr, cc, vv, rid, weights): 159 | nnz = rr.shape[0] 160 | N = rr[nnz - 1] + 1 161 | 162 | marked = np.zeros(N, bool) 163 | rowstart = np.zeros(N, np.int32) 164 | rowlength = np.zeros(N, np.int32) 165 | cluster_id = np.zeros(N, np.int32) 166 | 167 | oldval = rr[0] 168 | count = 0 169 | clustercount = 0 170 | 171 | for ii in range(nnz): 172 | rowlength[count] = rowlength[count] + 1 173 | if rr[ii] > oldval: 174 | oldval = rr[ii] 175 | rowstart[count + 1] = ii 176 | count = count + 1 177 | 178 | for ii in range(N): 179 | tid = rid[ii] 180 | if not marked[tid]: 181 | wmax = 0.0 182 | rs = rowstart[tid] 183 | marked[tid] = True 184 | bestneighbor = -1 185 | for jj in range(rowlength[tid]): 186 | nid = cc[rs + jj] 187 | if marked[nid]: 188 | tval = 0.0 189 | else: 190 | 191 | # First approach 192 | if 2 == 1: 193 | tval = vv[rs + jj] * (1.0 / weights[tid] + 1.0 / weights[nid]) 194 | 195 | # Second approach 196 | if 1 == 1: 197 | Wij = vv[rs + jj] 198 | Wii = vv[rowstart[tid]] 199 | Wjj = vv[rowstart[nid]] 200 | di = weights[tid] 201 | dj = weights[nid] 202 | tval = (2. * Wij + Wii + Wjj) * 1. / (di + dj + 1e-9) 203 | 204 | if tval > wmax: 205 | wmax = tval 206 | bestneighbor = nid 207 | 208 | cluster_id[tid] = clustercount 209 | 210 | if bestneighbor > -1: 211 | cluster_id[bestneighbor] = clustercount 212 | marked[bestneighbor] = True 213 | 214 | clustercount += 1 215 | 216 | return cluster_id 217 | 218 | 219 | def compute_perm(parents): 220 | """ 221 | Return a list of indices to reorder the adjacency and data matrices so 222 | that the union of two neighbors from layer to layer forms a binary tree. 223 | """ 224 | 225 | # Order of last layer is random (chosen by the clustering algorithm). 226 | indices = [] 227 | if len(parents) > 0: 228 | M_last = max(parents[-1]) + 1 229 | indices.append(list(range(M_last))) 230 | 231 | for parent in parents[::-1]: 232 | 233 | # Fake nodes go after real ones. 234 | pool_singeltons = len(parent) 235 | 236 | indices_layer = [] 237 | for i in indices[-1]: 238 | indices_node = list(np.where(parent == i)[0]) 239 | assert 0 <= len(indices_node) <= 2 240 | 241 | # Add a node to go with a singelton. 242 | if len(indices_node) == 1: 243 | indices_node.append(pool_singeltons) 244 | pool_singeltons += 1 245 | 246 | # Add two nodes as children of a singelton in the parent. 247 | elif len(indices_node) == 0: 248 | indices_node.append(pool_singeltons + 0) 249 | indices_node.append(pool_singeltons + 1) 250 | pool_singeltons += 2 251 | 252 | indices_layer.extend(indices_node) 253 | indices.append(indices_layer) 254 | 255 | # Sanity checks. 256 | for i, indices_layer in enumerate(indices): 257 | M = M_last * 2 ** i 258 | # Reduction by 2 at each layer (binary tree). 259 | assert len(indices[0] == M) 260 | # The new ordering does not omit an indice. 261 | assert sorted(indices_layer) == list(range(M)) 262 | 263 | return indices[::-1] 264 | 265 | 266 | assert (compute_perm([np.array([4, 1, 1, 2, 2, 3, 0, 0, 3]), np.array([2, 1, 0, 1, 0])]) 267 | == [[3, 4, 0, 9, 1, 2, 5, 8, 6, 7, 10, 11], [2, 4, 1, 3, 0, 5], [0, 1, 2]]) 268 | 269 | 270 | def perm_adjacency(A, indices): 271 | """ 272 | Permute adjacency matrix, i.e. exchange node ids, 273 | so that binary unions form the clustering tree. 274 | """ 275 | if indices is None: 276 | return A 277 | 278 | M, M = A.shape 279 | Mnew = len(indices) 280 | A = A.tocoo() 281 | 282 | # Add Mnew - M isolated vertices. 283 | rows = scipy.sparse.coo_matrix((Mnew - M, M), dtype=np.float32) 284 | cols = scipy.sparse.coo_matrix((Mnew, Mnew - M), dtype=np.float32) 285 | A = scipy.sparse.vstack([A, rows]) 286 | A = scipy.sparse.hstack([A, cols]) 287 | 288 | # Permute the rows and the columns. 289 | perm = np.argsort(indices) 290 | A.row = np.array(perm)[A.row] 291 | A.col = np.array(perm)[A.col] 292 | 293 | assert np.abs(A - A.T).mean() < 1e-8 # 1e-9 294 | assert type(A) is scipy.sparse.coo.coo_matrix 295 | return A 296 | 297 | 298 | """need to be modified to adapted to F features""" 299 | 300 | 301 | def perm_data(x, indices): 302 | """ 303 | Permute data matrix, i.e. exchange node ids, 304 | so that binary unions form the clustering tree. 305 | """ 306 | if indices is None: 307 | return x 308 | 309 | M, F = x.shape 310 | Mnew = len(indices) 311 | assert Mnew >= M 312 | xnew = np.empty((Mnew, F)) # """need to be modified to adapted to F features""" 313 | for i, j in enumerate(indices): 314 | # Existing vertex, i.e. real data. 315 | if j < M: 316 | xnew[i, :] = x[j, :] 317 | # Fake vertex because of singeltons. 318 | # They will stay 0 so that max pooling chooses the singelton. 319 | # Or -infty ? 320 | else: 321 | """need to be modified to adapted to F features and negative values""" 322 | """np.full((2, 2), -np.inf)""" 323 | xnew[i, :] = np.zeros((F)) # np.full((F), -np.inf) # 324 | return xnew 325 | 326 | 327 | def perm_index_reverse(indices): 328 | indices_reverse = np.copy(indices) 329 | 330 | for i, j in enumerate(indices): 331 | indices_reverse[j] = i 332 | 333 | return indices_reverse 334 | 335 | 336 | def perm_tri(tri, indices): 337 | """ 338 | tri: T x 3 339 | """ 340 | indices_reverse = perm_index_reverse(indices) 341 | tri_new = np.copy(tri) 342 | for i in range(len(tri)): 343 | tri_new[i, 0] = indices_reverse[tri[i, 0]] 344 | tri_new[i, 1] = indices_reverse[tri[i, 1]] 345 | tri_new[i, 2] = indices_reverse[tri[i, 2]] 346 | 347 | return tri_new 348 | 349 | 350 | def build_adj_mat(faces, num_vertex=None): 351 | """ 352 | :param faces: T x 3 353 | :return: adj: sparse matrix, V x V (torch.sparse.FloatTensor) 354 | """ 355 | if num_vertex is None: 356 | num_vertex = np.max(faces) + 1 357 | 358 | num_tri = faces.shape[0] 359 | edges = np.empty((num_tri * 3, 2)) 360 | for i_tri in range(num_tri): 361 | edges[i_tri * 3] = faces[i_tri, :2] 362 | edges[i_tri * 3 + 1] = faces[i_tri, 1:] 363 | edges[i_tri * 3 + 2] = faces[i_tri, [0, 2]] 364 | 365 | adj = scipy.sparse.coo_matrix((np.ones(edges.shape[0]), (edges[:, 0], edges[:, 1])), 366 | shape=(num_vertex, num_vertex), dtype=np.float32) 367 | 368 | adj = adj - (adj > 1) * 1.0 369 | 370 | # build symmetric adjacency matrix 371 | adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj) 372 | 373 | # adj = normalize_sparse_mx(adj + sp.eye(adj.shape[0])) 374 | # adj = sparse_mx_to_torch_sparse_tensor(adj) 375 | 376 | return adj 377 | 378 | 379 | def cut_perm(perm_list, level, N): 380 | perm = torch.tensor(perm_list) 381 | perm[perm > (N - 1)] = -1 382 | for ll in range(level): 383 | perm = perm.view(-1, 2**(ll + 1)) 384 | start = 0 385 | mid = perm.shape[1] // 2 386 | end = perm.shape[1] 387 | for i in range(perm.shape[0]): 388 | if perm[i, start] == -1: 389 | perm[i, start:mid] = perm[i, mid:end] 390 | if perm[i, mid] == -1: 391 | perm[i, mid:end] = perm[i, start:mid] 392 | perm = perm.view(-1) 393 | perm = perm.tolist() 394 | return perm 395 | 396 | 397 | def build_graph(faces, coarsening_levels=4): 398 | """ 399 | Build graph for Hand Mesh 400 | """ 401 | joints_num = faces.max() + 1 402 | 403 | # Build adj mat 404 | hand_mesh_adj = build_adj_mat(faces, joints_num) 405 | # Compute coarsened graphs 406 | graph_Adj, graph_L, graph_perm = coarsen(hand_mesh_adj, coarsening_levels) 407 | 408 | graph_mask = torch.from_numpy((np.array(graph_perm) < faces.max() + 1).astype(float)).float() 409 | graph_mask = graph_mask # V 410 | 411 | # Compute max eigenvalue of graph Laplacians, rescale Laplacian 412 | graph_lmax = [] 413 | for i in range(coarsening_levels): 414 | graph_lmax.append(lmax_L(graph_L[i])) 415 | graph_L[i] = rescale_L(graph_L[i], graph_lmax[i]) 416 | 417 | graph_perm_reverse = perm_index_reverse(graph_perm) 418 | graph_perm = cut_perm(graph_perm, coarsening_levels, joints_num) 419 | 420 | graph_dict = {'mesh_faces': faces, 421 | 'mesh_adj': hand_mesh_adj, 422 | 'graph_mask': graph_mask, 423 | 'coarsen_graphs_adj': graph_Adj, 424 | 'coarsen_graphs_L': graph_L, 425 | 'graph_perm': graph_perm, 426 | 'graph_perm_reverse': graph_perm_reverse} 427 | 428 | return graph_dict 429 | -------------------------------------------------------------------------------- /models/model_zoo/fc.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class noop(nn.Module): 5 | def forward(self, x): 6 | return x 7 | 8 | 9 | def build_activate_layer(actType): 10 | if actType == 'relu': 11 | return nn.ReLU(inplace=True) 12 | elif actType == 'lrelu': 13 | return nn.LeakyReLU(0.1, inplace=True) 14 | elif actType == 'elu': 15 | return nn.ELU(inplace=True) 16 | elif actType == 'sigmoid': 17 | return nn.Sigmoid() 18 | elif actType == 'tanh': 19 | return nn.Tanh() 20 | elif actType == 'noop': 21 | return noop() 22 | else: 23 | raise RuntimeError('no such activate layer!') 24 | 25 | 26 | def build_fc_layer(inDim, outDim, actFun='relu', dropout_prob=-1, weight_norm=False): 27 | net = [] 28 | if dropout_prob > 0: 29 | net.append(nn.Dropout(p=dropout_prob)) 30 | if weight_norm: 31 | net.append(nn.utils.weight_norm(nn.Linear(inDim, outDim))) 32 | else: 33 | net.append(nn.Linear(inDim, outDim)) 34 | net.append(build_activate_layer(actFun)) 35 | return nn.Sequential(*net) 36 | -------------------------------------------------------------------------------- /models/model_zoo/graph_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | # forked from https://github.com/3d-hand-shape/hand-graph-cnn 7 | 8 | 9 | def sparse_python_to_torch(sp_python): 10 | L = sp_python.tocoo() 11 | indices = np.column_stack((L.row, L.col)).T 12 | indices = indices.astype(np.int64) 13 | indices = torch.from_numpy(indices) 14 | indices = indices.type(torch.LongTensor) 15 | L_data = L.data.astype(np.float32) 16 | L_data = torch.from_numpy(L_data) 17 | L_data = L_data.type(torch.FloatTensor) 18 | L = torch.sparse.FloatTensor(indices, L_data, torch.Size(L.shape)) 19 | # if torch.cuda.is_available(): 20 | # L = L.cuda() 21 | 22 | return L 23 | 24 | 25 | def graph_max_pool(x, p): 26 | if p > 1: 27 | x = x.permute(0, 2, 1).contiguous() # x = B x F x V 28 | x = nn.MaxPool1d(p)(x) # B x F x V/p 29 | x = x.permute(0, 2, 1).contiguous() # x = B x V/p x F 30 | return x 31 | else: 32 | return x 33 | 34 | 35 | def graph_avg_pool(x, p): 36 | if p > 1: 37 | x = x.permute(0, 2, 1).contiguous() # x = B x F x V 38 | x = nn.AvgPool1d(p)(x) # B x F x V/p 39 | x = x.permute(0, 2, 1).contiguous() # x = B x V/p x F 40 | return x 41 | else: 42 | return x 43 | 44 | # Upsampling of size p. 45 | 46 | 47 | def graph_upsample(x, p): 48 | if p > 1: 49 | x = x.permute(0, 2, 1).contiguous() # x = B x F x V 50 | x = nn.Upsample(scale_factor=p)(x) # B x F x (V*p) 51 | x = x.permute(0, 2, 1).contiguous() # x = B x (V*p) x F 52 | return x 53 | else: 54 | return x 55 | 56 | 57 | def graph_conv_cheby(x, cl, L, K=3): 58 | # parameters 59 | # B = batch size 60 | # V = nb vertices 61 | # Fin = nb input features 62 | # Fout = nb output features 63 | # K = Chebyshev order & support size 64 | B, V, Fin = x.size() 65 | B, V, Fin = int(B), int(V), int(Fin) 66 | 67 | # transform to Chebyshev basis 68 | x0 = x.permute(1, 2, 0).contiguous() # V x Fin x B 69 | x0 = x0.view([V, Fin * B]) # V x Fin*B 70 | x = x0.unsqueeze(0) # 1 x V x Fin*B 71 | 72 | def concat(x, x_): 73 | x_ = x_.unsqueeze(0) # 1 x V x Fin*B 74 | return torch.cat((x, x_), 0) # K x V x Fin*B 75 | 76 | if K > 1: 77 | x1 = torch.mm(L, x0) # V x Fin*B 78 | x = torch.cat((x, x1.unsqueeze(0)), 0) # 2 x V x Fin*B 79 | for k in range(2, K): 80 | x2 = 2 * torch.mm(L, x1) - x0 81 | x = torch.cat((x, x2.unsqueeze(0)), 0) # M x Fin*B 82 | x0, x1 = x1, x2 83 | 84 | x = x.view([K, V, Fin, B]) # K x V x Fin x B 85 | x = x.permute(3, 1, 2, 0).contiguous() # B x V x Fin x K 86 | x = x.view([B * V, Fin * K]) # B*V x Fin*K 87 | 88 | # Compose linearly Fin features to get Fout features 89 | x = cl(x) # B*V x Fout 90 | x = x.view([B, V, -1]) # B x V x Fout 91 | 92 | return x 93 | 94 | 95 | class Graph_CNN_Feat_Mesh(nn.Module): 96 | def __init__(self, num_input_chan, num_mesh_output_chan, graph_L): 97 | print('Graph ConvNet: feature to mesh') 98 | 99 | super(Graph_CNN_Feat_Mesh, self).__init__() 100 | 101 | self.num_input_chan = num_input_chan 102 | self.num_mesh_output_chan = num_mesh_output_chan 103 | self.graph_L = graph_L 104 | 105 | # parameters 106 | self.CL_F = [64, 32, num_mesh_output_chan] 107 | self.CL_K = [3, 3] 108 | self.layers_per_block = [2, 2] 109 | 110 | self.FC_F = [num_input_chan, 512, self.CL_F[0] * self.graph_L[-1].shape[0]] 111 | 112 | self.fc = nn.Sequential() 113 | for fc_id in range(len(self.FC_F) - 1): 114 | if fc_id == 0: 115 | use_activation = True 116 | else: 117 | use_activation = False 118 | self.fc.add_module('fc_%d' % (fc_id + 1), FCLayer(self.FC_F[fc_id], 119 | self.FC_F[fc_id + 1], use_dropout=False, 120 | use_activation=use_activation)) 121 | 122 | _cl = [] 123 | _bn = [] 124 | for block_i in range(len(self.CL_F) - 1): 125 | for layer_i in range(self.layers_per_block[block_i]): 126 | Fin = self.CL_K[block_i] * self.CL_F[block_i] 127 | 128 | if layer_i is not self.layers_per_block[block_i] - 1: 129 | Fout = self.CL_F[block_i] 130 | else: 131 | Fout = self.CL_F[block_i + 1] 132 | 133 | _cl.append(nn.Linear(Fin, Fout)) 134 | 135 | scale = np.sqrt(2.0 / (Fin + Fout)) 136 | _cl[-1].weight.data.uniform_(-scale, scale) 137 | _cl[-1].bias.data.fill_(0.0) 138 | 139 | if block_i == len(self.CL_F) - 2 and layer_i == self.layers_per_block[block_i] - 1: 140 | _bn.append(None) 141 | else: 142 | _bn.append(nn.BatchNorm1d(Fout)) 143 | 144 | self.cl = nn.ModuleList(_cl) 145 | self.bn = nn.ModuleList(_bn) 146 | 147 | # convert scipy sparse matric L to pytorch 148 | for graph_i in range(len(graph_L)): 149 | self.graph_L[graph_i] = sparse_python_to_torch(self.graph_L[graph_i]) 150 | 151 | def init_weights(self, W, Fin, Fout): 152 | scale = np.sqrt(2.0 / (Fin + Fout)) 153 | W.uniform_(-scale, scale) 154 | 155 | return W 156 | 157 | def graph_conv_cheby(self, x, cl, bn, L, Fout, K): 158 | # parameters 159 | # B = batch size 160 | # V = nb vertices 161 | # Fin = nb input features 162 | # Fout = nb output features 163 | # K = Chebyshev order & support size 164 | B, V, Fin = x.size() 165 | B, V, Fin = int(B), int(V), int(Fin) 166 | 167 | # transform to Chebyshev basis 168 | x0 = x.permute(1, 2, 0).contiguous() # V x Fin x B 169 | x0 = x0.view([V, Fin * B]) # V x Fin*B 170 | x = x0.unsqueeze(0) # 1 x V x Fin*B 171 | 172 | def concat(x, x_): 173 | x_ = x_.unsqueeze(0) # 1 x V x Fin*B 174 | return torch.cat((x, x_), 0) # K x V x Fin*B 175 | 176 | if K > 1: 177 | x1 = my_sparse_mm()(L, x0) # V x Fin*B 178 | x = torch.cat((x, x1.unsqueeze(0)), 0) # 2 x V x Fin*B 179 | for k in range(2, K): 180 | x2 = 2 * my_sparse_mm()(L, x1) - x0 181 | x = torch.cat((x, x2.unsqueeze(0)), 0) # M x Fin*B 182 | x0, x1 = x1, x2 183 | 184 | x = x.view([K, V, Fin, B]) # K x V x Fin x B 185 | x = x.permute(3, 1, 2, 0).contiguous() # B x V x Fin x K 186 | x = x.view([B * V, Fin * K]) # B*V x Fin*K 187 | 188 | # Compose linearly Fin features to get Fout features 189 | x = cl(x) # B*V x Fout 190 | if bn is not None: 191 | x = bn(x) # B*V x Fout 192 | x = x.view([B, V, Fout]) # B x V x Fout 193 | 194 | return x 195 | 196 | # Upsampling of size p. 197 | def graph_upsample(self, x, p): 198 | if p > 1: 199 | x = x.permute(0, 2, 1).contiguous() # x = B x F x V 200 | x = nn.Upsample(scale_factor=p)(x) # B x F x (V*p) 201 | x = x.permute(0, 2, 1).contiguous() # x = B x (V*p) x F 202 | return x 203 | else: 204 | return x 205 | 206 | def forward(self, x): 207 | # x: B x num_input_chan 208 | x = self.fc(x) 209 | # x: B x (self.CL_F[0] * self.graph_L[-1].shape[0]) 210 | x = x.view(-1, self.graph_L[-1].shape[0], self.CL_F[0]) 211 | # x: B x 80 x 64 212 | 213 | cl_i = 0 214 | for block_i in range(len(self.CL_F) - 1): 215 | x = self.graph_upsample(x, 2) 216 | x = self.graph_upsample(x, 2) 217 | 218 | for layer_i in range(self.layers_per_block[block_i]): 219 | if layer_i is not self.layers_per_block[block_i] - 1: 220 | Fout = self.CL_F[block_i] 221 | else: 222 | Fout = self.CL_F[block_i + 1] 223 | 224 | x = self.graph_conv_cheby(x, self.cl[cl_i], self.bn[cl_i], self.graph_L[-(block_i * 2 + 3)], 225 | # 2 - block_i*2], 226 | Fout, self.CL_K[block_i]) 227 | if block_i is not len(self.CL_F) - 2 or layer_i is not self.layers_per_block[block_i] - 1: 228 | x = F.relu(x) 229 | 230 | cl_i = cl_i + 1 231 | 232 | return x # x: B x 1280 x 3 233 | -------------------------------------------------------------------------------- /utils/DataProvider.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | from torch.utils.data.distributed import DistributedSampler 3 | 4 | 5 | class DataProvider: 6 | def __init__(self, dataset, batch_size, num_workers=1, dist=False): 7 | self.batch_size = batch_size 8 | self.dataset = dataset 9 | self.dataiter = None 10 | self.iteration = 0 11 | self.epoch = 0 12 | self.num_workers = num_workers 13 | self.dist = dist 14 | self.build() 15 | 16 | def build(self): 17 | if self.dist: 18 | sampler = DistributedSampler(self.dataset, shuffle=True, drop_last=True) 19 | dataloader = DataLoader(self.dataset, batch_size=self.batch_size, sampler=sampler, 20 | num_workers=self.num_workers, drop_last=True, pin_memory=True) 21 | else: 22 | dataloader = DataLoader(self.dataset, batch_size=self.batch_size, shuffle=True, 23 | num_workers=self.num_workers, drop_last=True, pin_memory=True) 24 | self.batch_per_epoch = len(dataloader) 25 | self.dataiter = dataloader.__iter__() 26 | 27 | def next(self): 28 | if self.dataiter is None: 29 | self.build() 30 | try: 31 | self.iteration += 1 32 | return self.dataiter.next() 33 | 34 | except StopIteration: 35 | self.epoch += 1 36 | self.build() 37 | self.iteration = 1 38 | return self.dataiter.next() 39 | -------------------------------------------------------------------------------- /utils/config.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | import os 3 | 4 | _C = CN(new_allowed=True) 5 | 6 | 7 | def get_cfg_defaults(): 8 | """Get a yacs CfgNode object with default values for my_project.""" 9 | # Return a clone so that the defaults will not be altered 10 | # This is for the "local variable" use pattern 11 | defaults_abspath = os.path.abspath(os.path.join(os.path.dirname(__file__), 'defaults.yaml')) 12 | _C.merge_from_file(defaults_abspath) 13 | _C.set_new_allowed(False) 14 | return _C.clone() 15 | 16 | 17 | def load_cfg(path=None): 18 | cfg = get_cfg_defaults() 19 | if path is not None: 20 | cfg.merge_from_file(path) 21 | return cfg 22 | 23 | # Alternatively, provide a way to import the defaults as 24 | # a global singleton: 25 | # cfg = _C # users can `from config import cfg` 26 | 27 | 28 | # if __name__ == '__main__': 29 | # cfg = get_cfg_defaults() 30 | # with open('test.yaml', 'w') as file: 31 | # file.write(cfg.dump()) 32 | -------------------------------------------------------------------------------- /utils/defaults.yaml: -------------------------------------------------------------------------------- 1 | SEED: 25 2 | MISC: 3 | MANO_PATH: "misc/mano" 4 | GRAPH_LEFT_DICT_PATH: "misc/graph_left.pkl" 5 | GRAPH_RIGHT_DICT_PATH: "misc/graph_right.pkl" 6 | DENSE_COLOR: "misc/v_color.pkl" 7 | MANO_SEG_PATH: "misc/mano_seg.pkl" 8 | UPSAMPLE_PATH: "misc/upsample.pkl" 9 | MODEL: 10 | ENCODER_TYPE: "resnet50" 11 | DECONV_DIMS: [256, 256, 256, 256] 12 | IMG_DIMS: [256, 128, 64] 13 | GCN_IN_DIM: [512, 256, 128] 14 | GCN_OUT_DIM: [256, 128, 64] 15 | ENCODER_PRETRAIN_PATH: "none" 16 | freeze_upsample: True 17 | graph_k: 2 18 | graph_layer_num: 4 19 | MODEL_PARAM: 20 | MODEL_PRETRAIN_PATH: "none" 21 | OPTIM_PATH: "none" 22 | LrSc_PATH: "none" 23 | DATASET: 24 | INTERHAND_PATH: "./interhand2.6m/" 25 | DATA_AUGMENT: 26 | THETA: 90 27 | SCALE: 0.1 28 | UV: 0.0 29 | TRAIN: 30 | DIST_PORT: 12345 31 | OPTIM: 'adam' 32 | current_epoch: 0 33 | lr_decay_step: 150 34 | lr_decay_gamma: 0.1 35 | warm_up: 3 36 | EPOCHS: 200 37 | BATCH_SIZE: 64 38 | LR: 1.0e-4 39 | dropout: 0.05 40 | LOSS_WEIGHT: 41 | AUX: 42 | DENSEPOSE: 30 43 | MASK: 500 44 | HMS: 100 45 | DATA: 46 | LABEL_3D: 100 47 | LABEL_2D: 50 48 | GRAPH: 49 | NORM: 50 | EDGE: 2000 51 | NORMAL: 10 52 | NORM_EPOCH: 50 53 | NORM: 54 | UPSAMPLE: 1.0 55 | TB: 56 | SHOW_GAP: 200 57 | SAVE_DIR: "./output/log/exp" 58 | SAVE: 59 | SAVE_GAP: 10 60 | SAVE_DIR: "./output/model/exp" 61 | 62 | -------------------------------------------------------------------------------- /utils/lr_sc.py: -------------------------------------------------------------------------------- 1 | from functools import partial, wraps 2 | import warnings 3 | import math 4 | from torch.optim.optimizer import Optimizer 5 | 6 | # forked form : https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html 7 | 8 | class _LRScheduler(object): 9 | ''' 10 | from pytorch 1.2.0 (torch.optim.lr_sceduler) 11 | ''' 12 | 13 | def __init__(self, optimizer, last_epoch=-1): 14 | if not isinstance(optimizer, Optimizer): 15 | raise TypeError('{} is not an Optimizer'.format( 16 | type(optimizer).__name__)) 17 | self.optimizer = optimizer 18 | if last_epoch == -1: 19 | for group in optimizer.param_groups: 20 | group.setdefault('initial_lr', group['lr']) 21 | last_epoch = 0 22 | else: 23 | for i, group in enumerate(optimizer.param_groups): 24 | if 'initial_lr' not in group: 25 | raise KeyError("param 'initial_lr' is not specified " 26 | "in param_groups[{}] when resuming an optimizer".format(i)) 27 | self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups)) 28 | self.last_epoch = last_epoch 29 | 30 | # Following https://github.com/pytorch/pytorch/issues/20124 31 | # We would like to ensure that `lr_scheduler.step()` is called after 32 | # `optimizer.step()` 33 | def with_counter(func, opt): 34 | @wraps(func) 35 | def wrapper(*args, **kwargs): 36 | opt._step_count += 1 37 | return func(*args, **kwargs) 38 | wrapper._with_counter = True 39 | return wrapper 40 | 41 | self.optimizer.step = with_counter(self.optimizer.step, self.optimizer) 42 | self.optimizer._step_count = 0 43 | self._step_count = 0 44 | self.step(last_epoch) 45 | 46 | def state_dict(self): 47 | """Returns the state of the scheduler as a :class:`dict`. 48 | 49 | It contains an entry for every variable in self.__dict__ which 50 | is not the optimizer. 51 | """ 52 | return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} 53 | 54 | def load_state_dict(self, state_dict): 55 | """Loads the schedulers state. 56 | 57 | Arguments: 58 | state_dict (dict): scheduler state. Should be an object returned 59 | from a call to :meth:`state_dict`. 60 | """ 61 | self.__dict__.update(state_dict) 62 | 63 | def get_lr(self): 64 | raise NotImplementedError 65 | 66 | def step(self, epoch=None): 67 | # Raise a warning if old pattern is detected 68 | # https://github.com/pytorch/pytorch/issues/20124 69 | if self._step_count == 1: 70 | if not hasattr(self.optimizer.step, "_with_counter"): 71 | warnings.warn("Seems like `optimizer.step()` has been overridden after learning rate scheduler " 72 | "initialization. Please, make sure to call `optimizer.step()` before " 73 | "`lr_scheduler.step()`. See more details at " 74 | "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning) 75 | 76 | # Just check if there were two first lr_scheduler.step() calls before optimizer.step() 77 | elif self.optimizer._step_count < 1: 78 | warnings.warn("Detected call of `lr_scheduler.step()` before `optimizer.step()`. " 79 | "In PyTorch 1.1.0 and later, you should call them in the opposite order: " 80 | "`optimizer.step()` before `lr_scheduler.step()`. Failure to do this " 81 | "will result in PyTorch skipping the first value of the learning rate schedule." 82 | "See more details at " 83 | "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning) 84 | self._step_count += 1 85 | 86 | if epoch is None: 87 | epoch = self.last_epoch + 1 88 | self.last_epoch = epoch 89 | for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): 90 | param_group['lr'] = lr 91 | 92 | 93 | class warmUpScheduler(_LRScheduler): 94 | def __init__(self, optimizer, last_epoch=-1, warm_up_epoch=1000): 95 | self.warm_up_epoch = warm_up_epoch 96 | super(warmUpScheduler, self).__init__(optimizer, last_epoch) 97 | 98 | def get_lr(self): 99 | if self.last_epoch < self.warm_up_epoch: 100 | return [base_lr * (self.last_epoch / self.warm_up_epoch) 101 | for base_lr in self.base_lrs] 102 | else: 103 | return self.base_lrs 104 | 105 | 106 | class StepLR_withWarmUp(_LRScheduler): 107 | def __init__(self, optimizer, last_epoch=-1, init_lr=1e-5, warm_up_epoch=1000, gamma=1, step_size=1000, min_thres=0): 108 | self.step_size = step_size 109 | self.gamma = gamma 110 | self.warm_up_epoch = warm_up_epoch 111 | self.min_thres = min_thres 112 | self.init_lr = init_lr 113 | super(StepLR_withWarmUp, self).__init__(optimizer, last_epoch) 114 | 115 | def get_lr(self): 116 | if self.last_epoch < self.warm_up_epoch: 117 | return [self.init_lr + (base_lr - self.init_lr) * (self.last_epoch / self.warm_up_epoch) 118 | for base_lr in self.base_lrs] 119 | else: 120 | return [base_lr * max(self.gamma ** ((self.last_epoch - self.warm_up_epoch) // self.step_size), self.min_thres) 121 | for base_lr in self.base_lrs] 122 | 123 | 124 | class SGDR_withWarmUp(object): 125 | def __init__(self, optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1, warm_up_epoch=1): 126 | if not isinstance(optimizer, Optimizer): 127 | raise TypeError('{} is not an Optimizer'.format( 128 | type(optimizer).__name__)) 129 | if T_0 <= 0 or not isinstance(T_0, int): 130 | raise ValueError("Expected positive integer T_0, but got {}".format(T_0)) 131 | if T_mult < 1 or not isinstance(T_mult, int): 132 | raise ValueError("Expected integer T_mult >= 1, but got {}".format(T_mult)) 133 | 134 | self.optimizer = optimizer 135 | if last_epoch == -1: 136 | for group in optimizer.param_groups: 137 | group.setdefault('initial_lr', group['lr']) 138 | last_epoch = 0 139 | else: 140 | for i, group in enumerate(optimizer.param_groups): 141 | if 'initial_lr' not in group: 142 | raise KeyError("param 'initial_lr' is not specified " 143 | "in param_groups[{}] when resuming an optimizer".format(i)) 144 | self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups)) 145 | self.last_epoch = last_epoch 146 | 147 | # Following https://github.com/pytorch/pytorch/issues/20124 148 | # We would like to ensure that `lr_scheduler.step()` is called after 149 | # `optimizer.step()` 150 | def with_counter(func, opt): 151 | @wraps(func) 152 | def wrapper(*args, **kwargs): 153 | opt._step_count += 1 154 | return func(*args, **kwargs) 155 | wrapper._with_counter = True 156 | return wrapper 157 | 158 | self.optimizer.step = with_counter(self.optimizer.step, self.optimizer) 159 | self.optimizer._step_count = 0 160 | self._step_count = 0 161 | 162 | self.T_0 = T_0 163 | self.T_i = T_0 164 | self.T_mult = T_mult 165 | self.eta_min = eta_min 166 | self.warm_up_epoch = warm_up_epoch 167 | self.T_cur = self.last_epoch 168 | 169 | self.step() 170 | 171 | def state_dict(self): 172 | """Returns the state of the scheduler as a :class:`dict`. 173 | 174 | It contains an entry for every variable in self.__dict__ which 175 | is not the optimizer. 176 | """ 177 | return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} 178 | 179 | def load_state_dict(self, state_dict): 180 | """Loads the schedulers state. 181 | 182 | Arguments: 183 | state_dict (dict): scheduler state. Should be an object returned 184 | from a call to :meth:`state_dict`. 185 | """ 186 | self.__dict__.update(state_dict) 187 | 188 | def get_lr(self): 189 | if self.last_epoch < self.warm_up_epoch: 190 | return [base_lr * (self.last_epoch / self.warm_up_epoch) 191 | for base_lr in self.base_lrs] 192 | else: 193 | return [self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * self.T_cur / self.T_i)) / 2 194 | for base_lr in self.base_lrs] 195 | 196 | def step(self): 197 | self.last_epoch += 1 198 | if self.last_epoch < self.warm_up_epoch: 199 | for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): 200 | param_group['lr'] = lr 201 | else: 202 | self.T_cur = self.T_cur + 1 203 | if self.T_cur >= self.T_i: 204 | self.T_cur = self.T_cur - self.T_i 205 | self.T_i = self.T_i * self.T_mult 206 | for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): 207 | param_group['lr'] = lr 208 | -------------------------------------------------------------------------------- /utils/tb_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import cv2 as cv 4 | 5 | 6 | MANO_PARENT = [-1, 0, 1, 2, 3, 7 | 0, 5, 6, 7, 8 | 0, 9, 10, 11, 9 | 0, 13, 14, 15, 10 | 0, 17, 18, 19] 11 | 12 | MANO_COLOR = [[100, 100, 100], 13 | [100, 0, 0], 14 | [150, 0, 0], 15 | [200, 0, 0], 16 | [255, 0, 0], 17 | [100, 100, 0], 18 | [150, 150, 0], 19 | [200, 200, 0], 20 | [255, 255, 0], 21 | [0, 100, 50], 22 | [0, 150, 75], 23 | [0, 200, 100], 24 | [0, 255, 125], 25 | [0, 50, 100], 26 | [0, 75, 150], 27 | [0, 100, 200], 28 | [0, 125, 255], 29 | [100, 0, 100], 30 | [150, 0, 150], 31 | [200, 0, 200], 32 | [255, 0, 255]] 33 | 34 | 35 | def draw_mano_joints(img, joints): 36 | for i in range(21): 37 | cv.circle(img, (int(joints[i, 0]), int(joints[i, 1])), 38 | 2, tuple(MANO_COLOR[i]), -1) 39 | for i in range(1, 21): 40 | cv.line(img, 41 | (int(joints[i, 0]), int(joints[i, 1])), 42 | (int(joints[MANO_PARENT[i], 0]), int(joints[MANO_PARENT[i], 1])), 43 | tuple(MANO_COLOR[i]), 44 | 2) 45 | return img 46 | 47 | 48 | class tbUtils(): 49 | @ staticmethod 50 | def draw_verts(writer, name, idx, 51 | imgTensor, vertTensor, 52 | color=(0, 0, 255), 53 | BGR=True, CHW=True): 54 | with torch.no_grad(): 55 | img = torch.clamp(imgTensor, 0, 1) * 255 56 | img = img.detach().cpu().numpy().astype(np.uint8) 57 | 58 | if CHW: 59 | img = img.transpose(1, 2, 0) 60 | img = img.copy() 61 | 62 | if not isinstance(vertTensor, list): 63 | vertTensor = [vertTensor] 64 | if not isinstance(color, list): 65 | color = [color] 66 | 67 | for j in range(len(vertTensor)): 68 | verts2d = vertTensor[j].detach().cpu().long().numpy() # N x 2 69 | for i in range(verts2d.shape[0]): 70 | cv.circle(img, (verts2d[i, 0], verts2d[i, 1]), 1, color[j]) 71 | 72 | if BGR: 73 | img = cv.cvtColor(img, cv.COLOR_BGR2RGB) 74 | 75 | writer.add_image(name, 76 | torch.from_numpy(img).float() / 255, 77 | idx, 78 | dataformats='HWC') 79 | 80 | def draw_MANO_joints(writer, name, idx, 81 | imgTensor, jointsTensor, 82 | BGR=True, CHW=True): 83 | with torch.no_grad(): 84 | img = torch.clamp(imgTensor, 0, 1) * 255 85 | img = img.detach().cpu().numpy().astype(np.uint8) 86 | 87 | if CHW: 88 | img = img.transpose(1, 2, 0) 89 | 90 | joints2d = jointsTensor.detach().cpu().long().numpy() # 21 x 2 91 | img = img.copy() 92 | 93 | img = draw_mano_joints(img, joints2d) 94 | 95 | if BGR: 96 | img = cv.cvtColor(img, cv.COLOR_BGR2RGB) 97 | 98 | writer.add_image(name, 99 | torch.from_numpy(img).float() / 255, 100 | idx, 101 | dataformats='HWC') 102 | 103 | @ staticmethod 104 | def add_image(writer, name, idx, 105 | imgTensor, dataformats='HW', clamp=False): 106 | if clamp: 107 | imgTensor = torch.clamp(imgTensor, 0, 1) 108 | writer.add_image(name, 109 | imgTensor.float(), 110 | idx, 111 | dataformats=dataformats) 112 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import math 4 | import cv2 as cv 5 | import pickle 6 | import torch 7 | 8 | import os 9 | import sys 10 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 11 | 12 | 13 | from utils.config import get_cfg_defaults 14 | from models.model_zoo import build_graph 15 | 16 | 17 | def projection(scale, trans2d, label3d, img_size=256): 18 | scale = scale * img_size 19 | trans2d = trans2d * img_size / 2 + img_size / 2 20 | trans2d = trans2d 21 | 22 | label2d = scale * label3d[:, :2] + trans2d 23 | return label2d 24 | 25 | 26 | def projection_batch(scale, trans2d, label3d, img_size=256): 27 | """orthodox projection 28 | Input: 29 | scale: (B) 30 | trans2d: (B, 2) 31 | label3d: (B x N x 3) 32 | Returns: 33 | (B, N, 2) 34 | """ 35 | scale = scale * img_size # bs 36 | if scale.dim() == 1: 37 | scale = scale.unsqueeze(-1).unsqueeze(-1) 38 | if scale.dim() == 2: 39 | scale = scale.unsqueeze(-1) 40 | trans2d = trans2d * img_size / 2 + img_size / 2 # bs x 2 41 | trans2d = trans2d.unsqueeze(1) 42 | 43 | label2d = scale * label3d[..., :2] + trans2d 44 | return label2d 45 | 46 | 47 | def projection_batch_np(scale, trans2d, label3d, img_size=256): 48 | """orthodox projection 49 | Input: 50 | scale: (B) 51 | trans2d: (B, 2) 52 | label3d: (B x N x 3) 53 | Returns: 54 | (B, N, 2) 55 | """ 56 | scale = scale * img_size # bs 57 | if scale.dim() == 1: 58 | scale = scale[..., np.newaxis, np.newaxis] 59 | if scale.dim() == 2: 60 | scale = scale[..., np.newaxis] 61 | trans2d = trans2d * img_size / 2 + img_size / 2 # bs x 2 62 | trans2d = trans2d[:, np.newaxis, :] 63 | 64 | label2d = scale * label3d[..., :2] + trans2d 65 | return label2d 66 | 67 | 68 | def get_mano_path(): 69 | cfg = get_cfg_defaults() 70 | abspath = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) 71 | path = os.path.join(abspath, cfg.MISC.MANO_PATH) 72 | mano_path = {'left': os.path.join(path, 'MANO_LEFT.pkl'), 73 | 'right': os.path.join(path, 'MANO_RIGHT.pkl')} 74 | return mano_path 75 | 76 | 77 | def get_graph_dict_path(): 78 | cfg = get_cfg_defaults() 79 | abspath = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) 80 | graph_path = {'left': os.path.join(abspath, cfg.MISC.GRAPH_LEFT_DICT_PATH), 81 | 'right': os.path.join(abspath, cfg.MISC.GRAPH_RIGHT_DICT_PATH)} 82 | return graph_path 83 | 84 | 85 | def get_dense_color_path(): 86 | cfg = get_cfg_defaults() 87 | abspath = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) 88 | dense_path = os.path.join(abspath, cfg.MISC.DENSE_COLOR) 89 | return dense_path 90 | 91 | 92 | def get_mano_seg_path(): 93 | cfg = get_cfg_defaults() 94 | abspath = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) 95 | seg_path = os.path.join(abspath, cfg.MISC.MANO_SEG_PATH) 96 | return seg_path 97 | 98 | 99 | def get_upsample_path(): 100 | cfg = get_cfg_defaults() 101 | abspath = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) 102 | upsample_path = os.path.join(abspath, cfg.MISC.UPSAMPLE_PATH) 103 | return upsample_path 104 | 105 | 106 | def build_mano_graph(): 107 | graph_path = get_graph_dict_path() 108 | mano_path = get_mano_path() 109 | for hand_type in ['left', 'right']: 110 | if not os.path.exists(graph_path[hand_type]): 111 | manoData = pickle.load(open(mano_path[hand_type], 'rb'), encoding='latin1') 112 | faces = manoData['f'] 113 | graph_dict = build_graph(faces, coarsening_levels=4) 114 | with open(graph_path[hand_type], 'wb') as file: 115 | pickle.dump(graph_dict, file) 116 | 117 | 118 | class imgUtils(): 119 | @ staticmethod 120 | def pad2squre(img, color=None): 121 | if img.shape[0] > img.shape[1]: 122 | W = img.shape[0] - img.shape[1] 123 | else: 124 | W = img.shape[1] - img.shape[0] 125 | W1 = int(W / 2) 126 | W2 = W - W1 127 | if color is None: 128 | if img.shape[2] == 3: 129 | color = (0, 0, 0) 130 | else: 131 | color = 0 132 | if img.shape[0] > img.shape[1]: 133 | return cv.copyMakeBorder(img, 0, 0, W1, W2, cv.BORDER_CONSTANT, value=color) 134 | else: 135 | return cv.copyMakeBorder(img, W1, W2, 0, 0, cv.BORDER_CONSTANT, value=color) 136 | 137 | @ staticmethod 138 | def cut2squre(img): 139 | if img.shape[0] > img.shape[1]: 140 | s = int((img.shape[0] - img.shape[1]) / 2) 141 | return img[s:(s + img.shape[1])] 142 | else: 143 | s = int((img.shape[1] - img.shape[0]) / 2) 144 | return img[:, s:(s + img.shape[0])] 145 | 146 | @ staticmethod 147 | def get_scale_mat(center, scale=1.0): 148 | scaleMat = np.zeros((3, 3), dtype='float32') 149 | scaleMat[0, 0] = scale 150 | scaleMat[1, 1] = scale 151 | scaleMat[2, 2] = 1.0 152 | t = np.matmul((np.identity(3, dtype='float32') - scaleMat), center) 153 | scaleMat[0, 2] = t[0] 154 | scaleMat[1, 2] = t[1] 155 | return scaleMat 156 | 157 | @ staticmethod 158 | def get_rotation_mat(center, theta=0): 159 | t = theta * (3.14159 / 180) 160 | rotationMat = np.zeros((3, 3), dtype='float32') 161 | rotationMat[0, 0] = math.cos(t) 162 | rotationMat[0, 1] = -math.sin(t) 163 | rotationMat[1, 0] = math.sin(t) 164 | rotationMat[1, 1] = math.cos(t) 165 | rotationMat[2, 2] = 1.0 166 | t = np.matmul((np.identity(3, dtype='float32') - rotationMat), center) 167 | rotationMat[0, 2] = t[0] 168 | rotationMat[1, 2] = t[1] 169 | return rotationMat 170 | 171 | @ staticmethod 172 | def get_rotation_mat3d(theta=0): 173 | t = theta * (3.14159 / 180) 174 | rotationMat = np.zeros((3, 3), dtype='float32') 175 | rotationMat[0, 0] = math.cos(t) 176 | rotationMat[0, 1] = -math.sin(t) 177 | rotationMat[1, 0] = math.sin(t) 178 | rotationMat[1, 1] = math.cos(t) 179 | rotationMat[2, 2] = 1.0 180 | return rotationMat 181 | 182 | @ staticmethod 183 | def get_affine_mat(theta=0, scale=1.0, 184 | u=0, v=0, 185 | height=480, width=640): 186 | center = np.array([width / 2, height / 2, 1], dtype='float32') 187 | rotationMat = imgUtils.get_rotation_mat(center, theta) 188 | scaleMat = imgUtils.get_scale_mat(center, scale) 189 | trans = np.identity(3, dtype='float32') 190 | trans[0, 2] = u 191 | trans[1, 2] = v 192 | affineMat = np.matmul(scaleMat, rotationMat) 193 | affineMat = np.matmul(trans, affineMat) 194 | return affineMat 195 | 196 | @staticmethod 197 | def img_trans(theta, scale, u, v, img): 198 | size = img.shape[0] 199 | u = int(u * size / 2) 200 | v = int(v * size / 2) 201 | affineMat = imgUtils.get_affine_mat(theta=theta, scale=scale, 202 | u=u, v=v, 203 | height=256, width=256) 204 | return cv.warpAffine(src=img, 205 | M=affineMat[0:2, :], 206 | dsize=(256, 256), 207 | dst=img, 208 | flags=cv.INTER_LINEAR, 209 | borderMode=cv.BORDER_REPLICATE, 210 | borderValue=(0, 0, 0) 211 | ) 212 | 213 | @staticmethod 214 | def data_augmentation(theta, scale, u, v, 215 | img_list=None, label2d_list=None, label3d_list=None, 216 | R=None, 217 | img_size=224): 218 | affineMat = imgUtils.get_affine_mat(theta=theta, scale=scale, 219 | u=u, v=v, 220 | height=img_size, width=img_size) 221 | if img_list is not None: 222 | img_list_out = [] 223 | for img in img_list: 224 | img_list_out.append(cv.warpAffine(src=img, 225 | M=affineMat[0:2, :], 226 | dsize=(img_size, img_size))) 227 | else: 228 | img_list_out = None 229 | 230 | if label2d_list is not None: 231 | label2d_list_out = [] 232 | for label2d in label2d_list: 233 | label2d_list_out.append(np.matmul(label2d, affineMat[0:2, 0:2].T) + affineMat[0:2, 2:3].T) 234 | else: 235 | label2d_list_out = None 236 | 237 | if label3d_list is not None: 238 | label3d_list_out = [] 239 | R_delta = imgUtils.get_rotation_mat3d(theta) 240 | for label3d in label3d_list: 241 | label3d_list_out.append(np.matmul(label3d, R_delta.T)) 242 | else: 243 | label3d_list_out = None 244 | 245 | if R is not None: 246 | R_delta = imgUtils.get_rotation_mat3d(theta) 247 | R = np.matmul(R_delta, R) 248 | else: 249 | R = None 250 | 251 | return img_list_out, label2d_list_out, label3d_list_out, R 252 | 253 | @ staticmethod 254 | def add_noise(img, noise=0.00, scale=255.0, alpha=0.3, beta=0.05): 255 | # add brightness noise & add random gaussian noise 256 | a = np.random.uniform(1 - alpha, 1 + alpha, 3) 257 | b = scale * beta * (2 * random.random() - 1) 258 | img = a * img + b + scale * np.random.normal(loc=0.0, scale=noise, size=img.shape) 259 | img = np.clip(img, 0, scale) 260 | return img 261 | -------------------------------------------------------------------------------- /utils/vis_utils.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy as np 3 | import torch 4 | 5 | import sys 6 | import os 7 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 8 | 9 | from models.manolayer import ManoLayer 10 | from utils.config import get_cfg_defaults 11 | from utils.utils import projection_batch, get_mano_path, get_dense_color_path 12 | 13 | 14 | # Data structures and functions for rendering 15 | from pytorch3d.structures import Meshes 16 | from pytorch3d.vis.plotly_vis import AxisArgs, plot_batch_individually, plot_scene 17 | from pytorch3d.vis.texture_vis import texturesuv_image_matplotlib 18 | from pytorch3d.renderer import ( 19 | look_at_view_transform, 20 | PerspectiveCameras, 21 | OrthographicCameras, 22 | PointLights, 23 | DirectionalLights, 24 | Materials, 25 | RasterizationSettings, 26 | MeshRenderer, 27 | MeshRasterizer, 28 | SoftPhongShader, 29 | HardPhongShader, 30 | TexturesUV, 31 | TexturesVertex, 32 | HardFlatShader, 33 | HardGouraudShader, 34 | AmbientLights, 35 | SoftSilhouetteShader 36 | ) 37 | 38 | 39 | class Renderer(): 40 | def __init__(self, img_size, device='cpu'): 41 | self.img_size = img_size 42 | self.raster_settings = RasterizationSettings( 43 | image_size=img_size, 44 | blur_radius=0.0, 45 | faces_per_pixel=1 46 | ) 47 | 48 | self.amblights = AmbientLights(device=device) 49 | self.point_lights = PointLights(location=[[0, 0, -1.0]], device=device) 50 | 51 | self.renderer_rgb = MeshRenderer( 52 | rasterizer=MeshRasterizer(raster_settings=self.raster_settings), 53 | shader=HardPhongShader(device=device) 54 | ) 55 | self.device = device 56 | 57 | def build_camera(self, cameras=None, 58 | scale=None, trans2d=None): 59 | if scale is not None and trans2d is not None: 60 | bs = scale.shape[0] 61 | R = torch.tensor([[-1, 0, 0], [0, -1, 0], [0, 0, 1]]).repeat(bs, 1, 1).to(scale.dtype) 62 | T = torch.tensor([0, 0, 10]).repeat(bs, 1).to(scale.dtype) 63 | return OrthographicCameras(focal_length=2 * scale.to(self.device), 64 | principal_point=-trans2d.to(self.device), 65 | R=R.to(self.device), 66 | T=T.to(self.device), 67 | in_ndc=True, 68 | device=self.device) 69 | if cameras is not None: 70 | # cameras: bs x 3 x 3 71 | fs = -torch.stack((cameras[:, 0, 0], cameras[:, 1, 1]), dim=-1) * 2 / self.img_size 72 | pps = -cameras[:, :2, -1] * 2 / self.img_size + 1 73 | return PerspectiveCameras(focal_length=fs.to(self.device), 74 | principal_point=pps.to(self.device), 75 | in_ndc=True, 76 | device=self.device 77 | ) 78 | 79 | def build_texture(self, uv_verts=None, uv_faces=None, texture=None, 80 | v_color=None): 81 | if uv_verts is not None and uv_faces is not None and texture is not None: 82 | return TexturesUV(texture.to(self.device), uv_faces.to(self.device), uv_verts.to(self.device)) 83 | if v_color is not None: 84 | return TexturesVertex(verts_features=v_color.to(self.device)) 85 | 86 | def render(self, verts, faces, cameras, textures, amblights=False, 87 | lights=None): 88 | if lights is None: 89 | if amblights: 90 | lights = self.amblights 91 | else: 92 | lights = self.point_lights 93 | mesh = Meshes(verts=verts.to(self.device), faces=faces.to(self.device), textures=textures) 94 | output = self.renderer_rgb(mesh, cameras=cameras, lights=lights) 95 | alpha = output[..., 3] 96 | img = output[..., :3] / 255 97 | return img, alpha 98 | 99 | 100 | class mano_renderer(Renderer): 101 | def __init__(self, mano_path=None, dense_path=None, img_size=224, device='cpu'): 102 | super(mano_renderer, self).__init__(img_size, device) 103 | if mano_path is None: 104 | mano_path = get_mano_path() 105 | if dense_path is None: 106 | dense_path = get_dense_color_path() 107 | 108 | self.mano = ManoLayer(mano_path, center_idx=9, use_pca=True) 109 | self.mano.to(self.device) 110 | self.faces_np = self.mano.get_faces().astype(np.int64) 111 | self.faces = torch.from_numpy(self.faces_np).to(self.device).unsqueeze(0) 112 | 113 | with open(dense_path, 'rb') as file: 114 | dense_coor = pickle.load(file) 115 | self.dense_coor = torch.from_numpy(dense_coor) * 255 116 | 117 | def render_rgb(self, cameras=None, scale=None, trans2d=None, 118 | R=None, pose=None, shape=None, trans=None, 119 | v3d=None, 120 | uv_verts=None, uv_faces=None, texture=None, v_color=(255, 255, 255), 121 | amblights=False): 122 | if v3d is None: 123 | v3d, _ = self.mano(R, pose, shape, trans=trans) 124 | bs = v3d.shape[0] 125 | vNum = v3d.shape[1] 126 | 127 | if not isinstance(v_color, torch.Tensor): 128 | v_color = torch.tensor(v_color) 129 | v_color = v_color.expand(bs, vNum, 3).to(v3d) 130 | 131 | return self.render(v3d, self.faces.repeat(bs, 1, 1), 132 | self.build_camera(cameras, scale, trans2d), 133 | self.build_texture(uv_verts, uv_faces, texture, v_color), 134 | amblights) 135 | 136 | def render_densepose(self, cameras=None, scale=None, trans2d=None, 137 | R=None, pose=None, shape=None, trans=None, 138 | v3d=None): 139 | if v3d is None: 140 | v3d, _ = self.mano(R, pose, shape, trans=trans) 141 | bs = v3d.shape[0] 142 | vNum = v3d.shape[1] 143 | 144 | return self.render(v3d, self.faces.repeat(bs, 1, 1), 145 | self.build_camera(cameras, scale, trans2d), 146 | self.build_texture(v_color=self.dense_coor.expand(bs, vNum, 3).to(v3d)), 147 | True) 148 | 149 | 150 | class mano_two_hands_renderer(Renderer): 151 | def __init__(self, mano_path=None, dense_path=None, img_size=224, device='cpu'): 152 | super(mano_two_hands_renderer, self).__init__(img_size, device) 153 | if mano_path is None: 154 | mano_path = get_mano_path() 155 | if dense_path is None: 156 | dense_path = get_dense_color_path() 157 | 158 | self.mano = {'right': ManoLayer(mano_path['right'], center_idx=None), 159 | 'left': ManoLayer(mano_path['left'], center_idx=None)} 160 | self.mano['left'].to(self.device) 161 | self.mano['right'].to(self.device) 162 | 163 | left_faces = torch.from_numpy(self.mano['left'].get_faces().astype(np.int64)).to(self.device).unsqueeze(0) 164 | right_faces = torch.from_numpy(self.mano['right'].get_faces().astype(np.int64)).to(self.device).unsqueeze(0) 165 | left_faces = right_faces[..., [1, 0, 2]] 166 | 167 | self.faces = torch.cat((left_faces, right_faces + 778), dim=1) 168 | 169 | with open(dense_path, 'rb') as file: 170 | dense_coor = pickle.load(file) 171 | self.dense_coor = torch.from_numpy(dense_coor) * 255 172 | 173 | def render_rgb(self, cameras=None, scale=None, trans2d=None, 174 | v3d_left=None, v3d_right=None, 175 | uv_verts=None, uv_faces=None, texture=None, v_color=None, 176 | amblights=False, 177 | lights=None): 178 | bs = v3d_left.shape[0] 179 | vNum = v3d_left.shape[1] 180 | 181 | if v_color is None: 182 | v_color = torch.zeros((778 * 2, 3)) 183 | v_color[:778, 0] = 204 184 | v_color[:778, 1] = 153 185 | v_color[:778, 2] = 0 186 | v_color[778:, 0] = 102 187 | v_color[778:, 1] = 102 188 | v_color[778:, 2] = 255 189 | 190 | if not isinstance(v_color, torch.Tensor): 191 | v_color = torch.tensor(v_color) 192 | v_color = v_color.expand(bs, 2 * vNum, 3).float().to(self.device) 193 | 194 | v3d = torch.cat((v3d_left, v3d_right), dim=1) 195 | 196 | return self.render(v3d, 197 | self.faces.repeat(bs, 1, 1), 198 | self.build_camera(cameras, scale, trans2d), 199 | self.build_texture(uv_verts, uv_faces, texture, v_color), 200 | amblights, 201 | lights) 202 | 203 | def render_rgb_orth(self, scale_left=None, trans2d_left=None, 204 | scale_right=None, trans2d_right=None, 205 | v3d_left=None, v3d_right=None, 206 | uv_verts=None, uv_faces=None, texture=None, v_color=None, 207 | amblights=False): 208 | scale = scale_left 209 | trans2d = trans2d_left 210 | 211 | s = scale_right / scale_left 212 | d = -(trans2d_left - trans2d_right) / 2 / scale_left.unsqueeze(-1) 213 | 214 | s = s.unsqueeze(-1).unsqueeze(-1) 215 | d = d.unsqueeze(1) 216 | v3d_right = s * v3d_right 217 | v3d_right[..., :2] = v3d_right[..., :2] + d 218 | 219 | # scale = (scale_left + scale_right) / 2 220 | # trans2d = (trans2d_left + trans2d_right) / 2 221 | 222 | return self.render_rgb(self, scale=scale, trans2d=trans2d, 223 | v3d_left=v3d_left, v3d_right=v3d_right, 224 | uv_verts=uv_verts, uv_faces=uv_faces, texture=texture, v_color=v_color, 225 | amblights=amblights) 226 | 227 | def render_mask(self, cameras=None, scale=None, trans2d=None, 228 | v3d_left=None, v3d_right=None): 229 | v_color = torch.zeros((778 * 2, 3)) 230 | v_color[:778, 2] = 255 231 | v_color[778:, 1] = 255 232 | rgb, mask = self.render_rgb(cameras, scale, trans2d, 233 | v3d_left, v3d_right, 234 | v_color=v_color, 235 | amblights=True) 236 | return rgb 237 | 238 | def render_densepose(self, cameras=None, scale=None, trans2d=None, 239 | v3d_left=None, v3d_right=None,): 240 | bs = v3d_left.shape[0] 241 | vNum = v3d_left.shape[1] 242 | 243 | v3d = torch.cat((v3d_left, v3d_right), dim=1) 244 | 245 | v_color = torch.cat((self.dense_coor, self.dense_coor), dim=0) 246 | 247 | return self.render(v3d, 248 | self.faces.repeat(bs, 1, 1), 249 | self.build_camera(cameras, scale, trans2d), 250 | self.build_texture(v_color=v_color.expand(bs, 2 * vNum, 3).to(v3d_left)), 251 | True) 252 | --------------------------------------------------------------------------------