├── .DS_Store ├── util ├── .DS_Store ├── utils_utils.py ├── grab.py ├── loss_func.py ├── opt.py └── data_utils.py ├── .gitattributes ├── model_others ├── .DS_Store ├── GCN.py └── EAI.py ├── run.sh ├── run_train.sh ├── requirement.yaml ├── README.md ├── test.py └── train.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dingpx/EAI/HEAD/.DS_Store -------------------------------------------------------------------------------- /util/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dingpx/EAI/HEAD/util/.DS_Store -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /model_others/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dingpx/EAI/HEAD/model_others/.DS_Store -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | set -e 2 | 3 | source ~/.bashrc 4 | 5 | if [ ! $GPU_NUM ]; then 6 | GPU_NUM=1 7 | fi 8 | 9 | export MODEL_TYPE=EAI 10 | export LR_ONE_EPOCH=0.001 11 | export BATCHSZIE_ONE_EPOCH=64 12 | 13 | 14 | echo "GPU_NUM : " $GPU_NUM 15 | echo "MODEL_TYPE : " $MODEL_TYPE 16 | echo "LR : " $LR_ONE_EPOCH 17 | echo "BATCHSZIE_ONE_EPOCH : " $BATCHSZIE_ONE_EPOCH 18 | 19 | EXPID=TRAIN_modeltype_${MODEL_TYPE}_batchsize_$[$BATCHSZIE_ONE_EPOCH]_lr_$LR_ONE_EPOCH 20 | TESTEXPID=TEST_modeltype_${MODEL_TYPE}_batchsize_$[$BATCHSZIE_ONE_EPOCH]_lr_$LR_ONE_EPOCH 21 | echo "EXPID : " $EXPID 22 | 23 | python -u \ 24 | test.py \ 25 | --input_n 30 \ 26 | --output 30 \ 27 | --all_n 60 \ 28 | --lr $LR_ONE_EPOCH \ 29 | --train_batch $[$BATCHSZIE_ONE_EPOCH*$GPU_NUM] \ 30 | --model_type $MODEL_TYPE \ 31 | --is_exp \ 32 | --is_using_saved_file \ 33 | --exp $EXPID \ 34 | --is_using_noTpose2 \ 35 | 36 | 37 | 38 | 39 | 40 | -------------------------------------------------------------------------------- /util/utils_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | 4 | 5 | class AccumLoss(object): 6 | def __init__(self): 7 | self.val = 0 8 | self.avg = 0 9 | self.sum = 0 10 | self.count = 0 11 | 12 | def update(self, val, n=1): 13 | self.val = val 14 | self.sum += val 15 | self.count += n 16 | self.avg = self.sum / self.count 17 | 18 | 19 | def lr_decay(optimizer, lr_now, gamma): 20 | lr = lr_now * gamma 21 | for param_group in optimizer.param_groups: 22 | param_group['lr'] = lr 23 | return lr 24 | 25 | def save_ckpt(state, ckpt_path, is_best=True, file_name=['ckpt_best.pth.tar', 'ckpt_last.pth.tar']): 26 | file_path = os.path.join(ckpt_path, file_name[1]) 27 | torch.save(state, file_path) 28 | if is_best: 29 | file_path = os.path.join(ckpt_path, file_name[0]) 30 | torch.save(state, file_path) 31 | -------------------------------------------------------------------------------- /run_train.sh: -------------------------------------------------------------------------------- 1 | set -e 2 | 3 | source ~/.bashrc 4 | 5 | function rand(){ 6 | min=$1 7 | max=$(($2-$min+1)) 8 | num=$(date +%s%N) 9 | echo $(($num%$max+$min)) 10 | } 11 | 12 | if [ ! $GPU_NUM ]; then 13 | GPU_NUM=1 14 | fi 15 | 16 | export MODEL_TYPE=EAI 17 | export LR_ONE_EPOCH=0.001 18 | export BATCHSZIE_ONE_EPOCH=64 19 | export DATAPATH=./Dataset_GRAB 20 | 21 | echo "GPU_NUM : " $GPU_NUM 22 | echo "MODEL_TYPE : " $MODEL_TYPE 23 | echo "LR : " $LR_ONE_EPOCH 24 | echo "BATCHSZIE_ONE_EPOCH : " $BATCHSZIE_ONE_EPOCH 25 | echo "DATAPATH : " $DATAPATH 26 | 27 | rnd=$[$(rand 1 1000)+4231] 28 | echo "RANDOM_PORT : " $rnd 29 | 30 | EXPID=TRAIN_modeltype_${MODEL_TYPE}_batchsize_$[$BATCHSZIE_ONE_EPOCH]_lr_$LR_ONE_EPOCH 31 | TESTEXPID=TEST_modeltype_${MODEL_TYPE}_batchsize_$[$BATCHSZIE_ONE_EPOCH]_lr_$LR_ONE_EPOCH 32 | echo "EXPID : " $EXPID 33 | 34 | python -u -m torch.distributed.launch \ 35 | --nproc_per_node=$GPU_NUM \ 36 | --master_port=$rnd \ 37 | train.py \ 38 | --input_n 30 \ 39 | --output 30 \ 40 | --all_n 60 \ 41 | --lr $LR_ONE_EPOCH \ 42 | --train_batch $[$BATCHSZIE_ONE_EPOCH*$GPU_NUM] \ 43 | --model_type $MODEL_TYPE \ 44 | --grab_data_dict $DATAPATH \ 45 | --is_exp \ 46 | --is_using_saved_file \ 47 | --is_using_noTpose2 \ 48 | --is_boneloss \ 49 | --exp $EXPID \ 50 | 51 | 52 | 53 | 54 | -------------------------------------------------------------------------------- /util/grab.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | import numpy as np 4 | from tqdm import tqdm 5 | 6 | 7 | class Grab(Dataset): 8 | def __init__(self, path_to_data, input_n, output_n, split=0, using_saved_file=True, using_noTpose2=False, norm=True, debug=False,opt=None,using_raw=False): 9 | tra_val_test = ['train', 'val', 'test'] 10 | pad_idx = np.repeat([input_n - 1], output_n) 11 | i_idx = np.append(np.arange(0, input_n), pad_idx) 12 | 13 | data_size = {} 14 | data_size[0] = (176384, 60, 55, 3) 15 | data_size[1] = (52255, 60, 55, 3) 16 | 17 | if using_saved_file: 18 | if split==2: 19 | sampled_seq = np.load('{}/grab_{}.npy'.format(path_to_data,tra_val_test[split])) 20 | else: 21 | # print('>>> remove the first and last 1 second') 22 | # sampled_seq = np.load('{}/grab_dataloader_normalized_noTpose2_{}.npy'.format(path_to_data,tra_val_test[split])) 23 | tmp_bin_size = data_size[split] 24 | tmp_seq = np.memmap('{}/grab_dataloader_normalized_noTpose2_{}.bin'.format(path_to_data,tra_val_test[split]), dtype=np.float32, shape=tmp_bin_size) 25 | tem_res = np.frombuffer(tmp_seq, dtype=np.float32) 26 | sampled_seq= tem_res.reshape(tmp_bin_size) 27 | 28 | self.input_pose = torch.from_numpy(sampled_seq[:, i_idx]) 29 | print("input", self.input_pose.shape) 30 | self.target_pose = torch.from_numpy(sampled_seq) 31 | print("target", self.target_pose.shape) 32 | 33 | import gc 34 | del sampled_seq 35 | gc.collect() 36 | return 37 | 38 | def gen_data(self): 39 | for input in self.input_pose: 40 | batch_samples = [] 41 | while len(batch_samples) > 0: 42 | yield batch_samples.pop() 43 | def __len__(self): 44 | return np.shape(self.input_pose)[0] 45 | 46 | def __getitem__(self, item): 47 | return self.input_pose[item], self.target_pose[item] 48 | 49 | 50 | -------------------------------------------------------------------------------- /requirement.yaml: -------------------------------------------------------------------------------- 1 | name: xcai 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1=main 6 | - _openmp_mutex=5.1=1_gnu 7 | - ca-certificates=2022.07.19=h06a4308_0 8 | - certifi=2022.6.15=py38h06a4308_0 9 | - ld_impl_linux-64=2.38=h1181459_1 10 | - libffi=3.3=he6710b0_2 11 | - libgcc-ng=11.2.0=h1234567_1 12 | - libgomp=11.2.0=h1234567_1 13 | - libstdcxx-ng=11.2.0=h1234567_1 14 | - ncurses=6.3=h5eee18b_3 15 | - openssl=1.1.1q=h7f8727e_0 16 | - pip=22.1.2=py38h06a4308_0 17 | - python=3.8.13=h12debd9_0 18 | - readline=8.1.2=h7f8727e_1 19 | - setuptools=63.4.1=py38h06a4308_0 20 | - sqlite=3.39.2=h5082296_0 21 | - tk=8.6.12=h1ccaba5_0 22 | - wheel=0.37.1=pyhd3eb1b0_0 23 | - xz=5.2.5=h7f8727e_1 24 | - zlib=1.2.12=h5eee18b_3 25 | - pip: 26 | - asttokens==2.1.0 27 | - backcall==0.2.0 28 | - contourpy==1.0.5 29 | - cycler==0.11.0 30 | - debugpy==1.6.3 31 | - decorator==5.1.1 32 | - entrypoints==0.4 33 | - executing==1.2.0 34 | - fonttools==4.37.2 35 | - imageio==2.22.0 36 | - ipykernel==6.17.1 37 | - ipython==8.6.0 38 | - jedi==0.18.1 39 | - jupyter-client==7.4.7 40 | - jupyter-core==5.0.0 41 | - kiwisolver==1.4.4 42 | - matplotlib==3.3.0 43 | - matplotlib-inline==0.1.6 44 | - nest-asyncio==1.5.6 45 | - numpy==1.23.3 46 | - packaging==21.3 47 | - pandas==1.4.4 48 | - parso==0.8.3 49 | - pbd==0.9 50 | - pexpect==4.8.0 51 | - pickleshare==0.7.5 52 | - pillow==9.2.0 53 | - platformdirs==2.5.4 54 | - prefetch-generator==1.0.1 55 | - progress==1.6 56 | - prompt-toolkit==3.0.32 57 | - psutil==5.9.4 58 | - ptyprocess==0.7.0 59 | - pure-eval==0.2.2 60 | - pygments==2.13.0 61 | - pyparsing==3.0.9 62 | - python-dateutil==2.8.2 63 | - pytz==2022.2.1 64 | - pyzmq==24.0.1 65 | - six==1.16.0 66 | - stack-data==0.6.1 67 | - torch==1.8.0 68 | - torchvision==0.9.0 69 | - tornado==6.2 70 | - tqdm==4.64.1 71 | - traitlets==5.5.0 72 | - typing-extensions==4.3.0 73 | - wcwidth==0.2.5 74 | prefix: /data/dingpengxiang/.conda/envs/xcai 75 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Expressive Forecasting of 3D Whole-body Human Motions (AAAI2024) 2 | [![arXiv](https://img.shields.io/badge/arXiv-Paper-.svg)](https://arxiv.org/abs/2312.11972) 3 | 4 | Pengxiang Ding, Qiongjie Cui, Min Zhang, Mengyuan Liu, Haofan Wang, Donglin Wang 5 | 6 | 7 | 8 | 9 | ## Abstract 10 | Human motion forecasting, with the goal of estimating future human behavior over a period of time, is a fundamental task in many real-world applications. 11 | Existing work typically concentrates on foretelling the major joints of the human body without considering the delicate movements of the human hands. 12 | In practical applications, hand gestures play an important role in human communication with the real world, and express the primary intentions of human beings. 13 | In this work, we propose a new Encoding-Alignment-Interaction (EAI) framework to address expressive forecasting of 3D whole-body human motions, which aims to predict coarse- (body joints) and fine-grained (gestures) activities cooperatively. 14 | To our knowledge, this meaningful topic has not been explored before. 15 | Specifically, our model mainly involves two key constituents: cross-context alignment (XCA) and cross-context interaction (XCI). 16 | Considering the heterogeneous information within the whole-body, the former aims to align the latent features of various human components, while the latter focuses on effectively capturing the 17 | context interaction among the human components. 18 | We conduct extensive experiments on a newly-introduced large-scale benchmark and achieve state-of-the-art performance. 19 | 20 | 21 | ## Installation 22 | 1. Clone this repository 23 | `$ git clone https://github.com/Dingpx/EAI.git` 24 | 25 | 2. Initialize conda environment 26 | `$ conda env create -f requirement.yaml` 27 | 28 | ## Datasets 29 | ### GRAB data 30 | Updated: You can download our [processed data](https://drive.google.com/drive/folders/1o5wfHCkCTwOJrXs8dhGoRFoE1y4q42CO?usp=drive_link) 31 | 32 | TODO: 33 | - The whole process of [GRAB](https://grab.is.tue.mpg.de/) will be updated soon. 34 | 35 | 36 | 37 | ## Training 38 | Run `$ bash run_train.sh`. 39 | 40 | ## Evaluation 41 | Run `$ bash run.sh`. 42 | 43 | 44 | ## Cite our work: 45 | ``` 46 | @article{ding2023expressive, 47 | title={Expressive Forecasting of 3D Whole-body Human Motions}, 48 | author={Ding, Pengxiang and Cui, Qiongjie and Zhang, Min and Liu, Mengyuan and Wang, Haofan and Wang, Donglin}, 49 | journal={arXiv preprint arXiv:2312.11972}, 50 | year={2023} 51 | } 52 | ``` 53 | -------------------------------------------------------------------------------- /model_others/GCN.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | from __future__ import print_function 4 | 5 | 6 | import torch.nn as nn 7 | import torch 8 | from torch.nn.parameter import Parameter 9 | import math 10 | 11 | 12 | 13 | class GraphConvolution(nn.Module): 14 | 15 | def __init__(self, in_features, out_features, bias=True, node_n=48): 16 | super(GraphConvolution, self).__init__() 17 | self.in_features = in_features 18 | self.out_features = out_features 19 | 20 | self.weight = Parameter(torch.FloatTensor(in_features, out_features)) 21 | self.att = Parameter(torch.FloatTensor(node_n, node_n)) 22 | if bias: 23 | self.bias = Parameter(torch.FloatTensor(out_features)) 24 | else: 25 | self.register_parameter('bias', None) 26 | self.reset_parameters() 27 | 28 | def reset_parameters(self): 29 | stdv = 1. / math.sqrt(self.weight.size(1)) 30 | self.weight.data.uniform_(-stdv, stdv) 31 | self.att.data.uniform_(-stdv, stdv) 32 | if self.bias is not None: 33 | self.bias.data.uniform_(-stdv, stdv) 34 | 35 | def forward(self, x): 36 | support = torch.matmul(x, self.weight) 37 | y = torch.matmul(self.att, support) 38 | if self.bias is not None: 39 | return y + self.bias 40 | else: 41 | return y 42 | def __repr__(self): 43 | return self.__class__.__name__ + ' (' \ 44 | + str(self.in_features) + ' -> ' \ 45 | + str(self.out_features) + ')' 46 | 47 | 48 | 49 | class GC_Block(nn.Module): 50 | def __init__(self, in_features, p_dropout, bias=True, node_n=48): 51 | 52 | super(GC_Block, self).__init__() 53 | self.in_features = in_features 54 | self.out_features = in_features 55 | self.gc1 = GraphConvolution(in_features, in_features, node_n=node_n, bias=bias) 56 | self.bn1 = nn.BatchNorm1d(node_n * in_features) 57 | 58 | self.gc2 = GraphConvolution(in_features, in_features, node_n=node_n, bias=bias) 59 | self.bn2 = nn.BatchNorm1d(node_n * in_features) 60 | self.do = nn.Dropout(p_dropout) 61 | self.act_f = nn.Tanh() 62 | 63 | def forward(self, x): 64 | y = self.gc1(x) 65 | b, n, f = y.shape 66 | y = self.bn1(y.view(b, -1)).view(b, n, f) 67 | y = self.act_f(y) 68 | y = self.do(y) 69 | 70 | y = self.gc2(y) 71 | b, n, f = y.shape 72 | y = self.bn2(y.view(b, -1)).view(b, n, f) 73 | y = self.act_f(y) 74 | y = self.do(y) 75 | 76 | return y + x 77 | 78 | def __repr__(self): 79 | return self.__class__.__name__ + ' (' \ 80 | + str(self.in_features) + ' -> ' \ 81 | + str(self.out_features) + ')' 82 | 83 | 84 | class GCN(nn.Module): 85 | def __init__(self, input_feature, hidden_feature, p_dropout, num_stage=1, node_n=48): 86 | 87 | super(GCN, self).__init__() 88 | self.num_stage = num_stage 89 | self.gc1 = GraphConvolution(input_feature, hidden_feature, node_n=node_n) 90 | self.bn1 = nn.BatchNorm1d(node_n * hidden_feature) 91 | 92 | self.gcbs = [] 93 | for i in range(num_stage): 94 | self.gcbs.append(GC_Block(hidden_feature, p_dropout=p_dropout, node_n=node_n)) 95 | 96 | self.gcbs = nn.ModuleList(self.gcbs) 97 | self.gc7 = GraphConvolution(hidden_feature, input_feature, node_n=node_n) 98 | self.do = nn.Dropout(p_dropout) 99 | self.act_f = nn.Tanh() 100 | 101 | def forward(self, x): 102 | y = self.gc1(x) 103 | b, n, f = y.shape 104 | y = self.bn1(y.view(b, -1)).view(b, n, f) 105 | y = self.act_f(y) 106 | y = self.do(y) 107 | for i in range(self.num_stage): 108 | y = self.gcbs[i](y) 109 | y = self.gc7(y) 110 | y = y + x 111 | 112 | return y 113 | 114 | -------------------------------------------------------------------------------- /util/loss_func.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from .data_utils import Grab_Skeleton_55 4 | 5 | 6 | def poses_loss(y_out, out_poses): 7 | y_out = y_out.reshape(-1, 3) 8 | out_poses = out_poses.reshape(-1, 3) 9 | loss = torch.mean(torch.norm(y_out - out_poses, 2, 1)) 10 | return loss 11 | 12 | def joint_loss(y_out, out_poses): 13 | y_out = y_out.reshape(-1, 3) 14 | out_poses = out_poses.reshape(-1, 3) 15 | return torch.mean(torch.norm(y_out - out_poses, 2, 1)) 16 | 17 | def relative_hand_loss(y_out, out_poses): 18 | 19 | y_out_rel_lhand = y_out[:,:,25:40,:] - y_out[:,:,20:21,:] 20 | y_out_rel_rhand = y_out[:,:,40:,:] - y_out[:,:,21:22,:] 21 | 22 | out_poses_rel_lhand = out_poses[:,:,25:40,:] - out_poses[:,:,20:21,:] 23 | out_poses_rel_rhand = out_poses[:,:,40:,:] - out_poses[:,:,21:22,:] 24 | 25 | loss_rel_lhand = joint_loss(y_out_rel_lhand,out_poses_rel_lhand) 26 | loss_rel_rhand = joint_loss(y_out_rel_rhand,out_poses_rel_rhand) 27 | 28 | return loss_rel_lhand + loss_rel_rhand 29 | 30 | def joint_body_loss(y_out, out_poses): 31 | 32 | y_out_wrist = y_out[:,:,20:22,:] 33 | out_poses_wrist = out_poses[:,:,20:22,:] 34 | 35 | y_out_wrist = y_out_wrist.reshape(-1, 3) 36 | out_poses_wrist = out_poses_wrist.reshape(-1, 3) 37 | 38 | l_wrist = torch.mean(torch.norm(y_out_wrist - out_poses_wrist, 2, 1)) 39 | 40 | y_out = y_out.reshape(-1, 3) 41 | out_poses = out_poses.reshape(-1, 3) 42 | return torch.mean(torch.norm(y_out - out_poses, 2, 1)),l_wrist 43 | 44 | def joint_body_loss_test(y_out, out_poses): 45 | 46 | y_out_wrist = y_out[:,20:22,:] 47 | out_poses_wrist = out_poses[:,20:22,:] 48 | 49 | y_out_wrist = y_out_wrist.reshape(-1, 3) 50 | out_poses_wrist = out_poses_wrist.reshape(-1, 3) 51 | 52 | l_wrist = torch.mean(torch.norm(y_out_wrist - out_poses_wrist, 2, 1)) 53 | 54 | y_out = y_out.reshape(-1, 3) 55 | out_poses = out_poses.reshape(-1, 3) 56 | return torch.mean(torch.norm(y_out - out_poses, 2, 1)),l_wrist 57 | 58 | def bone_length_error(joints, input_bone_lengths, skeleton_cls): 59 | bone_lengths = calculate_bone_lengths(joints, skeleton_cls) 60 | return np.sum(np.abs(np.array(input_bone_lengths) - bone_lengths)) 61 | 62 | def calculate_bone_lengths(joints, skeleton_cls): 63 | return np.array([np.linalg.norm(joints[bone[0]] - joints[bone[1]] + 0.001) for bone in skeleton_cls.bones]) 64 | 65 | def bone_loss(raw,predict,device): 66 | 67 | raw_bone_length = cal_bone_loss(raw,device) 68 | pred_bone_length = cal_bone_loss(predict,device) 69 | 70 | diff = torch.abs(pred_bone_length - raw_bone_length) 71 | loss = torch.mean(diff) 72 | 73 | return loss 74 | 75 | def cal_bone_loss(x,device): 76 | # KCS 77 | batch_num = x.size()[0] 78 | frame_num = x.size()[1] 79 | joint_num = x.size()[2] 80 | 81 | Ct = get_matrix(device) 82 | 83 | x_ = x.transpose(2, 3) # b, t, 3, 55 84 | x_ = torch.matmul(x_, Ct) # b, t, 3, 54 85 | bone_length = torch.norm(x_, 2, 2) # b, t, 54 86 | 87 | return bone_length 88 | 89 | def get_matrix(device,type='all'): 90 | 91 | S_of_lhand = [20, 20, 20, 20, 20, 37, 38, 25, 26, 28, 29, 34, 35, 31, 32] 92 | S_of_rhand = [21, 21, 21, 21, 21, 52, 53, 40, 41, 43, 44, 49, 50, 46, 47] 93 | S_of_body = [0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 12, 13, 14, 16, 18, 17, 19, 15, 15, 15] 94 | E_of_lhand = [37, 25, 28, 34, 31, 38, 39, 26, 27, 29, 30, 35, 36, 32, 33] 95 | E_of_rhand = [40, 43, 49, 46, 52, 53, 54, 41, 42, 44, 45, 50, 51, 47, 48] 96 | E_of_body = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 20, 19, 21, 22, 23, 24] 97 | 98 | if type=='all': 99 | E = np.hstack((E_of_body, E_of_lhand, E_of_rhand)) 100 | S = np.hstack((S_of_body, S_of_lhand, S_of_rhand)) 101 | matrix = torch.zeros([55,54]) 102 | elif type=='lhand': 103 | E = E_of_lhand 104 | S = S_of_lhand 105 | matrix = torch.zeros([55,15]) 106 | elif type=='rhand': 107 | E = E_of_rhand 108 | S = S_of_rhand 109 | matrix = torch.zeros([55,15]) 110 | elif type=='body': 111 | E = E_of_body 112 | S = S_of_body 113 | matrix = torch.zeros([55,24]) 114 | 115 | for i in range(S.shape[0]): 116 | matrix[S[i].tolist(),i] = 1 117 | matrix[E[i].tolist(),i] = -1 118 | 119 | matrix = matrix.to(device) 120 | 121 | return matrix 122 | -------------------------------------------------------------------------------- /util/opt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from pprint import pprint 4 | 5 | 6 | class Options: 7 | def __init__(self): 8 | self.parser = argparse.ArgumentParser() 9 | self.opt = None 10 | 11 | def _initial(self): 12 | # =============================================================== 13 | # General options 14 | # =============================================================== 15 | 16 | self.parser.add_argument('--device', type=str, default='cuda:0', 17 | help='path to amass Synthetic dataset') 18 | self.parser.add_argument('--grab_data_dict', type=str, default='./Dataset_GRAB/',help='path to GRAB dataset') 19 | self.parser.add_argument('--exp', type=str, default='test', help='ID of experiment') 20 | self.parser.add_argument('--ckpt', type=str, default='checkpoint/', help='path to save checkpoint') 21 | self.parser.add_argument('--model_type', type=str, default='LTD', help='path to save checkpoint') 22 | 23 | 24 | # =============================================================== 25 | # Model options 26 | # =============================================================== 27 | self.parser.add_argument('--max_norm', dest='max_norm', action='store_true', 28 | help='maxnorm constraint to weights') 29 | self.parser.add_argument('--linear_size', type=int, default=256, help='size of each model layer') 30 | self.parser.add_argument('--num_stage', type=int, default=12, help='# layers in linear model') 31 | self.parser.add_argument('--num_body', type=int, default=25, help='# layers in linear model') 32 | self.parser.add_argument('--num_lh', type=int, default=15, help='# layers in linear model') 33 | self.parser.add_argument('--num_rh', type=int, default=15, help='# layers in linear model') 34 | 35 | # =============================================================== 36 | # Running options 37 | # =============================================================== 38 | self.parser.add_argument('--lr', type=float, default=1.0e-3) 39 | self.parser.add_argument('--lr_decay', type=int, default=2, help='every lr_decay epoch do lr decay') 40 | self.parser.add_argument('--lr_gamma', type=float, default=0.96) 41 | self.parser.add_argument('--input_n', type=int, default=30, help='observed seq length') 42 | self.parser.add_argument('--output_n', type=int, default=30, help='future seq length') 43 | self.parser.add_argument('--all_n', type=int, default=60, help='number of DCT coeff. preserved for 3D') 44 | self.parser.add_argument('--actions', type=str, default='all', help='path to save checkpoint') 45 | self.parser.add_argument('--epochs', type=int, default=50) 46 | self.parser.add_argument('--dropout', type=float, default=0.5, help='dropout probability, 1.0 to make no dropout') 47 | self.parser.add_argument('--train_batch', type=int, default=64) 48 | self.parser.add_argument('--val_batch', type=int, default=128) 49 | self.parser.add_argument('--test_batch', type=int, default=128) 50 | self.parser.add_argument('--job', type=int, default=0, help='subprocesses to use for data loading') 51 | self.parser.add_argument('--seed', type=int, default=1024, help='random seed') 52 | self.parser.add_argument("--local_rank", type=int, help="local rank") 53 | self.parser.add_argument('--W_pg', type=float, default=0.6, help='The weight of information propagation between part') 54 | self.parser.add_argument('--W_p', type=float, default=0.6, help='The weight of part on the whole body') 55 | 56 | self.parser.add_argument('--is_load', dest='is_load', action='store_true', help='wether to load existing model') 57 | self.parser.add_argument('--is_debug', dest='is_debug', action='store_true', help='wether to debug') 58 | self.parser.add_argument('--is_exp', dest='is_exp', action='store_true', help='wether to save different model') 59 | self.parser.add_argument('--sample_rate', type=int, default=2, help='frame sampling rate') 60 | self.parser.add_argument('--is_norm_dct', dest='is_norm_dct', action='store_true', 61 | help='whether to normalize the dct coeff') 62 | self.parser.add_argument('--is_norm', dest='is_norm', action='store_true', 63 | help='whether to normalize the angles/3d coordinates') 64 | self.parser.add_argument('--is_using_saved_file', dest='is_using_saved_file', action='store_true', 65 | help='whether to normalize the angles/3d coordinates') 66 | self.parser.add_argument('--is_hand_norm', dest='is_hand_norm', action='store_true',help='') 67 | self.parser.add_argument('--is_hand_norm_split', dest='is_hand_norm_split', action='store_true',help='') 68 | self.parser.add_argument('--is_part', dest='is_part', action='store_true', help='') 69 | self.parser.add_argument('--part_type', type=str, default='lhand', help='') 70 | self.parser.add_argument('--is_boneloss', dest='is_boneloss', action='store_true', help='') 71 | self.parser.add_argument('--is_weighted_jointloss', dest='is_weighted_jointloss', action='store_true', help='') 72 | self.parser.add_argument('--is_using_noTpose2', dest='is_using_noTpose2', action='store_true', help='') 73 | self.parser.add_argument('--is_using_raw', dest='is_using_raw', action='store_true', help='') 74 | 75 | self.parser.add_argument('--J', type=int, default=1, help='The number of wavelet filters') 76 | self.parser.set_defaults(max_norm=True) 77 | 78 | 79 | def _print(self): 80 | print("\n==================Options=================") 81 | pprint(vars(self.opt), indent=4) 82 | print("==========================================\n") 83 | 84 | def parse(self): 85 | self._initial() 86 | self.opt = self.parser.parse_args() 87 | 88 | return self.opt 89 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from multiprocessing.util import is_exiting 2 | import numpy 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim 6 | from torch.utils.data import DataLoader 7 | from torch.autograd import Variable 8 | import numpy as np 9 | import pandas as pd 10 | import os 11 | from util.grab import Grab 12 | 13 | from model_others.EAI import GCN_EAI 14 | 15 | from util.opt import Options 16 | import util.data_utils as data_utils 17 | from util import utils_utils as utils 18 | from util import loss_func 19 | from tqdm import tqdm 20 | import pdb 21 | import shutil 22 | import random 23 | 24 | from torch.utils.data import DataLoader 25 | from prefetch_generator import BackgroundGenerator 26 | 27 | import torch.distributed as dist 28 | from torch.nn.parallel import DistributedDataParallel as DDP 29 | from torch.utils.data.distributed import DistributedSampler 30 | 31 | def setup_DDP(backend="nccl", verbose=False): 32 | rank = int(os.environ["RANK"]) 33 | local_rank = int(os.environ["LOCAL_RANK"]) 34 | world_size = int(os.environ["WORLD_SIZE"]) 35 | dist.init_process_group(backend=backend) 36 | device = torch.device("cuda:{}".format(local_rank)) 37 | if verbose: 38 | print(f"local rank: {local_rank}, global rank: {rank}, world size: {world_size}") 39 | return rank, local_rank, world_size, device 40 | 41 | def print_only_rank0(log): 42 | if dist.get_rank() == 0: 43 | print(log) 44 | 45 | class DataLoaderX(DataLoader): 46 | def __iter__(self): 47 | return BackgroundGenerator(super().__iter__()) 48 | 49 | def main(opt): 50 | 51 | is_cuda = torch.cuda.is_available() 52 | input_n = opt.input_n 53 | output_n = opt.output_n 54 | all_n = input_n + output_n 55 | 56 | print(">>> creating model") 57 | model = GCN_EAI(input_feature=all_n, hidden_feature=opt.linear_size, p_dropout=opt.dropout, num_stage=opt.num_stage, lh_node_n=opt.num_lh*3, rh_node_n=opt.num_rh*3,b_node_n=opt.num_body*3) 58 | model_name = '{}'.format(opt.model_type) 59 | 60 | if is_cuda: 61 | model.cuda() 62 | 63 | dct_trans_funcs = { 64 | 'Norm': get_dct_norm, 65 | 'No_Norm': get_dct, 66 | } 67 | idct_trans_funcs = { 68 | 'Norm': get_idct_norm, 69 | 'No_Norm': get_idct, 70 | } 71 | 72 | if opt.is_hand_norm: 73 | dct_trans = dct_trans_funcs['Norm'] 74 | idct_trans = idct_trans_funcs['Norm'] 75 | else: 76 | dct_trans = dct_trans_funcs['No_Norm'] 77 | idct_trans = idct_trans_funcs['No_Norm'] 78 | 79 | train_expid = opt.exp 80 | test_expid = 'TEST_' + opt.exp[6:] 81 | 82 | test_script_name = os.path.basename(__file__).split('.')[0] 83 | train_script_name = 'train_' + test_script_name[5:] 84 | train_script_name = "ckpt_eai_dct_n{:d}_out{:d}_dctn{:d}".format(input_n, output_n, all_n) 85 | test_script_name = "ckpt_eai_dct_n{:d}_out{:d}_dctn{:d}".format(input_n, output_n, all_n) 86 | 87 | train_ckpt_path = './checkpoint/{}/{}_best.pth.tar'.format(train_expid,train_script_name) 88 | test_csv_path = './checkpoint/{}'.format(test_expid) 89 | 90 | print(">>> loading ckpt len from '{}'".format(train_ckpt_path)) 91 | ckpt = torch.load(train_ckpt_path) if is_cuda else torch.load(train_ckpt_path, map_location='cpu') 92 | 93 | lr = ckpt['lr'] 94 | start_epoch = ckpt['epoch'] 95 | train_loss = ckpt['train_loss'] 96 | 97 | new_ckpt_state_dict = {} 98 | for i in ckpt['state_dict'].keys(): 99 | new_ckpt_state_dict[i[7:]] = ckpt['state_dict'][i] 100 | 101 | model.load_state_dict(new_ckpt_state_dict) 102 | print(">>> ckpt len loaded (epoch: {} | err: {})".format(start_epoch, train_loss)) 103 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 104 | 105 | test_dataset = Grab(path_to_data=opt.grab_data_dict, input_n=input_n, output_n=output_n, split=2, debug= opt.is_debug, using_saved_file=opt.is_using_saved_file,using_noTpose2=opt.is_using_noTpose2) 106 | test_loader = DataLoader(test_dataset, batch_size=opt.test_batch, shuffle=False, num_workers=0,pin_memory=True) 107 | test_loss, test_body_loss, test_lhand_loss, test_rhand_loss,test_lhand_rel_loss,test_rhand_rel_loss,_ = test_split(test_loader, model=model, device=device,dct_trans=dct_trans, idct_trans=idct_trans) 108 | 109 | eval_frame = [3, 6, 9, 12, 15, 18, 21, 24, 27, 30] 110 | print('>>> Frames |>>> whole body |>>> main body |>>> left hand |>>> right hand|>>> rel left hand |>>> rel right hand') 111 | for i, f in enumerate(eval_frame): 112 | print('>>> {} |>>> {:.3f} |>>> {:.3f} |>>> {:.3f} |>>> {:.3f} |>>> {:.3f} |>>> {:.3f} '\ 113 | .format(f, test_loss[i],test_body_loss[i],test_lhand_loss[i],test_rhand_loss[i],test_lhand_rel_loss[i],test_rhand_rel_loss[i])) 114 | 115 | 116 | def get_dct(out_joints): 117 | batch, frame, node, dim = out_joints.data.shape 118 | dct_m_in, _ = get_dct_matrix(frame) 119 | input_joints = out_joints.transpose(0, 1).reshape(frame, -1).contiguous() 120 | input_dct_seq = np.matmul((dct_m_in[0:frame, :]), input_joints) 121 | input_dct_seq = torch.as_tensor(input_dct_seq) 122 | input_joints = input_dct_seq.reshape(frame, batch, -1).permute(1, 2, 0).contiguous() 123 | return input_joints 124 | 125 | def get_idct(y_out, out_joints, device): 126 | batch, frame, node, dim = out_joints.data.shape 127 | _, idct_m = get_dct_matrix(frame) 128 | idct_m = torch.from_numpy(idct_m).float().to(device) 129 | outputs_t = y_out.view(-1, frame).transpose(1, 0) 130 | outputs_p3d = torch.matmul(idct_m[:, 0:frame], outputs_t) 131 | outputs_p3d = outputs_p3d.reshape(frame, batch, -1, dim).contiguous().transpose(0, 1) 132 | pred_3d = outputs_p3d 133 | targ_3d = out_joints 134 | return pred_3d, targ_3d 135 | 136 | def get_dct_norm(out_joints): 137 | 138 | batch, frame, node, dim = out_joints.data.shape 139 | 140 | out_joints[:,:,25:40,:] = out_joints[:,:,25:40,:] - out_joints[:,:,20:21,:] 141 | out_joints[:,:,40:,:] = out_joints[:,:,40:,:] - out_joints[:,:,21:22,:] 142 | 143 | dct_m_in, _ = get_dct_matrix(frame) 144 | input_joints = out_joints.transpose(0, 1).reshape(frame, -1).contiguous() 145 | input_dct_seq = np.matmul((dct_m_in[0:frame, :]), input_joints) 146 | input_dct_seq = torch.as_tensor(input_dct_seq) 147 | input_joints = input_dct_seq.reshape(frame, batch, -1).permute(1, 2, 0).contiguous() 148 | return input_joints 149 | 150 | def get_idct_norm(y_out, out_joints, device): 151 | batch, frame, node, dim = out_joints.data.shape 152 | _, idct_m = get_dct_matrix(frame) 153 | idct_m = torch.from_numpy(idct_m).float().to(device) 154 | outputs_t = y_out.view(-1, frame).transpose(1, 0) 155 | # 50,32*55*3 156 | outputs_p3d = torch.matmul(idct_m[:, 0:frame], outputs_t) 157 | outputs_p3d = outputs_p3d.reshape(frame, batch, -1, dim).contiguous().transpose(0, 1) 158 | # 32,162,50 159 | 160 | outputs_p3d[:,:,25:40,:] = outputs_p3d[:,:,25:40,:] + outputs_p3d[:,:,20:21,:] 161 | outputs_p3d[:,:,40:,:] = outputs_p3d[:,:,40:,:] + outputs_p3d[:,:,21:22,:] 162 | 163 | pred_3d = outputs_p3d 164 | targ_3d = out_joints 165 | return pred_3d, targ_3d 166 | 167 | 168 | def test_split(test_loader, model, device, dct_trans,idct_trans): 169 | N = 0 170 | 171 | eval_frame = [32, 35, 38, 41, 44, 47, 50, 53, 56, 59] 172 | model.eval() 173 | t_posi = np.zeros(len(eval_frame)) 174 | 175 | t_body_posi = np.zeros(len(eval_frame)) 176 | t_lhand_posi = np.zeros(len(eval_frame)) 177 | t_rhand_posi = np.zeros(len(eval_frame)) 178 | 179 | t_lhand_rel_posi = np.zeros(len(eval_frame)) 180 | t_rhand_rel_posi = np.zeros(len(eval_frame)) 181 | 182 | with torch.no_grad(): 183 | for i, (input_pose, target_pose) in enumerate(test_loader): 184 | model_input = dct_trans(input_pose) 185 | n = input_pose.shape[0] 186 | if torch.cuda.is_available(): 187 | model_input = model_input.to(device).float() 188 | target_pose = target_pose.to(device).float() 189 | out_pose,_,_,_ = model(model_input) 190 | pred_3d, targ_3d = idct_trans(y_out=out_pose, out_joints=target_pose, device=device) 191 | 192 | rel_pred_3d = pred_3d.clone() 193 | rel_targ_3d = targ_3d.clone() 194 | 195 | rel_pred_3d[:,:,25:40] = rel_pred_3d[:,:,25:40] - rel_pred_3d[:,:,20:21] 196 | rel_pred_3d[:,:,40:] = rel_pred_3d[:,:,40:] - rel_pred_3d[:,:,21:22] 197 | 198 | rel_targ_3d[:,:,25:40] = rel_targ_3d[:,:,25:40] - rel_targ_3d[:,:,20:21] 199 | rel_targ_3d[:,:,40:] = rel_targ_3d[:,:,40:] - rel_targ_3d[:,:,21:22] 200 | 201 | for k in np.arange(0, len(eval_frame)): 202 | j = eval_frame[k] 203 | 204 | test_out, test_joints = pred_3d[:, j, :, :], targ_3d[:, j, :, :] 205 | loss_wholebody, _ = loss_func.joint_body_loss_test(test_out, test_joints) 206 | t_posi[k] += loss_wholebody.cpu().data.numpy() * n * 100 207 | 208 | test_body_out, test_body_joints = pred_3d[:, j, :25, :], targ_3d[:, j, :25, :] 209 | t_body_posi[k] += loss_func.joint_loss(test_body_out, test_body_joints).cpu().data.numpy() * n * 100 210 | 211 | test_lhand_out, test_lhand_joints = pred_3d[:, j, 25:40, :], targ_3d[:, j, 25:40, :] 212 | t_lhand_posi[k] += loss_func.joint_loss(test_lhand_out, test_lhand_joints).cpu().data.numpy() * n * 100 213 | 214 | test_rhand_out, test_rhand_joints = pred_3d[:, j, 40:, :], targ_3d[:, j, 40:, :] 215 | t_rhand_posi[k] += loss_func.joint_loss(test_rhand_out, test_rhand_joints).cpu().data.numpy() * n * 100 216 | 217 | test_lhand_rel_out, test_lhand_rel_joints = rel_pred_3d[:, j, 25:40, :], rel_targ_3d[:, j, 25:40, :] 218 | t_lhand_rel_posi[k] += loss_func.joint_loss(test_lhand_rel_out, test_lhand_rel_joints).cpu().data.numpy() * n * 100 219 | 220 | test_rhand_rel_out, test_rhand_rel_joints = rel_pred_3d[:, j, 40:, :], rel_targ_3d[:, j, 40:, :] 221 | t_rhand_rel_posi[k] += loss_func.joint_loss(test_rhand_rel_out, test_rhand_rel_joints).cpu().data.numpy() * n * 100 222 | 223 | N += n 224 | return t_posi / N,t_body_posi / N,t_lhand_posi / N,t_rhand_posi / N,t_lhand_rel_posi / N,t_rhand_rel_posi / N, N 225 | 226 | 227 | def get_dct_matrix(N): 228 | dct_m = np.eye(N) 229 | for k in np.arange(N): 230 | for i in np.arange(N): 231 | w = np.sqrt(2 / N) 232 | if k == 0: 233 | w = np.sqrt(1 / N) 234 | dct_m[k, i] = w * np.cos(np.pi * (i + 1 / 2) * k / N) 235 | idct_m = np.linalg.inv(dct_m) 236 | return dct_m, idct_m 237 | 238 | def setup_seed(seed): 239 | torch.manual_seed(seed) 240 | torch.cuda.manual_seed_all(seed) 241 | np.random.seed(seed) 242 | random.seed(seed) 243 | torch.backends.cudnn.deterministic = True 244 | 245 | if __name__ == "__main__": 246 | option = Options().parse() 247 | main(option) 248 | 249 | -------------------------------------------------------------------------------- /util/data_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import logging 4 | from copy import copy 5 | import torch.nn as nn 6 | 7 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 8 | to_cpu = lambda tensor: tensor.detach().cpu().numpy() 9 | 10 | 11 | def parse_npz(npz, allow_pickle=True): 12 | npz = np.load(npz, allow_pickle=allow_pickle) 13 | npz = {k: npz[k].item() for k in npz.files} 14 | return DotDict(npz) 15 | 16 | 17 | def params2torch(params, dtype=torch.float32): 18 | return {k: torch.from_numpy(v).type(dtype) for k, v in params.items()} 19 | 20 | 21 | def prepare_params(params, frame_mask, dtype=np.float32): 22 | return {k: v[frame_mask].astype(dtype) for k, v in params.items()} 23 | 24 | 25 | def DotDict(in_dict): 26 | out_dict = copy(in_dict) 27 | for k, v in out_dict.items(): 28 | if isinstance(v, dict): 29 | out_dict[k] = DotDict(v) 30 | return dotdict(out_dict) 31 | 32 | 33 | class dotdict(dict): 34 | __getattr__ = dict.get 35 | __setattr__ = dict.__setitem__ 36 | __delattr__ = dict.__delitem__ 37 | 38 | 39 | def append2dict(source, data): 40 | for k in data.keys(): 41 | if isinstance(data[k], list): 42 | source[k] += data[k].astype(np.float32) 43 | else: 44 | source[k].append(data[k].astype(np.float32)) 45 | 46 | 47 | def np2torch(item): 48 | out = {} 49 | for k, v in item.items(): 50 | if v == []: 51 | continue 52 | if isinstance(v, list): 53 | try: 54 | out[k] = torch.from_numpy(np.concatenate(v)) 55 | except: 56 | out[k] = torch.from_numpy(np.array(v)) 57 | elif isinstance(v, dict): 58 | out[k] = np2torch(v) 59 | else: 60 | out[k] = torch.from_numpy(v) 61 | return out 62 | 63 | 64 | def to_tensor(array, dtype=torch.float32): 65 | if not torch.is_tensor(array): 66 | array = torch.tensor(array) 67 | return array.to(dtype) 68 | 69 | 70 | def to_np(array, dtype=np.float32): 71 | if 'scipy.sparse' in str(type(array)): 72 | array = np.array(array.todencse(), dtype=dtype) 73 | elif torch.is_tensor(array): 74 | array = array.detach().cpu().numpy() 75 | return array 76 | 77 | 78 | def makepath(desired_path, isfile=False): 79 | import os 80 | if isfile: 81 | if not os.path.exists(os.path.dirname(desired_path)): os.makedirs(os.path.dirname(desired_path)) 82 | else: 83 | if not os.path.exists(desired_path): os.makedirs(desired_path) 84 | return desired_path 85 | 86 | 87 | def makelogger(log_dir, mode='w'): 88 | makepath(log_dir, isfile=True) 89 | logger = logging.getLogger() 90 | logger.setLevel(logging.INFO) 91 | 92 | ch = logging.StreamHandler() 93 | ch.setLevel(logging.INFO) 94 | 95 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 96 | 97 | ch.setFormatter(formatter) 98 | 99 | logger.addHandler(ch) 100 | 101 | fh = logging.FileHandler('%s' % log_dir, mode=mode) 102 | fh.setFormatter(formatter) 103 | logger.addHandler(fh) 104 | 105 | return logger 106 | 107 | 108 | def euler(rots, order='xyz', units='deg'): 109 | rots = np.asarray(rots) 110 | single_val = False if len(rots.shape) > 1 else True 111 | rots = rots.reshape(-1, 3) 112 | rotmats = [] 113 | 114 | for xyz in rots: 115 | if units == 'deg': 116 | xyz = np.radians(xyz) 117 | r = np.eye(3) 118 | for theta, axis in zip(xyz, order): 119 | c = np.cos(theta) 120 | s = np.sin(theta) 121 | if axis == 'x': 122 | r = np.dot(np.array([[1, 0, 0], [0, c, -s], [0, s, c]]), r) 123 | if axis == 'y': 124 | r = np.dot(np.array([[c, 0, s], [0, 1, 0], [-s, 0, c]]), r) 125 | if axis == 'z': 126 | r = np.dot(np.array([[c, -s, 0], [s, c, 0], [0, 0, 1]]), r) 127 | rotmats.append(r) 128 | rotmats = np.stack(rotmats).astype(np.float32) 129 | if single_val: 130 | return rotmats[0] 131 | else: 132 | return rotmats 133 | 134 | 135 | def create_video(path, fps=30, name='movie'): 136 | import os 137 | import subprocess 138 | 139 | src = os.path.join(path, '%*.png') 140 | movie_path = os.path.join(path, '%s.mp4' % name) 141 | i = 0 142 | while os.path.isfile(movie_path): 143 | movie_path = os.path.join(path, '%s_%02d.mp4' % (name, i)) 144 | i += 1 145 | 146 | cmd = 'ffmpeg -f image2 -r %d -i %s -b:v 6400k -pix_fmt yuv420p %s' % (fps, src, movie_path) 147 | 148 | subprocess.call(cmd.split(' ')) 149 | while not os.path.exists(movie_path): 150 | continue 151 | 152 | import torch 153 | import os 154 | 155 | 156 | class AccumLoss(object): 157 | def __init__(self): 158 | self.val = 0 159 | self.avg = 0 160 | self.sum = 0 161 | self.count = 0 162 | 163 | def update(self, val, n=1): 164 | self.val = val 165 | self.sum += val 166 | self.count += n 167 | self.avg = self.sum / self.count 168 | 169 | 170 | def lr_decay(optimizer, lr_now, gamma): 171 | lr = lr_now * gamma 172 | for param_group in optimizer.param_groups: 173 | param_group['lr'] = lr 174 | return lr 175 | 176 | 177 | def save_ckpt(state, ckpt_path, file_name=['ckpt_best.pth.tar']): 178 | file_path = os.path.join(ckpt_path, file_name[0]) 179 | torch.save(state, file_path) 180 | 181 | 182 | contact_ids = {'Body': 1, 183 | 'L_Thigh': 2, 184 | 'R_Thigh': 3, 185 | 'Spine': 4, 186 | 'L_Calf': 5, 187 | 'R_Calf': 6, 188 | 'Spine1': 7, 189 | 'L_Foot': 8, 190 | 'R_Foot': 9, 191 | 'Spine2': 10, 192 | 'L_Toes': 11, 193 | 'R_Toes': 12, 194 | 'Neck': 13, 195 | 'L_Shoulder': 14, 196 | 'R_Shoulder': 15, 197 | 'Head': 16, 198 | 'L_UpperArm': 17, 199 | 'R_UpperArm': 18, 200 | 'L_ForeArm': 19, 201 | 'R_ForeArm': 20, 202 | 'L_Hand': 21, 203 | 'R_Hand': 22, 204 | 'Jaw': 23, 205 | 'L_Eye': 24, 206 | 'R_Eye': 25, 207 | 'L_Index1': 26, 208 | 'L_Index2': 27, 209 | 'L_Index3': 28, 210 | 'L_Middle1': 29, 211 | 'L_Middle2': 30, 212 | 'L_Middle3': 31, 213 | 'L_Pinky1': 32, 214 | 'L_Pinky2': 33, 215 | 'L_Pinky3': 34, 216 | 'L_Ring1': 35, 217 | 'L_Ring2': 36, 218 | 'L_Ring3': 37, 219 | 'L_Thumb1': 38, 220 | 'L_Thumb2': 39, 221 | 'L_Thumb3': 40, 222 | 'R_Index1': 41, 223 | 'R_Index2': 42, 224 | 'R_Index3': 43, 225 | 'R_Middle1': 44, 226 | 'R_Middle2': 45, 227 | 'R_Middle3': 46, 228 | 'R_Pinky1': 47, 229 | 'R_Pinky2': 48, 230 | 'R_Pinky3': 49, 231 | 'R_Ring1': 50, 232 | 'R_Ring2': 51, 233 | 'R_Ring3': 52, 234 | 'R_Thumb1': 53, 235 | 'R_Thumb2': 54, 236 | 'R_Thumb3': 55} 237 | 238 | 239 | def normal_init_(layer, mean_, sd_, bias, norm_bias=True): 240 | """Intialization of layers with normal distribution with mean and bias""" 241 | classname = layer.__class__.__name__ 242 | if classname.find('Linear') != -1: 243 | layer.weight.data.normal_(mean_, sd_) 244 | if norm_bias: 245 | layer.bias.data.normal_(bias, 0.05) 246 | else: 247 | layer.bias.data.fill_(bias) 248 | 249 | 250 | def weight_init( 251 | module, 252 | mean_=0, 253 | sd_=0.004, 254 | bias=0.0, 255 | norm_bias=False, 256 | init_fn_=normal_init_): 257 | """Initialization of layers with normal distribution""" 258 | moduleclass = module.__class__.__name__ 259 | try: 260 | for layer in module: 261 | if layer.__class__.__name__ == 'Sequential': 262 | for l in layer: 263 | init_fn_(l, mean_, sd_, bias, norm_bias) 264 | else: 265 | init_fn_(layer, mean_, sd_, bias, norm_bias) 266 | except TypeError: 267 | init_fn_(module, mean_, sd_, bias, norm_bias) 268 | 269 | 270 | def xavier_init_(layer, mean_, sd_, bias, norm_bias=True): 271 | classname = layer.__class__.__name__ 272 | if classname.find('Linear') != -1: 273 | print('[INFO] (xavier_init) Initializing layer {}'.format(classname)) 274 | nn.init.xavier_uniform_(layer.weight.data) 275 | if norm_bias: 276 | layer.bias.data.normal_(0, 0.05) 277 | else: 278 | layer.bias.data.zero_() 279 | 280 | 281 | def create_dir_tree(base_dir): 282 | dir_tree = ['models', 'tf_logs', 'config', 'std_log'] 283 | for dir_ in dir_tree: 284 | os.makedirs(os.path.join(base_dir, dir_), exist_ok=True) 285 | 286 | 287 | def create_look_ahead_mask(seq_length, is_nonautoregressive=False): 288 | """Generates a binary mask to prevent to use future context in a sequence.""" 289 | if is_nonautoregressive: 290 | return np.zeros((seq_length, seq_length), dtype=np.float32) 291 | x = np.ones((seq_length, seq_length), dtype=np.float32) 292 | mask = np.triu(x, 1).astype(np.float32) 293 | return mask # (seq_len, seq_len) 294 | 295 | 296 | RED = (0, 1, 1) 297 | ORANGE = (20/360, 1, 1) 298 | YELLOW = (60/360, 1, 1) 299 | GREEN = (100/360, 1, 1) 300 | CYAN = (175/360, 1, 1) 301 | BLUE = (210/360, 1, 1) 302 | 303 | RED_DARKER = (0, 1, 0.25) 304 | ORANGE_DARKER = (20/360, 1, 0.25) 305 | YELLOW_DARKER = (60/360, 1, 0.25) 306 | GREEN_DARKER = (100/360, 1, 0.25) 307 | CYAN_DARKER = (175/360, 1, 0.25) 308 | BLUE_DARKER = (210/360, 1, 0.25) 309 | class Grab_Skeleton_55: 310 | num_joints = 55 311 | start_joints = [0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 12, 13, 14, 16, 18, 17, 19, 15, 15, 15, 312 | 21, 21, 21, 21, 21, 52, 53, 40, 41, 43, 44, 49, 50, 46, 47, 313 | 20, 20, 20, 20, 20, 37, 38, 25, 26, 28, 29, 34, 35, 31, 32] 314 | end_joints = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 20, 19, 21, 22, 23, 24, 315 | 40, 43, 49, 46, 52, 53, 54, 41, 42, 44, 45, 50, 51, 47, 48, 316 | 37, 25, 28, 34, 31, 38, 39, 26, 27, 29, 30, 35, 36, 32, 33, 317 | ] 318 | bones = list(zip(start_joints,end_joints)) 319 | 320 | 321 | def define_actions(action='all'): 322 | if action == 'all': 323 | return ['airplane-fly-1', 'airplane-lift-1', 'airplane-pass-1', 'alarmclock-lift-1', 'alarmclock-pass-1', 324 | 'alarmclock-see-1', 'apple-eat-1', 'apple-pass-1', 'banana-eat-1', 'banana-lift-1', 'banana-pass-1', 325 | 'banana-peel-1', 'banana-peel-2', 'binoculars-lift-1', 'binoculars-pass-1', 'binoculars-see-1', 326 | 'bowl-drink-1', 327 | 'bowl-drink-2', 'bowl-lift-1', 'bowl-pass-1', 'camera-browse-1', 'camera-pass-1', 328 | 'camera-takepicture-1', 329 | 'camera-takepicture-2', 'camera-takepicture-3', 'cubelarge-inspect-1', 'cubelarge-lift-1', 330 | 'cubelarge-pass-1', 331 | 'cubemedium-inspect-1', 'cubemedium-lift-1', 'cubemedium-pass-1', 'cubesmall-inspect-1', 332 | 'cubesmall-lift-1', 333 | 'cubesmall-pass-1', 'cup-drink-1', 'cup-drink-2', 'cup-lift-1', 'cup-pass-1', 'cup-pour-1', 334 | 'cylinderlarge-inspect-1', 'cylinderlarge-lift-1', 'cylinderlarge-pass-1', 'cylindermedium-inspect-1', 335 | 'cylindermedium-pass-1', 'cylindersmall-inspect-1', 'cylindersmall-pass-1', 'doorknob-lift-1', 336 | 'doorknob-use-1', 'doorknob-use-2', 'duck-pass-1', 'elephant-inspect-1', 'elephant-pass-1', 337 | 'eyeglasses-wear-1', 'flashlight-on-1', 'flashlight-on-2', 'flute-pass-1', 'flute-play-1', 338 | 'fryingpan-cook-1', 339 | 'fryingpan-cook-2', 'gamecontroller-lift-1', 'gamecontroller-pass-1', 'gamecontroller-play-1', 340 | 'hammer-lift-1', 341 | 'hammer-pass-1', 'hammer-use-1', 'hammer-use-2', 'hammer-use-3', 'hand-inspect-1', 'hand-lift-1', 342 | 'hand-pass-1', 'hand-shake-1', 'headphones-lift-1', 'headphones-pass-1', 'headphones-use-1', 343 | 'knife-chop-1', 344 | 'knife-pass-1', 'knife-peel-1', 'lightbulb-pass-1', 'lightbulb-screw-1', 'mouse-lift-1', 'mouse-pass-1', 345 | 'mouse-use-1', 'mug-drink-1', 'mug-drink-2', 'mug-lift-1', 'mug-pass-1', 'mug-toast-1', 'phone-call-1', 346 | 'phone-lift-1', 'phone-pass-1', 'piggybank-pass-1', 'piggybank-use-1', 'pyramidlarge-pass-1', 347 | 'pyramidmedium-inspect-1', 'pyramidmedium-lift-1', 'pyramidmedium-pass-1', 'pyramidsmall-inspect-1', 348 | 'scissors-pass-1', 'scissors-use-1', 'spherelarge-inspect-1', 'spherelarge-lift-1', 349 | 'spherelarge-pass-1', 350 | 'spheremedium-inspect-1', 'spheremedium-lift-1', 'spheremedium-pass-1', 'spheresmall-inspect-1', 351 | 'spheresmall-pass-1', 'stamp-lift-1', 'stamp-pass-1', 'stamp-stamp-1', 'stanfordbunny-inspect-1', 352 | 'stanfordbunny-lift-1', 'stanfordbunny-pass-1', 'stapler-lift-1', 'stapler-pass-1', 'stapler-staple-1', 353 | 'stapler-staple-2', 'teapot-pass-1', 'teapot-pour-1', 'teapot-pour-2', 'toothpaste-lift-1', 354 | 'toothpaste-pass-1', 'toothpaste-squeeze-1', 'toothpaste-squeeze-2', 'toruslarge-inspect-1', 355 | 'toruslarge-lift-1', 'toruslarge-pass-1', 'torusmedium-inspect-1', 'torusmedium-lift-1', 356 | 'torusmedium-pass-1', 357 | 'torussmall-inspect-1', 'torussmall-lift-1', 'torussmall-pass-1', 'train-lift-1', 'train-pass-1', 358 | 'train-play-1', 'watch-pass-1', 'waterbottle-drink-1', 'waterbottle-pass-1', 'waterbottle-pour-1', 359 | 'wineglass-drink-1', 'wineglass-drink-2', 'wineglass-lift-1', 'wineglass-pass-1', 'wineglass-toast-1'] 360 | else: 361 | return action 362 | 363 | if __name__ == "__main__": 364 | skeleton = Grab_Skeleton_55 365 | print(skeleton.bones) 366 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | #!/usr/bin/env python 3 | # -*- coding: utf-8 -*- 4 | """ 5 | @IDE: PyCharm 6 | @author: dpx 7 | @contact: dingpx2015@gmail.com 8 | @time: 2022,9月 9 | Copyright (c), xiaohongshu 10 | 11 | @Desc: 12 | 13 | """ 14 | from multiprocessing.util import is_exiting 15 | import os 16 | import pdb 17 | import numpy 18 | import torch 19 | import shutil 20 | import random 21 | import torch.optim 22 | import numpy as np 23 | import pandas as pd 24 | import torch.nn as nn 25 | from tqdm import tqdm 26 | from util import loss_func 27 | from util.opt import Options 28 | from util.grab import Grab 29 | from torch.autograd import Variable 30 | from util import utils_utils as utils 31 | from model_others.EAI import GCN_EAI 32 | from torch.utils.data import DataLoader 33 | import torch.distributed as dist 34 | from torch.nn.parallel import DistributedDataParallel as DDP 35 | from torch.utils.data.distributed import DistributedSampler 36 | 37 | 38 | 39 | def main(opt, rank, local_rank, world_size, device): 40 | 41 | # 初始化参数 42 | setup_seed(opt.seed) 43 | input_n = opt.input_n 44 | output_n = opt.output_n 45 | all_n = input_n + output_n 46 | start_epoch = 0 47 | err_best = 10000 48 | lr_now = opt.lr 49 | 50 | # 加载数据集 51 | print(">>> loading train_data") 52 | train_dataset = Grab(path_to_data=opt.grab_data_dict, input_n=input_n, output_n=output_n, split=0, debug= opt.is_debug, using_saved_file=opt.is_using_saved_file, using_noTpose2=opt.is_using_noTpose2) 53 | print(">>> loading val_data") 54 | val_dataset = Grab(path_to_data=opt.grab_data_dict, input_n=input_n, output_n=output_n, split=1, debug= opt.is_debug, using_saved_file=opt.is_using_saved_file,using_noTpose2=opt.is_using_noTpose2) 55 | print(">>> making dataloader") 56 | 57 | # 多GPU分布式训练的数据处理 58 | batch_size = opt.train_batch // world_size # [*] // world_size 59 | train_sampler = DistributedSampler(train_dataset, shuffle=True) # [*] 60 | val_sampler = DistributedSampler(val_dataset, shuffle=False) # [*] 61 | train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler) # [*] sampler=... 62 | val_loader = DataLoader(val_dataset, batch_size=batch_size, sampler=val_sampler) # [*] sampler=... 63 | print(">>> train data {}".format(train_dataset.__len__())) 64 | print(">>> validation data {}".format(val_dataset.__len__())) 65 | 66 | # 加载模型 67 | print(">>> creating model") 68 | model = GCN_EAI(input_feature=all_n, hidden_feature=opt.linear_size, p_dropout=opt.dropout, num_stage=opt.num_stage, lh_node_n=opt.num_lh*3, rh_node_n=opt.num_rh*3,b_node_n=opt.num_body*3) 69 | model_name = '{}'.format(opt.model_type) 70 | if opt.is_exp: 71 | ckpt = opt.ckpt + opt.exp 72 | else: 73 | ckpt = opt.ckpt + model_name 74 | 75 | # 将模型迁移到GPU上 76 | is_cuda = torch.cuda.is_available() 77 | if is_cuda: 78 | if_find_unused_parameters = False 79 | model = model.to(device) 80 | model = DDP(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=if_find_unused_parameters) # [*] DDP(...) 81 | print_only_rank0(">>> total params: {:.2f}M".format(sum(p.numel() for p in model.parameters()) / 1000000.0)) 82 | 83 | # 加载优化器 84 | optimizer = torch.optim.AdamW(model.parameters(), lr=opt.lr*world_size) 85 | 86 | # continue from checkpoint 87 | script_name = "eai_dct_n{:d}_out{:d}_dctn{:d}".format(input_n, output_n, all_n) 88 | print_only_rank0(">>> is_load {}".format(opt.is_load)) 89 | if opt.is_load: 90 | model_path_len = '{}/ckpt_{}_best.pth.tar'.format(ckpt, script_name) 91 | print_only_rank0(">>> loading ckpt len from '{}'".format(model_path_len)) 92 | if is_cuda: 93 | ckpt_model = torch.load(model_path_len) 94 | else: 95 | ckpt_model = torch.load(model_path_len, map_location='cpu') 96 | start_epoch = ckpt_model['epoch'] 97 | err_best = ckpt_model['train_loss'] 98 | lr_now = ckpt_model['lr'] 99 | model.load_state_dict(ckpt_model['state_dict']) 100 | optimizer.load_state_dict(ckpt_model['optimizer']) 101 | print_only_rank0(">>> ckpt len loaded (epoch: {} | err: {})".format(start_epoch, err_best)) 102 | else: 103 | print_only_rank0(">>> loading ckpt from scratch") 104 | # 新建/覆盖ckpt文件 105 | if dist.get_rank() == 0: 106 | if os.path.exists(ckpt): 107 | shutil.rmtree(ckpt) 108 | os.makedirs(ckpt,exist_ok=True) 109 | 110 | # start training 111 | print(">>> err_best", err_best) 112 | 113 | dct_trans_funcs = { 114 | 'Norm': get_dct_norm, 115 | 'No_Norm': get_dct, 116 | } 117 | idct_trans_funcs = { 118 | 'Norm': get_idct_norm, 119 | 'No_Norm': get_idct, 120 | } 121 | 122 | # flag设定:是否要对手部做norm 123 | print('>>> whether hand norm:{}'.format(opt.is_hand_norm)) 124 | if opt.is_hand_norm: 125 | dct_trans = dct_trans_funcs['Norm'] 126 | idct_trans = idct_trans_funcs['Norm'] 127 | else: 128 | dct_trans = dct_trans_funcs['No_Norm'] 129 | idct_trans = idct_trans_funcs['No_Norm'] 130 | 131 | # 训练 132 | for epoch in range(start_epoch, opt.epochs): 133 | 134 | # sampler重采样dataloader 135 | train_loader.sampler.set_epoch(epoch) 136 | val_loader.sampler.set_epoch(epoch) 137 | 138 | # 学习率衰减设置 139 | if (epoch + 1) % opt.lr_decay == 0: # lr_decay=2学习率延迟 140 | lr_now = utils.lr_decay(optimizer, lr_now, opt.lr_gamma) # lr_gamma学习率更新倍数0.96 141 | print_only_rank0('=====================================') 142 | print_only_rank0('>>> epoch: {} | lr: {:.6f}'.format(epoch + 1, lr_now)) 143 | 144 | # csv初始化设置 145 | ret_log = np.array([epoch + 1]) 146 | head = np.array(['epoch']) 147 | 148 | # 训练 149 | Ir_now, t_l, = train(train_loader, model, optimizer, device=device, lr_now=lr_now, max_norm=opt.max_norm,dct_trans=dct_trans,idct_trans=idct_trans,is_boneloss=opt.is_boneloss,is_weighted_jointloss=opt.is_weighted_jointloss) 150 | # 训练结果 151 | print_only_rank0("train_loss:{}".format(t_l)) 152 | ret_log = np.append(ret_log, [lr_now, t_l]) 153 | head = np.append(head, ['lr', 't_l']) 154 | 155 | # 验证 156 | v_loss = validate(val_loader, model, device=device,dct_trans=dct_trans,idct_trans=idct_trans) 157 | # 短时结果 158 | print_only_rank0("v_loss:{}".format(v_loss)) 159 | ret_log = np.append(ret_log, [v_loss]) 160 | head = np.append(head, ['v_loss']) 161 | 162 | ######################################################################################################################## 163 | # 以下是短时的ckpt保存的代码 164 | if not np.isnan(v_loss): # 判断空值 只有数组数值运算时可使用如果v_e不是空值 165 | is_best = v_loss < err_best # err_best=10000 166 | err_best = min(v_loss, err_best) 167 | else: 168 | is_best = Falsecd 169 | ret_log = np.append(ret_log, is_best) # 内容 170 | head = np.append(head, ['is_best']) # 表头 171 | df = pd.DataFrame(np.expand_dims(ret_log, axis=0)) # DataFrame是Python中Pandas库中的一种数据结构,它类似excel,是一种二维表。 172 | if not os.path.exists(ckpt): 173 | os.makedirs(ckpt) 174 | if epoch == start_epoch: 175 | df.to_csv(ckpt + '/' + script_name + '.csv', header=head, index=False) 176 | else: 177 | with open(ckpt + '/' + script_name + '.csv', 'a') as f: 178 | df.to_csv(f, header=False, index=False) 179 | file_name = ['ckpt_' + script_name + '_epoch_{}.pth.tar'.format(epoch+1), 'ckpt_'] 180 | 181 | if dist.get_rank() == 0: 182 | file_name = ['ckpt_' + script_name + '_best.pth.tar', 'ckpt_' + script_name + '_last.pth.tar'] 183 | utils.save_ckpt({'epoch': epoch + 1, 184 | 'lr': lr_now, 185 | 'train_loss': t_l, 186 | 'state_dict': model.state_dict(), 187 | 'optimizer': optimizer.state_dict()}, 188 | ckpt_path=ckpt, 189 | is_best=is_best, 190 | file_name=file_name) 191 | 192 | 193 | def train(train_loader, model, optimizer, device, lr_now, max_norm, dct_trans, idct_trans, is_boneloss,is_weighted_jointloss): 194 | print_only_rank0("进入train") 195 | # 初始化 196 | iter_num = 0 197 | t_l = utils.AccumLoss() 198 | model.train() 199 | 200 | for (input_pose, target_pose) in tqdm(train_loader): 201 | # 加载数据 202 | model_input = dct_trans(input_pose) 203 | n = input_pose.shape[0] # 16 204 | if torch.cuda.is_available(): 205 | model_input = model_input.to(device).float() 206 | target_pose = target_pose.to(device).float() 207 | 208 | # 前向传播过程 209 | out_pose, mmdloss_ab, mmdloss_ac, mmdloss_bc = model(model_input) 210 | pred_3d, targ_3d = idct_trans(y_out=out_pose, out_joints=target_pose, device=device) 211 | 212 | # loss计算 213 | if is_weighted_jointloss: 214 | loss_jt = loss_func.weighted_joint_loss(pred_3d, targ_3d, ratio=0.6) 215 | else: 216 | loss_jt = loss_func.joint_loss(pred_3d, targ_3d) 217 | loss_pjt = loss_func.relative_hand_loss(pred_3d, targ_3d) 218 | 219 | if is_boneloss: 220 | loss_bl = loss_func.bone_loss(pred_3d, targ_3d, device) 221 | loss = loss_jt + 0.1 * loss_bl + 0.1 * loss_pjt 222 | else: 223 | loss = loss_jt 224 | loss = loss + 0.001 * (mmdloss_ab+mmdloss_ac+mmdloss_bc) 225 | 226 | # 反向传播过程 227 | optimizer.zero_grad() # 把梯度置零,也就是把loss关于weight的导数变成0. 228 | loss.backward() 229 | if max_norm: 230 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1) 231 | optimizer.step() # 则可以用所有Variable的grad成员和lr的数值自动更新Variable的数值 232 | 233 | # 更新总体loss结果 234 | t_l.update(loss.cpu().data.numpy() * n, n) 235 | 236 | return lr_now, t_l.avg 237 | 238 | def validate(val_loader, model, device, dct_trans, idct_trans): 239 | print_only_rank0("进入val") 240 | # 初始化 241 | t_l = utils.AccumLoss() 242 | model.eval() 243 | 244 | for i, (input_pose, target_pose) in enumerate(val_loader): 245 | 246 | # 加载数据 247 | model_input = dct_trans(input_pose) 248 | n = input_pose.shape[0] # 64 249 | if torch.cuda.is_available(): 250 | model_input = model_input.to(device).float() 251 | target_pose = target_pose.to(device).float() 252 | 253 | # 前向传播过程 254 | out_pose, _, _, _ = model(model_input) 255 | 256 | # DCT 转 3D结果 257 | pred_3d, targ_3d = idct_trans(y_out=out_pose, out_joints=target_pose, device=device) 258 | 259 | # 短时的ckpt挑选 260 | pred_3d = pred_3d 261 | targ_3d = targ_3d 262 | loss= loss_func.joint_loss(pred_3d, targ_3d) 263 | t_l.update(loss.cpu().data.numpy() * n, n) 264 | 265 | 266 | 267 | return t_l.avg 268 | 269 | # 一维DCT变换 270 | def get_dct_matrix(N): 271 | dct_m = np.eye(N) # 返回one-hot数组 272 | for k in np.arange(N): 273 | for i in np.arange(N): 274 | w = np.sqrt(2 / N) # 2/35开更 275 | if k == 0: 276 | w = np.sqrt(1 / N) 277 | dct_m[k, i] = w * np.cos(np.pi * (i + 1 / 2) * k / N) 278 | idct_m = np.linalg.inv(dct_m) # 矩阵求逆 279 | return dct_m, idct_m 280 | 281 | # 设定种子 282 | def setup_seed(seed): 283 | torch.manual_seed(seed) 284 | torch.cuda.manual_seed_all(seed) 285 | np.random.seed(seed) 286 | random.seed(seed) 287 | # 会降低训练速度 288 | torch.backends.cudnn.deterministic = True 289 | 290 | # 多GPU分布式训练的初始化 291 | def setup_DDP(backend="nccl", verbose=False): 292 | """ 293 | We don't set ADDR and PORT in here, like: 294 | # os.environ['MASTER_ADDR'] = 'localhost' 295 | # os.environ['MASTER_PORT'] = '12355' 296 | Because program's ADDR and PORT can be given automatically at startup. 297 | E.g. You can set ADDR and PORT by using: 298 | python -m torch.distributed.launch --master_addr="192.168.1.201" --master_port=23456 ... 299 | 300 | You don't set rank and world_size in dist.init_process_group() explicitly. 301 | 302 | :param backend: 303 | :param verbose: 304 | :return: 305 | """ 306 | rank = int(os.environ["RANK"]) 307 | local_rank = int(os.environ["LOCAL_RANK"]) 308 | world_size = int(os.environ["WORLD_SIZE"]) 309 | # If the OS is Windows or macOS, use gloo instead of nccl 310 | dist.init_process_group(backend=backend) 311 | # set distributed device 312 | device = torch.device("cuda:{}".format(local_rank)) 313 | if verbose: 314 | print(f"local rank: {local_rank}, global rank: {rank}, world size: {world_size}") 315 | return rank, local_rank, world_size, device 316 | 317 | # 多GPU分布式训练时候只打印第0个GPU的结果 318 | def print_only_rank0(log): 319 | if dist.get_rank() == 0: 320 | print(log) 321 | 322 | # 相对plevis的坐标系下:3D转DCT 323 | def get_dct(out_joints): 324 | batch, frame, node, dim = out_joints.data.shape 325 | dct_m_in, _ = get_dct_matrix(frame) 326 | input_joints = out_joints.transpose(0, 1).reshape(frame, -1).contiguous() 327 | input_dct_seq = np.matmul((dct_m_in[0:frame, :]), input_joints) 328 | input_dct_seq = torch.as_tensor(input_dct_seq) 329 | input_joints = input_dct_seq.reshape(frame, batch, -1).permute(1, 2, 0).contiguous() 330 | return input_joints 331 | 332 | # 相对plevis的坐标系下:DCT转3D 333 | def get_idct(y_out, out_joints, device): 334 | batch, frame, node, dim = out_joints.data.shape 335 | _, idct_m = get_dct_matrix(frame) 336 | idct_m = torch.from_numpy(idct_m).float().to(device) 337 | outputs_t = y_out.view(-1, frame).transpose(1, 0) 338 | outputs_p3d = torch.matmul(idct_m[:, 0:frame], outputs_t) 339 | outputs_p3d = outputs_p3d.reshape(frame, batch, -1, dim).contiguous().transpose(0, 1) 340 | pred_3d = outputs_p3d 341 | targ_3d = out_joints 342 | return pred_3d, targ_3d 343 | 344 | # 身体的关节,相对plevis的坐标系下:3D转DCT; 针对手部,相对wrist关节的坐标系下:3D转DCT; 345 | def get_dct_norm(out_joints): 346 | batch, frame, node, dim = out_joints.data.shape 347 | out_joints[:,:,25:40,:] = out_joints[:,:,25:40,:] - out_joints[:,:,20:21,:] 348 | out_joints[:,:,40:,:] = out_joints[:,:,40:,:] - out_joints[:,:,21:22,:] 349 | dct_m_in, _ = get_dct_matrix(frame) 350 | input_joints = out_joints.transpose(0, 1).reshape(frame, -1).contiguous() 351 | input_dct_seq = np.matmul((dct_m_in[0:frame, :]), input_joints) 352 | input_dct_seq = torch.as_tensor(input_dct_seq) 353 | input_joints = input_dct_seq.reshape(frame, batch, -1).permute(1, 2, 0).contiguous() 354 | return input_joints 355 | 356 | # 身体的关节,相对plevis的坐标系下:DCT转3D; 针对手部,相对wrist关节的坐标系下:DCT转3D; 357 | def get_idct_norm(y_out, out_joints, device): 358 | batch, frame, node, dim = out_joints.data.shape 359 | _, idct_m = get_dct_matrix(frame) 360 | idct_m = torch.from_numpy(idct_m).float().to(device) 361 | outputs_t = y_out.view(-1, frame).transpose(1, 0) 362 | outputs_p3d = torch.matmul(idct_m[:, 0:frame], outputs_t) 363 | outputs_p3d = outputs_p3d.reshape(frame, batch, -1, dim).contiguous().transpose(0, 1) 364 | outputs_p3d[:,:,25:40,:] = outputs_p3d[:,:,25:40,:] + outputs_p3d[:,:,20:21,:] 365 | outputs_p3d[:,:,40:,:] = outputs_p3d[:,:,40:,:] + outputs_p3d[:,:,21:22,:] 366 | pred_3d = outputs_p3d 367 | targ_3d = out_joints 368 | return pred_3d, targ_3d 369 | 370 | 371 | if __name__ == "__main__": 372 | option = Options().parse() 373 | # 初始化ddp的代码 374 | rank, local_rank, world_size, device = setup_DDP(verbose=True) 375 | main(option, rank, local_rank, world_size, device) 376 | -------------------------------------------------------------------------------- /model_others/EAI.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | 4 | import torch.nn as nn 5 | import torch 6 | from torch.nn.parameter import Parameter 7 | import math, copy 8 | from torch.nn import functional as F 9 | 10 | from model_others.GCN import * 11 | import util.data_utils as utils 12 | 13 | class TempSoftmaxFusion_2(nn.Module): 14 | def __init__(self, channels, detach_inputs=False, detach_feature=False): 15 | super(TempSoftmaxFusion_2, self).__init__() 16 | self.detach_inputs = detach_inputs 17 | self.detach_feature = detach_feature 18 | layers = [] 19 | for l in range(0, len(channels) - 1): 20 | layers.append(nn.Linear(channels[l], channels[l+1])) 21 | if l < len(channels) - 2: 22 | layers.append(nn.ReLU()) 23 | self.layers = nn.Sequential(*layers) 24 | self.register_parameter('temperature', nn.Parameter(torch.ones(1))) 25 | 26 | def forward(self, x, y, work=True): 27 | b, n, f = x.shape 28 | x = x.reshape(-1, f) 29 | y = y.reshape(-1, f) 30 | f_in = torch.cat([x, y], dim=1) 31 | if self.detach_inputs: 32 | f_in = f_in.detach() 33 | f_temp = self.layers(f_in) 34 | f_weight = F.softmax(f_temp*self.temperature, dim=1) 35 | if self.detach_feature: 36 | x = x.detach() 37 | y = y.detach() 38 | f_out = f_weight[:,[0]]*x + f_weight[:,[1]]*y 39 | f_out = f_out.view(b,-1,f) 40 | return f_out 41 | 42 | def guassian_kernel(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None): 43 | n_samples = int(source.size()[0])+int(target.size()[0]) 44 | total = torch.cat([source, target], dim=0) 45 | 46 | total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1))) 47 | total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1))) 48 | L2_distance = ((total0-total1)**2).sum(2) 49 | if fix_sigma: 50 | bandwidth = fix_sigma 51 | else: 52 | bandwidth = torch.sum(L2_distance.data) / (n_samples**2-n_samples) 53 | bandwidth /= kernel_mul ** (kernel_num // 2) 54 | bandwidth_list = [bandwidth * (kernel_mul**i) for i in range(kernel_num)] 55 | kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list] 56 | return sum(kernel_val) 57 | 58 | def get_mmdloss(source, target,kernel_mul=2.0, kernel_num=5, fix_sigma=None): 59 | batch_size = int(source.size()[0]) 60 | kernels = guassian_kernel(source, target, kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma) 61 | XX = kernels[:batch_size, :batch_size] 62 | YY = kernels[batch_size:, batch_size:] 63 | XY = kernels[:batch_size, batch_size:] 64 | YX = kernels[batch_size:, :batch_size] 65 | loss = torch.mean(XX + YY - XY -YX) 66 | return loss 67 | 68 | class MultiheadAttention(nn.Module): 69 | 70 | def __init__(self, embed_dim, num_heads, attn_dropout=0., 71 | bias=True, add_bias_kv=False, add_zero_attn=False): 72 | super().__init__() 73 | self.embed_dim = embed_dim 74 | self.num_heads = num_heads 75 | self.attn_dropout = attn_dropout 76 | self.head_dim = embed_dim // num_heads 77 | self.scaling = self.head_dim ** -0.5 78 | 79 | self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim)) 80 | self.register_parameter('in_proj_bias', None) 81 | if bias: 82 | self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim)) 83 | self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 84 | 85 | if add_bias_kv: 86 | self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim)) 87 | self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim)) 88 | else: 89 | self.bias_k = self.bias_v = None 90 | 91 | self.add_zero_attn = add_zero_attn 92 | 93 | self.reset_parameters() 94 | 95 | def reset_parameters(self): 96 | nn.init.xavier_uniform_(self.in_proj_weight) 97 | nn.init.xavier_uniform_(self.out_proj.weight) 98 | if self.in_proj_bias is not None: 99 | nn.init.constant_(self.in_proj_bias, 0.) 100 | nn.init.constant_(self.out_proj.bias, 0.) 101 | if self.bias_k is not None: 102 | nn.init.xavier_normal_(self.bias_k) 103 | if self.bias_v is not None: 104 | nn.init.xavier_normal_(self.bias_v) 105 | 106 | def forward(self, query, key, value, pde_qk, attn_mask=None): 107 | qkv_same = query.data_ptr() == key.data_ptr() == value.data_ptr() 108 | kv_same = key.data_ptr() == value.data_ptr() 109 | 110 | query = query.transpose(0,1) 111 | key = key.transpose(0,1) 112 | value = value.transpose(0,1) 113 | 114 | tgt_len, bsz, embed_dim = query.size() 115 | assert embed_dim == self.embed_dim 116 | assert list(query.size()) == [tgt_len, bsz, embed_dim] 117 | assert key.size() == value.size() 118 | 119 | aved_state = None 120 | 121 | if qkv_same: 122 | q, k, v = self.in_proj_qkv(query) 123 | elif kv_same: 124 | q = self.in_proj_q(query) 125 | 126 | if key is None: 127 | assert value is None 128 | k = v = None 129 | else: 130 | k, v = self.in_proj_kv(key) 131 | else: 132 | q = self.in_proj_q(query) 133 | k = self.in_proj_k(key) 134 | v = self.in_proj_v(value) 135 | q = q * self.scaling 136 | 137 | if self.bias_k is not None: 138 | assert self.bias_v is not None 139 | k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) 140 | v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) 141 | if attn_mask is not None: 142 | attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1) 143 | 144 | q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) 145 | if k is not None: 146 | k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) 147 | if v is not None: 148 | v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) 149 | 150 | src_len = k.size(1) 151 | 152 | if self.add_zero_attn: 153 | src_len += 1 154 | k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1) 155 | v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1) 156 | if attn_mask is not None: 157 | attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1) 158 | 159 | attn_weights = torch.bmm(q, k.transpose(1, 2)) 160 | assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] 161 | 162 | if attn_mask is not None: 163 | try: 164 | attn_weights += attn_mask.unsqueeze(0) 165 | except: 166 | print(attn_weights.shape) 167 | print(attn_mask.unsqueeze(0).shape) 168 | assert False 169 | 170 | attn_weights = F.softmax(attn_weights.float(), dim=-1).type_as(attn_weights) 171 | attn_weights = F.dropout(attn_weights, p=self.attn_dropout, training=self.training) 172 | attn = torch.bmm(attn_weights, v) 173 | assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] 174 | attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) 175 | attn = self.out_proj(attn) 176 | attn = attn.transpose(0,1) 177 | attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) 178 | attn_weights = attn_weights.sum(dim=1) / self.num_heads 179 | 180 | return attn, attn_weights 181 | 182 | def in_proj_qkv(self, query): 183 | return self._in_proj(query).chunk(3, dim=-1) 184 | 185 | def in_proj_kv(self, key): 186 | return self._in_proj(key, start=self.embed_dim).chunk(2, dim=-1) 187 | 188 | def in_proj_q(self, query, **kwargs): 189 | return self._in_proj(query, end=self.embed_dim, **kwargs) 190 | 191 | def in_proj_k(self, key): 192 | return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim) 193 | 194 | def in_proj_v(self, value): 195 | return self._in_proj(value, start=2 * self.embed_dim) 196 | 197 | def _in_proj(self, input, start=0, end=None, **kwargs): 198 | weight = kwargs.get('weight', self.in_proj_weight) 199 | bias = kwargs.get('bias', self.in_proj_bias) 200 | weight = weight[start:end, :] 201 | if bias is not None: 202 | bias = bias[start:end] 203 | return F.linear(input, weight, bias) 204 | 205 | class SelfAttention_block(nn.Module): 206 | def __init__(self, input_feature=60, hidden_feature=256, p_dropout=0.5, num_stage=12, node_n=15*3): 207 | super(SelfAttention_block, self).__init__() 208 | self.gcn = GCN(input_feature, hidden_feature, p_dropout, num_stage, node_n) 209 | 210 | def forward(self, x): 211 | y = self.gcn(x) 212 | return y 213 | 214 | class CrossAttention_block(nn.Module): 215 | def __init__(self, 216 | input_dim=60, 217 | head_num=3, 218 | dim_ffn=256, 219 | dropout=0.2, 220 | init_fn=utils.normal_init_): 221 | super(CrossAttention_block, self).__init__() 222 | self._model_dim = input_dim 223 | self._dim_ffn = dim_ffn 224 | self._relu = nn.ReLU() 225 | self._dropout_layer = nn.Dropout(dropout) 226 | self.inner_att = MultiheadAttention(input_dim, head_num, attn_dropout=dropout) 227 | self._linear1 = nn.Linear(self._model_dim, self._dim_ffn) 228 | self._linear2 = nn.Linear(self._dim_ffn, self._model_dim) 229 | self._norm2 = nn.LayerNorm(self._model_dim, eps=1e-5) 230 | 231 | utils.weight_init(self._linear1, init_fn_=init_fn) 232 | utils.weight_init(self._linear2, init_fn_=init_fn) 233 | 234 | def forward(self, x, y, pdm_xy=None): 235 | query =x 236 | key = y 237 | value = y 238 | attn_output, _ = self.inner_att( 239 | query, 240 | key, 241 | value, 242 | pdm_xy 243 | ) 244 | norm_attn_ = self._dropout_layer(attn_output) + query 245 | norm_attn = self._norm2(norm_attn_) 246 | output = self._linear1(norm_attn) 247 | output = self._relu(output) 248 | output = self._dropout_layer(output) 249 | output = self._linear2(output) 250 | output = self._dropout_layer(output) + norm_attn_ 251 | return output 252 | 253 | class DA_Norm(nn.Module): 254 | def __init__(self,num_features): 255 | super().__init__() 256 | 257 | shape = (1,1,num_features) 258 | shape2 = (1,1,num_features) 259 | 260 | self.gamma = nn.Parameter(torch.ones(shape)) 261 | self.beta = nn.Parameter(torch.zeros(shape)) 262 | self.gamma2 = nn.Parameter(torch.ones(shape)) 263 | self.beta2 = nn.Parameter(torch.zeros(shape)) 264 | 265 | moving_mean = torch.zeros(shape2) 266 | moving_var = torch.zeros(shape2) 267 | moving_mean2 = torch.zeros(shape2) 268 | moving_var2 = torch.zeros(shape2) 269 | 270 | self.register_buffer("moving_mean", moving_mean) 271 | self.register_buffer("moving_var", moving_var) 272 | self.register_buffer("moving_mean2", moving_mean2) 273 | self.register_buffer("moving_var2", moving_var2) 274 | self.weight = nn.Parameter(torch.zeros(1)) 275 | 276 | def forward(self,X, X2): 277 | if self.moving_mean.device != X.device: 278 | self.moving_mean = self.moving_mean.to(X.device) 279 | self.moving_var = self.moving_var.to(X.device) 280 | self.moving_mean2 = self.moving_mean2.to(X.device) 281 | self.moving_var2 = self.moving_var2.to(X.device) 282 | Y, Y2, self.moving_mean, self.moving_var,self.moving_mean2, self.moving_var2 = batch_norm(X,X2,self.gamma,self.beta,self.moving_mean,self.moving_var,self.gamma2,self.beta2,self.moving_mean2,self.moving_var2,self.weight,eps=1e-5,momentum=0.9) 283 | 284 | return Y,Y2 285 | 286 | def batch_norm(X, X2,gamma,beta,moving_mean,moving_var,gamma2,beta2,moving_mean2,moving_var2,weight,eps,momentum): 287 | 288 | if not torch.is_grad_enabled(): 289 | 290 | X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps) 291 | X_hat2 = (X2 - moving_mean2) / torch.sqrt(moving_var2 + eps) 292 | else: 293 | weight = (F.sigmoid(weight)+1)/2 294 | 295 | mean = X.mean(dim=(0,1),keepdim=True) 296 | var = ((X - mean)**2).mean(dim=(0,1),keepdim=True) 297 | 298 | mean2 = X2.mean(dim=(0,1),keepdim=True) 299 | var2 = ((X2 - mean2)**2).mean(dim=(0,1),keepdim=True) 300 | 301 | mean_fa = weight * mean + (1-weight)* mean2 302 | mean_fb = weight * mean2 + (1-weight)* mean 303 | 304 | var_fa = weight * var + (1-weight)* var2 305 | var_fb = weight * var2 + (1-weight)* var 306 | 307 | X_hat = (X - mean_fa) / torch.sqrt(var_fa + eps) 308 | X_hat2 = (X2 - mean_fb) / torch.sqrt(var_fb + eps) 309 | 310 | moving_mean = momentum * moving_mean + (1.0 - momentum) * mean_fa 311 | moving_var = momentum * moving_var + (1.0 - momentum) * var_fa 312 | 313 | moving_mean2 = momentum * moving_mean2 + (1.0 - momentum) * mean_fb 314 | moving_var2 = momentum * moving_var2 + (1.0 - momentum) * var_fb 315 | 316 | Y = gamma * X_hat + beta 317 | Y2 = gamma2 * X_hat2 + beta2 318 | 319 | return Y, Y2, moving_mean.data, moving_var.data,moving_mean2.data, moving_var2.data 320 | 321 | class Alignment_block(nn.Module): 322 | def __init__(self, 323 | input_dim=256, 324 | head_num=8, 325 | dim_ffn=256, 326 | dropout=0.2, 327 | init_fn=utils.normal_init_, 328 | src_len1=None, 329 | src_len2=None): 330 | super(Alignment_block, self).__init__() 331 | self._model_dim = input_dim 332 | self._dim_ffn = dim_ffn 333 | 334 | def forward(self, x, x2, x3, mmd_flag=False): 335 | # Calculating MMD loss 336 | output_sa_x = x 337 | output_sa_x2 = x2 338 | output_sa_x3 = x3 339 | if mmd_flag: 340 | xa_f = torch.mean(output_sa_x,1) 341 | xb_f = torch.mean(output_sa_x2,1) 342 | xc_f = torch.mean(output_sa_x3,1) 343 | mmdlossab = get_mmdloss(xa_f,xb_f) 344 | mmdlossbc = get_mmdloss(xb_f,xc_f) 345 | mmdlossac = get_mmdloss(xc_f,xa_f) 346 | else: 347 | mmdlossab = 0 348 | mmdlossbc = 0 349 | mmdlossac = 0 350 | 351 | return output_sa_x, output_sa_x2, output_sa_x3, mmdlossab, mmdlossbc, mmdlossac 352 | 353 | class GCN_EAI(nn.Module): 354 | def __init__(self, input_feature=60, hidden_feature=256, p_dropout=0.5, num_stage=12, lh_node_n=15*3, rh_node_n=15*3,b_node_n=25*3): 355 | super(GCN_EAI, self).__init__() 356 | 357 | # Individual Encoder 358 | num_stage_encoder = 12 359 | self.body_encoder = SelfAttention_block(input_feature, hidden_feature, p_dropout, num_stage_encoder, node_n=b_node_n) 360 | self.lhand_encoder = SelfAttention_block(input_feature, hidden_feature, p_dropout, num_stage_encoder, node_n=lh_node_n+3) 361 | self.rhand_encoder = SelfAttention_block(input_feature, hidden_feature, p_dropout, num_stage_encoder, node_n=rh_node_n+3) 362 | 363 | # Distribution Norm 364 | self._normab = DA_Norm(input_feature) 365 | self._normbc = DA_Norm(input_feature) 366 | self._normca = DA_Norm(input_feature) 367 | 368 | # Feature Alignment 369 | self.align_num_layers = 1 370 | head_num = 3 371 | self._align_layers = nn.ModuleList([]) 372 | for i in range(self.align_num_layers): 373 | self._align_layers.append(Alignment_block(head_num=head_num,input_dim=input_feature,src_len1=b_node_n,src_len2=lh_node_n+3)) 374 | 375 | # Semantic Interaction 376 | self.ca_num_layers = 5 377 | self._inter_body_lhand_layers = nn.ModuleList([]) 378 | self._inter_body_rhand_layers = nn.ModuleList([]) 379 | self._inter_lhand_body_layers = nn.ModuleList([]) 380 | self._inter_lhand_rhand_layers = nn.ModuleList([]) 381 | self._inter_rhand_lhand_layers = nn.ModuleList([]) 382 | self._inter_rhand_body_layers = nn.ModuleList([]) 383 | for i in range(self.ca_num_layers): 384 | self._inter_body_lhand_layers.append(CrossAttention_block(head_num=head_num,input_dim=input_feature)) 385 | self._inter_body_rhand_layers.append(CrossAttention_block(head_num=head_num,input_dim=input_feature)) 386 | self._inter_lhand_body_layers.append(CrossAttention_block(head_num=head_num,input_dim=input_feature)) 387 | self._inter_lhand_rhand_layers.append(CrossAttention_block(head_num=head_num,input_dim=input_feature)) 388 | self._inter_rhand_lhand_layers.append(CrossAttention_block(head_num=head_num,input_dim=input_feature)) 389 | self._inter_rhand_body_layers.append(CrossAttention_block(head_num=head_num,input_dim=input_feature)) 390 | 391 | # Physical Interaction 392 | self.fusion_lwrist = TempSoftmaxFusion_2(channels=[input_feature*6,input_feature,2]) 393 | self.fusion_rwrist = TempSoftmaxFusion_2(channels=[input_feature*6,input_feature,2]) 394 | 395 | # Decoder 396 | self.body_decoder = nn.Linear(input_feature*3, input_feature) 397 | self.lhand_decoder = nn.Linear(input_feature*3, input_feature) 398 | self.rhand_decoder = nn.Linear(input_feature*3, input_feature) 399 | self.rwrist_decoder = nn.Linear(input_feature*3, input_feature) 400 | self.lwrist_decoder = nn.Linear(input_feature*3, input_feature) 401 | utils.weight_init(self.body_decoder, init_fn_= utils.normal_init_) 402 | utils.weight_init(self.lhand_decoder, init_fn_= utils.normal_init_) 403 | utils.weight_init(self.rhand_decoder, init_fn_= utils.normal_init_) 404 | utils.weight_init(self.rwrist_decoder, init_fn_= utils.normal_init_) 405 | utils.weight_init(self.lwrist_decoder, init_fn_= utils.normal_init_) 406 | 407 | def forward(self, x, action=None, pde_ml=None,pde_lm=None,pde_mr=None,pde_rm=None,pde_lr=None,pde_rl=None): 408 | 409 | # data process & wrist replicate 410 | b, n, f = x.shape 411 | whole_body_x = x.view(b, -1, 3, f) 412 | lwrist = whole_body_x[:,20:21].detach() 413 | rwrist = whole_body_x[:,21:22].detach() 414 | b_x = whole_body_x[:,:25].view(b, -1, f) 415 | lh_x = torch.cat((lwrist,whole_body_x[:,25:40]),1) 416 | lh_x = lh_x.view(b, -1, f) 417 | rh_x = torch.cat((rwrist,whole_body_x[:,40:]),1) 418 | rh_x = rh_x.view(b, -1, f) 419 | 420 | # Encoding 421 | hbody = self.body_encoder(b_x) 422 | lhand = self.lhand_encoder(lh_x) 423 | rhand = self.rhand_encoder(rh_x) 424 | 425 | # Distribution Normalization 426 | hbody1,lhand1 = self._normab(hbody,lhand) 427 | lhand1,rhand1 = self._normbc(lhand1,rhand) 428 | rhand1,hbody1 = self._normca(rhand1,hbody1) 429 | 430 | # Feature Alignment 431 | hbody2, rhand2, lhand2, mmdloss_ab, mmdloss_ac, mmdloss_bc = self._align_layers[0](hbody1, rhand1,lhand1,mmd_flag = True) 432 | 433 | # Semantic Interaction 434 | rhand_2_hbody = hbody2 435 | lhand_2_hbody = hbody2 436 | lhand_2_rhand = rhand2 437 | hbody_2_rhand = rhand2 438 | rhand_2_lhand = lhand2 439 | hbody_2_lhand = lhand2 440 | 441 | for i in range(self.ca_num_layers): 442 | rhand_2_hbody = self._inter_body_rhand_layers[i](rhand_2_hbody, rhand2) 443 | lhand_2_hbody = self._inter_body_lhand_layers[i](lhand_2_hbody, lhand2) 444 | 445 | lhand_2_rhand = self._inter_rhand_lhand_layers[i](lhand_2_rhand, lhand2) 446 | hbody_2_rhand = self._inter_rhand_body_layers[i](hbody_2_rhand, hbody2) 447 | 448 | rhand_2_lhand = self._inter_lhand_rhand_layers[i](rhand_2_lhand, rhand2) 449 | hbody_2_lhand = self._inter_lhand_body_layers[i](hbody_2_lhand, hbody2) 450 | 451 | # Feature Concat 452 | fusion_body = torch.cat((hbody,rhand_2_hbody,lhand_2_hbody),dim=2) 453 | fusion_rhand = torch.cat((rhand,lhand_2_rhand,hbody_2_rhand),dim=2) 454 | fusion_lhand = torch.cat((lhand,rhand_2_lhand,hbody_2_lhand),dim=2) 455 | 456 | # Physical Interaction 457 | b, n, f1 = fusion_body.shape 458 | hbody_lwrist = fusion_body.view(b, -1, 3, f1)[:,20:21].view(b, -1, f1) 459 | hbody_rwrist = fusion_body.view(b, -1, 3, f1)[:,21:22].view(b, -1, f1) 460 | lhand_lwrist = fusion_lhand.view(b, -1, 3, f1)[:,:1].view(b, -1, f1) 461 | rhand_rwrist = fusion_rhand.view(b, -1, 3, f1)[:,:1].view(b, -1, f1) 462 | fusion_lwrist = self.fusion_lwrist(hbody_lwrist,lhand_lwrist) 463 | fusion_rwrist = self.fusion_rwrist(hbody_rwrist,rhand_rwrist) 464 | 465 | hbody_no_wrist = torch.cat((fusion_body.view(b, -1, 3, f1)[:,:20],fusion_body.view(b, -1, 3, f1)[:,22:]),1).view(b, -1, f1) 466 | lhand_no_wrist = fusion_lhand.view(b, -1, 3, f1)[:,1:].view(b, -1, f1) 467 | rhand_no_wrist = fusion_rhand.view(b, -1, 3, f1)[:,1:].view(b, -1, f1) 468 | 469 | # Decoding 470 | hbody_no_wrist = self.body_decoder(hbody_no_wrist) 471 | lhand_no_wrist = self.lhand_decoder(lhand_no_wrist) 472 | rhand_no_wrist = self.rhand_decoder(rhand_no_wrist) 473 | fusion_lwrist = self.lwrist_decoder(fusion_lwrist) 474 | fusion_rwrist = self.rwrist_decoder(fusion_rwrist) 475 | 476 | hbody_no_wrist = hbody_no_wrist.view(b, -1, 3, f) 477 | lhand_no_wrist = lhand_no_wrist.view(b, -1, 3, f) 478 | rhand_no_wrist = rhand_no_wrist.view(b, -1, 3, f) 479 | fusion_lwrist = fusion_lwrist.view(b, -1, 3, f) 480 | fusion_rwrist = fusion_rwrist.view(b, -1, 3, f) 481 | output = torch.cat([hbody_no_wrist[:,:20],fusion_lwrist,fusion_rwrist,hbody_no_wrist[:,20:],lhand_no_wrist,rhand_no_wrist],1).view(b, -1, f) + x 482 | 483 | return output, mmdloss_ab, mmdloss_ac, mmdloss_bc 484 | --------------------------------------------------------------------------------