├── utils ├── util.py ├── config.py └── urdf2graph.py ├── configs ├── train │ ├── hand.yaml │ ├── arm.yaml │ └── yumi.yaml └── inference │ ├── hand.yaml │ ├── arm.yaml │ └── yumi.yaml ├── LICENSE ├── .gitignore ├── README.md ├── test.py ├── train.py ├── main.py ├── models ├── model.py ├── kinematics.py └── loss.py ├── inference.py └── dataset.py /utils/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | def create_folder(folder): 4 | if not os.path.exists(folder): 5 | os.makedirs(folder) 6 | -------------------------------------------------------------------------------- /configs/train/hand.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | TRAIN: 3 | SOURCE_NAME: "SignWithHand" 4 | SOURCE_PATH: "./data/source/sign-hand/train" 5 | TARGET_NAME: "InspireHand" 6 | TARGET_PATH: "./data/target/yumi-with-hands" 7 | TEST: 8 | SOURCE_NAME: "SignWithHand" 9 | SOURCE_PATH: "./data/source/sign-hand/test" 10 | TARGET_NAME: "InspireHand" 11 | TARGET_PATH: "./data/target/yumi-with-hands" 12 | MODEL: 13 | NAME: "HandNet" 14 | HYPER: 15 | EPOCHS: 10 16 | BATCH_SIZE: 16 17 | LEARNING_RATE: 0.0001 18 | LOSS: 19 | FIN: True 20 | REG: True 21 | OTHERS: 22 | SAVE: "./saved/models/hand" 23 | LOG: "./saved/log/hand" 24 | SUMMARY: "./saved/runs/hand" 25 | LOG_INTERVAL: 100 -------------------------------------------------------------------------------- /configs/inference/hand.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | TEST: 3 | SOURCE_NAME: "SignWithHand" 4 | SOURCE_PATH: "./data/source/sign-hand/test" 5 | TARGET_NAME: "InspireHand" 6 | TARGET_PATH: "./data/target/yumi-with-hands" 7 | MODEL: 8 | NAME: "HandNet" 9 | CHECKPOINT: "./saved/models/hand/best_model_epoch_0009.pth" 10 | HYPER: 11 | EPOCHS: 100 12 | BATCH_SIZE: 16 13 | LEARNING_RATE: 0.0001 14 | LOSS: 15 | FIN: True 16 | REG: True 17 | INFERENCE: 18 | MOTION: 19 | SOURCE: './data/source/sign-hand/test/h5/yumi_intro_YuMi.h5' 20 | KEY: '我-wo' 21 | H5: 22 | BOOL: True 23 | PATH: './saved/h5' 24 | OTHERS: 25 | LOG: "./saved/log/hand" 26 | SUMMARY: "./saved/runs/hand" 27 | LOG_INTERVAL: 100 -------------------------------------------------------------------------------- /configs/train/arm.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | TRAIN: 3 | SOURCE_NAME: "SignDataset" 4 | SOURCE_PATH: "./data/source/sign/train" 5 | TARGET_NAME: "YumiDataset" 6 | TARGET_PATH: "./data/target/yumi" 7 | TEST: 8 | SOURCE_NAME: "SignDataset" 9 | SOURCE_PATH: "./data/source/sign/test" 10 | TARGET_NAME: "YumiDataset" 11 | TARGET_PATH: "./data/target/yumi" 12 | MODEL: 13 | NAME: "ArmNet" 14 | HYPER: 15 | EPOCHS: 10 16 | BATCH_SIZE: 16 17 | LEARNING_RATE: 0.0001 18 | LOSS: 19 | EE: True 20 | VEC: True 21 | ORI: True 22 | COL: True 23 | COL_THRESHOLD: 0.15 24 | REG: True 25 | OTHERS: 26 | SAVE: "./saved/models/arm" 27 | LOG: "./saved/log" 28 | SUMMARY: "./saved/runs" 29 | LOG_INTERVAL: 100 -------------------------------------------------------------------------------- /configs/train/yumi.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | TRAIN: 3 | SOURCE_NAME: "SignAll" 4 | SOURCE_PATH: "./data/source/sign-all/train" 5 | TARGET_NAME: "YumiAll" 6 | TARGET_PATH: "./data/target/yumi-all" 7 | TEST: 8 | SOURCE_NAME: "SignAll" 9 | SOURCE_PATH: "./data/source/sign-all/test" 10 | TARGET_NAME: "YumiAll" 11 | TARGET_PATH: "./data/target/yumi-all" 12 | MODEL: 13 | NAME: "YumiNet" 14 | HYPER: 15 | EPOCHS: 10 16 | BATCH_SIZE: 16 17 | LEARNING_RATE: 0.0001 18 | LOSS: 19 | EE: True 20 | VEC: True 21 | ORI: True 22 | FIN: True 23 | COL: True 24 | COL_THRESHOLD: 0.15 25 | REG: True 26 | OTHERS: 27 | SAVE: "./saved/models/yumi" 28 | LOG: "./saved/log/yumi" 29 | SUMMARY: "./saved/runs/yumi" 30 | LOG_INTERVAL: 100 -------------------------------------------------------------------------------- /configs/inference/arm.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | TEST: 3 | SOURCE_NAME: "SignDataset" 4 | SOURCE_PATH: "./data/source/sign/test" 5 | TARGET_NAME: "YumiDataset" 6 | TARGET_PATH: "./data/target/yumi" 7 | MODEL: 8 | NAME: "ArmNet" 9 | CHECKPOINT: "./saved/models/arm/best_model_epoch_0000.pth" 10 | HYPER: 11 | EPOCHS: 1000 12 | BATCH_SIZE: 16 13 | LEARNING_RATE: 0.001 14 | LOSS: 15 | EE: True 16 | VEC: True 17 | ORI: True 18 | COL: True 19 | COL_THRESHOLD: 0.15 20 | REG: True 21 | INFERENCE: 22 | MOTION: 23 | SOURCE: './data/source/sign/test/h5/yumi_intro_YuMi.h5' 24 | KEY: '我-wo' 25 | H5: 26 | BOOL: True 27 | PATH: './saved/h5' 28 | OTHERS: 29 | SAVE: "./saved/models/arm" 30 | LOG: "./saved/log" 31 | SUMMARY: "./saved/runs" 32 | LOG_INTERVAL: 100 33 | -------------------------------------------------------------------------------- /configs/inference/yumi.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | TEST: 3 | SOURCE_NAME: "SignAll" 4 | SOURCE_PATH: "./data/source/sign-all/test" 5 | TARGET_NAME: "YumiAll" 6 | TARGET_PATH: "./data/target/yumi-all" 7 | MODEL: 8 | NAME: "YumiNet" 9 | CHECKPOINT: "./saved/models/yumi/best_model_epoch_0003.pth" 10 | HYPER: 11 | EPOCHS: 200 12 | BATCH_SIZE: 16 13 | LEARNING_RATE: 0.00003 14 | LOSS: 15 | EE: True 16 | VEC: True 17 | ORI: True 18 | FIN: True 19 | COL: True 20 | COL_THRESHOLD: 0.15 21 | REG: True 22 | # INFERENCE: 23 | # MOTION: 24 | # SOURCE: './data/source/sign-all/test/h5/yumi_intro_YuMi.h5' 25 | # KEY: '我-wo' 26 | # H5: 27 | # BOOL: True 28 | # PATH: './saved/h5' 29 | OTHERS: 30 | LOG: "./saved/log/yumi" 31 | SUMMARY: "./saved/runs/yumi" 32 | LOG_INTERVAL: 100 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Haodong Zhang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /utils/config.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | 3 | _C = CN() 4 | 5 | _C.DATASET = CN() 6 | _C.DATASET.TRAIN = CN() 7 | _C.DATASET.TRAIN.SOURCE_NAME = None 8 | _C.DATASET.TRAIN.SOURCE_PATH = None 9 | _C.DATASET.TRAIN.TARGET_NAME = None 10 | _C.DATASET.TRAIN.TARGET_PATH = None 11 | _C.DATASET.TEST = CN() 12 | _C.DATASET.TEST.SOURCE_NAME = None 13 | _C.DATASET.TEST.SOURCE_PATH = None 14 | _C.DATASET.TEST.TARGET_NAME = None 15 | _C.DATASET.TEST.TARGET_PATH = None 16 | 17 | _C.MODEL = CN() 18 | _C.MODEL.NAME = None 19 | _C.MODEL.CHECKPOINT = None 20 | 21 | _C.HYPER = CN() 22 | _C.HYPER.EPOCHS = None 23 | _C.HYPER.BATCH_SIZE = None 24 | _C.HYPER.LEARNING_RATE = None 25 | 26 | _C.LOSS = CN() 27 | _C.LOSS.EE = False 28 | _C.LOSS.VEC = False 29 | _C.LOSS.COL = False 30 | _C.LOSS.COL_THRESHOLD = None 31 | _C.LOSS.LIM = False 32 | _C.LOSS.ORI = False 33 | _C.LOSS.FIN = False 34 | _C.LOSS.REG = False 35 | 36 | _C.INFERENCE = CN() 37 | _C.INFERENCE.MOTION = CN() 38 | _C.INFERENCE.MOTION.SOURCE = None 39 | _C.INFERENCE.MOTION.KEY = None 40 | _C.INFERENCE.H5 = CN() 41 | _C.INFERENCE.H5.BOOL = None 42 | _C.INFERENCE.H5.PATH = None 43 | 44 | _C.OTHERS = CN() 45 | _C.OTHERS.SAVE = None 46 | _C.OTHERS.LOG = None 47 | _C.OTHERS.SUMMARY = None 48 | _C.OTHERS.LOG_INTERVAL = None 49 | 50 | cfg = _C 51 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # project dirs 132 | data/ 133 | saved/ 134 | data 135 | saved 136 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Neural Retargeting 2 | 3 | Code for the paper "Kinematic Motion Retargeting via Neural Latent Optimization for Learning Sign Language" 4 | 5 | [![arXiv](https://img.shields.io/badge/arXiv-2103.08882-00ff00.svg)](https://arxiv.org/abs/2103.08882) 6 | [![YouTube](https://img.shields.io/badge/YouTube-Video-green.svg)](https://www.youtube.com/watch?v=pX-uie3vLMA) 7 | [![Bilibili](https://img.shields.io/badge/Bilibili-Video-blue.svg)](https://www.bilibili.com/video/BV1mh411Q7BR?share_source=copy_web) 8 | 9 | ## Prerequisite 10 | 11 | - [**PyTorch**](https://pytorch.org/get-started/locally/) Tensors and Dynamic neural networks in Python with strong GPU acceleration 12 | - [**pytorch_geometric**](https://github.com/rusty1s/pytorch_geometric) Geometric Deep Learning Extension Library for PyTorch 13 | - [**Kornia**](https://github.com/kornia/kornia) a differentiable computer vision library for PyTorch. 14 | - [**HDF5 for Python**](https://docs.h5py.org/en/stable/) The h5py package is a Pythonic interface to the HDF5 binary data format. 15 | 16 | 17 | ## Dataset 18 | 19 | The Chinese sign language dataset can be downloaded [here](https://www.jianguoyun.com/p/DYm5RzMQ74eHChj_lJ0E). 20 | 21 | ## Model 22 | 23 | The pretrained model can be downloaded [here](https://www.jianguoyun.com/p/DSl6o3EQy96PCBiN750E). 24 | 25 | ## Get Started 26 | 27 | **Training** 28 | ```bash 29 | CUDA_VISIBLE_DEVICES=0 python main.py --cfg './configs/train/yumi.yaml' 30 | ``` 31 | 32 | **Inference** 33 | ```bash 34 | CUDA_VISIBLE_DEVICES=0 python inference.py --cfg './configs/inference/yumi.yaml' 35 | ``` 36 | 37 | ## Simulation Experiment 38 | 39 | 40 | 41 | We build the simulation environment using pybullet, and the code is in this [repository](https://github.com/0aqz0/yumi-gym). 42 | 43 | After inference is done, the motion retargeting results are stored in a h5 file. Then run the sample code [here](https://github.com/0aqz0/yumi-gym/tree/master/examples). 44 | 45 | ## Real-World Experiment 46 | 47 | Real-world experiments could be conducted on ABB's YuMi dual-arm collaborative robot equipped with Inspire-Robotics' dexterous hands. 48 | 49 | We release the code in this [repository](https://github.com/0aqz0/yumi-control), please follow the instructions. 50 | 51 | ## Citation 52 | 53 | If you find this project useful in your research, please cite this paper. 54 | 55 | ``` 56 | @article{zhang2022kinematic, 57 | title={Kinematic Motion Retargeting via Neural Latent Optimization for Learning Sign Language}, 58 | author={Zhang, Haodong and Li, Weijie and Liu, Jiangpin and Chen, Zexi and Cui, Yuxiang and Wang, Yue and Xiong, Rong}, 59 | journal={IEEE Robotics and Automation Letters}, 60 | year={2022}, 61 | publisher={IEEE} 62 | } 63 | ``` 64 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.data import Batch 3 | from models.loss import calculate_all_loss 4 | import time 5 | 6 | def test_epoch(model, ee_criterion, vec_criterion, col_criterion, lim_criterion, ori_criterion, fin_criterion, reg_criterion, dataloader, target_skeleton, epoch, logger, log_interval, writer, device): 7 | logger.info("Testing Epoch {}".format(epoch+1).center(60, '-')) 8 | start_time = time.time() 9 | 10 | model.eval() 11 | all_losses = [] 12 | ee_losses = [] 13 | vec_losses = [] 14 | col_losses = [] 15 | lim_losses = [] 16 | ori_losses = [] 17 | fin_losses = [] 18 | reg_losses = [] 19 | 20 | with torch.no_grad(): 21 | for batch_idx, data_list in enumerate(dataloader): 22 | for target_idx, target in enumerate(target_skeleton): 23 | # fetch target 24 | target_list = [target for data in data_list] 25 | 26 | # forward 27 | z, target_ang, target_pos, target_rot, target_global_pos, l_hand_ang, l_hand_pos, r_hand_ang, r_hand_pos = model(Batch.from_data_list(data_list).to(device), Batch.from_data_list(target_list).to(device)) 28 | 29 | # calculate all loss 30 | loss = calculate_all_loss(data_list, target_list, ee_criterion, vec_criterion, col_criterion, lim_criterion, ori_criterion, fin_criterion, reg_criterion, 31 | z, target_ang, target_pos, target_rot, target_global_pos, l_hand_pos, r_hand_pos, all_losses, ee_losses, vec_losses, col_losses, lim_losses, ori_losses, fin_losses, reg_losses) 32 | 33 | # Compute average loss 34 | test_loss = sum(all_losses)/len(all_losses) 35 | ee_loss = sum(ee_losses)/len(ee_losses) 36 | vec_loss = sum(vec_losses)/len(vec_losses) 37 | col_loss = sum(col_losses)/len(col_losses) 38 | lim_loss = sum(lim_losses)/len(lim_losses) 39 | ori_loss = sum(ori_losses)/len(ori_losses) 40 | fin_loss = sum(fin_losses)/len(fin_losses) 41 | reg_loss = sum(reg_losses)/len(reg_losses) 42 | # Log 43 | writer.add_scalars('testing_loss', {'test': test_loss}, epoch+1) 44 | writer.add_scalars('end_effector_loss', {'test': ee_loss}, epoch+1) 45 | writer.add_scalars('vector_loss', {'test': vec_loss}, epoch+1) 46 | writer.add_scalars('collision_loss', {'test': col_loss}, epoch+1) 47 | writer.add_scalars('joint_limit_loss', {'test': lim_loss}, epoch+1) 48 | writer.add_scalars('orientation_loss', {'test': ori_loss}, epoch+1) 49 | writer.add_scalars('finger_loss', {'test': fin_loss}, epoch+1) 50 | writer.add_scalars('regularization_loss', {'test': reg_loss}, epoch+1) 51 | end_time = time.time() 52 | logger.info("Epoch {:04d} | Testing Time {:.2f} s | Avg Testing Loss {:.6f} | Avg EE Loss {:.6f} | Avg Vec Loss {:.6f} | Avg Col Loss {:.6f} | Avg Lim Loss {:.6f} | Avg Ori Loss {:.6f} | Avg Fin Loss {:.6f} | Avg Reg Loss {:.6f}".format(epoch+1, end_time-start_time, test_loss, ee_loss, vec_loss, col_loss, lim_loss, ori_loss, fin_loss, reg_loss)) 53 | 54 | return test_loss 55 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.data import Batch 3 | from models.loss import calculate_all_loss 4 | import time 5 | 6 | def train_epoch(model, ee_criterion, vec_criterion, col_criterion, lim_criterion, ori_criterion, fin_criterion, reg_criterion, optimizer, dataloader, target_skeleton, epoch, logger, log_interval, writer, device, z_all=None, ang_all=None): 7 | logger.info("Training Epoch {}".format(epoch+1).center(60, '-')) 8 | start_time = time.time() 9 | 10 | model.train() 11 | all_losses = [] 12 | ee_losses = [] 13 | vec_losses = [] 14 | col_losses = [] 15 | lim_losses = [] 16 | ori_losses = [] 17 | fin_losses = [] 18 | reg_losses = [] 19 | 20 | for batch_idx, data_list in enumerate(dataloader): 21 | for target_idx, target in enumerate(target_skeleton): 22 | # zero gradient 23 | optimizer.zero_grad() 24 | 25 | # fetch target 26 | target_list = [target for data in data_list] 27 | 28 | # forward 29 | if z_all is not None: 30 | z = z_all[batch_idx] 31 | _, target_ang, target_pos, target_rot, target_global_pos, l_hand_ang, l_hand_pos, r_hand_ang, r_hand_pos = model.decode(z, Batch.from_data_list(target_list).to(device)) 32 | else: 33 | z, target_ang, target_pos, target_rot, target_global_pos, l_hand_ang, l_hand_pos, r_hand_ang, r_hand_pos = model(Batch.from_data_list(data_list).to(device), Batch.from_data_list(target_list).to(device)) 34 | 35 | # calculate all loss 36 | loss = calculate_all_loss(data_list, target_list, ee_criterion, vec_criterion, col_criterion, lim_criterion, ori_criterion, fin_criterion, reg_criterion, 37 | z, target_ang, target_pos, target_rot, target_global_pos, l_hand_pos, r_hand_pos, all_losses, ee_losses, vec_losses, col_losses, lim_losses, ori_losses, fin_losses, reg_losses) 38 | 39 | # backward 40 | loss.backward() 41 | 42 | # gradient clipping 43 | torch.nn.utils.clip_grad_norm_(model.parameters(), 10) 44 | 45 | # optimize 46 | optimizer.step() 47 | 48 | # log 49 | if (batch_idx + 1) % log_interval == 0: 50 | logger.info("epoch {:04d} | iteration {:05d} | EE {:.6f} | Vec {:.6f} | Col {:.6f} | Lim {:.6f} | Ori {:.6f} | Fin {:.6f} | Reg {:.6f}".format(epoch+1, batch_idx+1, ee_losses[-1], vec_losses[-1], col_losses[-1], lim_losses[-1], ori_losses[-1], fin_losses[-1], reg_losses[-1])) 51 | 52 | # Compute average loss 53 | train_loss = sum(all_losses)/len(all_losses) 54 | ee_loss = sum(ee_losses)/len(ee_losses) 55 | vec_loss = sum(vec_losses)/len(vec_losses) 56 | col_loss = sum(col_losses)/len(col_losses) 57 | lim_loss = sum(lim_losses)/len(lim_losses) 58 | ori_loss = sum(ori_losses)/len(ori_losses) 59 | fin_loss = sum(fin_losses)/len(fin_losses) 60 | reg_loss = sum(reg_losses)/len(reg_losses) 61 | # Log 62 | writer.add_scalars('training_loss', {'train': train_loss}, epoch+1) 63 | writer.add_scalars('end_effector_loss', {'train': ee_loss}, epoch+1) 64 | writer.add_scalars('vector_loss', {'train': vec_loss}, epoch+1) 65 | writer.add_scalars('collision_loss', {'train': col_loss}, epoch+1) 66 | writer.add_scalars('joint_limit_loss', {'train': lim_loss}, epoch+1) 67 | writer.add_scalars('orientation_loss', {'train': ori_loss}, epoch+1) 68 | writer.add_scalars('finger_loss', {'train': fin_loss}, epoch+1) 69 | writer.add_scalars('regularization_loss', {'train': reg_loss}, epoch+1) 70 | end_time = time.time() 71 | logger.info("Epoch {:04d} | Training Time {:.2f} s | Avg Training Loss {:.6f} | Avg EE Loss {:.6f} | Avg Vec Loss {:.6f} | Avg Col Loss {:.6f} | Avg Lim Loss {:.6f} | Avg Ori Loss {:.6f} | Avg Fin Loss {:.6f} | Avg Reg Loss {:.6f}".format(epoch+1, end_time-start_time, train_loss, ee_loss, vec_loss, col_loss, lim_loss, ori_loss, fin_loss, reg_loss)) 72 | 73 | return train_loss 74 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torch_geometric.transforms as transforms 5 | from torch_geometric.data import DataListLoader 6 | from tensorboardX import SummaryWriter 7 | 8 | from models import model 9 | from models.loss import CollisionLoss, JointLimitLoss, RegLoss 10 | import dataset 11 | from dataset import Normalize 12 | from train import train_epoch 13 | from test import test_epoch 14 | from utils.config import cfg 15 | from utils.util import create_folder 16 | 17 | import os 18 | import logging 19 | import argparse 20 | from datetime import datetime 21 | 22 | # Argument parse 23 | parser = argparse.ArgumentParser(description='Command line arguments') 24 | parser.add_argument('--cfg', default='configs/train/yumi.yaml', type=str, help='Path to configuration file') 25 | args = parser.parse_args() 26 | 27 | # Configurations parse 28 | cfg.merge_from_file(args.cfg) 29 | cfg.freeze() 30 | print(cfg) 31 | 32 | # Create folder 33 | create_folder(cfg.OTHERS.SAVE) 34 | create_folder(cfg.OTHERS.LOG) 35 | create_folder(cfg.OTHERS.SUMMARY) 36 | 37 | # Create logger & tensorboard writer 38 | logging.basicConfig(level=logging.INFO, format="%(message)s", handlers=[logging.FileHandler(os.path.join(cfg.OTHERS.LOG, "{:%Y-%m-%d_%H-%M-%S}.log".format(datetime.now()))), logging.StreamHandler()]) 39 | logger = logging.getLogger() 40 | writer = SummaryWriter(os.path.join(cfg.OTHERS.SUMMARY, "{:%Y-%m-%d_%H-%M-%S}".format(datetime.now()))) 41 | 42 | # Device setting 43 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 44 | 45 | 46 | if __name__ == '__main__': 47 | # Load data 48 | pre_transform = transforms.Compose([Normalize()]) 49 | train_set = getattr(dataset, cfg.DATASET.TRAIN.SOURCE_NAME)(root=cfg.DATASET.TRAIN.SOURCE_PATH, pre_transform=pre_transform) 50 | train_loader = DataListLoader(train_set, batch_size=cfg.HYPER.BATCH_SIZE, shuffle=True, num_workers=16, pin_memory=True) 51 | train_target = sorted([target for target in getattr(dataset, cfg.DATASET.TRAIN.TARGET_NAME)(root=cfg.DATASET.TRAIN.TARGET_PATH)], key=lambda target : target.skeleton_type) 52 | test_set = getattr(dataset, cfg.DATASET.TEST.SOURCE_NAME)(root=cfg.DATASET.TEST.SOURCE_PATH, pre_transform=pre_transform) 53 | test_loader = DataListLoader(test_set, batch_size=cfg.HYPER.BATCH_SIZE, shuffle=True, num_workers=16, pin_memory=True) 54 | test_target = sorted([target for target in getattr(dataset, cfg.DATASET.TEST.TARGET_NAME)(root=cfg.DATASET.TEST.TARGET_PATH)], key=lambda target : target.skeleton_type) 55 | 56 | # Create model 57 | model = getattr(model, cfg.MODEL.NAME)().to(device) 58 | 59 | # Load checkpoint 60 | if cfg.MODEL.CHECKPOINT is not None: 61 | model.load_state_dict(torch.load(cfg.MODEL.CHECKPOINT)) 62 | 63 | # Create loss criterion 64 | # end effector loss 65 | ee_criterion = nn.MSELoss() if cfg.LOSS.EE else None 66 | # vector similarity loss 67 | vec_criterion = nn.MSELoss() if cfg.LOSS.VEC else None 68 | # collision loss 69 | col_criterion = CollisionLoss(cfg.LOSS.COL_THRESHOLD) if cfg.LOSS.COL else None 70 | # joint limit loss 71 | lim_criterion = JointLimitLoss() if cfg.LOSS.LIM else None 72 | # end effector orientation loss 73 | ori_criterion = nn.MSELoss() if cfg.LOSS.ORI else None 74 | # finger similarity loss 75 | fin_criterion = nn.MSELoss() if cfg.LOSS.FIN else None 76 | # regularization loss 77 | reg_criterion = RegLoss() if cfg.LOSS.REG else None 78 | 79 | # Create optimizer 80 | optimizer = optim.Adam(model.parameters(), lr=cfg.HYPER.LEARNING_RATE) 81 | 82 | best_loss = float('Inf') 83 | 84 | for epoch in range(cfg.HYPER.EPOCHS): 85 | # Start training 86 | train_loss = train_epoch(model, ee_criterion, vec_criterion, col_criterion, lim_criterion, ori_criterion, fin_criterion, reg_criterion, 87 | optimizer, train_loader, train_target, epoch, logger, cfg.OTHERS.LOG_INTERVAL, writer, device) 88 | 89 | # Start testing 90 | test_loss = test_epoch(model, ee_criterion, vec_criterion, col_criterion, lim_criterion, ori_criterion, fin_criterion, reg_criterion, 91 | test_loader, test_target, epoch, logger, cfg.OTHERS.LOG_INTERVAL, writer, device) 92 | 93 | # Save model 94 | if test_loss < best_loss: 95 | best_loss = test_loss 96 | torch.save(model.state_dict(), os.path.join(cfg.OTHERS.SAVE, "best_model_epoch_{:04d}.pth".format(epoch))) 97 | logger.info("Epoch {} Model Saved".format(epoch+1).center(60, '-')) 98 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch_geometric.nn.conv import MessagePassing 5 | 6 | import os, inspect, sys 7 | currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) 8 | sys.path.insert(0, currentdir) 9 | from kinematics import ForwardKinematicsURDF, ForwardKinematicsAxis 10 | 11 | 12 | class SpatialBasicBlock(MessagePassing): 13 | def __init__(self, in_channels, out_channels, edge_channels, aggr='add', batch_norm=False, bias=True, **kwargs): 14 | super(SpatialBasicBlock, self).__init__(aggr=aggr, **kwargs) 15 | self.batch_norm = batch_norm 16 | # network architecture 17 | self.lin = nn.Linear(2*in_channels + edge_channels, out_channels, bias=bias) 18 | self.upsample = nn.Linear(in_channels, out_channels, bias=bias) 19 | self.bn = nn.BatchNorm1d(out_channels) 20 | 21 | self.reset_parameters() 22 | 23 | def reset_parameters(self): 24 | self.lin.reset_parameters() 25 | self.bn.reset_parameters() 26 | 27 | def forward(self, x, edge_index, edge_attr=None): 28 | if isinstance(x, torch.Tensor): 29 | x = (x, x) 30 | 31 | out = self.propagate(edge_index, x=x, edge_attr=edge_attr) 32 | out = self.bn(out) if self.batch_norm else out 33 | out += self.upsample(x[1]) 34 | return out 35 | 36 | def message(self, x_i, x_j, edge_attr): 37 | z = torch.cat([x_i, x_j, edge_attr], dim=-1) 38 | return F.leaky_relu(self.lin(z)) 39 | 40 | 41 | class Encoder(torch.nn.Module): 42 | def __init__(self, channels, dim): 43 | super(Encoder, self).__init__() 44 | self.conv1 = SpatialBasicBlock(in_channels=channels, out_channels=16, edge_channels=dim) 45 | self.conv2 = SpatialBasicBlock(in_channels=16, out_channels=32, edge_channels=dim) 46 | self.conv3 = SpatialBasicBlock(in_channels=32, out_channels=64, edge_channels=dim) 47 | 48 | def forward(self, x, edge_index, edge_attr): 49 | """ 50 | Keyword arguments: 51 | x -- joint angles [num_nodes, num_node_features] 52 | edge_index -- edge index [2, num_edges] 53 | edge_attr -- edge features [num_edges, num_edge_features] 54 | """ 55 | out = self.conv1(x, edge_index, edge_attr) 56 | out = self.conv2(out, edge_index, edge_attr) 57 | out = self.conv3(out, edge_index, edge_attr) 58 | return out 59 | 60 | 61 | class Decoder(torch.nn.Module): 62 | def __init__(self, channels, dim): 63 | super(Decoder, self).__init__() 64 | self.conv1 = SpatialBasicBlock(in_channels=64+2, out_channels=32, edge_channels=dim) 65 | self.conv2 = SpatialBasicBlock(in_channels=32, out_channels=16, edge_channels=dim) 66 | self.conv3 = SpatialBasicBlock(in_channels=16, out_channels=channels, edge_channels=dim) 67 | 68 | def forward(self, x, edge_index, edge_attr, lower, upper): 69 | """ 70 | Keyword arguments: 71 | x -- joint angles [num_nodes, num_node_features] 72 | edge_index -- edge index [2, num_edges] 73 | edge_attr -- edge features [num_edges, num_edge_features] 74 | """ 75 | x = torch.cat([x, lower, upper], dim=1) 76 | out = self.conv1(x, edge_index, edge_attr) 77 | out = self.conv2(out, edge_index, edge_attr) 78 | out = self.conv3(out, edge_index, edge_attr).tanh() 79 | return out 80 | 81 | 82 | class ArmNet(torch.nn.Module): 83 | def __init__(self): 84 | super(ArmNet, self).__init__() 85 | self.encoder = Encoder(6, 3) 86 | self.transform = nn.Sequential( 87 | nn.Linear(6*64, 14*64), 88 | nn.Tanh(), 89 | ) 90 | self.decoder = Decoder(1, 6) 91 | self.fk = ForwardKinematicsURDF() 92 | 93 | def forward(self, data, target): 94 | return self.decode(self.encode(data), target) 95 | 96 | def encode(self, data): 97 | z = self.encoder(data.x, data.edge_index, data.edge_attr) 98 | z = self.transform(z.view(data.num_graphs, -1, 64).view(data.num_graphs, -1)).view(data.num_graphs, -1, 64).view(-1, 64) 99 | return z 100 | 101 | def decode(self, z, target): 102 | ang = self.decoder(z, target.edge_index, target.edge_attr, target.lower, target.upper) 103 | ang = target.lower + (target.upper - target.lower)*(ang + 1)/2 104 | pos, rot, global_pos = self.fk(ang, target.parent, target.offset, target.num_graphs) 105 | return z, ang, pos, rot, global_pos, None, None, None, None 106 | 107 | 108 | class HandNet(torch.nn.Module): 109 | def __init__(self): 110 | super(HandNet, self).__init__() 111 | self.encoder = Encoder(3, 3) 112 | self.transform = nn.Sequential( 113 | nn.Linear(17*64, 18*64), 114 | nn.Tanh(), 115 | ) 116 | self.decoder = Decoder(1, 6) 117 | self.fk = ForwardKinematicsAxis() 118 | 119 | def forward(self, data, target): 120 | return self.decode(self.encode(data), target) 121 | 122 | def encode(self, data): 123 | x = torch.cat([data.l_hand_x, data.r_hand_x], dim=0) 124 | edge_index = torch.cat([data.l_hand_edge_index, data.r_hand_edge_index+data.l_hand_x.size(0)], dim=1) 125 | edge_attr = torch.cat([data.l_hand_edge_attr, data.r_hand_edge_attr], dim=0) 126 | z = self.encoder(x, edge_index, edge_attr) 127 | z = self.transform(z.view(2*data.num_graphs, -1, 64).view(2*data.num_graphs, -1)).view(2*data.num_graphs, -1, 64).view(-1, 64) 128 | # l_hand_z = self.encoder(data.l_hand_x, data.l_hand_edge_index, data.l_hand_edge_attr) 129 | # l_hand_z = self.transform(l_hand_z.view(data.num_graphs, -1, 64).view(data.num_graphs, -1)).view(data.num_graphs, -1, 64).view(-1, 64) 130 | # r_hand_z = self.encoder(data.r_hand_x, data.r_hand_edge_index, data.r_hand_edge_attr) 131 | # r_hand_z = self.transform(r_hand_z.view(data.num_graphs, -1, 64).view(data.num_graphs, -1)).view(data.num_graphs, -1, 64).view(-1, 64) 132 | # z = torch.cat([l_hand_z, r_hand_z], dim=0) 133 | return z 134 | 135 | def decode(self, z, target): 136 | edge_index = torch.cat([target.hand_edge_index, target.hand_edge_index+z.size(0)//2], dim=1) 137 | edge_attr = torch.cat([target.hand_edge_attr, target.hand_edge_attr], dim=0) 138 | lower = torch.cat([target.hand_lower, target.hand_lower], dim=0) 139 | upper = torch.cat([target.hand_upper, target.hand_upper], dim=0) 140 | offset = torch.cat([target.hand_offset, target.hand_offset], dim=0) 141 | parent = torch.cat([target.hand_parent, target.hand_parent], dim=0) 142 | num_graphs = 2*target.num_graphs 143 | axis = torch.cat([target.hand_axis, target.hand_axis], dim=0) 144 | 145 | hand_ang = self.decoder(z, edge_index, edge_attr, lower, upper) 146 | hand_ang = lower + (upper - lower)*(hand_ang + 1)/2 147 | hand_pos, _, _ = self.fk(hand_ang, parent, offset, num_graphs, axis) 148 | 149 | half = hand_ang.size(0)//2 150 | l_hand_ang, r_hand_ang = hand_ang[:half, :], hand_ang[half:, :] 151 | l_hand_pos, r_hand_pos = hand_pos[:half, :], hand_pos[half:, :] 152 | # half = z.shape[0] // 2 153 | # l_hand_z, r_hand_z = z[:half, :], z[half:, :] 154 | 155 | # l_hand_ang = self.decoder(l_hand_z, target.hand_edge_index, target.hand_edge_attr, target.hand_lower, target.hand_upper) 156 | # l_hand_ang = target.hand_lower + (target.hand_upper - target.hand_lower)*(l_hand_ang + 1)/2 157 | # l_hand_pos, _, _ = self.fk(l_hand_ang, target.hand_parent, target.hand_offset, target.num_graphs, target.hand_axis) 158 | 159 | # r_hand_ang = self.decoder(r_hand_z, target.hand_edge_index, target.hand_edge_attr, target.hand_lower, target.hand_upper) 160 | # r_hand_ang = target.hand_lower + (target.hand_upper - target.hand_lower)*(r_hand_ang + 1)/2 161 | # r_hand_pos, _, _ = self.fk(r_hand_ang, target.hand_parent, target.hand_offset, target.num_graphs, target.hand_axis) 162 | return z, None, None, None, None, l_hand_ang, l_hand_pos, r_hand_ang, r_hand_pos 163 | 164 | 165 | class YumiNet(torch.nn.Module): 166 | def __init__(self): 167 | super(YumiNet, self).__init__() 168 | self.arm_net = ArmNet() 169 | self.hand_net = HandNet() 170 | 171 | def forward(self, data, target): 172 | return self.decode(self.encode(data), target) 173 | 174 | def encode(self, data): 175 | arm_z = self.arm_net.encode(data) 176 | hand_z = self.hand_net.encode(data) 177 | z = torch.cat([arm_z, hand_z], dim=0) 178 | return z 179 | 180 | def decode(self, z, target): 181 | half = target.num_nodes 182 | arm_z, hand_z = z[:half, :], z[half:, :] 183 | _, ang, pos, rot, global_pos, _, _, _, _ = self.arm_net.decode(arm_z, target) 184 | _, _, _, _, _, l_hand_ang, l_hand_pos, r_hand_ang, r_hand_pos = self.hand_net.decode(hand_z, target) 185 | return z, ang, pos, rot, global_pos, l_hand_ang, l_hand_pos, r_hand_ang, r_hand_pos 186 | 187 | -------------------------------------------------------------------------------- /models/kinematics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | """ 5 | Forward Kinematics for URDF 6 | """ 7 | class ForwardKinematicsURDF(nn.Module): 8 | def __init__(self): 9 | super(ForwardKinematicsURDF, self).__init__() 10 | 11 | def forward(self, x, parent, offset, num_graphs, axis='z', order='xyz'): 12 | """ 13 | x -- joint angles [num_graphs*num_nodes, 1] 14 | parent -- node parent [num_graphs*num_nodes] 15 | offset -- node origin(xyzrpy) [num_graphs*num_nodes, 6] 16 | num_graphs -- number of graphs 17 | axis -- rotation axis for rotation x 18 | order -- rotation order for init rotation 19 | """ 20 | x = x.view(num_graphs, -1) # [batch_size, num_nodes] 21 | parent = parent.view(num_graphs, -1)[0] # [num_nodes] the same batch, the same topology 22 | offset = offset.view(num_graphs, -1, 6) # [batch_size, num_nodes, 6] 23 | xyz = offset[:, :, :3] # [batch_size, num_nodes, 3] 24 | rpy = offset[:, :, 3:] # [batch_size, num_nodes, 3] 25 | 26 | positions = torch.empty(x.shape[0], x.shape[1], 3, device=x.device) # [batch_size, num_nodes, 3] 27 | global_positions = torch.empty(x.shape[0], x.shape[1], 3, device=x.device) # [batch_size, num_nodes, 3] 28 | rot_matrices = torch.empty(x.shape[0], x.shape[1], 3, 3, device=x.device) # [batch_size, num_nodes, 3, 3] 29 | transform = self.transform_from_axis(x, axis) # [batch_size, num_nodes, 3, 3] 30 | rpy_transform = self.transform_from_euler(rpy, order) # [batch_size, num_nodes, 3, 3] 31 | 32 | # iterate all nodes 33 | for node_idx in range(x.shape[1]): 34 | # serach parent 35 | parent_idx = parent[node_idx] 36 | 37 | # position 38 | if parent_idx != -1: 39 | positions[:, node_idx, :] = torch.bmm(rot_matrices[:, parent_idx, :, :], xyz[:, node_idx, :].unsqueeze(2)).squeeze() + positions[:, parent_idx, :] 40 | global_positions[:, node_idx, :] = torch.bmm(rot_matrices[:, parent_idx, :, :], xyz[:, node_idx, :].unsqueeze(2)).squeeze() + global_positions[:, parent_idx, :] 41 | rot_matrices[:, node_idx, :, :] = torch.bmm(rot_matrices[:, parent_idx, :, :].clone(), torch.bmm(rpy_transform[:, node_idx, :, :], transform[:, node_idx, :, :])) 42 | else: 43 | positions[:, node_idx, :] = torch.zeros(3) # xyz[:, node_idx, :] 44 | global_positions[:, node_idx, :] = xyz[:, node_idx, :] 45 | rot_matrices[:, node_idx, :, :] = torch.bmm(rpy_transform[:, node_idx, :, :], transform[:, node_idx, :, :]) 46 | 47 | return positions.view(-1, 3), rot_matrices.view(-1, 3, 3), global_positions.view(-1, 3) 48 | 49 | @staticmethod 50 | def transform_from_euler(rotation, order): 51 | transform = torch.matmul(ForwardKinematicsURDF.transform_from_axis(rotation[..., 2], order[2]), 52 | ForwardKinematicsURDF.transform_from_axis(rotation[..., 1], order[1])) 53 | transform = torch.matmul(transform, 54 | ForwardKinematicsURDF.transform_from_axis(rotation[..., 0], order[0])) 55 | return transform 56 | 57 | @staticmethod 58 | def transform_from_axis(euler, axis): 59 | transform = torch.empty(euler.shape[0:2] + (3, 3), device=euler.device) # [batch_size, num_nodes, 3, 3] 60 | cos = torch.cos(euler) 61 | sin = torch.sin(euler) 62 | cord = ord(axis) - ord('x') 63 | 64 | transform[..., cord, :] = transform[..., :, cord] = 0 65 | transform[..., cord, cord] = 1 66 | 67 | if axis == 'x': 68 | transform[..., 1, 1] = transform[..., 2, 2] = cos 69 | transform[..., 1, 2] = -sin 70 | transform[..., 2, 1] = sin 71 | if axis == 'y': 72 | transform[..., 0, 0] = transform[..., 2, 2] = cos 73 | transform[..., 0, 2] = sin 74 | transform[..., 2, 0] = -sin 75 | if axis == 'z': 76 | transform[..., 0, 0] = transform[..., 1, 1] = cos 77 | transform[..., 0, 1] = -sin 78 | transform[..., 1, 0] = sin 79 | 80 | return transform 81 | 82 | 83 | """ 84 | Forward Kinematics with Different Axes 85 | """ 86 | class ForwardKinematicsAxis(nn.Module): 87 | def __init__(self): 88 | super(ForwardKinematicsAxis, self).__init__() 89 | 90 | def forward(self, x, parent, offset, num_graphs, axis, order='xyz'): 91 | """ 92 | x -- joint angles [num_graphs*num_nodes, 1] 93 | parent -- node parent [num_graphs*num_nodes] 94 | offset -- node origin(xyzrpy) [num_graphs*num_nodes, 6] 95 | num_graphs -- number of graphs 96 | axis -- rotation axis for rotation x 97 | order -- rotation order for init rotation 98 | """ 99 | x = x.view(num_graphs, -1) # [batch_size, num_nodes] 100 | parent = parent.view(num_graphs, -1)[0] # [num_nodes] the same batch, the same topology 101 | axis = axis.view(num_graphs, -1, 3)[0] # [num_nodes, 3] the same batch, the same topology 102 | axis_norm = torch.norm(axis, dim=-1) 103 | # print(x.shape, axis.shape) 104 | x = x * axis_norm # filter no rotation node 105 | offset = offset.view(num_graphs, -1, 6) # [batch_size, num_nodes, 6] 106 | xyz = offset[:, :, :3] # [batch_size, num_nodes, 3] 107 | rpy = offset[:, :, 3:] # [batch_size, num_nodes, 3] 108 | 109 | positions = torch.empty(x.shape[0], x.shape[1], 3, device=x.device) # [batch_size, num_nodes, 3] 110 | global_positions = torch.empty(x.shape[0], x.shape[1], 3, device=x.device) # [batch_size, num_nodes, 3] 111 | rot_matrices = torch.empty(x.shape[0], x.shape[1], 3, 3, device=x.device) # [batch_size, num_nodes, 3, 3] 112 | transform = self.transform_from_multiple_axis(x, axis) # [batch_size, num_nodes, 3, 3] 113 | rpy_transform = self.transform_from_euler(rpy, order) # [batch_size, num_nodes, 3, 3] 114 | 115 | # iterate all nodes 116 | for node_idx in range(x.shape[1]): 117 | # serach parent 118 | parent_idx = parent[node_idx] 119 | 120 | # position 121 | if parent_idx != -1: 122 | positions[:, node_idx, :] = torch.bmm(rot_matrices[:, parent_idx, :, :], xyz[:, node_idx, :].unsqueeze(2)).squeeze() + positions[:, parent_idx, :] 123 | global_positions[:, node_idx, :] = torch.bmm(rot_matrices[:, parent_idx, :, :], xyz[:, node_idx, :].unsqueeze(2)).squeeze() + global_positions[:, parent_idx, :] 124 | rot_matrices[:, node_idx, :, :] = torch.bmm(rot_matrices[:, parent_idx, :, :].clone(), torch.bmm(rpy_transform[:, node_idx, :, :], transform[:, node_idx, :, :])) 125 | else: 126 | positions[:, node_idx, :] = torch.zeros(3) # xyz[:, node_idx, :] 127 | global_positions[:, node_idx, :] = xyz[:, node_idx, :] 128 | rot_matrices[:, node_idx, :, :] = torch.bmm(rpy_transform[:, node_idx, :, :], transform[:, node_idx, :, :]) 129 | 130 | return positions.view(-1, 3), rot_matrices.view(-1, 3, 3), global_positions.view(-1, 3) 131 | 132 | @staticmethod 133 | def transform_from_euler(rotation, order): 134 | transform = torch.matmul(ForwardKinematicsAxis.transform_from_single_axis(rotation[..., 2], order[2]), 135 | ForwardKinematicsAxis.transform_from_single_axis(rotation[..., 1], order[1])) 136 | transform = torch.matmul(transform, 137 | ForwardKinematicsAxis.transform_from_single_axis(rotation[..., 0], order[0])) 138 | return transform 139 | 140 | @staticmethod 141 | def transform_from_single_axis(euler, axis): 142 | transform = torch.empty(euler.shape[0:2] + (3, 3), device=euler.device) # [batch_size, num_nodes, 3, 3] 143 | cos = torch.cos(euler) 144 | sin = torch.sin(euler) 145 | cord = ord(axis) - ord('x') 146 | 147 | transform[..., cord, :] = transform[..., :, cord] = 0 148 | transform[..., cord, cord] = 1 149 | 150 | if axis == 'x': 151 | transform[..., 1, 1] = transform[..., 2, 2] = cos 152 | transform[..., 1, 2] = -sin 153 | transform[..., 2, 1] = sin 154 | if axis == 'y': 155 | transform[..., 0, 0] = transform[..., 2, 2] = cos 156 | transform[..., 0, 2] = sin 157 | transform[..., 2, 0] = -sin 158 | if axis == 'z': 159 | transform[..., 0, 0] = transform[..., 1, 1] = cos 160 | transform[..., 0, 1] = -sin 161 | transform[..., 1, 0] = sin 162 | 163 | return transform 164 | 165 | @staticmethod 166 | def transform_from_multiple_axis(euler, axis): 167 | transform = torch.empty(euler.shape[0:2] + (3, 3), device=euler.device) # [batch_size, num_nodes, 3, 3] 168 | cos = torch.cos(euler) 169 | sin = torch.sin(euler) 170 | n1 = axis[..., 0] 171 | n2 = axis[..., 1] 172 | n3 = axis[..., 2] 173 | 174 | transform[..., 0, 0] = cos + n1 * n1 * (1 - cos) 175 | transform[..., 0, 1] = n1 * n2 * (1 - cos) - n3 * sin 176 | transform[..., 0, 2] = n1 * n3 * (1 - cos) + n2 * sin 177 | transform[..., 1, 0] = n1 * n2 * (1 - cos) + n3 * sin 178 | transform[..., 1, 1] = cos + n2 * n2 * (1 - cos) 179 | transform[..., 1, 2] = n2 * n3 * (1 - cos) - n1 * sin 180 | transform[..., 2, 0] = n1 * n3 * (1 - cos) - n2 * sin 181 | transform[..., 2, 1] = n2 * n3 * (1 - cos) + n1 * sin 182 | transform[..., 2, 2] = cos + n3 * n3 * (1 - cos) 183 | 184 | return transform -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torch_geometric.transforms as transforms 5 | from torch_geometric.data import Batch, DataListLoader 6 | from tensorboardX import SummaryWriter 7 | import matplotlib.pyplot as plt 8 | import matplotlib.animation as animation 9 | import mpl_toolkits.mplot3d.axes3d as p3 10 | import numpy as np 11 | import h5py 12 | import argparse 13 | import logging 14 | import time 15 | import os 16 | import copy 17 | from datetime import datetime 18 | 19 | import dataset 20 | from dataset import Normalize, parse_h5, parse_h5_hand, parse_all 21 | from models import model 22 | from models.loss import CollisionLoss, JointLimitLoss, RegLoss 23 | from train import train_epoch 24 | from utils.config import cfg 25 | from utils.util import create_folder 26 | 27 | # Argument parse 28 | parser = argparse.ArgumentParser(description='Inference with trained model') 29 | parser.add_argument('--cfg', default='configs/inference/yumi.yaml', type=str, help='Path to configuration file') 30 | args = parser.parse_args() 31 | 32 | # Configurations parse 33 | cfg.merge_from_file(args.cfg) 34 | cfg.freeze() 35 | print(cfg) 36 | 37 | # Create folder 38 | create_folder(cfg.OTHERS.LOG) 39 | create_folder(cfg.OTHERS.SUMMARY) 40 | 41 | # Create logger & tensorboard writer 42 | logging.basicConfig(level=logging.INFO, format="%(message)s", handlers=[logging.FileHandler(os.path.join(cfg.OTHERS.LOG, "{:%Y-%m-%d_%H-%M-%S}.log".format(datetime.now()))), logging.StreamHandler()]) 43 | logger = logging.getLogger() 44 | writer = SummaryWriter(os.path.join(cfg.OTHERS.SUMMARY, "{:%Y-%m-%d_%H-%M-%S}".format(datetime.now()))) 45 | 46 | # Device setting 47 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 48 | 49 | if __name__ == '__main__': 50 | # Load data 51 | pre_transform = transforms.Compose([Normalize()]) 52 | if cfg.INFERENCE.MOTION.KEY: 53 | # inference single key 54 | print('Inference single key {}'.format(cfg.INFERENCE.MOTION.KEY)) 55 | # test_data, l_hand_angle, r_hand_angle = parse_h5(filename=cfg.INFERENCE.MOTION.SOURCE, selected_key=cfg.INFERENCE.MOTION.KEY) 56 | # test_data = parse_h5_hand(filename=cfg.INFERENCE.MOTION.SOURCE, selected_key=cfg.INFERENCE.MOTION.KEY) 57 | test_data = parse_all(filename=cfg.INFERENCE.MOTION.SOURCE, selected_key=cfg.INFERENCE.MOTION.KEY) 58 | test_data = [pre_transform(data) for data in test_data] 59 | indices = [idx for idx in range(0, len(test_data), cfg.HYPER.BATCH_SIZE)] 60 | test_loader = [test_data]#[test_data[idx: idx+cfg.HYPER.BATCH_SIZE] for idx in indices] 61 | hf = h5py.File(os.path.join(cfg.INFERENCE.H5.PATH, 'source.h5'), 'w') 62 | g1 = hf.create_group('group1') 63 | source_pos = torch.stack([data.pos for data in test_data], dim=0) 64 | g1.create_dataset('l_joint_pos', data=source_pos[:, :3]) 65 | g1.create_dataset('r_joint_pos', data=source_pos[:, 3:]) 66 | hf.close() 67 | print('Source H5 file saved!') 68 | else: 69 | # inference all 70 | print('Inference all') 71 | test_set = getattr(dataset, cfg.DATASET.TEST.SOURCE_NAME)(root=cfg.DATASET.TEST.SOURCE_PATH, pre_transform=pre_transform) 72 | test_loader = DataListLoader(test_set, batch_size=cfg.HYPER.BATCH_SIZE, shuffle=False, num_workers=16, pin_memory=True) 73 | test_target = sorted([target for target in getattr(dataset, cfg.DATASET.TEST.TARGET_NAME)(root=cfg.DATASET.TEST.TARGET_PATH)], key=lambda target : target.skeleton_type) 74 | 75 | # Create model 76 | model = getattr(model, cfg.MODEL.NAME)().to(device) 77 | 78 | # Load checkpoint 79 | if cfg.MODEL.CHECKPOINT is not None: 80 | model.load_state_dict(torch.load(cfg.MODEL.CHECKPOINT)) 81 | 82 | # training set z mean & std 83 | # train_set = getattr(dataset, "SignAll")(root="./data/source/sign-all/train", pre_transform=pre_transform) 84 | # train_loader = DataListLoader(train_set, batch_size=cfg.HYPER.BATCH_SIZE, shuffle=True, num_workers=16, pin_memory=True) 85 | # train_target = sorted([target for target in getattr(dataset, "YumiAll")(root="./data/target/yumi-all")], key=lambda target : target.skeleton_type) 86 | # model.eval() 87 | # z_train = [] 88 | # for batch_idx, data_list in enumerate(train_loader): 89 | # for target_idx, target in enumerate(train_target): 90 | # # fetch target 91 | # target_list = [target for data in data_list] 92 | # # forward 93 | # z = model.encode(Batch.from_data_list(data_list).to(device)).detach() 94 | # z_train.append(z) 95 | # z_train = torch.cat(z_train, dim=0) 96 | # mean = z_train.mean(0) 97 | # std = z_train.std(0) 98 | # print(z_train.shape, mean.shape, std.shape) 99 | # print(mean, std) 100 | 101 | # store initial z 102 | encode_start_time = time.time() 103 | model.eval() 104 | z_all = [] 105 | for batch_idx, data_list in enumerate(test_loader): 106 | for target_idx, target in enumerate(test_target): 107 | # forward 108 | z = model.encode(Batch.from_data_list(data_list).to(device)).detach() 109 | # target_batch = Batch.from_data_list([target for data in data_list]) 110 | # target_nodes = target_batch.x.size(0)+2*target_batch.hand_x.size(0) 111 | # z = torch.empty(target_nodes, 64).normal_(mean=0, std=0.1).to(device) 112 | # z = torch.zeros(target_nodes, 64).to(device) 113 | # z = torch.stack([torch.normal(mean=mean, std=std) for _ in range(target_nodes)], dim=0).to(device) 114 | z.requires_grad = True 115 | z_all.append(z) 116 | encode_end_time = time.time() 117 | print('encode time {} ms'.format((encode_end_time - encode_start_time)*1000)) 118 | # Create loss criterion 119 | # end effector loss 120 | ee_criterion = nn.MSELoss() if cfg.LOSS.EE else None 121 | # vector similarity loss 122 | vec_criterion = nn.MSELoss() if cfg.LOSS.VEC else None 123 | # collision loss 124 | col_criterion = CollisionLoss(cfg.LOSS.COL_THRESHOLD) if cfg.LOSS.COL else None 125 | # joint limit loss 126 | lim_criterion = JointLimitLoss() if cfg.LOSS.LIM else None 127 | # end effector orientation loss 128 | ori_criterion = nn.MSELoss() if cfg.LOSS.ORI else None 129 | # finger similarity loss 130 | fin_criterion = nn.MSELoss() if cfg.LOSS.FIN else None 131 | # regularization loss 132 | reg_criterion = RegLoss() if cfg.LOSS.REG else None 133 | 134 | # Create optimizer 135 | optimizer = optim.Adam(z_all, lr=cfg.HYPER.LEARNING_RATE) 136 | 137 | best_loss = float('Inf') 138 | best_z_all = copy.deepcopy(z_all) 139 | best_cnt = 0 140 | start_time = time.time() 141 | 142 | # latent optimization 143 | decode_start_time = time.time() 144 | for epoch in range(cfg.HYPER.EPOCHS): 145 | train_loss = train_epoch(model, ee_criterion, vec_criterion, col_criterion, lim_criterion, ori_criterion, fin_criterion, reg_criterion, optimizer, test_loader, test_target, epoch, logger, cfg.OTHERS.LOG_INTERVAL, writer, device, z_all) 146 | if cfg.INFERENCE.MOTION.KEY: 147 | # Save model 148 | if train_loss > best_loss: 149 | best_cnt += 1 150 | else: 151 | best_cnt = 0 152 | best_loss = train_loss 153 | best_z_all = copy.deepcopy(z_all) 154 | if best_cnt == 5: 155 | logger.info("Interation Finished") 156 | break 157 | print(best_cnt) 158 | decode_end_time = time.time() 159 | print('decode time {} ms'.format((decode_end_time - decode_start_time)*1000)) 160 | if cfg.INFERENCE.MOTION.KEY: 161 | # store final results 162 | model.eval() 163 | pos_all = [] 164 | ang_all = [] 165 | l_hand_ang_all = [] 166 | r_hand_ang_all = [] 167 | for batch_idx, data_list in enumerate(test_loader): 168 | for target_idx, target in enumerate(test_target): 169 | # fetch target 170 | target_list = [target for data in data_list] 171 | # fetch z 172 | z = best_z_all[batch_idx] 173 | # forward 174 | _, target_ang, _, _, target_global_pos, l_hand_ang, _, r_hand_ang, _ = model.decode(z, Batch.from_data_list(target_list).to(z.device)) 175 | 176 | if target_global_pos is not None and target_ang is not None: 177 | pos_all.append(target_global_pos) 178 | ang_all.append(target_ang) 179 | if l_hand_ang is not None and r_hand_ang is not None: 180 | l_hand_ang_all.append(l_hand_ang) 181 | r_hand_ang_all.append(r_hand_ang) 182 | 183 | if cfg.INFERENCE.H5.BOOL: 184 | hf = h5py.File(os.path.join(cfg.INFERENCE.H5.PATH, 'inference.h5'), 'w') 185 | g1 = hf.create_group('group1') 186 | if pos_all and ang_all: 187 | pos = torch.cat(pos_all, dim=0).view(len(test_data), -1, 3).detach().cpu().numpy() # [T, joint_num, xyz] 188 | ang = torch.cat(ang_all, dim=0).view(len(test_data), -1).detach().cpu().numpy() 189 | g1.create_dataset('l_joint_pos', data=pos[:, :7]) 190 | g1.create_dataset('r_joint_pos', data=pos[:, 7:]) 191 | g1.create_dataset('l_joint_angle', data=ang[:, :7]) 192 | g1.create_dataset('r_joint_angle', data=ang[:, 7:]) 193 | # g1.create_dataset('l_glove_angle', data=l_hand_angle) 194 | # g1.create_dataset('r_glove_angle', data=r_hand_angle) 195 | if l_hand_ang_all and r_hand_ang_all: 196 | l_hand_angle = torch.cat(l_hand_ang_all, dim=0).view(len(test_data), -1).detach().cpu().numpy() 197 | r_hand_angle = torch.cat(r_hand_ang_all, dim=0).view(len(test_data), -1).detach().cpu().numpy() 198 | # remove zeros 199 | l_hand_angle = np.concatenate([l_hand_angle[:,1:3],l_hand_angle[:,4:6],l_hand_angle[:,7:9],l_hand_angle[:,10:12],l_hand_angle[:,13:17]], axis=1) 200 | r_hand_angle = np.concatenate([r_hand_angle[:,1:3],r_hand_angle[:,4:6],r_hand_angle[:,7:9],r_hand_angle[:,10:12],r_hand_angle[:,13:17]], axis=1) 201 | g1.create_dataset('l_glove_angle', data=l_hand_angle) 202 | g1.create_dataset('r_glove_angle', data=r_hand_angle) 203 | hf.close() 204 | print('Target H5 file saved!') 205 | -------------------------------------------------------------------------------- /utils/urdf2graph.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch_geometric.data import Data 4 | from urdfpy import URDF, matrix_to_xyz_rpy 5 | import math 6 | 7 | 8 | """ 9 | convert Yumi URDF to graph 10 | """ 11 | def yumi2graph(urdf_file, cfg): 12 | # load URDF 13 | robot = URDF.load(urdf_file) 14 | 15 | # parse joint params 16 | joints = {} 17 | for joint in robot.joints: 18 | # joint atributes 19 | joints[joint.name] = {'type': joint.joint_type, 'axis': joint.axis, 20 | 'parent': joint.parent, 'child': joint.child, 21 | 'origin': matrix_to_xyz_rpy(joint.origin), 22 | 'lower': joint.limit.lower if joint.limit else 0, 23 | 'upper': joint.limit.upper if joint.limit else 0} 24 | 25 | # debug msg 26 | # for name, attr in joints.items(): 27 | # print(name, attr) 28 | 29 | # skeleton type & topology type 30 | skeleton_type = 0 31 | topology_type = 0 32 | 33 | # collect edge index & edge feature 34 | joints_name = cfg['joints_name'] 35 | joints_index = {name: i for i, name in enumerate(joints_name)} 36 | edge_index = [] 37 | edge_attr = [] 38 | for edge in cfg['edges']: 39 | parent, child = edge 40 | # add edge index 41 | edge_index.append(torch.LongTensor([joints_index[parent], joints_index[child]])) 42 | # add edge attr 43 | edge_attr.append(torch.Tensor(joints[child]['origin'])) 44 | edge_index = torch.stack(edge_index, dim=0) 45 | edge_index = edge_index.permute(1, 0) 46 | edge_attr = torch.stack(edge_attr, dim=0) 47 | # print(edge_index, edge_attr, edge_index.shape, edge_attr.shape) 48 | 49 | # number of nodes 50 | num_nodes = len(joints_name) 51 | 52 | # end effector mask 53 | ee_mask = torch.zeros(len(joints_name), 1).bool() 54 | for ee in cfg['end_effectors']: 55 | ee_mask[joints_index[ee]] = True 56 | 57 | # shoulder mask 58 | sh_mask = torch.zeros(len(joints_name), 1).bool() 59 | for sh in cfg['shoulders']: 60 | sh_mask[joints_index[sh]] = True 61 | 62 | # elbow mask 63 | el_mask = torch.zeros(len(joints_name), 1).bool() 64 | for el in cfg['elbows']: 65 | el_mask[joints_index[el]] = True 66 | 67 | # node parent 68 | parent = -torch.ones(len(joints_name)).long() 69 | for edge in edge_index.permute(1, 0): 70 | parent[edge[1]] = edge[0] 71 | 72 | # node offset 73 | offset = torch.stack([torch.Tensor(joints[joint]['origin']) for joint in joints_name], dim=0) 74 | # change root offset to store init pose 75 | init_pose = {} 76 | fk = robot.link_fk() 77 | for link, matrix in fk.items(): 78 | init_pose[link.name] = matrix_to_xyz_rpy(matrix) 79 | origin = torch.zeros(6) 80 | for root in cfg['root_name']: 81 | offset[joints_index[root]] = torch.Tensor(init_pose[joints[root]['child']]) 82 | origin[:3] += offset[joints_index[root]][:3] 83 | origin /= 2 84 | # move relative to origin 85 | for root in cfg['root_name']: 86 | offset[joints_index[root]] -= origin 87 | # print(offset, offset.shape) 88 | 89 | # dist to root 90 | root_dist = torch.zeros(len(joints_name), 1) 91 | for node_idx in range(len(joints_name)): 92 | dist = 0 93 | current_idx = node_idx 94 | while current_idx != -1: 95 | origin = offset[current_idx] 96 | offsets_mod = math.sqrt(origin[0]**2+origin[1]**2+origin[2]**2) 97 | dist += offsets_mod 98 | current_idx = parent[current_idx] 99 | root_dist[node_idx] = dist 100 | # print(root_dist, root_dist.shape) 101 | 102 | # dist to shoulder 103 | shoulder_dist = torch.zeros(len(joints_name), 1) 104 | for node_idx in range(len(joints_name)): 105 | dist = 0 106 | current_idx = node_idx 107 | while current_idx != -1 and joints_name[current_idx] not in cfg['shoulders']: 108 | origin = offset[current_idx] 109 | offsets_mod = math.sqrt(origin[0]**2+origin[1]**2+origin[2]**2) 110 | dist += offsets_mod 111 | current_idx = parent[current_idx] 112 | shoulder_dist[node_idx] = dist 113 | # print(shoulder_dist, shoulder_dist.shape) 114 | 115 | # dist to elbow 116 | elbow_dist = torch.zeros(len(joints_name), 1) 117 | for node_idx in range(len(joints_name)): 118 | dist = 0 119 | current_idx = node_idx 120 | while current_idx != -1 and joints_name[current_idx] not in cfg['elbows']: 121 | origin = offset[current_idx] 122 | offsets_mod = math.sqrt(origin[0]**2+origin[1]**2+origin[2]**2) 123 | dist += offsets_mod 124 | current_idx = parent[current_idx] 125 | elbow_dist[node_idx] = dist 126 | # print(elbow_dist, elbow_dist.shape) 127 | 128 | # rotation axis 129 | axis = [torch.Tensor(joints[joint]['axis']) for joint in joints_name] 130 | axis = torch.stack(axis, dim=0) 131 | 132 | # joint limit 133 | lower = [torch.Tensor([joints[joint]['lower']]) for joint in joints_name] 134 | lower = torch.stack(lower, dim=0) 135 | upper = [torch.Tensor([joints[joint]['upper']]) for joint in joints_name] 136 | upper = torch.stack(upper, dim=0) 137 | # print(lower.shape, upper.shape) 138 | 139 | # skeleton 140 | data = Data(x=torch.zeros(num_nodes, 1), 141 | edge_index=edge_index, 142 | edge_attr=edge_attr, 143 | skeleton_type=skeleton_type, 144 | topology_type=topology_type, 145 | ee_mask=ee_mask, 146 | sh_mask=sh_mask, 147 | el_mask=el_mask, 148 | root_dist=root_dist, 149 | shoulder_dist=shoulder_dist, 150 | elbow_dist=elbow_dist, 151 | num_nodes=num_nodes, 152 | parent=parent, 153 | offset=offset, 154 | axis=axis, 155 | lower=lower, 156 | upper=upper) 157 | 158 | # test forward kinematics 159 | # print(joints_name) 160 | # result = robot.link_fk(cfg={joint:0.0 for joint in joints_name}) 161 | # for link, matrix in result.items(): 162 | # print(link.name, matrix) 163 | # import os, sys, inspect 164 | # currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) 165 | # parentdir = os.path.dirname(currentdir) 166 | # sys.path.insert(0,parentdir) 167 | # from models.kinematics import ForwardKinematicsURDF 168 | # fk = ForwardKinematicsURDF() 169 | # pos = fk(data.x, data.parent, data.offset, 1) 170 | 171 | # # visualize 172 | # import matplotlib.pyplot as plt 173 | # fig = plt.figure() 174 | # ax = fig.add_subplot(111, projection='3d') 175 | # ax.set_xlabel('X') 176 | # ax.set_ylabel('Y') 177 | # ax.set_zlabel('Z') 178 | 179 | # # plot 3D lines 180 | # for edge in edge_index.permute(1, 0): 181 | # line_x = [pos[edge[0]][0], pos[edge[1]][0]] 182 | # line_y = [pos[edge[0]][1], pos[edge[1]][1]] 183 | # line_z = [pos[edge[0]][2], pos[edge[1]][2]] 184 | # plt.plot(line_x, line_y, line_z, 'royalblue', marker='o') 185 | # # plt.show() 186 | # plt.savefig('foo.png') 187 | 188 | return data 189 | 190 | """ 191 | convert Inspire Hand URDF graph 192 | """ 193 | def hand2graph(urdf_file, cfg): 194 | # load URDF 195 | robot = URDF.load(urdf_file) 196 | 197 | # parse joint params 198 | joints = {} 199 | for joint in robot.joints: 200 | # joint atributes 201 | joints[joint.name] = {'type': joint.joint_type, 'axis': joint.axis, 202 | 'parent': joint.parent, 'child': joint.child, 203 | 'origin': matrix_to_xyz_rpy(joint.origin), 204 | 'lower': joint.limit.lower if joint.limit else 0, 205 | 'upper': joint.limit.upper if joint.limit else 0} 206 | 207 | # debug msg 208 | # for name, attr in joints.items(): 209 | # print(name, attr) 210 | 211 | # skeleton type & topology type 212 | skeleton_type = 0 213 | topology_type = 0 214 | 215 | # collect edge index & edge feature 216 | joints_name = cfg['joints_name'] 217 | joints_index = {name: i for i, name in enumerate(joints_name)} 218 | edge_index = [] 219 | edge_attr = [] 220 | for edge in cfg['edges']: 221 | parent, child = edge 222 | # add edge index 223 | edge_index.append(torch.LongTensor([joints_index[parent], joints_index[child]])) 224 | # add edge attr 225 | edge_attr.append(torch.Tensor(joints[child]['origin'])) 226 | edge_index = torch.stack(edge_index, dim=0) 227 | edge_index = edge_index.permute(1, 0) 228 | edge_attr = torch.stack(edge_attr, dim=0) 229 | # print(edge_index, edge_attr, edge_index.shape, edge_attr.shape) 230 | 231 | # number of nodes 232 | num_nodes = len(joints_name) 233 | # print(num_nodes) 234 | 235 | # end effector mask 236 | ee_mask = torch.zeros(len(joints_name), 1).bool() 237 | for ee in cfg['end_effectors']: 238 | ee_mask[joints_index[ee]] = True 239 | # print(ee_mask) 240 | 241 | # elbow mask 242 | el_mask = torch.zeros(len(joints_name), 1).bool() 243 | for el in cfg['elbows']: 244 | el_mask[joints_index[el]] = True 245 | # print(el_mask) 246 | 247 | # node parent 248 | parent = -torch.ones(len(joints_name)).long() 249 | for edge in edge_index.permute(1, 0): 250 | parent[edge[1]] = edge[0] 251 | # print(parent) 252 | 253 | # node offset 254 | offset = [] 255 | for joint in joints_name: 256 | offset.append(torch.Tensor(joints[joint]['origin'])) 257 | offset = torch.stack(offset, dim=0) 258 | # print(offset, offset.shape) 259 | 260 | # dist to root 261 | root_dist = torch.zeros(len(joints_name), 1) 262 | for node_idx in range(len(joints_name)): 263 | dist = 0 264 | current_idx = node_idx 265 | while joints_name[current_idx] != cfg['root_name']: 266 | origin = offset[current_idx] 267 | offsets_mod = math.sqrt(origin[0]**2+origin[1]**2+origin[2]**2) 268 | dist += offsets_mod 269 | current_idx = parent[current_idx] 270 | root_dist[node_idx] = dist 271 | # print(root_dist, root_dist.shape) 272 | 273 | # dist to elbow 274 | elbow_dist = torch.zeros(len(joints_name), 1) 275 | for node_idx in range(len(joints_name)): 276 | dist = 0 277 | current_idx = node_idx 278 | while joints_name[current_idx] != cfg['root_name'] and joints_name[current_idx] not in cfg['elbows']: 279 | origin = offset[current_idx] 280 | offsets_mod = math.sqrt(origin[0]**2+origin[1]**2+origin[2]**2) 281 | dist += offsets_mod 282 | current_idx = parent[current_idx] 283 | elbow_dist[node_idx] = dist 284 | # print(elbow_dist, elbow_dist.shape) 285 | 286 | # rotation axis 287 | axis = [torch.Tensor(joints[joint]['axis']) if joint != cfg['root_name'] else torch.zeros(3) for joint in joints_name] 288 | axis = torch.stack(axis, dim=0) 289 | # print(axis, axis.shape) 290 | 291 | # joint limit 292 | lower = [torch.Tensor([joints[joint]['lower']]) if joint != cfg['root_name'] else torch.zeros(1) for joint in joints_name] 293 | lower = torch.stack(lower, dim=0) 294 | upper = [torch.Tensor([joints[joint]['upper']]) if joint != cfg['root_name'] else torch.zeros(1) for joint in joints_name] 295 | upper = torch.stack(upper, dim=0) 296 | # print(lower, upper, lower.shape, upper.shape) 297 | 298 | # skeleton 299 | data = Data(x=torch.zeros(num_nodes, 1), 300 | edge_index=edge_index, 301 | edge_attr=edge_attr, 302 | skeleton_type=skeleton_type, 303 | topology_type=topology_type, 304 | ee_mask=ee_mask, 305 | el_mask=el_mask, 306 | root_dist=root_dist, 307 | elbow_dist=elbow_dist, 308 | num_nodes=num_nodes, 309 | parent=parent, 310 | offset=offset, 311 | axis=axis, 312 | lower=lower, 313 | upper=upper) 314 | # data for arm with hand 315 | data.hand_x = data.x 316 | data.hand_edge_index = data.edge_index 317 | data.hand_edge_attr = data.edge_attr 318 | data.hand_ee_mask = data.ee_mask 319 | data.hand_el_mask = data.el_mask 320 | data.hand_root_dist = data.root_dist 321 | data.hand_elbow_dist = data.elbow_dist 322 | data.hand_num_nodes = data.num_nodes 323 | data.hand_parent = data.parent 324 | data.hand_offset = data.offset 325 | data.hand_axis = data.axis 326 | data.hand_lower = data.lower 327 | data.hand_upper = data.upper 328 | # print(data) 329 | 330 | # # test forward kinematics 331 | # result = robot.link_fk(cfg={joint: 0.0 for joint in cfg['joints_name'] if joint != cfg['root_name']}) 332 | # # for link, matrix in result.items(): 333 | # # print(link.name, matrix) 334 | # fk = ForwardKinematicsAxis() 335 | # pos, _, _ = fk(data.x, data.parent, data.offset, 1, data.axis) 336 | # # print(joints_index, pos) 337 | 338 | # # visualize 339 | # import matplotlib.pyplot as plt 340 | # fig = plt.figure() 341 | # ax = fig.add_subplot(111, projection='3d') 342 | # # ax.set_axis_off() 343 | # # ax.view_init(elev=0, azim=90) 344 | # ax.set_xlabel('X') 345 | # ax.set_ylabel('Y') 346 | # ax.set_zlabel('Z') 347 | # ax.set_xlim3d(-0.2,0.) 348 | # ax.set_ylim3d(-0.1,0.1) 349 | # ax.set_zlim3d(-0.1,0.1) 350 | 351 | # # plot 3D lines 352 | # for edge in edge_index.permute(1, 0): 353 | # line_x = [pos[edge[0]][0], pos[edge[1]][0]] 354 | # line_y = [pos[edge[0]][1], pos[edge[1]][1]] 355 | # line_z = [pos[edge[0]][2], pos[edge[1]][2]] 356 | # # line_x = [pos[edge[0]][2], pos[edge[1]][2]] 357 | # # line_y = [pos[edge[0]][0], pos[edge[1]][0]] 358 | # # line_z = [pos[edge[0]][1], pos[edge[1]][1]] 359 | # plt.plot(line_x, line_y, line_z, 'royalblue', marker='o') 360 | # plt.show() 361 | # # plt.savefig('hand.png') 362 | 363 | return data 364 | 365 | if __name__ == '__main__': 366 | yumi_cfg = { 367 | 'joints_name': [ 368 | 'yumi_joint_1_l', 369 | 'yumi_joint_2_l', 370 | 'yumi_joint_7_l', 371 | 'yumi_joint_3_l', 372 | 'yumi_joint_4_l', 373 | 'yumi_joint_5_l', 374 | 'yumi_joint_6_l', 375 | 'yumi_joint_1_r', 376 | 'yumi_joint_2_r', 377 | 'yumi_joint_7_r', 378 | 'yumi_joint_3_r', 379 | 'yumi_joint_4_r', 380 | 'yumi_joint_5_r', 381 | 'yumi_joint_6_r', 382 | ], 383 | 'edges': [ 384 | ['yumi_joint_1_l', 'yumi_joint_2_l'], 385 | ['yumi_joint_2_l', 'yumi_joint_7_l'], 386 | ['yumi_joint_7_l', 'yumi_joint_3_l'], 387 | ['yumi_joint_3_l', 'yumi_joint_4_l'], 388 | ['yumi_joint_4_l', 'yumi_joint_5_l'], 389 | ['yumi_joint_5_l', 'yumi_joint_6_l'], 390 | ['yumi_joint_1_r', 'yumi_joint_2_r'], 391 | ['yumi_joint_2_r', 'yumi_joint_7_r'], 392 | ['yumi_joint_7_r', 'yumi_joint_3_r'], 393 | ['yumi_joint_3_r', 'yumi_joint_4_r'], 394 | ['yumi_joint_4_r', 'yumi_joint_5_r'], 395 | ['yumi_joint_5_r', 'yumi_joint_6_r'], 396 | ], 397 | 'root_name': [ 398 | 'yumi_joint_1_l', 399 | 'yumi_joint_1_r', 400 | ], 401 | 'end_effectors': [ 402 | 'yumi_joint_6_l', 403 | 'yumi_joint_6_r', 404 | ], 405 | 'shoulders': [ 406 | 'yumi_joint_2_l', 407 | 'yumi_joint_2_r', 408 | ], 409 | 'elbows': [ 410 | 'yumi_joint_3_l', 411 | 'yumi_joint_3_r', 412 | ], 413 | } 414 | graph = yumi2graph(urdf_file='./data/target/yumi/yumi.urdf', cfg=yumi_cfg) 415 | print('yumi', graph) 416 | 417 | hand_cfg = { 418 | 'joints_name': [ 419 | 'yumi_link_7_r_joint', 420 | 'Link1', 421 | 'Link11', 422 | 'Link1111', 423 | 'Link2', 424 | 'Link22', 425 | 'Link2222', 426 | 'Link3', 427 | 'Link33', 428 | 'Link3333', 429 | 'Link4', 430 | 'Link44', 431 | 'Link4444', 432 | 'Link5', 433 | 'Link51', 434 | 'Link52', 435 | 'Link53', 436 | 'Link5555', 437 | ], 438 | 'edges': [ 439 | ['yumi_link_7_r_joint', 'Link1'], 440 | ['Link1', 'Link11'], 441 | ['Link11', 'Link1111'], 442 | ['yumi_link_7_r_joint', 'Link2'], 443 | ['Link2', 'Link22'], 444 | ['Link22', 'Link2222'], 445 | ['yumi_link_7_r_joint', 'Link3'], 446 | ['Link3', 'Link33'], 447 | ['Link33', 'Link3333'], 448 | ['yumi_link_7_r_joint', 'Link4'], 449 | ['Link4', 'Link44'], 450 | ['Link44', 'Link4444'], 451 | ['yumi_link_7_r_joint', 'Link5'], 452 | ['Link5', 'Link51'], 453 | ['Link51', 'Link52'], 454 | ['Link52', 'Link53'], 455 | ['Link53', 'Link5555'], 456 | ], 457 | 'root_name': 'yumi_link_7_r_joint', 458 | 'end_effectors': [ 459 | 'Link1111', 460 | 'Link2222', 461 | 'Link3333', 462 | 'Link4444', 463 | 'Link5555', 464 | ], 465 | 'elbows': [ 466 | 'Link1', 467 | 'Link2', 468 | 'Link3', 469 | 'Link4', 470 | 'Link5', 471 | ], 472 | } 473 | graph = hand2graph(urdf_file='./data/target/yumi-with-hands/yumi_with_hands.urdf', cfg=hand_cfg) 474 | print('hand', graph) 475 | -------------------------------------------------------------------------------- /models/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from kornia.geometry.conversions import quaternion_to_rotation_matrix 5 | 6 | """ 7 | Calculate All Loss 8 | """ 9 | def calculate_all_loss(data_list, target_list, ee_criterion, vec_criterion, col_criterion, lim_criterion, ori_criterion, fin_criterion, reg_criterion, 10 | z, target_ang, target_pos, target_rot, target_global_pos, l_hand_pos, r_hand_pos, all_losses=[], ee_losses=[], vec_losses=[], col_losses=[], lim_losses=[], ori_losses=[], fin_losses=[], reg_losses=[]): 11 | # end effector loss 12 | if ee_criterion: 13 | ee_loss = calculate_ee_loss(data_list, target_list, target_pos, ee_criterion)*1000 14 | ee_losses.append(ee_loss.item()) 15 | else: 16 | ee_loss = 0 17 | ee_losses.append(0) 18 | 19 | # vector loss 20 | if vec_criterion: 21 | vec_loss = calculate_vec_loss(data_list, target_list, target_pos, vec_criterion)*100 22 | vec_losses.append(vec_loss.item()) 23 | else: 24 | vec_loss = 0 25 | vec_losses.append(0) 26 | 27 | # collision loss 28 | if col_criterion: 29 | col_loss = col_criterion(target_global_pos.view(len(target_list), -1, 3), target_list[0].edge_index, 30 | target_rot.view(len(target_list), -1, 9), target_list[0].ee_mask)*1000 31 | col_losses.append(col_loss.item()) 32 | else: 33 | col_loss = 0 34 | col_losses.append(0) 35 | 36 | # joint limit loss 37 | if lim_criterion: 38 | lim_loss = calculate_lim_loss(target_list, target_ang, lim_criterion)*10000 39 | lim_losses.append(lim_loss.item()) 40 | else: 41 | lim_loss = 0 42 | lim_losses.append(0) 43 | 44 | # end effector orientation loss 45 | if ori_criterion: 46 | ori_loss = calculate_ori_loss(data_list, target_list, target_rot, ori_criterion)*100 47 | ori_losses.append(ori_loss.item()) 48 | else: 49 | ori_loss = 0 50 | ori_losses.append(0) 51 | 52 | # finger similarity loss 53 | if fin_criterion: 54 | fin_loss = calculate_fin_loss(data_list, target_list, l_hand_pos, r_hand_pos, fin_criterion)*100 55 | fin_losses.append(fin_loss.item()) 56 | else: 57 | fin_loss = 0 58 | fin_losses.append(0) 59 | 60 | # regularization loss 61 | if reg_criterion: 62 | reg_loss = reg_criterion(z.view(len(target_list), -1, 64)) 63 | reg_losses.append(reg_loss.item()) 64 | else: 65 | reg_loss = 0 66 | reg_losses.append(0) 67 | 68 | # total loss 69 | loss = ee_loss + vec_loss + col_loss + lim_loss + ori_loss + fin_loss + reg_loss 70 | all_losses.append(loss.item()) 71 | 72 | return loss 73 | 74 | """ 75 | Calculate End Effector Loss 76 | """ 77 | def calculate_ee_loss(data_list, target_list, target_pos, ee_criterion): 78 | target_mask = torch.cat([data.ee_mask for data in target_list]).to(target_pos.device) 79 | source_mask = torch.cat([data.ee_mask for data in data_list]).to(target_pos.device) 80 | target_ee = torch.masked_select(target_pos, target_mask).view(-1, 3) 81 | source_ee = torch.masked_select(torch.cat([data.pos for data in data_list]).to(target_pos.device), source_mask).view(-1, 3) 82 | # normalize 83 | target_root_dist = torch.cat([data.root_dist for data in target_list]).to(target_pos.device) 84 | source_root_dist = torch.cat([data.root_dist for data in data_list]).to(target_pos.device) 85 | target_ee = target_ee / torch.masked_select(target_root_dist, target_mask).unsqueeze(1) 86 | source_ee = source_ee / torch.masked_select(source_root_dist, source_mask).unsqueeze(1) 87 | ee_loss = ee_criterion(target_ee, source_ee) 88 | return ee_loss 89 | 90 | """ 91 | Calculate Vector Loss 92 | """ 93 | def calculate_vec_loss(data_list, target_list, target_pos, vec_criterion): 94 | target_sh_mask = torch.cat([data.sh_mask for data in target_list]).to(target_pos.device) 95 | target_el_mask = torch.cat([data.el_mask for data in target_list]).to(target_pos.device) 96 | target_ee_mask = torch.cat([data.ee_mask for data in target_list]).to(target_pos.device) 97 | source_sh_mask = torch.cat([data.sh_mask for data in data_list]).to(target_pos.device) 98 | source_el_mask = torch.cat([data.el_mask for data in data_list]).to(target_pos.device) 99 | source_ee_mask = torch.cat([data.ee_mask for data in data_list]).to(target_pos.device) 100 | target_sh = torch.masked_select(target_pos, target_sh_mask).view(-1, 3) 101 | target_el = torch.masked_select(target_pos, target_el_mask).view(-1, 3) 102 | target_ee = torch.masked_select(target_pos, target_ee_mask).view(-1, 3) 103 | source_sh = torch.masked_select(torch.cat([data.pos for data in data_list]).to(target_pos.device), source_sh_mask).view(-1, 3) 104 | source_el = torch.masked_select(torch.cat([data.pos for data in data_list]).to(target_pos.device), source_el_mask).view(-1, 3) 105 | source_ee = torch.masked_select(torch.cat([data.pos for data in data_list]).to(target_pos.device), source_ee_mask).view(-1, 3) 106 | # print(target_sh.shape, target_el.shape, target_ee.shape, source_sh.shape, source_el.shape, source_ee.shape) 107 | target_vector1 = target_el - target_sh 108 | target_vector2 = target_ee - target_el 109 | source_vector1 = source_el - source_sh 110 | source_vector2 = source_ee - source_el 111 | # print(target_vector1.shape, target_vector2.shape, source_vector1.shape, source_vector2.shape, (target_vector1*source_vector1).sum(-1).shape) 112 | # normalize 113 | target_shoulder_dist = torch.cat([data.shoulder_dist for data in target_list]).to(target_pos.device) 114 | target_elbow_dist = torch.cat([data.elbow_dist for data in target_list]).to(target_pos.device)/2 115 | source_shoulder_dist = torch.cat([data.shoulder_dist for data in data_list]).to(target_pos.device) 116 | source_elbow_dist = torch.cat([data.elbow_dist for data in data_list]).to(target_pos.device)/2 117 | normalize_target_vector1 = target_vector1 / torch.masked_select(target_shoulder_dist, target_el_mask).unsqueeze(1) 118 | normalize_source_vector1 = source_vector1 / torch.masked_select(source_shoulder_dist, source_el_mask).unsqueeze(1) 119 | vector1_loss = vec_criterion(normalize_target_vector1, normalize_source_vector1) 120 | normalize_target_vector2 = target_vector2 / torch.masked_select(target_elbow_dist, target_ee_mask).unsqueeze(1) 121 | normalize_source_vector2 = source_vector2 / torch.masked_select(source_elbow_dist, source_ee_mask).unsqueeze(1) 122 | vector2_loss = vec_criterion(normalize_target_vector2, normalize_source_vector2) 123 | vec_loss = vector2_loss#(vector1_loss + vector2_loss)*100 124 | return vec_loss 125 | 126 | """ 127 | Calculate Joint Limit Loss 128 | """ 129 | def calculate_lim_loss(target_list, target_ang, lim_criterion): 130 | target_lower = torch.cat([data.lower for data in target_list]).to(target_ang.device) 131 | target_upper = torch.cat([data.upper for data in target_list]).to(target_ang.device) 132 | lim_loss = lim_criterion(target_ang, target_lower, target_upper) 133 | return lim_loss 134 | 135 | """ 136 | Calculate Orientation Loss 137 | """ 138 | def calculate_ori_loss(data_list, target_list, target_rot, ori_criterion): 139 | target_mask = torch.cat([data.ee_mask for data in target_list]).to(target_rot.device) 140 | source_mask = torch.cat([data.ee_mask for data in data_list]).to(target_rot.device) 141 | target_rot = target_rot.view(-1, 9) 142 | source_rot = quaternion_to_rotation_matrix(torch.cat([data.q for data in data_list]).to(target_rot.device)).view(-1, 9) 143 | target_q = torch.masked_select(target_rot, target_mask) 144 | source_q = torch.masked_select(source_rot, source_mask) 145 | ori_loss = ori_criterion(target_q, source_q) 146 | return ori_loss 147 | 148 | """ 149 | Calculate Finger Similarity Loss 150 | """ 151 | def calculate_fin_loss(data_list, target_list, l_hand_pos, r_hand_pos, ee_criterion): 152 | # left hand 153 | target_el_mask = torch.cat([data.hand_el_mask for data in target_list]).to(l_hand_pos.device) 154 | target_ee_mask = torch.cat([data.hand_ee_mask for data in target_list]).to(l_hand_pos.device) 155 | source_el_mask = torch.cat([data.l_hand_el_mask for data in data_list]).to(l_hand_pos.device) 156 | source_ee_mask = torch.cat([data.l_hand_ee_mask for data in data_list]).to(l_hand_pos.device) 157 | target_el = torch.masked_select(l_hand_pos, target_el_mask).view(-1, 3) 158 | target_ee = torch.masked_select(l_hand_pos, target_ee_mask).view(-1, 3) 159 | source_el = torch.masked_select(torch.cat([data.l_hand_pos for data in data_list]).to(l_hand_pos.device), source_el_mask).view(-1, 3) 160 | source_ee = torch.masked_select(torch.cat([data.l_hand_pos for data in data_list]).to(l_hand_pos.device), source_ee_mask).view(-1, 3) 161 | target_vector = target_ee - target_el 162 | source_vector = source_ee - source_el 163 | # normalize 164 | target_elbow_dist = torch.cat([data.hand_elbow_dist for data in target_list]).to(l_hand_pos.device) 165 | source_elbow_dist = torch.cat([data.l_hand_elbow_dist for data in data_list]).to(l_hand_pos.device) 166 | normalize_target_vector = target_vector / torch.masked_select(target_elbow_dist, target_ee_mask).unsqueeze(1) 167 | normalize_source_vector = source_vector / torch.masked_select(source_elbow_dist, source_ee_mask).unsqueeze(1) 168 | l_fin_loss = ee_criterion(normalize_target_vector, normalize_source_vector) 169 | 170 | # right hand 171 | target_el_mask = torch.cat([data.hand_el_mask for data in target_list]).to(r_hand_pos.device) 172 | target_ee_mask = torch.cat([data.hand_ee_mask for data in target_list]).to(r_hand_pos.device) 173 | source_el_mask = torch.cat([data.r_hand_el_mask for data in data_list]).to(r_hand_pos.device) 174 | source_ee_mask = torch.cat([data.r_hand_ee_mask for data in data_list]).to(r_hand_pos.device) 175 | target_el = torch.masked_select(r_hand_pos, target_el_mask).view(-1, 3) 176 | target_ee = torch.masked_select(r_hand_pos, target_ee_mask).view(-1, 3) 177 | source_el = torch.masked_select(torch.cat([data.r_hand_pos for data in data_list]).to(r_hand_pos.device), source_el_mask).view(-1, 3) 178 | source_ee = torch.masked_select(torch.cat([data.r_hand_pos for data in data_list]).to(r_hand_pos.device), source_ee_mask).view(-1, 3) 179 | target_vector = target_ee - target_el 180 | source_vector = source_ee - source_el 181 | # normalize 182 | target_elbow_dist = torch.cat([data.hand_elbow_dist for data in target_list]).to(r_hand_pos.device) 183 | source_elbow_dist = torch.cat([data.r_hand_elbow_dist for data in data_list]).to(r_hand_pos.device) 184 | normalize_target_vector = target_vector / torch.masked_select(target_elbow_dist, target_ee_mask).unsqueeze(1) 185 | normalize_source_vector = source_vector / torch.masked_select(source_elbow_dist, source_ee_mask).unsqueeze(1) 186 | r_fin_loss = ee_criterion(normalize_target_vector, normalize_source_vector) 187 | 188 | fin_loss = (l_fin_loss + r_fin_loss)/2 189 | return fin_loss 190 | 191 | """ 192 | Collision Loss 193 | """ 194 | class CollisionLoss(nn.Module): 195 | def __init__(self, threshold, mode='capsule-capsule'): 196 | super(CollisionLoss, self).__init__() 197 | self.threshold = threshold 198 | self.mode = mode 199 | 200 | def forward(self, pos, edge_index, rot, ee_mask): 201 | """ 202 | Keyword arguments: 203 | pos -- joint positions [batch_size, num_nodes, 3] 204 | edge_index -- edge index [2, num_edges] 205 | """ 206 | batch_size = pos.shape[0] 207 | num_nodes = pos.shape[1] 208 | num_edges = edge_index.shape[1] 209 | 210 | # sphere-sphere detection 211 | if self.mode == 'sphere-sphere': 212 | l_sphere = pos[:, :num_nodes//2, :] 213 | r_sphere = pos[:, num_nodes//2:, :] 214 | l_sphere = l_sphere.unsqueeze(1).expand(batch_size, num_nodes//2, num_nodes//2, 3) 215 | r_sphere = r_sphere.unsqueeze(2).expand(batch_size, num_nodes//2, num_nodes//2, 3) 216 | dist_square = torch.sum(torch.pow(l_sphere - r_sphere, 2), dim=-1) 217 | mask = (dist_square < self.threshold**2) & (dist_square > 0) 218 | loss = torch.sum(torch.exp(-1*torch.masked_select(dist_square, mask)))/batch_size 219 | 220 | # sphere-capsule detection 221 | if self.mode == 'sphere-capsule': 222 | # capsule p0 & p1 223 | p0 = pos[:, edge_index[0], :] 224 | p1 = pos[:, edge_index[1], :] 225 | # print(edge_index.shape, p0.shape, p1.shape) 226 | 227 | # left sphere & right capsule 228 | l_sphere = pos[:, :num_nodes//2, :] 229 | r_capsule_p0 = p0[:, num_edges//2:, :] 230 | r_capsule_p1 = p1[:, num_edges//2:, :] 231 | dist_square_1 = self.sphere_capsule_dist_square(l_sphere, r_capsule_p0, r_capsule_p1, batch_size, num_nodes, num_edges) 232 | 233 | # left capsule & right sphere 234 | r_sphere = pos[:, num_nodes//2:, :] 235 | l_capsule_p0 = p0[:, :num_edges//2, :] 236 | l_capsule_p1 = p1[:, :num_edges//2, :] 237 | dist_square_2 = self.sphere_capsule_dist_square(r_sphere, l_capsule_p0, l_capsule_p1, batch_size, num_nodes, num_edges) 238 | 239 | # calculate loss 240 | dist_square = torch.cat([dist_square_1, dist_square_2]) 241 | mask = (dist_square < self.threshold**2) & (dist_square > 0) 242 | loss = torch.sum(torch.exp(-1*torch.masked_select(dist_square, mask)))/batch_size 243 | 244 | # capsule-capsule detection 245 | if self.mode == 'capsule-capsule': 246 | # capsule p0 & p1 247 | p0 = pos[:, edge_index[0], :] 248 | p1 = pos[:, edge_index[1], :] 249 | # left capsule 250 | l_capsule_p0 = p0[:, :num_edges//2, :] 251 | l_capsule_p1 = p1[:, :num_edges//2, :] 252 | # right capsule 253 | r_capsule_p0 = p0[:, num_edges//2:, :] 254 | r_capsule_p1 = p1[:, num_edges//2:, :] 255 | # add capsule for left hand & right hand(for yumi) 256 | ee_pos = torch.masked_select(pos, ee_mask.to(pos.device)).view(-1, 3) 257 | ee_rot = torch.masked_select(rot, ee_mask.to(pos.device)).view(-1, 3, 3) 258 | offset = torch.Tensor([[[0], [0], [0.2]]]).repeat(ee_rot.size(0), 1, 1).to(pos.device) 259 | hand_pos = torch.bmm(ee_rot, offset).squeeze() + ee_pos 260 | l_ee_pos = ee_pos[::2, :].unsqueeze(1) 261 | l_hand_pos = hand_pos[::2, :].unsqueeze(1) 262 | r_ee_pos = ee_pos[1::2, :].unsqueeze(1) 263 | r_hand_pos = hand_pos[1::2, :].unsqueeze(1) 264 | l_capsule_p0 = torch.cat([l_capsule_p0, l_ee_pos], dim=1) 265 | l_capsule_p1 = torch.cat([l_capsule_p1, l_hand_pos], dim=1) 266 | r_capsule_p0 = torch.cat([r_capsule_p0, r_ee_pos], dim=1) 267 | r_capsule_p1 = torch.cat([r_capsule_p1, r_hand_pos], dim=1) 268 | num_edges += 2 269 | # print(l_capsule_p0.shape, l_capsule_p1.shape, r_capsule_p0.shape, r_capsule_p1.shape) 270 | # calculate loss 271 | dist_square = self.capsule_capsule_dist_square(l_capsule_p0, l_capsule_p1, r_capsule_p0, r_capsule_p1, batch_size, num_edges) 272 | mask = (dist_square < 0.1**2) & (dist_square > 0) 273 | mask[:, 6, 6] = (dist_square[:, 6, 6] < self.threshold**2) & (dist_square[:, 6, 6] > 0) 274 | loss = torch.sum(torch.exp(-1*torch.masked_select(dist_square, mask)))/batch_size 275 | 276 | return loss 277 | 278 | def sphere_capsule_dist_square(self, sphere, capsule_p0, capsule_p1, batch_size, num_nodes, num_edges): 279 | # condition 1: p0 is the closest point 280 | vec_p0_p1 = capsule_p1 - capsule_p0 # vector p0-p1 [batch_size, num_edges//2, 3] 281 | vec_p0_pr = sphere.unsqueeze(2).expand(batch_size, num_nodes//2, num_edges//2, 3) - capsule_p0.unsqueeze(1).expand(batch_size, num_nodes//2, num_edges//2, 3) # vector p0-pr [batch_size, num_nodes//2, num_edges//2, 3] 282 | vec_mul_p0 = torch.mul(vec_p0_p1.unsqueeze(1).expand(batch_size, num_nodes//2, num_edges//2, 3), vec_p0_pr).sum(dim=-1) # vector p0-p1 * vector p0-pr [batch_size, num_nodes//2, num_edges//2] 283 | dist_square_p0 = torch.masked_select(vec_p0_pr.norm(dim=-1)**2, vec_mul_p0 <= 0) 284 | # print(dist_square_p0.shape) 285 | 286 | # condition 2: p1 is the closest point 287 | vec_p1_p0 = capsule_p0 - capsule_p1 # vector p1-p0 [batch_size, num_edges//2, 3] 288 | vec_p1_pr = sphere.unsqueeze(2).expand(batch_size, num_nodes//2, num_edges//2, 3) - capsule_p1.unsqueeze(1).expand(batch_size, num_nodes//2, num_edges//2, 3) # vector p1-pr [batch_size, num_nodes//2, num_edges//2, 3] 289 | vec_mul_p1 = torch.mul(vec_p1_p0.unsqueeze(1).expand(batch_size, num_nodes//2, num_edges//2, 3), vec_p1_pr).sum(dim=-1) # vector p1-p0 * vector p1-pr [batch_size, num_nodes//2, num_edges//2] 290 | dist_square_p1 = torch.masked_select(vec_p1_pr.norm(dim=-1)**2, vec_mul_p1 <= 0) 291 | # print(dist_square_p1.shape) 292 | 293 | # condition 3: closest point in p0-p1 segement 294 | d = vec_mul_p0 / vec_p0_p1.norm(dim=-1).unsqueeze(1).expand(batch_size, num_nodes//2, num_edges//2) # vector p0-p1 * vector p0-pr / |vector p0-p1| [batch_size, num_nodes//2, num_edges//2] 295 | dist_square_middle = vec_p0_pr.norm(dim=-1)**2 - d**2 # distance square [batch_size, num_nodes//2, num_edges//2] 296 | dist_square_middle = torch.masked_select(dist_square_middle, (vec_mul_p0 > 0) & (vec_mul_p1 > 0)) 297 | # print(dist_square_middle.shape) 298 | 299 | return torch.cat([dist_square_p0, dist_square_p1, dist_square_middle]) 300 | 301 | def capsule_capsule_dist_square(self, capsule_p0, capsule_p1, capsule_q0, capsule_q1, batch_size, num_edges): 302 | # expand left capsule 303 | capsule_p0 = capsule_p0.unsqueeze(1).expand(batch_size, num_edges//2, num_edges//2, 3) 304 | capsule_p1 = capsule_p1.unsqueeze(1).expand(batch_size, num_edges//2, num_edges//2, 3) 305 | # expand right capsule 306 | capsule_q0 = capsule_q0.unsqueeze(2).expand(batch_size, num_edges//2, num_edges//2, 3) 307 | capsule_q1 = capsule_q1.unsqueeze(2).expand(batch_size, num_edges//2, num_edges//2, 3) 308 | # basic variables 309 | a = torch.mul(capsule_p1 - capsule_p0, capsule_p1 - capsule_p0).sum(dim=-1) 310 | b = torch.mul(capsule_p1 - capsule_p0, capsule_q1 - capsule_q0).sum(dim=-1) 311 | c = torch.mul(capsule_q1 - capsule_q0, capsule_q1 - capsule_q0).sum(dim=-1) 312 | d = torch.mul(capsule_p1 - capsule_p0, capsule_p0 - capsule_q0).sum(dim=-1) 313 | e = torch.mul(capsule_q1 - capsule_q0, capsule_p0 - capsule_q0).sum(dim=-1) 314 | f = torch.mul(capsule_p0 - capsule_q0, capsule_p0 - capsule_q0).sum(dim=-1) 315 | # initialize s, t to zero 316 | s = torch.zeros(batch_size, num_edges//2, num_edges//2).to(capsule_p0.device) 317 | t = torch.zeros(batch_size, num_edges//2, num_edges//2).to(capsule_p0.device) 318 | one = torch.ones(batch_size, num_edges//2, num_edges//2).to(capsule_p0.device) 319 | # calculate coefficient 320 | det = a * c - b**2 321 | bte = b * e 322 | ctd = c * d 323 | ate = a * e 324 | btd = b * d 325 | # nonparallel segments 326 | # region 6 327 | s = torch.where((det > 0) & (bte <= ctd) & (e <= 0) & (-d >= a), one, s) 328 | s = torch.where((det > 0) & (bte <= ctd) & (e <= 0) & (-d < a) & (-d > 0), -d/a, s) 329 | # region 5 330 | t = torch.where((det > 0) & (bte <= ctd) & (e > 0) & (e < c), e/c, t) 331 | # region 4 332 | s = torch.where((det > 0) & (bte <= ctd) & (e > 0) & (e >= c) & (b - d >= a), one, s) 333 | s = torch.where((det > 0) & (bte <= ctd) & (e > 0) & (e >= c) & (b - d < a) & (b - d > 0), (b-d)/a, s) 334 | t = torch.where((det > 0) & (bte <= ctd) & (e > 0) & (e >= c), one, t) 335 | # region 8 336 | s = torch.where((det > 0) & (bte > ctd) & (bte - ctd >= det) & (b + e <= 0) & (-d > 0) & (-d < a), -d/a, s) 337 | s = torch.where((det > 0) & (bte > ctd) & (bte - ctd >= det) & (b + e <= 0) & (-d > 0) & (-d >= a), one, s) 338 | # region 1 339 | s = torch.where((det > 0) & (bte > ctd) & (bte - ctd >= det) & (b + e > 0) & (b + e < c), one, s) 340 | t = torch.where((det > 0) & (bte > ctd) & (bte - ctd >= det) & (b + e > 0) & (b + e < c), (b+e)/c, t) 341 | # region 2 342 | s = torch.where((det > 0) & (bte > ctd) & (bte - ctd >= det) & (b + e > 0) & (b + e >= c) & (b - d > 0) & (b - d < a), (b-d)/a, s) 343 | s = torch.where((det > 0) & (bte > ctd) & (bte - ctd >= det) & (b + e > 0) & (b + e >= c) & (b - d > 0) & (b - d >= a), one, s) 344 | t = torch.where((det > 0) & (bte > ctd) & (bte - ctd >= det) & (b + e > 0) & (b + e >= c), one, t) 345 | # region 7 346 | s = torch.where((det > 0) & (bte > ctd) & (bte - ctd < det) & (ate <= btd) & (-d > 0) & (-d >= a), one, s) 347 | s = torch.where((det > 0) & (bte > ctd) & (bte - ctd < det) & (ate <= btd) & (-d > 0) & (-d < a), -d/a, s) 348 | # region 3 349 | s = torch.where((det > 0) & (bte > ctd) & (bte - ctd < det) & (ate > btd) & (ate - btd >= det) & (b - d > 0) & (b - d >= a), one, s) 350 | s = torch.where((det > 0) & (bte > ctd) & (bte - ctd < det) & (ate > btd) & (ate - btd >= det) & (b - d > 0) & (b - d < a), (b-d)/a, s) 351 | t = torch.where((det > 0) & (bte > ctd) & (bte - ctd < det) & (ate > btd) & (ate - btd >= det), one, t) 352 | # region 0 353 | s = torch.where((det > 0) & (bte > ctd) & (bte - ctd < det) & (ate > btd) & (ate - btd < det), (bte-ctd)/det, s) 354 | t = torch.where((det > 0) & (bte > ctd) & (bte - ctd < det) & (ate > btd) & (ate - btd < det), (ate-btd)/det, t) 355 | # parallel segments 356 | # e <= 0 357 | s = torch.where((det <= 0) & (e <= 0) & (-d > 0) & (-d >= a), one, s) 358 | s = torch.where((det <= 0) & (e <= 0) & (-d > 0) & (-d < a), -d/a, s) 359 | # e >= c 360 | s = torch.where((det <= 0) & (e > 0) & (e >= c) & (b - d > 0) & (b - d >= a), one, s) 361 | s = torch.where((det <= 0) & (e > 0) & (e >= c) & (b - d > 0) & (b - d < a), (b-d)/a, s) 362 | t = torch.where((det <= 0) & (e > 0) & (e >= c), one, t) 363 | # 0 < e < c 364 | t = torch.where((det <= 0) & (e > 0) & (e < c), e/c, t) 365 | # print(s, t) 366 | s = s.unsqueeze(-1).expand(batch_size, num_edges//2, num_edges//2, 3).detach() 367 | t = t.unsqueeze(-1).expand(batch_size, num_edges//2, num_edges//2, 3).detach() 368 | w = capsule_p0 - capsule_q0 + s*(capsule_p1 - capsule_p0) - t*(capsule_q1 - capsule_q0) 369 | dist_square = torch.mul(w, w).sum(dim=-1) 370 | return dist_square 371 | 372 | """ 373 | Joint Limit Loss 374 | """ 375 | class JointLimitLoss(nn.Module): 376 | def __init__(self): 377 | super(JointLimitLoss, self).__init__() 378 | 379 | def forward(self, ang, lower, upper): 380 | """ 381 | Keyword auguments: 382 | ang -- joint angles [batch_size*num_nodes, num_node_features] 383 | """ 384 | # calculate mask with limit 385 | lower_mask = ang < lower 386 | upper_mask = ang > upper 387 | 388 | # calculate final loss 389 | lower_loss = torch.sum(torch.masked_select(lower - ang, lower_mask)) 390 | upper_loss = torch.sum(torch.masked_select(ang - upper, upper_mask)) 391 | loss = (lower_loss + upper_loss)/ang.shape[0] 392 | 393 | return loss 394 | 395 | """ 396 | Regularization Loss 397 | """ 398 | class RegLoss(nn.Module): 399 | def __init__(self): 400 | super(RegLoss, self).__init__() 401 | 402 | def forward(self, z): 403 | # calculate final loss 404 | batch_size = z.shape[0] 405 | loss = torch.mean(torch.norm(z.view(batch_size, -1), dim=1).pow(2)) 406 | 407 | return loss 408 | 409 | 410 | if __name__ == '__main__': 411 | fake_sphere = torch.Tensor([[[2,3,0]]]) 412 | fake_capsule_p0 = torch.Tensor([[[0,0,0]]]) 413 | fake_capsule_p1 = torch.Tensor([[[1,0,0]]]) 414 | col_loss = CollisionLoss(threshold=1.0) 415 | # print(col_loss.sphere_capsule_dist_square(fake_sphere, fake_capsule_p0, fake_capsule_p1, 1, 2, 2)) 416 | fake_capsule_p0 = torch.Tensor([[[0,0,0]]]) 417 | fake_capsule_p1 = torch.Tensor([[[1,0,0]]]) 418 | fake_capsule_q0 = torch.Tensor([[[-10,0,0]]]) 419 | fake_capsule_q1 = torch.Tensor([[[-9,2,0]]]) 420 | print(col_loss.capsule_capsule_dist_square(fake_capsule_p0, fake_capsule_p1, fake_capsule_q0, fake_capsule_q1, 1, 2)) 421 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch_geometric.transforms as transforms 3 | from torch_geometric.data import Data as OldData 4 | from torch_geometric.data import InMemoryDataset 5 | 6 | import os 7 | import math 8 | import numpy as np 9 | from numpy.linalg import inv 10 | from scipy.spatial.transform import Rotation as R 11 | from utils.urdf2graph import yumi2graph, hand2graph 12 | import h5py 13 | 14 | 15 | class Data(OldData): 16 | def __inc__(self, key, value): 17 | if key == 'edge_index': 18 | return self.num_nodes 19 | elif key == 'l_hand_edge_index': 20 | return self.l_hand_num_nodes 21 | elif key == 'r_hand_edge_index': 22 | return self.r_hand_num_nodes 23 | else: 24 | return 0 25 | 26 | """ 27 | Normalize by a constant coefficient 28 | """ 29 | class Normalize(object): 30 | def __call__(self, data, coeff=100.0): 31 | if hasattr(data, 'x'): 32 | data.x = data.x/coeff 33 | if hasattr(data, 'l_hand_x'): 34 | data.l_hand_x = data.l_hand_x/coeff 35 | if hasattr(data, 'r_hand_x'): 36 | data.r_hand_x = data.r_hand_x/coeff 37 | return data 38 | 39 | def __repr__(self): 40 | return '{}()'.format(self.__class__.__name__) 41 | 42 | """ 43 | Target Dataset for Yumi Manipulator 44 | """ 45 | class YumiDataset(InMemoryDataset): 46 | yumi_cfg = { 47 | 'joints_name': [ 48 | 'yumi_joint_1_l', 49 | 'yumi_joint_2_l', 50 | 'yumi_joint_7_l', 51 | 'yumi_joint_3_l', 52 | 'yumi_joint_4_l', 53 | 'yumi_joint_5_l', 54 | 'yumi_joint_6_l', 55 | 'yumi_joint_1_r', 56 | 'yumi_joint_2_r', 57 | 'yumi_joint_7_r', 58 | 'yumi_joint_3_r', 59 | 'yumi_joint_4_r', 60 | 'yumi_joint_5_r', 61 | 'yumi_joint_6_r', 62 | ], 63 | 'edges': [ 64 | ['yumi_joint_1_l', 'yumi_joint_2_l'], 65 | ['yumi_joint_2_l', 'yumi_joint_7_l'], 66 | ['yumi_joint_7_l', 'yumi_joint_3_l'], 67 | ['yumi_joint_3_l', 'yumi_joint_4_l'], 68 | ['yumi_joint_4_l', 'yumi_joint_5_l'], 69 | ['yumi_joint_5_l', 'yumi_joint_6_l'], 70 | ['yumi_joint_1_r', 'yumi_joint_2_r'], 71 | ['yumi_joint_2_r', 'yumi_joint_7_r'], 72 | ['yumi_joint_7_r', 'yumi_joint_3_r'], 73 | ['yumi_joint_3_r', 'yumi_joint_4_r'], 74 | ['yumi_joint_4_r', 'yumi_joint_5_r'], 75 | ['yumi_joint_5_r', 'yumi_joint_6_r'], 76 | ], 77 | 'root_name': [ 78 | 'yumi_joint_1_l', 79 | 'yumi_joint_1_r', 80 | ], 81 | 'end_effectors': [ 82 | 'yumi_joint_6_l', 83 | 'yumi_joint_6_r', 84 | ], 85 | 'shoulders': [ 86 | 'yumi_joint_2_l', 87 | 'yumi_joint_2_r', 88 | ], 89 | 'elbows': [ 90 | 'yumi_joint_3_l', 91 | 'yumi_joint_3_r', 92 | ], 93 | } 94 | def __init__(self, root, transform=None, pre_transform=None): 95 | super(YumiDataset, self).__init__(root, transform, pre_transform) 96 | self.data, self.slices = torch.load(self.processed_paths[0]) 97 | 98 | @property 99 | def raw_file_names(self): 100 | self._raw_file_names = [os.path.join(self.root, file) for file in os.listdir(self.root) if file.endswith('.urdf')] 101 | return self._raw_file_names 102 | 103 | @property 104 | def processed_file_names(self): 105 | return ['data.pt'] 106 | 107 | def process(self): 108 | data_list = [] 109 | for file in self.raw_file_names: 110 | data_list.append(yumi2graph(file, self.yumi_cfg)) 111 | 112 | if self.pre_filter is not None: 113 | data_list = [data for data in data_list if self.pre_filter(data)] 114 | 115 | if self.pre_transform is not None: 116 | data_list = [self.pre_transform(data) for data in data_list] 117 | 118 | data, slices = self.collate(data_list) 119 | torch.save((data, slices), self.processed_paths[0]) 120 | 121 | 122 | """ 123 | Map glove data to inspire hand data 124 | """ 125 | def linear_map(x_, min_, max_, min_hat, max_hat): 126 | x_hat = 1.0 * (x_ - min_) / (max_ - min_) * (max_hat - min_hat) + min_hat 127 | return x_hat 128 | 129 | def map_glove_to_inspire_hand(glove_angles): 130 | 131 | ### This function linearly maps the Wiseglove angle measurement to Inspire hand's joint angles. 132 | 133 | ## preparation, specify the range for linear scaling 134 | hand_start = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3, 0.1, 0.0, 0.0]) # radius already 135 | hand_final = np.array([-1.6, -1.6, -1.6, -1.6, -1.6, -1.6, -1.6, -1.6, -0.75, 0.0, -0.2, -0.15]) 136 | glove_start = np.array([0, 0, 53, 0, 0, 22, 0, 0, 22, 0, 0, 35, 0, 0])# * pi / 180.0 # degree to radius 137 | glove_final = np.array([45, 100, 0, 90, 120, 0, 90, 120, 0, 90, 120, 0, 90, 120])# * pi / 180.0 138 | length = glove_angles.shape[0] 139 | hand_angles = np.zeros((length, 12)) # 12 joints 140 | 141 | ## Iterate to map angles 142 | for i in range(length): 143 | # four fingers' extension/flexion (abduction/adduction are dumped) 144 | hand_angles[i, 0] = linear_map(glove_angles[i, 3], glove_start[3], glove_final[3], hand_start[0], hand_final[0]) # Link1 (joint name) 145 | hand_angles[i, 1] = linear_map(glove_angles[i, 4], glove_start[4], glove_final[4], hand_start[1], hand_final[1]) # Link11 146 | hand_angles[i, 2] = linear_map(glove_angles[i, 6], glove_start[6], glove_final[6], hand_start[2], hand_final[2]) # Link2 147 | hand_angles[i, 3] = linear_map(glove_angles[i, 7], glove_start[7], glove_final[7], hand_start[3], hand_final[3]) # Link22 148 | hand_angles[i, 4] = linear_map(glove_angles[i, 9], glove_start[9], glove_final[9], hand_start[4], hand_final[4]) # Link3 149 | hand_angles[i, 5] = linear_map(glove_angles[i, 10], glove_start[10], glove_final[10], hand_start[5], hand_final[5]) # Link33 150 | hand_angles[i, 6] = linear_map(glove_angles[i, 12], glove_start[12], glove_final[12], hand_start[6], hand_final[6]) # Link4 151 | hand_angles[i, 7] = linear_map(glove_angles[i, 13], glove_start[13], glove_final[13], hand_start[7], hand_final[7]) # Link44 152 | 153 | # thumb 154 | hand_angles[i, 8] = (hand_start[8] + hand_final[8]) / 2.0 # Link5 (rotation about z axis), fixed! 155 | hand_angles[i, 9] = linear_map(glove_angles[i, 2], glove_start[2], glove_final[2], hand_start[9], hand_final[9]) # Link 51 156 | hand_angles[i, 10] = linear_map(glove_angles[i, 0], glove_start[0], glove_final[0], hand_start[10], hand_final[10]) # Link 52 157 | hand_angles[i, 11] = linear_map(glove_angles[i, 1], glove_start[1], glove_final[1], hand_start[11], hand_final[11]) # Link 53 158 | 159 | return hand_angles 160 | 161 | """ 162 | Parse H5 File 163 | """ 164 | def parse_h5(filename, selected_key=None): 165 | data_list = [] 166 | h5_file = h5py.File(filename, 'r') 167 | # print(filename, h5_file.keys(), len(h5_file.keys())) 168 | if selected_key is None: 169 | keys = h5_file.keys() 170 | else: 171 | keys = [selected_key] 172 | for key in keys: 173 | if '语句' in key and selected_key is None: 174 | print('Skipping'+key) 175 | continue 176 | # glove data 177 | l_glove_angle = h5_file[key + '/l_glove_angle'][:] 178 | r_glove_angle = h5_file[key + '/r_glove_angle'][:] 179 | l_hand_angle = map_glove_to_inspire_hand(l_glove_angle) 180 | r_hand_angle = map_glove_to_inspire_hand(r_glove_angle) 181 | # position data 182 | l_shoulder_pos = h5_file[key + '/l_up_pos'][:] 183 | r_shoulder_pos = h5_file[key + '/r_up_pos'][:] 184 | l_elbow_pos = h5_file[key + '/l_fr_pos'][:] 185 | r_elbow_pos = h5_file[key + '/r_fr_pos'][:] 186 | l_wrist_pos = h5_file[key + '/l_hd_pos'][:] 187 | r_wrist_pos = h5_file[key + '/r_hd_pos'][:] 188 | # quaternion data 189 | l_shoulder_quat = R.from_quat(h5_file[key + '/l_up_quat'][:]) 190 | r_shoulder_quat = R.from_quat(h5_file[key + '/r_up_quat'][:]) 191 | l_elbow_quat = R.from_quat(h5_file[key + '/l_fr_quat'][:]) 192 | r_elbow_quat = R.from_quat(h5_file[key + '/r_fr_quat'][:]) 193 | l_wrist_quat = R.from_quat(h5_file[key + '/l_hd_quat'][:]) 194 | r_wrist_quat = R.from_quat(h5_file[key + '/r_hd_quat'][:]) 195 | # rotation matrix data 196 | l_shoulder_matrix = l_shoulder_quat.as_matrix() 197 | r_shoulder_matrix = r_shoulder_quat.as_matrix() 198 | l_elbow_matrix = l_elbow_quat.as_matrix() 199 | r_elbow_matrix = r_elbow_quat.as_matrix() 200 | l_wrist_matrix = l_wrist_quat.as_matrix() 201 | r_wrist_matrix = r_wrist_quat.as_matrix() 202 | # transform to local coordinates 203 | # l_wrist_matrix = l_wrist_matrix * inv(l_elbow_matrix) 204 | # r_wrist_matrix = r_wrist_matrix * inv(r_elbow_matrix) 205 | # l_elbow_matrix = l_elbow_matrix * inv(l_shoulder_matrix) 206 | # r_elbow_matrix = r_elbow_matrix * inv(r_shoulder_matrix) 207 | # l_shoulder_matrix = l_shoulder_matrix * inv(l_shoulder_matrix) 208 | # r_shoulder_matrix = r_shoulder_matrix * inv(r_shoulder_matrix) 209 | # euler data 210 | l_shoulder_euler = R.from_matrix(l_shoulder_matrix).as_euler('zyx', degrees=True) 211 | r_shoulder_euler = R.from_matrix(r_shoulder_matrix).as_euler('zyx', degrees=True) 212 | l_elbow_euler = R.from_matrix(l_elbow_matrix).as_euler('zyx', degrees=True) 213 | r_elbow_euler = R.from_matrix(r_elbow_matrix).as_euler('zyx', degrees=True) 214 | l_wrist_euler = R.from_matrix(l_wrist_matrix).as_euler('zyx', degrees=True) 215 | r_wrist_euler = R.from_matrix(r_wrist_matrix).as_euler('zyx', degrees=True) 216 | 217 | total_frames = l_shoulder_pos.shape[0] 218 | for t in range(total_frames): 219 | data = parse_arm(l_shoulder_euler[t], l_elbow_euler[t], l_wrist_euler[t], r_shoulder_euler[t], r_elbow_euler[t], r_wrist_euler[t], 220 | l_shoulder_pos[t], l_elbow_pos[t], l_wrist_pos[t], r_shoulder_pos[t], r_elbow_pos[t], r_wrist_pos[t], 221 | l_shoulder_quat[t], l_elbow_quat[t], l_wrist_quat[t], r_shoulder_quat[t], r_elbow_quat[t], r_wrist_quat[t]) 222 | data_list.append(data) 223 | return data_list, l_hand_angle, r_hand_angle 224 | 225 | def parse_arm(l_shoulder_euler, l_elbow_euler, l_wrist_euler, r_shoulder_euler, r_elbow_euler, r_wrist_euler, 226 | l_shoulder_pos, l_elbow_pos, l_wrist_pos, r_shoulder_pos, r_elbow_pos, r_wrist_pos, 227 | l_shoulder_quat, l_elbow_quat, l_wrist_quat, r_shoulder_quat, r_elbow_quat, r_wrist_quat): 228 | # x 229 | x = torch.stack([torch.from_numpy(l_shoulder_euler), 230 | torch.from_numpy(l_elbow_euler), 231 | torch.from_numpy(l_wrist_euler), 232 | torch.from_numpy(r_shoulder_euler), 233 | torch.from_numpy(r_elbow_euler), 234 | torch.from_numpy(r_wrist_euler)], dim=0).float() 235 | # number of nodes 236 | num_nodes = 6 237 | # edge index 238 | edge_index = torch.LongTensor([[0, 1, 3, 4], 239 | [1, 2, 4, 5]]) 240 | # position 241 | pos = torch.stack([torch.from_numpy(l_shoulder_pos), 242 | torch.from_numpy(l_elbow_pos), 243 | torch.from_numpy(l_wrist_pos), 244 | torch.from_numpy(r_shoulder_pos), 245 | torch.from_numpy(r_elbow_pos), 246 | torch.from_numpy(r_wrist_pos)], dim=0).float() 247 | # edge attributes 248 | edge_attr = [] 249 | for edge in edge_index.permute(1, 0): 250 | parent = edge[0] 251 | child = edge[1] 252 | edge_attr.append(pos[child] - pos[parent]) 253 | edge_attr = torch.stack(edge_attr, dim=0) 254 | # skeleton type & topology type 255 | skeleton_type = 0 256 | topology_type = 0 257 | # end effector mask 258 | ee_mask = torch.zeros(num_nodes, 1).bool() 259 | ee_mask[2] = ee_mask[5] = True 260 | # shoulder mask 261 | sh_mask = torch.zeros(num_nodes, 1).bool() 262 | sh_mask[0] = sh_mask[3] = True 263 | # elbow mask 264 | el_mask = torch.zeros(num_nodes, 1).bool() 265 | el_mask[1] = el_mask[4] = True 266 | # parent 267 | parent = torch.LongTensor([-1, 0, 1, -1, 3, 4]) 268 | # offset 269 | offset = torch.zeros(num_nodes, 3) 270 | for node_idx in range(num_nodes): 271 | if parent[node_idx] != -1: 272 | offset[node_idx] = pos[node_idx] - pos[parent[node_idx]] 273 | else: 274 | offset[node_idx] = pos[node_idx] 275 | # distance to root 276 | root_dist = torch.zeros(num_nodes, 1) 277 | for node_idx in range(num_nodes): 278 | dist = 0 279 | current_idx = node_idx 280 | while current_idx != -1: 281 | origin = offset[current_idx] 282 | offsets_mod = math.sqrt(origin[0]**2+origin[1]**2+origin[2]**2) 283 | dist += offsets_mod 284 | current_idx = parent[current_idx] 285 | root_dist[node_idx] = dist 286 | # distance to shoulder 287 | shoulder_dist = torch.zeros(num_nodes, 1) 288 | for node_idx in range(num_nodes): 289 | dist = 0 290 | current_idx = node_idx 291 | while current_idx != -1 and current_idx != 0 and current_idx != 3: 292 | origin = offset[current_idx] 293 | offsets_mod = math.sqrt(origin[0]**2+origin[1]**2+origin[2]**2) 294 | dist += offsets_mod 295 | current_idx = parent[current_idx] 296 | shoulder_dist[node_idx] = dist 297 | # distance to elbow 298 | elbow_dist = torch.zeros(num_nodes, 1) 299 | for node_idx in range(num_nodes): 300 | dist = 0 301 | current_idx = node_idx 302 | while current_idx != -1 and current_idx != 1 and current_idx != 4: 303 | origin = offset[current_idx] 304 | offsets_mod = math.sqrt(origin[0]**2+origin[1]**2+origin[2]**2) 305 | dist += offsets_mod 306 | current_idx = parent[current_idx] 307 | elbow_dist[node_idx] = dist 308 | # quaternion 309 | q = torch.stack([torch.from_numpy(l_shoulder_quat.as_quat()), 310 | torch.from_numpy(l_elbow_quat.as_quat()), 311 | torch.from_numpy(l_wrist_quat.as_quat()), 312 | torch.from_numpy(r_shoulder_quat.as_quat()), 313 | torch.from_numpy(r_elbow_quat.as_quat()), 314 | torch.from_numpy(r_wrist_quat.as_quat())], dim=0).float() 315 | data = Data(x=torch.cat([x,pos], dim=-1), 316 | edge_index=edge_index, 317 | edge_attr=edge_attr, 318 | pos=pos, 319 | q=q, 320 | skeleton_type=skeleton_type, 321 | topology_type=topology_type, 322 | ee_mask=ee_mask, 323 | sh_mask=sh_mask, 324 | el_mask=el_mask, 325 | root_dist=root_dist, 326 | shoulder_dist=shoulder_dist, 327 | elbow_dist=elbow_dist, 328 | num_nodes=num_nodes, 329 | parent=parent, 330 | offset=offset) 331 | # print(data) 332 | return data 333 | 334 | """ 335 | Source Dataset for Sign Language 336 | """ 337 | class SignDataset(InMemoryDataset): 338 | def __init__(self, root, transform=None, pre_transform=None): 339 | super(SignDataset, self).__init__(root, transform, pre_transform) 340 | self.data, self.slices = torch.load(self.processed_paths[0]) 341 | 342 | @property 343 | def raw_file_names(self): 344 | data_path = os.path.join(self.root, 'h5') 345 | self._raw_file_names = [os.path.join(data_path, file) for file in os.listdir(data_path)] 346 | return self._raw_file_names 347 | 348 | @property 349 | def processed_file_names(self): 350 | return ['data.pt'] 351 | 352 | def process(self): 353 | data_list = [] 354 | for file in self.raw_file_names: 355 | data, _, _ = parse_h5(file) 356 | data_list.extend(data) 357 | 358 | if self.pre_filter is not None: 359 | data_list = [data for data in data_list if self.pre_filter(data)] 360 | 361 | if self.pre_transform is not None: 362 | data_list = [self.pre_transform(data) for data in data_list] 363 | 364 | data, slices = self.collate(data_list) 365 | torch.save((data, slices), self.processed_paths[0]) 366 | 367 | 368 | """ 369 | parse h5 with hand 370 | """ 371 | def parse_h5_hand(filename, selected_key=None): 372 | data_list = [] 373 | h5_file = h5py.File(filename, 'r') 374 | if selected_key is None: 375 | keys = h5_file.keys() 376 | else: 377 | keys = [selected_key] 378 | for key in keys: 379 | if '语句' in key and selected_key is None: 380 | print('Skipping'+key) 381 | continue 382 | # glove data 383 | l_glove_pos = h5_file[key + '/l_glove_pos'][:] 384 | r_glove_pos = h5_file[key + '/r_glove_pos'][:] 385 | # insert zero for root 386 | total_frames = l_glove_pos.shape[0] 387 | l_glove_pos = np.concatenate([np.zeros((total_frames, 1, 3)), l_glove_pos], axis=1) 388 | r_glove_pos = np.concatenate([np.zeros((total_frames, 1, 3)), r_glove_pos], axis=1) 389 | # print(l_glove_pos.shape, r_glove_pos.shape) 390 | # switch dimensions 391 | l_glove_pos = np.stack([-l_glove_pos[..., 2], -l_glove_pos[..., 1], -l_glove_pos[..., 0]], axis=-1) 392 | r_glove_pos = np.stack([-r_glove_pos[..., 2], -r_glove_pos[..., 1], -r_glove_pos[..., 0]], axis=-1) 393 | 394 | for t in range(total_frames): 395 | data = parse_glove_pos(l_glove_pos[t]) 396 | data.l_hand_x = data.x 397 | data.l_hand_edge_index = data.edge_index 398 | data.l_hand_edge_attr = data.edge_attr 399 | data.l_hand_pos = data.pos 400 | data.l_hand_ee_mask = data.ee_mask 401 | data.l_hand_el_mask = data.el_mask 402 | data.l_hand_root_dist = data.root_dist 403 | data.l_hand_elbow_dist = data.elbow_dist 404 | data.l_hand_num_nodes = data.num_nodes 405 | data.l_hand_parent = data.parent 406 | data.l_hand_offset = data.offset 407 | 408 | r_hand_data = parse_glove_pos(r_glove_pos[t]) 409 | data.r_hand_x = r_hand_data.x 410 | data.r_hand_edge_index = r_hand_data.edge_index 411 | data.r_hand_edge_attr = r_hand_data.edge_attr 412 | data.r_hand_pos = r_hand_data.pos 413 | data.r_hand_ee_mask = r_hand_data.ee_mask 414 | data.r_hand_el_mask = r_hand_data.el_mask 415 | data.r_hand_root_dist = r_hand_data.root_dist 416 | data.r_hand_elbow_dist = r_hand_data.elbow_dist 417 | data.r_hand_num_nodes = r_hand_data.num_nodes 418 | data.r_hand_parent = r_hand_data.parent 419 | data.r_hand_offset = r_hand_data.offset 420 | 421 | data_list.append(data) 422 | return data_list 423 | 424 | def parse_glove_pos(glove_pos): 425 | # x 426 | x = torch.from_numpy(glove_pos).float() 427 | 428 | # number of nodes 429 | num_nodes = 17 430 | 431 | # edge index 432 | edge_index = torch.LongTensor([[0, 1, 2, 0, 4, 5, 0, 7, 8, 0, 10, 11, 0, 13, 14, 15], 433 | [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]]) 434 | 435 | # position 436 | pos = torch.from_numpy(glove_pos).float() 437 | 438 | # edge attributes 439 | edge_attr = [] 440 | for edge in edge_index.permute(1, 0): 441 | parent = edge[0] 442 | child = edge[1] 443 | edge_attr.append(pos[child] - pos[parent]) 444 | edge_attr = torch.stack(edge_attr, dim=0) 445 | 446 | # skeleton type & topology type 447 | skeleton_type = 0 448 | topology_type = 0 449 | 450 | # end effector mask 451 | ee_mask = torch.zeros(num_nodes, 1).bool() 452 | ee_mask[3] = ee_mask[6] = ee_mask[9] = ee_mask[12] = ee_mask[16] = True 453 | 454 | # elbow mask 455 | el_mask = torch.zeros(num_nodes, 1).bool() 456 | el_mask[1] = el_mask[4] = el_mask[7] = el_mask[10] = el_mask[13] = True 457 | 458 | # parent 459 | parent = torch.LongTensor([-1, 0, 1, 2, 0, 4, 5, 0, 7, 8, 0, 10, 11, 0, 13, 14, 15]) 460 | 461 | # offset 462 | offset = torch.zeros(num_nodes, 3) 463 | for node_idx in range(num_nodes): 464 | if parent[node_idx] != -1: 465 | offset[node_idx] = pos[node_idx] - pos[parent[node_idx]] 466 | 467 | # distance to root 468 | root_dist = torch.zeros(num_nodes, 1) 469 | for node_idx in range(num_nodes): 470 | dist = 0 471 | current_idx = node_idx 472 | while parent[current_idx] != -1: 473 | origin = offset[current_idx] 474 | offsets_mod = math.sqrt(origin[0]**2+origin[1]**2+origin[2]**2) 475 | dist += offsets_mod 476 | current_idx = parent[current_idx] 477 | root_dist[node_idx] = dist 478 | 479 | # distance to elbow 480 | elbow_dist = torch.zeros(num_nodes, 1) 481 | for node_idx in range(num_nodes): 482 | dist = 0 483 | current_idx = node_idx 484 | while current_idx != -1 and not el_mask[current_idx]: 485 | origin = offset[current_idx] 486 | offsets_mod = math.sqrt(origin[0]**2+origin[1]**2+origin[2]**2) 487 | dist += offsets_mod 488 | current_idx = parent[current_idx] 489 | elbow_dist[node_idx] = dist 490 | 491 | data = Data(x=x, 492 | edge_index=edge_index, 493 | edge_attr=edge_attr, 494 | pos=pos, 495 | skeleton_type=skeleton_type, 496 | topology_type=topology_type, 497 | ee_mask=ee_mask, 498 | el_mask=el_mask, 499 | root_dist=root_dist, 500 | elbow_dist=elbow_dist, 501 | num_nodes=num_nodes, 502 | parent=parent, 503 | offset=offset) 504 | # print(data) 505 | return data 506 | 507 | 508 | """ 509 | Source Dataset for Sign Language with Hand 510 | """ 511 | class SignWithHand(InMemoryDataset): 512 | def __init__(self, root, transform=None, pre_transform=None): 513 | super(SignWithHand, self).__init__(root, transform, pre_transform) 514 | self.data, self.slices = torch.load(self.processed_paths[0]) 515 | 516 | @property 517 | def raw_file_names(self): 518 | data_path = os.path.join(self.root, 'h5') 519 | self._raw_file_names = [os.path.join(data_path, file) for file in os.listdir(data_path)] 520 | return self._raw_file_names 521 | 522 | @property 523 | def processed_file_names(self): 524 | return ['data.pt'] 525 | 526 | def process(self): 527 | data_list = [] 528 | for file in self.raw_file_names: 529 | data = parse_h5_hand(file) 530 | data_list.extend(data) 531 | 532 | if self.pre_filter is not None: 533 | data_list = [data for data in data_list if self.pre_filter(data)] 534 | 535 | if self.pre_transform is not None: 536 | data_list = [self.pre_transform(data) for data in data_list] 537 | 538 | data, slices = self.collate(data_list) 539 | torch.save((data, slices), self.processed_paths[0]) 540 | 541 | 542 | """ 543 | Target Dataset for Inspire Hand 544 | """ 545 | class InspireHand(InMemoryDataset): 546 | hand_cfg = { 547 | 'joints_name': [ 548 | 'yumi_link_7_r_joint', 549 | 'Link1', 550 | 'Link11', 551 | 'Link1111', 552 | 'Link2', 553 | 'Link22', 554 | 'Link2222', 555 | 'Link3', 556 | 'Link33', 557 | 'Link3333', 558 | 'Link4', 559 | 'Link44', 560 | 'Link4444', 561 | 'Link5', 562 | 'Link51', 563 | 'Link52', 564 | 'Link53', 565 | 'Link5555', 566 | ], 567 | 'edges': [ 568 | ['yumi_link_7_r_joint', 'Link1'], 569 | ['Link1', 'Link11'], 570 | ['Link11', 'Link1111'], 571 | ['yumi_link_7_r_joint', 'Link2'], 572 | ['Link2', 'Link22'], 573 | ['Link22', 'Link2222'], 574 | ['yumi_link_7_r_joint', 'Link3'], 575 | ['Link3', 'Link33'], 576 | ['Link33', 'Link3333'], 577 | ['yumi_link_7_r_joint', 'Link4'], 578 | ['Link4', 'Link44'], 579 | ['Link44', 'Link4444'], 580 | ['yumi_link_7_r_joint', 'Link5'], 581 | ['Link5', 'Link51'], 582 | ['Link51', 'Link52'], 583 | ['Link52', 'Link53'], 584 | ['Link53', 'Link5555'], 585 | ], 586 | 'root_name': 'yumi_link_7_r_joint', 587 | 'end_effectors': [ 588 | 'Link1111', 589 | 'Link2222', 590 | 'Link3333', 591 | 'Link4444', 592 | 'Link5555', 593 | ], 594 | 'elbows': [ 595 | 'Link1', 596 | 'Link2', 597 | 'Link3', 598 | 'Link4', 599 | 'Link5', 600 | ], 601 | } 602 | def __init__(self, root, transform=None, pre_transform=None): 603 | super(InspireHand, self).__init__(root, transform, pre_transform) 604 | self.data, self.slices = torch.load(self.processed_paths[0]) 605 | 606 | @property 607 | def raw_file_names(self): 608 | self._raw_file_names = [os.path.join(self.root, file) for file in os.listdir(self.root) if file.endswith('.urdf')] 609 | return self._raw_file_names 610 | 611 | @property 612 | def processed_file_names(self): 613 | return ['data.pt'] 614 | 615 | def process(self): 616 | data_list = [] 617 | for file in self.raw_file_names: 618 | data_list.append(hand2graph(file, self.hand_cfg)) 619 | 620 | if self.pre_filter is not None: 621 | data_list = [data for data in data_list if self.pre_filter(data)] 622 | 623 | if self.pre_transform is not None: 624 | data_list = [self.pre_transform(data) for data in data_list] 625 | 626 | data, slices = self.collate(data_list) 627 | torch.save((data, slices), self.processed_paths[0]) 628 | 629 | 630 | """ 631 | parse h5 with all data 632 | """ 633 | def parse_all(filename, selected_key=None): 634 | data_list = [] 635 | h5_file = h5py.File(filename, 'r') 636 | if selected_key is None: 637 | keys = h5_file.keys() 638 | else: 639 | keys = [selected_key] 640 | for key in keys: 641 | if '语句' in key and selected_key is None: 642 | print('Skipping'+key) 643 | continue 644 | # position data 645 | l_shoulder_pos = h5_file[key + '/l_up_pos'][:] 646 | r_shoulder_pos = h5_file[key + '/r_up_pos'][:] 647 | l_elbow_pos = h5_file[key + '/l_fr_pos'][:] 648 | r_elbow_pos = h5_file[key + '/r_fr_pos'][:] 649 | l_wrist_pos = h5_file[key + '/l_hd_pos'][:] 650 | r_wrist_pos = h5_file[key + '/r_hd_pos'][:] 651 | # quaternion data 652 | l_shoulder_quat = R.from_quat(h5_file[key + '/l_up_quat'][:]) 653 | r_shoulder_quat = R.from_quat(h5_file[key + '/r_up_quat'][:]) 654 | l_elbow_quat = R.from_quat(h5_file[key + '/l_fr_quat'][:]) 655 | r_elbow_quat = R.from_quat(h5_file[key + '/r_fr_quat'][:]) 656 | l_wrist_quat = R.from_quat(h5_file[key + '/l_hd_quat'][:]) 657 | r_wrist_quat = R.from_quat(h5_file[key + '/r_hd_quat'][:]) 658 | # rotation matrix data 659 | l_shoulder_matrix = l_shoulder_quat.as_matrix() 660 | r_shoulder_matrix = r_shoulder_quat.as_matrix() 661 | l_elbow_matrix = l_elbow_quat.as_matrix() 662 | r_elbow_matrix = r_elbow_quat.as_matrix() 663 | l_wrist_matrix = l_wrist_quat.as_matrix() 664 | r_wrist_matrix = r_wrist_quat.as_matrix() 665 | # transform to local coordinates 666 | # l_wrist_matrix = l_wrist_matrix * inv(l_elbow_matrix) 667 | # r_wrist_matrix = r_wrist_matrix * inv(r_elbow_matrix) 668 | # l_elbow_matrix = l_elbow_matrix * inv(l_shoulder_matrix) 669 | # r_elbow_matrix = r_elbow_matrix * inv(r_shoulder_matrix) 670 | # l_shoulder_matrix = l_shoulder_matrix * inv(l_shoulder_matrix) 671 | # r_shoulder_matrix = r_shoulder_matrix * inv(r_shoulder_matrix) 672 | # euler data 673 | l_shoulder_euler = R.from_matrix(l_shoulder_matrix).as_euler('zyx', degrees=True) 674 | r_shoulder_euler = R.from_matrix(r_shoulder_matrix).as_euler('zyx', degrees=True) 675 | l_elbow_euler = R.from_matrix(l_elbow_matrix).as_euler('zyx', degrees=True) 676 | r_elbow_euler = R.from_matrix(r_elbow_matrix).as_euler('zyx', degrees=True) 677 | l_wrist_euler = R.from_matrix(l_wrist_matrix).as_euler('zyx', degrees=True) 678 | r_wrist_euler = R.from_matrix(r_wrist_matrix).as_euler('zyx', degrees=True) 679 | # glove data 680 | l_glove_pos = h5_file[key + '/l_glove_pos'][:] 681 | r_glove_pos = h5_file[key + '/r_glove_pos'][:] 682 | # insert zero for root 683 | total_frames = l_glove_pos.shape[0] 684 | l_glove_pos = np.concatenate([np.zeros((total_frames, 1, 3)), l_glove_pos], axis=1) 685 | r_glove_pos = np.concatenate([np.zeros((total_frames, 1, 3)), r_glove_pos], axis=1) 686 | # print(l_glove_pos.shape, r_glove_pos.shape) 687 | # switch dimensions 688 | l_glove_pos = np.stack([-l_glove_pos[..., 2], -l_glove_pos[..., 1], -l_glove_pos[..., 0]], axis=-1) 689 | r_glove_pos = np.stack([-r_glove_pos[..., 2], -r_glove_pos[..., 1], -r_glove_pos[..., 0]], axis=-1) 690 | 691 | for t in range(total_frames): 692 | data = Data() 693 | l_hand_data = parse_glove_pos(l_glove_pos[t]) 694 | data.l_hand_x = l_hand_data.x 695 | data.l_hand_edge_index = l_hand_data.edge_index 696 | data.l_hand_edge_attr = l_hand_data.edge_attr 697 | data.l_hand_pos = l_hand_data.pos 698 | data.l_hand_ee_mask = l_hand_data.ee_mask 699 | data.l_hand_el_mask = l_hand_data.el_mask 700 | data.l_hand_root_dist = l_hand_data.root_dist 701 | data.l_hand_elbow_dist = l_hand_data.elbow_dist 702 | data.l_hand_num_nodes = l_hand_data.num_nodes 703 | data.l_hand_parent = l_hand_data.parent 704 | data.l_hand_offset = l_hand_data.offset 705 | 706 | r_hand_data = parse_glove_pos(r_glove_pos[t]) 707 | data.r_hand_x = r_hand_data.x 708 | data.r_hand_edge_index = r_hand_data.edge_index 709 | data.r_hand_edge_attr = r_hand_data.edge_attr 710 | data.r_hand_pos = r_hand_data.pos 711 | data.r_hand_ee_mask = r_hand_data.ee_mask 712 | data.r_hand_el_mask = r_hand_data.el_mask 713 | data.r_hand_root_dist = r_hand_data.root_dist 714 | data.r_hand_elbow_dist = r_hand_data.elbow_dist 715 | data.r_hand_num_nodes = r_hand_data.num_nodes 716 | data.r_hand_parent = r_hand_data.parent 717 | data.r_hand_offset = r_hand_data.offset 718 | 719 | arm_data = parse_arm(l_shoulder_euler[t], l_elbow_euler[t], l_wrist_euler[t], r_shoulder_euler[t], r_elbow_euler[t], r_wrist_euler[t], 720 | l_shoulder_pos[t], l_elbow_pos[t], l_wrist_pos[t], r_shoulder_pos[t], r_elbow_pos[t], r_wrist_pos[t], 721 | l_shoulder_quat[t], l_elbow_quat[t], l_wrist_quat[t], r_shoulder_quat[t], r_elbow_quat[t], r_wrist_quat[t]) 722 | data.x = arm_data.x 723 | data.edge_index = arm_data.edge_index 724 | data.edge_attr = arm_data.edge_attr 725 | data.pos = arm_data.pos 726 | data.q = arm_data.q 727 | data.skeleton_type = arm_data.skeleton_type 728 | data.topology_type = arm_data.topology_type 729 | data.ee_mask = arm_data.ee_mask 730 | data.sh_mask = arm_data.sh_mask 731 | data.el_mask = arm_data.el_mask 732 | data.root_dist = arm_data.root_dist 733 | data.shoulder_dist = arm_data.shoulder_dist 734 | data.elbow_dist = arm_data.elbow_dist 735 | data.num_nodes = arm_data.num_nodes 736 | data.parent = arm_data.parent 737 | data.offset = arm_data.offset 738 | data_list.append(data) 739 | return data_list 740 | 741 | """ 742 | Source Dataset for Sign Language with Hand 743 | """ 744 | class SignAll(InMemoryDataset): 745 | def __init__(self, root, transform=None, pre_transform=None): 746 | super(SignAll, self).__init__(root, transform, pre_transform) 747 | self.data, self.slices = torch.load(self.processed_paths[0]) 748 | 749 | @property 750 | def raw_file_names(self): 751 | data_path = os.path.join(self.root, 'h5') 752 | self._raw_file_names = [os.path.join(data_path, file) for file in os.listdir(data_path)] 753 | return self._raw_file_names 754 | 755 | @property 756 | def processed_file_names(self): 757 | return ['data.pt'] 758 | 759 | def process(self): 760 | data_list = [] 761 | for file in self.raw_file_names: 762 | data = parse_all(file) 763 | data_list.extend(data) 764 | 765 | if self.pre_filter is not None: 766 | data_list = [data for data in data_list if self.pre_filter(data)] 767 | 768 | if self.pre_transform is not None: 769 | data_list = [self.pre_transform(data) for data in data_list] 770 | 771 | data, slices = self.collate(data_list) 772 | torch.save((data, slices), self.processed_paths[0]) 773 | 774 | 775 | """ 776 | Target Dataset for Yumi 777 | """ 778 | class YumiAll(InMemoryDataset): 779 | def __init__(self, root, transform=None, pre_transform=None): 780 | super(YumiAll, self).__init__(root, transform, pre_transform) 781 | self.data, self.slices = torch.load(self.processed_paths[0]) 782 | 783 | @property 784 | def raw_file_names(self): 785 | self._raw_file_names = [os.path.join(self.root, file) for file in os.listdir(self.root) if file.endswith('.urdf')] 786 | return self._raw_file_names 787 | 788 | @property 789 | def processed_file_names(self): 790 | return ['data.pt'] 791 | 792 | def process(self): 793 | data_list = [] 794 | for file in self.raw_file_names: 795 | data = yumi2graph(file, YumiDataset.yumi_cfg) 796 | hand_data = hand2graph(file, InspireHand.hand_cfg) 797 | data.hand_x = hand_data.x 798 | data.hand_edge_index = hand_data.edge_index 799 | data.hand_edge_attr = hand_data.edge_attr 800 | data.hand_ee_mask = hand_data.ee_mask 801 | data.hand_el_mask = hand_data.el_mask 802 | data.hand_root_dist = hand_data.root_dist 803 | data.hand_elbow_dist = hand_data.elbow_dist 804 | data.hand_num_nodes = hand_data.num_nodes 805 | data.hand_parent = hand_data.parent 806 | data.hand_offset = hand_data.offset 807 | data.hand_axis = hand_data.axis 808 | data.hand_lower = hand_data.lower 809 | data.hand_upper = hand_data.upper 810 | data_list.append(data) 811 | 812 | if self.pre_filter is not None: 813 | data_list = [data for data in data_list if self.pre_filter(data)] 814 | 815 | if self.pre_transform is not None: 816 | data_list = [self.pre_transform(data) for data in data_list] 817 | 818 | data, slices = self.collate(data_list) 819 | torch.save((data, slices), self.processed_paths[0]) 820 | 821 | 822 | if __name__ == '__main__': 823 | yumi_dataset = YumiDataset(root='./data/target/yumi') 824 | sign_dataset = SignDataset(root='./data/source/sign/train', pre_transform=transforms.Compose([Normalize()])) 825 | inspire_hand = InspireHand(root='./data/target/yumi-with-hands') 826 | sign_with_hand = SignWithHand(root='./data/source/sign-hand/train', pre_transform=transforms.Compose([Normalize()])) 827 | sign_all = SignAll(root='./data/source/sign-all/train', pre_transform=transforms.Compose([Normalize()])) 828 | yumi_all = YumiAll(root='./data/target/yumi-all') 829 | --------------------------------------------------------------------------------