├── MRT ├── README.md ├── Modules.py ├── Layers.py ├── SubLayers.py └── Models.py ├── mocap ├── README.md ├── vis.py ├── mix_mocap.py ├── preprocess_mocap.py └── amc_parser.py ├── mupots3d ├── README.md ├── vis.py └── preprocess_mupots.py ├── utils.py ├── discriminator_data.py ├── data.py ├── README.md ├── train_mrt.py └── test.py /MRT/README.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /mocap/README.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /mupots3d/README.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def disc_l2_loss(disc_value): 6 | 7 | k = disc_value.shape[0] 8 | return torch.sum((disc_value - 1.0) ** 2) * 1.0 / k 9 | 10 | 11 | def adv_disc_l2_loss(real_disc_value, fake_disc_value): 12 | 13 | ka = real_disc_value.shape[0] 14 | kb = fake_disc_value.shape[0] 15 | lb, la = torch.sum(fake_disc_value ** 2) / kb, torch.sum((real_disc_value - 1) ** 2) / ka 16 | return la, lb, la + lb -------------------------------------------------------------------------------- /MRT/Modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class ScaledDotProductAttention(nn.Module): 6 | ''' Scaled Dot-Product Attention ''' 7 | 8 | def __init__(self, temperature, attn_dropout=0.1): 9 | super().__init__() 10 | self.temperature = temperature 11 | self.dropout = nn.Dropout(attn_dropout) 12 | 13 | def forward(self, q, k, v, mask=None): 14 | 15 | attn = torch.matmul(q / self.temperature, k.transpose(2, 3)) 16 | 17 | if mask is not None: 18 | attn = attn.masked_fill(mask == 0, -1e9) 19 | 20 | attn = self.dropout(F.softmax(attn, dim=-1)) 21 | output = torch.matmul(attn, v) 22 | 23 | return output, attn 24 | -------------------------------------------------------------------------------- /discriminator_data.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import torch 3 | import numpy as np 4 | 5 | class D_DATA(data.Dataset): 6 | def __init__(self,joints=15): 7 | 8 | self.data=np.load('./mocap/discriminator_3_120_mocap.npy',allow_pickle=True) 9 | 10 | self.len=len(self.data) 11 | 12 | 13 | def __getitem__(self, index): 14 | 15 | input_seq=self.data[index][:,:30,:][:,::2,:] 16 | output_seq=self.data[index][:,30:,:][:,::2,:] 17 | last_input=input_seq[:,-1:,:] 18 | output_seq=np.concatenate([last_input,output_seq],axis=1) 19 | 20 | return input_seq,output_seq 21 | 22 | 23 | 24 | def __len__(self): 25 | return self.len 26 | 27 | -------------------------------------------------------------------------------- /mupots3d/vis.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from mpl_toolkits.mplot3d import Axes3D 4 | 5 | data=np.load('mupots_120_3persons.npy',allow_pickle=True) 6 | eg=1 7 | data_list=data[eg] 8 | 9 | data_list=data_list.reshape(-1,120,15,3) 10 | 11 | 12 | 13 | 14 | body_edges = np.array( 15 | [[0,1], [1,2],[2,3],[0,4], 16 | [4,5],[5,6],[0,7],[7,8],[7,9],[9,10],[10,11],[7,12],[12,13],[13,14]] 17 | ) 18 | 19 | 20 | fig = plt.figure(figsize=(10, 4.5)) 21 | 22 | ax = fig.add_subplot(111, projection='3d') 23 | 24 | plt.ion() 25 | 26 | 27 | length_=data_list.shape[1] 28 | 29 | i=0 30 | while i < length_: 31 | ax.lines = [] 32 | for j in range(len(data_list)): 33 | 34 | xs=data_list[j,i,:,0] 35 | ys=data_list[j,i,:,1] 36 | zs=data_list[j,i,:,2] 37 | #print(xs) 38 | ax.plot( zs,xs, ys, 'y.') 39 | 40 | 41 | plot_edge=True 42 | if plot_edge: 43 | for edge in body_edges: 44 | x=[data_list[j,i,edge[0],0],data_list[j,i,edge[1],0]] 45 | y=[data_list[j,i,edge[0],1],data_list[j,i,edge[1],1]] 46 | z=[data_list[j,i,edge[0],2],data_list[j,i,edge[1],2]] 47 | if i>=30: 48 | ax.plot(z,x, y, 'green') 49 | else: 50 | ax.plot(z,x, y, 'blue') 51 | 52 | 53 | ax.set_xlim3d([-2 , 2]) 54 | ax.set_ylim3d([-2 , 2]) 55 | ax.set_zlim3d([-0, 2]) 56 | #ax.set_axis_off() 57 | ax.set_xlabel("x") 58 | ax.set_ylabel("y") 59 | ax.set_zlabel("z") 60 | 61 | plt.pause(0.01) 62 | i += 1 63 | 64 | 65 | plt.ioff() 66 | plt.show() -------------------------------------------------------------------------------- /MRT/Layers.py: -------------------------------------------------------------------------------- 1 | ''' Define the Layers ''' 2 | import torch.nn as nn 3 | import torch 4 | from MRT.SubLayers import MultiHeadAttention, PositionwiseFeedForward 5 | 6 | 7 | __author__ = "Yu-Hsiang Huang" 8 | 9 | 10 | class EncoderLayer(nn.Module): 11 | ''' Compose with two layers ''' 12 | 13 | def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1): 14 | super(EncoderLayer, self).__init__() 15 | self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) 16 | self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout) 17 | 18 | def forward(self, enc_input, slf_attn_mask=None): 19 | enc_output, enc_slf_attn = self.slf_attn( 20 | enc_input, enc_input, enc_input, mask=slf_attn_mask) 21 | enc_output = self.pos_ffn(enc_output) 22 | return enc_output, enc_slf_attn 23 | 24 | 25 | class DecoderLayer(nn.Module): 26 | ''' Compose with three layers ''' 27 | 28 | def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1): 29 | super(DecoderLayer, self).__init__() 30 | self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) 31 | self.enc_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) 32 | self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout) 33 | 34 | def forward( 35 | self, dec_input, enc_output, 36 | slf_attn_mask=None, dec_enc_attn_mask=None): 37 | #dec_output, dec_slf_attn = self.slf_attn( 38 | # dec_input, dec_input, dec_input, mask=slf_attn_mask) 39 | dec_output, dec_enc_attn = self.enc_attn( 40 | dec_input, enc_output, enc_output, mask=dec_enc_attn_mask) 41 | dec_output = self.pos_ffn(dec_output) 42 | return dec_output, None, dec_enc_attn#dec_slf_attn, dec_enc_attn 43 | 44 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import torch 3 | import numpy as np 4 | 5 | class DATA(data.Dataset): 6 | def __init__(self): 7 | 8 | self.data=np.load('./mocap/train_3_120_mocap.npy',allow_pickle=True) 9 | 10 | self.len=len(self.data) 11 | 12 | # if joints==15: 13 | # use=[0,1,2,3,6,7,8,14,16,17,18,20,24,25,27] 14 | # self.data=self.data.reshape(self.data.shape[0],3,-1,31,3) 15 | # self.data=self.data[:,:,:,use,:] 16 | # self.data=self.data.reshape(self.data.shape[0],3,-1,45) 17 | 18 | 19 | def __getitem__(self, index): 20 | 21 | input_seq=self.data[index][:,:30,:][:,::2,:]#input, 30 fps to 15 fps 22 | output_seq=self.data[index][:,30:,:][:,::2,:]#output, 30 fps to 15 fps 23 | last_input=input_seq[:,-1:,:] 24 | output_seq=np.concatenate([last_input,output_seq],axis=1) 25 | 26 | return input_seq,output_seq 27 | 28 | 29 | 30 | def __len__(self): 31 | return self.len 32 | 33 | 34 | 35 | class TESTDATA(data.Dataset): 36 | def __init__(self,dataset='mocap'): 37 | 38 | if dataset=='mocap': 39 | self.data=np.load('./mocap/test_3_120_mocap.npy',allow_pickle=True) 40 | 41 | 42 | # use=[0,1,2,3,6,7,8,14,16,17,18,20,24,25,27] 43 | # self.data=self.data 44 | # self.data=self.data.reshape(self.data.shape[0],self.data.shape[1],-1,31,3) 45 | # self.data=self.data[:,:,:,use,:] 46 | # self.data=self.data.reshape(self.data.shape[0],self.data.shape[1],-1,45) 47 | 48 | if dataset=='mupots': 49 | self.data=np.load('./mupots3d/mupots_120_3persons.npy',allow_pickle=True) 50 | 51 | self.len=len(self.data) 52 | 53 | def __getitem__(self, index): 54 | 55 | input_seq=self.data[index][:,:30,:][:,::2,:]#input, 30 fps to 15 fps 56 | output_seq=self.data[index][:,30:,:][:,::2,:]#output, 30 fps to 15 fps 57 | last_input=input_seq[:,-1:,:] 58 | output_seq=np.concatenate([last_input,output_seq],axis=1) 59 | 60 | return input_seq,output_seq 61 | 62 | def __len__(self): 63 | return self.len 64 | -------------------------------------------------------------------------------- /mocap/vis.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from mpl_toolkits.mplot3d import Axes3D 4 | 5 | data=np.load('two_train_4seconds_2.npy',allow_pickle=True) 6 | eg=1 7 | data_list=data[eg] 8 | 9 | data_list=data_list.reshape(-1,120,31,3) 10 | data_list=data_list*0.1*1.8/3 # scale 11 | #no need to scale if using the mix_mocap data 12 | 13 | body_edges = np.array( 14 | [[0,1], [1,2],[2,3],[3,4], 15 | [4,5],[0,6],[6,7],[7,8],[8,9],[9,10],[0,11],[11,12],[12,13],[13,14],[14,15],[15,16],[13,17],[17,18],[18,19],[19,20],[21,22],[20,23],[13,24],[24,25],[25,26],[26,27],[27,28],[28,29],[27,30]] 16 | ) 17 | 18 | ''' 19 | if use the 15 joints in common 20 | use=[0,1,2,3,6,7,8,14,16,17,18,20,24,25,27] 21 | data_list=data_list.reshape(-1,120,15,3) 22 | data_list=data_list[:,:,[0,1,4,7,2,5,8,12,15,16,18,20,17,19,21],:] 23 | body_edges = np.array( 24 | [[0,1], [1,2],[2,3],[0,4], 25 | [4,5],[5,6],[0,7],[7,8],[7,9],[9,10],[10,11],[7,12],[12,13],[13,14]] 26 | ) 27 | ''' 28 | 29 | fig = plt.figure(figsize=(10, 4.5)) 30 | 31 | ax = fig.add_subplot(111, projection='3d') 32 | 33 | plt.ion() 34 | 35 | 36 | length_=data_list.shape[1] 37 | 38 | i=0 39 | while i < length_: 40 | ax.lines = [] 41 | for j in range(len(data_list)): 42 | 43 | xs=data_list[j,i,:,0] 44 | ys=data_list[j,i,:,1] 45 | zs=data_list[j,i,:,2] 46 | #print(xs) 47 | ax.plot( zs,xs, ys, 'y.') 48 | 49 | 50 | plot_edge=True 51 | if plot_edge: 52 | for edge in body_edges: 53 | x=[data_list[j,i,edge[0],0],data_list[j,i,edge[1],0]] 54 | y=[data_list[j,i,edge[0],1],data_list[j,i,edge[1],1]] 55 | z=[data_list[j,i,edge[0],2],data_list[j,i,edge[1],2]] 56 | if i>=30: 57 | ax.plot(z,x, y, 'green') 58 | else: 59 | ax.plot(z,x, y, 'blue') 60 | 61 | 62 | ax.set_xlim3d([-2 , 2]) 63 | ax.set_ylim3d([-2 , 2]) 64 | ax.set_zlim3d([-0, 2]) 65 | #ax.set_axis_off() 66 | ax.set_xlabel("x") 67 | ax.set_ylabel("y") 68 | ax.set_zlabel("z") 69 | 70 | plt.pause(0.01) 71 | i += 1 72 | 73 | 74 | plt.ioff() 75 | plt.show() -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MRT 2 | 3 | This is an implementation of the NeurIPS'21 paper "Multi-Person 3D Motion Prediction with Multi-Range Transformers". 4 | 5 | Please check our [paper](https://arxiv.org/pdf/2111.12073.pdf) and the [project webpage](https://jiashunwang.github.io/MRT/) for more details. 6 | 7 | We will also provide the code to fit our skeleton representation data to [SMPL](https://smpl.is.tue.mpg.de/) data. 8 | 9 | ## Citation 10 | 11 | If you find our code or paper useful, please consider citing: 12 | ``` 13 | @article{wang2021multi, 14 | title={Multi-Person 3D Motion Prediction with Multi-Range Transformers}, 15 | author={Wang, Jiashun and Xu, Huazhe and Narasimhan, Medhini and Wang, Xiaolong}, 16 | journal={Advances in Neural Information Processing Systems}, 17 | volume={34}, 18 | year={2021} 19 | } 20 | ``` 21 | 22 | ## Dependencies 23 | 24 | Requirements: 25 | - python3.6 26 | - pytorch==1.1.0 27 | - [torch_dct](https://github.com/zh217/torch-dct) 28 | - [AMCParser](https://github.com/CalciferZh/AMCParser) 29 | 30 | ## Datasets 31 | We provide the data preprocessing code of [CMU-Mocap](http://mocap.cs.cmu.edu/) and [MuPoTS-3D](http://vcai.mpi-inf.mpg.de/projects/SingleShotMultiPerson/) (others are coming soon). 32 | For CMU-Mocap, the dictionary tree is like 33 | ``` 34 | mocap 35 | ├── amc_parser.py 36 | ├── mix_mocap.py 37 | ├── preprocess_mocap.py 38 | ├── vis.py 39 | └── all_asfamc 40 | └── subjects 41 | ├── 01 42 | ... 43 | ``` 44 | After dowloading the original data, please try 45 | ``` 46 | python ./mocap/preprocess_mocap.py 47 | python ./mocap/mix_mocap.py 48 | ``` 49 | For MuPoTS-3D, the dictionary tree is like 50 | ``` 51 | mupots3d 52 | ├── preprocess_mupots.py 53 | ├── vis.py 54 | └── data 55 | ├── TS1 56 | ... 57 | ``` 58 | After dowloading the original data, please try 59 | ``` 60 | python ./mocap/preprocess_mupots.py 61 | ``` 62 | 63 | ## Training 64 | To train our model, please try 65 | ``` 66 | python train_mrt.py 67 | ``` 68 | 69 | ## Evaluation and visualization 70 | We provide the evaluation and visualization code in `test.py` 71 | 72 | ## Acknowledgement 73 | This work was supported, in part, by grants from DARPA LwLL, NSF CCF-2112665 (TILOS), NSF 1730158 CI-New: Cognitive Hardware and Software Ecosystem Community Infrastructure (CHASE-CI), NSF ACI-1541349 CC\*DNI Pacific Research Platform, and gifts from Qualcomm, TuSimple and Picsart. 74 | Part of our code is based on [attention-is-all-you-need-pytorch](https://github.com/jadore801120/attention-is-all-you-need-pytorch) and [AMCParser](https://github.com/CalciferZh/AMCParser). Many thanks! 75 | 76 | -------------------------------------------------------------------------------- /MRT/SubLayers.py: -------------------------------------------------------------------------------- 1 | ''' Define the sublayers in encoder/decoder layer ''' 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from MRT.Modules import ScaledDotProductAttention 6 | 7 | 8 | class MultiHeadAttention(nn.Module): 9 | ''' Multi-Head Attention module ''' 10 | 11 | def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1): 12 | super().__init__() 13 | 14 | self.n_head = n_head 15 | self.d_k = d_k 16 | self.d_v = d_v 17 | 18 | self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False) 19 | self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False) 20 | self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False) 21 | self.fc = nn.Linear(n_head * d_v, d_model, bias=False) 22 | 23 | self.attention = ScaledDotProductAttention(temperature=d_k ** 0.5) 24 | 25 | self.dropout = nn.Dropout(dropout) 26 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) 27 | 28 | 29 | def forward(self, q, k, v, mask=None): 30 | 31 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head 32 | sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1) 33 | 34 | residual = q 35 | 36 | # Pass through the pre-attention projection: b x lq x (n*dv) 37 | # Separate different heads: b x lq x n x dv 38 | q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) 39 | k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) 40 | v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) 41 | 42 | # Transpose for attention dot product: b x n x lq x dv 43 | q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) 44 | 45 | if mask is not None: 46 | mask = mask.unsqueeze(1) # For head axis broadcasting. 47 | 48 | q, attn = self.attention(q, k, v, mask=mask) 49 | 50 | # Transpose to move the head dimension back: b x lq x n x dv 51 | # Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv) 52 | q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1) 53 | q = self.dropout(self.fc(q)) 54 | q += residual 55 | 56 | q = self.layer_norm(q) 57 | 58 | return q, attn 59 | 60 | 61 | 62 | 63 | class PositionwiseFeedForward(nn.Module): 64 | ''' A two-feed-forward-layer module ''' 65 | 66 | def __init__(self, d_in, d_hid, dropout=0.1): 67 | super().__init__() 68 | self.w_1 = nn.Linear(d_in, d_hid) # position-wise 69 | self.w_2 = nn.Linear(d_hid, d_in) # position-wise 70 | self.layer_norm = nn.LayerNorm(d_in, eps=1e-6) 71 | self.dropout = nn.Dropout(dropout) 72 | 73 | def forward(self, x): 74 | 75 | residual = x 76 | 77 | x = self.w_2(F.relu(self.w_1(x))) 78 | x = self.dropout(x) 79 | x += residual 80 | 81 | x = self.layer_norm(x) 82 | 83 | return x 84 | -------------------------------------------------------------------------------- /mupots3d/preprocess_mupots.py: -------------------------------------------------------------------------------- 1 | import open3d as o3d 2 | import scipy.io as io 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | from mpl_toolkits.mplot3d import Axes3D 6 | import trimesh 7 | 8 | ################ 9 | # 3 persons 10 | final_data=[] 11 | for j in range(1,21,1): 12 | 13 | 14 | data=io.loadmat('./data/TS'+str(j)+'/annot.mat')['annotations'] 15 | if data.shape[1]!=3: 16 | continue 17 | 18 | 19 | #print(j) 20 | 21 | print(data.shape) 22 | 23 | v_total=[] 24 | for i in range(len(data)): 25 | v1=data[i][0][0][0][1].transpose(1,0) 26 | rot = trimesh.transformations.rotation_matrix(np.radians(180), [1, 0, 0])[:3,:3] 27 | v1=np.matmul(v1,rot) 28 | v2=data[i][1][0][0][1].transpose(1,0) 29 | #rot = trimesh.transformations.rotation_matrix(np.radians(90), [1, 0, 0])[:3,:3] 30 | v2=np.matmul(v2,rot) 31 | v3=data[i][2][0][0][1].transpose(1,0) 32 | #rot = trimesh.transformations.rotation_matrix(np.radians(90), [1, 0, 0])[:3,:3] 33 | v3=np.matmul(v3,rot) 34 | v=np.concatenate([v1,v2,v3]).reshape(3,17,3) 35 | v_total.append(v) 36 | temp=np.array(v_total) 37 | temp=temp.swapaxes(0,1) 38 | temp[:,:,:,1]=temp[:,:,:,1]-np.min(temp[:,:,:,1]) #foot on ground 39 | temp[:,:,:,0]=temp[:,:,:,0]-np.mean(temp[:,:,:,0]) #center 40 | temp[:,:,:,2]=temp[:,:,:,2]-np.mean(temp[:,:,:,2]) #center 41 | use=[14,11,12,13,8,9,10,1,0,5,6,7,2,3,4] #used joints and order 42 | temp_data=temp[:,:,use,:] 43 | 44 | for i in range(0,temp_data.shape[1],15): #down sample 45 | 46 | if (i+120)>temp_data.shape[1]: 47 | break 48 | final_data.append(temp_data[:,i:i+120,:,:]) 49 | #print(j) 50 | final_data=np.concatenate(final_data)*0.017*0.1*1.8/3 # scale 51 | final_data=final_data.reshape(-1,3,120,45) # n, 3 persons, 30 fps 4 seconds, 15 joints xyz coordinates 52 | 53 | np.save('mupots_120_3persons.npy',final_data) 54 | 55 | ################ 56 | # 2 persons 57 | 58 | final_data=[] 59 | for j in range(1,21,1): 60 | 61 | 62 | data=io.loadmat('./data/TS'+str(j)+'/annot.mat')['annotations'] 63 | 64 | if data.shape[1]==3: 65 | continue 66 | 67 | #print(j) 68 | 69 | print(data.shape) 70 | 71 | v_total=[] 72 | for i in range(len(data)): 73 | v1=data[i][0][0][0][1].transpose(1,0) 74 | rot = trimesh.transformations.rotation_matrix(np.radians(180), [1, 0, 0])[:3,:3] 75 | v1=np.matmul(v1,rot) 76 | v2=data[i][1][0][0][1].transpose(1,0) 77 | #rot = trimesh.transformations.rotation_matrix(np.radians(90), [1, 0, 0])[:3,:3] 78 | v2=np.matmul(v2,rot) 79 | 80 | v=np.concatenate([v1,v2]).reshape(2,17,3) 81 | v_total.append(v) 82 | temp=np.array(v_total) 83 | temp=temp.swapaxes(0,1) 84 | temp[:,:,:,1]=temp[:,:,:,1]-np.min(temp[:,:,:,1]) #foot on ground 85 | temp[:,:,:,0]=temp[:,:,:,0]-np.mean(temp[:,:,:,0]) #center 86 | temp[:,:,:,2]=temp[:,:,:,2]-np.mean(temp[:,:,:,2]) #center 87 | use=[14,11,12,13,8,9,10,1,0,5,6,7,2,3,4] #used joints and order 88 | temp_data=temp[:,:,use,:] 89 | 90 | for i in range(0,temp_data.shape[1],15): #down sample 91 | 92 | if (i+120)>temp_data.shape[1]: 93 | break 94 | final_data.append(temp_data[:,i:i+120,:,:]) 95 | #print(j) 96 | final_data=np.concatenate(final_data)*0.017*0.1*1.8/3 # scale 97 | final_data=final_data.reshape(-1,2,120,45) # n, 2 persons, 30 fps 4 seconds, 15 joints xyz coordinates 98 | 99 | np.save('mupots_120_2persons.npy',final_data) -------------------------------------------------------------------------------- /train_mrt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | import numpy as np 4 | import torch_dct as dct #https://github.com/zh217/torch-dct 5 | import time 6 | 7 | from MRT.Models import Transformer,Discriminator 8 | from utils import disc_l2_loss,adv_disc_l2_loss 9 | from torch.autograd import Variable 10 | import torch.nn as nn 11 | from torch.nn import init 12 | 13 | 14 | 15 | from data import DATA 16 | dataset = DATA() 17 | batch_size=64 18 | 19 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True) 20 | 21 | from discriminator_data import D_DATA 22 | real_=D_DATA() 23 | 24 | real_motion_dataloader = torch.utils.data.DataLoader(real_, batch_size=batch_size, shuffle=True) 25 | real_motion_all=list(enumerate(real_motion_dataloader)) 26 | 27 | device='cuda' 28 | 29 | 30 | 31 | model = Transformer(d_word_vec=128, d_model=128, d_inner=1024, 32 | n_layers=3, n_head=8, d_k=64, d_v=64,device=device).to(device) 33 | 34 | discriminator = Discriminator(d_word_vec=45, d_model=45, d_inner=256, 35 | n_layers=3, n_head=8, d_k=32, d_v=32,device=device).to(device) 36 | 37 | lrate=0.0003 38 | lrate2=0.0005 39 | 40 | params = [ 41 | {"params": model.parameters(), "lr": lrate} 42 | ] 43 | optimizer = optim.Adam(params) 44 | params_d = [ 45 | {"params": discriminator.parameters(), "lr": lrate} 46 | ] 47 | optimizer_d = optim.Adam(params_d) 48 | 49 | 50 | for epoch in range(100): 51 | total_loss=0 52 | 53 | for j,data in enumerate(dataloader,0): 54 | 55 | use=None 56 | input_seq,output_seq=data 57 | input_seq=torch.tensor(input_seq,dtype=torch.float32).to(device) # batch, N_person, 15 (15 fps 1 second), 45 (15joints xyz) 58 | output_seq=torch.tensor(output_seq,dtype=torch.float32).to(device) # batch, N_persons, 46 (last frame of input + future 3 seconds), 45 (15joints xyz) 59 | 60 | # first 1 second predict future 1 second 61 | input_=input_seq.view(-1,15,input_seq.shape[-1]) # batch x n_person ,15: 15 fps, 1 second, 45: 15joints x 3 62 | 63 | output_=output_seq.view(output_seq.shape[0]*output_seq.shape[1],-1,input_seq.shape[-1]) 64 | 65 | input_ = dct.dct(input_) 66 | 67 | rec_=model.forward(input_[:,1:15,:]-input_[:,:14,:],dct.idct(input_[:,-1:,:]),input_seq,use) 68 | 69 | rec=dct.idct(rec_) 70 | 71 | # first 2 seconds predict 1 second 72 | new_input=torch.cat([input_[:,1:15,:]-input_[:,:14,:],dct.dct(rec_)],dim=-2) 73 | 74 | new_input_seq=torch.cat([input_seq,output_seq[:,:,1:16]],dim=-2) 75 | new_input_=dct.dct(new_input_seq.reshape(-1,30,45)) 76 | new_rec_=model.forward(new_input_[:,1:,:]-new_input_[:,:29,:],dct.idct(new_input_[:,-1:,:]),new_input_seq,use) 77 | 78 | new_rec=dct.idct(new_rec_) 79 | 80 | # first 3 seconds predict 1 second 81 | new_new_input_seq=torch.cat([input_seq,output_seq[:,:,1:31]],dim=-2) 82 | new_new_input_=dct.dct(new_new_input_seq.reshape(-1,45,45)) 83 | new_new_rec_=model.forward(new_new_input_[:,1:,:]-new_new_input_[:,:44,:],dct.idct(new_new_input_[:,-1:,:]),new_new_input_seq,use) 84 | 85 | new_new_rec=dct.idct(new_new_rec_) 86 | 87 | rec=torch.cat([rec,new_rec,new_new_rec],dim=-2) 88 | 89 | results=output_[:,:1,:] 90 | for i in range(1,31+15): 91 | results=torch.cat([results,output_[:,:1,:]+torch.sum(rec[:,:i,:],dim=1,keepdim=True)],dim=1) 92 | results=results[:,1:,:] 93 | 94 | loss=torch.mean((rec[:,:,:]-(output_[:,1:46,:]-output_[:,:45,:]))**2) 95 | 96 | 97 | if (j+1)%2==0: 98 | 99 | fake_motion=results 100 | 101 | disc_loss=disc_l2_loss(discriminator(fake_motion)) 102 | loss=loss+0.0005*disc_loss 103 | 104 | fake_motion=fake_motion.detach() 105 | 106 | real_motion=real_motion_all[int(j/2)][1][1] 107 | real_motion=real_motion.view(-1,46,45)[:,1:46,:].float().to(device) 108 | 109 | fake_disc_value = discriminator(fake_motion) 110 | real_disc_value = discriminator(real_motion) 111 | 112 | d_motion_disc_real, d_motion_disc_fake, d_motion_disc_loss = adv_disc_l2_loss(real_disc_value, fake_disc_value) 113 | 114 | optimizer_d.zero_grad() 115 | d_motion_disc_loss.backward() 116 | optimizer_d.step() 117 | 118 | 119 | optimizer.zero_grad() 120 | loss.backward() 121 | optimizer.step() 122 | 123 | total_loss=total_loss+loss 124 | 125 | print('epoch:',epoch,'loss:',total_loss/(j+1)) 126 | if (epoch+1)%5==0: 127 | save_path=f'./saved_model/{epoch}.model' 128 | torch.save(model.state_dict(),save_path) 129 | 130 | 131 | 132 | 133 | -------------------------------------------------------------------------------- /mocap/mix_mocap.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | 4 | two_train=np.load('two_train_4seconds_2.npy',allow_pickle=True) 5 | one_train=np.load('one_train_4seconds_30.npy',allow_pickle=True) 6 | 7 | print(two_train.shape) 8 | print(one_train.shape) 9 | 10 | 11 | # 3000 sequences have 2 subjects and 1 single subject, 3000 sequences have 3 single subject 12 | 13 | two_sample=np.random.choice(len(two_train),3000) 14 | one_sample=np.random.choice(len(one_train),3000+3000*3) 15 | 16 | one=one_sample[:3000] #mix with two 17 | 18 | one_1=one_sample[3000:6000] 19 | one_2=one_sample[6000:9000] 20 | one_3=one_sample[9000:12000] 21 | 22 | data=[] 23 | for i in range(6000): 24 | #3000 sequences have 2 subjects and 1 single subject 25 | if i<3000: 26 | two_person=two_train[two_sample[i]] 27 | one_person=one_train[one[i]] 28 | 29 | #random initialization 30 | two_person[:,:,:,[0,2]]=two_person[:,:,:,[0,2]]+np.array([np.random.randint(-50,50),np.random.randint(-50,50)]) 31 | one_person[:,:,:,[0,2]]=one_person[:,:,:,[0,2]]+np.array([np.random.randint(-50,50),np.random.randint(-50,50)]) 32 | temp=np.concatenate([one_person,two_person]) 33 | #put the whole scene into the center 34 | temp[:,:,:,0]=temp[:,:,:,0]-np.mean(temp[:,:,:,0]) 35 | temp[:,:,:,2]=temp[:,:,:,2]-np.mean(temp[:,:,:,2]) 36 | temp=temp.reshape(3,120,-1) 37 | data.append(temp) 38 | 39 | #3000 sequences have 3 single subject 40 | else: 41 | one_person_1=one_train[one_1[i-3000]] 42 | one_person_2=one_train[one_2[i-3000]] 43 | one_person_3=one_train[one_3[i-3000]] 44 | one_person_1[:,:,:,[0,2]]=one_person_1[:,:,:,[0,2]]+np.array([np.random.randint(-50,50),np.random.randint(-50,50)]) 45 | one_person_2[:,:,:,[0,2]]=one_person_2[:,:,:,[0,2]]+np.array([np.random.randint(-50,50),np.random.randint(-50,50)]) 46 | one_person_3[:,:,:,[0,2]]=one_person_3[:,:,:,[0,2]]+np.array([np.random.randint(-50,50),np.random.randint(-50,50)]) 47 | temp=np.concatenate([one_person_1,one_person_2,one_person_3]) 48 | temp[:,:,:,0]=temp[:,:,:,0]-np.mean(temp[:,:,:,0]) 49 | temp[:,:,:,2]=temp[:,:,:,2]-np.mean(temp[:,:,:,2]) 50 | temp=temp.reshape(3,120,-1) 51 | data.append(temp) 52 | 53 | data=np.array(data) # 6000 sequences, 3 persons, 120 (30 fps 4 seconds), 93 joints xyz (31x3) 54 | print(data.shape) 55 | 56 | use=[0,1,2,3,6,7,8,14,16,17,18,20,24,25,27] #used joints and order 57 | data=data.reshape(data.shape[0],3,-1,31,3) 58 | data=data[:,:,:,use,:] 59 | data=data.reshape(data.shape[0],3,-1,45) 60 | #In order to mix the data from different sources, we scale different data respectively in this code. 61 | #This may make the result slightly different from the table in the paper. 62 | data=data*0.1*1.8/3 63 | np.save('train_3_120_mocap.npy',data) 64 | 65 | 66 | 67 | ########################################################################### 68 | 69 | #test data 70 | 71 | two_test=np.load('two_test_4seconds_2.npy',allow_pickle=True) 72 | one_test=np.load('one_test_4seconds_30.npy',allow_pickle=True) 73 | 74 | print(two_test.shape) 75 | print(one_test.shape) 76 | 77 | #400 sequences have 2 subjects and 1 single subject 78 | #400 sequences have 3 single subject 79 | 80 | two_sample=np.random.choice(len(two_test),400) 81 | one_sample=np.random.choice(len(one_test),400+400*3) 82 | 83 | one_1=one_sample[400:800] 84 | one_2=one_sample[800:1200] 85 | one_3=one_sample[1200:1600] 86 | 87 | data=[] 88 | for i in range(800): 89 | #800 sequences have 2 subjects and 1 single subject 90 | if i<400: 91 | two_person=two_test[two_sample[i]] 92 | one_person=one_test[one_sample[i]] 93 | two_person[:,:,:,[0,2]]=two_person[:,:,:,[0,2]]+np.array([np.random.randint(-50,50),np.random.randint(-50,50)]) 94 | one_person[:,:,:,[0,2]]=one_person[:,:,:,[0,2]]+np.array([np.random.randint(-50,50),np.random.randint(-50,50)]) 95 | temp=np.concatenate([one_person,two_person]) 96 | temp[:,:,:,0]=temp[:,:,:,0]-np.mean(temp[:,:,:,0]) 97 | temp[:,:,:,2]=temp[:,:,:,2]-np.mean(temp[:,:,:,2]) 98 | temp=temp.reshape(3,120,-1) 99 | data.append(temp) 100 | 101 | 102 | else: 103 | one_person_1=one_test[one_1[i-400]] 104 | one_person_2=one_test[one_2[i-400]] 105 | one_person_3=one_test[one_3[i-400]] 106 | one_person_1[:,:,:,[0,2]]=one_person_1[:,:,:,[0,2]]+np.array([np.random.randint(-50,50),np.random.randint(-50,50)]) 107 | one_person_2[:,:,:,[0,2]]=one_person_2[:,:,:,[0,2]]+np.array([np.random.randint(-50,50),np.random.randint(-50,50)]) 108 | one_person_3[:,:,:,[0,2]]=one_person_3[:,:,:,[0,2]]+np.array([np.random.randint(-50,50),np.random.randint(-50,50)]) 109 | temp=np.concatenate([one_person_1,one_person_2,one_person_3]) 110 | temp[:,:,:,0]=temp[:,:,:,0]-np.mean(temp[:,:,:,0]) 111 | temp[:,:,:,2]=temp[:,:,:,2]-np.mean(temp[:,:,:,2]) 112 | temp=temp.reshape(3,120,-1) 113 | data.append(temp) 114 | 115 | data=np.array(data) 116 | 117 | use=[0,1,2,3,6,7,8,14,16,17,18,20,24,25,27] #used joints and order 118 | data=data.reshape(data.shape[0],3,-1,31,3) 119 | data=data[:,:,:,use,:] 120 | data=data.reshape(data.shape[0],3,-1,45) 121 | data=data*0.1*1.8/3 # scale 122 | print(data.shape) 123 | np.save('test_3_120_mocap.npy',data) 124 | 125 | 126 | 127 | ########################################################################### 128 | 129 | #discriminator data 130 | one_train=np.load('one_train_4seconds_30.npy',allow_pickle=True) 131 | print(one_train.shape) 132 | 133 | # 6000 have 3 single subject 134 | 135 | one_sample=np.random.choice(len(one_train),6000*3) 136 | 137 | 138 | 139 | one_1=one_sample[:6000] 140 | one_2=one_sample[6000:12000] 141 | one_3=one_sample[12000:] 142 | 143 | data=[] 144 | for i in range(6000): 145 | 146 | one_person_1=one_train[one_1[i]] 147 | one_person_2=one_train[one_2[i]] 148 | one_person_3=one_train[one_3[i]] 149 | one_person_1[:,:,:,[0,2]]=one_person_1[:,:,:,[0,2]]+np.array([np.random.randint(-50,50),np.random.randint(-50,50)]) 150 | one_person_2[:,:,:,[0,2]]=one_person_2[:,:,:,[0,2]]+np.array([np.random.randint(-50,50),np.random.randint(-50,50)]) 151 | one_person_3[:,:,:,[0,2]]=one_person_3[:,:,:,[0,2]]+np.array([np.random.randint(-50,50),np.random.randint(-50,50)]) 152 | temp=np.concatenate([one_person_1,one_person_2,one_person_3]) 153 | temp[:,:,:,0]=temp[:,:,:,0]-np.mean(temp[:,:,:,0]) 154 | temp[:,:,:,2]=temp[:,:,:,2]-np.mean(temp[:,:,:,2]) 155 | temp=temp.reshape(3,120,-1) 156 | data.append(temp) 157 | 158 | data=np.array(data) 159 | 160 | use=[0,1,2,3,6,7,8,14,16,17,18,20,24,25,27] #used joints and order 161 | data=data.reshape(data.shape[0],3,-1,31,3) 162 | data=data[:,:,:,use,:] 163 | data=data.reshape(data.shape[0],3,-1,45) 164 | data=data*0.1*1.8/3 # scale 165 | print(data.shape) 166 | np.save('discriminator_3_120_mocap.npy',data) 167 | -------------------------------------------------------------------------------- /mocap/preprocess_mocap.py: -------------------------------------------------------------------------------- 1 | from amc_parser import * 2 | import numpy as np 3 | import os 4 | 5 | #two subjects data 6 | 7 | data=[] 8 | test_data=[] 9 | for ii in range(4): 10 | 11 | # 18 19 20 21 22 23 33 34 are two subjects data 12 | if ii==0: 13 | A='18' 14 | B='19' 15 | if ii==1: 16 | A='20' 17 | B='21' 18 | if ii==2: 19 | A='22' 20 | B='23' 21 | if ii==3: 22 | A='33' 23 | B='34' 24 | 25 | motion_list_A_All=[] 26 | motion_list_A_test=[] 27 | asf_path = './all_asfamc/subjects/'+A+'/'+A+'.asf' 28 | iii=0 29 | for each in sorted(os.listdir('./all_asfamc/subjects/'+A+'/')): 30 | if each[-3:]!='amc': 31 | continue 32 | print(each) 33 | amc_path = './all_asfamc/subjects/'+A+'/'+each 34 | joints = parse_asf(asf_path) 35 | motions = parse_amc(amc_path) 36 | length=len(motions) 37 | 38 | if (iii%4==1) and (ii!=3): #just an example 39 | print('test') 40 | motion_list_A=[] 41 | for i in range(0,length,4): 42 | frame_idx = i 43 | joints['root'].set_motion(motions[frame_idx]) 44 | joints_list=[] 45 | for joint in joints.values(): 46 | xyz=np.array([joint.coordinate[0],\ 47 | joint.coordinate[1],joint.coordinate[2]]).squeeze(1) 48 | joints_list.append(xyz) 49 | motion_list_A.append(np.array(joints_list)) 50 | motion_list_A_test.append(motion_list_A) 51 | 52 | 53 | else: 54 | if ii==3 and iii%4==1: 55 | continue 56 | 57 | print('train') 58 | motion_list_A=[] 59 | for i in range(0,length,4): 60 | frame_idx = i 61 | joints['root'].set_motion(motions[frame_idx]) 62 | joints_list=[] 63 | for joint in joints.values(): 64 | xyz=np.array([joint.coordinate[0],\ 65 | joint.coordinate[1],joint.coordinate[2]]).squeeze(1) 66 | joints_list.append(xyz) 67 | motion_list_A.append(np.array(joints_list)) 68 | motion_list_A_All.append(motion_list_A) 69 | 70 | iii=iii+1 71 | 72 | motion_list_B_All=[] 73 | motion_list_B_test=[] 74 | asf_path_2 = './all_asfamc/subjects/'+B+'/'+B+'.asf' 75 | iii=0 76 | for each in sorted(os.listdir('./all_asfamc/subjects/'+B+'/')): 77 | if each[-3:]!='amc': 78 | continue 79 | print(each) 80 | amc_path_2 = './all_asfamc/subjects/'+B+'/'+each 81 | joints_2 = parse_asf(asf_path_2) 82 | motions_2 = parse_amc(amc_path_2) 83 | length=len(motions_2) 84 | 85 | if (iii%4==1) and (ii!=3): 86 | print('test') 87 | motion_list_B=[] 88 | for i in range(0,length,4): 89 | frame_idx = i 90 | joints_2['root'].set_motion(motions_2[frame_idx]) 91 | joints_list_2=[] 92 | for joint in joints_2.values(): 93 | xyz=np.array([joint.coordinate[0],\ 94 | joint.coordinate[1],joint.coordinate[2]]).squeeze(1) 95 | joints_list_2.append(xyz) 96 | motion_list_B.append(np.array(joints_list_2)) 97 | motion_list_B_test.append(motion_list_B) 98 | 99 | else: 100 | if ii==3 and iii%4==1: 101 | continue 102 | 103 | print('train') 104 | motion_list_B=[] 105 | for i in range(0,length,4): 106 | frame_idx = i 107 | joints_2['root'].set_motion(motions_2[frame_idx]) 108 | joints_list_2=[] 109 | for joint in joints_2.values(): 110 | xyz=np.array([joint.coordinate[0],\ 111 | joint.coordinate[1],joint.coordinate[2]]).squeeze(1) 112 | joints_list_2.append(xyz) 113 | motion_list_B.append(np.array(joints_list_2)) 114 | motion_list_B_All.append(motion_list_B) 115 | 116 | iii=iii+1 117 | 118 | scene_length=len(motion_list_B_All) 119 | 120 | #print(scene_length) 121 | for i in range(scene_length): 122 | motion_list_A=np.array(motion_list_A_All[i]) 123 | motion_list_B=np.array(motion_list_B_All[i]) 124 | #print(motion_list_A.shape[0]) 125 | for j in range(0,motion_list_A.shape[0],2): 126 | 127 | if j+120>motion_list_A.shape[0]: 128 | break 129 | A_=np.expand_dims(np.array(motion_list_A[j:j+120]),0) 130 | B_=np.expand_dims(np.array(motion_list_B[j:j+120]),0) 131 | motion=np.concatenate([A_,B_]) 132 | data.append(motion) 133 | 134 | scene_length=len(motion_list_B_test) 135 | for i in range(scene_length): 136 | motion_list_A=np.array(motion_list_A_test[i]) 137 | motion_list_B=np.array(motion_list_B_test[i]) 138 | #print(motion_list_A.shape[0]) 139 | for j in range(0,motion_list_A.shape[0],2): #down sample 140 | 141 | if j+120>motion_list_A.shape[0]: 142 | break 143 | A_=np.expand_dims(np.array(motion_list_A[j:j+120]),0) # 120: 30 fps, 4 seconds 144 | B_=np.expand_dims(np.array(motion_list_B[j:j+120]),0) 145 | motion=np.concatenate([A_,B_]) 146 | test_data.append(motion) 147 | print(ii) 148 | 149 | np.save('two_train_4seconds_2.npy',np.array(data)) 150 | np.save('two_test_4seconds_2.npy',np.array(test_data)) 151 | 152 | ######################################################################## 153 | 154 | #one subject data 155 | 156 | data=[] 157 | test_data=[] 158 | for ii in sorted(os.listdir('./all_asfamc/subjects/')): 159 | 160 | motion_list_A_All=[] 161 | motion_list_A_test=[] 162 | asf_path = './all_asfamc/subjects/'+ii+'/'+ii+'.asf' 163 | iii=0 164 | for each in sorted(os.listdir('./all_asfamc/subjects/'+ii+'/')): 165 | if each[-3:]!='amc': 166 | continue 167 | amc_path = './all_asfamc/subjects/'+ii+'/'+each 168 | joints = parse_asf(asf_path) 169 | motions = parse_amc(amc_path) 170 | length=len(motions) 171 | if iii%4!=1: 172 | print('train') 173 | motion_list_A=[] 174 | for i in range(0,length,4): 175 | frame_idx = i 176 | joints['root'].set_motion(motions[frame_idx]) 177 | joints_list=[] 178 | for joint in joints.values(): 179 | xyz=np.array([joint.coordinate[0],\ 180 | joint.coordinate[1],joint.coordinate[2]]).squeeze(1) 181 | joints_list.append(xyz) 182 | motion_list_A.append(np.array(joints_list)) 183 | motion_list_A_All.append(motion_list_A) 184 | else: 185 | print('test') 186 | motion_list_A=[] 187 | for i in range(0,length,4): 188 | frame_idx = i 189 | joints['root'].set_motion(motions[frame_idx]) 190 | joints_list=[] 191 | for joint in joints.values(): 192 | xyz=np.array([joint.coordinate[0],\ 193 | joint.coordinate[1],joint.coordinate[2]]).squeeze(1) 194 | joints_list.append(xyz) 195 | motion_list_A.append(np.array(joints_list)) 196 | motion_list_A_test.append(motion_list_A) 197 | iii=iii+1 198 | scene_length=len(motion_list_A_All) 199 | for i in range(scene_length): 200 | motion_list_A=np.array(motion_list_A_All[i]) 201 | for j in range(0,motion_list_A.shape[0],30): #down sample 202 | if (j+120)>motion_list_A.shape[0]: 203 | break 204 | A=np.expand_dims(np.array(motion_list_A[j:j+120]),0) 205 | data.append(A) 206 | 207 | scene_length=len(motion_list_A_test) 208 | for i in range(scene_length): 209 | motion_list_A=np.array(motion_list_A_test[i]) 210 | for j in range(0,motion_list_A.shape[0],30): 211 | if (j+120)>motion_list_A.shape[0]: 212 | break 213 | A=np.expand_dims(np.array(motion_list_A[j:j+120]),0) 214 | test_data.append(A) 215 | print(ii) 216 | np.save('one_train_4seconds_30.npy',np.array(data)) 217 | np.save('one_test_4seconds_30.npy',np.array(test_data)) 218 | 219 | 220 | #use mix_mocap.py to mix two subjects and one subject 221 | -------------------------------------------------------------------------------- /mocap/amc_parser.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from transforms3d.euler import euler2mat 4 | from mpl_toolkits.mplot3d import Axes3D 5 | 6 | 7 | class Joint: 8 | def __init__(self, name, direction, length, axis, dof, limits): 9 | """ 10 | Definition of basic joint. The joint also contains the information of the 11 | bone between it's parent joint and itself. Refer 12 | [here](https://research.cs.wisc.edu/graphics/Courses/cs-838-1999/Jeff/ASF-AMC.html) 13 | for detailed description for asf files. 14 | 15 | Parameter 16 | --------- 17 | name: Name of the joint defined in the asf file. There should always be one 18 | root joint. String. 19 | 20 | direction: Default direction of the joint(bone). The motions are all defined 21 | based on this default pose. 22 | 23 | length: Length of the bone. 24 | 25 | axis: Axis of rotation for the bone. 26 | 27 | dof: Degree of freedom. Specifies the number of motion channels and in what 28 | order they appear in the AMC file. 29 | 30 | limits: Limits on each of the channels in the dof specification 31 | 32 | """ 33 | self.name = name 34 | self.direction = np.reshape(direction, [3, 1]) 35 | self.length = length 36 | axis = np.deg2rad(axis) 37 | self.C = euler2mat(*axis) 38 | self.Cinv = np.linalg.inv(self.C) 39 | self.limits = np.zeros([3, 2]) 40 | for lm, nm in zip(limits, dof): 41 | if nm == 'rx': 42 | self.limits[0] = lm 43 | elif nm == 'ry': 44 | self.limits[1] = lm 45 | else: 46 | self.limits[2] = lm 47 | self.parent = None 48 | self.children = [] 49 | self.coordinate = None 50 | self.matrix = None 51 | 52 | def set_motion(self, motion): 53 | if self.name == 'root': 54 | self.coordinate = np.reshape(np.array(motion['root'][:3]), [3, 1]) 55 | rotation = np.deg2rad(motion['root'][3:]) 56 | self.matrix = self.C.dot(euler2mat(*rotation)).dot(self.Cinv) 57 | else: 58 | idx = 0 59 | rotation = np.zeros(3) 60 | for axis, lm in enumerate(self.limits): 61 | if not np.array_equal(lm, np.zeros(2)): 62 | rotation[axis] = motion[self.name][idx] 63 | idx += 1 64 | rotation = np.deg2rad(rotation) 65 | self.matrix = self.parent.matrix.dot(self.C).dot(euler2mat(*rotation)).dot(self.Cinv) 66 | self.coordinate = self.parent.coordinate + self.length * self.matrix.dot(self.direction) 67 | for child in self.children: 68 | child.set_motion(motion) 69 | 70 | def draw(self): 71 | joints = self.to_dict() 72 | fig = plt.figure() 73 | ax = Axes3D(fig) 74 | 75 | ax.set_xlim3d(-50, 10) 76 | ax.set_ylim3d(-20, 40) 77 | ax.set_zlim3d(-20, 40) 78 | 79 | xs, ys, zs = [], [], [] 80 | for joint in joints.values(): 81 | xs.append(joint.coordinate[0, 0]) 82 | ys.append(joint.coordinate[1, 0]) 83 | zs.append(joint.coordinate[2, 0]) 84 | plt.plot(zs, xs, ys, 'b.') 85 | 86 | for joint in joints.values(): 87 | child = joint 88 | if child.parent is not None: 89 | parent = child.parent 90 | xs = [child.coordinate[0, 0], parent.coordinate[0, 0]] 91 | ys = [child.coordinate[1, 0], parent.coordinate[1, 0]] 92 | zs = [child.coordinate[2, 0], parent.coordinate[2, 0]] 93 | plt.plot(zs, xs, ys, 'r') 94 | plt.show() 95 | 96 | def to_dict(self): 97 | ret = {self.name: self} 98 | for child in self.children: 99 | ret.update(child.to_dict()) 100 | return ret 101 | 102 | def pretty_print(self): 103 | print('===================================') 104 | print('joint: %s' % self.name) 105 | print('direction:') 106 | print(self.direction) 107 | print('limits:', self.limits) 108 | print('parent:', self.parent) 109 | print('children:', self.children) 110 | 111 | 112 | def read_line(stream, idx): 113 | if idx >= len(stream): 114 | return None, idx 115 | line = stream[idx].strip().split() 116 | idx += 1 117 | return line, idx 118 | 119 | 120 | def parse_asf(file_path): 121 | '''read joint data only''' 122 | with open(file_path) as f: 123 | content = f.read().splitlines() 124 | 125 | for idx, line in enumerate(content): 126 | # meta infomation is ignored 127 | if line == ':bonedata': 128 | content = content[idx+1:] 129 | break 130 | 131 | # read joints 132 | joints = {'root': Joint('root', np.zeros(3), 0, np.zeros(3), [], [])} 133 | idx = 0 134 | while True: 135 | # the order of each section is hard-coded 136 | 137 | line, idx = read_line(content, idx) 138 | 139 | if line[0] == ':hierarchy': 140 | break 141 | 142 | assert line[0] == 'begin' 143 | 144 | line, idx = read_line(content, idx) 145 | assert line[0] == 'id' 146 | 147 | line, idx = read_line(content, idx) 148 | assert line[0] == 'name' 149 | name = line[1] 150 | 151 | line, idx = read_line(content, idx) 152 | assert line[0] == 'direction' 153 | direction = np.array([float(axis) for axis in line[1:]]) 154 | 155 | # skip length 156 | line, idx = read_line(content, idx) 157 | assert line[0] == 'length' 158 | length = float(line[1]) 159 | 160 | line, idx = read_line(content, idx) 161 | assert line[0] == 'axis' 162 | assert line[4] == 'XYZ' 163 | 164 | axis = np.array([float(axis) for axis in line[1:-1]]) 165 | 166 | dof = [] 167 | limits = [] 168 | 169 | line, idx = read_line(content, idx) 170 | if line[0] == 'dof': 171 | dof = line[1:] 172 | for i in range(len(dof)): 173 | line, idx = read_line(content, idx) 174 | if i == 0: 175 | assert line[0] == 'limits' 176 | line = line[1:] 177 | assert len(line) == 2 178 | mini = float(line[0][1:]) 179 | maxi = float(line[1][:-1]) 180 | limits.append((mini, maxi)) 181 | 182 | line, idx = read_line(content, idx) 183 | 184 | assert line[0] == 'end' 185 | joints[name] = Joint( 186 | name, 187 | direction, 188 | length, 189 | axis, 190 | dof, 191 | limits 192 | ) 193 | 194 | # read hierarchy 195 | assert line[0] == ':hierarchy' 196 | 197 | line, idx = read_line(content, idx) 198 | 199 | assert line[0] == 'begin' 200 | 201 | while True: 202 | line, idx = read_line(content, idx) 203 | if line[0] == 'end': 204 | break 205 | assert len(line) >= 2 206 | for joint_name in line[1:]: 207 | joints[line[0]].children.append(joints[joint_name]) 208 | for nm in line[1:]: 209 | joints[nm].parent = joints[line[0]] 210 | 211 | return joints 212 | 213 | 214 | def parse_amc(file_path): 215 | with open(file_path) as f: 216 | content = f.read().splitlines() 217 | 218 | for idx, line in enumerate(content): 219 | if line == ':DEGREES': 220 | content = content[idx+1:] 221 | break 222 | 223 | frames = [] 224 | idx = 0 225 | line, idx = read_line(content, idx) 226 | assert line[0].isnumeric(), line 227 | EOF = False 228 | while not EOF: 229 | joint_degree = {} 230 | while True: 231 | line, idx = read_line(content, idx) 232 | if line is None: 233 | EOF = True 234 | break 235 | if line[0].isnumeric(): 236 | break 237 | joint_degree[line[0]] = [float(deg) for deg in line[1:]] 238 | frames.append(joint_degree) 239 | return frames 240 | 241 | 242 | def test_all(): 243 | import os 244 | lv0 = './data' 245 | lv1s = os.listdir(lv0) 246 | for lv1 in lv1s: 247 | lv2s = os.listdir('/'.join([lv0, lv1])) 248 | asf_path = '%s/%s/%s.asf' % (lv0, lv1, lv1) 249 | print('parsing %s' % asf_path) 250 | joints = parse_asf(asf_path) 251 | motions = parse_amc('./nopose.amc') 252 | joints['root'].set_motion(motions[0]) 253 | joints['root'].draw() 254 | 255 | # for lv2 in lv2s: 256 | # if lv2.split('.')[-1] != 'amc': 257 | # continue 258 | # amc_path = '%s/%s/%s' % (lv0, lv1, lv2) 259 | # print('parsing amc %s' % amc_path) 260 | # motions = parse_amc(amc_path) 261 | # for idx, motion in enumerate(motions): 262 | # print('setting motion %d' % idx) 263 | # joints['root'].set_motion(motion) 264 | 265 | 266 | if __name__ == '__main__': 267 | #test_all() 268 | asf_path = '../all_asfamc/subjects/20/20.asf' 269 | amc_path = '../all_asfamc/subjects/20/20_06.amc' 270 | joints = parse_asf(asf_path) 271 | motions = parse_amc(amc_path) 272 | length=len(motions) 273 | 274 | asf_path_2 = '../all_asfamc/subjects/21/21.asf' 275 | amc_path_2 = '../all_asfamc/subjects/21/21_06.amc' 276 | joints_2 = parse_asf(asf_path_2) 277 | motions_2 = parse_amc(amc_path_2) 278 | 279 | #joints['root'].draw() 280 | fig = plt.figure() 281 | 282 | ax = fig.add_subplot(111, projection='3d') 283 | 284 | plt.ion() 285 | i=0 286 | 287 | while i < length: 288 | frame_idx = i 289 | 290 | joints['root'].set_motion(motions[frame_idx]) 291 | ax.lines = [] 292 | 293 | xs, ys, zs = [], [], [] 294 | for joint in joints.values(): 295 | xs.append(joint.coordinate[0, 0]) 296 | ys.append(joint.coordinate[1, 0]) 297 | zs.append(joint.coordinate[2, 0]) 298 | 299 | ax.plot(zs, xs, ys, 'b.') 300 | 301 | for joint in joints.values(): 302 | child = joint 303 | if child.parent is not None: 304 | parent = child.parent 305 | xs = [child.coordinate[0, 0], parent.coordinate[0, 0]] 306 | ys = [child.coordinate[1, 0], parent.coordinate[1, 0]] 307 | zs = [child.coordinate[2, 0], parent.coordinate[2, 0]] 308 | ax.plot(zs, xs, ys, 'r') 309 | 310 | 311 | joints_2['root'].set_motion(motions_2[frame_idx]) 312 | #ax.lines = [] 313 | 314 | xs, ys, zs = [], [], [] 315 | for joint in joints_2.values(): 316 | xs.append(joint.coordinate[0, 0]) 317 | ys.append(joint.coordinate[1, 0]) 318 | zs.append(joint.coordinate[2, 0]) 319 | 320 | ax.plot(zs, xs, ys, 'b.') 321 | 322 | for joint in joints_2.values(): 323 | child = joint 324 | if child.parent is not None: 325 | parent = child.parent 326 | xs = [child.coordinate[0, 0], parent.coordinate[0, 0]] 327 | ys = [child.coordinate[1, 0], parent.coordinate[1, 0]] 328 | zs = [child.coordinate[2, 0], parent.coordinate[2, 0]] 329 | ax.plot(zs, xs, ys, 'r') 330 | 331 | ax.set_xlim3d(-50, 50) 332 | ax.set_ylim3d(-50, 50) 333 | ax.set_zlim3d(-10, 40) 334 | 335 | ax.set_xlabel("x") 336 | ax.set_ylabel("y") 337 | ax.set_zlabel("z") 338 | 339 | plt.pause(0.001) 340 | #print(i) 341 | i += 4 342 | if i >= length: 343 | i = 0 344 | plt.ioff() 345 | plt.show() -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch_dct as dct 4 | import time 5 | from MRT.Models import Transformer 6 | 7 | 8 | 9 | import matplotlib.pyplot as plt 10 | from mpl_toolkits.mplot3d import Axes3D 11 | 12 | 13 | import numpy as np 14 | import os 15 | 16 | from data import TESTDATA 17 | 18 | dataset_name='mupots' 19 | 20 | test_dataset = TESTDATA(dataset=dataset_name) 21 | test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False) 22 | 23 | device='cpu' 24 | 25 | batch_size=1 26 | 27 | 28 | model = Transformer(d_word_vec=128, d_model=128, d_inner=1024, 29 | n_layers=3, n_head=8, d_k=64, d_v=64,device=device).to(device) 30 | 31 | 32 | 33 | 34 | plot=False 35 | gt=False 36 | 37 | 38 | model.load_state_dict(torch.load('./saved_model/29.model',map_location=device)) 39 | 40 | 41 | body_edges = np.array( 42 | [[0,1], [1,2],[2,3],[0,4], 43 | [4,5],[5,6],[0,7],[7,8],[7,9],[9,10],[10,11],[7,12],[12,13],[13,14]] 44 | ) 45 | 46 | 47 | losses=[] 48 | 49 | total_loss=0 50 | loss_list1=[] 51 | loss_list2=[] 52 | loss_list3=[] 53 | with torch.no_grad(): 54 | model.eval() 55 | loss_list=[] 56 | for jjj,data in enumerate(test_dataloader,0): 57 | print(jjj) 58 | #if jjj!=20: 59 | # continue 60 | input_seq,output_seq=data 61 | 62 | input_seq=torch.tensor(input_seq,dtype=torch.float32).to(device) 63 | output_seq=torch.tensor(output_seq,dtype=torch.float32).to(device) 64 | n_joints=int(input_seq.shape[-1]/3) 65 | use=[input_seq.shape[1]] 66 | 67 | input_=input_seq.view(-1,15,input_seq.shape[-1]) 68 | 69 | 70 | output_=output_seq.view(output_seq.shape[0]*output_seq.shape[1],-1,input_seq.shape[-1]) 71 | 72 | input_ = dct.dct(input_) 73 | output__ = dct.dct(output_[:,:,:]) 74 | 75 | 76 | rec_=model.forward(input_[:,1:15,:]-input_[:,:14,:],dct.idct(input_[:,-1:,:]),input_seq,use) 77 | 78 | rec=dct.idct(rec_) 79 | 80 | results=output_[:,:1,:] 81 | for i in range(1,16): 82 | results=torch.cat([results,output_[:,:1,:]+torch.sum(rec[:,:i,:],dim=1,keepdim=True)],dim=1) 83 | results=results[:,1:,:] 84 | 85 | new_input_seq=torch.cat([input_seq,results.reshape(input_seq.shape)],dim=-2) 86 | new_input=dct.dct(new_input_seq.reshape(-1,30,45)) 87 | 88 | new_rec_=model.forward(new_input[:,1:,:]-new_input[:,:-1,:],dct.idct(new_input[:,-1:,:]),new_input_seq,use) 89 | 90 | 91 | new_rec=dct.idct(new_rec_) 92 | 93 | new_results=new_input_seq.reshape(-1,30,45)[:,-1:,:] 94 | for i in range(1,16): 95 | new_results=torch.cat([new_results,new_input_seq.reshape(-1,30,45)[:,-1:,:]+torch.sum(new_rec[:,:i,:],dim=1,keepdim=True)],dim=1) 96 | new_results=new_results[:,1:,:] 97 | 98 | results=torch.cat([results,new_results],dim=-2) 99 | 100 | rec=torch.cat([rec,new_rec],dim=-2) 101 | 102 | results=output_[:,:1,:] 103 | 104 | for i in range(1,16+15): 105 | results=torch.cat([results,output_[:,:1,:]+torch.sum(rec[:,:i,:],dim=1,keepdim=True)],dim=1) 106 | results=results[:,1:,:] 107 | 108 | new_new_input_seq=torch.cat([input_seq,results.unsqueeze(0)],dim=-2) 109 | new_new_input=dct.dct(new_new_input_seq.reshape(-1,45,45)) 110 | 111 | new_new_rec_=model.forward(new_new_input[:,1:,:]-new_new_input[:,:-1,:],dct.idct(new_new_input[:,-1:,:]),new_new_input_seq,use) 112 | 113 | 114 | new_new_rec=dct.idct(new_new_rec_) 115 | rec=torch.cat([rec,new_new_rec],dim=-2) 116 | 117 | results=output_[:,:1,:] 118 | 119 | for i in range(1,31+15): 120 | results=torch.cat([results,output_[:,:1,:]+torch.sum(rec[:,:i,:],dim=1,keepdim=True)],dim=1) 121 | results=results[:,1:,:] 122 | 123 | prediction_1=results[:,:15,:].view(results.shape[0],-1,n_joints,3) 124 | prediction_2=results[:,:30,:].view(results.shape[0],-1,n_joints,3) 125 | prediction_3=results[:,:45,:].view(results.shape[0],-1,n_joints,3) 126 | 127 | gt_1=output_seq[0][:,1:16,:].view(results.shape[0],-1,n_joints,3) 128 | gt_2=output_seq[0][:,1:31,:].view(results.shape[0],-1,n_joints,3) 129 | gt_3=output_seq[0][:,1:46,:].view(results.shape[0],-1,n_joints,3) 130 | 131 | if dataset_name=='mocap': 132 | #match the scale with the paper, also see line 63 in mix_mocap.py 133 | loss1=torch.sqrt(((prediction_1/1.8 - gt_1/1.8) ** 2).sum(dim=-1)).mean(dim=-1).mean(dim=-1).numpy().tolist() 134 | loss2=torch.sqrt(((prediction_2/1.8 - gt_2/1.8) ** 2).sum(dim=-1)).mean(dim=-1).mean(dim=-1).numpy().tolist() 135 | loss3=torch.sqrt(((prediction_3/1.8 - gt_3/1.8) ** 2).sum(dim=-1)).mean(dim=-1).mean(dim=-1).numpy().tolist() 136 | 137 | #pose with align 138 | # loss1=torch.sqrt((((prediction_1 - prediction_1[:,:,0:1,:] - gt_1 + gt_1[:,:,0:1,:])/1.8) ** 2).sum(dim=-1)).mean(dim=-1).mean(dim=-1).numpy().tolist() 139 | # loss2=torch.sqrt((((prediction_2 - prediction_2[:,:,0:1,:] - gt_2 + gt_2[:,:,0:1,:])/1.8) ** 2).sum(dim=-1)).mean(dim=-1).mean(dim=-1).numpy().tolist() 140 | # loss3=torch.sqrt((((prediction_3 - prediction_3[:,:,0:1,:] - gt_3 + gt_3[:,:,0:1,:])/1.8) ** 2).sum(dim=-1)).mean(dim=-1).mean(dim=-1).numpy().tolist() 141 | 142 | 143 | if dataset_name=='mupots': 144 | loss1=torch.sqrt(((prediction_1 - gt_1) ** 2).sum(dim=-1)).mean(dim=-1).mean(dim=-1).numpy().tolist() 145 | loss2=torch.sqrt(((prediction_2 - gt_2) ** 2).sum(dim=-1)).mean(dim=-1).mean(dim=-1).numpy().tolist() 146 | loss3=torch.sqrt(((prediction_3 - gt_3) ** 2).sum(dim=-1)).mean(dim=-1).mean(dim=-1).numpy().tolist() 147 | 148 | #pose with align 149 | # loss1=torch.sqrt(((prediction_1 - prediction_1[:,:,0:1,:] - gt_1 + gt_1[:,:,0:1,:]) ** 2).sum(dim=-1)).mean(dim=-1).mean(dim=-1).numpy().tolist() 150 | # loss2=torch.sqrt(((prediction_2 - prediction_2[:,:,0:1,:] - gt_2 + gt_2[:,:,0:1,:]) ** 2).sum(dim=-1)).mean(dim=-1).mean(dim=-1).numpy().tolist() 151 | # loss3=torch.sqrt(((prediction_3 - prediction_3[:,:,0:1,:] - gt_3 + gt_3[:,:,0:1,:]) ** 2).sum(dim=-1)).mean(dim=-1).mean(dim=-1).numpy().tolist() 152 | 153 | 154 | loss_list1.append(np.mean(loss1))#+loss1 155 | loss_list2.append(np.mean(loss2))#+loss2 156 | loss_list3.append(np.mean(loss3))#+loss3 157 | 158 | loss=torch.mean((rec[:,:,:]-(output_[:,1:46,:]-output_[:,:45,:]))**2) 159 | losses.append(loss) 160 | 161 | 162 | rec=results[:,:,:] 163 | 164 | rec=rec.reshape(results.shape[0],-1,n_joints,3) 165 | 166 | input_seq=input_seq.view(results.shape[0],15,n_joints,3) 167 | pred=torch.cat([input_seq,rec],dim=1) 168 | output_seq=output_seq.view(results.shape[0],-1,n_joints,3)[:,1:,:,:] 169 | all_seq=torch.cat([input_seq,output_seq],dim=1) 170 | 171 | 172 | pred=pred[:,:,:,:].cpu() 173 | all_seq=all_seq[:,:,:,:].cpu() 174 | 175 | 176 | if plot: 177 | fig = plt.figure(figsize=(10, 4.5)) 178 | fig.tight_layout() 179 | ax = fig.add_subplot(111, projection='3d') 180 | 181 | plt.ion() 182 | 183 | length=45+15 184 | length_=45+15 185 | i=0 186 | 187 | p_x=np.linspace(-10,10,15) 188 | p_y=np.linspace(-10,10,15) 189 | X,Y=np.meshgrid(p_x,p_y) 190 | 191 | 192 | while i < length_: 193 | 194 | ax.lines = [] 195 | 196 | for x_i in range(p_x.shape[0]): 197 | temp_x=[p_x[x_i],p_x[x_i]] 198 | temp_y=[p_y[0],p_y[-1]] 199 | z=[0,0] 200 | ax.plot(temp_x,temp_y,z,color='black',alpha=0.1) 201 | 202 | for y_i in range(p_x.shape[0]): 203 | temp_x=[p_x[0],p_x[-1]] 204 | temp_y=[p_y[y_i],p_y[y_i]] 205 | z=[0,0] 206 | ax.plot(temp_x,temp_y,z,color='black',alpha=0.1) 207 | 208 | for j in range(results.shape[0]): 209 | 210 | xs=pred[j,i,:,0].numpy() 211 | ys=pred[j,i,:,1].numpy() 212 | zs=pred[j,i,:,2].numpy() 213 | 214 | alpha=1 215 | ax.plot( zs,xs, ys, 'y.',alpha=alpha) 216 | 217 | if gt: 218 | x=all_seq[j,i,:,0].numpy() 219 | 220 | y=all_seq[j,i,:,1].numpy() 221 | z=all_seq[j,i,:,2].numpy() 222 | 223 | 224 | ax.plot( z,x, y, 'y.') 225 | 226 | 227 | plot_edge=True 228 | if plot_edge: 229 | for edge in body_edges: 230 | x=[pred[j,i,edge[0],0],pred[j,i,edge[1],0]] 231 | y=[pred[j,i,edge[0],1],pred[j,i,edge[1],1]] 232 | z=[pred[j,i,edge[0],2],pred[j,i,edge[1],2]] 233 | if i>=15: 234 | ax.plot(z,x, y, zdir='z',c='blue',alpha=alpha) 235 | 236 | else: 237 | ax.plot(z,x, y, zdir='z',c='green',alpha=alpha) 238 | 239 | if gt: 240 | x=[all_seq[j,i,edge[0],0],all_seq[j,i,edge[1],0]] 241 | y=[all_seq[j,i,edge[0],1],all_seq[j,i,edge[1],1]] 242 | z=[all_seq[j,i,edge[0],2],all_seq[j,i,edge[1],2]] 243 | 244 | if i>=15: 245 | ax.plot( z,x, y, 'yellow',alpha=0.8) 246 | else: 247 | ax.plot( z, x, y, 'green') 248 | 249 | 250 | ax.set_xlim3d([-3 , 3]) 251 | ax.set_ylim3d([-3 , 3]) 252 | ax.set_zlim3d([0,3]) 253 | # ax.set_xlim3d([-8 , 8]) 254 | # ax.set_ylim3d([-8 , 8]) 255 | # ax.set_zlim3d([0,5]) 256 | # ax.set_xticklabels([]) 257 | # ax.set_yticklabels([]) 258 | # ax.set_zticklabels([]) 259 | ax.set_axis_off() 260 | #ax.patch.set_alpha(1) 261 | #ax.set_aspect('equal') 262 | #ax.set_xlabel("x") 263 | #ax.set_ylabel("y") 264 | #ax.set_zlabel("z") 265 | plt.title(str(i),y=-0.1) 266 | plt.pause(0.1) 267 | i += 1 268 | 269 | 270 | plt.ioff() 271 | plt.show() 272 | plt.close() 273 | 274 | 275 | 276 | 277 | print('avg 1 second',np.mean(loss_list1)) 278 | print('avg 2 seconds',np.mean(loss_list2)) 279 | print('avg 3 seconds',np.mean(loss_list3)) 280 | 281 | 282 | -------------------------------------------------------------------------------- /MRT/Models.py: -------------------------------------------------------------------------------- 1 | ''' Define the Transformer model ''' 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | from MRT.Layers import EncoderLayer, DecoderLayer 6 | import torch.nn.functional as F 7 | 8 | 9 | def get_pad_mask(seq, pad_idx): 10 | return (seq != pad_idx).unsqueeze(-2) 11 | 12 | 13 | def get_subsequent_mask(seq): 14 | ''' For masking out the subsequent info. ''' 15 | sz_b, len_s, *_ = seq.size() 16 | subsequent_mask = (1 - torch.triu( 17 | torch.ones((1, len_s, len_s), device=seq.device), diagonal=1)).bool() 18 | return subsequent_mask 19 | 20 | 21 | class PositionalEncoding(nn.Module): 22 | 23 | def __init__(self, d_hid, n_position=200): 24 | super(PositionalEncoding, self).__init__() 25 | 26 | # Not a parameter 27 | self.register_buffer('pos_table', self._get_sinusoid_encoding_table(n_position, d_hid)) 28 | self.register_buffer('pos_table2', self._get_sinusoid_encoding_table(n_position, d_hid)) 29 | # self.register_buffer('pos_table3', self._get_sinusoid_encoding_table(n_position, d_hid)) 30 | def _get_sinusoid_encoding_table(self, n_position, d_hid): 31 | ''' Sinusoid position encoding table ''' 32 | 33 | def get_position_angle_vec(position): 34 | return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] 35 | 36 | sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) 37 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i 38 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 39 | 40 | return torch.FloatTensor(sinusoid_table).unsqueeze(0) 41 | 42 | def forward(self,x,n_person): 43 | p=self.pos_table[:,:x.size(1)].clone().detach() 44 | return x + p 45 | 46 | def forward2(self, x, n_person): 47 | # if x.shape[1]==135: 48 | # p=self.pos_table3[:, :int(x.shape[1]/n_person)].clone().detach() 49 | # p=p.repeat(1,n_person,1) 50 | # else: 51 | p=self.pos_table2[:, :int(x.shape[1]/n_person)].clone().detach() 52 | p=p.repeat(1,n_person,1) 53 | return x + p 54 | 55 | 56 | class Encoder(nn.Module): 57 | ''' A encoder model with self attention mechanism. ''' 58 | 59 | def __init__( 60 | self, d_word_vec, n_layers, n_head, d_k, d_v, 61 | d_model, d_inner, pad_idx, dropout=0.1, n_position=200, device='cuda'): 62 | 63 | super().__init__() 64 | self.position_embeddings = nn.Embedding(n_position, d_model) 65 | #self.src_word_emb = nn.Embedding(n_src_vocab, d_word_vec, padding_idx=pad_idx) 66 | self.position_enc = PositionalEncoding(d_word_vec, n_position=n_position) 67 | self.dropout = nn.Dropout(p=dropout) 68 | self.layer_stack = nn.ModuleList([ 69 | EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout) 70 | for _ in range(n_layers)]) 71 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) 72 | self.device=device 73 | def forward(self, src_seq,n_person, src_mask, return_attns=False, global_feature=False): 74 | 75 | enc_slf_attn_list = [] 76 | # -- Forward 77 | #src_seq = self.layer_norm(src_seq) 78 | if global_feature: 79 | enc_output = self.dropout(self.position_enc.forward2(src_seq,n_person)) 80 | #enc_output = self.dropout(src_seq) 81 | else: 82 | enc_output = self.dropout(self.position_enc(src_seq,n_person)) 83 | #enc_output = self.layer_norm(enc_output) 84 | #enc_output=self.dropout(src_seq+position_embeddings) 85 | #enc_output = self.dropout(self.layer_norm(enc_output)) 86 | for enc_layer in self.layer_stack: 87 | enc_output, enc_slf_attn = enc_layer(enc_output, slf_attn_mask=src_mask) 88 | enc_slf_attn_list += [enc_slf_attn] if return_attns else [] 89 | 90 | if return_attns: 91 | return enc_output, enc_slf_attn_list 92 | 93 | 94 | return enc_output, 95 | 96 | 97 | class Decoder(nn.Module): 98 | 99 | def __init__( 100 | self, d_word_vec, n_layers, n_head, d_k, d_v, 101 | d_model, d_inner, pad_idx, n_position=200, dropout=0.1,device='cuda'): 102 | 103 | super().__init__() 104 | 105 | #self.trg_word_emb = nn.Embedding(n_trg_vocab, d_word_vec, padding_idx=pad_idx) 106 | self.position_enc = PositionalEncoding(d_word_vec, n_position=n_position) 107 | self.dropout = nn.Dropout(p=dropout) 108 | self.layer_stack = nn.ModuleList([ 109 | DecoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout) 110 | for _ in range(n_layers)]) 111 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) 112 | self.device=device 113 | 114 | def forward(self, trg_seq, trg_mask, enc_output, src_mask, return_attns=False): 115 | 116 | dec_slf_attn_list, dec_enc_attn_list = [], [] 117 | 118 | dec_output = (trg_seq) 119 | 120 | for dec_layer in self.layer_stack: 121 | dec_output, dec_slf_attn, dec_enc_attn = dec_layer( 122 | dec_output, enc_output, slf_attn_mask=trg_mask, dec_enc_attn_mask=src_mask) 123 | dec_slf_attn_list += [dec_slf_attn] if return_attns else [] 124 | dec_enc_attn_list += [dec_enc_attn] if return_attns else [] 125 | 126 | if return_attns: 127 | return dec_output, dec_slf_attn_list, dec_enc_attn_list 128 | return dec_output, dec_enc_attn_list 129 | 130 | class Transformer(nn.Module): 131 | ''' A sequence to sequence model with attention mechanism. ''' 132 | 133 | def __init__( 134 | self, src_pad_idx=1, trg_pad_idx=1, 135 | d_word_vec=64, d_model=64, d_inner=512, 136 | n_layers=3, n_head=8, d_k=32, d_v=32, dropout=0.2, n_position=100, 137 | device='cuda'): 138 | 139 | super().__init__() 140 | 141 | self.device=device 142 | 143 | self.d_model=d_model 144 | self.src_pad_idx, self.trg_pad_idx = src_pad_idx, trg_pad_idx 145 | self.proj=nn.Linear(45,d_model) # 45: 15jointsx3 146 | self.proj2=nn.Linear(45,d_model) 147 | self.proj_inverse=nn.Linear(d_model,45) 148 | self.l1=nn.Linear(d_model, d_model*4) 149 | self.l2=nn.Linear(d_model*4, d_model*15) 150 | 151 | self.dropout = nn.Dropout(p=dropout) 152 | 153 | self.encoder = Encoder( 154 | n_position=n_position, 155 | d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner, 156 | n_layers=n_layers, n_head=n_head, d_k=d_k, d_v=d_v, 157 | pad_idx=src_pad_idx, dropout=dropout, device=self.device) 158 | 159 | self.encoder_global = Encoder( 160 | n_position=n_position, 161 | d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner, 162 | n_layers=n_layers, n_head=n_head, d_k=d_k, d_v=d_v, 163 | pad_idx=src_pad_idx, dropout=dropout, device=self.device) 164 | 165 | self.decoder = Decoder( 166 | n_position=n_position, 167 | d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner, 168 | n_layers=n_layers, n_head=n_head, d_k=d_k, d_v=d_v, 169 | pad_idx=trg_pad_idx, dropout=dropout, device=self.device) 170 | 171 | 172 | 173 | for p in self.parameters(): 174 | if p.dim() > 1: 175 | nn.init.xavier_uniform_(p) 176 | 177 | assert d_model == d_word_vec, \ 178 | 'To facilitate the residual connections, \ 179 | the dimensions of all module outputs shall be the same.' 180 | 181 | def forward_local(self, src_seq, trg_seq, input_seq,use=None): 182 | 183 | #only use local-range encoder 184 | 185 | n_person=input_seq.shape[1] 186 | 187 | src_mask = (torch.ones([src_seq.shape[0],1,src_seq.shape[1]])==True).to(self.device) 188 | 189 | 190 | trg_mask = (torch.ones([trg_seq.shape[0],1,trg_seq.shape[1]])==True).to(self.device) & get_subsequent_mask(trg_seq).to(self.device) 191 | 192 | 193 | src_seq_=self.proj(src_seq) 194 | trg_seq_=self.proj2(trg_seq) 195 | 196 | enc_output, *_=self.encoder(src_seq_,n_person, src_mask) 197 | dec_output, *_=self.decoder(trg_seq_, None, enc_output, None) 198 | 199 | dec_output=self.l1(dec_output) 200 | dec_output=self.l2(dec_output) 201 | dec_output=dec_output.view(dec_output.shape[0],-1,self.d_model) 202 | 203 | dec_output=self.proj_inverse(dec_output) 204 | 205 | return dec_output 206 | 207 | 208 | 209 | def forward(self, src_seq, trg_seq, input_seq, use=None): 210 | ''' 211 | src_seq: local 212 | trg_seq: local 213 | input_seq: global 214 | ''' 215 | n_person=input_seq.shape[1] 216 | 217 | #src_mask = (torch.ones([src_seq.shape[0],1,src_seq.shape[1]])==True).to(self.device) 218 | 219 | src_seq_=self.proj(src_seq) 220 | trg_seq_=self.proj2(trg_seq) 221 | 222 | enc_output, *_ = self.encoder(src_seq_, n_person, None) 223 | 224 | others=input_seq[:,:,:,:].view(input_seq.shape[0],-1,45) 225 | others_=self.proj2(others) 226 | mask_other=None 227 | mask_dec=None 228 | 229 | #mask_other=torch.zeros([others.shape[0],1,others_.shape[1]]).to(self.device).long() 230 | #for i in range(len(use)): 231 | # mask_other[i][0][:use[i]*15]=1 232 | 233 | enc_others,*_=self.encoder_global(others_,n_person, mask_other, global_feature=True) 234 | enc_others=enc_others.unsqueeze(1).expand(input_seq.shape[0],input_seq.shape[1],-1,self.d_model) 235 | 236 | enc_others=enc_others.reshape(enc_output.shape[0],-1,self.d_model) 237 | #mask_other=mask_other.unsqueeze(1).expand(input_seq.shape[0],input_seq.shape[1],1,-1) 238 | #mask_other=mask_other.reshape(enc_others.shape[0],1,-1) 239 | #mask_dec=torch.cat([src_mask*1,mask_other.long()],dim=-1) 240 | 241 | 242 | temp_a=input_seq.unsqueeze(1).repeat(1,input_seq.shape[1],1,1,1) 243 | temp_b=input_seq[:,:,-1:,:].unsqueeze(2).repeat(1,1,input_seq.shape[1],1,1) 244 | c=torch.mean((temp_a-temp_b)**2,dim=-1) 245 | c=c.reshape(c.shape[0]*c.shape[1],c.shape[2]*c.shape[3],1) 246 | 247 | 248 | enc_output=torch.cat([enc_output,enc_others+torch.exp(-c)],dim=1) 249 | dec_output, dec_attention,*_ = self.decoder(trg_seq_[:,:1,:], None, enc_output, mask_dec) 250 | 251 | 252 | dec_output= self.l1(dec_output) 253 | dec_output= self.l2(dec_output) 254 | dec_output=dec_output.view(dec_output.shape[0],15,self.d_model) 255 | 256 | dec_output=self.proj_inverse(dec_output) 257 | 258 | return dec_output#,dec_attention 259 | 260 | 261 | 262 | class Discriminator(nn.Module): 263 | def __init__( 264 | self, src_pad_idx=1, trg_pad_idx=1, 265 | d_word_vec=128, d_model=128, d_inner=1024, 266 | n_layers=3, n_head=8, d_k=64, d_v=64, dropout=0.2, n_position=50, 267 | device='cuda'): 268 | 269 | super().__init__() 270 | self.device=device 271 | self.d_model=d_model 272 | self.encoder = Encoder( 273 | n_position=n_position, 274 | d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner, 275 | n_layers=n_layers, n_head=n_head, d_k=d_k, d_v=d_v, 276 | pad_idx=src_pad_idx, dropout=dropout, device=self.device) 277 | 278 | self.fc=nn.Linear(45,1) 279 | 280 | 281 | def forward(self, x): 282 | x, *_ = self.encoder(x,n_person=None, src_mask=None) 283 | x=self.fc(x) 284 | x=x.view(-1,1) 285 | return x --------------------------------------------------------------------------------