├── .gitignore
├── README.md
├── asset
├── comparison_sota_DexYCB.png
├── comparison_sota_HO3D.png
└── model.png
├── common
├── base.py
├── logger.py
├── nets
│ ├── backbone.py
│ ├── cbam.py
│ ├── hand_head.py
│ ├── mano_head.py
│ ├── regressor.py
│ └── transformer.py
├── timer.py
└── utils
│ ├── __init__.py
│ ├── camera.py
│ ├── dir.py
│ ├── fitting.py
│ ├── mano.py
│ ├── manopth
│ ├── .gitignore
│ ├── LICENSE
│ ├── README.md
│ ├── environment.yml
│ ├── examples
│ │ ├── manopth_demo.py
│ │ └── manopth_mindemo.py
│ ├── mano
│ │ ├── __init__.py
│ │ └── webuser
│ │ │ ├── __init__.py
│ │ │ ├── lbs.py
│ │ │ ├── posemapper.py
│ │ │ ├── serialization.py
│ │ │ ├── smpl_handpca_wrapper_HAND_only.py
│ │ │ └── verts.py
│ ├── manopth
│ │ ├── __init__.py
│ │ ├── argutils.py
│ │ ├── demo.py
│ │ ├── manolayer.py
│ │ ├── rodrigues_layer.py
│ │ ├── rot6d.py
│ │ ├── rotproj.py
│ │ └── tensutils.py
│ ├── setup.py
│ └── test
│ │ └── test_demo.py
│ ├── optimizers
│ ├── __init__.py
│ ├── lbfgs_ls.py
│ └── optim_factory.py
│ ├── preprocessing.py
│ ├── transforms.py
│ └── vis.py
├── data
├── DEX_YCB
│ └── DEX_YCB.py
└── HO3D
│ └── HO3D.py
├── demo
├── demo.py
├── demo_fitting.py
├── fitting_input.png
├── hand_bbox.png
├── hand_image.png
├── input.png
└── output.obj
├── main
├── config.py
├── model.py
├── test.py
└── train.py
└── requiremets.sh
/.gitignore:
--------------------------------------------------------------------------------
1 | # Created by .ignore support plugin (hsz.mobi)
2 | ### Python template
3 | # Byte-compiled / optimized / DLL files
4 | __pycache__/
5 | *.py[cod]
6 | *$py.class
7 |
8 | # C extensions
9 | *.so
10 |
11 | *.npy
12 | #*.png
13 | *.jpg
14 | *.out
15 | output
16 | data/*/data
17 |
18 |
19 | # Distribution / packaging
20 | .Python
21 | build/
22 | develop-eggs/
23 | dist/
24 | downloads/
25 | eggs/
26 | .eggs/
27 |
28 | lib64/
29 | parts/
30 | sdist/
31 | var/
32 | wheels/
33 | *.egg-info/
34 | .installed.cfg
35 | *.egg
36 | MANIFEST
37 |
38 | # PyInstaller
39 | # Usually these files are written by a python script from a template
40 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
41 | *.manifest
42 | *.spec
43 |
44 | # Installer logs
45 | pip-log.txt
46 | pip-delete-this-directory.txt
47 |
48 | # Unit test / coverage reports
49 | htmlcov/
50 | .tox/
51 | .coverage
52 | .coverage.*
53 | .cache
54 | nosetests.xml
55 | coverage.xml
56 | *.cover
57 | .hypothesis/
58 | .pytest_cache/
59 |
60 | # Translations
61 | *.mo
62 | *.pot
63 |
64 | # Django stuff:
65 | *.log
66 | local_settings.py
67 | db.sqlite3
68 |
69 | # Flask stuff:
70 | instance/
71 | .webassets-cache
72 |
73 | # Scrapy stuff:
74 | .scrapy
75 |
76 | # Sphinx documentation
77 | docs/_build/
78 |
79 | # PyBuilder
80 | target/
81 |
82 | # Jupyter Notebook
83 | .ipynb_checkpoints
84 |
85 | # pyenv
86 | .python-version
87 |
88 | # celery beat schedule file
89 | celerybeat-schedule
90 |
91 | # SageMath parsed files
92 | *.sage.py
93 |
94 | # Environments
95 | .env
96 | .venv
97 | env/
98 | venv/
99 | ENV/
100 | env.bak/
101 | venv.bak/
102 |
103 | # Spyder project settings
104 | .spyderproject
105 | .spyproject
106 |
107 | # Rope project settings
108 | .ropeproject
109 |
110 | # mkdocs documentation
111 | /site
112 |
113 | # mypy
114 | .mypy_cache/
115 | ### macOS template
116 | # General
117 | .DS_Store
118 | .AppleDouble
119 | .LSOverride
120 |
121 | # Icon must end with two \r
122 | Icon
123 |
124 | # Thumbnails
125 | ._*
126 |
127 | # Files that might appear in the root of a volume
128 | .DocumentRevisions-V100
129 | .fseventsd
130 | .Spotlight-V100
131 | .TemporaryItems
132 | .Trashes
133 | .VolumeIcon.icns
134 | .com.apple.timemachine.donotpresent
135 |
136 | # Directories potentially created on remote AFP share
137 | .AppleDB
138 | .AppleDesktop
139 | Network Trash Folder
140 | Temporary Items
141 | .apdisk
142 | ### JetBrains template
143 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm
144 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
145 |
146 | # User-specific stuff
147 | .idea/**/workspace.xml
148 | .idea/**/tasks.xml
149 | .idea/**/dictionaries
150 | .idea/**/shelf
151 |
152 | # Sensitive or high-churn files
153 | .idea/**/dataSources/
154 | .idea/**/dataSources.ids
155 | .idea/**/dataSources.local.xml
156 | .idea/**/sqlDataSources.xml
157 | .idea/**/dynamic.xml
158 | .idea/**/uiDesigner.xml
159 | .idea/**/dbnavigator.xml
160 |
161 | # Gradle
162 | .idea/**/gradle.xml
163 | .idea/**/libraries
164 |
165 | # CMake
166 | cmake-build-debug/
167 | cmake-build-release/
168 |
169 | # Mongo Explorer plugin
170 | .idea/**/mongoSettings.xml
171 |
172 | # File-based project format
173 | *.iws
174 |
175 | # IntelliJ
176 | out/
177 |
178 | # mpeltonen/sbt-idea plugin
179 | .idea_modules/
180 |
181 | # JIRA plugin
182 | atlassian-ide-plugin.xml
183 |
184 | # Cursive Clojure plugin
185 | .idea/replstate.xml
186 |
187 | # Crashlytics plugin (for Android Studio and IntelliJ)
188 | com_crashlytics_export_strings.xml
189 | crashlytics.properties
190 | crashlytics-build.properties
191 | fabric.properties
192 |
193 | # Editor-based Rest Client
194 | .idea/httpRequests
195 |
196 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # HandOccNet: Occlusion-Robust 3D Hand Mesh Estimation Network
2 |
3 | ## Introduction
4 | This repository is the offical [Pytorch](https://pytorch.org/) implementation of **[HandOccNet: Occlusion-Robust 3D Hand Mesh Estimation Network (CVPR 2022)](https://arxiv.org/abs/2203.14564)**. Below is the overall pipeline of HandOccNet.
5 | 
6 |
7 | ## Quick demo
8 | * Install **[PyTorch](https://pytorch.org)** and Python >= 3.7.4 and run `sh requirements.sh`.
9 | * Download `snapshot_demo.pth.tar` from [here](https://drive.google.com/drive/folders/1OlyV-qbzOmtQYdzV6dbQX4OtAU5ajBOa?usp=sharing) and place at `demo` folder.
10 | * Prepare `input.jpg` at `demo` folder.
11 | * Download `MANO_RIGHT.pkl` from [here](https://mano.is.tue.mpg.de/) and place at `common/utils/manopth/mano/models`.
12 | * Go to `demo` folder and edit `bbox` in [here](https://github.com/namepllet/HandOccNet/blob/185492e0e5b08c47e37039c5d67e3f2b099a6f9e/demo/demo.py#L61).
13 | * Run `python demo.py --gpu 0` if you want to run on gpu 0.
14 | * You can see `hand_bbox.png`, `hand_image.png`, and `output.obj`.
15 | * Run `python demo_fitting.py --gpu 0 --depth 0.5` if you want to get the hand mesh's translation from the camera. The depth argument is initialization for the optimization.
16 | * You can see `fitting_input_3d_mesh.json` that contains the translation and MANO parameters, `fitting_input_3dmesh.obj`, `fitting_input_2d_prediction.png`, and `fitting_input_projection.png`.
17 |
18 | ## Directory
19 | ### Root
20 | The `${ROOT}` is described as below.
21 | ```
22 | ${ROOT}
23 | |-- data
24 | |-- demo
25 | |-- common
26 | |-- main
27 | |-- output
28 | ```
29 | * `data` contains data loading codes and soft links to images and annotations directories.
30 | * `demo` contains demo codes.
31 | * `common` contains kernel codes for HandOccNet.
32 | * `main` contains high-level codes for training or testing the network.
33 | * `output` contains log, trained models, visualized outputs, and test result.
34 |
35 | ### Data
36 | You need to follow directory structure of the `data` as below.
37 | ```
38 | ${ROOT}
39 | |-- data
40 | | |-- HO3D
41 | | | |-- data
42 | | | | |-- train
43 | | | | | |-- ABF10
44 | | | | | |-- ......
45 | | | | |-- evaluation
46 | | | | |-- annotations
47 | | | | | |-- HO3D_train_data.json
48 | | | | | |-- HO3D_evaluation_data.json
49 | | |-- DEX_YCB
50 | | | |-- data
51 | | | | |-- 20200709-subject-01
52 | | | | |-- ......
53 | | | | |-- annotations
54 | | | | | |--DEX_YCB_s0_train_data.json
55 | | | | | |--DEX_YCB_s0_test_data.json
56 | ```
57 | * Download HO3D(version 2) data and annotation files [[data](https://www.tugraz.at/institute/icg/research/team-lepetit/research-projects/hand-object-3d-pose-annotation/)][[annotation files](https://drive.google.com/drive/folders/1pmRpgv38PXvlLOODtoxpTYnIpYTkNV6b?usp=sharing)]
58 | * Download DexYCB data and annotation files [[data](https://dex-ycb.github.io/)][[annotation files](https://drive.google.com/drive/folders/1pmRpgv38PXvlLOODtoxpTYnIpYTkNV6b?usp=sharing)]
59 |
60 | ### Pytorch MANO layer
61 | * For the MANO layer, I used [manopth](https://github.com/hassony2/manopth). The repo is already included in `common/utils/manopth`.
62 | * Download `MANO_RIGHT.pkl` from [here](https://mano.is.tue.mpg.de/) and place at `common/utils/manopth/mano/models`.
63 |
64 | ### Output
65 | You need to follow the directory structure of the `output` folder as below.
66 | ```
67 | ${ROOT}
68 | |-- output
69 | | |-- log
70 | | |-- model_dump
71 | | |-- result
72 | | |-- vis
73 | ```
74 | * Creating `output` folder as soft link form is recommended instead of folder form because it would take large storage capacity.
75 | * `log` folder contains training log file.
76 | * `model_dump` folder contains saved checkpoints for each epoch.
77 | * `result` folder contains final estimation files generated in the testing stage.
78 | * `vis` folder contains visualized results.
79 |
80 | ## Running HandOccNet
81 | ### Start
82 | * Install **[PyTorch](https://pytorch.org)** and Python >= 3.7.4 and run `sh requirements.sh`.
83 | * In the `main/config.py`, you can change settings of the model including dataset to use and input size and so on.
84 |
85 | ### Train
86 | In the `main` folder, set trainset in `config.py` (as 'HO3D' or 'DEX_YCB') and run
87 | ```bash
88 | python train.py --gpu 0-3
89 | ```
90 | to train HandOccNet on the GPU 0,1,2,3. `--gpu 0,1,2,3` can be used instead of `--gpu 0-3`.
91 |
92 | ### Test
93 | Place trained model at the `output/model_dump/`.
94 |
95 | In the `main` folder, set testset in `config.py` (as 'HO3D' or 'DEX_YCB') and run
96 | ```bash
97 | python test.py --gpu 0-3 --test_epoch {test epoch}
98 | ```
99 | to test HandOccNet on the GPU 0,1,2,3 with {test epoch}th epoch trained model. `--gpu 0,1,2,3` can be used instead of `--gpu 0-3`.
100 |
101 | * For the HO3D dataset, pred{test epoch}.zip will be generated in `output/result` folder. You can upload it to the [codalab challenge](https://competitions.codalab.org/competitions/22485) and see the results.
102 | * Our trained model can be downloaded from [here](https://drive.google.com/drive/folders/1OlyV-qbzOmtQYdzV6dbQX4OtAU5ajBOa?usp=sharing)
103 |
104 | ## Results
105 | Here I report the performance of the HandOccNet.
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 | ## Reference
115 | ```
116 | @InProceedings{Park_2022_CVPR_HandOccNet,
117 | author = {Park, JoonKyu and Oh, Yeonguk and Moon, Gyeongsik and Choi, Hongsuk and Lee, Kyoung Mu},
118 | title = {HandOccNet: Occlusion-Robust 3D Hand Mesh Estimation Network},
119 | booktitle = {Conference on Computer Vision and Pattern Recognition (CVPR)},
120 | year = {2022}
121 | }
122 | ```
123 | ## Acknowledgements
124 | For this project, we relied on research codes from:
125 | * [I2L-MeshNet_RELEASE](https://github.com/mks0601/I2L-MeshNet_RELEASE)
126 | * [Semi-Hand-Object](https://github.com/stevenlsw/Semi-Hand-Object)
127 | * [attention-module](https://github.com/Jongchan/attention-module)
128 |
--------------------------------------------------------------------------------
/asset/comparison_sota_DexYCB.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/namepllet/HandOccNet/65ba997c9ce88947f70453fa0ac6c1e7d052660f/asset/comparison_sota_DexYCB.png
--------------------------------------------------------------------------------
/asset/comparison_sota_HO3D.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/namepllet/HandOccNet/65ba997c9ce88947f70453fa0ac6c1e7d052660f/asset/comparison_sota_HO3D.png
--------------------------------------------------------------------------------
/asset/model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/namepllet/HandOccNet/65ba997c9ce88947f70453fa0ac6c1e7d052660f/asset/model.png
--------------------------------------------------------------------------------
/common/base.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import math
4 | import time
5 | import glob
6 | import abc
7 | from torch.utils.data import DataLoader
8 | import torch.optim
9 | import torchvision.transforms as transforms
10 | from timer import Timer
11 | from logger import colorlogger
12 | from torch.nn.parallel.data_parallel import DataParallel
13 | from config import cfg
14 | from model import get_model
15 |
16 | # dynamic dataset import
17 | exec('from ' + cfg.trainset + ' import ' + cfg.trainset)
18 | exec('from ' + cfg.testset + ' import ' + cfg.testset)
19 |
20 | class Base(object):
21 | __metaclass__ = abc.ABCMeta
22 |
23 | def __init__(self, log_name='logs.txt'):
24 |
25 | self.cur_epoch = 0
26 |
27 | # timer
28 | self.tot_timer = Timer()
29 | self.gpu_timer = Timer()
30 | self.read_timer = Timer()
31 |
32 | # logger
33 | self.logger = colorlogger(cfg.log_dir, log_name=log_name)
34 |
35 | @abc.abstractmethod
36 | def _make_batch_generator(self):
37 | return
38 |
39 | @abc.abstractmethod
40 | def _make_model(self):
41 | return
42 |
43 | class Trainer(Base):
44 | def __init__(self):
45 | super(Trainer, self).__init__(log_name = 'train_logs.txt')
46 |
47 | def get_optimizer(self, model):
48 | model_params = filter(lambda p: p.requires_grad, model.parameters())
49 | optimizer = torch.optim.Adam(model_params, lr=cfg.lr)
50 | return optimizer
51 |
52 | def save_model(self, state, epoch):
53 | file_path = osp.join(cfg.model_dir,'snapshot_{}.pth.tar'.format(str(epoch)))
54 | torch.save(state, file_path)
55 | self.logger.info("Write snapshot into {}".format(file_path))
56 |
57 | def load_model(self, model, optimizer):
58 | model_file_list = glob.glob(osp.join(cfg.model_dir,'*.pth.tar'))
59 | cur_epoch = max([int(file_name[file_name.find('snapshot_') + 9 : file_name.find('.pth.tar')]) for file_name in model_file_list])
60 | ckpt_path = osp.join(cfg.model_dir, 'snapshot_' + str(cur_epoch) + '.pth.tar')
61 | ckpt = torch.load(ckpt_path)
62 | start_epoch = ckpt['epoch'] + 1
63 | model.load_state_dict(ckpt['network'], strict=False)
64 | #optimizer.load_state_dict(ckpt['optimizer'])
65 |
66 | self.logger.info('Load checkpoint from {}'.format(ckpt_path))
67 | return start_epoch, model, optimizer
68 |
69 | def set_lr(self, epoch):
70 | for e in cfg.lr_dec_epoch:
71 | if epoch < e:
72 | break
73 | if epoch < cfg.lr_dec_epoch[-1]:
74 | idx = cfg.lr_dec_epoch.index(e)
75 | for g in self.optimizer.param_groups:
76 | g['lr'] = cfg.lr * (cfg.lr_dec_factor ** idx)
77 | else:
78 | for g in self.optimizer.param_groups:
79 | g['lr'] = cfg.lr * (cfg.lr_dec_factor ** len(cfg.lr_dec_epoch))
80 |
81 | def get_lr(self):
82 | for g in self.optimizer.param_groups:
83 | cur_lr = g['lr']
84 | return cur_lr
85 |
86 | def _make_batch_generator(self):
87 | # data load and construct batch generator
88 | self.logger.info("Creating dataset...")
89 | train_dataset = eval(cfg.trainset)(transforms.ToTensor(), "train")
90 |
91 | self.itr_per_epoch = math.ceil(len(train_dataset) / cfg.num_gpus / cfg.train_batch_size)
92 | self.batch_generator = DataLoader(dataset=train_dataset, batch_size=cfg.num_gpus*cfg.train_batch_size, shuffle=True, num_workers=cfg.num_thread, pin_memory=True)
93 |
94 | def _make_model(self):
95 | # prepare network
96 | self.logger.info("Creating graph and optimizer...")
97 | model = get_model('train')
98 |
99 | model = DataParallel(model).cuda()
100 | optimizer = self.get_optimizer(model)
101 | if cfg.continue_train:
102 | start_epoch, model, optimizer = self.load_model(model, optimizer)
103 | else:
104 | start_epoch = 0
105 | model.train()
106 |
107 | self.start_epoch = start_epoch
108 | self.model = model
109 | self.optimizer = optimizer
110 |
111 | class Tester(Base):
112 | def __init__(self, test_epoch):
113 | self.test_epoch = int(test_epoch)
114 | super(Tester, self).__init__(log_name = 'test_logs.txt')
115 |
116 | def _make_batch_generator(self):
117 | # data load and construct batch generator
118 | self.logger.info("Creating dataset...")
119 | self.test_dataset = eval(cfg.testset)(transforms.ToTensor(), "test")
120 | self.batch_generator = DataLoader(dataset=self.test_dataset, batch_size=cfg.num_gpus*cfg.test_batch_size, shuffle=False, num_workers=cfg.num_thread, pin_memory=True)
121 |
122 | def _make_model(self):
123 | model_path = os.path.join(cfg.model_dir, 'snapshot_%d.pth.tar' % self.test_epoch)
124 | assert os.path.exists(model_path), 'Cannot find model at ' + model_path
125 | self.logger.info('Load checkpoint from {}'.format(model_path))
126 |
127 | # prepare network
128 | self.logger.info("Creating graph...")
129 | model = get_model('test')
130 | model = DataParallel(model).cuda()
131 | ckpt = torch.load(model_path)
132 | model.load_state_dict(ckpt['network'], strict=False)
133 | model.eval()
134 |
135 | self.model = model
136 |
137 | def _evaluate(self, outs, cur_sample_idx):
138 | eval_result = self.test_dataset.evaluate(outs, cur_sample_idx)
139 | return eval_result
140 |
141 | def _print_eval_result(self, test_epoch):
142 | self.test_dataset.print_eval_result(test_epoch)
--------------------------------------------------------------------------------
/common/logger.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 |
4 | OK = '\033[92m'
5 | WARNING = '\033[93m'
6 | FAIL = '\033[91m'
7 | END = '\033[0m'
8 |
9 | PINK = '\033[95m'
10 | BLUE = '\033[94m'
11 | GREEN = OK
12 | RED = FAIL
13 | WHITE = END
14 | YELLOW = WARNING
15 |
16 | class colorlogger():
17 | def __init__(self, log_dir, log_name='train_logs.txt'):
18 | # set log
19 | self._logger = logging.getLogger(log_name)
20 | self._logger.setLevel(logging.INFO)
21 | log_file = os.path.join(log_dir, log_name)
22 | if not os.path.exists(log_dir):
23 | os.makedirs(log_dir)
24 | file_log = logging.FileHandler(log_file, mode='a')
25 | file_log.setLevel(logging.INFO)
26 | console_log = logging.StreamHandler()
27 | console_log.setLevel(logging.INFO)
28 | formatter = logging.Formatter(
29 | "{}%(asctime)s{} %(message)s".format(GREEN, END),
30 | "%m-%d %H:%M:%S")
31 | file_log.setFormatter(formatter)
32 | console_log.setFormatter(formatter)
33 | self._logger.addHandler(file_log)
34 | self._logger.addHandler(console_log)
35 |
36 | def debug(self, msg):
37 | self._logger.debug(str(msg))
38 |
39 | def info(self, msg):
40 | self._logger.info(str(msg))
41 |
42 | def warning(self, msg):
43 | self._logger.warning(WARNING + 'WRN: ' + str(msg) + END)
44 |
45 | def critical(self, msg):
46 | self._logger.critical(RED + 'CRI: ' + str(msg) + END)
47 |
48 | def error(self, msg):
49 | self._logger.error(RED + 'ERR: ' + str(msg) + END)
--------------------------------------------------------------------------------
/common/nets/backbone.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | import torch.utils.model_zoo as model_zoo
4 |
5 | from torchvision import ops
6 | import torch
7 |
8 | from nets.cbam import SpatialGate
9 |
10 | class FPN(nn.Module):
11 | def __init__(self, pretrained=True):
12 | super(FPN, self).__init__()
13 | self.in_planes = 64
14 |
15 | resnet = resnet50(pretrained=pretrained)
16 |
17 | self.toplayer = nn.Conv2d(2048, 256, kernel_size=1, stride=1, padding=0) # Reduce channels
18 |
19 | self.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.leakyrelu, resnet.maxpool)
20 | self.layer1 = nn.Sequential(resnet.layer1)
21 | self.layer2 = nn.Sequential(resnet.layer2)
22 | self.layer3 = nn.Sequential(resnet.layer3)
23 | self.layer4 = nn.Sequential(resnet.layer4)
24 |
25 | # Smooth layers
26 | #self.smooth1 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
27 | self.smooth2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
28 | self.smooth3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
29 |
30 | # Lateral layers
31 | self.latlayer1 = nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0)
32 | self.latlayer2 = nn.Conv2d( 512, 256, kernel_size=1, stride=1, padding=0)
33 | self.latlayer3 = nn.Conv2d( 256, 256, kernel_size=1, stride=1, padding=0)
34 |
35 | # Attention Module
36 | self.attention_module = SpatialGate()
37 |
38 | self.pool = nn.AvgPool2d(2, stride=2)
39 |
40 | def _upsample_add(self, x, y):
41 | _, _, H, W = y.size()
42 | return F.interpolate(x, size=(H,W), mode='bilinear', align_corners=False) + y
43 |
44 | def forward(self, x):
45 | # Bottom-up
46 | c1 = self.layer0(x)
47 | c2 = self.layer1(c1)
48 | c3 = self.layer2(c2)
49 | c4 = self.layer3(c3)
50 | c5 = self.layer4(c4)
51 | # Top-down
52 | p5 = self.toplayer(c5)
53 | p4 = self._upsample_add(p5, self.latlayer1(c4))
54 | p3 = self._upsample_add(p4, self.latlayer2(c3))
55 | p2 = self._upsample_add(p3, self.latlayer3(c2))
56 | # Smooth
57 | #p4 = self.smooth1(p4)
58 | p3 = self.smooth2(p3)
59 | p2 = self.smooth3(p2)
60 |
61 | # Attention
62 | p2 = self.pool(p2)
63 | primary_feats, secondary_feats = self.attention_module(p2)
64 |
65 | return primary_feats, secondary_feats
66 |
67 |
68 | class ResNet(nn.Module):
69 | def __init__(self, block, layers, num_classes=1000):
70 | self.inplanes = 64
71 | super(ResNet, self).__init__()
72 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
73 | self.bn1 = nn.BatchNorm2d(64)
74 | self.leakyrelu = nn.LeakyReLU(inplace=True)
75 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
76 | self.layer1 = self._make_layer(block, 64, layers[0])
77 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
78 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
79 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
80 | self.avgpool = nn.AvgPool2d(7, stride=1)
81 | self.fc = nn.Linear(512 * block.expansion, num_classes)
82 |
83 | for m in self.modules():
84 | if isinstance(m, nn.Conv2d):
85 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="leaky_relu")
86 | elif isinstance(m, nn.BatchNorm2d):
87 | nn.init.constant_(m.weight, 1)
88 | nn.init.constant_(m.bias, 0)
89 |
90 | def _make_layer(self, block, planes, blocks, stride=1):
91 | downsample = None
92 | if stride != 1 or self.inplanes != planes * block.expansion:
93 | downsample = nn.Sequential(
94 | nn.Conv2d(self.inplanes, planes * block.expansion,
95 | kernel_size=1, stride=stride, bias=False),
96 | nn.BatchNorm2d(planes * block.expansion))
97 | layers = []
98 | layers.append(block(self.inplanes, planes, stride, downsample))
99 | self.inplanes = planes * block.expansion
100 | for i in range(1, blocks):
101 | layers.append(block(self.inplanes, planes))
102 |
103 | return nn.Sequential(*layers)
104 |
105 | def forward(self, x):
106 | x = self.conv1(x)
107 | x = self.bn1(x)
108 | x = self.leakyrelu(x)
109 | x = self.maxpool(x)
110 |
111 | x = self.layer1(x)
112 | x = self.layer2(x)
113 | x = self.layer3(x)
114 | x = self.layer4(x)
115 |
116 | x = x.mean(3).mean(2)
117 | x = x.view(x.size(0), -1)
118 | x = self.fc(x)
119 | return x
120 |
121 |
122 | def resnet50(pretrained=False, **kwargs):
123 | """Constructs a ResNet-50 model Encoder"""
124 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
125 | if pretrained:
126 | model.load_state_dict(model_zoo.load_url("https://download.pytorch.org/models/resnet50-19c8e357.pth"))
127 | return model
128 |
129 |
130 | def conv3x3(in_planes, out_planes, stride=1):
131 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
132 |
133 |
134 | class BasicBlock(nn.Module):
135 | expansion = 1
136 |
137 | def __init__(self, inplanes, planes, stride=1, downsample=None):
138 | super(BasicBlock, self).__init__()
139 | self.conv1 = conv3x3(inplanes, planes, stride)
140 | self.bn1 = nn.BatchNorm2d(planes)
141 | self.leakyrelu = nn.LeakyReLU(inplace=True)
142 | self.conv2 = conv3x3(planes, planes)
143 | self.bn2 = nn.BatchNorm2d(planes)
144 | self.downsample = downsample
145 | self.stride = stride
146 |
147 | def forward(self, x):
148 | residual = x
149 |
150 | out = self.conv1(x)
151 | out = self.bn1(out)
152 | out = self.leakyrelu(out)
153 |
154 | out = self.conv2(out)
155 | out = self.bn2(out)
156 |
157 | if self.downsample is not None:
158 | residual = self.downsample(x)
159 |
160 | out += residual
161 | out = self.leakyrelu(out)
162 |
163 | return out
164 |
165 |
166 | class Bottleneck(nn.Module):
167 | expansion = 4
168 |
169 | def __init__(self, inplanes, planes, stride=1, downsample=None):
170 | super(Bottleneck, self).__init__()
171 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
172 | self.bn1 = nn.BatchNorm2d(planes)
173 | self.conv2 = nn.Conv2d(
174 | planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
175 | )
176 | self.bn2 = nn.BatchNorm2d(planes)
177 | self.conv3 = nn.Conv2d(
178 | planes, planes * self.expansion, kernel_size=1, bias=False
179 | )
180 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
181 | self.leakyrelu = nn.LeakyReLU(inplace=True)
182 | self.downsample = downsample
183 | self.stride = stride
184 |
185 | def forward(self, x):
186 | residual = x
187 |
188 | out = self.conv1(x)
189 | out = self.bn1(out)
190 | out = self.leakyrelu(out)
191 |
192 | out = self.conv2(out)
193 | out = self.bn2(out)
194 | out = self.leakyrelu(out)
195 |
196 | out = self.conv3(out)
197 | out = self.bn3(out)
198 |
199 | if self.downsample is not None:
200 | residual = self.downsample(x)
201 |
202 | out += residual
203 | out = self.leakyrelu(out)
204 |
205 | return out
--------------------------------------------------------------------------------
/common/nets/cbam.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | class BasicConv(nn.Module):
7 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False):
8 | super(BasicConv, self).__init__()
9 | self.out_channels = out_planes
10 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
11 | self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None
12 | self.relu = nn.ReLU() if relu else None
13 |
14 | def forward(self, x):
15 | x = self.conv(x)
16 | if self.bn is not None:
17 | x = self.bn(x)
18 | if self.relu is not None:
19 | x = self.relu(x)
20 | return x
21 |
22 | class Flatten(nn.Module):
23 | def forward(self, x):
24 | return x.view(x.size(0), -1)
25 |
26 | class ChannelGate(nn.Module):
27 | def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']):
28 | super(ChannelGate, self).__init__()
29 | self.gate_channels = gate_channels
30 | self.mlp = nn.Sequential(
31 | Flatten(),
32 | nn.Linear(gate_channels, gate_channels // reduction_ratio),
33 | nn.ReLU(),
34 | nn.Linear(gate_channels // reduction_ratio, gate_channels)
35 | )
36 | self.pool_types = pool_types
37 | def forward(self, x):
38 | channel_att_sum = None
39 | for pool_type in self.pool_types:
40 | if pool_type=='avg':
41 | avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
42 | channel_att_raw = self.mlp( avg_pool )
43 | elif pool_type=='max':
44 | max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
45 | channel_att_raw = self.mlp( max_pool )
46 | elif pool_type=='lp':
47 | lp_pool = F.lp_pool2d( x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
48 | channel_att_raw = self.mlp( lp_pool )
49 | elif pool_type=='lse':
50 | # LSE pool only
51 | lse_pool = logsumexp_2d(x)
52 | channel_att_raw = self.mlp( lse_pool )
53 |
54 | if channel_att_sum is None:
55 | channel_att_sum = channel_att_raw
56 | else:
57 | channel_att_sum = channel_att_sum + channel_att_raw
58 |
59 | scale = F.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x)
60 | return x * scale
61 |
62 | def logsumexp_2d(tensor):
63 | tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1)
64 | s, _ = torch.max(tensor_flatten, dim=2, keepdim=True)
65 | outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log()
66 | return outputs
67 |
68 | class ChannelPool(nn.Module):
69 | def forward(self, x):
70 | return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 )
71 |
72 | class SpatialGate(nn.Module):
73 | def __init__(self):
74 | super(SpatialGate, self).__init__()
75 | kernel_size = 7
76 | self.compress = ChannelPool()
77 | self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False)
78 | def forward(self, x):
79 | x_compress = self.compress(x)
80 | x_out = self.spatial(x_compress)
81 | scale = F.sigmoid(x_out) # broadcasting
82 | return x*scale, x*(1-scale)
83 |
84 | class CBAM(nn.Module):
85 | def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False):
86 | super(CBAM, self).__init__()
87 | self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types)
88 | self.no_spatial=no_spatial
89 | if not no_spatial:
90 | self.SpatialGate = SpatialGate()
91 | def forward(self, x):
92 | x_out = self.ChannelGate(x)
93 | if not self.no_spatial:
94 | x_out = self.SpatialGate(x_out)
95 | return x_out
--------------------------------------------------------------------------------
/common/nets/hand_head.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 |
5 | class hand_regHead(nn.Module):
6 | def __init__(self, roi_res=32, joint_nb=21, stacks=1, channels=256, blocks=1):
7 | """
8 | Args:
9 | inr_res: input image size
10 | joint_nb: hand joint num
11 | """
12 | super(hand_regHead, self).__init__()
13 |
14 | # hand head
15 | self.out_res = roi_res
16 | self.joint_nb = joint_nb
17 |
18 | self.channels = channels
19 | self.blocks = blocks
20 | self.stacks = stacks
21 |
22 | self.betas = nn.Parameter(torch.ones((self.joint_nb, 1), dtype=torch.float32))
23 |
24 | center_offset = 0.5
25 | vv, uu = torch.meshgrid(torch.arange(self.out_res).float(), torch.arange(self.out_res).float())
26 | uu, vv = uu + center_offset, vv + center_offset
27 | self.register_buffer("uu", uu / self.out_res)
28 | self.register_buffer("vv", vv / self.out_res)
29 |
30 | self.softmax = nn.Softmax(dim=2)
31 | block = Bottleneck
32 | self.features = self.channels // block.expansion
33 |
34 | hg, res, fc, score, fc_, score_ = [], [], [], [], [], []
35 | for i in range(self.stacks):
36 | hg.append(Hourglass(block, self.blocks, self.features, 4))
37 | res.append(self.make_residual(block, self.channels, self.features, self.blocks))
38 | fc.append(BasicBlock(self.channels, self.channels, kernel_size=1))
39 | score.append(nn.Conv2d(self.channels, self.joint_nb, kernel_size=1, bias=True))
40 | if i < self.stacks - 1:
41 | fc_.append(nn.Conv2d(self.channels, self.channels, kernel_size=1, bias=True))
42 | score_.append(nn.Conv2d(self.joint_nb, self.channels, kernel_size=1, bias=True))
43 |
44 | self.hg = nn.ModuleList(hg)
45 | self.res = nn.ModuleList(res)
46 | self.fc = nn.ModuleList(fc)
47 | self.score = nn.ModuleList(score)
48 | self.fc_ = nn.ModuleList(fc_)
49 | self.score_ = nn.ModuleList(score_)
50 |
51 | def make_residual(self, block, inplanes, planes, blocks, stride=1):
52 | skip = None
53 | if stride != 1 or inplanes != planes * block.expansion:
54 | skip = nn.Sequential(
55 | nn.Conv2d(inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=True))
56 | layers = []
57 | layers.append(block(inplanes, planes, stride, skip))
58 | for i in range(1, blocks):
59 | layers.append(block(inplanes, planes))
60 | return nn.Sequential(*layers)
61 |
62 | def spatial_softmax(self, latents):
63 | latents = latents.view((-1, self.joint_nb, self.out_res ** 2))
64 | latents = latents * self.betas
65 | heatmaps = self.softmax(latents)
66 | heatmaps = heatmaps.view(-1, self.joint_nb, self.out_res, self.out_res)
67 | return heatmaps
68 |
69 | def generate_output(self, heatmaps):
70 | predictions = torch.stack((
71 | torch.sum(torch.sum(heatmaps * self.uu, dim=2), dim=2),
72 | torch.sum(torch.sum(heatmaps * self.vv, dim=2), dim=2)), dim=2)
73 | return predictions
74 |
75 | def forward(self, x):
76 | out, encoding, preds = [], [], []
77 | for i in range(self.stacks):
78 | y = self.hg[i](x)
79 | y = self.res[i](y)
80 | y = self.fc[i](y)
81 | latents = self.score[i](y)
82 | heatmaps= self.spatial_softmax(latents)
83 | out.append(heatmaps)
84 | predictions = self.generate_output(heatmaps)
85 | preds.append(predictions)
86 | if i < self.stacks - 1:
87 | fc_ = self.fc_[i](y)
88 | score_ = self.score_[i](heatmaps)
89 | x = x + fc_ + score_
90 | encoding.append(x)
91 | else:
92 | encoding.append(y)
93 | return out, encoding, preds
94 |
95 |
96 | class BasicBlock(nn.Module):
97 | def __init__(self, in_planes, out_planes, kernel_size,groups=1):
98 | super(BasicBlock, self).__init__()
99 | self.block = nn.Sequential(
100 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size,
101 | stride=1, padding=((kernel_size - 1) // 2),
102 | groups=groups,bias=True),
103 | nn.BatchNorm2d(out_planes),
104 | nn.LeakyReLU(inplace=True)
105 | )
106 |
107 | def forward(self, x):
108 | return self.block(x)
109 |
110 |
111 | class Residual(nn.Module):
112 | def __init__(self, numIn, numOut):
113 | super(Residual, self).__init__()
114 | self.numIn = numIn
115 | self.numOut = numOut
116 | self.bn = nn.BatchNorm2d(self.numIn)
117 | self.leakyrelu = nn.LeakyReLU(inplace=True)
118 | self.conv1 = nn.Conv2d(self.numIn, self.numOut // 2, bias=True, kernel_size=1)
119 | self.bn1 = nn.BatchNorm2d(self.numOut // 2)
120 | self.conv2 = nn.Conv2d(self.numOut // 2, self.numOut // 2, bias=True, kernel_size=3, stride=1, padding=1)
121 | self.bn2 = nn.BatchNorm2d(self.numOut // 2)
122 | self.conv3 = nn.Conv2d(self.numOut // 2, self.numOut, bias=True, kernel_size=1)
123 |
124 | if self.numIn != self.numOut:
125 | self.conv4 = nn.Conv2d(self.numIn, self.numOut, bias=True, kernel_size=1)
126 |
127 | def forward(self, x):
128 | residual = x
129 | out = self.bn(x)
130 | out = self.leakyrelu(out)
131 | out = self.conv1(out)
132 | out = self.bn1(out)
133 | out = self.leakyrelu(out)
134 | out = self.conv2(out)
135 | out = self.bn2(out)
136 | out = self.leakyrelu(out)
137 | out = self.conv3(out)
138 |
139 | if self.numIn != self.numOut:
140 | residual = self.conv4(x)
141 |
142 | return out + residual
143 |
144 |
145 | class Bottleneck(nn.Module):
146 | expansion = 2
147 |
148 | def __init__(self, inplanes, planes, stride=1, skip=None, groups=1):
149 | super(Bottleneck, self).__init__()
150 |
151 | self.bn1 = nn.BatchNorm2d(inplanes)
152 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=True, groups=groups)
153 | self.bn2 = nn.BatchNorm2d(planes)
154 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
155 | padding=1, bias=True, groups=groups)
156 | self.bn3 = nn.BatchNorm2d(planes)
157 | self.conv3 = nn.Conv2d(planes, planes * 2, kernel_size=1, bias=True, groups=groups)
158 | self.leakyrelu = nn.LeakyReLU(inplace=True) # negative_slope=0.01
159 | self.skip = skip
160 | self.stride = stride
161 |
162 | def forward(self, x):
163 | residual = x
164 |
165 | out = self.bn1(x)
166 | out = self.leakyrelu(out)
167 | out = self.conv1(out)
168 |
169 | out = self.bn2(out)
170 | out = self.leakyrelu(out)
171 | out = self.conv2(out)
172 |
173 | out = self.bn3(out)
174 | out = self.leakyrelu(out)
175 | out = self.conv3(out)
176 |
177 | if self.skip is not None:
178 | residual = self.skip(x)
179 |
180 | out += residual
181 |
182 | return out
183 |
184 |
185 | class Hourglass(nn.Module):
186 | def __init__(self, block, num_blocks, planes, depth):
187 |
188 | super(Hourglass, self).__init__()
189 | self.depth = depth
190 | self.block = block
191 | self.hg = self._make_hour_glass(block, num_blocks, planes, depth)
192 |
193 | def _make_residual(self, block, num_blocks, planes):
194 |
195 | layers = []
196 | for i in range(0, num_blocks):
197 | # channel changes: planes*block.expansion->planes->2*planes
198 | layers.append(block(planes * block.expansion, planes))
199 | return nn.Sequential(*layers)
200 |
201 | def _make_hour_glass(self, block, num_blocks, planes, depth):
202 | hg = []
203 | for i in range(depth):
204 | res = []
205 | for j in range(3):
206 | # 3 residual modules composed of a residual unit
207 | # <2*planes><2*planes>
208 | res.append(self._make_residual(block, num_blocks, planes))
209 | if i == 0:
210 | # i=0 in a recursive construction build the basic network path
211 | # see: low2 = self.hg[n-1][3](low1)
212 | # <2*planes><2*planes>
213 | res.append(self._make_residual(block, num_blocks, planes))
214 | hg.append(nn.ModuleList(res))
215 | return nn.ModuleList(hg)
216 |
217 | def _hour_glass_forward(self, n, x):
218 | up1 = self.hg[n - 1][0](x) # skip branches
219 | low1 = F.max_pool2d(x, 2, stride=2)
220 | low1 = self.hg[n - 1][1](low1)
221 |
222 | if n > 1:
223 | low2 = self._hour_glass_forward(n - 1, low1)
224 | else:
225 | low2 = self.hg[n - 1][3](low1) # only for depth=1 basic path of the hourglass network
226 | low3 = self.hg[n - 1][2](low2)
227 | up2 = F.interpolate(low3, scale_factor=2) # scale_factor=2 should be consistent with F.max_pool2d(2,stride=2)
228 | out = up1 + up2
229 | return out
230 |
231 | def forward(self, x):
232 | # depth: order of the hourglass network
233 | # do network forward recursively
234 | return self._hour_glass_forward(self.depth, x)
235 |
236 |
237 | class hand_Encoder(nn.Module):
238 | def __init__(self, num_heatmap_chan=21, num_feat_chan=256, size_input_feature=(32, 32),
239 | nRegBlock=4, nRegModules=2):
240 | super(hand_Encoder, self).__init__()
241 |
242 | self.num_heatmap_chan = num_heatmap_chan
243 | self.num_feat_chan = num_feat_chan
244 | self.size_input_feature = size_input_feature
245 |
246 | self.nRegBlock = nRegBlock
247 | self.nRegModules = nRegModules
248 |
249 | self.heatmap_conv = nn.Conv2d(self.num_heatmap_chan, self.num_feat_chan,
250 | bias=True, kernel_size=1, stride=1)
251 | self.encoding_conv = nn.Conv2d(self.num_feat_chan, self.num_feat_chan,
252 | bias=True, kernel_size=1, stride=1)
253 |
254 | reg = []
255 | for i in range(self.nRegBlock):
256 | for j in range(self.nRegModules):
257 | reg.append(Residual(self.num_feat_chan, self.num_feat_chan))
258 |
259 | self.reg = nn.ModuleList(reg)
260 | self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
261 | self.downsample_scale = 2 ** self.nRegBlock
262 |
263 | # fc layers
264 | self.num_feat_out = self.num_feat_chan * (size_input_feature[0] * size_input_feature[1] // (self.downsample_scale ** 2))
265 |
266 | def forward(self, hm_list, encoding_list):
267 | x = self.heatmap_conv(hm_list[-1]) + self.encoding_conv(encoding_list[-1])
268 | if len(encoding_list) > 1:
269 | x = x + encoding_list[-2]
270 |
271 | # x: B x num_feat_chan x 32 x 32
272 | for i in range(self.nRegBlock):
273 | for j in range(self.nRegModules):
274 | x = self.reg[i * self.nRegModules + j](x)
275 | x = self.maxpool(x)
276 |
277 | # x: B x num_feat_chan x 2 x 2
278 | out = x.view(x.size(0), -1)
279 |
280 | return out
281 |
--------------------------------------------------------------------------------
/common/nets/mano_head.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 | from utils.mano import MANO
5 | mano = MANO()
6 |
7 | def batch_rodrigues(theta):
8 | # theta N x 3
9 | l1norm = torch.norm(theta + 1e-8, p=2, dim=1)
10 | angle = torch.unsqueeze(l1norm, -1)
11 | normalized = torch.div(theta, angle)
12 | angle = angle * 0.5
13 | v_cos = torch.cos(angle)
14 | v_sin = torch.sin(angle)
15 | quat = torch.cat([v_cos, v_sin * normalized], dim=1)
16 |
17 | return quat2mat(quat)
18 |
19 |
20 | def quat2mat(quat):
21 | """Convert quaternion coefficients to rotation matrix.
22 | """
23 | norm_quat = quat
24 | norm_quat = norm_quat / norm_quat.norm(p=2, dim=1, keepdim=True)
25 | w, x, y, z = norm_quat[:, 0], norm_quat[:, 1], norm_quat[:, 2], norm_quat[:, 3]
26 |
27 | B = quat.size(0)
28 |
29 | w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
30 | wx, wy, wz = w * x, w * y, w * z
31 | xy, xz, yz = x * y, x * z, y * z
32 |
33 | rotMat = torch.stack([w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz,
34 | 2 * wz + 2 * xy, w2 - x2 + y2 - z2, 2 * yz - 2 * wx,
35 | 2 * xz - 2 * wy, 2 * wx + 2 * yz, w2 - x2 - y2 + z2], dim=1).view(B, 3, 3)
36 | return rotMat
37 |
38 |
39 | def quat2aa(quaternion):
40 | """Convert quaternion vector to angle axis of rotation."""
41 | if not torch.is_tensor(quaternion):
42 | raise TypeError("Input type is not a torch.Tensor. Got {}".format(
43 | type(quaternion)))
44 |
45 | if not quaternion.shape[-1] == 4:
46 | raise ValueError("Input must be a tensor of shape Nx4 or 4. Got {}"
47 | .format(quaternion.shape))
48 | # unpack input and compute conversion
49 | q1 = quaternion[..., 1]
50 | q2 = quaternion[..., 2]
51 | q3 = quaternion[..., 3]
52 | sin_squared_theta = q1 * q1 + q2 * q2 + q3 * q3
53 |
54 | sin_theta = torch.sqrt(sin_squared_theta)
55 | cos_theta = quaternion[..., 0]
56 | two_theta = 2.0 * torch.where(
57 | cos_theta < 0.0,
58 | torch.atan2(-sin_theta, -cos_theta),
59 | torch.atan2(sin_theta, cos_theta))
60 |
61 | k_pos = two_theta / sin_theta
62 | k_neg = 2.0 * torch.ones_like(sin_theta)
63 | k = torch.where(sin_squared_theta > 0.0, k_pos, k_neg)
64 |
65 | angle_axis = torch.zeros_like(quaternion)[..., :3]
66 | angle_axis[..., 0] += q1 * k
67 | angle_axis[..., 1] += q2 * k
68 | angle_axis[..., 2] += q3 * k
69 | return angle_axis
70 |
71 |
72 | def mat2quat(rotation_matrix, eps=1e-6):
73 | """Convert 3x4 rotation matrix to 4d quaternion vector"""
74 | if not torch.is_tensor(rotation_matrix):
75 | raise TypeError("Input type is not a torch.Tensor. Got {}".format(
76 | type(rotation_matrix)))
77 |
78 | if len(rotation_matrix.shape) > 3:
79 | raise ValueError(
80 | "Input size must be a three dimensional tensor. Got {}".format(
81 | rotation_matrix.shape))
82 | if not rotation_matrix.shape[-2:] == (3, 4):
83 | raise ValueError(
84 | "Input size must be a N x 3 x 4 tensor. Got {}".format(
85 | rotation_matrix.shape))
86 |
87 | rmat_t = torch.transpose(rotation_matrix, 1, 2)
88 |
89 | mask_d2 = rmat_t[:, 2, 2] < eps
90 |
91 | mask_d0_d1 = rmat_t[:, 0, 0] > rmat_t[:, 1, 1]
92 | mask_d0_nd1 = rmat_t[:, 0, 0] < -rmat_t[:, 1, 1]
93 |
94 | t0 = 1 + rmat_t[:, 0, 0] - rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
95 | q0 = torch.stack([rmat_t[:, 1, 2] - rmat_t[:, 2, 1],
96 | t0, rmat_t[:, 0, 1] + rmat_t[:, 1, 0],
97 | rmat_t[:, 2, 0] + rmat_t[:, 0, 2]], -1)
98 | t0_rep = t0.repeat(4, 1).t()
99 |
100 | t1 = 1 - rmat_t[:, 0, 0] + rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
101 | q1 = torch.stack([rmat_t[:, 2, 0] - rmat_t[:, 0, 2],
102 | rmat_t[:, 0, 1] + rmat_t[:, 1, 0],
103 | t1, rmat_t[:, 1, 2] + rmat_t[:, 2, 1]], -1)
104 | t1_rep = t1.repeat(4, 1).t()
105 |
106 | t2 = 1 - rmat_t[:, 0, 0] - rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
107 | q2 = torch.stack([rmat_t[:, 0, 1] - rmat_t[:, 1, 0],
108 | rmat_t[:, 2, 0] + rmat_t[:, 0, 2],
109 | rmat_t[:, 1, 2] + rmat_t[:, 2, 1], t2], -1)
110 | t2_rep = t2.repeat(4, 1).t()
111 |
112 | t3 = 1 + rmat_t[:, 0, 0] + rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
113 | q3 = torch.stack([t3, rmat_t[:, 1, 2] - rmat_t[:, 2, 1],
114 | rmat_t[:, 2, 0] - rmat_t[:, 0, 2],
115 | rmat_t[:, 0, 1] - rmat_t[:, 1, 0]], -1)
116 | t3_rep = t3.repeat(4, 1).t()
117 |
118 | mask_c0 = mask_d2 * mask_d0_d1
119 | mask_c1 = mask_d2 * ~mask_d0_d1
120 | mask_c2 = ~mask_d2 * mask_d0_nd1
121 | mask_c3 = ~mask_d2 * ~mask_d0_nd1
122 | mask_c0 = mask_c0.view(-1, 1).type_as(q0)
123 | mask_c1 = mask_c1.view(-1, 1).type_as(q1)
124 | mask_c2 = mask_c2.view(-1, 1).type_as(q2)
125 | mask_c3 = mask_c3.view(-1, 1).type_as(q3)
126 |
127 | q = q0 * mask_c0 + q1 * mask_c1 + q2 * mask_c2 + q3 * mask_c3
128 | q /= torch.sqrt(t0_rep * mask_c0 + t1_rep * mask_c1 + # noqa
129 | t2_rep * mask_c2 + t3_rep * mask_c3) # noqa
130 | q *= 0.5
131 | return q
132 |
133 |
134 | def rot6d2mat(x):
135 | """Convert 6D rotation representation to 3x3 rotation matrix.
136 | Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019
137 | """
138 | a1 = x[:, 0:3]
139 | a2 = x[:, 3:6]
140 | b1 = F.normalize(a1)
141 | b2 = F.normalize(a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1)
142 | b3 = torch.cross(b1, b2, dim=1)
143 | return torch.stack((b1, b2, b3), dim=-1)
144 |
145 |
146 | def mat2aa(rotation_matrix):
147 | """Convert 3x4 rotation matrix to Rodrigues vector"""
148 |
149 | def convert_points_to_homogeneous(points):
150 | if not torch.is_tensor(points):
151 | raise TypeError("Input type is not a torch.Tensor. Got {}".format(
152 | type(points)))
153 | if len(points.shape) < 2:
154 | raise ValueError("Input must be at least a 2D tensor. Got {}".format(
155 | points.shape))
156 |
157 | return F.pad(points, (0, 1), "constant", 1.0)
158 |
159 | if rotation_matrix.shape[1:] == (3, 3):
160 | rotation_matrix = convert_points_to_homogeneous(rotation_matrix)
161 | quaternion = mat2quat(rotation_matrix)
162 | aa = quat2aa(quaternion)
163 | aa[torch.isnan(aa)] = 0.0
164 | return aa
165 |
166 |
167 | class mano_regHead(nn.Module):
168 | def __init__(self, mano_layer=mano.layer, feature_size=1024, mano_neurons=[1024, 512]):
169 | super(mano_regHead, self).__init__()
170 |
171 | # 6D representation of rotation matrix
172 | self.pose6d_size = 16 * 6
173 | self.mano_pose_size = 16 * 3
174 |
175 | # Base Regression Layers
176 | mano_base_neurons = [feature_size] + mano_neurons
177 | base_layers = []
178 | for layer_idx, (inp_neurons, out_neurons) in enumerate(
179 | zip(mano_base_neurons[:-1], mano_base_neurons[1:])):
180 | base_layers.append(nn.Linear(inp_neurons, out_neurons))
181 | base_layers.append(nn.LeakyReLU(inplace=True))
182 | self.mano_base_layer = nn.Sequential(*base_layers)
183 | # Pose layers
184 | self.pose_reg = nn.Linear(mano_base_neurons[-1], self.pose6d_size)
185 | # Shape layers
186 | self.shape_reg = nn.Linear(mano_base_neurons[-1], 10)
187 |
188 | self.mano_layer = mano_layer
189 |
190 | def forward(self, features, gt_mano_params=None):
191 | mano_features = self.mano_base_layer(features)
192 | pred_mano_pose_6d = self.pose_reg(mano_features)
193 |
194 | pred_mano_pose_rotmat = rot6d2mat(pred_mano_pose_6d.view(-1, 6)).view(-1, 16, 3, 3).contiguous()
195 | pred_mano_shape = self.shape_reg(mano_features)
196 | pred_mano_pose = mat2aa(pred_mano_pose_rotmat.view(-1, 3, 3)).contiguous().view(-1, self.mano_pose_size)
197 | pred_verts, pred_joints, pred_manojoints2cam = self.mano_layer(th_pose_coeffs=pred_mano_pose, th_betas=pred_mano_shape)
198 |
199 | pred_verts /= 1000
200 | pred_joints /= 1000
201 |
202 | pred_mano_results = {
203 | "verts3d": pred_verts,
204 | "joints3d": pred_joints,
205 | "mano_shape": pred_mano_shape,
206 | "mano_pose": pred_mano_pose_rotmat,
207 | "mano_pose_aa": pred_mano_pose,
208 | "manojoints2cam": pred_manojoints2cam
209 | }
210 |
211 | if gt_mano_params is not None:
212 | gt_mano_shape = gt_mano_params[:, self.mano_pose_size:]
213 | gt_mano_pose = gt_mano_params[:, :self.mano_pose_size].contiguous()
214 | gt_mano_pose_rotmat = batch_rodrigues(gt_mano_pose.view(-1, 3)).view(-1, 16, 3, 3)
215 | gt_verts, gt_joints = self.mano_layer(th_pose_coeffs=gt_mano_pose, th_betas=gt_mano_shape)
216 |
217 | gt_verts /= 1000
218 | gt_joints /= 1000
219 |
220 | gt_mano_results = {
221 | "verts3d": gt_verts,
222 | "joints3d": gt_joints,
223 | "mano_shape": gt_mano_shape,
224 | "mano_pose": gt_mano_pose_rotmat}
225 | else:
226 | gt_mano_results = None
227 |
228 | return pred_mano_results, gt_mano_results
229 |
--------------------------------------------------------------------------------
/common/nets/regressor.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import functional as F
4 | from utils.mano import MANO
5 | from nets.hand_head import hand_regHead, hand_Encoder
6 | from nets.mano_head import mano_regHead
7 |
8 | class Regressor(nn.Module):
9 | def __init__(self):
10 | super(Regressor, self).__init__()
11 | self.hand_regHead = hand_regHead()
12 | self.hand_Encoder = hand_Encoder()
13 | self.mano_regHead = mano_regHead()
14 |
15 | def forward(self, feats, gt_mano_params=None):
16 | out_hm, encoding, preds_joints_img = self.hand_regHead(feats)
17 | mano_encoding = self.hand_Encoder(out_hm, encoding)
18 | pred_mano_results, gt_mano_results = self.mano_regHead(mano_encoding, gt_mano_params)
19 |
20 | return pred_mano_results, gt_mano_results, preds_joints_img
21 |
--------------------------------------------------------------------------------
/common/nets/transformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from einops import repeat
5 |
6 | class Transformer(nn.Module):
7 | def __init__(self, inp_res=32, dim=256, depth=2, num_heads=4, mlp_ratio=4., injection=True):
8 | super().__init__()
9 |
10 | self.injection=injection
11 |
12 | self.layers = nn.ModuleList([])
13 | for _ in range(depth):
14 | self.layers.append(Block(dim=dim, num_heads=num_heads, mlp_ratio=mlp_ratio, injection=injection))
15 |
16 | if self.injection:
17 | self.conv1 = nn.Sequential(
18 | nn.Conv2d(dim*2, dim, 3, padding=1),
19 | nn.ReLU(),
20 | nn.Conv2d(dim, dim, 3, padding=1),
21 | )
22 | self.conv2 = nn.Sequential(
23 | nn.Conv2d(dim*2, dim, 1, padding=0),
24 | )
25 |
26 | def forward(self, query, key):
27 | output = query
28 | for i, layer in enumerate(self.layers):
29 | output = layer(query=output, key=key)
30 |
31 | if self.injection:
32 | output = torch.cat([key, output], dim=1)
33 | output = self.conv1(output) + self.conv2(output)
34 |
35 | return output
36 |
37 | class Mlp(nn.Module):
38 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
39 | super().__init__()
40 | out_features = out_features or in_features
41 | hidden_features = hidden_features or in_features
42 | self.fc1 = nn.Linear(in_features, hidden_features)
43 | self.act = act_layer()
44 | self.fc2 = nn.Linear(hidden_features, out_features)
45 | self.drop = nn.Dropout(drop)
46 | self._init_weights()
47 |
48 | def forward(self, x):
49 | x = self.fc1(x)
50 | x = self.act(x)
51 | x = self.drop(x)
52 | x = self.fc2(x)
53 | x = self.drop(x)
54 | return x
55 |
56 | def _init_weights(self):
57 | nn.init.xavier_uniform_(self.fc1.weight)
58 | nn.init.xavier_uniform_(self.fc2.weight)
59 | nn.init.normal_(self.fc1.bias, std=1e-6)
60 | nn.init.normal_(self.fc2.bias, std=1e-6)
61 |
62 |
63 | class Attention(nn.Module):
64 | def __init__(self, dim, num_heads=1):
65 | super().__init__()
66 | self.num_heads = num_heads
67 | head_dim = dim // num_heads
68 | self.scale = head_dim ** -0.5
69 | self.sigmoid = nn.Sigmoid()
70 |
71 | def forward(self, query, key, value, query2, key2, use_sigmoid):
72 | B, N, C = query.shape
73 | query = query.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
74 | key = key.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
75 | value = value.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
76 | attn = torch.matmul(query, key.transpose(-2, -1)) * self.scale
77 | attn = attn.softmax(dim=-1)
78 |
79 | if use_sigmoid:
80 | query2 = query2.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
81 | key2 = key2.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
82 | attn2 = torch.matmul(query2, key2.transpose(-2, -1)) * self.scale
83 | attn2 = torch.sum(attn2, dim=-1)
84 | attn2 = self.sigmoid(attn2)
85 | attn = attn * attn2.unsqueeze(3)
86 |
87 | x = torch.matmul(attn, value).transpose(1, 2).reshape(B, N, C)
88 | return x
89 |
90 | class Block(nn.Module):
91 |
92 | def __init__(self, dim, num_heads, mlp_ratio=4., act_layer=nn.GELU, norm_layer=nn.LayerNorm, injection=True):
93 | super().__init__()
94 |
95 | self.injection = injection
96 |
97 | self.channels = dim
98 |
99 | self.encode_value = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=1, stride=1, padding=0)
100 | self.encode_query = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=1, stride=1, padding=0)
101 | self.encode_key = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=1, stride=1, padding=0)
102 |
103 | if self.injection:
104 | self.encode_query2 = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=1, stride=1, padding=0)
105 | self.encode_key2 = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=1, stride=1, padding=0)
106 |
107 | self.attn = Attention(dim, num_heads=num_heads)
108 | self.norm2 = norm_layer(dim)
109 | mlp_hidden_dim = int(dim * mlp_ratio)
110 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer)
111 | self.q_embedding = nn.Parameter(torch.randn(1, 256, 32, 32))
112 | self.k_embedding = nn.Parameter(torch.randn(1, 256, 32, 32))
113 |
114 | def with_pos_embed(self, tensor, pos):
115 | return tensor if pos is None else tensor + pos
116 |
117 | def forward(self, query, key, query_embed=None, key_embed=None):
118 | b, c, h, w = query.shape
119 | query_embed = repeat(self.q_embedding, '() n c d -> b n c d', b = b)
120 | key_embed = repeat(self.k_embedding, '() n c d -> b n c d', b = b)
121 |
122 | q_embed = self.with_pos_embed(query, query_embed)
123 | k_embed = self.with_pos_embed(key, key_embed)
124 |
125 | v = self.encode_value(key).view(b, self.channels, -1)
126 | v = v.permute(0, 2, 1)
127 |
128 | q = self.encode_query(q_embed).view(b, self.channels, -1)
129 | q = q.permute(0, 2, 1)
130 |
131 | k = self.encode_key(k_embed).view(b, self.channels, -1)
132 | k = k.permute(0, 2, 1)
133 |
134 | query = query.view(b, self.channels, -1).permute(0, 2, 1)
135 |
136 | if self.injection:
137 | q2 = self.encode_query2(q_embed).view(b, self.channels, -1)
138 | q2 = q2.permute(0, 2, 1)
139 |
140 | k2 = self.encode_key2(k_embed).view(b, self.channels, -1)
141 | k2 = k2.permute(0, 2, 1)
142 |
143 | query = self.attn(query=q, key=k, value=v,query2 = q2, key2 = k2, use_sigmoid=True)
144 | else:
145 | q2 = None
146 | k2 = None
147 |
148 | query = query + self.attn(query=q, key=k, value=v, query2 = q2, key2 = k2, use_sigmoid=False)
149 |
150 | query = query + self.mlp(self.norm2(query))
151 | query = query.permute(0, 2, 1).contiguous().view(b, self.channels, h, w)
152 |
153 | return query
154 |
--------------------------------------------------------------------------------
/common/timer.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Fast R-CNN
3 | # Copyright (c) 2015 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Ross Girshick
6 | # --------------------------------------------------------
7 |
8 | import time
9 |
10 | class Timer(object):
11 | """A simple timer."""
12 | def __init__(self):
13 | self.total_time = 0.
14 | self.calls = 0
15 | self.start_time = 0.
16 | self.diff = 0.
17 | self.average_time = 0.
18 | self.warm_up = 0
19 |
20 | def tic(self):
21 | # using time.time instead of time.clock because time time.clock
22 | # does not normalize for multithreading
23 | self.start_time = time.time()
24 |
25 | def toc(self, average=True):
26 | self.diff = time.time() - self.start_time
27 | if self.warm_up < 10:
28 | self.warm_up += 1
29 | return self.diff
30 | else:
31 | self.total_time += self.diff
32 | self.calls += 1
33 | self.average_time = self.total_time / self.calls
34 |
35 | if average:
36 | return self.average_time
37 | else:
38 | return self.diff
--------------------------------------------------------------------------------
/common/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/namepllet/HandOccNet/65ba997c9ce88947f70453fa0ac6c1e7d052660f/common/utils/__init__.py
--------------------------------------------------------------------------------
/common/utils/camera.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4 | # holder of all proprietary rights on this computer program.
5 | # You can only use this computer program if you have closed
6 | # a license agreement with MPG or you get the right to use the computer
7 | # program from someone who is authorized to grant you that right.
8 | # Any use of the computer program without a valid license is prohibited and
9 | # liable to prosecution.
10 | #
11 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
12 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
13 | # for Intelligent Systems and the Max Planck Institute for Biological
14 | # Cybernetics. All rights reserved.
15 | #
16 | # Contact: ps-license@tuebingen.mpg.de
17 |
18 | from __future__ import absolute_import
19 | from __future__ import print_function
20 | from __future__ import division
21 |
22 | from collections import namedtuple
23 |
24 | import torch
25 | import torch.nn as nn
26 | import torch.nn.functional as F
27 |
28 |
29 | PerspParams = namedtuple('ModelOutput',
30 | ['rotation', 'translation', 'center',
31 | 'focal_length'])
32 |
33 | def transform_mat(R: torch.tensor, t: torch.tensor) -> torch.Tensor:
34 | ''' Creates a batch of transformation matrices
35 | Args:
36 | - R: Bx3x3 array of a batch of rotation matrices
37 | - t: Bx3x1 array of a batch of translation vectors
38 | Returns:
39 | - T: Bx4x4 Transformation matrix
40 | '''
41 | # No padding left or right, only add an extra row
42 | return torch.cat([F.pad(R, [0, 0, 0, 1]),
43 | F.pad(t, [0, 0, 0, 1], value=1)], dim=2)
44 |
45 |
46 |
47 | def create_camera(camera_type='persp', **kwargs):
48 | if camera_type.lower() == 'persp':
49 | return PerspectiveCamera(**kwargs)
50 | else:
51 | raise ValueError('Uknown camera type: {}'.format(camera_type))
52 |
53 |
54 | class PerspectiveCamera(nn.Module):
55 |
56 | FOCAL_LENGTH = 500
57 |
58 | def __init__(self, rotation=None, translation=None,
59 | focal_length_x=None, focal_length_y=None,
60 | batch_size=1,
61 | center=None, dtype=torch.float32, **kwargs):
62 | super(PerspectiveCamera, self).__init__()
63 | self.name = ''
64 | self.batch_size = batch_size
65 | self.dtype = dtype
66 | # Make a buffer so that PyTorch does not complain when creating
67 | # the camera matrix
68 | self.register_buffer('zero',
69 | torch.zeros([batch_size], dtype=dtype))
70 |
71 | if focal_length_x is None or type(focal_length_x) == float:
72 | focal_length_x = torch.full(
73 | [batch_size],
74 | self.FOCAL_LENGTH if focal_length_x is None else
75 | focal_length_x,
76 | dtype=dtype)
77 |
78 | if focal_length_y is None or type(focal_length_y) == float:
79 | focal_length_y = torch.full(
80 | [batch_size],
81 | self.FOCAL_LENGTH if focal_length_y is None else
82 | focal_length_y,
83 | dtype=dtype)
84 |
85 | self.register_buffer('focal_length_x', focal_length_x)
86 | self.register_buffer('focal_length_y', focal_length_y)
87 |
88 | if center is None:
89 | center = torch.zeros([batch_size, 2], dtype=dtype)
90 | self.register_buffer('center', center)
91 |
92 | if rotation is None:
93 | rotation = torch.eye(
94 | 3, dtype=dtype).unsqueeze(dim=0).repeat(batch_size, 1, 1)
95 |
96 | rotation = nn.Parameter(rotation, requires_grad=True)
97 | self.register_parameter('rotation', rotation)
98 |
99 | if translation is None:
100 | translation = torch.zeros([batch_size, 3], dtype=dtype)
101 |
102 | translation = nn.Parameter(translation,
103 | requires_grad=True)
104 | self.register_parameter('translation', translation)
105 |
106 | def forward(self, points):
107 | device = points.device
108 |
109 | with torch.no_grad():
110 | camera_mat = torch.zeros([self.batch_size, 2, 2],
111 | dtype=self.dtype, device=points.device)
112 | camera_mat[:, 0, 0] = self.focal_length_x
113 | camera_mat[:, 1, 1] = self.focal_length_y
114 |
115 | camera_transform = transform_mat(self.rotation,
116 | self.translation.unsqueeze(dim=-1))
117 | homog_coord = torch.ones(list(points.shape)[:-1] + [1],
118 | dtype=points.dtype,
119 | device=device)
120 | # Convert the points to homogeneous coordinates
121 | points_h = torch.cat([points, homog_coord], dim=-1)
122 |
123 | projected_points = torch.einsum('bki,bji->bjk',
124 | [camera_transform, points_h])
125 |
126 | img_points = torch.div(projected_points[:, :, :2],
127 | projected_points[:, :, 2].unsqueeze(dim=-1))
128 | img_points = torch.einsum('bki,bji->bjk', [camera_mat, img_points]) \
129 | + self.center.unsqueeze(dim=1)
130 | return img_points
--------------------------------------------------------------------------------
/common/utils/dir.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | def make_folder(folder_name):
5 | if not os.path.exists(folder_name):
6 | os.makedirs(folder_name)
7 |
8 | def add_pypath(path):
9 | if path not in sys.path:
10 | sys.path.insert(0, path)
--------------------------------------------------------------------------------
/common/utils/fitting.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 |
5 |
6 |
7 | def to_tensor(tensor, dtype=torch.float32):
8 | if torch.Tensor == type(tensor):
9 | return tensor.clone().detach()
10 | else:
11 | return torch.tensor(tensor, dtype=dtype)
12 |
13 |
14 | def rel_change(prev_val, curr_val):
15 | return (prev_val - curr_val) / max([np.abs(prev_val), np.abs(curr_val), 1])
16 |
17 | class FittingMonitor(object):
18 | def __init__(self, summary_steps=1, visualize=False,
19 | maxiters=300, ftol=1e-10, gtol=1e-09,
20 | body_color=(1.0, 1.0, 0.9, 1.0),
21 | model_type='mano',
22 | **kwargs):
23 | super(FittingMonitor, self).__init__()
24 |
25 | self.maxiters = maxiters
26 | self.ftol = ftol
27 | self.gtol = gtol
28 |
29 | self.visualize = visualize
30 | self.summary_steps = summary_steps
31 | self.body_color = body_color
32 | self.model_type = model_type
33 |
34 | def __enter__(self):
35 | self.steps = 0
36 |
37 | return self
38 |
39 | def __exit__(self, exception_type, exception_value, traceback):
40 | pass
41 |
42 | def set_colors(self, vertex_color):
43 | batch_size = self.colors.shape[0]
44 |
45 | self.colors = np.tile(
46 | np.array(vertex_color).reshape(1, 3),
47 | [batch_size, 1])
48 |
49 | def run_fitting(self, optimizer, closure, params,
50 | **kwargs):
51 | ''' Helper function for running an optimization process
52 | Parameters
53 | ----------
54 | optimizer: torch.optim.Optimizer
55 | The PyTorch optimizer object
56 | closure: function
57 | The function used to calculate the gradients
58 | params: list
59 | List containing the parameters that will be optimized
60 |
61 | Returns
62 | -------
63 | loss: float
64 | The final loss value
65 | '''
66 | prev_loss = None
67 | for n in range(self.maxiters):
68 | loss = optimizer.step(closure)
69 |
70 | if torch.isnan(loss).sum() > 0:
71 | print('NaN loss value, stopping!')
72 | break
73 |
74 | if torch.isinf(loss).sum() > 0:
75 | print('Infinite loss value, stopping!')
76 | break
77 |
78 | if n > 0 and prev_loss is not None and self.ftol > 0:
79 | loss_rel_change = rel_change(prev_loss, loss.item())
80 |
81 | if loss_rel_change <= self.ftol:
82 | break
83 |
84 | if all([torch.abs(var.grad.view(-1).max()).item() < self.gtol
85 | for var in params if var.grad is not None]):
86 | break
87 |
88 | prev_loss = loss.item()
89 |
90 | return prev_loss
91 |
92 | def create_fitting_closure(self,
93 | optimizer,
94 | camera=None,
95 | joint_cam=None,
96 | joint_img=None,
97 | hand_translation=None,
98 | hand_scale=None,
99 | loss=None,
100 | joints_conf=None,
101 | joint_weights=None,
102 | create_graph=False,
103 | **kwargs):
104 |
105 | def fitting_func(backward=True):
106 | if backward:
107 | optimizer.zero_grad()
108 |
109 | total_loss = loss(camera=camera,
110 | joint_cam=joint_cam,
111 | joint_img=joint_img,
112 | hand_translation=hand_translation,
113 | hand_scale =hand_scale,
114 | **kwargs)
115 |
116 | if backward:
117 | total_loss.backward(create_graph=create_graph)
118 |
119 | self.steps += 1
120 |
121 | return total_loss
122 |
123 | return fitting_func
124 |
125 |
126 |
127 |
128 | class ScaleTranslationLoss(nn.Module):
129 |
130 | def __init__(self, init_joints_idxs, trans_estimation=None,
131 | reduction='sum',
132 | data_weight=1.0,
133 | depth_loss_weight=1e3, dtype=torch.float32,
134 | **kwargs):
135 | super(ScaleTranslationLoss, self).__init__()
136 | self.dtype = dtype
137 |
138 | if trans_estimation is not None:
139 | self.register_buffer(
140 | 'trans_estimation',
141 | to_tensor(trans_estimation, dtype=dtype))
142 | else:
143 | self.trans_estimation = trans_estimation
144 |
145 | self.register_buffer('data_weight',
146 | torch.tensor(data_weight, dtype=dtype))
147 | self.register_buffer(
148 | 'init_joints_idxs',
149 | to_tensor(init_joints_idxs, dtype=torch.long))
150 | self.register_buffer('depth_loss_weight',
151 | torch.tensor(depth_loss_weight, dtype=dtype))
152 |
153 | def reset_loss_weights(self, loss_weight_dict):
154 | for key in loss_weight_dict:
155 | if hasattr(self, key):
156 | weight_tensor = getattr(self, key)
157 | weight_tensor = torch.tensor(loss_weight_dict[key],
158 | dtype=weight_tensor.dtype,
159 | device=weight_tensor.device)
160 | setattr(self, key, weight_tensor)
161 |
162 | def forward(self, camera, joint_cam, joint_img, hand_translation, hand_scale, **kwargs):
163 |
164 | projected_joints = camera(
165 | hand_scale * joint_cam + hand_translation)
166 |
167 | joint_error = \
168 | torch.index_select(joint_img, 1, self.init_joints_idxs) - \
169 | torch.index_select(projected_joints, 1, self.init_joints_idxs)
170 | joint_loss = torch.sum(joint_error.abs()) * self.data_weight ** 2
171 |
172 | depth_loss = 0.0
173 | if (self.depth_loss_weight.item() > 0 and self.trans_estimation is not None):
174 | depth_loss = self.depth_loss_weight * torch.sum((
175 | hand_translation[2] - self.trans_estimation[2]).abs() ** 2)
176 |
177 | return joint_loss + depth_loss
--------------------------------------------------------------------------------
/common/utils/mano.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import os.path as osp
4 | import json
5 | from config import cfg
6 |
7 | import sys
8 | sys.path.insert(0, cfg.mano_path)
9 | import manopth
10 | from manopth.manolayer import ManoLayer
11 |
12 | class MANO(object):
13 | def __init__(self):
14 | # TEMP
15 | self.left_layer = ManoLayer(mano_root=osp.join(cfg.mano_path, 'mano', 'models'), flat_hand_mean=False, use_pca=False, side='left') # load right hand MANO model
16 | self.layer = self.get_layer()
17 | self.vertex_num = 778
18 | self.face = self.layer.th_faces.numpy()
19 | self.joint_regressor = self.layer.th_J_regressor.numpy()
20 |
21 | self.joint_num = 21
22 | self.joints_name = ('Wrist', 'Thumb_1', 'Thumb_2', 'Thumb_3', 'Thumb_4', 'Index_1', 'Index_2', 'Index_3', 'Index_4', 'Middle_1', 'Middle_2', 'Middle_3', 'Middle_4', 'Ring_1', 'Ring_2', 'Ring_3', 'Ring_4', 'Pinky_1', 'Pinky_2', 'Pinky_3', 'Pinly_4')
23 | self.skeleton = ( (0,1), (0,5), (0,9), (0,13), (0,17), (1,2), (2,3), (3,4), (5,6), (6,7), (7,8), (9,10), (10,11), (11,12), (13,14), (14,15), (15,16), (17,18), (18,19), (19,20) )
24 | self.root_joint_idx = self.joints_name.index('Wrist')
25 |
26 | # add fingertips to joint_regressor
27 | self.fingertip_vertex_idx = [728, 353, 442, 576, 694] # mesh vertex idx
28 |
29 | thumbtip_onehot = np.array([1 if i == 728 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1)
30 | indextip_onehot = np.array([1 if i == 353 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1)
31 | middletip_onehot = np.array([1 if i == 442 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1)
32 | ringtip_onehot = np.array([1 if i == 576 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1)
33 | pinkytip_onehot = np.array([1 if i == 694 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1)
34 |
35 | self.joint_regressor = np.concatenate((self.joint_regressor, thumbtip_onehot, indextip_onehot, middletip_onehot, ringtip_onehot, pinkytip_onehot))
36 | self.joint_regressor = self.joint_regressor[[0, 13, 14, 15, 16, 1, 2, 3, 17, 4, 5, 6, 18, 10, 11, 12, 19, 7, 8, 9, 20],:]
37 |
38 | def get_layer(self):
39 | return ManoLayer(mano_root=osp.join(cfg.mano_path, 'mano', 'models'), flat_hand_mean=False, use_pca=False, side='right') # load right hand MANO model
--------------------------------------------------------------------------------
/common/utils/manopth/.gitignore:
--------------------------------------------------------------------------------
1 | *.sw*
2 | *.bak
3 | *_bak.py
4 |
5 | .cache/
6 | __pycache__/
7 | build/
8 | dist/
9 | manopth_hassony2.egg-info/
10 |
11 | mano/models
12 | assets/mano_layer.svg
13 |
--------------------------------------------------------------------------------
/common/utils/manopth/README.md:
--------------------------------------------------------------------------------
1 | Manopth
2 | =======
3 |
4 | [MANO](http://mano.is.tue.mpg.de) layer for [PyTorch](https://pytorch.org/) (tested with v0.4 and v1.x)
5 |
6 | ManoLayer is a differentiable PyTorch layer that deterministically maps from pose and shape parameters to hand joints and vertices.
7 | It can be integrated into any architecture as a differentiable layer to predict hand meshes.
8 |
9 | 
10 |
11 | ManoLayer takes **batched** hand pose and shape vectors and outputs corresponding hand joints and vertices.
12 |
13 | The code is mostly a PyTorch port of the original [MANO](http://mano.is.tue.mpg.de) model from [chumpy](https://github.com/mattloper/chumpy) to [PyTorch](https://pytorch.org/).
14 | It therefore builds directly upon the work of Javier Romero, Dimitrios Tzionas and Michael J. Black.
15 |
16 | This layer was developped and used for the paper *Learning joint reconstruction of hands and manipulated objects* for CVPR19.
17 | See [project page](https://github.com/hassony2/obman) and [demo+training code](https://github.com/hassony2/obman_train).
18 |
19 |
20 | It [reuses](https://github.com/hassony2/manopth/blob/master/manopth/rodrigues_layer.py) [part of the great code](https://github.com/MandyMo/pytorch_HMR/blob/master/src/util.py) from the [Pytorch layer for the SMPL body model](https://github.com/MandyMo/pytorch_HMR/blob/master/README.md) by Zhang Xiong ([MandyMo](https://github.com/MandyMo)) to compute the rotation utilities !
21 |
22 | It also includes in `mano/webuser` partial content of files from the original [MANO](http://mano.is.tue.mpg.de) code ([posemapper.py](mano/webuser/posemapper.py), [serialization.py](mano/webuser/serialization.py), [lbs.py](mano/webuser/lbs.py), [verts.py](mano/webuser/verts.py), [smpl_handpca_wrapper_HAND_only.py](mano/webuser/smpl_handpca_wrapper_HAND_only.py)).
23 |
24 | If you find this code useful for your research, consider citing:
25 |
26 | - the original [MANO](http://mano.is.tue.mpg.de) publication:
27 |
28 | ```
29 | @article{MANO:SIGGRAPHASIA:2017,
30 | title = {Embodied Hands: Modeling and Capturing Hands and Bodies Together},
31 | author = {Romero, Javier and Tzionas, Dimitrios and Black, Michael J.},
32 | journal = {ACM Transactions on Graphics, (Proc. SIGGRAPH Asia)},
33 | publisher = {ACM},
34 | month = nov,
35 | year = {2017},
36 | url = {http://doi.acm.org/10.1145/3130800.3130883},
37 | month_numeric = {11}
38 | }
39 | ```
40 |
41 | - the publication this PyTorch port was developped for:
42 |
43 | ```
44 | @INPROCEEDINGS{hasson19_obman,
45 | title = {Learning joint reconstruction of hands and manipulated objects},
46 | author = {Hasson, Yana and Varol, G{\"u}l and Tzionas, Dimitris and Kalevatykh, Igor and Black, Michael J. and Laptev, Ivan and Schmid, Cordelia},
47 | booktitle = {CVPR},
48 | year = {2019}
49 | }
50 | ```
51 |
52 | The training code associated with this paper, compatible with manopth can be found [here](https://github.com/hassony2/obman_train). The release includes a model trained on a variety of hand datasets.
53 |
54 | # Installation
55 |
56 | ## Get code and dependencies
57 |
58 | - `git clone https://github.com/hassony2/manopth`
59 | - `cd manopth`
60 | - Install the dependencies listed in [environment.yml](environment.yml)
61 | - In an existing conda environment, `conda env update -f environment.yml`
62 | - In a new environment, `conda env create -f environment.yml`, will create a conda environment named `manopth`
63 |
64 | ## Download MANO pickle data-structures
65 |
66 | - Go to [MANO website](http://mano.is.tue.mpg.de/)
67 | - Create an account by clicking *Sign Up* and provide your information
68 | - Download Models and Code (the downloaded file should have the format `mano_v*_*.zip`). Note that all code and data from this download falls under the [MANO license](http://mano.is.tue.mpg.de/license).
69 | - unzip and copy the `models` folder into the `manopth/mano` folder
70 | - Your folder structure should look like this:
71 | ```
72 | manopth/
73 | mano/
74 | models/
75 | MANO_LEFT.pkl
76 | MANO_RIGHT.pkl
77 | ...
78 | manopth/
79 | __init__.py
80 | ...
81 | ```
82 |
83 | To check that everything is going well, run `python examples/manopth_mindemo.py`, which should generate from a random hand using the MANO layer !
84 |
85 | ## Install `manopth` package
86 |
87 | To be able to import and use `ManoLayer` in another project, go to your `manopth` folder and run `pip install .`
88 |
89 |
90 | `cd /path/to/other/project`
91 |
92 | You can now use `from manopth import ManoLayer` in this other project!
93 |
94 | # Usage
95 |
96 | ## Minimal usage script
97 |
98 | See [examples/manopth_mindemo.py](examples/manopth_mindemo.py)
99 |
100 | Simple forward pass with random pose and shape parameters through MANO layer
101 |
102 | ```python
103 | import torch
104 | from manopth.manolayer import ManoLayer
105 | from manopth import demo
106 |
107 | batch_size = 10
108 | # Select number of principal components for pose space
109 | ncomps = 6
110 |
111 | # Initialize MANO layer
112 | mano_layer = ManoLayer(mano_root='mano/models', use_pca=True, ncomps=ncomps)
113 |
114 | # Generate random shape parameters
115 | random_shape = torch.rand(batch_size, 10)
116 | # Generate random pose parameters, including 3 values for global axis-angle rotation
117 | random_pose = torch.rand(batch_size, ncomps + 3)
118 |
119 | # Forward pass through MANO layer
120 | hand_verts, hand_joints = mano_layer(random_pose, random_shape)
121 | demo.display_hand({'verts': hand_verts, 'joints': hand_joints}, mano_faces=mano_layer.th_faces)
122 | ```
123 |
124 | Result :
125 |
126 | 
127 |
128 | ## Demo
129 |
130 | With more options, forward and backward pass, and a loop for quick profiling, look at [examples/manopth_demo.py](examples/manopth_demo.py).
131 |
132 | You can run it locally with:
133 |
134 | `python examples/manopth_demo.py`
135 |
136 |
--------------------------------------------------------------------------------
/common/utils/manopth/environment.yml:
--------------------------------------------------------------------------------
1 | name: manopth
2 |
3 | dependencies:
4 | - opencv
5 | - python=3.7
6 | - matplotlib
7 | - numpy
8 | - pytorch
9 | - tqdm
10 | - git
11 | - pip:
12 | - git+https://github.com/hassony2/chumpy.git
13 |
--------------------------------------------------------------------------------
/common/utils/manopth/examples/manopth_demo.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | from matplotlib import pyplot as plt
4 | from mpl_toolkits.mplot3d import Axes3D
5 | import torch
6 | from tqdm import tqdm
7 |
8 | from manopth import argutils
9 | from manopth.manolayer import ManoLayer
10 | from manopth.demo import display_hand
11 |
12 | if __name__ == '__main__':
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument('--batch_size', default=1, type=int)
15 | parser.add_argument('--cuda', action='store_true')
16 | parser.add_argument(
17 | '--no_display',
18 | action='store_true',
19 | help="Disable display output of ManoLayer given random inputs")
20 | parser.add_argument('--side', default='left', choices=['left', 'right'])
21 | parser.add_argument('--random_shape', action='store_true', help="Random hand shape")
22 | parser.add_argument('--rand_mag', type=float, default=1, help="Controls pose variability")
23 | parser.add_argument(
24 | '--flat_hand_mean',
25 | action='store_true',
26 | help="Use flat hand as mean instead of average hand pose")
27 | parser.add_argument(
28 | '--iters',
29 | type=int,
30 | default=1,
31 | help=
32 | "Use for quick profiling of forward and backward pass accross ManoLayer"
33 | )
34 | parser.add_argument('--mano_root', default='mano/models')
35 | parser.add_argument('--root_rot_mode', default='axisang', choices=['rot6d', 'axisang'])
36 | parser.add_argument('--no_pca', action='store_true', help="Give axis-angle or rotation matrix as inputs instead of PCA coefficients")
37 | parser.add_argument('--joint_rot_mode', default='axisang', choices=['rotmat', 'axisang'], help="Joint rotation inputs")
38 | parser.add_argument(
39 | '--mano_ncomps', default=6, type=int, help="Number of PCA components")
40 | args = parser.parse_args()
41 |
42 | argutils.print_args(args)
43 |
44 | layer = ManoLayer(
45 | flat_hand_mean=args.flat_hand_mean,
46 | side=args.side,
47 | mano_root=args.mano_root,
48 | ncomps=args.mano_ncomps,
49 | use_pca=not args.no_pca,
50 | root_rot_mode=args.root_rot_mode,
51 | joint_rot_mode=args.joint_rot_mode)
52 | if args.root_rot_mode == 'axisang':
53 | rot = 3
54 | else:
55 | rot = 6
56 | print(rot)
57 | if args.no_pca:
58 | args.mano_ncomps = 45
59 |
60 | # Generate random pose coefficients
61 | pose_params = args.rand_mag * torch.rand(args.batch_size, args.mano_ncomps + rot)
62 | pose_params.requires_grad = True
63 | if args.random_shape:
64 | shape = torch.rand(args.batch_size, 10)
65 | else:
66 | shape = torch.zeros(1) # Hack to act like None for PyTorch JIT
67 | if args.cuda:
68 | pose_params = pose_params.cuda()
69 | shape = shape.cuda()
70 | layer.cuda()
71 |
72 | # Loop for forward/backward quick profiling
73 | for idx in tqdm(range(args.iters)):
74 | # Forward pass
75 | verts, Jtr = layer(pose_params, th_betas=shape)
76 |
77 | # Backward pass
78 | loss = torch.norm(verts)
79 | loss.backward()
80 |
81 | if not args.no_display:
82 | verts, Jtr = layer(pose_params, th_betas=shape)
83 | joints = Jtr.cpu().detach()
84 | verts = verts.cpu().detach()
85 |
86 | # Draw obtained vertices and joints
87 | display_hand({
88 | 'verts': verts,
89 | 'joints': joints
90 | },
91 | mano_faces=layer.th_faces)
92 |
--------------------------------------------------------------------------------
/common/utils/manopth/examples/manopth_mindemo.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from manopth.manolayer import ManoLayer
3 | from manopth import demo
4 |
5 | batch_size = 10
6 | # Select number of principal components for pose space
7 | ncomps = 6
8 |
9 | # Initialize MANO layer
10 | mano_layer = ManoLayer(
11 | mano_root='mano/models', use_pca=True, ncomps=ncomps, flat_hand_mean=False)
12 |
13 | # Generate random shape parameters
14 | random_shape = torch.rand(batch_size, 10)
15 | # Generate random pose parameters, including 3 values for global axis-angle rotation
16 | random_pose = torch.rand(batch_size, ncomps + 3)
17 |
18 | # Forward pass through MANO layer
19 | hand_verts, hand_joints = mano_layer(random_pose, random_shape)
20 | demo.display_hand({
21 | 'verts': hand_verts,
22 | 'joints': hand_joints
23 | },
24 | mano_faces=mano_layer.th_faces)
25 |
--------------------------------------------------------------------------------
/common/utils/manopth/mano/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/namepllet/HandOccNet/65ba997c9ce88947f70453fa0ac6c1e7d052660f/common/utils/manopth/mano/__init__.py
--------------------------------------------------------------------------------
/common/utils/manopth/mano/webuser/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/namepllet/HandOccNet/65ba997c9ce88947f70453fa0ac6c1e7d052660f/common/utils/manopth/mano/webuser/__init__.py
--------------------------------------------------------------------------------
/common/utils/manopth/mano/webuser/lbs.py:
--------------------------------------------------------------------------------
1 | '''
2 | Copyright 2017 Javier Romero, Dimitrios Tzionas, Michael J Black and the Max Planck Gesellschaft. All rights reserved.
3 | This software is provided for research purposes only.
4 | By using this software you agree to the terms of the MANO/SMPL+H Model license here http://mano.is.tue.mpg.de/license
5 |
6 | More information about MANO/SMPL+H is available at http://mano.is.tue.mpg.de.
7 | For comments or questions, please email us at: mano@tue.mpg.de
8 |
9 |
10 | About this file:
11 | ================
12 | This file defines a wrapper for the loading functions of the MANO model.
13 |
14 | Modules included:
15 | - load_model:
16 | loads the MANO model from a given file location (i.e. a .pkl file location),
17 | or a dictionary object.
18 |
19 | '''
20 |
21 |
22 | from mano.webuser.posemapper import posemap
23 | import chumpy
24 | import numpy as np
25 |
26 |
27 | def global_rigid_transformation(pose, J, kintree_table, xp):
28 | results = {}
29 | pose = pose.reshape((-1, 3))
30 | id_to_col = {kintree_table[1, i]: i for i in range(kintree_table.shape[1])}
31 | parent = {
32 | i: id_to_col[kintree_table[0, i]]
33 | for i in range(1, kintree_table.shape[1])
34 | }
35 |
36 | if xp == chumpy:
37 | from mano.webuser.posemapper import Rodrigues
38 | rodrigues = lambda x: Rodrigues(x)
39 | else:
40 | import cv2
41 | rodrigues = lambda x: cv2.Rodrigues(x)[0]
42 |
43 | with_zeros = lambda x: xp.vstack((x, xp.array([[0.0, 0.0, 0.0, 1.0]])))
44 | results[0] = with_zeros(
45 | xp.hstack((rodrigues(pose[0, :]), J[0, :].reshape((3, 1)))))
46 |
47 | for i in range(1, kintree_table.shape[1]):
48 | results[i] = results[parent[i]].dot(
49 | with_zeros(
50 | xp.hstack((rodrigues(pose[i, :]), ((J[i, :] - J[parent[i], :]
51 | ).reshape((3, 1)))))))
52 |
53 | pack = lambda x: xp.hstack([np.zeros((4, 3)), x.reshape((4, 1))])
54 |
55 | results = [results[i] for i in sorted(results.keys())]
56 | results_global = results
57 |
58 | if True:
59 | results2 = [
60 | results[i] - (pack(results[i].dot(xp.concatenate(((J[i, :]), 0)))))
61 | for i in range(len(results))
62 | ]
63 | results = results2
64 | result = xp.dstack(results)
65 | return result, results_global
66 |
67 |
68 | def verts_core(pose, v, J, weights, kintree_table, want_Jtr=False, xp=chumpy):
69 | A, A_global = global_rigid_transformation(pose, J, kintree_table, xp)
70 | T = A.dot(weights.T)
71 |
72 | rest_shape_h = xp.vstack((v.T, np.ones((1, v.shape[0]))))
73 |
74 | v = (T[:, 0, :] * rest_shape_h[0, :].reshape(
75 | (1, -1)) + T[:, 1, :] * rest_shape_h[1, :].reshape(
76 | (1, -1)) + T[:, 2, :] * rest_shape_h[2, :].reshape(
77 | (1, -1)) + T[:, 3, :] * rest_shape_h[3, :].reshape((1, -1))).T
78 |
79 | v = v[:, :3]
80 |
81 | if not want_Jtr:
82 | return v
83 | Jtr = xp.vstack([g[:3, 3] for g in A_global])
84 | return (v, Jtr)
85 |
--------------------------------------------------------------------------------
/common/utils/manopth/mano/webuser/posemapper.py:
--------------------------------------------------------------------------------
1 | '''
2 | Copyright 2017 Javier Romero, Dimitrios Tzionas, Michael J Black and the Max Planck Gesellschaft. All rights reserved.
3 | This software is provided for research purposes only.
4 | By using this software you agree to the terms of the MANO/SMPL+H Model license here http://mano.is.tue.mpg.de/license
5 |
6 | More information about MANO/SMPL+H is available at http://mano.is.tue.mpg.de.
7 | For comments or questions, please email us at: mano@tue.mpg.de
8 |
9 |
10 | About this file:
11 | ================
12 | This file defines a wrapper for the loading functions of the MANO model.
13 |
14 | Modules included:
15 | - load_model:
16 | loads the MANO model from a given file location (i.e. a .pkl file location),
17 | or a dictionary object.
18 |
19 | '''
20 |
21 |
22 | import chumpy as ch
23 | import numpy as np
24 | import cv2
25 |
26 |
27 | class Rodrigues(ch.Ch):
28 | dterms = 'rt'
29 |
30 | def compute_r(self):
31 | return cv2.Rodrigues(self.rt.r)[0]
32 |
33 | def compute_dr_wrt(self, wrt):
34 | if wrt is self.rt:
35 | return cv2.Rodrigues(self.rt.r)[1].T
36 |
37 |
38 | def lrotmin(p):
39 | if isinstance(p, np.ndarray):
40 | p = p.ravel()[3:]
41 | return np.concatenate(
42 | [(cv2.Rodrigues(np.array(pp))[0] - np.eye(3)).ravel()
43 | for pp in p.reshape((-1, 3))]).ravel()
44 | if p.ndim != 2 or p.shape[1] != 3:
45 | p = p.reshape((-1, 3))
46 | p = p[1:]
47 | return ch.concatenate([(Rodrigues(pp) - ch.eye(3)).ravel()
48 | for pp in p]).ravel()
49 |
50 |
51 | def posemap(s):
52 | if s == 'lrotmin':
53 | return lrotmin
54 | else:
55 | raise Exception('Unknown posemapping: %s' % (str(s), ))
56 |
--------------------------------------------------------------------------------
/common/utils/manopth/mano/webuser/serialization.py:
--------------------------------------------------------------------------------
1 | '''
2 | Copyright 2017 Javier Romero, Dimitrios Tzionas, Michael J Black and the Max Planck Gesellschaft. All rights reserved.
3 | This software is provided for research purposes only.
4 | By using this software you agree to the terms of the MANO/SMPL+H Model license here http://mano.is.tue.mpg.de/license
5 |
6 | More information about MANO/SMPL+H is available at http://mano.is.tue.mpg.de.
7 | For comments or questions, please email us at: mano@tue.mpg.de
8 |
9 |
10 | About this file:
11 | ================
12 | This file defines a wrapper for the loading functions of the MANO model.
13 |
14 | Modules included:
15 | - load_model:
16 | loads the MANO model from a given file location (i.e. a .pkl file location),
17 | or a dictionary object.
18 |
19 | '''
20 |
21 |
22 | __all__ = ['load_model', 'save_model']
23 |
24 | import numpy as np
25 | import pickle
26 | import chumpy as ch
27 | from chumpy.ch import MatVecMult
28 | from mano.webuser.posemapper import posemap
29 | from mano.webuser.verts import verts_core
30 |
31 | def ready_arguments(fname_or_dict):
32 |
33 | if not isinstance(fname_or_dict, dict):
34 | dd = pickle.load(open(fname_or_dict, 'rb'), encoding='latin1')
35 | else:
36 | dd = fname_or_dict
37 |
38 | backwards_compatibility_replacements(dd)
39 |
40 | want_shapemodel = 'shapedirs' in dd
41 | nposeparms = dd['kintree_table'].shape[1] * 3
42 |
43 | if 'trans' not in dd:
44 | dd['trans'] = np.zeros(3)
45 | if 'pose' not in dd:
46 | dd['pose'] = np.zeros(nposeparms)
47 | if 'shapedirs' in dd and 'betas' not in dd:
48 | dd['betas'] = np.zeros(dd['shapedirs'].shape[-1])
49 |
50 | for s in [
51 | 'v_template', 'weights', 'posedirs', 'pose', 'trans', 'shapedirs',
52 | 'betas', 'J'
53 | ]:
54 | if (s in dd) and not hasattr(dd[s], 'dterms'):
55 | dd[s] = ch.array(dd[s])
56 |
57 | if want_shapemodel:
58 | dd['v_shaped'] = dd['shapedirs'].dot(dd['betas']) + dd['v_template']
59 | v_shaped = dd['v_shaped']
60 | J_tmpx = MatVecMult(dd['J_regressor'], v_shaped[:, 0])
61 | J_tmpy = MatVecMult(dd['J_regressor'], v_shaped[:, 1])
62 | J_tmpz = MatVecMult(dd['J_regressor'], v_shaped[:, 2])
63 | dd['J'] = ch.vstack((J_tmpx, J_tmpy, J_tmpz)).T
64 | dd['v_posed'] = v_shaped + dd['posedirs'].dot(
65 | posemap(dd['bs_type'])(dd['pose']))
66 | else:
67 | dd['v_posed'] = dd['v_template'] + dd['posedirs'].dot(
68 | posemap(dd['bs_type'])(dd['pose']))
69 |
70 | return dd
71 |
72 |
73 | def load_model(fname_or_dict):
74 | dd = ready_arguments(fname_or_dict)
75 |
76 | args = {
77 | 'pose': dd['pose'],
78 | 'v': dd['v_posed'],
79 | 'J': dd['J'],
80 | 'weights': dd['weights'],
81 | 'kintree_table': dd['kintree_table'],
82 | 'xp': ch,
83 | 'want_Jtr': True,
84 | 'bs_style': dd['bs_style']
85 | }
86 |
87 | result, Jtr = verts_core(**args)
88 | result = result + dd['trans'].reshape((1, 3))
89 | result.J_transformed = Jtr + dd['trans'].reshape((1, 3))
90 |
91 | for k, v in dd.items():
92 | setattr(result, k, v)
93 |
94 | return result
95 |
--------------------------------------------------------------------------------
/common/utils/manopth/mano/webuser/smpl_handpca_wrapper_HAND_only.py:
--------------------------------------------------------------------------------
1 | '''
2 | Copyright 2017 Javier Romero, Dimitrios Tzionas, Michael J Black and the Max Planck Gesellschaft. All rights reserved.
3 | This software is provided for research purposes only.
4 | By using this software you agree to the terms of the MANO/SMPL+H Model license here http://mano.is.tue.mpg.de/license
5 |
6 | More information about MANO/SMPL+H is available at http://mano.is.tue.mpg.de.
7 | For comments or questions, please email us at: mano@tue.mpg.de
8 |
9 |
10 | About this file:
11 | ================
12 | This file defines a wrapper for the loading functions of the MANO model.
13 |
14 | Modules included:
15 | - load_model:
16 | loads the MANO model from a given file location (i.e. a .pkl file location),
17 | or a dictionary object.
18 |
19 | '''
20 |
21 |
22 | def ready_arguments(fname_or_dict, posekey4vposed='pose'):
23 | import numpy as np
24 | import pickle
25 | import chumpy as ch
26 | from chumpy.ch import MatVecMult
27 | from mano.webuser.posemapper import posemap
28 |
29 | if not isinstance(fname_or_dict, dict):
30 | dd = pickle.load(open(fname_or_dict, 'rb'), encoding='latin1')
31 | # dd = pickle.load(open(fname_or_dict, 'rb'))
32 | else:
33 | dd = fname_or_dict
34 |
35 | want_shapemodel = 'shapedirs' in dd
36 | nposeparms = dd['kintree_table'].shape[1] * 3
37 |
38 | if 'trans' not in dd:
39 | dd['trans'] = np.zeros(3)
40 | if 'pose' not in dd:
41 | dd['pose'] = np.zeros(nposeparms)
42 | if 'shapedirs' in dd and 'betas' not in dd:
43 | dd['betas'] = np.zeros(dd['shapedirs'].shape[-1])
44 |
45 | for s in [
46 | 'v_template', 'weights', 'posedirs', 'pose', 'trans', 'shapedirs',
47 | 'betas', 'J'
48 | ]:
49 | if (s in dd) and not hasattr(dd[s], 'dterms'):
50 | dd[s] = ch.array(dd[s])
51 |
52 | assert (posekey4vposed in dd)
53 | if want_shapemodel:
54 | dd['v_shaped'] = dd['shapedirs'].dot(dd['betas']) + dd['v_template']
55 | v_shaped = dd['v_shaped']
56 | J_tmpx = MatVecMult(dd['J_regressor'], v_shaped[:, 0])
57 | J_tmpy = MatVecMult(dd['J_regressor'], v_shaped[:, 1])
58 | J_tmpz = MatVecMult(dd['J_regressor'], v_shaped[:, 2])
59 | dd['J'] = ch.vstack((J_tmpx, J_tmpy, J_tmpz)).T
60 | pose_map_res = posemap(dd['bs_type'])(dd[posekey4vposed])
61 | dd['v_posed'] = v_shaped + dd['posedirs'].dot(pose_map_res)
62 | else:
63 | pose_map_res = posemap(dd['bs_type'])(dd[posekey4vposed])
64 | dd_add = dd['posedirs'].dot(pose_map_res)
65 | dd['v_posed'] = dd['v_template'] + dd_add
66 |
67 | return dd
68 |
69 |
70 | def load_model(fname_or_dict, ncomps=6, flat_hand_mean=False, v_template=None):
71 | ''' This model loads the fully articulable HAND SMPL model,
72 | and replaces the pose DOFS by ncomps from PCA'''
73 |
74 | from mano.webuser.verts import verts_core
75 | import numpy as np
76 | import chumpy as ch
77 | import pickle
78 | import scipy.sparse as sp
79 | np.random.seed(1)
80 |
81 | if not isinstance(fname_or_dict, dict):
82 | smpl_data = pickle.load(open(fname_or_dict, 'rb'), encoding='latin1')
83 | # smpl_data = pickle.load(open(fname_or_dict, 'rb'))
84 | else:
85 | smpl_data = fname_or_dict
86 |
87 | rot = 3 # for global orientation!!!
88 |
89 | hands_components = smpl_data['hands_components']
90 | hands_mean = np.zeros(hands_components.shape[
91 | 1]) if flat_hand_mean else smpl_data['hands_mean']
92 | hands_coeffs = smpl_data['hands_coeffs'][:, :ncomps]
93 |
94 | selected_components = np.vstack((hands_components[:ncomps]))
95 | hands_mean = hands_mean.copy()
96 |
97 | pose_coeffs = ch.zeros(rot + selected_components.shape[0])
98 | full_hand_pose = pose_coeffs[rot:(rot + ncomps)].dot(selected_components)
99 |
100 | smpl_data['fullpose'] = ch.concatenate((pose_coeffs[:rot],
101 | hands_mean + full_hand_pose))
102 | smpl_data['pose'] = pose_coeffs
103 |
104 | Jreg = smpl_data['J_regressor']
105 | if not sp.issparse(Jreg):
106 | smpl_data['J_regressor'] = (sp.csc_matrix(
107 | (Jreg.data, (Jreg.row, Jreg.col)), shape=Jreg.shape))
108 |
109 | # slightly modify ready_arguments to make sure that it uses the fullpose
110 | # (which will NOT be pose) for the computation of posedirs
111 | dd = ready_arguments(smpl_data, posekey4vposed='fullpose')
112 |
113 | # create the smpl formula with the fullpose,
114 | # but expose the PCA coefficients as smpl.pose for compatibility
115 | args = {
116 | 'pose': dd['fullpose'],
117 | 'v': dd['v_posed'],
118 | 'J': dd['J'],
119 | 'weights': dd['weights'],
120 | 'kintree_table': dd['kintree_table'],
121 | 'xp': ch,
122 | 'want_Jtr': True,
123 | 'bs_style': dd['bs_style'],
124 | }
125 |
126 | result_previous, meta = verts_core(**args)
127 |
128 | result = result_previous + dd['trans'].reshape((1, 3))
129 | result.no_translation = result_previous
130 |
131 | if meta is not None:
132 | for field in ['Jtr', 'A', 'A_global', 'A_weighted']:
133 | if (hasattr(meta, field)):
134 | setattr(result, field, getattr(meta, field))
135 |
136 | setattr(result, 'Jtr', meta)
137 | if hasattr(result, 'Jtr'):
138 | result.J_transformed = result.Jtr + dd['trans'].reshape((1, 3))
139 |
140 | for k, v in dd.items():
141 | setattr(result, k, v)
142 |
143 | if v_template is not None:
144 | result.v_template[:] = v_template
145 |
146 | return result
147 |
148 |
149 | if __name__ == '__main__':
150 | load_model()
151 |
--------------------------------------------------------------------------------
/common/utils/manopth/mano/webuser/verts.py:
--------------------------------------------------------------------------------
1 | '''
2 | Copyright 2017 Javier Romero, Dimitrios Tzionas, Michael J Black and the Max Planck Gesellschaft. All rights reserved.
3 | This software is provided for research purposes only.
4 | By using this software you agree to the terms of the MANO/SMPL+H Model license here http://mano.is.tue.mpg.de/license
5 |
6 | More information about MANO/SMPL+H is available at http://mano.is.tue.mpg.de.
7 | For comments or questions, please email us at: mano@tue.mpg.de
8 |
9 |
10 | About this file:
11 | ================
12 | This file defines a wrapper for the loading functions of the MANO model.
13 |
14 | Modules included:
15 | - load_model:
16 | loads the MANO model from a given file location (i.e. a .pkl file location),
17 | or a dictionary object.
18 |
19 | '''
20 |
21 |
22 | import chumpy
23 | import mano.webuser.lbs as lbs
24 | from mano.webuser.posemapper import posemap
25 | import scipy.sparse as sp
26 | from chumpy.ch import MatVecMult
27 |
28 |
29 | def ischumpy(x):
30 | return hasattr(x, 'dterms')
31 |
32 |
33 | def verts_decorated(trans,
34 | pose,
35 | v_template,
36 | J_regressor,
37 | weights,
38 | kintree_table,
39 | bs_style,
40 | f,
41 | bs_type=None,
42 | posedirs=None,
43 | betas=None,
44 | shapedirs=None,
45 | want_Jtr=False):
46 |
47 | for which in [
48 | trans, pose, v_template, weights, posedirs, betas, shapedirs
49 | ]:
50 | if which is not None:
51 | assert ischumpy(which)
52 |
53 | v = v_template
54 |
55 | if shapedirs is not None:
56 | if betas is None:
57 | betas = chumpy.zeros(shapedirs.shape[-1])
58 | v_shaped = v + shapedirs.dot(betas)
59 | else:
60 | v_shaped = v
61 |
62 | if posedirs is not None:
63 | v_posed = v_shaped + posedirs.dot(posemap(bs_type)(pose))
64 | else:
65 | v_posed = v_shaped
66 |
67 | v = v_posed
68 |
69 | if sp.issparse(J_regressor):
70 | J_tmpx = MatVecMult(J_regressor, v_shaped[:, 0])
71 | J_tmpy = MatVecMult(J_regressor, v_shaped[:, 1])
72 | J_tmpz = MatVecMult(J_regressor, v_shaped[:, 2])
73 | J = chumpy.vstack((J_tmpx, J_tmpy, J_tmpz)).T
74 | else:
75 | assert (ischumpy(J))
76 |
77 | assert (bs_style == 'lbs')
78 | result, Jtr = lbs.verts_core(
79 | pose, v, J, weights, kintree_table, want_Jtr=True, xp=chumpy)
80 |
81 | tr = trans.reshape((1, 3))
82 | result = result + tr
83 | Jtr = Jtr + tr
84 |
85 | result.trans = trans
86 | result.f = f
87 | result.pose = pose
88 | result.v_template = v_template
89 | result.J = J
90 | result.J_regressor = J_regressor
91 | result.weights = weights
92 | result.kintree_table = kintree_table
93 | result.bs_style = bs_style
94 | result.bs_type = bs_type
95 | if posedirs is not None:
96 | result.posedirs = posedirs
97 | result.v_posed = v_posed
98 | if shapedirs is not None:
99 | result.shapedirs = shapedirs
100 | result.betas = betas
101 | result.v_shaped = v_shaped
102 | if want_Jtr:
103 | result.J_transformed = Jtr
104 | return result
105 |
106 |
107 | def verts_core(pose,
108 | v,
109 | J,
110 | weights,
111 | kintree_table,
112 | bs_style,
113 | want_Jtr=False,
114 | xp=chumpy):
115 |
116 | if xp == chumpy:
117 | assert (hasattr(pose, 'dterms'))
118 | assert (hasattr(v, 'dterms'))
119 | assert (hasattr(J, 'dterms'))
120 | assert (hasattr(weights, 'dterms'))
121 |
122 | assert (bs_style == 'lbs')
123 | result = lbs.verts_core(pose, v, J, weights, kintree_table, want_Jtr, xp)
124 | return result
125 |
--------------------------------------------------------------------------------
/common/utils/manopth/manopth/__init__.py:
--------------------------------------------------------------------------------
1 | name = 'manopth'
2 |
--------------------------------------------------------------------------------
/common/utils/manopth/manopth/argutils.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import os
3 | import pickle
4 | import subprocess
5 | import sys
6 |
7 |
8 | def print_args(args):
9 | opts = vars(args)
10 | print('======= Options ========')
11 | for k, v in sorted(opts.items()):
12 | print('{}: {}'.format(k, v))
13 | print('========================')
14 |
15 |
16 | def save_args(args, save_folder, opt_prefix='opt', verbose=True):
17 | opts = vars(args)
18 | # Create checkpoint folder
19 | if not os.path.exists(save_folder):
20 | os.makedirs(save_folder, exist_ok=True)
21 |
22 | # Save options
23 | opt_filename = '{}.txt'.format(opt_prefix)
24 | opt_path = os.path.join(save_folder, opt_filename)
25 | with open(opt_path, 'a') as opt_file:
26 | opt_file.write('====== Options ======\n')
27 | for k, v in sorted(opts.items()):
28 | opt_file.write(
29 | '{option}: {value}\n'.format(option=str(k), value=str(v)))
30 | opt_file.write('=====================\n')
31 | opt_file.write('launched {} at {}\n'.format(
32 | str(sys.argv[0]), str(datetime.datetime.now())))
33 |
34 | # Add git info
35 | label = subprocess.check_output(["git", "describe",
36 | "--always"]).strip()
37 | if subprocess.call(
38 | ["git", "branch"],
39 | stderr=subprocess.STDOUT,
40 | stdout=open(os.devnull, 'w')) == 0:
41 | opt_file.write('=== Git info ====\n')
42 | opt_file.write('{}\n'.format(label))
43 | commit = subprocess.check_output(['git', 'rev-parse', 'HEAD'])
44 | opt_file.write('commit : {}\n'.format(commit.strip()))
45 |
46 | opt_picklename = '{}.pkl'.format(opt_prefix)
47 | opt_picklepath = os.path.join(save_folder, opt_picklename)
48 | with open(opt_picklepath, 'wb') as opt_file:
49 | pickle.dump(opts, opt_file)
50 | if verbose:
51 | print('Saved options to {}'.format(opt_path))
52 |
--------------------------------------------------------------------------------
/common/utils/manopth/manopth/demo.py:
--------------------------------------------------------------------------------
1 | from matplotlib import pyplot as plt
2 | from mpl_toolkits.mplot3d import Axes3D
3 | from mpl_toolkits.mplot3d.art3d import Poly3DCollection
4 | import numpy as np
5 | import torch
6 |
7 | from manopth.manolayer import ManoLayer
8 |
9 |
10 | def generate_random_hand(batch_size=1, ncomps=6, mano_root='mano/models'):
11 | nfull_comps = ncomps + 3 # Add global orientation dims to PCA
12 | random_pcapose = torch.rand(batch_size, nfull_comps)
13 | mano_layer = ManoLayer(mano_root=mano_root)
14 | verts, joints = mano_layer(random_pcapose)
15 | return {'verts': verts, 'joints': joints, 'faces': mano_layer.th_faces}
16 |
17 |
18 | def display_hand(hand_info, mano_faces=None, ax=None, alpha=0.2, batch_idx=0, show=True):
19 | """
20 | Displays hand batch_idx in batch of hand_info, hand_info as returned by
21 | generate_random_hand
22 | """
23 | if ax is None:
24 | fig = plt.figure()
25 | ax = fig.add_subplot(111, projection='3d')
26 | verts, joints = hand_info['verts'][batch_idx], hand_info['joints'][
27 | batch_idx]
28 | if mano_faces is None:
29 | ax.scatter(verts[:, 0], verts[:, 1], verts[:, 2], alpha=0.1)
30 | else:
31 | mesh = Poly3DCollection(verts[mano_faces], alpha=alpha)
32 | face_color = (141 / 255, 184 / 255, 226 / 255)
33 | edge_color = (50 / 255, 50 / 255, 50 / 255)
34 | mesh.set_edgecolor(edge_color)
35 | mesh.set_facecolor(face_color)
36 | ax.add_collection3d(mesh)
37 | ax.scatter(joints[:, 0], joints[:, 1], joints[:, 2], color='r')
38 | cam_equal_aspect_3d(ax, verts.numpy())
39 | if show:
40 | plt.show()
41 |
42 |
43 | def cam_equal_aspect_3d(ax, verts, flip_x=False):
44 | """
45 | Centers view on cuboid containing hand and flips y and z axis
46 | and fixes azimuth
47 | """
48 | extents = np.stack([verts.min(0), verts.max(0)], axis=1)
49 | sz = extents[:, 1] - extents[:, 0]
50 | centers = np.mean(extents, axis=1)
51 | maxsize = max(abs(sz))
52 | r = maxsize / 2
53 | if flip_x:
54 | ax.set_xlim(centers[0] + r, centers[0] - r)
55 | else:
56 | ax.set_xlim(centers[0] - r, centers[0] + r)
57 | # Invert y and z axis
58 | ax.set_ylim(centers[1] + r, centers[1] - r)
59 | ax.set_zlim(centers[2] + r, centers[2] - r)
60 |
--------------------------------------------------------------------------------
/common/utils/manopth/manopth/manolayer.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | import torch
5 | from torch.nn import Module
6 |
7 | from mano.webuser.smpl_handpca_wrapper_HAND_only import ready_arguments
8 | from manopth import rodrigues_layer, rotproj, rot6d
9 | from manopth.tensutils import (th_posemap_axisang, th_with_zeros, th_pack,
10 | subtract_flat_id, make_list)
11 |
12 |
13 | class ManoLayer(Module):
14 | __constants__ = [
15 | 'use_pca', 'rot', 'ncomps', 'ncomps', 'kintree_parents', 'check',
16 | 'side', 'center_idx', 'joint_rot_mode'
17 | ]
18 |
19 | def __init__(self,
20 | center_idx=None,
21 | flat_hand_mean=True,
22 | ncomps=6,
23 | side='right',
24 | mano_root='mano/models',
25 | use_pca=True,
26 | root_rot_mode='axisang',
27 | joint_rot_mode='axisang',
28 | robust_rot=False):
29 | """
30 | Args:
31 | center_idx: index of center joint in our computations,
32 | if -1 centers on estimate of palm as middle of base
33 | of middle finger and wrist
34 | flat_hand_mean: if True, (0, 0, 0, ...) pose coefficients match
35 | flat hand, else match average hand pose
36 | mano_root: path to MANO pkl files for left and right hand
37 | ncomps: number of PCA components form pose space (<45)
38 | side: 'right' or 'left'
39 | use_pca: Use PCA decomposition for pose space.
40 | joint_rot_mode: 'axisang' or 'rotmat', ignored if use_pca
41 | """
42 | super().__init__()
43 |
44 | self.center_idx = center_idx
45 | self.robust_rot = robust_rot
46 | if root_rot_mode == 'axisang':
47 | self.rot = 3
48 | else:
49 | self.rot = 6
50 | self.flat_hand_mean = flat_hand_mean
51 | self.side = side
52 | self.use_pca = use_pca
53 | self.joint_rot_mode = joint_rot_mode
54 | self.root_rot_mode = root_rot_mode
55 | if use_pca:
56 | self.ncomps = ncomps
57 | else:
58 | self.ncomps = 45
59 |
60 | if side == 'right':
61 | self.mano_path = os.path.join(mano_root, 'MANO_RIGHT.pkl')
62 | elif side == 'left':
63 | self.mano_path = os.path.join(mano_root, 'MANO_LEFT.pkl')
64 |
65 | smpl_data = ready_arguments(self.mano_path)
66 |
67 | hands_components = smpl_data['hands_components']
68 |
69 | self.smpl_data = smpl_data
70 |
71 | self.register_buffer('th_betas',
72 | torch.Tensor(smpl_data['betas'].r).unsqueeze(0))
73 | self.register_buffer('th_shapedirs',
74 | torch.Tensor(smpl_data['shapedirs'].r))
75 | self.register_buffer('th_posedirs',
76 | torch.Tensor(smpl_data['posedirs'].r))
77 | self.register_buffer(
78 | 'th_v_template',
79 | torch.Tensor(smpl_data['v_template'].r).unsqueeze(0))
80 | self.register_buffer(
81 | 'th_J_regressor',
82 | torch.Tensor(np.array(smpl_data['J_regressor'].toarray())))
83 | self.register_buffer('th_weights',
84 | torch.Tensor(smpl_data['weights'].r))
85 | self.register_buffer('th_faces',
86 | torch.Tensor(smpl_data['f'].astype(np.int32)).long())
87 |
88 | # Get hand mean
89 | hands_mean = np.zeros(hands_components.shape[1]
90 | ) if flat_hand_mean else smpl_data['hands_mean']
91 | hands_mean = hands_mean.copy()
92 | th_hands_mean = torch.Tensor(hands_mean).unsqueeze(0)
93 |
94 | if self.use_pca or self.joint_rot_mode == 'axisang':
95 | # Save as axis-angle
96 | self.register_buffer('th_hands_mean', th_hands_mean)
97 | selected_components = hands_components[:ncomps]
98 | self.register_buffer('th_selected_comps',
99 | torch.Tensor(selected_components))
100 | else:
101 | th_hands_mean_rotmat = rodrigues_layer.batch_rodrigues(
102 | th_hands_mean.view(15, 3)).reshape(15, 3, 3)
103 | self.register_buffer('th_hands_mean_rotmat', th_hands_mean_rotmat)
104 |
105 | # Kinematic chain params
106 | self.kintree_table = smpl_data['kintree_table']
107 | parents = list(self.kintree_table[0].tolist())
108 | self.kintree_parents = parents
109 |
110 | def forward(self,
111 | th_pose_coeffs,
112 | th_betas=torch.zeros(1),
113 | th_trans=torch.zeros(1),
114 | root_palm=torch.Tensor([0]),
115 | share_betas=torch.Tensor([0]),
116 | ):
117 | """
118 | Args:
119 | th_trans (Tensor (batch_size x ncomps)): if provided, applies trans to joints and vertices
120 | th_betas (Tensor (batch_size x 10)): if provided, uses given shape parameters for hand shape
121 | else centers on root joint (9th joint)
122 | root_palm: return palm as hand root instead of wrist
123 | """
124 | # if len(th_pose_coeffs) == 0:
125 | # return th_pose_coeffs.new_empty(0), th_pose_coeffs.new_empty(0)
126 |
127 | batch_size = th_pose_coeffs.shape[0]
128 | # Get axis angle from PCA components and coefficients
129 | if self.use_pca or self.joint_rot_mode == 'axisang':
130 | # Remove global rot coeffs
131 | th_hand_pose_coeffs = th_pose_coeffs[:, self.rot:self.rot +
132 | self.ncomps]
133 | if self.use_pca:
134 | # PCA components --> axis angles
135 | th_full_hand_pose = th_hand_pose_coeffs.mm(self.th_selected_comps)
136 | else:
137 | th_full_hand_pose = th_hand_pose_coeffs
138 |
139 | # Concatenate back global rot
140 | th_full_pose = torch.cat([
141 | th_pose_coeffs[:, :self.rot],
142 | self.th_hands_mean + th_full_hand_pose
143 | ], 1)
144 |
145 | if self.root_rot_mode == 'axisang':
146 | # compute rotation matrixes from axis-angle while skipping global rotation
147 | th_pose_map, th_rot_map = th_posemap_axisang(th_full_pose)
148 | root_rot = th_rot_map[:, :9].view(batch_size, 3, 3)
149 | th_rot_map = th_rot_map[:, 9:]
150 | th_pose_map = th_pose_map[:, 9:]
151 | else:
152 | # th_posemap offsets by 3, so add offset or 3 to get to self.rot=6
153 | th_pose_map, th_rot_map = th_posemap_axisang(th_full_pose[:, 6:])
154 | if self.robust_rot:
155 | root_rot = rot6d.robust_compute_rotation_matrix_from_ortho6d(th_full_pose[:, :6])
156 | else:
157 | root_rot = rot6d.compute_rotation_matrix_from_ortho6d(th_full_pose[:, :6])
158 | else:
159 | assert th_pose_coeffs.dim() == 4, (
160 | 'When not self.use_pca, '
161 | 'th_pose_coeffs should have 4 dims, got {}'.format(
162 | th_pose_coeffs.dim()))
163 | assert th_pose_coeffs.shape[2:4] == (3, 3), (
164 | 'When not self.use_pca, th_pose_coeffs have 3x3 matrix for two'
165 | 'last dims, got {}'.format(th_pose_coeffs.shape[2:4]))
166 | th_pose_rots = rotproj.batch_rotprojs(th_pose_coeffs)
167 | th_rot_map = th_pose_rots[:, 1:].view(batch_size, -1)
168 | th_pose_map = subtract_flat_id(th_rot_map)
169 | root_rot = th_pose_rots[:, 0]
170 |
171 | # Full axis angle representation with root joint
172 | if th_betas is None or th_betas.numel() == 1:
173 | th_v_shaped = torch.matmul(self.th_shapedirs,
174 | self.th_betas.transpose(1, 0)).permute(
175 | 2, 0, 1) + self.th_v_template
176 | th_j = torch.matmul(self.th_J_regressor, th_v_shaped).repeat(
177 | batch_size, 1, 1)
178 |
179 | else:
180 | if share_betas:
181 | th_betas = th_betas.mean(0, keepdim=True).expand(th_betas.shape[0], 10)
182 | th_v_shaped = torch.matmul(self.th_shapedirs,
183 | th_betas.transpose(1, 0)).permute(
184 | 2, 0, 1) + self.th_v_template
185 | th_j = torch.matmul(self.th_J_regressor, th_v_shaped)
186 | # th_pose_map should have shape 20x135
187 |
188 | th_v_posed = th_v_shaped + torch.matmul(
189 | self.th_posedirs, th_pose_map.transpose(0, 1)).permute(2, 0, 1)
190 | # Final T pose with transformation done !
191 |
192 | # Global rigid transformation
193 |
194 | root_j = th_j[:, 0, :].contiguous().view(batch_size, 3, 1)
195 | root_trans = th_with_zeros(torch.cat([root_rot, root_j], 2))
196 |
197 | all_rots = th_rot_map.view(th_rot_map.shape[0], 15, 3, 3)
198 | lev1_idxs = [1, 4, 7, 10, 13]
199 | lev2_idxs = [2, 5, 8, 11, 14]
200 | lev3_idxs = [3, 6, 9, 12, 15]
201 | lev1_rots = all_rots[:, [idx - 1 for idx in lev1_idxs]]
202 | lev2_rots = all_rots[:, [idx - 1 for idx in lev2_idxs]]
203 | lev3_rots = all_rots[:, [idx - 1 for idx in lev3_idxs]]
204 | lev1_j = th_j[:, lev1_idxs]
205 | lev2_j = th_j[:, lev2_idxs]
206 | lev3_j = th_j[:, lev3_idxs]
207 |
208 | # From base to tips
209 | # Get lev1 results
210 | all_transforms = [root_trans.unsqueeze(1)]
211 | lev1_j_rel = lev1_j - root_j.transpose(1, 2)
212 | lev1_rel_transform_flt = th_with_zeros(torch.cat([lev1_rots, lev1_j_rel.unsqueeze(3)], 3).view(-1, 3, 4))
213 | root_trans_flt = root_trans.unsqueeze(1).repeat(1, 5, 1, 1).view(root_trans.shape[0] * 5, 4, 4)
214 | lev1_flt = torch.matmul(root_trans_flt, lev1_rel_transform_flt)
215 | all_transforms.append(lev1_flt.view(all_rots.shape[0], 5, 4, 4))
216 |
217 | # Get lev2 results
218 | lev2_j_rel = lev2_j - lev1_j
219 | lev2_rel_transform_flt = th_with_zeros(torch.cat([lev2_rots, lev2_j_rel.unsqueeze(3)], 3).view(-1, 3, 4))
220 | lev2_flt = torch.matmul(lev1_flt, lev2_rel_transform_flt)
221 | all_transforms.append(lev2_flt.view(all_rots.shape[0], 5, 4, 4))
222 |
223 | # Get lev3 results
224 | lev3_j_rel = lev3_j - lev2_j
225 | lev3_rel_transform_flt = th_with_zeros(torch.cat([lev3_rots, lev3_j_rel.unsqueeze(3)], 3).view(-1, 3, 4))
226 | lev3_flt = torch.matmul(lev2_flt, lev3_rel_transform_flt)
227 | all_transforms.append(lev3_flt.view(all_rots.shape[0], 5, 4, 4))
228 |
229 | reorder_idxs = [0, 1, 6, 11, 2, 7, 12, 3, 8, 13, 4, 9, 14, 5, 10, 15]
230 | th_results = torch.cat(all_transforms, 1)[:, reorder_idxs]
231 | th_results_global = th_results
232 |
233 | joint_js = torch.cat([th_j, th_j.new_zeros(th_j.shape[0], 16, 1)], 2)
234 | tmp2 = torch.matmul(th_results, joint_js.unsqueeze(3))
235 | th_results2 = (th_results - torch.cat([tmp2.new_zeros(*tmp2.shape[:2], 4, 3), tmp2], 3)).permute(0, 2, 3, 1)
236 |
237 | th_T = torch.matmul(th_results2, self.th_weights.transpose(0, 1))
238 |
239 | th_rest_shape_h = torch.cat([
240 | th_v_posed.transpose(2, 1),
241 | torch.ones((batch_size, 1, th_v_posed.shape[1]),
242 | dtype=th_T.dtype,
243 | device=th_T.device),
244 | ], 1)
245 |
246 | th_verts = (th_T * th_rest_shape_h.unsqueeze(1)).sum(2).transpose(2, 1)
247 | th_verts = th_verts[:, :, :3]
248 | th_jtr = th_results_global[:, :, :3, 3]
249 | # In addition to MANO reference joints we sample vertices on each finger
250 | # to serve as finger tips
251 | if self.side == 'right':
252 | tips = th_verts[:, [745, 317, 444, 556, 673]]
253 | else:
254 | tips = th_verts[:, [745, 317, 445, 556, 673]]
255 | if bool(root_palm):
256 | palm = (th_verts[:, 95] + th_verts[:, 22]).unsqueeze(1) / 2
257 | th_jtr = torch.cat([palm, th_jtr[:, 1:]], 1)
258 | th_jtr = torch.cat([th_jtr, tips], 1)
259 |
260 | # Reorder joints to match visualization utilities
261 | th_jtr = th_jtr[:, [0, 13, 14, 15, 16, 1, 2, 3, 17, 4, 5, 6, 18, 10, 11, 12, 19, 7, 8, 9, 20]]
262 |
263 | if th_trans is None or bool(torch.norm(th_trans) == 0):
264 | if self.center_idx is not None:
265 | center_joint = th_jtr[:, self.center_idx].unsqueeze(1)
266 | th_jtr = th_jtr - center_joint
267 | th_verts = th_verts - center_joint
268 | else:
269 | th_jtr = th_jtr + th_trans.unsqueeze(1)
270 | th_verts = th_verts + th_trans.unsqueeze(1)
271 |
272 | # Scale to milimeters
273 | th_verts = th_verts * 1000
274 | th_jtr = th_jtr * 1000
275 | return th_verts, th_jtr, th_results_global
276 |
--------------------------------------------------------------------------------
/common/utils/manopth/manopth/rodrigues_layer.py:
--------------------------------------------------------------------------------
1 | """
2 | This part reuses code from https://github.com/MandyMo/pytorch_HMR/blob/master/src/util.py
3 | which is part of a PyTorch port of SMPL.
4 | Thanks to Zhang Xiong (MandyMo) for making this great code available on github !
5 | """
6 |
7 | import argparse
8 | from torch.autograd import gradcheck
9 | import torch
10 | from torch.autograd import Variable
11 |
12 | from manopth import argutils
13 |
14 |
15 | def quat2mat(quat):
16 | """Convert quaternion coefficients to rotation matrix.
17 | Args:
18 | quat: size = [batch_size, 4] 4 <===>(w, x, y, z)
19 | Returns:
20 | Rotation matrix corresponding to the quaternion -- size = [batch_size, 3, 3]
21 | """
22 | norm_quat = quat
23 | norm_quat = norm_quat / norm_quat.norm(p=2, dim=1, keepdim=True)
24 | w, x, y, z = norm_quat[:, 0], norm_quat[:, 1], norm_quat[:,
25 | 2], norm_quat[:,
26 | 3]
27 |
28 | batch_size = quat.size(0)
29 |
30 | w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
31 | wx, wy, wz = w * x, w * y, w * z
32 | xy, xz, yz = x * y, x * z, y * z
33 |
34 | rotMat = torch.stack([
35 | w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz, 2 * wz + 2 * xy,
36 | w2 - x2 + y2 - z2, 2 * yz - 2 * wx, 2 * xz - 2 * wy, 2 * wx + 2 * yz,
37 | w2 - x2 - y2 + z2
38 | ],
39 | dim=1).view(batch_size, 3, 3)
40 | return rotMat
41 |
42 |
43 | def batch_rodrigues(axisang):
44 | #axisang N x 3
45 | axisang_norm = torch.norm(axisang + 1e-8, p=2, dim=1)
46 | angle = torch.unsqueeze(axisang_norm, -1)
47 | axisang_normalized = torch.div(axisang, angle)
48 | angle = angle * 0.5
49 | v_cos = torch.cos(angle)
50 | v_sin = torch.sin(angle)
51 | quat = torch.cat([v_cos, v_sin * axisang_normalized], dim=1)
52 | rot_mat = quat2mat(quat)
53 | rot_mat = rot_mat.view(rot_mat.shape[0], 9)
54 | return rot_mat
55 |
56 |
57 | def th_get_axis_angle(vector):
58 | angle = torch.norm(vector, 2, 1)
59 | axes = vector / angle.unsqueeze(1)
60 | return axes, angle
61 |
62 |
63 | if __name__ == '__main__':
64 | parser = argparse.ArgumentParser()
65 | parser.add_argument('--batch_size', default=1, type=int)
66 | parser.add_argument('--cuda', action='store_true')
67 | args = parser.parse_args()
68 |
69 | argutils.print_args(args)
70 |
71 | n_components = 6
72 | rot = 3
73 | inputs = torch.rand(args.batch_size, rot)
74 | inputs_var = Variable(inputs.double(), requires_grad=True)
75 | if args.cuda:
76 | inputs = inputs.cuda()
77 | # outputs = batch_rodrigues(inputs)
78 | test_function = gradcheck(batch_rodrigues, (inputs_var, ))
79 | print('batch test passed !')
80 |
81 | inputs = torch.rand(rot)
82 | inputs_var = Variable(inputs.double(), requires_grad=True)
83 | test_function = gradcheck(th_cv2_rod_sub_id.apply, (inputs_var, ))
84 | print('th_cv2_rod test passed')
85 |
86 | inputs = torch.rand(rot)
87 | inputs_var = Variable(inputs.double(), requires_grad=True)
88 | test_th = gradcheck(th_cv2_rod.apply, (inputs_var, ))
89 | print('th_cv2_rod_id test passed !')
90 |
--------------------------------------------------------------------------------
/common/utils/manopth/manopth/rot6d.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def compute_rotation_matrix_from_ortho6d(poses):
5 | """
6 | Code from
7 | https://github.com/papagina/RotationContinuity
8 | On the Continuity of Rotation Representations in Neural Networks
9 | Zhou et al. CVPR19
10 | https://zhouyisjtu.github.io/project_rotation/rotation.html
11 | """
12 | x_raw = poses[:, 0:3] # batch*3
13 | y_raw = poses[:, 3:6] # batch*3
14 |
15 | x = normalize_vector(x_raw) # batch*3
16 | z = cross_product(x, y_raw) # batch*3
17 | z = normalize_vector(z) # batch*3
18 | y = cross_product(z, x) # batch*3
19 |
20 | x = x.view(-1, 3, 1)
21 | y = y.view(-1, 3, 1)
22 | z = z.view(-1, 3, 1)
23 | matrix = torch.cat((x, y, z), 2) # batch*3*3
24 | return matrix
25 |
26 | def robust_compute_rotation_matrix_from_ortho6d(poses):
27 | """
28 | Instead of making 2nd vector orthogonal to first
29 | create a base that takes into account the two predicted
30 | directions equally
31 | """
32 | x_raw = poses[:, 0:3] # batch*3
33 | y_raw = poses[:, 3:6] # batch*3
34 |
35 | x = normalize_vector(x_raw) # batch*3
36 | y = normalize_vector(y_raw) # batch*3
37 | middle = normalize_vector(x + y)
38 | orthmid = normalize_vector(x - y)
39 | x = normalize_vector(middle + orthmid)
40 | y = normalize_vector(middle - orthmid)
41 | # Their scalar product should be small !
42 | # assert torch.einsum("ij,ij->i", [x, y]).abs().max() < 0.00001
43 | z = normalize_vector(cross_product(x, y))
44 |
45 | x = x.view(-1, 3, 1)
46 | y = y.view(-1, 3, 1)
47 | z = z.view(-1, 3, 1)
48 | matrix = torch.cat((x, y, z), 2) # batch*3*3
49 | # Check for reflection in matrix ! If found, flip last vector TODO
50 | assert (torch.stack([torch.det(mat) for mat in matrix ])< 0).sum() == 0
51 | return matrix
52 |
53 |
54 | def normalize_vector(v):
55 | batch = v.shape[0]
56 | v_mag = torch.sqrt(v.pow(2).sum(1)) # batch
57 | v_mag = torch.max(v_mag, v.new([1e-8]))
58 | v_mag = v_mag.view(batch, 1).expand(batch, v.shape[1])
59 | v = v/v_mag
60 | return v
61 |
62 |
63 | def cross_product(u, v):
64 | batch = u.shape[0]
65 | i = u[:, 1] * v[:, 2] - u[:, 2] * v[:, 1]
66 | j = u[:, 2] * v[:, 0] - u[:, 0] * v[:, 2]
67 | k = u[:, 0] * v[:, 1] - u[:, 1] * v[:, 0]
68 |
69 | out = torch.cat((i.view(batch, 1), j.view(batch, 1), k.view(batch, 1)), 1)
70 |
71 | return out
72 |
--------------------------------------------------------------------------------
/common/utils/manopth/manopth/rotproj.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def batch_rotprojs(batches_rotmats):
5 | proj_rotmats = []
6 | for batch_idx, batch_rotmats in enumerate(batches_rotmats):
7 | proj_batch_rotmats = []
8 | for rot_idx, rotmat in enumerate(batch_rotmats):
9 | # GPU implementation of svd is VERY slow
10 | # ~ 2 10^-3 per hit vs 5 10^-5 on cpu
11 | U, S, V = rotmat.cpu().svd()
12 | rotmat = torch.matmul(U, V.transpose(0, 1))
13 | orth_det = rotmat.det()
14 | # Remove reflection
15 | if orth_det < 0:
16 | rotmat[:, 2] = -1 * rotmat[:, 2]
17 |
18 | rotmat = rotmat.cuda()
19 | proj_batch_rotmats.append(rotmat)
20 | proj_rotmats.append(torch.stack(proj_batch_rotmats))
21 | return torch.stack(proj_rotmats)
22 |
--------------------------------------------------------------------------------
/common/utils/manopth/manopth/tensutils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from manopth import rodrigues_layer
4 |
5 |
6 | def th_posemap_axisang(pose_vectors):
7 | rot_nb = int(pose_vectors.shape[1] / 3)
8 | pose_vec_reshaped = pose_vectors.contiguous().view(-1, 3)
9 | rot_mats = rodrigues_layer.batch_rodrigues(pose_vec_reshaped)
10 | rot_mats = rot_mats.view(pose_vectors.shape[0], rot_nb * 9)
11 | pose_maps = subtract_flat_id(rot_mats)
12 | return pose_maps, rot_mats
13 |
14 |
15 | def th_with_zeros(tensor):
16 | batch_size = tensor.shape[0]
17 | padding = tensor.new([0.0, 0.0, 0.0, 1.0])
18 | padding.requires_grad = False
19 |
20 | concat_list = [tensor, padding.view(1, 1, 4).repeat(batch_size, 1, 1)]
21 | cat_res = torch.cat(concat_list, 1)
22 | return cat_res
23 |
24 |
25 | def th_pack(tensor):
26 | batch_size = tensor.shape[0]
27 | padding = tensor.new_zeros((batch_size, 4, 3))
28 | padding.requires_grad = False
29 | pack_list = [padding, tensor]
30 | pack_res = torch.cat(pack_list, 2)
31 | return pack_res
32 |
33 |
34 | def subtract_flat_id(rot_mats):
35 | # Subtracts identity as a flattened tensor
36 | rot_nb = int(rot_mats.shape[1] / 9)
37 | id_flat = torch.eye(
38 | 3, dtype=rot_mats.dtype, device=rot_mats.device).view(1, 9).repeat(
39 | rot_mats.shape[0], rot_nb)
40 | # id_flat.requires_grad = False
41 | results = rot_mats - id_flat
42 | return results
43 |
44 |
45 | def make_list(tensor):
46 | # type: (List[int]) -> List[int]
47 | return tensor
48 |
--------------------------------------------------------------------------------
/common/utils/manopth/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import find_packages, setup
2 | import warnings
3 |
4 | DEPENDENCY_PACKAGE_NAMES = ["matplotlib", "torch", "tqdm", "numpy", "cv2",
5 | "chumpy"]
6 |
7 |
8 | def check_dependencies():
9 | missing_dependencies = []
10 | for package_name in DEPENDENCY_PACKAGE_NAMES:
11 | try:
12 | __import__(package_name)
13 | except ImportError:
14 | missing_dependencies.append(package_name)
15 |
16 | if missing_dependencies:
17 | warnings.warn(
18 | 'Missing dependencies: {}. We recommend you follow '
19 | 'the installation instructions at '
20 | 'https://github.com/hassony2/manopth#installation'.format(
21 | missing_dependencies))
22 |
23 |
24 | with open("README.md", "r") as fh:
25 | long_description = fh.read()
26 |
27 | check_dependencies()
28 |
29 | setup(
30 | name="manopth",
31 | version="0.0.1",
32 | author="Yana Hasson",
33 | author_email="yana.hasson.inria@gmail.com",
34 | packages=find_packages(exclude=('tests',)),
35 | python_requires=">=3.5.0",
36 | description="PyTorch mano layer",
37 | long_description=long_description,
38 | long_description_content_type="text/markdown",
39 | url="https://github.com/hassony2/manopth",
40 | classifiers=[
41 | "Programming Language :: Python :: 3",
42 | "License :: OSI Approved :: GNU GENERAL PUBLIC LICENSE",
43 | "Operating System :: OS Independent",
44 | ],
45 | )
46 |
--------------------------------------------------------------------------------
/common/utils/manopth/test/test_demo.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from manopth.demo import generate_random_hand
4 |
5 |
6 | def test_generate_random_hand():
7 | batch_size = 3
8 | hand_info = generate_random_hand(batch_size=batch_size, ncomps=6)
9 | verts = hand_info['verts']
10 | joints = hand_info['joints']
11 | assert verts.shape == (batch_size, 778, 3)
12 | assert joints.shape == (batch_size, 21, 3)
13 |
--------------------------------------------------------------------------------
/common/utils/optimizers/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4 | # holder of all proprietary rights on this computer program.
5 | # You can only use this computer program if you have closed
6 | # a license agreement with MPG or you get the right to use the computer
7 | # program from someone who is authorized to grant you that right.
8 | # Any use of the computer program without a valid license is prohibited and
9 | # liable to prosecution.
10 | #
11 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
12 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
13 | # for Intelligent Systems and the Max Planck Institute for Biological
14 | # Cybernetics. All rights reserved.
15 | #
16 | # Contact: ps-license@tuebingen.mpg.de
--------------------------------------------------------------------------------
/common/utils/optimizers/lbfgs_ls.py:
--------------------------------------------------------------------------------
1 | # PyTorch implementation of L-BFGS with Strong Wolfe line search
2 | # Will be removed once https://github.com/pytorch/pytorch/pull/8824
3 | # is merged
4 |
5 | import torch
6 | from functools import reduce
7 |
8 | from torch.optim import Optimizer
9 |
10 |
11 | def _cubic_interpolate(x1, f1, g1, x2, f2, g2, bounds=None):
12 | # ported from https://github.com/torch/optim/blob/master/polyinterp.lua
13 | # Compute bounds of interpolation area
14 | if bounds is not None:
15 | xmin_bound, xmax_bound = bounds
16 | else:
17 | xmin_bound, xmax_bound = (x1, x2) if x1 <= x2 else (x2, x1)
18 |
19 | # Code for most common case: cubic interpolation of 2 points
20 | # w/ function and derivative values for both
21 | # Solution in this case (where x2 is the farthest point):
22 | # d1 = g1 + g2 - 3*(f1-f2)/(x1-x2);
23 | # d2 = sqrt(d1^2 - g1*g2);
24 | # min_pos = x2 - (x2 - x1)*((g2 + d2 - d1)/(g2 - g1 + 2*d2));
25 | # t_new = min(max(min_pos,xmin_bound),xmax_bound);
26 | d1 = g1 + g2 - 3 * (f1 - f2) / (x1 - x2)
27 | d2_square = d1 ** 2 - g1 * g2
28 | if d2_square >= 0:
29 | d2 = d2_square.sqrt()
30 | if x1 <= x2:
31 | min_pos = x2 - (x2 - x1) * ((g2 + d2 - d1) / (g2 - g1 + 2 * d2))
32 | else:
33 | min_pos = x1 - (x1 - x2) * ((g1 + d2 - d1) / (g1 - g2 + 2 * d2))
34 | return min(max(min_pos, xmin_bound), xmax_bound)
35 | else:
36 | return (xmin_bound + xmax_bound) / 2.
37 |
38 |
39 | def _strong_Wolfe(obj_func, x, t, d, f, g, gtd, c1=1e-4, c2=0.9, tolerance_change=1e-9,
40 | max_iter=20,
41 | max_ls=25):
42 | # ported from https://github.com/torch/optim/blob/master/lswolfe.lua
43 | d_norm = d.abs().max()
44 | g = g.clone()
45 | # evaluate objective and gradient using initial step
46 | f_new, g_new = obj_func(x, t, d)
47 | ls_func_evals = 1
48 | gtd_new = g_new.dot(d)
49 |
50 | # bracket an interval containing a point satisfying the Wolfe criteria
51 | t_prev, f_prev, g_prev, gtd_prev = 0, f, g, gtd
52 | done = False
53 | ls_iter = 0
54 | while ls_iter < max_ls:
55 | # check conditions
56 | if f_new > (f + c1 * t * gtd) or (ls_iter > 1 and f_new >= f_prev):
57 | bracket = [t_prev, t]
58 | bracket_f = [f_prev, f_new]
59 | bracket_g = [g_prev, g_new.clone()]
60 | bracket_gtd = [gtd_prev, gtd_new]
61 | break
62 |
63 | if abs(gtd_new) <= -c2 * gtd:
64 | bracket = [t]
65 | bracket_f = [f_new]
66 | bracket_g = [g_new]
67 | bracket_gtd = [gtd_new]
68 | done = True
69 | break
70 |
71 | if gtd_new >= 0:
72 | bracket = [t_prev, t]
73 | bracket_f = [f_prev, f_new]
74 | bracket_g = [g_prev, g_new.clone()]
75 | bracket_gtd = [gtd_prev, gtd_new]
76 | break
77 |
78 | # interpolate
79 | min_step = t + 0.01 * (t - t_prev)
80 | max_step = t * 10
81 | tmp = t
82 | t = _cubic_interpolate(t_prev, f_prev, gtd_prev, t, f_new, gtd_new,
83 | bounds=(min_step, max_step))
84 |
85 | # next step
86 | t_prev = tmp
87 | f_prev = f_new
88 | g_prev = g_new.clone()
89 | gtd_prev = gtd_new
90 | f_new, g_new = obj_func(x, t, d)
91 | ls_func_evals += 1
92 | gtd_new = g_new.dot(d)
93 | ls_iter += 1
94 |
95 | # reached max number of iterations?
96 | if ls_iter == max_ls:
97 | bracket = [0, t]
98 | bracket_f = [f, f_new]
99 | bracket_g = [g, g_new]
100 | bracket_gtd = [gtd, gtd_new]
101 |
102 | # zoom phase: we now have a point satisfying the criteria, or
103 | # a bracket around it. We refine the bracket until we find the
104 | # exact point satisfying the criteria
105 | insuf_progress = False
106 | # find high and low points in bracket
107 | low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[-1] else (1, 0)
108 | while not done and ls_iter < max_iter:
109 | # compute new trial value
110 | t = _cubic_interpolate(bracket[0], bracket_f[0], bracket_gtd[0],
111 | bracket[1], bracket_f[1], bracket_gtd[1])
112 |
113 | # test what we are making sufficient progress
114 | eps = 0.1 * (max(bracket) - min(bracket))
115 | if min(max(bracket) - t, t - min(bracket)) < eps:
116 | # interpolation close to boundary
117 | if insuf_progress or t >= max(bracket) or t <= min(bracket):
118 | # evaluate at 0.1 away from boundary
119 | if abs(t - max(bracket)) < abs(t - min(bracket)):
120 | t = max(bracket) - eps
121 | else:
122 | t = min(bracket) + eps
123 | insuf_progress = False
124 | else:
125 | insuf_progress = True
126 | else:
127 | insuf_progress = False
128 |
129 | # Evaluate new point
130 | f_new, g_new = obj_func(x, t, d)
131 | ls_func_evals += 1
132 | gtd_new = g_new.dot(d)
133 | ls_iter += 1
134 |
135 | if f_new > (f + c1 * t * gtd) or f_new >= bracket_f[low_pos]:
136 | # Armijo condition not satisfied or not lower than lowest point
137 | bracket[high_pos] = t
138 | bracket_f[high_pos] = f_new
139 | bracket_g[high_pos] = g_new.clone()
140 | bracket_gtd[high_pos] = gtd_new
141 | low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[1] else (1, 0)
142 | else:
143 | if abs(gtd_new) <= -c2 * gtd:
144 | # Wolfe conditions satisfied
145 | done = True
146 | elif gtd_new * (bracket[high_pos] - bracket[low_pos]) >= 0:
147 | # old high becomes new low
148 | bracket[high_pos] = bracket[low_pos]
149 | bracket_f[high_pos] = bracket_f[low_pos]
150 | bracket_g[high_pos] = bracket_g[low_pos]
151 | bracket_gtd[high_pos] = bracket_gtd[low_pos]
152 |
153 | # new point becomes new low
154 | bracket[low_pos] = t
155 | bracket_f[low_pos] = f_new
156 | bracket_g[low_pos] = g_new.clone()
157 | bracket_gtd[low_pos] = gtd_new
158 |
159 | # line-search bracket is so small
160 | if abs(bracket[1] - bracket[0]) * d_norm < tolerance_change:
161 | break
162 |
163 | # return stuff
164 | t = bracket[low_pos]
165 | f_new = bracket_f[low_pos]
166 | g_new = bracket_g[low_pos]
167 | return f_new, g_new, t, ls_func_evals
168 |
169 |
170 | # LBFGS with strong Wolfe line search introduces in PR #8824
171 | # Will be removed once merged with master
172 | class LBFGS(Optimizer):
173 | """Implements L-BFGS algorithm, heavily inspired by `minFunc
174 | `.
175 | .. warning::
176 | This optimizer doesn't support per-parameter options and parameter
177 | groups (there can be only one).
178 | .. warning::
179 | Right now all parameters have to be on a single device. This will be
180 | improved in the future.
181 | .. note::
182 | This is a very memory intensive optimizer (it requires additional
183 | ``param_bytes * (history_size + 1)`` bytes). If it doesn't fit in memory
184 | try reducing the history size, or use a different algorithm.
185 | Arguments:
186 | lr (float): learning rate (default: 1)
187 | max_iter (int): maximal number of iterations per optimization step
188 | (default: 20)
189 | max_eval (int): maximal number of function evaluations per optimization
190 | step (default: max_iter * 1.25).
191 | tolerance_grad (float): termination tolerance on first order optimality
192 | (default: 1e-5).
193 | tolerance_change (float): termination tolerance on function
194 | value/parameter changes (default: 1e-9).
195 | history_size (int): update history size (default: 100).
196 | line_search_fn (str): either 'strong_Wolfe' or None (default: None).
197 | """
198 |
199 | def __init__(self, params, lr=1, max_iter=20, max_eval=None,
200 | tolerance_grad=1e-5, tolerance_change=1e-9, history_size=100,
201 | line_search_fn=None):
202 | if max_eval is None:
203 | max_eval = max_iter * 5 // 4
204 | defaults = dict(lr=lr, max_iter=max_iter, max_eval=max_eval,
205 | tolerance_grad=tolerance_grad, tolerance_change=tolerance_change,
206 | history_size=history_size, line_search_fn=line_search_fn)
207 | super(LBFGS, self).__init__(params, defaults)
208 |
209 | if len(self.param_groups) != 1:
210 | raise ValueError("LBFGS doesn't support per-parameter options "
211 | "(parameter groups)")
212 |
213 | self._params = self.param_groups[0]['params']
214 | self._numel_cache = None
215 |
216 | def _numel(self):
217 | if self._numel_cache is None:
218 | self._numel_cache = reduce(lambda total, p: total + p.numel(), self._params, 0)
219 | return self._numel_cache
220 |
221 | def _gather_flat_grad(self):
222 | views = []
223 | for p in self._params:
224 | if p.grad is None:
225 | view = p.new(p.numel()).zero_()
226 | elif p.grad.is_sparse:
227 | view = p.grad.to_dense().view(-1)
228 | else:
229 | view = p.grad.view(-1)
230 | views.append(view)
231 | return torch.cat(views, 0)
232 |
233 | def _add_grad(self, step_size, update):
234 | offset = 0
235 | for p in self._params:
236 | numel = p.numel()
237 | # view as to avoid deprecated pointwise semantics
238 | p.data.add_(step_size, update[offset:offset + numel].view_as(p.data))
239 | offset += numel
240 | assert offset == self._numel()
241 |
242 | def _clone_param(self):
243 | return [p.clone() for p in self._params]
244 |
245 | def _set_param(self, params_data):
246 | for p, pdata in zip(self._params, params_data):
247 | p.data.copy_(pdata)
248 |
249 | def _directional_evaluate(self, closure, x, t, d):
250 | self._add_grad(t, d)
251 | loss = float(closure())
252 | flat_grad = self._gather_flat_grad()
253 | self._set_param(x)
254 | return loss, flat_grad
255 |
256 | def step(self, closure):
257 | """Performs a single optimization step.
258 | Arguments:
259 | closure (callable): A closure that reevaluates the model
260 | and returns the loss.
261 | """
262 | assert len(self.param_groups) == 1
263 |
264 | group = self.param_groups[0]
265 | lr = group['lr']
266 | max_iter = group['max_iter']
267 | max_eval = group['max_eval']
268 | tolerance_grad = group['tolerance_grad']
269 | tolerance_change = group['tolerance_change']
270 | line_search_fn = group['line_search_fn']
271 | history_size = group['history_size']
272 |
273 | # NOTE: LBFGS has only global state, but we register it as state for
274 | # the first param, because this helps with casting in load_state_dict
275 | state = self.state[self._params[0]]
276 | state.setdefault('func_evals', 0)
277 | state.setdefault('n_iter', 0)
278 |
279 | # evaluate initial f(x) and df/dx
280 | orig_loss = closure()
281 | loss = float(orig_loss)
282 | current_evals = 1
283 | state['func_evals'] += 1
284 |
285 | flat_grad = self._gather_flat_grad()
286 | opt_cond = flat_grad.abs().max() <= tolerance_grad
287 |
288 | # optimal condition
289 | if opt_cond:
290 | return orig_loss
291 |
292 | # tensors cached in state (for tracing)
293 | d = state.get('d')
294 | t = state.get('t')
295 | old_dirs = state.get('old_dirs')
296 | old_stps = state.get('old_stps')
297 | ro = state.get('ro')
298 | H_diag = state.get('H_diag')
299 | prev_flat_grad = state.get('prev_flat_grad')
300 | prev_loss = state.get('prev_loss')
301 |
302 | n_iter = 0
303 | # optimize for a max of max_iter iterations
304 | while n_iter < max_iter:
305 | # keep track of nb of iterations
306 | n_iter += 1
307 | state['n_iter'] += 1
308 |
309 | ############################################################
310 | # compute gradient descent direction
311 | ############################################################
312 | if state['n_iter'] == 1:
313 | d = flat_grad.neg()
314 | old_dirs = []
315 | old_stps = []
316 | ro = []
317 | H_diag = 1
318 | else:
319 | # do lbfgs update (update memory)
320 | y = flat_grad.sub(prev_flat_grad)
321 | s = d.mul(t)
322 | ys = y.dot(s) # y*s
323 | if ys > 1e-10:
324 | # updating memory
325 | if len(old_dirs) == history_size:
326 | # shift history by one (limited-memory)
327 | old_dirs.pop(0)
328 | old_stps.pop(0)
329 | ro.pop(0)
330 |
331 | # store new direction/step
332 | old_dirs.append(y)
333 | old_stps.append(s)
334 | ro.append(1. / ys)
335 |
336 | # update scale of initial Hessian approximation
337 | H_diag = ys / y.dot(y) # (y*y)
338 |
339 | # compute the approximate (L-BFGS) inverse Hessian
340 | # multiplied by the gradient
341 | num_old = len(old_dirs)
342 |
343 | if 'al' not in state:
344 | state['al'] = [None] * history_size
345 | al = state['al']
346 |
347 | # iteration in L-BFGS loop collapsed to use just one buffer
348 | q = flat_grad.neg()
349 | for i in range(num_old - 1, -1, -1):
350 | al[i] = old_stps[i].dot(q) * ro[i]
351 | q.add_(-al[i], old_dirs[i])
352 |
353 | # multiply by initial Hessian
354 | # r/d is the final direction
355 | d = r = torch.mul(q, H_diag)
356 | for i in range(num_old):
357 | be_i = old_dirs[i].dot(r) * ro[i]
358 | r.add_(al[i] - be_i, old_stps[i])
359 |
360 | if prev_flat_grad is None:
361 | prev_flat_grad = flat_grad.clone()
362 | else:
363 | prev_flat_grad.copy_(flat_grad)
364 | prev_loss = loss
365 |
366 | ############################################################
367 | # compute step length
368 | ############################################################
369 | # reset initial guess for step size
370 | if state['n_iter'] == 1:
371 | t = min(1., 1. / flat_grad.abs().sum()) * lr
372 | else:
373 | t = lr
374 |
375 | # directional derivative
376 | gtd = flat_grad.dot(d) # g * d
377 |
378 | # directional derivative is below tolerance
379 | if gtd > -tolerance_change:
380 | break
381 |
382 | # optional line search: user function
383 | ls_func_evals = 0
384 | if line_search_fn is not None:
385 | # perform line search, using user function
386 | if line_search_fn != "strong_Wolfe":
387 | raise RuntimeError("only 'strong_Wolfe' is supported")
388 | else:
389 | x_init = self._clone_param()
390 |
391 | def obj_func(x, t, d):
392 | return self._directional_evaluate(closure, x, t, d)
393 | loss, flat_grad, t, ls_func_evals = _strong_Wolfe(obj_func, x_init, t, d,
394 | loss,
395 | flat_grad,
396 | gtd,
397 | max_iter=max_iter)
398 | self._add_grad(t, d)
399 | opt_cond = flat_grad.abs().max() <= tolerance_grad
400 | else:
401 | # no line search, simply move with fixed-step
402 | self._add_grad(t, d)
403 | if n_iter != max_iter:
404 | # re-evaluate function only if not in last iteration
405 | # the reason we do this: in a stochastic setting,
406 | # no use to re-evaluate that function here
407 | loss = float(closure())
408 | flat_grad = self._gather_flat_grad()
409 | opt_cond = flat_grad.abs().max() <= tolerance_grad
410 | ls_func_evals = 1
411 |
412 | # update func eval
413 | current_evals += ls_func_evals
414 | state['func_evals'] += ls_func_evals
415 |
416 | ############################################################
417 | # check conditions
418 | ############################################################
419 | if n_iter == max_iter:
420 | break
421 |
422 | if current_evals >= max_eval:
423 | break
424 |
425 | # optimal condition
426 | if opt_cond:
427 | break
428 |
429 | # lack of progress
430 | if d.mul(t).abs().max() <= tolerance_change:
431 | break
432 |
433 | if abs(loss - prev_loss) < tolerance_change:
434 | break
435 |
436 | state['d'] = d
437 | state['t'] = t
438 | state['old_dirs'] = old_dirs
439 | state['old_stps'] = old_stps
440 | state['ro'] = ro
441 | state['H_diag'] = H_diag
442 | state['prev_flat_grad'] = prev_flat_grad
443 | state['prev_loss'] = prev_loss
444 |
445 | return orig_loss
--------------------------------------------------------------------------------
/common/utils/optimizers/optim_factory.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4 | # holder of all proprietary rights on this computer program.
5 | # You can only use this computer program if you have closed
6 | # a license agreement with MPG or you get the right to use the computer
7 | # program from someone who is authorized to grant you that right.
8 | # Any use of the computer program without a valid license is prohibited and
9 | # liable to prosecution.
10 | #
11 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
12 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
13 | # for Intelligent Systems and the Max Planck Institute for Biological
14 | # Cybernetics. All rights reserved.
15 | #
16 | # Contact: ps-license@tuebingen.mpg.de
17 |
18 |
19 | from __future__ import absolute_import
20 | from __future__ import print_function
21 | from __future__ import division
22 |
23 | import torch.optim as optim
24 | from .lbfgs_ls import LBFGS as LBFGSLs
25 |
26 |
27 | def create_optimizer(parameters, optim_type='lbfgs',
28 | lr=1e-3,
29 | momentum=0.9,
30 | use_nesterov=True,
31 | beta1=0.9,
32 | beta2=0.999,
33 | epsilon=1e-8,
34 | use_locking=False,
35 | weight_decay=0.0,
36 | centered=False,
37 | rmsprop_alpha=0.99,
38 | maxiters=20,
39 | gtol=1e-6,
40 | ftol=1e-9,
41 | **kwargs):
42 | ''' Creates the optimizer
43 | '''
44 | if optim_type == 'adam':
45 | return (optim.Adam(parameters, lr=lr, betas=(beta1, beta2),
46 | weight_decay=weight_decay),
47 | False)
48 | elif optim_type == 'lbfgs':
49 | return (optim.LBFGS(parameters, lr=lr, max_iter=maxiters), False)
50 | elif optim_type == 'lbfgsls':
51 | return LBFGSLs(parameters, lr=lr, max_iter=maxiters,
52 | line_search_fn='strong_Wolfe'), False
53 | elif optim_type == 'rmsprop':
54 | return (optim.RMSprop(parameters, lr=lr, epsilon=epsilon,
55 | alpha=rmsprop_alpha,
56 | weight_decay=weight_decay,
57 | momentum=momentum, centered=centered),
58 | False)
59 | elif optim_type == 'sgd':
60 | return (optim.SGD(parameters, lr=lr, momentum=momentum,
61 | weight_decay=weight_decay,
62 | nesterov=use_nesterov),
63 | False)
64 | else:
65 | raise ValueError('Optimizer {} not supported!'.format(optim_type))
--------------------------------------------------------------------------------
/common/utils/preprocessing.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | import random
4 | from config import cfg
5 | import math
6 | import torchvision
7 |
8 | def load_img(path, order='RGB'):
9 | img = cv2.imread(path, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION)
10 | if not isinstance(img, np.ndarray):
11 | raise IOError("Fail to read %s" % path)
12 |
13 | if order=='RGB':
14 | img = img[:,:,::-1].copy()
15 |
16 | img = img.astype(np.float32)
17 | return img
18 |
19 | def get_bbox(joint_img, joint_valid, expansion_factor=1.0):
20 |
21 | x_img, y_img = joint_img[:,0], joint_img[:,1]
22 | x_img = x_img[joint_valid==1]; y_img = y_img[joint_valid==1];
23 | xmin = min(x_img); ymin = min(y_img); xmax = max(x_img); ymax = max(y_img);
24 |
25 | x_center = (xmin+xmax)/2.; width = (xmax-xmin)*expansion_factor;
26 | xmin = x_center - 0.5*width
27 | xmax = x_center + 0.5*width
28 |
29 | y_center = (ymin+ymax)/2.; height = (ymax-ymin)*expansion_factor;
30 | ymin = y_center - 0.5*height
31 | ymax = y_center + 0.5*height
32 |
33 | bbox = np.array([xmin, ymin, xmax - xmin, ymax - ymin]).astype(np.float32)
34 | return bbox
35 |
36 | def process_bbox(bbox, img_width, img_height, expansion_factor=1.25):
37 | # sanitize bboxes
38 | x, y, w, h = bbox
39 | x1 = np.max((0, x))
40 | y1 = np.max((0, y))
41 | x2 = np.min((img_width - 1, x1 + np.max((0, w - 1))))
42 | y2 = np.min((img_height - 1, y1 + np.max((0, h - 1))))
43 | if w*h > 0 and x2 >= x1 and y2 >= y1:
44 | bbox = np.array([x1, y1, x2-x1, y2-y1])
45 | else:
46 | return None
47 |
48 | # aspect ratio preserving bbox
49 | w = bbox[2]
50 | h = bbox[3]
51 | c_x = bbox[0] + w/2.
52 | c_y = bbox[1] + h/2.
53 | aspect_ratio = cfg.input_img_shape[1]/cfg.input_img_shape[0]
54 | if w > aspect_ratio * h:
55 | h = w / aspect_ratio
56 | elif w < aspect_ratio * h:
57 | w = h * aspect_ratio
58 | bbox[2] = w*expansion_factor
59 | bbox[3] = h*expansion_factor
60 | bbox[0] = c_x - bbox[2]/2.
61 | bbox[1] = c_y - bbox[3]/2.
62 |
63 | return bbox
64 |
65 | def get_aug_config():
66 | scale_factor = 0.25
67 | rot_factor = 30
68 | color_factor = 0.2
69 |
70 | scale = np.clip(np.random.randn(), -1.0, 1.0) * scale_factor + 1.0
71 | rot = np.clip(np.random.randn(), -2.0,
72 | 2.0) * rot_factor if random.random() <= 0.6 else 0
73 | c_up = 1.0 + color_factor
74 | c_low = 1.0 - color_factor
75 | color_scale = np.array([random.uniform(c_low, c_up), random.uniform(c_low, c_up), random.uniform(c_low, c_up)])
76 |
77 | return scale, rot, color_scale
78 |
79 | def augmentation(img, bbox, data_split, do_flip=False):
80 | if data_split == 'train':
81 | scale, rot, color_scale = get_aug_config()
82 | else:
83 | scale, rot, color_scale = 1.0, 0.0, np.array([1,1,1])
84 | img, trans, inv_trans = generate_patch_image(img, bbox, scale, rot, do_flip, cfg.input_img_shape)
85 |
86 | img = np.clip(img * color_scale[None,None,:], 0, 255)
87 | return img, trans, inv_trans, rot, scale
88 |
89 | def generate_patch_image(cvimg, bbox, scale, rot, do_flip, out_shape):
90 | img = cvimg.copy()
91 | img_height, img_width, img_channels = img.shape
92 |
93 | bb_c_x = float(bbox[0] + 0.5*bbox[2])
94 | bb_c_y = float(bbox[1] + 0.5*bbox[3])
95 | bb_width = float(bbox[2])
96 | bb_height = float(bbox[3])
97 |
98 | if do_flip:
99 | img = img[:, ::-1, :]
100 | bb_c_x = img_width - bb_c_x - 1
101 |
102 | trans = gen_trans_from_patch_cv(bb_c_x, bb_c_y, bb_width, bb_height, out_shape[1], out_shape[0], scale, rot)
103 | img_patch = cv2.warpAffine(img, trans, (int(out_shape[1]), int(out_shape[0])), flags=cv2.INTER_LINEAR)
104 | img_patch = img_patch.astype(np.float32)
105 | inv_trans = gen_trans_from_patch_cv(bb_c_x, bb_c_y, bb_width, bb_height, out_shape[1], out_shape[0], scale, rot, inv=True)
106 |
107 | return img_patch, trans, inv_trans
108 |
109 | def rotate_2d(pt_2d, rot_rad):
110 | x = pt_2d[0]
111 | y = pt_2d[1]
112 | sn, cs = np.sin(rot_rad), np.cos(rot_rad)
113 | xx = x * cs - y * sn
114 | yy = x * sn + y * cs
115 | return np.array([xx, yy], dtype=np.float32)
116 |
117 | def gen_trans_from_patch_cv(c_x, c_y, src_width, src_height, dst_width, dst_height, scale, rot, inv=False):
118 | # augment size with scale
119 | src_w = src_width * scale
120 | src_h = src_height * scale
121 | src_center = np.array([c_x, c_y], dtype=np.float32)
122 |
123 | # augment rotation
124 | rot_rad = np.pi * rot / 180
125 | src_downdir = rotate_2d(np.array([0, src_h * 0.5], dtype=np.float32), rot_rad)
126 | src_rightdir = rotate_2d(np.array([src_w * 0.5, 0], dtype=np.float32), rot_rad)
127 |
128 | dst_w = dst_width
129 | dst_h = dst_height
130 | dst_center = np.array([dst_w * 0.5, dst_h * 0.5], dtype=np.float32)
131 | dst_downdir = np.array([0, dst_h * 0.5], dtype=np.float32)
132 | dst_rightdir = np.array([dst_w * 0.5, 0], dtype=np.float32)
133 |
134 | src = np.zeros((3, 2), dtype=np.float32)
135 | src[0, :] = src_center
136 | src[1, :] = src_center + src_downdir
137 | src[2, :] = src_center + src_rightdir
138 |
139 | dst = np.zeros((3, 2), dtype=np.float32)
140 | dst[0, :] = dst_center
141 | dst[1, :] = dst_center + dst_downdir
142 | dst[2, :] = dst_center + dst_rightdir
143 |
144 | if inv:
145 | trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
146 | else:
147 | trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
148 |
149 | trans = trans.astype(np.float32)
150 | return trans
151 |
--------------------------------------------------------------------------------
/common/utils/transforms.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from config import cfg
4 |
5 | def cam2pixel(cam_coord, f, c):
6 | x = cam_coord[:,0] / cam_coord[:,2] * f[0] + c[0]
7 | y = cam_coord[:,1] / cam_coord[:,2] * f[1] + c[1]
8 | z = cam_coord[:,2]
9 | return np.stack((x,y,z),1)
10 |
11 | def pixel2cam(pixel_coord, f, c):
12 | x = (pixel_coord[:,0] - c[0]) / f[0] * pixel_coord[:,2]
13 | y = (pixel_coord[:,1] - c[1]) / f[1] * pixel_coord[:,2]
14 | z = pixel_coord[:,2]
15 | return np.stack((x,y,z),1)
16 |
17 | def world2cam(world_coord, R, t):
18 | cam_coord = np.dot(R, world_coord.transpose(1,0)).transpose(1,0) + t.reshape(1,3)
19 | return cam_coord
20 |
21 | def cam2world(cam_coord, R, t):
22 | world_coord = np.dot(np.linalg.inv(R), (cam_coord - t.reshape(1,3)).transpose(1,0)).transpose(1,0)
23 | return world_coord
24 |
25 | def rigid_transform_3D(A, B):
26 | n, dim = A.shape
27 | centroid_A = np.mean(A, axis = 0)
28 | centroid_B = np.mean(B, axis = 0)
29 | H = np.dot(np.transpose(A - centroid_A), B - centroid_B) / n
30 | U, s, V = np.linalg.svd(H)
31 | R = np.dot(np.transpose(V), np.transpose(U))
32 | if np.linalg.det(R) < 0:
33 | s[-1] = -s[-1]
34 | V[2] = -V[2]
35 | R = np.dot(np.transpose(V), np.transpose(U))
36 |
37 | varP = np.var(A, axis=0).sum()
38 | c = 1/varP * np.sum(s)
39 |
40 | t = -np.dot(c*R, np.transpose(centroid_A)) + np.transpose(centroid_B)
41 | return c, R, t
42 |
43 | def rigid_align(A, B):
44 | c, R, t = rigid_transform_3D(A, B)
45 | A2 = np.transpose(np.dot(c*R, np.transpose(A))) + t
46 | return A2
47 |
48 | def transform_joint_to_other_db(src_joint, src_name, dst_name):
49 | src_joint_num = len(src_name)
50 | dst_joint_num = len(dst_name)
51 |
52 | new_joint = np.zeros(((dst_joint_num,) + src_joint.shape[1:]), dtype=np.float32)
53 | for src_idx in range(len(src_name)):
54 | name = src_name[src_idx]
55 | if name in dst_name:
56 | dst_idx = dst_name.index(name)
57 | new_joint[dst_idx] = src_joint[src_idx]
58 |
59 | return new_joint
--------------------------------------------------------------------------------
/common/utils/vis.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import numpy as np
4 | from mpl_toolkits.mplot3d import Axes3D
5 | import matplotlib.pyplot as plt
6 | import matplotlib as mpl
7 | os.environ["PYOPENGL_PLATFORM"] = "egl"
8 | import pyrender
9 | import trimesh
10 |
11 | def vis_keypoints_with_skeleton(img, kps, kps_lines, kp_thresh=0.4, alpha=1):
12 | # Convert from plt 0-1 RGBA colors to 0-255 BGR colors for opencv.
13 | cmap = plt.get_cmap('rainbow')
14 | colors = [cmap(i) for i in np.linspace(0, 1, len(kps_lines) + 2)]
15 | colors = [(c[2] * 255, c[1] * 255, c[0] * 255) for c in colors]
16 |
17 | # Perform the drawing on a copy of the image, to allow for blending.
18 | kp_mask = np.copy(img)
19 |
20 | # Draw the keypoints.
21 | for l in range(len(kps_lines)):
22 | i1 = kps_lines[l][0]
23 | i2 = kps_lines[l][1]
24 | p1 = kps[0, i1].astype(np.int32), kps[1, i1].astype(np.int32)
25 | p2 = kps[0, i2].astype(np.int32), kps[1, i2].astype(np.int32)
26 | if kps[2, i1] > kp_thresh and kps[2, i2] > kp_thresh:
27 | cv2.line(
28 | kp_mask, p1, p2,
29 | color=colors[l], thickness=2, lineType=cv2.LINE_AA)
30 | if kps[2, i1] > kp_thresh:
31 | cv2.circle(
32 | kp_mask, p1,
33 | radius=3, color=colors[l], thickness=-1, lineType=cv2.LINE_AA)
34 | if kps[2, i2] > kp_thresh:
35 | cv2.circle(
36 | kp_mask, p2,
37 | radius=3, color=colors[l], thickness=-1, lineType=cv2.LINE_AA)
38 |
39 | # Blend the keypoints.
40 | return cv2.addWeighted(img, 1.0 - alpha, kp_mask, alpha, 0)
41 |
42 | def vis_keypoints(img, kps, alpha=1):
43 | # Convert from plt 0-1 RGBA colors to 0-255 BGR colors for opencv.
44 | cmap = plt.get_cmap('rainbow')
45 | colors = [cmap(i) for i in np.linspace(0, 1, len(kps) + 2)]
46 | colors = [(c[2] * 255, c[1] * 255, c[0] * 255) for c in colors]
47 |
48 | # Perform the drawing on a copy of the image, to allow for blending.
49 | kp_mask = np.copy(img)
50 |
51 | # Draw the keypoints.
52 | for i in range(len(kps)):
53 | p = kps[i][0].astype(np.int32), kps[i][1].astype(np.int32)
54 | cv2.circle(kp_mask, p, radius=3, color=colors[i], thickness=-1, lineType=cv2.LINE_AA)
55 |
56 | # Blend the keypoints.
57 | return cv2.addWeighted(img, 1.0 - alpha, kp_mask, alpha, 0)
58 |
59 | def vis_mesh(img, mesh_vertex, alpha=0.5):
60 | # Convert from plt 0-1 RGBA colors to 0-255 BGR colors for opencv.
61 | cmap = plt.get_cmap('rainbow')
62 | colors = [cmap(i) for i in np.linspace(0, 1, len(mesh_vertex))]
63 | colors = [(c[2] * 255, c[1] * 255, c[0] * 255) for c in colors]
64 |
65 | # Perform the drawing on a copy of the image, to allow for blending.
66 | mask = np.copy(img)
67 |
68 | # Draw the mesh
69 | for i in range(len(mesh_vertex)):
70 | p = mesh_vertex[i][0].astype(np.int32), mesh_vertex[i][1].astype(np.int32)
71 | cv2.circle(mask, p, radius=1, color=colors[i], thickness=-1, lineType=cv2.LINE_AA)
72 |
73 | # Blend the keypoints.
74 | return cv2.addWeighted(img, 1.0 - alpha, mask, alpha, 0)
75 |
76 | def vis_3d_skeleton(kpt_3d, kpt_3d_vis, kps_lines, filename=None):
77 |
78 | fig = plt.figure()
79 | ax = fig.add_subplot(111, projection='3d')
80 |
81 | # Convert from plt 0-1 RGBA colors to 0-255 BGR colors for opencv.
82 | cmap = plt.get_cmap('rainbow')
83 | colors = [cmap(i) for i in np.linspace(0, 1, len(kps_lines) + 2)]
84 | colors = [np.array((c[2], c[1], c[0])) for c in colors]
85 |
86 | for l in range(len(kps_lines)):
87 | i1 = kps_lines[l][0]
88 | i2 = kps_lines[l][1]
89 | x = np.array([kpt_3d[i1,0], kpt_3d[i2,0]])
90 | y = np.array([kpt_3d[i1,1], kpt_3d[i2,1]])
91 | z = np.array([kpt_3d[i1,2], kpt_3d[i2,2]])
92 |
93 | if kpt_3d_vis[i1,0] > 0 and kpt_3d_vis[i2,0] > 0:
94 | ax.plot(x, z, -y, c=colors[l], linewidth=2)
95 | if kpt_3d_vis[i1,0] > 0:
96 | ax.scatter(kpt_3d[i1,0], kpt_3d[i1,2], -kpt_3d[i1,1], c=colors[l], marker='o')
97 | if kpt_3d_vis[i2,0] > 0:
98 | ax.scatter(kpt_3d[i2,0], kpt_3d[i2,2], -kpt_3d[i2,1], c=colors[l], marker='o')
99 |
100 | if filename is None:
101 | ax.set_title('3D vis')
102 | else:
103 | ax.set_title(filename)
104 |
105 | ax.set_xlabel('X Label')
106 | ax.set_ylabel('Z Label')
107 | ax.set_zlabel('Y Label')
108 | ax.legend()
109 |
110 | #plt.show()
111 | #cv2.waitKey(0)
112 |
113 | plt.savefig(filename)
114 |
115 | def save_obj(v, f, file_name='output.obj'):
116 | obj_file = open(file_name, 'w')
117 | for i in range(len(v)):
118 | obj_file.write('v ' + str(v[i][0]) + ' ' + str(v[i][1]) + ' ' + str(v[i][2]) + '\n')
119 | for i in range(len(f)):
120 | obj_file.write('f ' + str(f[i][0]+1) + '/' + str(f[i][0]+1) + ' ' + str(f[i][1]+1) + '/' + str(f[i][1]+1) + ' ' + str(f[i][2]+1) + '/' + str(f[i][2]+1) + '\n')
121 | obj_file.close()
122 |
123 | def render_mesh(img, mesh, face, cam_param):
124 | # mesh
125 | mesh = trimesh.Trimesh(mesh, face)
126 | rot = trimesh.transformations.rotation_matrix(
127 | np.radians(180), [1, 0, 0])
128 | mesh.apply_transform(rot)
129 | material = pyrender.MetallicRoughnessMaterial(metallicFactor=0.0, alphaMode='OPAQUE', baseColorFactor=(1.0, 1.0, 0.9, 1.0))
130 | mesh = pyrender.Mesh.from_trimesh(mesh, material=material, smooth=False)
131 | scene = pyrender.Scene(ambient_light=(0.3, 0.3, 0.3))
132 | scene.add(mesh, 'mesh')
133 |
134 | focal, princpt = cam_param['focal'], cam_param['princpt']
135 | camera = pyrender.IntrinsicsCamera(fx=focal[0], fy=focal[1], cx=princpt[0], cy=princpt[1])
136 | scene.add(camera)
137 |
138 | # renderer
139 | renderer = pyrender.OffscreenRenderer(viewport_width=img.shape[1], viewport_height=img.shape[0], point_size=1.0)
140 |
141 | # light
142 | light = pyrender.DirectionalLight(color=[1.0, 1.0, 1.0], intensity=0.8)
143 | light_pose = np.eye(4)
144 | light_pose[:3, 3] = np.array([0, -1, 1])
145 | scene.add(light, pose=light_pose)
146 | light_pose[:3, 3] = np.array([0, 1, 1])
147 | scene.add(light, pose=light_pose)
148 | light_pose[:3, 3] = np.array([1, 1, 2])
149 | scene.add(light, pose=light_pose)
150 |
151 | # render
152 | rgb, depth = renderer.render(scene, flags=pyrender.RenderFlags.RGBA)
153 | rgb = rgb[:,:,:3].astype(np.float32)
154 | valid_mask = (depth > 0)[:,:,None]
155 |
156 | # save to image
157 | img = rgb * valid_mask*0.5 + img #* (1-valid_mask)
158 | return img
--------------------------------------------------------------------------------
/data/DEX_YCB/DEX_YCB.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import numpy as np
4 | import torch
5 | import cv2
6 | import random
7 | import json
8 | import math
9 | import copy
10 | from pycocotools.coco import COCO
11 | from config import cfg
12 | from utils.preprocessing import load_img, get_bbox, process_bbox, generate_patch_image, augmentation
13 | from utils.transforms import world2cam, cam2pixel, pixel2cam, rigid_align, transform_joint_to_other_db
14 | from utils.vis import vis_keypoints, vis_mesh, save_obj, vis_keypoints_with_skeleton, render_mesh, vis_3d_skeleton
15 | from utils.mano import MANO
16 | mano = MANO()
17 |
18 | # # TEMP; test set
19 | # target_img_list = {
20 | # 1: ['20200820-subject-03/20200820_135508/836212060125/color_000030.jpg', '20200820-subject-03/20200820_135508/836212060125/color_000060.jpg', '20200903-subject-04/20200903_103828/836212060125/color_000030.jpg', '20200903-subject-04/20200903_103828/836212060125/color_000060.jpg', '20200908-subject-05/20200908_143535/836212060125/color_000060.jpg', '20200908-subject-05/20200908_143535/932122060857/color_000030.jpg', '20200918-subject-06/20200918_113137/836212060125/color_000030.jpg', '20200918-subject-06/20200918_113137/836212060125/color_000060.jpg',
21 | # '20200928-subject-07/20200928_143500/836212060125/color_000060.jpg', '20200928-subject-07/20200928_143500/932122060857/color_000030.jpg', '20201002-subject-08/20201002_104827/836212060125/color_000060.jpg', '20201002-subject-08/20201002_104827/932122060861/color_000030.jpg', '20201015-subject-09/20201015_142844/836212060125/color_000030.jpg', '20201015-subject-09/20201015_142844/836212060125/color_000060.jpg', '20201015-subject-09/20201015_142844/841412060263/color_000000.jpg', '20201022-subject-10/20201022_110947/840412060917/color_000060.jpg', '20201022-subject-10/20201022_110947/932122060857/color_000030.jpg'],
22 | # 2: ['20200820-subject-03/20200820_135810/836212060125/color_000060.jpg', '20200820-subject-03/20200820_135810/839512060362/color_000030.jpg', '20200903-subject-04/20200903_104115/836212060125/color_000030.jpg', '20200903-subject-04/20200903_104115/839512060362/color_000060.jpg', '20200908-subject-05/20200908_143832/836212060125/color_000030.jpg', '20200908-subject-05/20200908_143832/836212060125/color_000060.jpg', '20200918-subject-06/20200918_113405/839512060362/color_000060.jpg', '20200918-subject-06/20200918_113405/840412060917/color_000030.jpg', '20200928-subject-07/20200928_143727/839512060362/color_000060.jpg', '20200928-subject-07/20200928_143727/840412060917/color_000030.jpg', '20201002-subject-08/20201002_105058/836212060125/color_000060.jpg', '20201002-subject-08/20201002_105058/840412060917/color_000030.jpg', '20201015-subject-09/20201015_143113/836212060125/color_000030.jpg', '20201015-subject-09/20201015_143113/836212060125/color_000060.jpg', '20201015-subject-09/20201015_143113/840412060917/color_000000.jpg', '20201022-subject-10/20201022_111144/840412060917/color_000030.jpg', '20201022-subject-10/20201022_111144/840412060917/color_000060.jpg'],
23 | # 10: ['20200820-subject-03/20200820_142158/932122060861/color_000030.jpg', '20200820-subject-03/20200820_142158/932122060861/color_000060.jpg', '20200903-subject-04/20200903_110342/836212060125/color_000060.jpg', '20200908-subject-05/20200908_145938/836212060125/color_000060.jpg', '20200908-subject-05/20200908_145938/839512060362/color_000030.jpg', '20200918-subject-06/20200918_115139/839512060362/color_000060.jpg', '20200918-subject-06/20200918_115139/840412060917/color_000030.jpg', '20200928-subject-07/20200928_153732/836212060125/color_000030.jpg', '20200928-subject-07/20200928_153732/932122060857/color_000060.jpg', '20201002-subject-08/20201002_110854/836212060125/color_000060.jpg', '20201015-subject-09/20201015_145212/836212060125/color_000030.jpg', '20201015-subject-09/20201015_145212/839512060362/color_000060.jpg'],
24 | # 15: ['20200820-subject-03/20200820_143802/836212060125/color_000060.jpg', '20200820-subject-03/20200820_143802/840412060917/color_000030.jpg', '20200903-subject-04/20200903_112724/836212060125/color_000060.jpg', '20200903-subject-04/20200903_112724/841412060263/color_000030.jpg', '20200908-subject-05/20200908_151328/836212060125/color_000060.jpg', '20200908-subject-05/20200908_151328/840412060917/color_000030.jpg', '20200918-subject-06/20200918_120310/836212060125/color_000030.jpg', '20200918-subject-06/20200918_120310/836212060125/color_000060.jpg', '20200928-subject-07/20200928_154943/836212060125/color_000030.jpg', '20200928-subject-07/20200928_154943/836212060125/color_000060.jpg', '20201002-subject-08/20201002_112045/836212060125/color_000030.jpg', '20201002-subject-08/20201002_112045/836212060125/color_000060.jpg', '20201015-subject-09/20201015_150413/836212060125/color_000030.jpg', '20201015-subject-09/20201015_150413/836212060125/color_000060.jpg', '20201022-subject-10/20201022_113909/836212060125/color_000060.jpg']
25 | # }
26 | # # TEMP; val set
27 | # # target_img_list = {
28 | # # 1: ['20200709-subject-01/20200709_142123/836212060125/color_000030.jpg', '20200709-subject-01/20200709_142123/836212060125/color_000060.jpg', '20200813-subject-02/20200813_145612/836212060125/color_000030.jpg', '20200813-subject-02/20200813_145612/836212060125/color_000060.jpg'],
29 | # # 2: ['20200709-subject-01/20200709_142446/840412060917/color_000030.jpg', '20200709-subject-01/20200709_142446/840412060917/color_000060.jpg', '20200813-subject-02/20200813_145920/836212060125/color_000030.jpg', '20200813-subject-02/20200813_145920/836212060125/color_000060.jpg'],
30 | # # 10: ['20200709-subject-01/20200709_145743/839512060362/color_000060.jpg', '20200709-subject-01/20200709_145743/932122061900/color_000030.jpg', '20200813-subject-02/20200813_152842/836212060125/color_000060.jpg', '20200813-subject-02/20200813_152842/841412060263/color_000030.jpg'],
31 | # # 15: ['20200709-subject-01/20200709_151632/836212060125/color_000060.jpg', '20200709-subject-01/20200709_151632/840412060917/color_000030.jpg', '20200813-subject-02/20200813_154408/836212060125/color_000030.jpg', '20200813-subject-02/20200813_154408/836212060125/color_000060.jpg'],
32 |
33 | # # }
34 |
35 | # target_img_list_sum = []
36 | # for key, val in target_img_list.items():
37 | # target_img_list_sum.extend(val)
38 |
39 | #with open('/home/hongsuk.c/Projects/HandOccNet/main/novel_object_test_list.json', 'r') as f:
40 | # target_img_list_sum = json.load(f)
41 | #print("TARGET LENGTH: ", len(target_img_list_sum))
42 |
43 | class DEX_YCB(torch.utils.data.Dataset):
44 | def __init__(self, transform, data_split):
45 | self.transform = transform
46 | self.data_split = data_split if data_split == 'train' else 'val'
47 | self.root_dir = osp.join('..', 'data', 'DEX_YCB', 'data')
48 | self.annot_path = osp.join(self.root_dir, 'annotations')
49 | self.root_joint_idx = 0
50 |
51 | self.datalist = self.load_data()
52 | if self.data_split != 'train':
53 | self.eval_result = [[],[]] #[mpjpe_list, pa-mpjpe_list]
54 | print("TEST DATA LEN: ", len(self.datalist))
55 | def load_data(self):
56 | db = COCO(osp.join(self.annot_path, "DEX_YCB_s0_{}_data.json".format(self.data_split)))
57 |
58 | datalist = []
59 | for aid in db.anns.keys():
60 | ann = db.anns[aid]
61 | image_id = ann['image_id']
62 | img = db.loadImgs(image_id)[0]
63 | img_path = osp.join(self.root_dir, img['file_name'])
64 | img_shape = (img['height'], img['width'])
65 | if self.data_split == 'train':
66 | joints_coord_cam = np.array(ann['joints_coord_cam'], dtype=np.float32) # meter
67 | cam_param = {k:np.array(v, dtype=np.float32) for k,v in ann['cam_param'].items()}
68 | joints_coord_img = np.array(ann['joints_img'], dtype=np.float32)
69 | hand_type = ann['hand_type']
70 |
71 | bbox = get_bbox(joints_coord_img[:,:2], np.ones_like(joints_coord_img[:,0]), expansion_factor=1.5)
72 | bbox = process_bbox(bbox, img['width'], img['height'], expansion_factor=1.0)
73 |
74 | if bbox is None:
75 | continue
76 |
77 | mano_pose = np.array(ann['mano_param']['pose'], dtype=np.float32)
78 | mano_shape = np.array(ann['mano_param']['shape'], dtype=np.float32)
79 |
80 | data = {"img_path": img_path, "img_shape": img_shape, "joints_coord_cam": joints_coord_cam, "joints_coord_img": joints_coord_img,
81 | "bbox": bbox, "cam_param": cam_param, "mano_pose": mano_pose, "mano_shape": mano_shape, "hand_type": hand_type}
82 | else:
83 | # if '/'.join(img_path.split('/')[-4:]) not in target_img_list_sum:
84 | # continue
85 |
86 |
87 |
88 | joints_coord_cam = np.array(ann['joints_coord_cam'], dtype=np.float32)
89 | root_joint_cam = copy.deepcopy(joints_coord_cam[0])
90 | joints_coord_img = np.array(ann['joints_img'], dtype=np.float32)
91 | hand_type = ann['hand_type']
92 |
93 | if False and hand_type == 'left':
94 |
95 | # mano_pose = np.array(ann['mano_param']['pose'], dtype=np.float32)
96 | # mano_shape = np.array(ann['mano_param']['shape'], dtype=np.float32)
97 |
98 | # vertices, joints, manojoints2cam = mano.left_layer(torch.from_numpy(mano_pose)[None, :], torch.from_numpy(mano_shape)[None, :])
99 | # vertices = vertices[0].numpy()
100 | # # save_obj(vertices, mano.left_layer.th_faces.numpy(), 'org_left.obj')
101 | # joints = joints[0].numpy()
102 | # joints /= 1000
103 | # joints = joints - joints[0:1] + root_joint_cam[None, :]
104 | # focal, princpt = ann['cam_param']['focal'], ann['cam_param']['princpt']
105 | # proj_joints = cam2pixel(joints, focal, princpt)
106 | # img = cv2.imread(img_path)
107 | # vis_img = vis_keypoints(img, proj_joints)
108 | # cv2.imshow('check cam', vis_img)
109 | # cv2.waitKey(0)
110 | # import pdb; pdb.set_trace()
111 |
112 | # mano_pose = mano_pose.reshape(-1,3)
113 | # mano_pose[:,1:] *= -1
114 | # mano_pose = mano_pose.reshape(-1)
115 | # vertices, joints, _ = mano.layer(torch.from_numpy(mano_pose)[None, :], torch.from_numpy(mano_shape)[None, :])
116 | # joints = joints[0].numpy()
117 | # joints /= 1000
118 | # joints = joints - joints[0:1]
119 | # joints[:, 0] *= -1
120 | # joints = joints + root_joint_cam[None, :]
121 |
122 | # focal, princpt = ann['cam_param']['focal'], ann['cam_param']['princpt']
123 | # proj_joints = cam2pixel(joints, focal, princpt)
124 | # img = cv2.imread(img_path)
125 | # vis_img = vis_keypoints(img, proj_joints)
126 | # cv2.imshow('check flip', vis_img)
127 | # cv2.waitKey(0)
128 | # import pdb; pdb.set_trace()
129 |
130 | import pdb; pdb.set_trace()
131 |
132 |
133 |
134 | bbox = get_bbox(joints_coord_img[:,:2], np.ones_like(joints_coord_img[:,0]), expansion_factor=1.5)
135 | bbox = process_bbox(bbox, img['width'], img['height'], expansion_factor=1.0)
136 | if bbox is None:
137 | bbox = np.array([0,0,img['width']-1, img['height']-1], dtype=np.float32)
138 |
139 | cam_param = {k:np.array(v, dtype=np.float32) for k,v in ann['cam_param'].items()}
140 |
141 |
142 | data = {"img_path": img_path, "img_shape": img_shape, "joints_coord_cam": joints_coord_cam, "root_joint_cam": root_joint_cam,
143 | "bbox": bbox, "cam_param": cam_param, "image_id": image_id, 'hand_type': hand_type}
144 |
145 | datalist.append(data)
146 | return datalist
147 |
148 | def __len__(self):
149 | return len(self.datalist)
150 |
151 | def __getitem__(self, idx):
152 | data = copy.deepcopy(self.datalist[idx])
153 | img_path, img_shape, bbox = data['img_path'], data['img_shape'], data['bbox']
154 | hand_type = data['hand_type']
155 | do_flip = (hand_type == 'left')
156 |
157 | # img
158 | img = load_img(img_path)
159 | orig_img = copy.deepcopy(img)[:,:,::-1]
160 | img, img2bb_trans, bb2img_trans, rot, scale = augmentation(img, bbox, self.data_split, do_flip=do_flip)
161 | img = self.transform(img.astype(np.float32))/255.
162 |
163 | if self.data_split == 'train':
164 | ## 2D joint coordinate
165 | joints_img = data['joints_coord_img']
166 | if do_flip:
167 | joints_img[:,0] = img_shape[1] - joints_img[:,0] - 1
168 | joints_img_xy1 = np.concatenate((joints_img[:,:2], np.ones_like(joints_img[:,:1])),1)
169 | joints_img = np.dot(img2bb_trans, joints_img_xy1.transpose(1,0)).transpose(1,0)[:,:2]
170 | # normalize to [0,1]
171 | joints_img[:,0] /= cfg.input_img_shape[1]
172 | joints_img[:,1] /= cfg.input_img_shape[0]
173 |
174 | ## 3D joint camera coordinate
175 | joints_coord_cam = data['joints_coord_cam']
176 | root_joint_cam = copy.deepcopy(joints_coord_cam[self.root_joint_idx])
177 | joints_coord_cam -= joints_coord_cam[self.root_joint_idx,None,:] # root-relative
178 | if do_flip:
179 | joints_coord_cam[:,0] *= -1
180 |
181 | # 3D data rotation augmentation
182 | rot_aug_mat = np.array([[np.cos(np.deg2rad(-rot)), -np.sin(np.deg2rad(-rot)), 0],
183 | [np.sin(np.deg2rad(-rot)), np.cos(np.deg2rad(-rot)), 0],
184 | [0, 0, 1]], dtype=np.float32)
185 | joints_coord_cam = np.dot(rot_aug_mat, joints_coord_cam.transpose(1,0)).transpose(1,0)
186 |
187 | ## mano parameter
188 | mano_pose, mano_shape = data['mano_pose'], data['mano_shape']
189 |
190 | # 3D data rotation augmentation
191 | mano_pose = mano_pose.reshape(-1,3)
192 | if do_flip:
193 | mano_pose[:,1:] *= -1
194 | root_pose = mano_pose[self.root_joint_idx,:]
195 | root_pose, _ = cv2.Rodrigues(root_pose)
196 | root_pose, _ = cv2.Rodrigues(np.dot(rot_aug_mat,root_pose))
197 | mano_pose[self.root_joint_idx] = root_pose.reshape(3)
198 | mano_pose = mano_pose.reshape(-1)
199 |
200 | inputs = {'img': img}
201 | targets = {'joints_img': joints_img, 'joints_coord_cam': joints_coord_cam, 'mano_pose': mano_pose, 'mano_shape': mano_shape}
202 | meta_info = {'root_joint_cam': root_joint_cam}
203 |
204 | else:
205 | root_joint_cam = data['root_joint_cam']
206 | inputs = {'img': img}
207 | targets = {}
208 | meta_info = {'root_joint_cam': root_joint_cam, 'img_path': img_path}
209 |
210 | return inputs, targets, meta_info
211 |
212 | def evaluate(self, outs, cur_sample_idx):
213 | annots = self.datalist
214 | sample_num = len(outs)
215 | for n in range(sample_num):
216 | annot = annots[cur_sample_idx + n]
217 |
218 | out = outs[n]
219 |
220 | joints_out = out['joints_coord_cam']
221 |
222 | # root centered
223 | joints_out -= joints_out[self.root_joint_idx]
224 |
225 | # flip back to left hand
226 | if annot['hand_type'] == 'left':
227 | joints_out[:,0] *= -1
228 |
229 | # root align
230 | gt_root_joint_cam = annot['root_joint_cam']
231 | joints_out += gt_root_joint_cam
232 |
233 | # GT and rigid align
234 | joints_gt = annot['joints_coord_cam']
235 | joints_out_aligned = rigid_align(joints_out, joints_gt)
236 |
237 | # m to mm
238 | joints_out *= 1000
239 | joints_out_aligned *= 1000
240 | joints_gt *= 1000
241 |
242 | self.eval_result[0].append(np.sqrt(np.sum((joints_out - joints_gt)**2,1)).mean())
243 | self.eval_result[1].append(np.sqrt(np.sum((joints_out_aligned - joints_gt)**2,1)).mean())
244 |
245 | def print_eval_result(self, test_epoch):
246 | print('MPJPE : %.2f mm' % np.mean(self.eval_result[0]))
247 | print('PA MPJPE : %.2f mm' % np.mean(self.eval_result[1]))
248 |
249 | """
250 | def evaluate(self, outs, cur_sample_idx):
251 | annots = self.datalist
252 | sample_num = len(outs)
253 | for n in range(sample_num):
254 | annot = annots[cur_sample_idx + n]
255 |
256 | out = outs[n]
257 |
258 | verts_out = out['mesh_coord_cam']
259 | joints_out = out['joints_coord_cam']
260 |
261 | # root centered
262 | verts_out -= joints_out[self.root_joint_idx]
263 | joints_out -= joints_out[self.root_joint_idx]
264 |
265 | # flip back to left hand
266 | if annot['hand_type'] == 'left':
267 | verts_out[:,0] *= -1
268 | joints_out[:,0] *= -1
269 |
270 | # root align
271 | gt_root_joint_cam = annot['root_joint_cam']
272 | verts_out += gt_root_joint_cam
273 | joints_out += gt_root_joint_cam
274 |
275 | # m to mm
276 | verts_out *= 1000
277 | joints_out *= 1000
278 |
279 | self.eval_result[0].append(joints_out)
280 | self.eval_result[1].append(verts_out)
281 |
282 | def print_eval_result(self, test_epoch):
283 | output_file_path = osp.join(cfg.result_dir, "DEX_RESULTS_EPOCH{}.txt".format(test_epoch))
284 |
285 | with open(output_file_path, 'w') as output_file:
286 | for i, pred_joints in enumerate(self.eval_result[0]):
287 | image_id = self.datalist[i]['image_id']
288 | output_file.write(str(image_id) + ',' + ','.join(pred_joints.ravel().astype(str).tolist()) + '\n')
289 | """
290 |
--------------------------------------------------------------------------------
/data/HO3D/HO3D.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import numpy as np
4 | import torch
5 | import cv2
6 | import random
7 | import json
8 | import math
9 | import copy
10 | from pycocotools.coco import COCO
11 | from config import cfg
12 | from utils.preprocessing import load_img, get_bbox, process_bbox, generate_patch_image, augmentation
13 | from utils.transforms import world2cam, cam2pixel, pixel2cam, rigid_align, transform_joint_to_other_db
14 | from utils.vis import vis_keypoints, vis_mesh, save_obj, vis_keypoints_with_skeleton
15 | from utils.mano import MANO
16 | mano = MANO()
17 |
18 | class HO3D(torch.utils.data.Dataset):
19 | def __init__(self, transform, data_split):
20 | self.transform = transform
21 | self.data_split = data_split if data_split == 'train' else 'evaluation'
22 | self.root_dir = osp.join('..', 'data', 'HO3D', 'data')
23 | self.annot_path = osp.join(self.root_dir, 'annotations')
24 | self.root_joint_idx = 0
25 |
26 | self.datalist = self.load_data()
27 | if self.data_split != 'train':
28 | self.eval_result = [[],[]] #[pred_joints_list, pred_verts_list]
29 | self.joints_name = ('Wrist', 'Index_1', 'Index_2', 'Index_3', 'Middle_1', 'Middle_2', 'Middle_3', 'Pinky_1', 'Pinky_2', 'Pinky_3', 'Ring_1', 'Ring_2', 'Ring_3', 'Thumb_1', 'Thumb_2', 'Thumb_3', 'Thumb_4', 'Index_4', 'Middle_4', 'Ring_4', 'Pinly_4')
30 |
31 | def load_data(self):
32 | db = COCO(osp.join(self.annot_path, "HO3D_{}_data.json".format(self.data_split)))
33 | # db = COCO(osp.join(self.annot_path, 'HO3Dv3_partial_test_multiseq_coco.json'))
34 |
35 | datalist = []
36 | for aid in db.anns.keys():
37 | ann = db.anns[aid]
38 | image_id = ann['image_id']
39 | img = db.loadImgs(image_id)[0]
40 | img_path = osp.join(self.root_dir, self.data_split, img['file_name'])
41 | # TEMP
42 | # img_path = osp.join(self.root_dir, 'train', img['sequence_name'], 'rgb', img['file_name'])
43 |
44 | img_shape = (img['height'], img['width'])
45 | if self.data_split == 'train':
46 | joints_coord_cam = np.array(ann['joints_coord_cam'], dtype=np.float32) # meter
47 | cam_param = {k:np.array(v, dtype=np.float32) for k,v in ann['cam_param'].items()}
48 | joints_coord_img = cam2pixel(joints_coord_cam, cam_param['focal'], cam_param['princpt'])
49 | bbox = get_bbox(joints_coord_img[:,:2], np.ones_like(joints_coord_img[:,0]), expansion_factor=1.5)
50 | bbox = process_bbox(bbox, img['width'], img['height'], expansion_factor=1.0)
51 | if bbox is None:
52 | continue
53 |
54 | mano_pose = np.array(ann['mano_param']['pose'], dtype=np.float32)
55 | mano_shape = np.array(ann['mano_param']['shape'], dtype=np.float32)
56 |
57 | data = {"img_path": img_path, "img_shape": img_shape, "joints_coord_cam": joints_coord_cam, "joints_coord_img": joints_coord_img,
58 | "bbox": bbox, "cam_param": cam_param, "mano_pose": mano_pose, "mano_shape": mano_shape}
59 | else:
60 | root_joint_cam = np.array(ann['root_joint_cam'], dtype=np.float32)
61 | cam_param = {k:np.array(v, dtype=np.float32) for k,v in ann['cam_param'].items()}
62 | # TEMP
63 | # root_joint_cam = np.zeros(0)
64 | # cam_param = np.zeros(0)
65 | bbox = np.array(ann['bbox'], dtype=np.float32)
66 | bbox = process_bbox(bbox, img['width'], img['height'], expansion_factor=1.5)
67 |
68 | data = {"img_path": img_path, "img_shape": img_shape, "root_joint_cam": root_joint_cam,
69 | "bbox": bbox, "cam_param": cam_param}
70 |
71 | datalist.append(data)
72 |
73 | return datalist
74 |
75 | def __len__(self):
76 | return len(self.datalist)
77 |
78 | def __getitem__(self, idx):
79 | data = copy.deepcopy(self.datalist[idx])
80 | img_path, img_shape, bbox = data['img_path'], data['img_shape'], data['bbox']
81 |
82 | # img
83 | img = load_img(img_path)
84 | img, img2bb_trans, bb2img_trans, rot, scale = augmentation(img, bbox, self.data_split, do_flip=False)
85 | img = self.transform(img.astype(np.float32))/255.
86 |
87 | if self.data_split == 'train':
88 | ## 2D joint coordinate
89 | joints_img = data['joints_coord_img']
90 | joints_img_xy1 = np.concatenate((joints_img[:,:2], np.ones_like(joints_img[:,:1])),1)
91 | joints_img = np.dot(img2bb_trans, joints_img_xy1.transpose(1,0)).transpose(1,0)[:,:2]
92 | # normalize to [0,1]
93 | joints_img[:,0] /= cfg.input_img_shape[1]
94 | joints_img[:,1] /= cfg.input_img_shape[0]
95 |
96 | ## 3D joint camera coordinate
97 | joints_coord_cam = data['joints_coord_cam']
98 | root_joint_cam = copy.deepcopy(joints_coord_cam[self.root_joint_idx])
99 | joints_coord_cam -= joints_coord_cam[self.root_joint_idx,None,:] # root-relative
100 | # 3D data rotation augmentation
101 | rot_aug_mat = np.array([[np.cos(np.deg2rad(-rot)), -np.sin(np.deg2rad(-rot)), 0],
102 | [np.sin(np.deg2rad(-rot)), np.cos(np.deg2rad(-rot)), 0],
103 | [0, 0, 1]], dtype=np.float32)
104 | joints_coord_cam = np.dot(rot_aug_mat, joints_coord_cam.transpose(1,0)).transpose(1,0)
105 |
106 | ## mano parameter
107 | mano_pose, mano_shape = data['mano_pose'], data['mano_shape']
108 | # 3D data rotation augmentation
109 | mano_pose = mano_pose.reshape(-1,3)
110 | root_pose = mano_pose[self.root_joint_idx,:]
111 | root_pose, _ = cv2.Rodrigues(root_pose)
112 | root_pose, _ = cv2.Rodrigues(np.dot(rot_aug_mat,root_pose))
113 | mano_pose[self.root_joint_idx] = root_pose.reshape(3)
114 | mano_pose = mano_pose.reshape(-1)
115 |
116 | inputs = {'img': img}
117 | targets = {'joints_img': joints_img, 'joints_coord_cam': joints_coord_cam, 'mano_pose': mano_pose, 'mano_shape': mano_shape}
118 | meta_info = {'root_joint_cam': root_joint_cam}
119 |
120 | else:
121 | root_joint_cam = data['root_joint_cam']
122 | inputs = {'img': img}
123 | targets = {}
124 | meta_info = {'root_joint_cam': root_joint_cam, 'img_path': img_path}
125 |
126 | return inputs, targets, meta_info
127 |
128 |
129 | def evaluate(self, outs, cur_sample_idx):
130 | annots = self.datalist
131 | sample_num = len(outs)
132 | for n in range(sample_num):
133 | annot = annots[cur_sample_idx + n]
134 |
135 | out = outs[n]
136 |
137 | verts_out = out['mesh_coord_cam']
138 | joints_out = out['joints_coord_cam']
139 |
140 | # root align
141 | gt_root_joint_cam = annot['root_joint_cam']
142 | verts_out = verts_out - joints_out[self.root_joint_idx] + gt_root_joint_cam
143 | joints_out = joints_out - joints_out[self.root_joint_idx] + gt_root_joint_cam
144 |
145 | # convert to openGL coordinate system.
146 | verts_out *= np.array([1, -1, -1])
147 | joints_out *= np.array([1, -1, -1])
148 |
149 | # convert joint ordering from MANO to HO3D.
150 | joints_out = transform_joint_to_other_db(joints_out, mano.joints_name, self.joints_name)
151 |
152 | self.eval_result[0].append(joints_out.tolist())
153 | self.eval_result[1].append(verts_out.tolist())
154 |
155 | def print_eval_result(self, test_epoch):
156 | output_json_file = osp.join(cfg.result_dir, 'pred{}.json'.format(test_epoch))
157 | output_zip_file = osp.join(cfg.result_dir, 'pred{}.zip'.format(test_epoch))
158 |
159 | with open(output_json_file, 'w') as f:
160 | json.dump(self.eval_result, f)
161 | print('Dumped %d joints and %d verts predictions to %s' % (len(self.eval_result[0]), len(self.eval_result[1]), output_json_file))
162 |
163 | cmd = 'zip -j ' + output_zip_file + ' ' + output_json_file
164 | print(cmd)
165 | os.system(cmd)
166 |
167 |
--------------------------------------------------------------------------------
/demo/demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | import os.path as osp
4 | import argparse
5 | import numpy as np
6 | import cv2
7 | import torch
8 | import torchvision.transforms as transforms
9 | from torch.nn.parallel.data_parallel import DataParallel
10 | import torch.backends.cudnn as cudnn
11 |
12 | sys.path.insert(0, osp.join('..', 'main'))
13 | sys.path.insert(0, osp.join('..', 'common'))
14 | from config import cfg
15 | from model import get_model
16 | from utils.preprocessing import load_img, process_bbox, generate_patch_image
17 | from utils.vis import save_obj
18 | from utils.mano import MANO
19 | mano = MANO()
20 |
21 | def parse_args():
22 | parser = argparse.ArgumentParser()
23 | parser.add_argument('--gpu', type=str, dest='gpu_ids')
24 | args = parser.parse_args()
25 |
26 | # test gpus
27 | if not args.gpu_ids:
28 | assert 0, print("Please set proper gpu ids")
29 |
30 | if '-' in args.gpu_ids:
31 | gpus = args.gpu_ids.split('-')
32 | gpus[0] = int(gpus[0])
33 | gpus[1] = int(gpus[1]) + 1
34 | args.gpu_ids = ','.join(map(lambda x: str(x), list(range(*gpus))))
35 |
36 | return args
37 |
38 | # argument parsing
39 | args = parse_args()
40 | cfg.set_args(args.gpu_ids)
41 | cudnn.benchmark = True
42 |
43 | # snapshot load
44 | model_path = './snapshot_demo.pth.tar'
45 | assert osp.exists(model_path), 'Cannot find model at ' + model_path
46 | print('Load checkpoint from {}'.format(model_path))
47 | model = get_model('test')
48 |
49 | model = DataParallel(model).cuda()
50 | ckpt = torch.load(model_path)
51 | model.load_state_dict(ckpt['network'], strict=False)
52 | model.eval()
53 |
54 | # prepare input image
55 | transform = transforms.ToTensor()
56 | img_path = 'input.png'
57 | original_img = load_img(img_path)
58 | original_img_height, original_img_width = original_img.shape[:2]
59 |
60 | # prepare bbox
61 | bbox = [340.8, 232.0, 20.7, 20.7] # xmin, ymin, width, height
62 |
63 | bbox = process_bbox(bbox, original_img_width, original_img_height)
64 | img, img2bb_trans, bb2img_trans = generate_patch_image(original_img, bbox, 1.0, 0.0, False, cfg.input_img_shape)
65 | img = transform(img.astype(np.float32))/255
66 | img = img.cuda()[None,:,:,:]
67 |
68 | # forward
69 | inputs = {'img': img}
70 | targets = {}
71 | meta_info = {}
72 | with torch.no_grad():
73 | out = model(inputs, targets, meta_info, 'test')
74 | img = (img[0].cpu().numpy().transpose(1,2,0)*255).astype(np.uint8) # cfg.input_img_shape[1], cfg.input_img_shape[0], 3
75 | verts_out = out['mesh_coord_cam'][0].cpu().numpy()
76 |
77 | # bbox for input hand image
78 | bbox_vis = np.array(bbox, int)
79 | bbox_vis[2:] += bbox_vis[:2]
80 | cvimg = cv2.rectangle(original_img.copy(), bbox_vis[:2], bbox_vis[2:], (255,0,0), 3)
81 | cv2.imwrite('hand_bbox.png', cvimg[:,:,::-1])
82 |
83 | ## input hand image
84 | cv2.imwrite('hand_image.png', img[:,:,::-1])
85 |
86 | # save mesh (obj)
87 | save_obj(verts_out*np.array([1,-1,-1]), mano.face, 'output.obj')
88 |
--------------------------------------------------------------------------------
/demo/demo_fitting.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import glob
3 | import os
4 | import os.path as osp
5 | import argparse
6 | import json
7 | import numpy as np
8 | import cv2
9 | import torch
10 | from PIL import Image
11 | import torchvision.transforms as transforms
12 | from torch.nn.parallel.data_parallel import DataParallel
13 | import torch.backends.cudnn as cudnn
14 | from tqdm import tqdm
15 |
16 | sys.path.insert(0, osp.join('..', 'main'))
17 | sys.path.insert(0, osp.join('..', 'common'))
18 | from config import cfg
19 | from model import get_model
20 | from utils.preprocessing import load_img, process_bbox, generate_patch_image
21 | from utils.vis import save_obj, vis_keypoints_with_skeleton
22 | from utils.mano import MANO
23 | from utils.camera import PerspectiveCamera
24 | mano = MANO()
25 |
26 | def parse_args():
27 | parser = argparse.ArgumentParser()
28 | parser.add_argument('--gpu', type=str, dest='gpu_ids')
29 | parser.add_argument('--depth', type=float, default='0.5')
30 |
31 | args = parser.parse_args()
32 |
33 | # test gpus
34 | if not args.gpu_ids:
35 | assert 0, print("Please set proper gpu ids")
36 |
37 | if '-' in args.gpu_ids:
38 | gpus = args.gpu_ids.split('-')
39 | gpus[0] = int(gpus[0])
40 | gpus[1] = int(gpus[1]) + 1
41 | args.gpu_ids = ','.join(map(lambda x: str(x), list(range(*gpus))))
42 |
43 | return args
44 |
45 | def load_camera(cam_path, cam_idx='0'):
46 | with open(cam_path, 'r') as f:
47 | cam_data = json.load(f)
48 |
49 | camera = PerspectiveCamera()
50 |
51 | camera.focal_length_x = torch.full([1], cam_data[cam_idx]['fx'])
52 | camera.focal_length_y = torch.full([1], cam_data[cam_idx]['fy'])
53 | camera.center = torch.tensor(
54 | [cam_data[cam_idx]['cx'], cam_data[cam_idx]['cy']]).unsqueeze(0)
55 | # only intrinsics
56 | # rotation, _ = cv2.Rodrigues(
57 | # np.array(cam_data[cam_idx]['rvec'], dtype=np.float32))
58 | # camera.rotation.data = torch.from_numpy(rotation).unsqueeze(0)
59 | # camera.translation.data = torch.tensor(
60 | # cam_data[cam_idx]['tvec']).unsqueeze(0) / 1000.
61 | camera.rotation.requires_grad = False
62 | camera.translation.requires_grad = False
63 | camera.name = str(cam_idx)
64 |
65 | return camera
66 |
67 | if __name__ == '__main__':
68 | # argument parsing
69 | args = parse_args()
70 | cfg.set_args(args.gpu_ids)
71 | cudnn.benchmark = True
72 | transform = transforms.ToTensor()
73 |
74 | # hard coding
75 | save_dir = './'
76 | init_depth = args.depth
77 | img_path = 'fitting_input.jpg'
78 | bbox = [300, 330, 90, 50]#[340.8, 232.0, 20.7, 20.7] # xmin, ymin, width, height
79 |
80 | # model snapshot load
81 | model_path = './snapshot_demo.pth.tar'
82 | assert osp.exists(model_path), 'Cannot find model at ' + model_path
83 | print('Load checkpoint from {}'.format(model_path))
84 | model = get_model('test')
85 |
86 | model = DataParallel(model).cuda()
87 | ckpt = torch.load(model_path)
88 | model.load_state_dict(ckpt['network'], strict=False)
89 | model.eval()
90 |
91 | # prepare input image
92 | transform = transforms.ToTensor()
93 | original_img = load_img(img_path)
94 | original_img_height, original_img_width = original_img.shape[:2]
95 |
96 | # prepare bbox
97 | bbox = process_bbox(bbox, original_img_width, original_img_height)
98 | img, img2bb_trans, bb2img_trans = generate_patch_image(original_img, bbox, 1.0, 0.0, False, cfg.input_img_shape)
99 | img = transform(img.astype(np.float32))/255
100 | img = img.cuda()[None,:,:,:]
101 |
102 | # get camera for projection
103 | camera = PerspectiveCamera()
104 | camera.rotation.requires_grad = False
105 | camera.translation.requires_grad = False
106 | camera.center[0, 0] = original_img.shape[1] / 2
107 | camera.center[0, 1] = original_img.shape[0] / 2
108 | camera.cuda()
109 |
110 | # forward pass to the model
111 | inputs = {'img': img} # cfg.input_img_shape[1], cfg.input_img_shape[0], 3
112 | targets = {}
113 | meta_info = {}
114 | with torch.no_grad():
115 | out = model(inputs, targets, meta_info, 'test')
116 | img = (img[0].cpu().numpy().transpose(1, 2, 0)*255).astype(np.uint8) #
117 | verts_out = out['mesh_coord_cam'][0].cpu().numpy()
118 |
119 | # get hand mesh's scale and translation by fitting joint cam to joint img
120 | joint_img, joint_cam = out['joints_coord_img'], out['joints_coord_cam']
121 |
122 | # denormalize joint_img from 0 ~ 1 to actual 0 ~ original height and width
123 | H, W = img.shape[:2]
124 | joint_img[:, :, 0] *= W
125 | joint_img[:, :, 1] *= H
126 | torch_bb2img_trans = torch.tensor(bb2img_trans).to(joint_img)
127 | homo_joint_img = torch.cat([joint_img, torch.ones_like(joint_img[:, :, :1])], dim=2)
128 | org_res_joint_img = homo_joint_img @ torch_bb2img_trans.transpose(0, 1)
129 |
130 | # depth initialization
131 | depth_map = None #np.asarray(Image.open(depth_path))
132 | hand_scale, hand_translation = model.module.get_mesh_scale_trans(
133 | org_res_joint_img, joint_cam, init_scale=1., init_depth=init_depth, camera=camera, depth_map=depth_map)
134 |
135 | np_joint_img = org_res_joint_img[0].cpu().numpy()
136 | np_joint_img = np.concatenate([np_joint_img, np.ones_like(np_joint_img[:, :1])], axis=1)
137 | vis_img = original_img.astype(np.uint8)[:, :, ::-1]
138 | pred_joint_img_overlay = vis_keypoints_with_skeleton(vis_img, np_joint_img.T, mano.skeleton)
139 | # cv2.imshow('2d prediction', pred_joint_img_overlay)
140 | save_path = osp.join(
141 | save_dir, f'{osp.basename(img_path)[:-4]}_2d_prediction.png')
142 |
143 | cv2.imwrite(save_path, pred_joint_img_overlay)
144 | projected_joints = camera(
145 | hand_scale * joint_cam + hand_translation)
146 | np_joint_img = projected_joints[0].detach().cpu().numpy()
147 | np_joint_img = np.concatenate([np_joint_img, np.ones_like(np_joint_img[:, :1])], axis=1)
148 |
149 | vis_img = original_img.astype(np.uint8)[:, :, ::-1]
150 | pred_joint_img_overlay = vis_keypoints_with_skeleton(vis_img, np_joint_img.T, mano.skeleton)
151 | # cv2.imshow('projection', pred_joint_img_overlay)
152 | # cv2.waitKey(0)
153 | save_path = osp.join(save_dir, f'{osp.basename(img_path)[:-4]}_projection.png')
154 | cv2.imwrite(save_path, pred_joint_img_overlay)
155 |
156 | # data to save
157 | data_to_save = {
158 | 'hand_scale': hand_scale.detach().cpu().numpy().tolist(), # 1
159 | 'hand_translation': hand_translation.detach().cpu().numpy().tolist(), # 3
160 | 'mano_pose': out['mano_pose'][0].detach().cpu().numpy().tolist(), # 48
161 | 'mano_shape': out['mano_shape'][0].detach().cpu().numpy().tolist(), # 10
162 | }
163 | save_path = osp.join(
164 | save_dir, f'{osp.basename(img_path)[:-4]}_3dmesh.json')
165 | with open(save_path, 'w') as f:
166 | json.dump(data_to_save, f)
167 |
168 | # # bbox for input hand image
169 | # bbox_vis = np.array(bbox, int)
170 | # bbox_vis[2:] += bbox_vis[:2]
171 | # cvimg = cv2.rectangle(original_img.copy(),
172 | # bbox_vis[:2], bbox_vis[2:], (255, 0, 0), 3)
173 | # cv2.imwrite(f'{osp.basename(img_path)[:-4]}_hand_bbox.png', cvimg[:, :, ::-1])
174 | # ## input hand image
175 | # cv2.imwrite(f'{osp.basename(img_path)[:-4]}_hand_image.png', img[:, :, ::-1])
176 |
177 | # save mesh (obj)
178 | save_path = osp.join(
179 | save_dir, f'{osp.basename(img_path)[:-4]}_3dmesh.obj')
180 | save_obj(verts_out*np.array([1, -1, -1]),
181 | mano.face, save_path)
--------------------------------------------------------------------------------
/demo/fitting_input.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/namepllet/HandOccNet/65ba997c9ce88947f70453fa0ac6c1e7d052660f/demo/fitting_input.png
--------------------------------------------------------------------------------
/demo/hand_bbox.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/namepllet/HandOccNet/65ba997c9ce88947f70453fa0ac6c1e7d052660f/demo/hand_bbox.png
--------------------------------------------------------------------------------
/demo/hand_image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/namepllet/HandOccNet/65ba997c9ce88947f70453fa0ac6c1e7d052660f/demo/hand_image.png
--------------------------------------------------------------------------------
/demo/input.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/namepllet/HandOccNet/65ba997c9ce88947f70453fa0ac6c1e7d052660f/demo/input.png
--------------------------------------------------------------------------------
/main/config.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import sys
4 | import numpy as np
5 |
6 | class Config:
7 |
8 | ## dataset
9 | # HO3D, DEX_YCB
10 | trainset = 'HO3D'
11 | testset = 'DEX_YCB'
12 |
13 | ## input, output
14 | input_img_shape = (256,256)
15 |
16 | ## training config
17 | if trainset == 'HO3D':
18 | lr_dec_epoch = [10*i for i in range(1,7)]
19 | end_epoch = 70
20 | lr = 1e-4
21 | lr_dec_factor = 0.7
22 | elif trainset == 'DEX_YCB':
23 | lr_dec_epoch = [i for i in range(1,25)]
24 | end_epoch = 25
25 | lr = 1e-4
26 | lr_dec_factor = 0.9
27 | train_batch_size = 16 # per GPU
28 | lambda_mano_verts = 1e4
29 | lambda_mano_joints = 1e4
30 | lambda_mano_pose = 10
31 | lambda_mano_shape = 0.1
32 | lambda_joints_img = 100
33 | ckpt_freq = 10
34 |
35 | ## testing config
36 | test_batch_size = 64
37 |
38 | ## others
39 | num_thread = 20
40 | gpu_ids = '0'
41 | num_gpus = 1
42 | continue_train = False
43 |
44 | ## directory
45 | cur_dir = osp.dirname(os.path.abspath(__file__))
46 | root_dir = osp.join(cur_dir, '..')
47 | data_dir = osp.join(root_dir, 'data')
48 | output_dir = osp.join(root_dir, 'output')
49 | model_dir = osp.join(output_dir, 'model_dump')
50 | vis_dir = osp.join(output_dir, 'vis')
51 | log_dir = osp.join(output_dir, 'log')
52 | result_dir = osp.join(output_dir, 'result')
53 | mano_path = osp.join(root_dir, 'common', 'utils', 'manopth')
54 |
55 | def set_args(self, gpu_ids, continue_train=False):
56 | self.gpu_ids = gpu_ids
57 | self.num_gpus = len(self.gpu_ids.split(','))
58 | self.continue_train = continue_train
59 | os.environ["CUDA_VISIBLE_DEVICES"] = self.gpu_ids
60 | print('>>> Using GPU: {}'.format(self.gpu_ids))
61 |
62 | cfg = Config()
63 |
64 | sys.path.insert(0, osp.join(cfg.root_dir, 'common'))
65 | from utils.dir import add_pypath, make_folder
66 | add_pypath(osp.join(cfg.data_dir))
67 | add_pypath(osp.join(cfg.data_dir, cfg.trainset))
68 | add_pypath(osp.join(cfg.data_dir, cfg.testset))
69 | make_folder(cfg.model_dir)
70 | make_folder(cfg.vis_dir)
71 | make_folder(cfg.log_dir)
72 | make_folder(cfg.result_dir)
73 |
--------------------------------------------------------------------------------
/main/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import functional as F
4 | from nets.backbone import FPN
5 | from nets.transformer import Transformer
6 | from nets.regressor import Regressor
7 | from utils.mano import MANO
8 | from utils.fitting import ScaleTranslationLoss, FittingMonitor
9 | from utils.optimizers import optim_factory
10 | from utils.camera import PerspectiveCamera
11 | from config import cfg
12 | import math
13 |
14 | class Model(nn.Module):
15 | def __init__(self, backbone, FIT, SET, regressor):
16 | super(Model, self).__init__()
17 | self.backbone = backbone
18 | self.FIT = FIT
19 | self.SET = SET
20 | self.regressor = regressor
21 |
22 | self.fitting_loss = ScaleTranslationLoss(list(range(0, 21))) # fitting joint indices
23 |
24 |
25 | def forward(self, inputs, targets, meta_info, mode):
26 | p_feats, s_feats = self.backbone(inputs['img']) # primary, secondary feats
27 | feats = self.FIT(s_feats, p_feats)
28 | feats = self.SET(feats, feats)
29 |
30 | if mode == 'train':
31 | gt_mano_params = torch.cat([targets['mano_pose'], targets['mano_shape']], dim=1)
32 | else:
33 | gt_mano_params = None
34 | pred_mano_results, gt_mano_results, preds_joints_img = self.regressor(feats, gt_mano_params)
35 |
36 | if mode == 'train':
37 | # loss functions
38 | loss = {}
39 | loss['mano_verts'] = cfg.lambda_mano_verts * F.mse_loss(pred_mano_results['verts3d'], gt_mano_results['verts3d'])
40 | loss['mano_joints'] = cfg.lambda_mano_joints * F.mse_loss(pred_mano_results['joints3d'], gt_mano_results['joints3d'])
41 | loss['mano_pose'] = cfg.lambda_mano_pose * F.mse_loss(pred_mano_results['mano_pose'], gt_mano_results['mano_pose'])
42 | loss['mano_shape'] = cfg.lambda_mano_shape * F.mse_loss(pred_mano_results['mano_shape'], gt_mano_results['mano_shape'])
43 | loss['joints_img'] = cfg.lambda_joints_img * F.mse_loss(preds_joints_img[0], targets['joints_img'])
44 | return loss
45 |
46 | else:
47 | # test output
48 | out = {}
49 | out['joints_coord_img'] = preds_joints_img[0]
50 | out['mano_pose'] = pred_mano_results['mano_pose_aa']
51 | out['mano_shape'] = pred_mano_results['mano_shape']
52 | out['joints_coord_cam'] = pred_mano_results['joints3d']
53 | out['mesh_coord_cam'] = pred_mano_results['verts3d']
54 | out['manojoints2cam'] = pred_mano_results['manojoints2cam']
55 | out['mano_pose_aa'] = pred_mano_results['mano_pose_aa']
56 |
57 | return out
58 |
59 | def get_mesh_scale_trans(self, pred_joint_img, pred_joint_cam, init_scale=1., init_depth=1., camera=None, depth_map=None):
60 | """
61 | pred_joint_img: (batch_size, 21, 2)
62 | pred_joint_cam: (batch_size, 21, 3)
63 | """
64 | if camera is None:
65 | camera = PerspectiveCamera()
66 |
67 | dtype, device = pred_joint_cam.dtype, pred_joint_cam.device
68 | hand_scale = torch.tensor([init_scale / 1.0], dtype=dtype, device=device, requires_grad=False)
69 | hand_translation = torch.tensor([0, 0, init_depth], dtype=dtype, device=device, requires_grad=True)
70 | if depth_map is not None:
71 | tensor_depth = torch.tensor(depth_map, device=device, dtype=dtype)[
72 | None, None, :, :]
73 | grid = pred_joint_img.clone()
74 | grid[:, :, 0] /= tensor_depth.shape[-1]
75 | grid[:, :, 1] /= tensor_depth.shape[-2]
76 | grid = 2 * grid - 1
77 | joints_depth = torch.nn.functional.grid_sample(
78 | tensor_depth, grid[:, None, :, :]) # (1, 1, 1, 21)
79 | joints_depth = joints_depth.reshape(1, 21, 1)
80 | hand_translation = torch.tensor(
81 | [0, 0, joints_depth[0, cfg.fitting_joint_idxs, 0].mean() / 1000.], device=device, requires_grad=True)
82 |
83 | # intended only for demo mesh rendering
84 | batch_size = 1
85 | self.fitting_loss.trans_estimation = hand_translation.clone()
86 |
87 | params = []
88 | params.append(hand_translation)
89 | params.append(hand_scale)
90 | optimizer, create_graph = optim_factory.create_optimizer(
91 | params, optim_type='lbfgsls', lr=1.0e-1)
92 |
93 | # optimization
94 | print("[Fitting]: fitting the hand scale and translation...")
95 | with FittingMonitor(batch_size=batch_size) as monitor:
96 | fit_camera = monitor.create_fitting_closure(
97 | optimizer, camera, pred_joint_cam, pred_joint_img, hand_translation, hand_scale, self.fitting_loss, create_graph=create_graph)
98 |
99 | loss_val = monitor.run_fitting(
100 | optimizer, fit_camera, params)
101 |
102 |
103 | print(f"[Fitting]: fitting finished with loss of {loss_val}")
104 | print(f"Scale: {hand_scale.detach().cpu().numpy()}, Translation: {hand_translation.detach().cpu().numpy()}")
105 | return hand_scale, hand_translation
106 |
107 | def init_weights(m):
108 | if type(m) == nn.ConvTranspose2d:
109 | nn.init.normal_(m.weight,std=0.001)
110 | elif type(m) == nn.Conv2d:
111 | nn.init.normal_(m.weight,std=0.001)
112 | nn.init.constant_(m.bias, 0)
113 | elif type(m) == nn.BatchNorm2d:
114 | nn.init.constant_(m.weight,1)
115 | nn.init.constant_(m.bias,0)
116 | elif type(m) == nn.Linear:
117 | nn.init.normal_(m.weight,std=0.01)
118 | nn.init.constant_(m.bias,0)
119 |
120 | def get_model(mode):
121 | backbone = FPN(pretrained=True)
122 | FIT = Transformer(injection=True) # feature injecting transformer
123 | SET = Transformer(injection=False) # self enhancing transformer
124 | regressor = Regressor()
125 |
126 | if mode == 'train':
127 | FIT.apply(init_weights)
128 | SET.apply(init_weights)
129 | regressor.apply(init_weights)
130 |
131 | model = Model(backbone, FIT, SET, regressor)
132 |
133 | return model
--------------------------------------------------------------------------------
/main/test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import argparse
3 | from tqdm import tqdm
4 | import numpy as np
5 | import torch.backends.cudnn as cudnn
6 | from config import cfg
7 | from base import Tester
8 |
9 | def parse_args():
10 | parser = argparse.ArgumentParser()
11 | parser.add_argument('--gpu', type=str, dest='gpu_ids')
12 | parser.add_argument('--test_epoch', type=str, dest='test_epoch')
13 | args = parser.parse_args()
14 |
15 | if not args.gpu_ids:
16 | assert 0, "Please set propoer gpu ids"
17 |
18 | if '-' in args.gpu_ids:
19 | gpus = args.gpu_ids.split('-')
20 | gpus[0] = int(gpus[0])
21 | gpus[1] = int(gpus[1]) + 1
22 | args.gpu_ids = ','.join(map(lambda x: str(x), list(range(*gpus))))
23 |
24 | assert args.test_epoch, 'Test epoch is required.'
25 | return args
26 |
27 | def main():
28 |
29 | args = parse_args()
30 | cfg.set_args(args.gpu_ids)
31 | cudnn.benchmark = True
32 |
33 | tester = Tester(args.test_epoch)
34 | tester._make_batch_generator()
35 | tester._make_model()
36 |
37 | eval_result = {}
38 | cur_sample_idx = 0
39 | for itr, (inputs, targets, meta_info) in enumerate(tqdm(tester.batch_generator)):
40 |
41 | # forward
42 | with torch.no_grad():
43 | out = tester.model(inputs, targets, meta_info, 'test')
44 |
45 | # save output
46 | out = {k: v.cpu().numpy() for k,v in out.items()}
47 | for k,v in out.items(): batch_size = out[k].shape[0]
48 | out = [{k: v[bid] for k,v in out.items()} for bid in range(batch_size)]
49 |
50 | # evaluate
51 | tester._evaluate(out, cur_sample_idx)
52 | cur_sample_idx += len(out)
53 |
54 | tester._print_eval_result(args.test_epoch)
55 |
56 | if __name__ == "__main__":
57 | main()
--------------------------------------------------------------------------------
/main/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from config import cfg
3 | import torch
4 | from base import Trainer
5 | import torch.backends.cudnn as cudnn
6 |
7 | def parse_args():
8 | parser = argparse.ArgumentParser()
9 | parser.add_argument('--gpu', type=str, dest='gpu_ids')
10 | parser.add_argument('--continue', dest='continue_train', action='store_true')
11 | args = parser.parse_args()
12 |
13 | if not args.gpu_ids:
14 | assert 0, "Please set propoer gpu ids"
15 |
16 | if '-' in args.gpu_ids:
17 | gpus = args.gpu_ids.split('-')
18 | gpus[0] = int(gpus[0])
19 | gpus[1] = int(gpus[1]) + 1
20 | args.gpu_ids = ','.join(map(lambda x: str(x), list(range(*gpus))))
21 |
22 | return args
23 |
24 | def main():
25 |
26 | # argument parse and create log
27 | args = parse_args()
28 | cfg.set_args(args.gpu_ids, args.continue_train)
29 | cudnn.benchmark = True
30 |
31 | trainer = Trainer()
32 | trainer._make_batch_generator()
33 | trainer._make_model()
34 |
35 | # train
36 | for epoch in range(trainer.start_epoch, cfg.end_epoch):
37 |
38 | trainer.set_lr(epoch)
39 | trainer.tot_timer.tic()
40 | trainer.read_timer.tic()
41 | for itr, (inputs, targets, meta_info) in enumerate(trainer.batch_generator):
42 | trainer.read_timer.toc()
43 | trainer.gpu_timer.tic()
44 |
45 | # forward
46 | trainer.optimizer.zero_grad()
47 | loss = trainer.model(inputs, targets, meta_info, 'train')
48 | loss = {k:loss[k].mean() for k in loss}
49 |
50 | # backward
51 | sum(loss[k] for k in loss).backward()
52 | trainer.optimizer.step()
53 | trainer.gpu_timer.toc()
54 | screen = [
55 | 'Epoch %d/%d itr %d/%d:' % (epoch, cfg.end_epoch, itr, trainer.itr_per_epoch),
56 | 'lr: %g' % (trainer.get_lr()),
57 | 'speed: %.2f(%.2fs r%.2f)s/itr' % (
58 | trainer.tot_timer.average_time, trainer.gpu_timer.average_time, trainer.read_timer.average_time),
59 | '%.2fh/epoch' % (trainer.tot_timer.average_time / 3600. * trainer.itr_per_epoch),
60 | ]
61 | screen += ['%s: %.4f' % ('loss_' + k, v.detach()) for k,v in loss.items()]
62 | trainer.logger.info(' '.join(screen))
63 |
64 | trainer.tot_timer.toc()
65 | trainer.tot_timer.tic()
66 | trainer.read_timer.tic()
67 |
68 | if (epoch+1)%cfg.ckpt_freq== 0 or epoch+1 == cfg.end_epoch:
69 | trainer.save_model({
70 | 'epoch': epoch,
71 | 'network': trainer.model.state_dict(),
72 | 'optimizer': trainer.optimizer.state_dict(),
73 | }, epoch+1)
74 |
75 |
76 | if __name__ == "__main__":
77 | main()
78 |
--------------------------------------------------------------------------------
/requiremets.sh:
--------------------------------------------------------------------------------
1 | pip install numpy==1.17.4 torch==1.9.1 torchvision==0.10.1 einops chumpy opencv-python pycocotools pyrender tqdm
2 |
3 |
4 |
--------------------------------------------------------------------------------