├── .gitignore
├── LICENSE
├── MotionFill
├── data
│ ├── AMASS_localmotionfill_dataloader.py
│ ├── GRAB_end2end_dataloader.py
│ └── __pycache__
│ │ └── GRAB_end2end_dataloader.cpython-38.pyc
└── models
│ ├── LocalMotionFill.py
│ ├── TrajFill.py
│ ├── __pycache__
│ ├── LocalMotionFill.cpython-310.pyc
│ ├── LocalMotionFill.cpython-38.pyc
│ ├── TrajFill.cpython-310.pyc
│ ├── TrajFill.cpython-38.pyc
│ ├── fittingop.cpython-38.pyc
│ └── motionfill_cvae.cpython-310.pyc
│ └── fittingop.py
├── README.md
├── WholeGraspPose
├── __pycache__
│ ├── trainer.cpython-310.pyc
│ └── trainer.cpython-38.pyc
├── configs
│ ├── WholeGraspPose.yaml
│ ├── rhand_weight.npy
│ └── verts_per_edge.npy
├── data
│ ├── __pycache__
│ │ ├── dataloader.cpython-310.pyc
│ │ └── dataloader.cpython-38.pyc
│ └── dataloader.py
├── models
│ ├── __pycache__
│ │ ├── fittingop.cpython-38.pyc
│ │ ├── models.cpython-310.pyc
│ │ ├── models.cpython-38.pyc
│ │ ├── objectmodel.cpython-38.pyc
│ │ ├── pointnet.cpython-38.pyc
│ │ ├── pointnet_util.cpython-310.pyc
│ │ └── pointnet_util.cpython-38.pyc
│ ├── fittingop.py
│ ├── models.py
│ ├── objectmodel.py
│ └── pointnet.py
└── trainer.py
├── body_utils
├── body_models
│ └── VPoser
│ │ ├── .DS_Store
│ │ ├── vposerDecoderWeights.npz
│ │ ├── vposerEncoderWeights.npz
│ │ └── vposerMeanPose.npz
├── body_segments
│ ├── L_Hand.json
│ ├── L_Leg.json
│ ├── R_Hand.json
│ ├── R_Leg.json
│ ├── back.json
│ ├── body_mask.json
│ ├── butt.json
│ ├── gluteus.json
│ └── thighs.json
├── left_heel_verts_id.npy
├── left_toe_verts_id.npy
├── left_whole_foot_verts_id.npy
├── right_heel_verts_id.npy
├── right_toe_verts_id.npy
├── right_whole_foot_verts_id.npy
├── smplx_mano_flame_correspondences
│ ├── MANO_SMPLX_vertex_ids.pkl
│ └── SMPL-X__FLAME_vertex_ids.npy
└── smplx_markerset.json
├── images
├── binoculars-0.jpg
├── binoculars-60-first-view.jpg
├── binoculars-60.jpg
├── binoculars-movie-first-view.gif
├── binoculars-video.gif
├── teaser.png
├── two-stage-pipeline.png
├── wineglass-0.jpg
├── wineglass-60-first-view.jpg
├── wineglass-60.jpg
├── wineglass-movie-first-view.gif
└── wineglass-video.gif
├── opt_graspmotion.py
├── opt_grasppose.py
├── train_graspmotion.py
├── train_grasppose.py
├── utils
├── Pivots.py
├── Pivots_torch.py
├── Quaternions.py
├── Quaternions_torch.py
├── __pycache__
│ ├── Pivots.cpython-38.pyc
│ ├── Pivots_torch.cpython-38.pyc
│ ├── Quaternions.cpython-38.pyc
│ ├── Quaternions_torch.cpython-38.pyc
│ ├── cfg_parser.cpython-310.pyc
│ ├── cfg_parser.cpython-38.pyc
│ ├── train_helper.cpython-38.pyc
│ ├── train_tools.cpython-310.pyc
│ ├── train_tools.cpython-38.pyc
│ ├── utils.cpython-310.pyc
│ ├── utils.cpython-38.pyc
│ └── utils_body.cpython-38.pyc
├── cfg_parser.py
├── como
│ ├── SSM2.json
│ ├── __pycache__
│ │ ├── como_smooth.cpython-310.pyc
│ │ ├── como_smooth.cpython-38.pyc
│ │ ├── como_utils.cpython-310.pyc
│ │ └── como_utils.cpython-38.pyc
│ ├── como_smooth.py
│ ├── como_smooth_model.pkl
│ ├── como_utils.py
│ ├── my_SSM2.json
│ └── preprocess_stats_global_markers
│ │ ├── Xmean.npy
│ │ └── Xstd.npy
├── train_helper.py
├── utils.py
└── utils_body.py
└── visualization
├── __pycache__
└── visualization_utils.cpython-38.pyc
├── vis_motion.py
├── vis_pose.py
└── visualization_utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | dataset/
2 | results/
3 | smplx/
4 | mano_v1_2/
5 | vposer_v1_0/
6 | pretrained_model/
7 | logs/
8 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Jiahao
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/MotionFill/data/__pycache__/GRAB_end2end_dataloader.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/MotionFill/data/__pycache__/GRAB_end2end_dataloader.cpython-38.pyc
--------------------------------------------------------------------------------
/MotionFill/models/LocalMotionFill.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torchvision
6 | from torch.autograd import Variable
7 |
8 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9 |
10 | class EncBlock(nn.Module):
11 | def __init__(self, nin, nout, downsample=True, kernel=3):
12 | super(EncBlock, self).__init__()
13 | self.downsample = downsample
14 | padding = kernel // 2
15 |
16 | self.main = nn.Sequential(
17 | nn.Conv2d(in_channels=nin, out_channels=nout, kernel_size=kernel, stride=1, padding=padding, padding_mode='replicate'),
18 | nn.LeakyReLU(0.2),
19 | nn.Conv2d(in_channels=nout, out_channels=nout, kernel_size=kernel, stride=1, padding=padding, padding_mode='replicate'),
20 | nn.LeakyReLU(0.2),
21 | )
22 |
23 | if self.downsample:
24 | self.pooling = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
25 | else:
26 | self.pooling = nn.MaxPool2d(kernel_size=(3,3), stride=(2, 1), padding=1)
27 |
28 | def forward(self, input):
29 | output = self.main(input)
30 | output = self.pooling(output)
31 | return output
32 |
33 | class DecBlock(nn.Module):
34 | def __init__(self, nin, nout, upsample=True, kernel=3):
35 | super(DecBlock, self).__init__()
36 | self.upsample = upsample
37 |
38 | padding = kernel // 2
39 | if upsample:
40 | self.deconv1 = nn.ConvTranspose2d(in_channels=nin, out_channels=nout, kernel_size=kernel, stride=2, padding=padding)
41 | else:
42 | self.deconv1 = nn.ConvTranspose2d(in_channels=nin, out_channels=nout, kernel_size=kernel, stride=(2, 1), padding=padding)
43 | self.deconv2 = nn.ConvTranspose2d(in_channels=nout, out_channels=nout, kernel_size=kernel, stride=1, padding=padding)
44 | self.leaky_relu = nn.LeakyReLU(0.2)
45 |
46 | def forward(self, input, out_size):
47 | output = self.deconv1(input, output_size=out_size)
48 | output = self.leaky_relu(output)
49 | output = self.leaky_relu(self.deconv2(output))
50 | return output
51 |
52 |
53 | class DecBlock_output(nn.Module):
54 | def __init__(self, nin, nout, upsample=True, kernel=3):
55 | super(DecBlock_output, self).__init__()
56 | self.upsample = upsample
57 | padding = kernel // 2
58 |
59 | if upsample:
60 | self.deconv1 = nn.ConvTranspose2d(in_channels=nin, out_channels=nout, kernel_size=kernel, stride=2, padding=padding)
61 | else:
62 | self.deconv1 = nn.ConvTranspose2d(in_channels=nin, out_channels=nout, kernel_size=kernel, stride=(2, 1), padding=padding)
63 | self.deconv2 = nn.ConvTranspose2d(in_channels=nout, out_channels=nout, kernel_size=kernel, stride=1, padding=padding)
64 | self.leaky_relu = nn.LeakyReLU(0.2)
65 |
66 |
67 | def forward(self, input, out_size):
68 | output = self.deconv1(input, output_size=out_size)
69 | output = self.leaky_relu(output)
70 | output = self.deconv2(output)
71 | return output
72 |
73 |
74 | class AE(nn.Module):
75 | def __init__(self, downsample=True, in_channel=1, kernel=3):
76 | super(AE, self).__init__()
77 | self.enc_blc1 = EncBlock(nin=in_channel, nout=32, downsample=downsample, kernel=kernel)
78 | self.enc_blc2 = EncBlock(nin=32, nout=64, downsample=downsample, kernel=kernel)
79 | self.enc_blc3 = EncBlock(nin=64, nout=128, downsample=downsample, kernel=kernel)
80 | self.enc_blc4 = EncBlock(nin=128, nout=256, downsample=downsample, kernel=kernel)
81 | self.enc_blc5 = EncBlock(nin=256, nout=256, downsample=downsample, kernel=kernel)
82 |
83 | self.dec_blc1 = DecBlock(nin=256, nout=256, upsample=downsample, kernel=kernel)
84 | self.dec_blc2 = DecBlock(nin=256, nout=128, upsample=downsample, kernel=kernel)
85 | self.dec_blc3 = DecBlock(nin=128, nout=64, upsample=downsample, kernel=kernel)
86 | self.dec_blc4 = DecBlock(nin=64, nout=32, upsample=downsample, kernel=kernel)
87 | self.dec_blc5 = DecBlock_output(nin=32, nout=1, upsample=downsample, kernel=kernel)
88 |
89 | def forward(self, input):
90 | # input: [bs, c, d, T]
91 | x_down1 = self.enc_blc1(input)
92 | x_down2 = self.enc_blc2(x_down1)
93 | x_down3 = self.enc_blc3(x_down2)
94 | x_down4 = self.enc_blc4(x_down3)
95 | z = self.enc_blc5(x_down4)
96 |
97 | x_up4 = self.dec_blc1(z, x_down4.size())
98 | x_up3 = self.dec_blc2(x_up4, x_down3.size())
99 | x_up2 = self.dec_blc3(x_up3, x_down2.size())
100 | x_up1 = self.dec_blc4(x_up2, x_down1.size())
101 | output = self.dec_blc5(x_up1, input.size())
102 |
103 | return output, z
104 |
105 |
106 | class Flatten(nn.Module):
107 | def forward(self, input):
108 | return input.view(input.size(0), -1)
109 |
110 |
111 | class CNN_Encoder(nn.Module):
112 | def __init__(self, downsample=True, in_channel=1, kernel=3):
113 | super(CNN_Encoder, self).__init__()
114 | self.enc_blc1 = EncBlock(nin=in_channel, nout=32, downsample=downsample, kernel=kernel)
115 | self.enc_blc2 = EncBlock(nin=32, nout=64, downsample=downsample, kernel=kernel)
116 | self.enc_blc3 = EncBlock(nin=64, nout=128, downsample=downsample, kernel=kernel)
117 | self.enc_blc4 = EncBlock(nin=128, nout=256, downsample=downsample, kernel=kernel)
118 | self.enc_blc5 = EncBlock(nin=256, nout=256, downsample=downsample, kernel=kernel)
119 |
120 | def forward(self, input):
121 | x_down1 = self.enc_blc1(input)
122 | x_down2 = self.enc_blc2(x_down1)
123 | x_down3 = self.enc_blc3(x_down2)
124 | x_down4 = self.enc_blc4(x_down3)
125 | z = self.enc_blc5(x_down4)
126 | size_list = [x_down4.size(), x_down3.size(), x_down2.size(), x_down1.size(), input.size()]
127 | return z, size_list
128 |
129 |
130 | class CNN_Decoder(nn.Module):
131 | def __init__(self, downsample=True, kernel=3):
132 | super(CNN_Decoder, self).__init__()
133 | self.dec_blc1 = DecBlock(nin=512, nout=256, upsample=downsample, kernel=kernel)
134 | self.dec_blc2 = DecBlock(nin=256, nout=128, upsample=downsample, kernel=kernel)
135 | self.dec_blc3 = DecBlock(nin=128, nout=64, upsample=downsample, kernel=kernel)
136 | self.dec_blc4 = DecBlock(nin=64, nout=32, upsample=downsample, kernel=kernel)
137 | self.dec_blc5 = DecBlock_output(nin=32, nout=1, upsample=downsample, kernel=kernel)
138 |
139 | def forward(self, z, size_list):
140 | x_up4 = self.dec_blc1(z, size_list[0])
141 | x_up3 = self.dec_blc2(x_up4, size_list[1])
142 | x_up2 = self.dec_blc3(x_up3, size_list[2])
143 | x_up1 = self.dec_blc4(x_up2, size_list[3])
144 | output = self.dec_blc5(x_up1, size_list[4])
145 | return output
146 |
147 |
148 | class Motion_CNN_CVAE(nn.Module):
149 | def __init__(self, nz, downsample=True, in_channel=1, kernel=3, clip_seconds=2):
150 | super(Motion_CNN_CVAE, self).__init__()
151 | self.nz = nz # dim of latent variables
152 | self.enc_conv_input = CNN_Encoder(downsample, in_channel, kernel)
153 | self.enc_conv_gt = CNN_Encoder(downsample, in_channel, kernel)
154 | self.enc_conv_cat = nn.Sequential(
155 | nn.Conv2d(in_channels=512, out_channels=512, kernel_size=kernel, stride=1, padding=kernel//2, padding_mode='replicate'),
156 | nn.LeakyReLU(0.2),
157 | Flatten(),
158 | )
159 | self.enc_mu = nn.Linear(512*8*clip_seconds, self.nz)
160 | self.enc_logvar = nn.Linear(512*8*clip_seconds, self.nz)
161 | self.dec_dense = nn.Linear(self.nz, 256*8*clip_seconds)
162 | self.dec_conv = CNN_Decoder(downsample, kernel)
163 |
164 |
165 | def encode(self, x, y):
166 | # x: [bs, c, d, T] (input)
167 | # y: [bs, c, d, T] (gt)
168 | e_x, _ = self.enc_conv_input(x)
169 | e_y, _ = self.enc_conv_gt(y)
170 | e_xy = torch.cat((e_x, e_y), dim=1)
171 | z = self.enc_conv_cat(e_xy)
172 | z_mu = self.enc_mu(z)
173 | z_logvar = self.enc_logvar(z)
174 | return z_mu, z_logvar
175 |
176 | def reparameterize(self, mu, logvar, eps):
177 | std = torch.exp(0.5*logvar)
178 | return mu + eps * std
179 |
180 | def decode(self, x, z):
181 | e_x, size_list = self.enc_conv_input(x)
182 | d_z_dense = self.dec_dense(z)
183 | d_z = d_z_dense.view(e_x.size(0), e_x.size(1), e_x.size(2), e_x.size(3))
184 | d_xz = torch.cat((e_x, d_z), dim=1)
185 | y_hat = self.dec_conv(d_xz, size_list)
186 | return y_hat
187 |
188 | def forward(self, input, gt=None, is_train=True, z=None, is_twice=None):
189 | # input: [bs, c, d, T]
190 | self.bs = len(input)
191 | if is_train:
192 | mu, logvar = self.encode(input, gt)
193 | eps = torch.randn_like(logvar)
194 | z = self.reparameterize(mu, logvar, eps) #if self.training else mu
195 | else:
196 | if is_twice:
197 | mu, logvar = self.encode(input, gt)
198 | z = mu
199 | else:
200 | if z is None:
201 | z = torch.randn((self.bs, self.nz), device=input.device)
202 | mu = 0
203 | logvar = 1
204 |
205 | pred = self.decode(input, z)
206 | return pred, mu, logvar
207 |
208 | def sample_prior(self, x):
209 | z = torch.randn((x.shape[1], self.nz), device=x.device)
210 | return self.decode(x, z)
211 |
--------------------------------------------------------------------------------
/MotionFill/models/TrajFill.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torchvision
6 | from torch.autograd import Variable
7 |
8 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9 |
10 | class ResBlock(nn.Module):
11 |
12 | def __init__(self,
13 | Fin,
14 | Fout,
15 | n_neurons=256):
16 |
17 | super(ResBlock, self).__init__()
18 | self.Fin = Fin
19 | self.Fout = Fout
20 |
21 | self.fc1 = nn.Linear(Fin, n_neurons)
22 | self.bn1 = nn.BatchNorm1d(n_neurons)
23 |
24 | self.fc2 = nn.Linear(n_neurons, Fout)
25 | self.bn2 = nn.BatchNorm1d(Fout)
26 |
27 | if Fin != Fout:
28 | self.fc3 = nn.Linear(Fin, Fout)
29 |
30 | self.ll = nn.LeakyReLU(negative_slope=0.2)
31 |
32 | def forward(self, x, final_nl=True):
33 | Xin = x if self.Fin == self.Fout else self.ll(self.fc3(x))
34 |
35 | Xout = self.fc1(x) # n_neurons
36 | Xout = self.bn1(Xout)
37 | Xout = self.ll(Xout)
38 |
39 | Xout = self.fc2(Xout)
40 | Xout = self.bn2(Xout)
41 | Xout = Xin + Xout
42 |
43 | if final_nl:
44 | return self.ll(Xout)
45 | return Xout
46 |
47 |
48 | class Traj_MLP_CVAE(nn.Module): # T*4 => T*8 => ResBlock -> z => ResBlock
49 | def __init__(self, nz, feature_dim, T, residual=False, load_path=None):
50 | super(Traj_MLP_CVAE, self).__init__()
51 | self.T = T
52 | self.feature_dim = feature_dim
53 | self.nz = nz
54 | self.residual = residual
55 | self.load_path = load_path
56 |
57 | """MLP"""
58 | self.enc1 = ResBlock(Fin=2*feature_dim*T, Fout=2*feature_dim*T, n_neurons=2*feature_dim*T)
59 | self.enc2 = ResBlock(Fin=2*feature_dim*T, Fout=2*feature_dim*T, n_neurons=2*feature_dim*T)
60 | self.enc_mu = nn.Linear(2*feature_dim*T, nz)
61 | self.enc_var = nn.Linear(2*feature_dim*T, nz)
62 |
63 | self.dec1 = ResBlock(Fin=nz+feature_dim*T, Fout=2*feature_dim*T, n_neurons=2*feature_dim*T)
64 | self.dec2 = ResBlock(Fin=2*feature_dim*T + feature_dim*T, Fout=feature_dim*T, n_neurons=feature_dim*T)
65 |
66 | if self.load_path is not None:
67 | self._load_model()
68 |
69 | def encode(self, x, y):
70 | """ x: [bs, T*feature_dim] """
71 | bs = x.shape[0]
72 | x = torch.cat([x, y], dim=-1)
73 | h = self.enc1(x)
74 | h = self.enc2(h)
75 | z_mu = self.enc_mu(h)
76 | z_logvar = self.enc_var(h)
77 |
78 | return z_mu, z_logvar
79 |
80 | def decode(self, z, y):
81 | """z: [bs, nz]; y: [bs, 2*feature_dim]"""
82 | bs = y.shape[0]
83 | x = torch.cat([z, y], dim=-1)
84 | x = self.dec1(x)
85 | x = torch.cat([x, y], dim=-1)
86 | x = self.dec2(x)
87 |
88 | if self.residual:
89 | x = x + y
90 |
91 | return x.reshape(bs, self.feature_dim, -1)
92 |
93 | def reparameterize(self, mu, logvar, eps):
94 | std = torch.exp(0.5*logvar)
95 | return mu + eps * std
96 |
97 | def forward(self, x, y):
98 | bs = x.shape[0]
99 | # print(x.shape, y.shape)
100 | mu, logvar = self.encode(x, y)
101 | eps = torch.randn_like(logvar)
102 | z = self.reparameterize(mu, logvar, eps) #if self.training else mu
103 | pred = self.decode(z, y)
104 |
105 | return pred, mu, logvar
106 |
107 | def sample(self, y, z=None):
108 | if z is None:
109 | z = torch.randn((y.shape[0], self.nz), device=y.device)
110 | return self.decode(z, y)
111 |
112 | def _load_model(self):
113 | print('Loading Traj_CVAE from {} ...'.format(self.load_path))
114 | assert self.load_path is not None
115 | model_cp = torch.load(self.load_path)
116 | self.load_state_dict(model_cp['model_dict'])
117 |
--------------------------------------------------------------------------------
/MotionFill/models/__pycache__/LocalMotionFill.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/MotionFill/models/__pycache__/LocalMotionFill.cpython-310.pyc
--------------------------------------------------------------------------------
/MotionFill/models/__pycache__/LocalMotionFill.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/MotionFill/models/__pycache__/LocalMotionFill.cpython-38.pyc
--------------------------------------------------------------------------------
/MotionFill/models/__pycache__/TrajFill.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/MotionFill/models/__pycache__/TrajFill.cpython-310.pyc
--------------------------------------------------------------------------------
/MotionFill/models/__pycache__/TrajFill.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/MotionFill/models/__pycache__/TrajFill.cpython-38.pyc
--------------------------------------------------------------------------------
/MotionFill/models/__pycache__/fittingop.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/MotionFill/models/__pycache__/fittingop.cpython-38.pyc
--------------------------------------------------------------------------------
/MotionFill/models/__pycache__/motionfill_cvae.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/MotionFill/models/__pycache__/motionfill_cvae.cpython-310.pyc
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | SAGA: Stochastic Whole-Body Grasping with Contact
3 |
4 |
5 | > [**SAGA: Stochastic Whole-Body Grasping with Contact**](https://jiahaoplus.github.io/SAGA/saga.html)
6 | > **ECCV 2022**
7 | > Yan Wu*, Jiahao Wang*, Yan Zhang, Siwei Zhang, Otmar Hilliges, Fisher Yu, Siyu Tang
8 |
9 | 
10 | This repository is the official implementation for the ECCV 2022 paper: [SAGA: Stochastic Whole-Body Grasping with Contact](https://jiahaoplus.github.io/SAGA/saga.html).
11 |
12 | \[[Project Page](https://jiahaoplus.github.io/SAGA/saga.html) | [Paper](https://arxiv.org/abs/2112.10103)\]
13 |
14 | ## Introduction
15 | Given an object in 3D space and a human initial pose, we aim to generate diverse human motion sequences to approach and grasp the given object. We propose a two-stage pipeline to address this problem by generating grasping ending pose first and then infilling the in-between motion.
16 |
17 | 
18 |
19 |
20 |
21 | Input |
22 | First-stage result |
23 | Second-stage result |
24 |
25 |
26 |  |
27 |  |
28 |  |
29 |
30 |
31 |  |
32 |  |
33 |
34 |
35 |  |
36 |  |
37 |  |
38 |
39 |
40 |  |
41 |  |
42 |
43 |
44 |
45 | ## Contents
46 | - [Installation](https://github.com/JiahaoPlus/SAGA#installation)
47 | - [Dataset Preparation](https://github.com/JiahaoPlus/SAGA#Dataset)
48 | - [Pretrained models](https://github.com/JiahaoPlus/SAGA#pretrained-models)
49 | - [Train](https://github.com/JiahaoPlus/SAGA#train)
50 | - [Grasping poses and motions generation for given object](https://github.com/JiahaoPlus/SAGA#inference) (object position and orientation can be customized)
51 | - [Visualization](https://github.com/JiahaoPlus/SAGA#visualization)
52 |
53 | ## Installation
54 | - Packages
55 | - python>=3.8
56 | - pytorch==1.12.1
57 | - [human-body-prior](https://pypi.org/project/human-body-prior/)
58 | - [SMPLX](https://github.com/vchoutas/smplx)
59 | - [Chamfer Distance](https://github.com/otaheri/chamfer_distance)
60 | - Open3D
61 |
62 | - Body Models
63 | Download [SMPL-X body model and vposer v1.0 model](https://smpl-x.is.tue.mpg.de/index.html) and put them under /body_utils/body_models folder as below:
64 | ```
65 | SAGA
66 | │
67 | └───body_utils
68 | │
69 | └───body_models
70 | │
71 | └───smplx
72 | │ └───SMPLX_FEMALE.npz
73 | │ └───...
74 | │
75 | └───vposer_v1_0
76 | │ └───snapshots
77 | │ └───TR00_E096.pt
78 | │ └───...
79 | │
80 | └───VPoser
81 | │ └───vposerDecoderWeights.npz
82 | │ └───vposerEnccoderWeights.npz
83 | │ └───vposerMeanPose.npz
84 | │
85 | └───...
86 | │
87 | └───...
88 | ```
89 |
90 | ## Dataset
91 | ###
92 | Download [GRAB](https://grab.is.tue.mpg.de/) object mesh
93 |
94 | Download dataset for the first stage (GraspPose) from [[Google Drive]](https://drive.google.com/uc?export=download&id=1OfSGa3Y1QwkbeXUmAhrfeXtF89qvZj54)
95 |
96 | Download dataset for the second stage (GraspMotion) from [[Google Drive]](https://drive.google.com/uc?export=download&id=1QiouaqunhxKuv0D0QHv1JHlwVU-F6dWm)
97 |
98 | Put them under /dataset as below,
99 | ```
100 | SAGA
101 | │
102 | └───dataset
103 | │
104 | └───GraspPose
105 | │ └───train
106 | │ └───s1
107 | │ └───...
108 | │ └───eval
109 | │ └───s1
110 | │ └───...
111 | │ └───test
112 | │ └───s1
113 | │ └───...
114 | │
115 | └───GraspMotion
116 | │ └───Processed_Traj
117 | │ └───s1
118 | │ └───...
119 | │
120 | └───contact_meshes
121 | │ └───airplane.ply
122 | │ └───...
123 | │
124 | └───...
125 | ```
126 |
127 | ## Pretrained models
128 | Download pretrained models from [[Google Drive]](https://drive.google.com/uc?export=download&id=1dxzBBOUbRuUAlNQGxnbmWLmtP7bmEE_9), and the pretrained models include:
129 | - Stage 1: pretrained WholeGrasp-VAE for male and female respectively
130 | - Stage 2: pretrained TrajFill-VAE and LocalMotionFill-VAE (end to end)
131 |
132 | ## Train
133 | ### First Stage: WholeGrasp-VAE training
134 | ```
135 | python train_grasppose.py --data_path ./dataset/GraspPose --gender male --exp_name male
136 | ```
137 |
138 | ### Second Stage: MotionFill-VAE training
139 | Can train TrajFill-VAE and LocalMotionFill-VAE separately first (download separately trained models from [[Google Drive]](https://drive.google.com/uc?export=download&id=1eyUW7YLmnAj-CHwIMe9qsAs6W63Aw7Ce)), and then train them end-to-end:
140 | ```
141 | python train_graspmotion.py --pretrained_path_traj $PRETRAINED_MODEL_PATH/TrajFill_model_separate_trained.pkl --pretrained_path_motion $PRETRAINED_MODEL_PATH/LocalMotionFill_model_separate_trained.pkl
142 | ```
143 |
144 | ## Inference
145 | ### First Stage: WholeGrasp-VAE sampling + GraspPose-Opt
146 | At the first stage, we generate grasping poses for the given object.
147 | The example command below generates 10 male pose samples to grasp camera, where the object's height and orientation are randomly set within a reasonable range. You can also easily customize your own setting accordingly.
148 | ```
149 | python opt_grasppose.py --exp_name pretrained_male --gender male --pose_ckpt_path $PRETRAINED_MODEL_PATH/male_grasppose_model.pt --object camera --n_object_samples 10
150 | ```
151 | ### Second Stage: MotionFill-VAE sampling + GraspMotion-Opt
152 | At the second stage, with generated ending pose from the first stage and a customizable human initial pose, we generate in-between motions.
153 | The example command below generates male motion samples to grasp camera, where the human initial pose and the initial distance away from the given object are randomly set within a reasonable range. You can also easily customize your own setting accordingly.
154 | ```
155 | python opt_graspmotion.py --GraspPose_exp_name pretrained_male --object camera --gender male --traj_ckpt_path $PRETRAINED_MODEL_PATH/TrajFill_model.pkl --motion_ckpt_path $PRETRAINED_MODEL_PATH/LocalMotionFill_model.pkl
156 | ```
157 |
158 | ## Visualization
159 | We provide visualization script to visualize the generated grasping ending pose results which is saved at (by default) _/results/$EXP_NAME/GraspPose/$OBJECT/fitting_results.npz_.
160 | ```
161 | cd visualization
162 | python vis_pose.py --exp_name pretrained_male --gender male --object camera
163 | ```
164 |
165 | We provide visualization script to visualize the generated grasping motion result which is saved at (by default) _/results/$EXP_NAME/GraspMotion/$OBJECT/fitting_results.npy_, from 3 view points, the first-person view, third-person view and the bird-eye view.
166 | ```
167 | cd visualization
168 | python vis_motion.py --GraspPose_exp_name pretrained_male --gender male --object camera
169 | ```
170 |
171 | ### Contact
172 | If you have any questions, feel free to contact us:
173 | - Yan Wu: yan.wu@vision.ee.ethz.ch
174 | - Jiahao Wang: jiwang@mpi-inf.mpg.de
175 | ### Citation
176 | ```bash
177 | @inproceedings{wu2022saga,
178 | title = {SAGA: Stochastic Whole-Body Grasping with Contact},
179 | author = {Wu, Yan and Wang, Jiahao and Zhang, Yan and Zhang, Siwei and Hilliges, Otmar and Yu, Fisher and Tang, Siyu},
180 | booktitle = {Proceedings of the European Conference on Computer Vision (ECCV)},
181 | year = {2022}
182 | }
183 | ```
184 |
--------------------------------------------------------------------------------
/WholeGraspPose/__pycache__/trainer.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/WholeGraspPose/__pycache__/trainer.cpython-310.pyc
--------------------------------------------------------------------------------
/WholeGraspPose/__pycache__/trainer.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/WholeGraspPose/__pycache__/trainer.cpython-38.pyc
--------------------------------------------------------------------------------
/WholeGraspPose/configs/WholeGraspPose.yaml:
--------------------------------------------------------------------------------
1 | base_lr: 0.0005
2 | batch_size: 128
3 | best_net: null
4 | bps_size: 4096
5 | c_weights_path: null
6 | cuda_id: 0
7 | dataset_dir: null
8 | kl_coef: 0.005
9 | latentD: 16
10 | log_every_epoch: 10
11 | n_epochs: 100
12 | n_markers: 512
13 | n_neurons: 512
14 | n_workers: 8
15 | reg_coef: 0.0005
16 | # rhm_path: null
17 | seed: 4815
18 | try_num: 0
19 | use_multigpu: false
20 | vpe_path: null
21 | work_dir: null
22 | load_on_ram: False
23 | cond_object_height: True
24 | # gender: male # 'male' / 'female' /
25 | motion_intent: False # 'use'/'offhand'/'pass'/'lift'/ False(all intent)
26 | object_class: ['all'] # ['', ''] / ['all']
27 | robustkl: False
28 | kl_annealing: True
29 | kl_annealing_epoch: 100
30 | marker_weight: 1
31 | foot_weight: 0
32 | collision_weight: 0
33 | consistency_weight: 1
34 | dropout: 0.1
35 | obj_feature: 12
36 | pointnet_hc: 64
37 | continue_train: False
38 | data_representation: 'markers_143' # 'joints' / 'markers_143' / 'markers_214' / 'markers_593'
--------------------------------------------------------------------------------
/WholeGraspPose/configs/rhand_weight.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/WholeGraspPose/configs/rhand_weight.npy
--------------------------------------------------------------------------------
/WholeGraspPose/configs/verts_per_edge.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/WholeGraspPose/configs/verts_per_edge.npy
--------------------------------------------------------------------------------
/WholeGraspPose/data/__pycache__/dataloader.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/WholeGraspPose/data/__pycache__/dataloader.cpython-310.pyc
--------------------------------------------------------------------------------
/WholeGraspPose/data/__pycache__/dataloader.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/WholeGraspPose/data/__pycache__/dataloader.cpython-38.pyc
--------------------------------------------------------------------------------
/WholeGraspPose/data/dataloader.py:
--------------------------------------------------------------------------------
1 | import gc
2 | import json
3 | import os
4 | import time
5 |
6 | import numpy as np
7 | import torch
8 | from smplx.lbs import batch_rodrigues
9 | from torch.utils import data
10 |
11 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12 | to_cpu = lambda tensor: tensor.detach().cpu().numpy()
13 |
14 |
15 | class LoadData(data.Dataset):
16 | def __init__(self,
17 | dataset_dir,
18 | ds_name='train',
19 | gender=None,
20 | motion_intent=None,
21 | object_class=['all'],
22 | dtype=torch.float32,
23 | data_type = 'markers_143'):
24 |
25 | super().__init__()
26 |
27 | print('Preparing {} data...'.format(ds_name.upper()))
28 | self.sbj_idxs = []
29 | self.objs_frames = {}
30 | self.ds_path = os.path.join(dataset_dir, ds_name)
31 | self.gender = gender
32 | self.motion_intent = motion_intent
33 | self.object_class = object_class
34 | self.data_type = data_type
35 |
36 | with open('body_utils/smplx_markerset.json') as f:
37 | markerset = json.load(f)['markersets']
38 | self.markers_idx = []
39 | for marker in markerset:
40 | if marker['type'] not in ['palm_5']: # 'palm_5' contains selected 5 markers per palm, but for training we use 'palm' set where there are 22 markers per palm.
41 | self.markers_idx += list(marker['indices'].values())
42 | print(len(self.markers_idx))
43 | self.ds = self.load_full_data(self.ds_path)
44 |
45 | def load_full_data(self, path):
46 | rec_list = []
47 | output = {}
48 |
49 | markers_list = []
50 | transf_transl_list = []
51 | verts_object_list = []
52 | contacts_object_list = []
53 | normal_object_list = []
54 | transl_object_list = []
55 | global_orient_object_list = []
56 | rotmat_list = []
57 | contacts_markers_list = []
58 | body_list = {}
59 | for key in ['transl', 'global_orient', 'body_pose', 'jaw_pose', 'leye_pose', 'reye_pose', 'left_hand_pose', 'right_hand_pose', 'expression']:
60 | body_list[key] = []
61 |
62 | subsets_dict = {'male':['s1', 's2', 's8', 's9', 's10'],
63 | 'female': ['s3', 's4', 's5', 's6', 's7']}
64 | subsets = subsets_dict[self.gender]
65 |
66 | print('loading {} dataset: {}'.format(self.gender, subsets))
67 | for subset in subsets:
68 | subset_path = os.path.join(path, subset)
69 | rec_list += [os.path.join(subset_path, i) for i in os.listdir(subset_path)]
70 |
71 | index = 0
72 |
73 | for rec in rec_list:
74 | data = np.load(rec, allow_pickle=True)
75 |
76 | ## select object
77 | obj_name = rec.split('/')[-1].split('_')[0]
78 | if 'all' not in self.object_class:
79 | if obj_name not in self.object_class:
80 | continue
81 |
82 | verts_object_list.append(data['verts_object'])
83 | markers_list.append(data[self.data_type])
84 | transf_transl_list.append(data['transf_transl'])
85 | normal_object_list.append(data['normal_object'])
86 | global_orient_object_list.append(data['global_orient_object'])
87 |
88 | orient = torch.tensor(data['global_orient_object'])
89 | rot_mats = batch_rodrigues(orient.view(-1, 3)).view([orient.shape[0], 9]).numpy()
90 | rotmat_list.append(rot_mats)
91 |
92 | object_contact = data['contact_object']
93 | markers_contact = data['contact_body'][:, self.markers_idx]
94 | object_contact_binary = (object_contact>0).astype(int)
95 | contacts_object_list.append(object_contact_binary)
96 | markers_contact_binary = (markers_contact>0).astype(int)
97 | contacts_markers_list.append(markers_contact_binary)
98 |
99 | # SMPLX parameters (optional)
100 | for key in data['body'][()].keys():
101 | body_list[key].append(data['body'][()][key])
102 |
103 | sbj_id = rec.split('/')[-2]
104 | self.sbj_idxs += [sbj_id]*data['verts_object'].shape[0]
105 | if obj_name in self.objs_frames.keys():
106 | self.objs_frames[obj_name] += list(range(index, index+data['verts_object'].shape[0]))
107 | else:
108 | self.objs_frames[obj_name] = list(range(index, index+data['verts_object'].shape[0]))
109 | index += data['verts_object'].shape[0]
110 | output['transf_transl'] = torch.tensor(np.concatenate(transf_transl_list, axis=0))
111 | output['markers'] = torch.tensor(np.concatenate(markers_list, axis=0)) # (B, 99, 3)
112 | output['verts_object'] = torch.tensor(np.concatenate(verts_object_list, axis=0)) # (B, 2048, 3)
113 | output['contacts_object'] = torch.tensor(np.concatenate(contacts_object_list, axis=0)) # (B, 2048, 3)
114 | output['contacts_markers'] = torch.tensor(np.concatenate(contacts_markers_list, axis=0)) # (B, 2048, 3)
115 | output['normal_object'] = torch.tensor(np.concatenate(normal_object_list, axis=0)) # (B, 2048, 3)
116 | output['global_orient_object'] = torch.tensor(np.concatenate(global_orient_object_list, axis=0)) # (B, 2048, 3)
117 | output['rotmat'] = torch.tensor(np.concatenate(rotmat_list, axis=0)) # (B, 2048, 3)
118 |
119 | # SMPLX parameters
120 | output['smplxparams'] = {}
121 | for key in ['transl', 'global_orient', 'body_pose', 'jaw_pose', 'leye_pose', 'reye_pose', 'left_hand_pose', 'right_hand_pose', 'expression']:
122 | output['smplxparams'][key] = torch.tensor(np.concatenate(body_list[key], axis=0))
123 |
124 | return output
125 |
126 | def __len__(self):
127 | k = list(self.ds.keys())[0]
128 | return self.ds[k].shape[0]
129 |
130 | def __getitem__(self, idx):
131 |
132 | data_out = {}
133 |
134 | data_out['markers'] = self.ds['markers'][idx]
135 | data_out['contacts_markers'] = self.ds['contacts_markers'][idx]
136 | data_out['verts_object'] = self.ds['verts_object'][idx]
137 | data_out['normal_object'] = self.ds['normal_object'][idx]
138 | data_out['global_orient_object'] = self.ds['global_orient_object'][idx]
139 | data_out['transf_transl'] = self.ds['transf_transl'][idx]
140 | data_out['contacts_object'] = self.ds['contacts_object'][idx]
141 | if len(data_out['verts_object'].shape) == 2:
142 | data_out['feat_object'] = torch.cat([self.ds['normal_object'][idx], self.ds['rotmat'][idx, :6].view(1, 6).repeat(2048, 1)], -1)
143 | else:
144 | data_out['feat_object'] = torch.cat([self.ds['normal_object'][idx], self.ds['rotmat'][idx, :6].view(-1, 1, 6).repeat(1, 2048, 1)], -1)
145 |
146 | """You may want to uncomment it when you need smplxparams!!!"""
147 | data_out['smplxparams'] = {}
148 | for key in ['transl', 'global_orient', 'body_pose', 'jaw_pose', 'leye_pose', 'reye_pose', 'left_hand_pose', 'right_hand_pose', 'expression']:
149 | data_out['smplxparams'][key] = self.ds['smplxparams'][key][idx]
150 |
151 | ## random rotation augmentation
152 | bsz = 1
153 | theta = torch.FloatTensor(np.random.uniform(-np.pi/6, np.pi/6, bsz))
154 | orient = torch.zeros((bsz, 3))
155 | orient[:, -1] = theta
156 | rot_mats = batch_rodrigues(orient.view(-1, 3)).view([bsz, 3, 3])
157 | if len(data_out['verts_object'].shape) == 3:
158 | data_out['markers'] = torch.matmul(data_out['markers'][:, :, :3], rot_mats.squeeze())
159 | data_out['verts_object'] = torch.matmul(data_out['verts_object'][:, :, :3], rot_mats.squeeze())
160 | data_out['normal_object'][:, :, :3] = torch.matmul(data_out['normal_object'][:, :, :3], rot_mats.squeeze())
161 | else:
162 | data_out['markers'] = torch.matmul(data_out['markers'][:, :3], rot_mats.squeeze())
163 | data_out['verts_object'] = torch.matmul(data_out['verts_object'][:, :3], rot_mats.squeeze())
164 | data_out['normal_object'][:, :3] = torch.matmul(data_out['normal_object'][:, :3], rot_mats.squeeze())
165 |
166 | return data_out
167 |
--------------------------------------------------------------------------------
/WholeGraspPose/models/__pycache__/fittingop.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/WholeGraspPose/models/__pycache__/fittingop.cpython-38.pyc
--------------------------------------------------------------------------------
/WholeGraspPose/models/__pycache__/models.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/WholeGraspPose/models/__pycache__/models.cpython-310.pyc
--------------------------------------------------------------------------------
/WholeGraspPose/models/__pycache__/models.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/WholeGraspPose/models/__pycache__/models.cpython-38.pyc
--------------------------------------------------------------------------------
/WholeGraspPose/models/__pycache__/objectmodel.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/WholeGraspPose/models/__pycache__/objectmodel.cpython-38.pyc
--------------------------------------------------------------------------------
/WholeGraspPose/models/__pycache__/pointnet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/WholeGraspPose/models/__pycache__/pointnet.cpython-38.pyc
--------------------------------------------------------------------------------
/WholeGraspPose/models/__pycache__/pointnet_util.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/WholeGraspPose/models/__pycache__/pointnet_util.cpython-310.pyc
--------------------------------------------------------------------------------
/WholeGraspPose/models/__pycache__/pointnet_util.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/WholeGraspPose/models/__pycache__/pointnet_util.cpython-38.pyc
--------------------------------------------------------------------------------
/WholeGraspPose/models/models.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | sys.path.append('.')
4 | sys.path.append('..')
5 | import numpy as np
6 | import torch
7 | from torch import nn
8 | from torch.nn import functional as F
9 | from WholeGraspPose.models.pointnet import (PointNetFeaturePropagation,
10 | PointNetSetAbstraction)
11 |
12 |
13 | class ResBlock(nn.Module):
14 | def __init__(self,
15 | Fin,
16 | Fout,
17 | n_neurons=256):
18 |
19 | super(ResBlock, self).__init__()
20 | self.Fin = Fin
21 | self.Fout = Fout
22 |
23 | self.fc1 = nn.Linear(Fin, n_neurons)
24 | self.bn1 = nn.BatchNorm1d(n_neurons)
25 |
26 | self.fc2 = nn.Linear(n_neurons, Fout)
27 | self.bn2 = nn.BatchNorm1d(Fout)
28 |
29 | if Fin != Fout:
30 | self.fc3 = nn.Linear(Fin, Fout)
31 |
32 | self.ll = nn.LeakyReLU(negative_slope=0.2)
33 |
34 | def forward(self, x, final_nl=True):
35 | Xin = x if self.Fin == self.Fout else self.ll(self.fc3(x))
36 |
37 | Xout = self.fc1(x) # n_neurons
38 | Xout = self.bn1(Xout)
39 | Xout = self.ll(Xout)
40 |
41 | Xout = self.fc2(Xout)
42 | Xout = self.bn2(Xout)
43 | Xout = Xin + Xout
44 |
45 | if final_nl:
46 | return self.ll(Xout)
47 | return Xout
48 |
49 |
50 | class PointNetEncoder(nn.Module):
51 |
52 | def __init__(self,
53 | hc,
54 | in_feature):
55 |
56 | super(PointNetEncoder, self).__init__()
57 | self.hc = hc
58 | self.in_feature = in_feature
59 |
60 | self.enc_sa1 = PointNetSetAbstraction(npoint=256, radius=0.2, nsample=32, in_channel=self.in_feature, mlp=[self.hc, self.hc*2], group_all=False)
61 | self.enc_sa2 = PointNetSetAbstraction(npoint=128, radius=0.25, nsample=64, in_channel=self.hc*2 + 3, mlp=[self.hc*2, self.hc*4], group_all=False)
62 | self.enc_sa3 = PointNetSetAbstraction(npoint=None, radius=None, nsample=None, in_channel=self.hc*4 + 3, mlp=[self.hc*4, self.hc*8], group_all=True)
63 |
64 | def forward(self, l0_xyz, l0_points):
65 |
66 | l1_xyz, l1_points = self.enc_sa1(l0_xyz, l0_points)
67 | l2_xyz, l2_points = self.enc_sa2(l1_xyz, l1_points)
68 | l3_xyz, l3_points = self.enc_sa3(l2_xyz, l2_points)
69 | x = l3_points.view(-1, self.hc*8)
70 |
71 | return l1_xyz, l1_points, l2_xyz, l2_points, l3_xyz, x
72 |
73 | class MarkerNet(nn.Module):
74 | def __init__(self, cfg, n_neurons=1024, in_cond=1024, latentD=16, in_feature=143*3, **kwargs):
75 | super(MarkerNet, self).__init__()
76 |
77 | self.cfg = cfg
78 | ## condition features
79 | self.obj_cond_feature = 1 if cfg.cond_object_height else 0
80 |
81 | self.enc_bn1 = nn.BatchNorm1d(in_feature + self.obj_cond_feature)
82 | self.enc_rb1 = ResBlock(in_cond + in_feature + int(in_feature/3) + self.obj_cond_feature, n_neurons)
83 | self.enc_rb2 = ResBlock(n_neurons + in_cond + in_feature + int(in_feature/3) + self.obj_cond_feature, n_neurons)
84 |
85 | self.dec_rb1 = ResBlock(latentD + in_cond, n_neurons)
86 | self.dec_rb2_xyz = ResBlock(n_neurons + latentD + in_cond + self.obj_cond_feature, n_neurons)
87 | self.dec_rb2_p = ResBlock(n_neurons + latentD + in_cond, n_neurons)
88 |
89 | self.dec_output_xyz = nn.Linear(n_neurons, 143*3)
90 | self.dec_output_p = nn.Linear(n_neurons, 143)
91 | self.p_output = nn.Sigmoid()
92 |
93 | def enc(self, cond_object, markers, contacts_markers, transf_transl):
94 | _, _, _, _, _, object_cond = cond_object
95 |
96 | X = markers.view(markers.shape[0], -1)
97 |
98 | if self.obj_cond_feature == 1:
99 | X = torch.cat([X, transf_transl[:, -1, None]], dim=-1).float()
100 |
101 | X0 = self.enc_bn1(X)
102 |
103 | X0 = torch.cat([X0, contacts_markers.view(-1, 143), object_cond], dim=-1)
104 |
105 |
106 | X = self.enc_rb1(X0, True)
107 | X = self.enc_rb2(torch.cat([X0, X], dim=1), True)
108 |
109 | return X
110 |
111 | def dec(self, Z, cond_object, transf_transl):
112 |
113 | _, _, _, _, _, object_cond = cond_object
114 | X0 = torch.cat([Z, object_cond], dim=1).float()
115 |
116 | X = self.dec_rb1(X0, True)
117 | X_xyz = self.dec_rb2_xyz(torch.cat([X0, X, transf_transl[:, -1, None]], dim=1).float(), True)
118 | X_p = self.dec_rb2_p(torch.cat([X0, X], dim=1).float(), True)
119 |
120 | xyz_pred = self.dec_output_xyz(X_xyz)
121 | p_pred = self.p_output(self.dec_output_p(X_p))
122 |
123 | return xyz_pred, p_pred
124 |
125 | class ContactNet(nn.Module):
126 | def __init__(self, cfg, latentD=16, hc=64, object_feature=6, **kwargs):
127 | super(ContactNet, self).__init__()
128 | self.latentD = latentD
129 | self.hc = hc
130 | self.object_feature = object_feature
131 |
132 | self.enc_pointnet = PointNetEncoder(self.hc, self.object_feature+1)
133 |
134 | self.dec_fc1 = nn.Linear(self.latentD, self.hc*2)
135 | self.dec_bn1 = nn.BatchNorm1d(self.hc*2)
136 | self.dec_drop1 = nn.Dropout(0.1)
137 | self.dec_fc2 = nn.Linear(self.hc*2, self.hc*4)
138 | self.dec_bn2 = nn.BatchNorm1d(self.hc*4)
139 | self.dec_drop2 = nn.Dropout(0.1)
140 | self.dec_fc3 = nn.Linear(self.hc*4, self.hc*8)
141 | self.dec_bn3 = nn.BatchNorm1d(self.hc*8)
142 | self.dec_drop3 = nn.Dropout(0.1)
143 |
144 | self.dec_fc4 = nn.Linear(self.hc*8+self.latentD, self.hc*8)
145 | self.dec_bn4 = nn.BatchNorm1d(self.hc*8)
146 | self.dec_drop4 = nn.Dropout(0.1)
147 |
148 | self.dec_fp3 = PointNetFeaturePropagation(in_channel=self.hc*8+self.hc*4, mlp=[self.hc*8, self.hc*4])
149 | self.dec_fp2 = PointNetFeaturePropagation(in_channel=self.hc*4+self.hc*2, mlp=[self.hc*4, self.hc*2])
150 | self.dec_fp1 = PointNetFeaturePropagation(in_channel=self.hc*2+self.object_feature, mlp=[self.hc*2, self.hc*2])
151 |
152 | self.dec_conv1 = nn.Conv1d(self.hc*2, self.hc*2, 1)
153 | self.dec_conv_bn1 = nn.BatchNorm1d(self.hc*2)
154 | self.dec_conv_drop1 = nn.Dropout(0.1)
155 | self.dec_conv2 = nn.Conv1d(self.hc*2, 1, 1)
156 |
157 | self.dec_output = nn.Sigmoid()
158 |
159 | def enc(self, contacts_object, verts_object, feat_object):
160 | l0_xyz = verts_object[:, :3, :]
161 | l0_points = torch.cat([feat_object, contacts_object], 1) if feat_object is not None else contacts_object
162 | _, _, _, _, _, x = self.enc_pointnet(l0_xyz, l0_points)
163 |
164 | return x
165 |
166 | def dec(self, z, cond_object, verts_object, feat_object):
167 | l0_xyz = verts_object[:, :3, :]
168 | l0_points = feat_object
169 |
170 | l1_xyz, l1_points, l2_xyz, l2_points, l3_xyz, l3_points = cond_object
171 |
172 | l3_points = torch.cat([l3_points, z], 1)
173 | l3_points = self.dec_drop4(F.relu(self.dec_bn4(self.dec_fc4(l3_points)), inplace=True))
174 | l3_points = l3_points.view(l3_points.size()[0], l3_points.size()[1], 1)
175 |
176 | l2_points = self.dec_fp3(l2_xyz, l3_xyz, l2_points, l3_points)
177 | l1_points = self.dec_fp2(l1_xyz, l2_xyz, l1_points, l2_points)
178 | if l0_points is None:
179 | l0_points = self.dec_fp1(l0_xyz, l1_xyz, l0_xyz, l1_points)
180 | else:
181 | l0_points = self.dec_fp1(l0_xyz, l1_xyz, torch.cat([l0_xyz,l0_points],1), l1_points)
182 | feat = F.relu(self.dec_conv_bn1(self.dec_conv1(l0_points)), inplace=True)
183 | x = self.dec_conv_drop1(feat)
184 | x = self.dec_conv2(x)
185 | x = self.dec_output(x)
186 |
187 | return x
188 |
189 |
190 | class FullBodyGraspNet(nn.Module):
191 | def __init__(self, cfg, **kwargs):
192 | super(FullBodyGraspNet, self).__init__()
193 |
194 | self.cfg = cfg
195 | self.latentD = cfg.latentD
196 | self.in_feature_list = {}
197 | self.in_feature_list['joints'] = 127*3
198 | self.in_feature_list['markers_143'] = 143*3
199 | self.in_feature_list['markers_214'] = 214*3
200 | self.in_feature_list['markers_593'] = 593*3
201 |
202 | self.in_feature = self.in_feature_list[cfg.data_representation]
203 |
204 | self.marker_net = MarkerNet(cfg, n_neurons=cfg.n_markers, in_cond=cfg.pointnet_hc*8, latentD=cfg.latentD, in_feature=self.in_feature)
205 | self.contact_net = ContactNet(cfg, latentD=cfg.latentD, hc=cfg.pointnet_hc, object_feature=cfg.obj_feature)
206 |
207 | self.pointnet = PointNetEncoder(hc=cfg.pointnet_hc, in_feature=cfg.obj_feature)
208 | # encoder fusion
209 | self.enc_fusion = ResBlock(cfg.n_markers+self.cfg.pointnet_hc*8, cfg.n_neurons)
210 |
211 | self.enc_mu = nn.Linear(cfg.n_neurons, cfg.latentD)
212 | self.enc_var = nn.Linear(cfg.n_neurons, cfg.latentD)
213 |
214 |
215 | def encode(self, object_cond, verts_object, feat_object, contacts_object, markers, contacts_markers, transf_transl):
216 | # marker branch
217 | marker_feat = self.marker_net.enc(object_cond, markers, contacts_markers, transf_transl) # [B, n_neurons=1024]
218 |
219 | # contact branch
220 | contact_feat = self.contact_net.enc(contacts_object, verts_object, feat_object) # [B, hc*8]
221 |
222 | # fusion
223 | X = torch.cat([marker_feat, contact_feat], dim=-1)
224 | X = self.enc_fusion(X, True)
225 |
226 | return torch.distributions.normal.Normal(self.enc_mu(X), F.softplus(self.enc_var(X)))
227 |
228 |
229 | def decode(self, Z, object_cond, verts_object, feat_object, transf_transl):
230 |
231 | bs = Z.shape[0]
232 | # marker_branch
233 | markers_xyz_pred, markers_p_pred = self.marker_net.dec(Z, object_cond, transf_transl)
234 |
235 | # contact branch
236 | contact_pred = self.contact_net.dec(Z, object_cond, verts_object, feat_object)
237 |
238 | return markers_xyz_pred.view(bs, -1, 3), markers_p_pred, contact_pred
239 |
240 | def forward(self, verts_object, feat_object, contacts_object, markers, contacts_markers, transf_transl, **kwargs):
241 | object_cond = self.pointnet(l0_xyz=verts_object, l0_points=feat_object)
242 | z = self.encode(object_cond, verts_object, feat_object, contacts_object, markers, contacts_markers, transf_transl)
243 | z_s = z.rsample()
244 |
245 | markers_xyz_pred, markers_p_pred, object_p_pred = self.decode(z_s, object_cond, verts_object, feat_object, transf_transl)
246 |
247 | results = {'markers': markers_xyz_pred, 'contacts_markers': markers_p_pred, 'contacts_object': object_p_pred, 'object_code': object_cond[-1], 'mean': z.mean, 'std': z.scale}
248 |
249 | return results
250 |
251 |
252 | def sample(self, verts_object, feat_object, transf_transl, seed=None):
253 | bs = verts_object.shape[0]
254 | if seed is not None:
255 | np.random.seed(seed)
256 | dtype = verts_object.dtype
257 | device = verts_object.device
258 | self.eval()
259 | with torch.no_grad():
260 | Zgen = np.random.normal(0., 1., size=(bs, self.latentD))
261 | Zgen = torch.tensor(Zgen,dtype=dtype).to(device)
262 |
263 | object_cond = self.pointnet(l0_xyz=verts_object, l0_points=feat_object)
264 |
265 | return self.decode(Zgen, object_cond, verts_object, feat_object, transf_transl)
266 |
--------------------------------------------------------------------------------
/WholeGraspPose/models/objectmodel.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | from smplx.lbs import batch_rodrigues
7 |
8 | model_output = namedtuple('output', ['vertices', 'vertex_normals', 'global_orient', 'transl'])
9 |
10 | class ObjectModel(nn.Module):
11 |
12 | def __init__(self,
13 | v_template,
14 | normal_template,
15 | batch_size=1,
16 | dtype=torch.float32):
17 |
18 | super(ObjectModel, self).__init__()
19 |
20 |
21 | self.dtype = dtype
22 | self.batch_size = batch_size
23 |
24 |
25 | def forward(self, global_orient=None, transl=None, v_template=None, n_template=None, rotmat=False, **kwargs):
26 |
27 |
28 | if global_orient is None:
29 | global_orient = self.global_orient
30 | if transl is None:
31 | transl = self.transl
32 | if v_template is None:
33 | v_template = self.v_template
34 | if n_template is None:
35 | n_template = self.n_template
36 |
37 | if not rotmat:
38 | rot_mats = batch_rodrigues(global_orient.view(-1, 3)).view([self.batch_size, 3, 3])
39 | else:
40 | rot_mats = global_orient.view([self.batch_size, 3, 3])
41 |
42 | vertices = torch.matmul(v_template, rot_mats) + transl.unsqueeze(dim=1)
43 |
44 | vertex_normals = torch.matmul(n_template, rot_mats)
45 |
46 | output = model_output(vertices=vertices,
47 | vertex_normals = vertex_normals,
48 | global_orient=global_orient,
49 | transl=transl)
50 |
51 | return output
52 |
53 |
--------------------------------------------------------------------------------
/WholeGraspPose/models/pointnet.py:
--------------------------------------------------------------------------------
1 | from time import time
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 |
9 | def timeit(tag, t):
10 | print("{}: {}s".format(tag, time() - t))
11 | return time()
12 |
13 | def pc_normalize(pc):
14 | l = pc.shape[0]
15 | centroid = np.mean(pc, axis=0)
16 | pc = pc - centroid
17 | m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
18 | pc = pc / m
19 | return pc
20 |
21 | def square_distance(src, dst):
22 | """
23 | Calculate Euclid distance between each two points.
24 |
25 | src^T * dst = xn * xm + yn * ym + zn * zm;
26 | sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
27 | sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
28 | dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
29 | = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
30 |
31 | Input:
32 | src: source points, [B, N, C]
33 | dst: target points, [B, M, C]
34 | Output:
35 | dist: per-point square distance, [B, N, M]
36 | """
37 | B, N, _ = src.shape
38 | _, M, _ = dst.shape
39 | dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
40 | dist += torch.sum(src ** 2, -1).view(B, N, 1)
41 | dist += torch.sum(dst ** 2, -1).view(B, 1, M)
42 | return dist
43 |
44 |
45 | def index_points(points, idx):
46 | """
47 |
48 | Input:
49 | points: input points data, [B, N, C]
50 | idx: sample index data, [B, S]
51 | Return:
52 | new_points:, indexed points data, [B, S, C]
53 | """
54 | device = points.device
55 | B = points.shape[0]
56 | view_shape = list(idx.shape)
57 | view_shape[1:] = [1] * (len(view_shape) - 1)
58 | repeat_shape = list(idx.shape)
59 | repeat_shape[0] = 1
60 | batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
61 | new_points = points[batch_indices, idx, :]
62 | return new_points
63 |
64 |
65 | def farthest_point_sample(xyz, npoint):
66 | """
67 | Input:
68 | xyz: pointcloud data, [B, N, 3]
69 | npoint: number of samples
70 | Return:
71 | centroids: sampled pointcloud index, [B, npoint]
72 | """
73 | device = xyz.device
74 | B, N, C = xyz.shape
75 | centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
76 | distance = torch.ones(B, N).to(device) * 1e10
77 | farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
78 | batch_indices = torch.arange(B, dtype=torch.long).to(device)
79 | for i in range(npoint):
80 | centroids[:, i] = farthest
81 | centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
82 | dist = torch.sum((xyz - centroid) ** 2, -1)
83 | mask = dist < distance
84 | distance[mask] = dist[mask]
85 | farthest = torch.max(distance, -1)[1]
86 | return centroids
87 |
88 |
89 | def query_ball_point(radius, nsample, xyz, new_xyz):
90 | """
91 | Input:
92 | radius: local region radius
93 | nsample: max sample number in local region
94 | xyz: all points, [B, N, 3]
95 | new_xyz: query points, [B, S, 3]
96 | Return:
97 | group_idx: grouped points index, [B, S, nsample]
98 | """
99 | device = xyz.device
100 | B, N, C = xyz.shape
101 | _, S, _ = new_xyz.shape
102 | group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
103 | sqrdists = square_distance(new_xyz, xyz)
104 | group_idx[sqrdists > radius ** 2] = N
105 | group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
106 | group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
107 | mask = group_idx == N
108 | group_idx[mask] = group_first[mask]
109 | return group_idx
110 |
111 |
112 | def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False):
113 | """
114 | Input:
115 | npoint:
116 | radius:
117 | nsample:
118 | xyz: input points position data, [B, N, 3]
119 | points: input points data, [B, N, D]
120 | Return:
121 | new_xyz: sampled points position data, [B, npoint, nsample, 3]
122 | new_points: sampled points data, [B, npoint, nsample, 3+D]
123 | """
124 | B, N, C = xyz.shape
125 | S = npoint
126 | fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C]
127 | torch.cuda.empty_cache()
128 | new_xyz = index_points(xyz, fps_idx)
129 | torch.cuda.empty_cache()
130 | idx = query_ball_point(radius, nsample, xyz, new_xyz)
131 | torch.cuda.empty_cache()
132 | grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C]
133 | torch.cuda.empty_cache()
134 | grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C)
135 | torch.cuda.empty_cache()
136 |
137 | if points is not None:
138 | grouped_points = index_points(points, idx)
139 | new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D]
140 | else:
141 | new_points = grouped_xyz_norm
142 | if returnfps:
143 | return new_xyz, new_points, grouped_xyz, fps_idx
144 | else:
145 | return new_xyz, new_points
146 |
147 |
148 | def sample_and_group_all(xyz, points):
149 | """
150 | Input:
151 | xyz: input points position data, [B, N, 3]
152 | points: input points data, [B, N, D]
153 | Return:
154 | new_xyz: sampled points position data, [B, 1, 3]
155 | new_points: sampled points data, [B, 1, N, 3+D]
156 | """
157 | device = xyz.device
158 | B, N, C = xyz.shape
159 | new_xyz = torch.zeros(B, 1, C).to(device)
160 | grouped_xyz = xyz.view(B, 1, N, C)
161 | if points is not None:
162 | new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1)
163 | else:
164 | new_points = grouped_xyz
165 | return new_xyz, new_points
166 |
167 |
168 | class PointNetSetAbstraction(nn.Module):
169 | def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all):
170 | super(PointNetSetAbstraction, self).__init__()
171 | self.npoint = npoint
172 | self.radius = radius
173 | self.nsample = nsample
174 | self.mlp_convs = nn.ModuleList()
175 | self.mlp_bns = nn.ModuleList()
176 | last_channel = in_channel
177 | for out_channel in mlp:
178 | self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))
179 | self.mlp_bns.append(nn.BatchNorm2d(out_channel))
180 | last_channel = out_channel
181 | self.group_all = group_all
182 |
183 | def forward(self, xyz, points):
184 | """
185 | Input:
186 | xyz: input points position data, [B, C, N]
187 | points: input points data, [B, D, N]
188 | Return:
189 | new_xyz: sampled points position data, [B, C, S]
190 | new_points_concat: sample points feature data, [B, D', S]
191 | """
192 | xyz = xyz.permute(0, 2, 1)
193 | if points is not None:
194 | points = points.permute(0, 2, 1)
195 | # print('before sample:', points.shape)
196 |
197 | if self.group_all:
198 | new_xyz, new_points = sample_and_group_all(xyz, points)
199 | else:
200 | new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points)
201 | # new_xyz: sampled points position data, [B, npoint, C]
202 | # new_points: sampled points data, [B, npoint, nsample, C+D]
203 | # print('after sample:', new_points.shape)
204 | new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint]
205 | for i, conv in enumerate(self.mlp_convs):
206 | bn = self.mlp_bns[i]
207 | new_points = F.relu(bn(conv(new_points)), inplace=True)
208 | # print('after conv:', new_points.shape)
209 | new_points = torch.max(new_points, 2)[0]
210 | new_xyz = new_xyz.permute(0, 2, 1)
211 | return new_xyz, new_points
212 |
213 |
214 | class PointNetSetAbstractionMsg(nn.Module):
215 | def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list):
216 | super(PointNetSetAbstractionMsg, self).__init__()
217 | self.npoint = npoint
218 | self.radius_list = radius_list
219 | self.nsample_list = nsample_list
220 | self.conv_blocks = nn.ModuleList()
221 | self.bn_blocks = nn.ModuleList()
222 | for i in range(len(mlp_list)):
223 | convs = nn.ModuleList()
224 | bns = nn.ModuleList()
225 | last_channel = in_channel + 3
226 | for out_channel in mlp_list[i]:
227 | convs.append(nn.Conv2d(last_channel, out_channel, 1))
228 | bns.append(nn.BatchNorm2d(out_channel))
229 | last_channel = out_channel
230 | self.conv_blocks.append(convs)
231 | self.bn_blocks.append(bns)
232 |
233 | def forward(self, xyz, points):
234 | """
235 | Input:
236 | xyz: input points position data, [B, C, N]
237 | points: input points data, [B, D, N]
238 | Return:
239 | new_xyz: sampled points position data, [B, C, S]
240 | new_points_concat: sample points feature data, [B, D', S]
241 | """
242 | xyz = xyz.permute(0, 2, 1)
243 | if points is not None:
244 | points = points.permute(0, 2, 1)
245 |
246 | B, N, C = xyz.shape
247 | S = self.npoint
248 | new_xyz = index_points(xyz, farthest_point_sample(xyz, S))
249 | new_points_list = []
250 | for i, radius in enumerate(self.radius_list):
251 | K = self.nsample_list[i]
252 | group_idx = query_ball_point(radius, K, xyz, new_xyz)
253 | grouped_xyz = index_points(xyz, group_idx)
254 | grouped_xyz -= new_xyz.view(B, S, 1, C)
255 | if points is not None:
256 | grouped_points = index_points(points, group_idx)
257 | grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1)
258 | else:
259 | grouped_points = grouped_xyz
260 |
261 | grouped_points = grouped_points.permute(0, 3, 2, 1) # [B, D, K, S]
262 | for j in range(len(self.conv_blocks[i])):
263 | conv = self.conv_blocks[i][j]
264 | bn = self.bn_blocks[i][j]
265 | grouped_points = F.relu(bn(conv(grouped_points)), inplace=True)
266 | new_points = torch.max(grouped_points, 2)[0] # [B, D', S]
267 | new_points_list.append(new_points)
268 |
269 | new_xyz = new_xyz.permute(0, 2, 1)
270 | new_points_concat = torch.cat(new_points_list, dim=1)
271 | return new_xyz, new_points_concat
272 |
273 |
274 | class PointNetFeaturePropagation(nn.Module):
275 | def __init__(self, in_channel, mlp):
276 | super(PointNetFeaturePropagation, self).__init__()
277 | self.mlp_convs = nn.ModuleList()
278 | self.mlp_bns = nn.ModuleList()
279 | last_channel = in_channel
280 | for out_channel in mlp:
281 | self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1))
282 | self.mlp_bns.append(nn.BatchNorm1d(out_channel))
283 | last_channel = out_channel
284 |
285 | def forward(self, xyz1, xyz2, points1, points2):
286 | """
287 | Input:
288 | xyz1: input points position data, [B, C, N]
289 | xyz2: sampled input points position data, [B, C, S]
290 | points1: input points data, [B, D, N]
291 | points2: input points data, [B, D, S]
292 | Return:
293 | new_points: upsampled points data, [B, D', N]
294 | """
295 | xyz1 = xyz1.permute(0, 2, 1)
296 | xyz2 = xyz2.permute(0, 2, 1)
297 |
298 | points2 = points2.permute(0, 2, 1)
299 | B, N, C = xyz1.shape
300 | _, S, _ = xyz2.shape
301 |
302 | if S == 1:
303 | interpolated_points = points2.repeat(1, N, 1)
304 | else:
305 | dists = square_distance(xyz1, xyz2)
306 | dists, idx = dists.sort(dim=-1)
307 | dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3]
308 |
309 | dist_recip = 1.0 / (dists + 1e-8)
310 | norm = torch.sum(dist_recip, dim=2, keepdim=True)
311 | weight = dist_recip / norm
312 | interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2)
313 |
314 | if points1 is not None:
315 | points1 = points1.permute(0, 2, 1)
316 | new_points = torch.cat([points1, interpolated_points], dim=-1)
317 | else:
318 | new_points = interpolated_points
319 |
320 | new_points = new_points.permute(0, 2, 1)
321 | for i, conv in enumerate(self.mlp_convs):
322 | bn = self.mlp_bns[i]
323 | new_points = F.relu(bn(conv(new_points)), inplace=True)
324 | return new_points
325 |
326 |
--------------------------------------------------------------------------------
/body_utils/body_models/VPoser/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/body_utils/body_models/VPoser/.DS_Store
--------------------------------------------------------------------------------
/body_utils/body_models/VPoser/vposerDecoderWeights.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/body_utils/body_models/VPoser/vposerDecoderWeights.npz
--------------------------------------------------------------------------------
/body_utils/body_models/VPoser/vposerEncoderWeights.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/body_utils/body_models/VPoser/vposerEncoderWeights.npz
--------------------------------------------------------------------------------
/body_utils/body_models/VPoser/vposerMeanPose.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/body_utils/body_models/VPoser/vposerMeanPose.npz
--------------------------------------------------------------------------------
/body_utils/body_segments/L_Hand.json:
--------------------------------------------------------------------------------
1 | {"faces_ind": [3320, 3333, 3335, 3347, 3352, 3357, 3369, 3376, 3389, 3392, 3427, 3445, 3446, 3450, 3453, 3454, 3456, 3458, 3459, 3463, 3468, 3471, 3477, 3478, 3484, 3488, 3498, 3504, 3514, 3526, 3527, 3532, 3535, 3546, 3549, 3551, 3560, 3561, 3566, 3576, 3578, 3591, 3593, 3594, 3595, 3596, 3597, 3601, 3613, 3614, 3615, 3620, 3625, 3626, 3627, 3628, 3632, 3634, 3638, 3640, 3646, 3647, 3650, 3652, 3653, 3654, 3655, 3659, 3664, 3668, 3681, 3685, 3686, 3687, 3690, 3695, 3698, 3708, 3709, 3714, 3715, 3741, 3753, 3756, 3771, 3788, 3823, 3825, 3826, 3831, 3833, 3990, 3991, 3996, 3998, 3999, 4004, 4006, 4011, 4083, 4093, 4130, 4131, 4136, 4145, 4152, 4155, 4166, 4273, 4342, 4359, 4360, 4366, 4376, 4378, 4384, 4406, 4649, 4653, 4696, 4700, 4707, 4708, 4716, 4722, 4773, 4777, 4794, 4804, 4825, 4837, 4948, 4963, 5019, 5022, 5026, 5033, 5042, 5046, 5091, 5133, 5150, 5156, 5168, 5178, 5179, 5180, 5193, 5196, 5224, 5247, 5252, 5258, 5304, 5306, 5365, 5550, 5558, 5564, 5576, 5589, 5616, 5646, 5647, 5648, 5661, 5668, 5670, 5677, 5716, 5721, 5767, 5774, 5789, 5802, 5803, 5809, 5878, 5911, 13854, 13867, 13869, 13871, 13881, 13886, 13891, 13903, 13909, 13922, 13925, 13954, 13960, 13977, 13978, 13982, 13985, 13986, 13988, 13990, 13991, 13995, 14000, 14003, 14009, 14010, 14016, 14020, 14030, 14036, 14046, 14058, 14059, 14064, 14067, 14078, 14081, 14083, 14092, 14093, 14098, 14108, 14110, 14123, 14125, 14126, 14127, 14128, 14129, 14133, 14145, 14146, 14147, 14152, 14157, 14158, 14159, 14160, 14164, 14166, 14170, 14172, 14178, 14182, 14184, 14185, 14186, 14187, 14189, 14191, 14196, 14198, 14200, 14213, 14217, 14218, 14219, 14222, 14223, 14224, 14227, 14230, 14240, 14241, 14246, 14247, 14273, 14285, 14288, 14303, 14320, 14355, 14357, 14358, 14363, 14365, 14522, 14523, 14528, 14531, 14536, 14538, 14543, 14615, 14625, 14662, 14663, 14668, 14684, 14687, 14698, 14804, 14873, 14890, 14891, 14897, 14907, 14915, 14937, 15180, 15184, 15202, 15227, 15231, 15238, 15239, 15247, 15253, 15308, 15324, 15325, 15335, 15356, 15368, 15494, 15550, 15553, 15557, 15573, 15620, 15662, 15679, 15685, 15697, 15707, 15708, 15709, 15722, 15725, 15753, 15772, 15776, 15781, 15787, 15831, 15833, 15835, 15840, 15894, 16079, 16087, 16093, 16105, 16118, 16145, 16175, 16176, 16177, 16190, 16197, 16199, 16206, 16249, 16295, 16302, 16317, 16330, 16331, 16337, 16406, 16439], "verts_ind": [4595, 4596, 4599, 4600, 4601, 4602, 4603, 4604, 4605, 4606, 4607, 4627, 4628, 4629, 4630, 4631, 4634, 4637, 4639, 4640, 4641, 4642, 4643, 4644, 4656, 4657, 4658, 4659, 4660, 4669, 4670, 4671, 4672, 4688, 4691, 4692, 4693, 4694, 4695, 4696, 4697, 4698, 4702, 4704, 4707, 4708, 4709, 4710, 4711, 4727, 4728, 4729, 4730, 4731, 4732, 4734, 4735, 4736, 4741, 4743, 4744, 4745, 4746, 4753, 4754, 4755, 4756, 4758, 4770, 4773, 4774, 4775, 4776, 4777, 4778, 4779, 4800, 4801, 4802, 4803, 4804, 4805, 4836, 4846, 4847, 4850, 4853, 4854, 4859, 4860, 4861, 4862, 4863, 4870, 4877, 4878, 4880, 4881, 4885, 4892, 4894, 4895, 4898, 4903, 4913, 4916, 4940, 4941, 4943, 4944, 4945, 4955, 4956, 4957, 4958, 4959, 4962, 4963, 4965, 4971, 4972, 4973, 4974, 4985, 4987, 4988, 4989, 4992, 4993, 4994, 4995, 5000, 5001, 5002, 5011, 5012, 5013, 5017, 5018, 5025, 5028, 5049, 5050, 5051, 5052, 5053, 5054, 5055, 5056, 5057, 5067, 5068, 5069, 5070, 5071, 5074, 5075, 5077, 5083, 5084, 5085, 5086, 5099, 5102, 5103, 5104, 5105, 5111, 5112, 5115, 5116, 5121, 5122, 5128, 5129, 5138, 5139, 5164, 5165, 5166, 5168, 5178, 5179, 5180, 5181, 5182, 5185, 5188, 5192, 5194, 5195, 5196, 5197, 5198, 5209, 5210, 5211, 5212, 5215, 5216, 5217, 5218, 5219, 5220, 5221, 5222, 5229, 5230, 5233, 5234, 5235, 5236, 5237, 5239, 5240, 5245, 5246, 5255, 5256, 5282, 5283, 5295, 5296, 5297, 5298, 5299, 5300, 5301, 5302, 5303, 5305, 5314, 5315, 5316, 5318, 5325, 5326, 5347, 5348, 5351, 5353, 5354, 5355, 5356, 5357, 5358, 5359, 5360, 5370, 5371, 5372, 5373, 5380, 5386, 5387, 5388, 5389, 5390, 5391, 5392, 5393]}
--------------------------------------------------------------------------------
/body_utils/body_segments/L_Leg.json:
--------------------------------------------------------------------------------
1 | {"verts_ind": [5774, 5781, 5789, 5790, 5791, 5792, 5793, 5794, 5797, 5805, 5806, 5807, 5808, 5813, 5814, 5815, 5816, 5817, 5818, 5824, 5827, 5830, 5831, 5832, 5839, 5840, 5842, 5843, 5844, 5847, 5850, 5851, 5854, 5855, 5859, 5861, 5862, 5864, 5865, 5869, 5902, 5906, 5907, 5908, 5909, 5910, 5911, 5912, 5913, 5914, 5915, 5916, 5917, 8866, 8867, 8868, 8879, 8880, 8881, 8882, 8883, 8884, 8888, 8889, 8890, 8891, 8897, 8898, 8899, 8900, 8901, 8902, 8903, 8904, 8905, 8906, 8907, 8908, 8909, 8910, 8911, 8912, 8913, 8914, 8915, 8916, 8917, 8919, 8920, 8921, 8922, 8923, 8924, 8925, 8929, 8930, 8934], "faces_ind": [4257, 4259, 4274, 4283, 4284, 4286, 4288, 4307, 4309, 4311, 4314, 4320, 4340, 4354, 4371, 4422, 4462, 4463, 4464, 4467, 4469, 4470, 4471, 4472, 4478, 4530, 4545, 4546, 4547, 4549, 4550, 4551, 4552, 4553, 4554, 4568, 4581, 4616, 4618, 4874, 4875, 4876, 4978, 5004, 5072, 5074, 5075, 5076, 5077, 5204, 5263, 5264, 5265, 5268, 5324, 5353, 5375, 5376, 5386, 5390, 5627, 5684, 5694, 5712, 5783, 5785, 5860, 5867, 14790, 14805, 14808, 14814, 14815, 14817, 14819, 14834, 14840, 14845, 14858, 14859, 14871, 14922, 14953, 14993, 14994, 14995, 14998, 15000, 15001, 15002, 15003, 15009, 15061, 15076, 15077, 15078, 15080, 15081, 15082, 15083, 15084, 15085, 15099, 15112, 15147, 15149, 15405, 15406, 15407, 15434, 15509, 15535, 15602, 15604, 15605, 15606, 15607, 15733, 15793, 15794, 15797, 15853, 15904, 15905, 15915, 15919, 16156, 16213, 16240, 16311, 16313, 16388, 16394, 16395, 16415]}
--------------------------------------------------------------------------------
/body_utils/body_segments/R_Hand.json:
--------------------------------------------------------------------------------
1 | {"faces_ind": [6679, 6687, 6692, 6700, 6712, 6713, 6717, 6718, 6719, 6720, 6730, 6732, 6733, 6735, 6736, 6739, 6749, 6750, 6759, 6762, 6790, 6798, 6801, 6803, 6818, 6820, 6821, 6823, 6824, 6830, 6831, 6847, 6854, 6857, 6864, 6874, 6878, 6879, 6882, 6887, 6889, 6894, 6898, 6910, 6942, 6943, 6944, 6945, 6946, 6948, 6951, 6952, 6953, 6954, 6958, 6962, 6963, 6964, 6971, 6972, 6975, 6980, 6987, 6991, 6992, 6999, 7005, 7007, 7019, 7059, 7060, 7061, 7062, 7063, 7070, 7071, 7074, 7079, 7080, 7081, 7082, 7088, 7091, 7100, 7106, 7110, 7117, 7123, 7125, 7137, 7143, 7177, 7178, 7179, 7180, 7183, 7186, 7187, 7188, 7189, 7197, 7198, 7199, 7206, 7209, 7214, 7218, 7220, 7225, 7226, 7229, 7233, 7239, 7241, 7244, 7252, 7253, 7260, 7280, 7293, 7294, 7295, 7296, 7297, 7298, 7299, 7302, 7303, 7304, 7305, 7308, 7309, 7313, 7314, 7315, 7316, 7321, 7322, 7325, 7330, 7331, 7333, 7337, 7364, 7365, 7366, 7379, 7380, 7381, 7382, 7383, 7384, 7385, 7391, 7392, 7400, 7401, 7402, 7403, 7408, 7411, 7412, 7414, 7415, 7416, 7417, 7418, 7419, 7420, 7421, 7422, 7423, 7424, 7426, 7428, 7429, 7437, 7440, 7453, 7462, 7463, 7468, 7479, 7480, 7481, 17207, 17215, 17219, 17220, 17228, 17240, 17241, 17245, 17246, 17247, 17248, 17258, 17260, 17263, 17264, 17277, 17278, 17287, 17290, 17317, 17325, 17328, 17330, 17345, 17347, 17348, 17350, 17351, 17357, 17373, 17380, 17383, 17385, 17390, 17400, 17404, 17405, 17408, 17413, 17415, 17420, 17424, 17435, 17467, 17468, 17469, 17470, 17471, 17473, 17476, 17477, 17478, 17479, 17482, 17483, 17487, 17488, 17489, 17496, 17497, 17500, 17502, 17505, 17512, 17516, 17517, 17524, 17530, 17532, 17544, 17584, 17585, 17586, 17587, 17588, 17595, 17596, 17599, 17604, 17605, 17606, 17613, 17616, 17625, 17631, 17635, 17642, 17648, 17650, 17662, 17669, 17702, 17703, 17704, 17705, 17706, 17708, 17711, 17713, 17714, 17722, 17723, 17724, 17731, 17734, 17739, 17743, 17745, 17746, 17750, 17751, 17754, 17758, 17764, 17766, 17769, 17777, 17778, 17785, 17818, 17819, 17820, 17821, 17822, 17824, 17827, 17828, 17829, 17830, 17833, 17834, 17838, 17839, 17840, 17841, 17846, 17847, 17850, 17855, 17856, 17858, 17862, 17889, 17890, 17891, 17904, 17905, 17906, 17907, 17908, 17909, 17910, 17916, 17917, 17925, 17926, 17927, 17933, 17936, 17937, 17939, 17940, 17941, 17942, 17943, 17944, 17945, 17946, 17947, 17948, 17951, 17954, 17964, 17969, 17977, 17986, 17987, 17992, 18003, 18004], "verts_ind": [7335, 7336, 7337, 7338, 7342, 7360, 7363, 7364, 7365, 7366, 7367, 7370, 7373, 7375, 7377, 7378, 7379, 7380, 7394, 7395, 7405, 7406, 7407, 7408, 7424, 7427, 7428, 7429, 7430, 7431, 7432, 7433, 7434, 7438, 7440, 7441, 7442, 7443, 7444, 7445, 7446, 7447, 7463, 7464, 7465, 7466, 7470, 7471, 7472, 7477, 7478, 7479, 7480, 7481, 7482, 7489, 7491, 7492, 7494, 7506, 7507, 7508, 7509, 7510, 7511, 7512, 7513, 7536, 7537, 7540, 7541, 7565, 7566, 7572, 7582, 7583, 7586, 7589, 7590, 7598, 7599, 7604, 7605, 7606, 7612, 7613, 7614, 7615, 7616, 7617, 7621, 7625, 7628, 7630, 7631, 7634, 7639, 7651, 7652, 7679, 7680, 7681, 7683, 7691, 7692, 7693, 7694, 7695, 7696, 7697, 7698, 7699, 7700, 7701, 7702, 7704, 7705, 7706, 7707, 7708, 7709, 7710, 7721, 7725, 7728, 7729, 7730, 7737, 7738, 7747, 7748, 7753, 7754, 7764, 7789, 7790, 7791, 7793, 7803, 7804, 7805, 7806, 7807, 7810, 7811, 7813, 7814, 7816, 7817, 7818, 7819, 7820, 7821, 7822, 7835, 7838, 7839, 7840, 7847, 7848, 7852, 7857, 7858, 7864, 7865, 7873, 7874, 7875, 7900, 7901, 7902, 7904, 7914, 7915, 7916, 7917, 7918, 7919, 7920, 7921, 7922, 7924, 7927, 7928, 7930, 7931, 7932, 7933, 7934, 7945, 7946, 7947, 7948, 7951, 7952, 7953, 7954, 7955, 7956, 7957, 7958, 7965, 7966, 7969, 7970, 7971, 7972, 7975, 7976, 7981, 7982, 7989, 7991, 7992, 8016, 8017, 8018, 8019, 8020, 8021, 8031, 8032, 8033, 8034, 8035, 8036, 8037, 8038, 8039, 8040, 8041, 8042, 8044, 8045, 8046, 8050, 8051, 8052, 8054, 8055, 8062, 8065, 8087, 8088, 8089, 8090, 8091, 8092, 8093, 8094, 8104, 8105, 8106, 8107, 8108, 8111, 8112, 8114, 8117, 8118, 8119, 8120, 8121, 8122, 8123, 8124, 8125, 8126, 8127]}
--------------------------------------------------------------------------------
/body_utils/body_segments/R_Leg.json:
--------------------------------------------------------------------------------
1 | {"verts_ind": [8468, 8484, 8485, 8487, 8488, 8499, 8500, 8501, 8502, 8507, 8508, 8509, 8510, 8511, 8512, 8521, 8522, 8523, 8524, 8525, 8526, 8527, 8530, 8533, 8534, 8536, 8537, 8538, 8541, 8543, 8544, 8545, 8546, 8548, 8549, 8555, 8556, 8558, 8559, 8563, 8600, 8601, 8602, 8603, 8604, 8605, 8606, 8607, 8608, 8609, 8610, 8611, 8654, 8655, 8656, 8667, 8668, 8669, 8670, 8671, 8672, 8676, 8677, 8678, 8679, 8685, 8686, 8687, 8688, 8689, 8690, 8691, 8692, 8693, 8694, 8695, 8696, 8697, 8698, 8699, 8700, 8701, 8702, 8703, 8704, 8705, 8706, 8707, 8708, 8709, 8710, 8711, 8712, 8713, 8714, 8715, 8716], "faces_ind": [8100, 8105, 8107, 8108, 8112, 8113, 8117, 8122, 8123, 8132, 8133, 8139, 8149, 8157, 8171, 8182, 8188, 8192, 8227, 8229, 8230, 8231, 8232, 8295, 8296, 8297, 8298, 8299, 8300, 8301, 8302, 8303, 8304, 8305, 8306, 8307, 8308, 8309, 8310, 8311, 8312, 8313, 8314, 8315, 8316, 8317, 8318, 8319, 8320, 8323, 8324, 8325, 8326, 8327, 8328, 8329, 8330, 8331, 8333, 8334, 8335, 8336, 8337, 8558, 8559, 8560, 8561, 8562, 18623, 18628, 18630, 18631, 18635, 18636, 18640, 18645, 18646, 18648, 18656, 18662, 18666, 18672, 18680, 18694, 18703, 18705, 18711, 18750, 18752, 18753, 18754, 18755, 18818, 18819, 18820, 18821, 18822, 18823, 18824, 18825, 18826, 18827, 18828, 18829, 18830, 18831, 18832, 18833, 18834, 18835, 18836, 18837, 18838, 18839, 18840, 18841, 18842, 18843, 18847, 18848, 18849, 18850, 18851, 18852, 18853, 18854, 18856, 18857, 18858, 18859, 18860, 19081, 19082, 19083, 19084]}
--------------------------------------------------------------------------------
/body_utils/body_segments/back.json:
--------------------------------------------------------------------------------
1 | {"faces_ind": [1517, 1535, 1593, 1598, 3126, 3195, 3264, 3269, 3272, 3507, 3552, 3739, 3846, 3861, 3866, 3868, 3869, 3871, 3872, 3881, 3882, 3891, 3892, 3893, 3894, 3895, 3896, 3918, 3921, 3922, 3977, 3978, 3981, 4010, 4043, 4099, 4100, 4191, 4224, 4229, 4302, 4573, 4656, 4665, 4675, 4677, 4737, 4748, 4756, 4779, 4787, 4791, 4796, 4800, 4805, 4807, 4811, 4818, 4819, 4820, 4832, 4838, 4844, 4862, 4949, 5009, 5010, 5038, 5055, 5066, 5085, 5121, 5253, 5388, 5599, 5630, 5687, 5691, 5698, 5763, 5799, 5833, 5859, 5862, 5888, 5896, 5965, 5972, 5973, 6024, 6025, 6026, 6159, 6169, 6174, 6175, 6343, 6543, 6552, 6553, 7518, 7526, 7530, 7534, 7648, 7655, 7657, 7681, 7682, 7683, 7684, 7685, 7686, 7687, 7703, 7704, 7705, 7712, 7723, 7724, 7725, 7726, 7727, 7737, 7738, 7739, 7740, 7741, 7742, 7743, 7744, 7745, 7771, 7772, 7773, 7783, 7784, 7786, 7787, 7853, 7854, 7855, 7856, 7931, 7932, 7933, 8038, 8047, 8049, 8447, 8450, 8451, 8465, 8466, 8477, 8492, 8503, 8504, 8505, 8509, 8510, 8511, 8527, 8541, 8591, 8617, 8623, 8662, 8671, 8672, 8673, 8676, 8687, 8749, 8755, 12066, 12084, 12142, 12147, 13660, 13729, 13798, 13803, 13806, 14084, 14271, 14378, 14393, 14398, 14400, 14404, 14413, 14414, 14417, 14423, 14424, 14425, 14426, 14428, 14453, 14454, 14509, 14510, 14513, 14542, 14575, 14631, 14632, 14723, 14755, 14760, 14833, 15104, 15187, 15196, 15206, 15208, 15268, 15273, 15279, 15287, 15310, 15318, 15322, 15327, 15331, 15336, 15338, 15342, 15349, 15350, 15351, 15363, 15369, 15375, 15393, 15480, 15540, 15541, 15569, 15585, 15596, 15614, 15650, 15782, 15917, 15938, 16128, 16159, 16216, 16220, 16227, 16233, 16291, 16327, 16361, 16362, 16387, 16390, 16416, 16493, 16500, 16501, 16552, 16553, 16554, 16687, 16697, 16702, 16703, 16815, 16871, 17071, 17080, 17081, 18042, 18050, 18054, 18058, 18172, 18179, 18181, 18182, 18205, 18206, 18207, 18208, 18209, 18210, 18211, 18227, 18228, 18229, 18236, 18248, 18249, 18250, 18251, 18262, 18263, 18264, 18265, 18266, 18267, 18268, 18269, 18295, 18296, 18297, 18307, 18308, 18310, 18311, 18377, 18378, 18379, 18380, 18455, 18456, 18457, 18561, 18570, 18572, 18969, 18972, 18973, 18988, 18999, 19014, 19025, 19026, 19027, 19031, 19032, 19033, 19049, 19063, 19139, 19184, 19194, 19195, 19198, 19209, 19271, 19277], "verts_ind": [3336, 3337, 3338, 3339, 3352, 3357, 3361, 3362, 3363, 3364, 3365, 3366, 3367, 3381, 3384, 3398, 3399, 3426, 3428, 3431, 3516, 3518, 3519, 3520, 3521, 3522, 3523, 3524, 3525, 3526, 3823, 3844, 3845, 3846, 3848, 3872, 3873, 3883, 3886, 3887, 3888, 3891, 3892, 3893, 4068, 4069, 4070, 4071, 4126, 4127, 4128, 4129, 4430, 4431, 4432, 4433, 4443, 4444, 4445, 4452, 4453, 4454, 4455, 4456, 4457, 5402, 5403, 5404, 5405, 5411, 5412, 5413, 5414, 5420, 5421, 5422, 5427, 5428, 5460, 5461, 5462, 5483, 5485, 5498, 5516, 5521, 5523, 5524, 5525, 5526, 5530, 5535, 5536, 5537, 5538, 5545, 5546, 5560, 5561, 5562, 5563, 5564, 5565, 5566, 5567, 5568, 5569, 5598, 5610, 5611, 5631, 5699, 6099, 6100, 6101, 6102, 6115, 6121, 6122, 6123, 6124, 6125, 6126, 6127, 6128, 6142, 6145, 6159, 6160, 6187, 6189, 6192, 6194, 6277, 6279, 6280, 6281, 6282, 6283, 6284, 6285, 6286, 6287, 6580, 6599, 6600, 6601, 6603, 6604, 6623, 6624, 6633, 6636, 6637, 6638, 6639, 6640, 6641, 6812, 6813, 6814, 6870, 6871, 6872, 6873, 7166, 7167, 7168, 7169, 7179, 7180, 7188, 7189, 7190, 7192, 8136, 8137, 8138, 8139, 8145, 8146, 8147, 8148, 8154, 8155, 8156, 8161, 8162, 8192, 8194, 8195, 8196, 8197, 8218, 8223, 8238, 8241, 8242, 8243, 8244, 8245, 8246, 8247, 8248, 8249, 8258, 8259, 8260, 8272, 8274, 8275, 8276, 8277, 8278, 8279, 8280, 8281, 8307, 8308, 8316, 8325, 8393]}
--------------------------------------------------------------------------------
/body_utils/body_segments/butt.json:
--------------------------------------------------------------------------------
1 | {"faces_ind": [1573, 3081, 3125, 3483, 4047, 4049, 4086, 4189, 4190, 4194, 4198, 4199, 4213, 4215, 4237, 4293, 4402, 4511, 4592, 4615, 4745, 4851, 4856, 4859, 4912, 4929, 5124, 5276, 5313, 5339, 5343, 5354, 5414, 5428, 5600, 5605, 5625, 5685, 5775, 5781, 5843, 5874, 6007, 6008, 6009, 6022, 6173, 7899, 7902, 7903, 7908, 7909, 7912, 7914, 7915, 7916, 7917, 7918, 7919, 7920, 7921, 7922, 7923, 7924, 7925, 7926, 7927, 7928, 8012, 8013, 8014, 8015, 8016, 8050, 8061, 8430, 8431, 8441, 8442, 8448, 8521, 8522, 8523, 8524, 8525, 8526, 8539, 8564, 8606, 8732, 12122, 13615, 13659, 13916, 14015, 14579, 14581, 14618, 14721, 14722, 14726, 14730, 14731, 14744, 14746, 14768, 14824, 14933, 15042, 15123, 15146, 15276, 15382, 15390, 15443, 15460, 15653, 15805, 15842, 15868, 15872, 15883, 15943, 15957, 16129, 16134, 16154, 16214, 16303, 16309, 16371, 16402, 16535, 16537, 16550, 16701, 18426, 18427, 18432, 18433, 18436, 18438, 18439, 18440, 18441, 18442, 18443, 18444, 18445, 18446, 18447, 18448, 18449, 18450, 18451, 18452, 18536, 18537, 18538, 18539, 18540, 18573, 18584, 18585, 18952, 18953, 18963, 18964, 18970, 19043, 19044, 19045, 19046, 19047, 19048, 19061, 19086, 19128, 19254], "verts_ind": [3462, 3463, 3464, 3465, 3466, 3468, 3469, 3470, 3471, 3472, 3473, 3483, 3501, 3512, 3513, 3514, 3515, 3867, 3884, 3885, 5574, 5575, 5596, 5613, 5614, 5659, 5661, 5665, 5666, 5667, 5673, 5674, 5675, 5676, 5678, 5680, 5681, 5682, 5683, 5684, 5685, 5686, 5687, 5688, 5689, 5690, 5691, 5692, 5693, 5694, 5695, 5696, 5712, 5713, 5714, 5715, 5934, 6223, 6224, 6225, 6226, 6227, 6229, 6230, 6231, 6232, 6233, 6234, 6273, 6274, 6275, 6276, 6634, 6635, 6651, 7145, 8353, 8354, 8355, 8359, 8360, 8361, 8362, 8363, 8365, 8366, 8367, 8368, 8369, 8370, 8372, 8374, 8375, 8376, 8377, 8378, 8379, 8380, 8381, 8382, 8383, 8384, 8385, 8386, 8387, 8388, 8389, 8390, 8405, 8406, 8407, 8408, 8409]}
--------------------------------------------------------------------------------
/body_utils/body_segments/gluteus.json:
--------------------------------------------------------------------------------
1 | {"faces_ind": [1573, 3081, 3125, 3483, 4047, 4049, 4086, 4189, 4190, 4194, 4198, 4199, 4213, 4215, 4237, 4293, 4402, 4511, 4592, 4615, 4745, 4851, 4856, 4859, 4912, 4929, 5124, 5276, 5313, 5339, 5343, 5354, 5414, 5428, 5600, 5605, 5625, 5685, 5775, 5781, 5843, 5874, 6007, 6008, 6009, 6022, 6173, 7899, 7902, 7903, 7908, 7909, 7912, 7914, 7915, 7916, 7917, 7918, 7919, 7920, 7921, 7922, 7923, 7924, 7925, 7926, 7927, 7928, 8012, 8013, 8014, 8015, 8016, 8050, 8061, 8430, 8431, 8441, 8442, 8448, 8521, 8522, 8523, 8524, 8525, 8526, 8539, 8564, 8606, 8732, 12122, 13615, 13659, 13916, 14015, 14579, 14581, 14618, 14721, 14722, 14726, 14730, 14731, 14744, 14746, 14768, 14824, 14933, 15042, 15123, 15146, 15276, 15382, 15390, 15443, 15460, 15653, 15805, 15842, 15868, 15872, 15883, 15943, 15957, 16129, 16134, 16154, 16214, 16303, 16309, 16371, 16402, 16535, 16537, 16550, 16701, 18426, 18427, 18432, 18433, 18436, 18438, 18439, 18440, 18441, 18442, 18443, 18444, 18445, 18446, 18447, 18448, 18449, 18450, 18451, 18452, 18536, 18537, 18538, 18539, 18540, 18573, 18584, 18585, 18952, 18953, 18963, 18964, 18970, 19043, 19044, 19045, 19046, 19047, 19048, 19061, 19086, 19128, 19254], "verts_ind": [3462, 3463, 3464, 3465, 3466, 3468, 3469, 3470, 3471, 3472, 3473, 3483, 3501, 3512, 3513, 3514, 3515, 3867, 3884, 3885, 5574, 5575, 5596, 5613, 5614, 5659, 5661, 5665, 5666, 5667, 5673, 5674, 5675, 5676, 5678, 5680, 5681, 5682, 5683, 5684, 5685, 5686, 5687, 5688, 5689, 5690, 5691, 5692, 5693, 5694, 5695, 5696, 5712, 5713, 5714, 5715, 5934, 6223, 6224, 6225, 6226, 6227, 6229, 6230, 6231, 6232, 6233, 6234, 6273, 6274, 6275, 6276, 6634, 6635, 6651, 7145, 8353, 8354, 8355, 8359, 8360, 8361, 8362, 8363, 8365, 8366, 8367, 8368, 8369, 8370, 8372, 8374, 8375, 8376, 8377, 8378, 8379, 8380, 8381, 8382, 8383, 8384, 8385, 8386, 8387, 8388, 8389, 8390, 8405, 8406, 8407, 8408, 8409]}
--------------------------------------------------------------------------------
/body_utils/body_segments/thighs.json:
--------------------------------------------------------------------------------
1 | {"faces_ind": [1615, 1620, 1627, 3172, 3173, 3174, 3179, 3286, 3386, 4056, 4058, 4070, 4088, 4847, 4886, 4954, 5096, 5296, 5307, 5735, 5820, 5869, 5908, 6055, 6063, 6064, 6075, 6301, 6302, 6304, 6310, 6311, 6312, 6338, 7941, 7952, 7953, 7954, 7958, 7959, 8453, 8531, 12164, 12169, 12176, 13706, 13707, 13708, 13713, 13820, 13919, 14588, 14602, 14620, 15200, 15378, 15417, 15485, 15625, 15825, 15836, 16348, 16397, 16436, 16583, 16591, 16592, 16603, 16829, 16830, 16832, 16838, 16839, 16840, 16866, 18464, 18465, 18476, 18477, 18478, 18482, 18483, 18975, 19053], "verts_ind": [3530, 3531, 3533, 3592, 3596, 3599, 3600, 3601, 3602, 3611, 3612, 3613, 3614, 3615, 3616, 3617, 3618, 3619, 3620, 3622, 3623, 3653, 3654, 3656, 3801, 3802, 3803, 4091, 4092, 4093, 4094, 4095, 4096, 6290, 6291, 6327, 6353, 6356, 6357, 6358, 6359, 6376, 6377, 6378, 6379, 6380, 6381, 6382, 6383, 6384, 6385, 6416, 6417, 6559, 6560, 6561, 6835, 6836, 6837, 6838, 6839, 6840]}
--------------------------------------------------------------------------------
/body_utils/left_heel_verts_id.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/body_utils/left_heel_verts_id.npy
--------------------------------------------------------------------------------
/body_utils/left_toe_verts_id.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/body_utils/left_toe_verts_id.npy
--------------------------------------------------------------------------------
/body_utils/left_whole_foot_verts_id.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/body_utils/left_whole_foot_verts_id.npy
--------------------------------------------------------------------------------
/body_utils/right_heel_verts_id.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/body_utils/right_heel_verts_id.npy
--------------------------------------------------------------------------------
/body_utils/right_toe_verts_id.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/body_utils/right_toe_verts_id.npy
--------------------------------------------------------------------------------
/body_utils/right_whole_foot_verts_id.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/body_utils/right_whole_foot_verts_id.npy
--------------------------------------------------------------------------------
/body_utils/smplx_mano_flame_correspondences/MANO_SMPLX_vertex_ids.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/body_utils/smplx_mano_flame_correspondences/MANO_SMPLX_vertex_ids.pkl
--------------------------------------------------------------------------------
/body_utils/smplx_mano_flame_correspondences/SMPL-X__FLAME_vertex_ids.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/body_utils/smplx_mano_flame_correspondences/SMPL-X__FLAME_vertex_ids.npy
--------------------------------------------------------------------------------
/body_utils/smplx_markerset.json:
--------------------------------------------------------------------------------
1 | {"markersets": [{"distance_from_skin": 0.0095, "indices": {"C7": 4391, "CLAV": 5533, "LANK": 5761, "LBSH": 4509, "LBWT": 5678, "LELB": 4245, "LFRM": 4379, "LFSH": 4515, "LFWT": 5726, "LHEL": 8852, "LIEL": 4258, "LKNE": 3638, "LKNI": 3781, "LSHN": 3705, "LTHI": 3479, "LUPA": 4039, "MBWT": 4297, "MFWT": 5615, "RANK": 8455, "RBSH": 7179, "RBWT": 7145, "RELB": 7028, "RFRM": 7115, "RFSH": 7251, "RFWT": 8421, "RHEL": 8634, "RIEL": 7036, "RKNE": 6401, "RKNI": 6539, "RSHN": 6466, "RTHI": 6352, "RUPA": 6778, "STRN": 5532, "ROWR": 7293, "RIWR": 7274, "LOWR": 4557, "LIWR": 4538, "T10": 5944}, "type": "body"}, {"distance_from_skin": 0.0095, "indices": {"LMT1": 5893, "LMT5": 5899, "LTOE": 5857, "RMT1": 8587, "RMT5": 8593, "RTOE": 8551}, "type": "foot"}, {"distance_from_skin": 0.0095, "indices": {"RFHD": 3035, "LFHD": 2148, "LBHD": 2041, "RBHD": 3076, "ARIEL": 9002}, "type": "head"}, {"distance_from_skin": 0.0002, "indices": {"LIDX1": 4875, "LIDX2": 4897, "LIDX3": 4931, "LMID1": 5014, "LMID2": 5020, "LMID3": 5045, "LPNK1": 5242, "LPNK2": 5250, "LPNK3": 5268, "LRNG1": 5124, "LRNG2": 5131, "LRNG3": 5149, "LTHM2": 4683, "LTHM3": 5321, "LTHM4": 5346, "RIDX1": 7611, "RIDX2": 7633, "RIDX3": 7667, "RMID1": 7750, "RMID2": 7756, "RMID3": 7781, "RPNK1": 7978, "RPNK2": 7984, "RPNK3": 8001, "RRNG1": 7860, "RRNG2": 7867, "RRNG3": 7884, "RTHM2": 7419, "RTHM3": 7602, "RTHM4": 8082}, "type": "finger"}, {"distance_from_skin": 0.0039, "indices": {"LTHM1": 4686, "RTHM1": 7423, "LIHAND": 4748, "LOHAND": 4615, "RIHAND": 7500, "ROHAND": 7351}, "type": "hand"}, {"distance_from_skin": 0.0002, "indices": {"MTH1": 2819, "MTH2": 2813, "MTH3": 8985, "MTH4": 1696, "MTH5": 1703, "MTH6": 1795, "MTH7": 8947, "MTH8": 2898, "CHN1": 8757, "CHN2": 9066}, "type": "face"}, {"distance_from_skin": 0.0002, "indices": {"REYE1": 2383, "REYE2": 2311, "LEYE1": 1043, "LEYE2": 919}, "type": "eyelids"}, {"distance_from_skin": 0, "indices": {"0": 4628, "1": 4641, "2": 4660, "3": 4690, "4": 4691, "5": 4710, "6": 4750, "7": 4885, "8": 4957, "9": 4970, "10": 5001, "11": 5012, "12": 5082, "13": 5111, "14": 5179, "15": 5193, "16": 5229, "17": 5296, "18": 5306, "19": 5315, "20": 5353, "21": 5387, "22": 7357, "23": 7396, "24": 7443, "25": 7446, "26": 7536, "27": 7589, "28": 7618, "29": 7625, "30": 7692, "31": 7706, "32": 7730, "33": 7748, "34": 7789, "35": 7847, "36": 7858, "37": 7924, "38": 7931, "39": 7976, "40": 8039, "41": 8050, "42": 8087, "43": 8122}, "type": "palm"}, {"distance_from_skin": 0, "indices": {"0": 4970, "1": 5082, "2": 5193, "3": 5306, "4": 5353, "5": 7706, "6": 7789, "7": 7924, "8": 8039, "9": 8087}, "type": "palm_5"}]}
--------------------------------------------------------------------------------
/images/binoculars-0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/images/binoculars-0.jpg
--------------------------------------------------------------------------------
/images/binoculars-60-first-view.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/images/binoculars-60-first-view.jpg
--------------------------------------------------------------------------------
/images/binoculars-60.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/images/binoculars-60.jpg
--------------------------------------------------------------------------------
/images/binoculars-movie-first-view.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/images/binoculars-movie-first-view.gif
--------------------------------------------------------------------------------
/images/binoculars-video.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/images/binoculars-video.gif
--------------------------------------------------------------------------------
/images/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/images/teaser.png
--------------------------------------------------------------------------------
/images/two-stage-pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/images/two-stage-pipeline.png
--------------------------------------------------------------------------------
/images/wineglass-0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/images/wineglass-0.jpg
--------------------------------------------------------------------------------
/images/wineglass-60-first-view.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/images/wineglass-60-first-view.jpg
--------------------------------------------------------------------------------
/images/wineglass-60.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/images/wineglass-60.jpg
--------------------------------------------------------------------------------
/images/wineglass-movie-first-view.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/images/wineglass-movie-first-view.gif
--------------------------------------------------------------------------------
/images/wineglass-video.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/images/wineglass-video.gif
--------------------------------------------------------------------------------
/opt_grasppose.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import sys
4 | from collections import defaultdict
5 |
6 | import numpy as np
7 | # import smplx
8 | import open3d as o3d
9 | import torch
10 | from smplx.lbs import batch_rodrigues
11 | from tqdm import tqdm
12 |
13 | from utils.cfg_parser import Config
14 | from utils.utils import makelogger, makepath
15 | from WholeGraspPose.models.fittingop import FittingOP
16 | from WholeGraspPose.models.objectmodel import ObjectModel
17 | from WholeGraspPose.trainer import Trainer
18 |
19 |
20 | #### inference
21 | def load_object_data_random(object_name, n_samples):
22 | mesh_base = './dataset/contact_meshes'
23 | obj_mesh_base = o3d.io.read_triangle_mesh(os.path.join(mesh_base, object_name + '.ply'))
24 | obj_mesh_base.compute_vertex_normals()
25 | v_temp = torch.FloatTensor(obj_mesh_base.vertices).to(grabpose.device).view(1, -1, 3).repeat(n_samples, 1, 1)
26 | normal_temp = torch.FloatTensor(obj_mesh_base.vertex_normals).to(grabpose.device).view(1, -1, 3).repeat(n_samples, 1, 1)
27 | obj_model = ObjectModel(v_temp, normal_temp, n_samples)
28 |
29 | """Prepare transl/global_orient data"""
30 | """Example: randomly sample object height and orientation"""
31 | transf_transl_list = torch.rand(n_samples) + 0.6 #### can be customized
32 | global_orient_list = (np.pi)*torch.rand(n_samples) - np.pi/2 #### can be customized
33 | transl = torch.zeros(n_samples, 3) # for object model which is centered at object
34 | transf_transl = torch.zeros(n_samples, 3)
35 | transf_transl[:, -1] = transf_transl_list
36 | global_orient = torch.zeros(n_samples, 3)
37 | global_orient[:, -1] = global_orient_list
38 | global_orient_rotmat = batch_rodrigues(global_orient.view(-1, 3)).to(grabpose.device) # [N, 3, 3]
39 |
40 | object_output = obj_model(global_orient_rotmat, transl.to(grabpose.device), v_temp.to(grabpose.device), normal_temp.to(grabpose.device), rotmat=True)
41 | object_verts = object_output[0].detach().squeeze().cpu().numpy() if n_samples != 1 else object_output[0].detach().cpu().numpy()
42 | object_normal = object_output[1].detach().squeeze().cpu().numpy() if n_samples != 1 else object_output[1].detach().cpu().numpy()
43 |
44 | index = np.linspace(0, object_verts.shape[1], num=2048, endpoint=False,retstep=True,dtype=int)[0]
45 |
46 | verts_object = object_verts[:, index]
47 | normal_object = object_normal[:, index]
48 | global_orient_rotmat_6d = global_orient_rotmat.view(-1, 1, 9)[:, :, :6].detach().cpu().numpy()
49 | feat_object = np.concatenate([normal_object, global_orient_rotmat_6d.repeat(2048, axis=1)], axis=-1)
50 |
51 | verts_object = torch.from_numpy(verts_object).to(grabpose.device)
52 | feat_object = torch.from_numpy(feat_object).to(grabpose.device)
53 | transf_transl = transf_transl.to(grabpose.device)
54 | return {'verts_object':verts_object, 'normal_object': normal_object, 'global_orient':global_orient, 'global_orient_rotmat':global_orient_rotmat, 'feat_object':feat_object, 'transf_transl':transf_transl}
55 |
56 |
57 | def load_object_data_uniform_sample(object_name, n_samples):
58 | mesh_base = './dataset/contact_meshes'
59 | obj_mesh_base = o3d.io.read_triangle_mesh(os.path.join(mesh_base, object_name + '.ply'))
60 | obj_mesh_base.compute_vertex_normals()
61 | v_temp = torch.FloatTensor(obj_mesh_base.vertices).to(grabpose.device).view(1, -1, 3).repeat(n_samples, 1, 1)
62 | normal_temp = torch.FloatTensor(obj_mesh_base.vertex_normals).to(grabpose.device).view(1, -1, 3).repeat(n_samples, 1, 1)
63 | obj_model = ObjectModel(v_temp, normal_temp, n_samples)
64 |
65 | """Prepare transl/global_orient data"""
66 | """Example: uniformly sample object height and orientation (can be customized)"""
67 | transf_transl_list = torch.arange(n_samples)*1.0/(n_samples-1) + 0.5
68 | global_orient_list = (2*np.pi)*torch.arange(n_samples)/n_samples
69 | n_samples = transf_transl_list.shape[0] * global_orient_list.shape[0]
70 | transl = torch.zeros(n_samples, 3) # for object model which is centered at object
71 | transf_transl = torch.zeros(n_samples, 3)
72 | transf_transl[:, -1] = transf_transl_list.repeat_interleave(global_orient_list.shape[0])
73 | global_orient = torch.zeros(n_samples, 3)
74 | global_orient[:, -1] = global_orient_list.repeat(transf_transl_list.shape[0]) # [6+6+6.....]
75 | global_orient_rotmat = batch_rodrigues(global_orient.view(-1, 3)).to(grabpose.device) # [N, 3, 3]
76 |
77 | object_output = obj_model(global_orient_rotmat, transl.to(grabpose.device), v_temp.to(grabpose.device), normal_temp.to(grabpose.device), rotmat=True)
78 | object_verts = object_output[0].detach().squeeze().cpu().numpy() if n_samples != 1 else object_output[0].detach().cpu().numpy()
79 | object_normal = object_output[1].detach().squeeze().cpu().numpy() if n_samples != 1 else object_output[1].detach().cpu().numpy()
80 |
81 | index = np.linspace(0, object_verts.shape[1], num=2048, endpoint=False,retstep=True,dtype=int)[0]
82 |
83 | verts_object = object_verts[:, index]
84 | normal_object = object_normal[:, index]
85 | global_orient_rotmat_6d = global_orient_rotmat.view(-1, 1, 9)[:, :, :6].detach().cpu().numpy()
86 | feat_object = np.concatenate([normal_object, global_orient_rotmat_6d.repeat(2048, axis=1)], axis=-1)
87 |
88 | verts_object = torch.from_numpy(verts_object).to(grabpose.device)
89 | feat_object = torch.from_numpy(feat_object).to(grabpose.device)
90 | transf_transl = transf_transl.to(grabpose.device)
91 | return {'verts_object':verts_object, 'normal_object': normal_object, 'global_orient':global_orient, 'global_orient_rotmat':global_orient_rotmat, 'feat_object':feat_object, 'transf_transl':transf_transl}
92 |
93 | def inference(grabpose, obj, n_samples, n_rand_samples, object_type, save_dir):
94 | """ prepare test object data: [verts_object, feat_object(normal + rotmat), transf_transl] """
95 | ### object centered
96 | # for obj in grabpose.cfg.object_class:
97 | if object_type == 'uniform':
98 | obj_data = load_object_data_uniform_sample(obj, n_samples)
99 | elif object_type == 'random':
100 | obj_data = load_object_data_random(obj, n_samples)
101 | obj_data['feat_object'] = obj_data['feat_object'].permute(0,2,1)
102 | obj_data['verts_object'] = obj_data['verts_object'].permute(0,2,1)
103 |
104 | n_samples_total = obj_data['feat_object'].shape[0]
105 |
106 | markers_gen = []
107 | object_contact_gen = []
108 | markers_contact_gen = []
109 | for i in range(n_samples_total):
110 | sample_results = grabpose.full_grasp_net.sample(obj_data['verts_object'][None, i].repeat(n_rand_samples,1,1), obj_data['feat_object'][None, i].repeat(n_rand_samples,1,1), obj_data['transf_transl'][None, i].repeat(n_rand_samples,1))
111 | markers_gen.append((sample_results[0]+obj_data['transf_transl'][None, i]))
112 | markers_contact_gen.append(sample_results[1])
113 | object_contact_gen.append(sample_results[2])
114 |
115 | markers_gen = torch.cat(markers_gen, dim=0) # [B, N, 3]
116 | object_contact_gen = torch.cat(object_contact_gen, dim=0).squeeze() # [B, 2048]
117 | markers_contact_gen = torch.cat(markers_contact_gen, dim=0) # [B, N]
118 |
119 | output = {}
120 | output['markers_gen'] = markers_gen.detach().cpu().numpy()
121 | output['markers_contact_gen'] = markers_contact_gen.detach().cpu().numpy()
122 | output['object_contact_gen'] = object_contact_gen.detach().cpu().numpy()
123 | output['normal_object'] = obj_data['normal_object']#.repeat(n_rand_samples, axis=0)
124 | output['transf_transl'] = obj_data['transf_transl'].detach().cpu().numpy()#.repeat(n_rand_samples, axis=0)
125 | output['global_orient_object'] = obj_data['global_orient'].detach().cpu().numpy()#.repeat(n_rand_samples, axis=0)
126 | output['global_orient_object_rotmat'] = obj_data['global_orient_rotmat'].detach().cpu().numpy()#.repeat(n_rand_samples, axis=0)
127 | output['verts_object'] = (obj_data['verts_object']+obj_data['transf_transl'].view(-1,3,1).repeat(1,1,2048)).permute(0, 2, 1).detach().cpu().numpy()#.repeat(n_rand_samples, axis=0)
128 |
129 | save_path = os.path.join(save_dir, 'markers_results.npy')
130 | np.save(save_path, output)
131 | print('Saving to {}'.format(save_path))
132 |
133 | return output
134 |
135 | def fitting_data_save(save_data,
136 | markers,
137 | markers_fit,
138 | smplxparams,
139 | gender,
140 | object_contact, body_contact,
141 | object_name, verts_object, global_orient_object, transf_transl_object):
142 | # markers & markers_fit
143 | save_data['markers'].append(markers)
144 | save_data['markers_fit'].append(markers_fit)
145 | # print('markers:', markers.shape)
146 |
147 | # body params
148 | for key in save_data['body'].keys():
149 | # print(key, smplxparams[key].shape)
150 | save_data['body'][key].append(smplxparams[key].detach().cpu().numpy())
151 | # object name & object params
152 | save_data['object_name'].append(object_name)
153 | save_data['gender'].append(gender)
154 | save_data['object']['transl'].append(transf_transl_object)
155 | save_data['object']['global_orient'].append(global_orient_object)
156 | save_data['object']['verts_object'].append(verts_object)
157 |
158 | # contact
159 | save_data['contact']['body'].append(body_contact)
160 | save_data['contact']['object'].append(object_contact)
161 |
162 | #### fitting
163 |
164 | def pose_opt(grabpose, samples_results, n_random_samples, obj, gender, save_dir, logger, device):
165 | # prepare objects
166 | n_samples = len(samples_results['verts_object'])
167 | verts_object = torch.tensor(samples_results['verts_object'])[:n_samples].to(device) # (n, 2048, 3)
168 | normals_object = torch.tensor(samples_results['normal_object'])[:n_samples].to(device) # (n, 2048, 3)
169 | global_orients_object = torch.tensor(samples_results['global_orient_object'])[:n_samples].to(device) # (n, 2048, 3)
170 | transf_transl_object = torch.tensor(samples_results['transf_transl'])[:n_samples].to(device) # (n, 2048, 3)
171 |
172 | # prepare body markers
173 | markers_gen = torch.tensor(samples_results['markers_gen']).to(device) # (n*k, 143, 3)
174 | object_contacts_gen = torch.tensor(samples_results['object_contact_gen']).view(markers_gen.shape[0], -1, 1).to(device) # (n, 2048, 1)
175 | markers_contacts_gen = torch.tensor(samples_results['markers_contact_gen']).view(markers_gen.shape[0], -1, 1).to(device) # (n, 143, 1)
176 |
177 | print('Fitting {} {} samples for {}...'.format(n_samples, cfg.gender, obj.upper()))
178 |
179 | fittingconfig={ 'init_lr_h': 0.008,
180 | 'num_iter': [300,400,500],
181 | 'batch_size': 1,
182 | 'num_markers': 143,
183 | 'device': device,
184 | 'cfg': cfg,
185 | 'verbose': False,
186 | 'hand_ncomps': 24,
187 | 'only_rec': False, # True / False
188 | 'contact_loss': 'contact', # contact / prior / False
189 | 'logger': logger,
190 | 'data_type': 'markers_143',
191 | }
192 | fittingop = FittingOP(fittingconfig)
193 |
194 | save_data_gen = {}
195 | for data in [save_data_gen]:
196 | data['markers'] = []
197 | data['markers_fit'] = []
198 | data['body'] = {}
199 | for key in ['betas', 'transl', 'global_orient', 'body_pose', 'leye_pose', 'reye_pose', 'left_hand_pose', 'right_hand_pose']:
200 | data['body'][key] = []
201 | data['object'] = {}
202 | for key in ['transl', 'global_orient', 'verts_object']:
203 | data['object'][key] = []
204 | data['contact'] = {}
205 | for key in ['body', 'object']:
206 | data['contact'][key] = []
207 | data['gender'] = []
208 | data['object_name'] = []
209 |
210 |
211 | for i in tqdm(range(n_samples)):
212 | # prepare object
213 | vert_object = verts_object[None, i, :, :]
214 | normal_object = normals_object[None, i, :, :]
215 |
216 | marker_gen = markers_gen[i*n_random_samples:(i+1)*n_random_samples, :, :]
217 | object_contact_gen = object_contacts_gen[i*n_random_samples:(i+1)*n_random_samples, :, :]
218 | markers_contact_gen = markers_contacts_gen[i*n_random_samples:(i+1)*n_random_samples, :, :]
219 |
220 | for k in range(n_random_samples):
221 | print('Fitting for {}-th GEN...'.format(k+1))
222 | markers_fit_gen, smplxparams_gen, loss_gen = fittingop.fitting(marker_gen[None, k, :], object_contact_gen[None, k, :], markers_contact_gen[None, k], vert_object, normal_object, gender)
223 | fitting_data_save(save_data_gen,
224 | marker_gen[k, :].detach().cpu().numpy().reshape(1, -1 ,3),
225 | markers_fit_gen[-1].squeeze().reshape(1, -1 ,3),
226 | smplxparams_gen[-1],
227 | gender,
228 | object_contact_gen[k].detach().cpu().numpy().reshape(1, -1), markers_contact_gen[k].detach().cpu().numpy().reshape(1, -1),
229 | obj, vert_object.detach().cpu().numpy(), global_orients_object[i].detach().cpu().numpy(), transf_transl_object[i].detach().cpu().numpy())
230 |
231 |
232 | for data in [save_data_gen]:
233 | # for data in [save_data_gt, save_data_rec, save_data_gen]:
234 | data['markers'] = np.vstack(data['markers'])
235 | data['markers_fit'] = np.vstack(data['markers_fit'])
236 | for key in ['betas', 'transl', 'global_orient', 'body_pose', 'leye_pose', 'reye_pose', 'left_hand_pose', 'right_hand_pose']:
237 | data['body'][key] = np.vstack(data['body'][key])
238 | for key in ['transl', 'global_orient', 'verts_object']:
239 | data['object'][key] = np.vstack(data['object'][key])
240 | for key in ['body', 'object']:
241 | data['contact'][key] = np.vstack(data['contact'][key])
242 |
243 | np.savez(os.path.join(save_dir, 'fitting_results.npz'), **save_data_gen)
244 |
245 | if __name__ == '__main__':
246 |
247 | parser = argparse.ArgumentParser(description='grabpose-Testing')
248 |
249 | parser.add_argument('--data_path', default = './dataset/GraspPose', type=str,
250 | help='The path to the folder that contains grabpose data')
251 |
252 | parser.add_argument('--object', default = None, type=str,
253 | help='object name')
254 |
255 | parser.add_argument('--gender', default=None, type=str,
256 | help='The gender of dataset')
257 |
258 | parser.add_argument('--config_path', default = None, type=str,
259 | help='The path to the confguration of the trained grabpose model')
260 |
261 | parser.add_argument('--exp_name', default = None, type=str,
262 | help='experiment name')
263 |
264 | parser.add_argument('--pose_ckpt_path', default = None, type=str,
265 | help='checkpoint path')
266 |
267 | parser.add_argument('--n_object_samples', default = 5, type=int,
268 | help='The number of object samples of this object')
269 |
270 | parser.add_argument('--type_object_samples', default = 'random', type=str,
271 | help='For the given object mesh, we provide two types of object heights and orientation sampling mode: random / uniform')
272 |
273 | parser.add_argument('--n_rand_samples_per_object', default = 1, type=int,
274 | help='The number of whole-body poses random samples generated per object')
275 |
276 | args = parser.parse_args()
277 |
278 | cwd = os.getcwd()
279 |
280 | best_net = os.path.join(cwd, args.pose_ckpt_path)
281 |
282 | vpe_path = '/configs/verts_per_edge.npy'
283 | c_weights_path = cwd + '/WholeGraspPose/configs/rhand_weight.npy'
284 | work_dir = cwd + '/results/{}/GraspPose'.format(args.exp_name)
285 | print(work_dir)
286 | config = {
287 | 'dataset_dir': args.data_path,
288 | 'work_dir':work_dir,
289 | 'vpe_path': vpe_path,
290 | 'c_weights_path': c_weights_path,
291 | 'exp_name': args.exp_name,
292 | 'gender': args.gender,
293 | 'best_net': best_net
294 | }
295 |
296 | cfg_path = 'WholeGraspPose/configs/WholeGraspPose.yaml'
297 | cfg = Config(default_cfg_path=cfg_path, **config)
298 |
299 | save_dir = os.path.join(work_dir, args.object)
300 | if not os.path.exists(save_dir):
301 | os.makedirs(save_dir)
302 |
303 | logger = makelogger(makepath(os.path.join(save_dir, '%s.log' % (args.object)), isfile=True)).info
304 |
305 | grabpose = Trainer(cfg=cfg, inference=True, logger=logger)
306 |
307 | samples_results = inference(grabpose, args.object, args.n_object_samples, args.n_rand_samples_per_object, args.type_object_samples, save_dir)
308 | fitting_results = pose_opt(grabpose, samples_results, args.n_rand_samples_per_object, args.object, cfg.gender, save_dir, logger, grabpose.device)
309 |
310 |
--------------------------------------------------------------------------------
/train_grasppose.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import sys
4 |
5 | from utils.cfg_parser import Config
6 | from WholeGraspPose.trainer import Trainer
7 |
8 | if __name__ == '__main__':
9 |
10 | parser = argparse.ArgumentParser(description='GrabNet-Training')
11 |
12 | parser.add_argument('--work-dir', default='logs/GraspPose', type=str,
13 | help='The path to the downloaded grab data')
14 |
15 | parser.add_argument('--gender', default=None, type=str,
16 | help='The gender of dataset')
17 |
18 | parser.add_argument('--data_path', default = '/cluster/work/cvl/wuyan/data/GRAB-series/GrabPose_r_fullbody/data', type=str,
19 | help='The path to the folder that contains grabpose data')
20 |
21 | parser.add_argument('--batch-size', default=64, type=int,
22 | help='Training batch size')
23 |
24 | parser.add_argument('--n-workers', default=8, type=int,
25 | help='Number of PyTorch dataloader workers')
26 |
27 | parser.add_argument('--lr', default=5e-4, type=float,
28 | help='Training learning rate')
29 |
30 | parser.add_argument('--kl-coef', default=0.5, type=float,
31 | help='KL divergence coefficent for Coarsenet training')
32 |
33 | parser.add_argument('--use-multigpu', default=False,
34 | type=lambda arg: arg.lower() in ['true', '1'],
35 | help='If to use multiple GPUs for training')
36 |
37 | parser.add_argument('--exp_name', default = None, type=str,
38 | help='experiment name')
39 |
40 |
41 | args = parser.parse_args()
42 |
43 | work_dir = os.path.join(args.work_dir, args.exp_name)
44 |
45 | cwd = os.getcwd()
46 | default_cfg_path = 'WholeGraspPose/configs/WholeGraspPose.yaml'
47 |
48 | cfg = {
49 | 'batch_size': args.batch_size,
50 | 'n_workers': args.n_workers,
51 | 'use_multigpu': args.use_multigpu,
52 | 'kl_coef': args.kl_coef,
53 | 'dataset_dir': args.data_path,
54 | 'base_dir': cwd,
55 | 'work_dir': work_dir,
56 | 'base_lr': args.lr,
57 | 'best_net': None,
58 | 'gender': args.gender,
59 | 'exp_name': args.exp_name,
60 | }
61 |
62 | cfg = Config(default_cfg_path=default_cfg_path, **cfg)
63 | grabpose_trainer = Trainer(cfg=cfg)
64 | grabpose_trainer.fit()
65 |
66 | cfg = grabpose_trainer.cfg
67 | cfg.write_cfg(os.path.join(work_dir, 'TR%02d_%s' % (cfg.try_num, os.path.basename(default_cfg_path))))
68 |
--------------------------------------------------------------------------------
/utils/Pivots.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | import numpy as np
5 |
6 | sys.path.append(os.getcwd())
7 | from utils.Quaternions import Quaternions
8 |
9 |
10 | class Pivots:
11 | """
12 | Pivots is an ndarray of angular rotations
13 |
14 | This wrapper provides some functions for
15 | working with pivots.
16 |
17 | These are particularly useful as a number
18 | of atomic operations (such as adding or
19 | subtracting) cannot be achieved using
20 | the standard arithmatic and need to be
21 | defined differently to work correctly
22 | """
23 |
24 | def __init__(self, ps): self.ps = np.array(ps)
25 | def __str__(self): return "Pivots("+ str(self.ps) + ")"
26 | def __repr__(self): return "Pivots("+ repr(self.ps) + ")"
27 |
28 | def __add__(self, other): return Pivots(np.arctan2(np.sin(self.ps + other.ps), np.cos(self.ps + other.ps)))
29 | def __sub__(self, other): return Pivots(np.arctan2(np.sin(self.ps - other.ps), np.cos(self.ps - other.ps)))
30 | def __mul__(self, other): return Pivots(self.ps * other.ps)
31 | def __div__(self, other): return Pivots(self.ps / other.ps)
32 | def __mod__(self, other): return Pivots(self.ps % other.ps)
33 | def __pow__(self, other): return Pivots(self.ps ** other.ps)
34 |
35 | def __lt__(self, other): return self.ps < other.ps
36 | def __le__(self, other): return self.ps <= other.ps
37 | def __eq__(self, other): return self.ps == other.ps
38 | def __ne__(self, other): return self.ps != other.ps
39 | def __ge__(self, other): return self.ps >= other.ps
40 | def __gt__(self, other): return self.ps > other.ps
41 |
42 | def __abs__(self): return Pivots(abs(self.ps))
43 | def __neg__(self): return Pivots(-self.ps)
44 |
45 | def __iter__(self): return iter(self.ps)
46 | def __len__(self): return len(self.ps)
47 |
48 | def __getitem__(self, k): return Pivots(self.ps[k])
49 | def __setitem__(self, k, v): self.ps[k] = v.ps
50 |
51 | def _ellipsis(self): return tuple(map(lambda x: slice(None), self.shape))
52 |
53 | def quaternions(self, plane='xz'):
54 | fa = self._ellipsis()
55 | axises = np.ones(self.ps.shape + (3,))
56 | axises[fa + ("xyz".index(plane[0]),)] = 0.0
57 | axises[fa + ("xyz".index(plane[1]),)] = 0.0
58 | return Quaternions.from_angle_axis(self.ps, axises)
59 |
60 | def directions(self, plane='xz'):
61 | dirs = np.zeros((len(self.ps), 3))
62 | dirs["xyz".index(plane[0])] = np.sin(self.ps)
63 | dirs["xyz".index(plane[1])] = np.cos(self.ps)
64 | return dirs
65 |
66 | def normalized(self):
67 | xs = np.copy(self.ps)
68 | while np.any(xs > np.pi): xs[xs > np.pi] = xs[xs > np.pi] - 2 * np.pi
69 | while np.any(xs < -np.pi): xs[xs < -np.pi] = xs[xs < -np.pi] + 2 * np.pi
70 | return Pivots(xs)
71 |
72 | def interpolate(self, ws):
73 | dir = np.average(self.directions, weights=ws, axis=0)
74 | return np.arctan2(dir[2], dir[0])
75 |
76 | def copy(self):
77 | return Pivots(np.copy(self.ps))
78 |
79 | @property
80 | def shape(self):
81 | return self.ps.shape
82 |
83 | @classmethod
84 | def from_quaternions(cls, qs, forward='z', plane='xz'):
85 | ds = np.zeros(qs.shape + (3,))
86 | ds[...,'xyz'.index(forward)] = 1.0
87 | return Pivots.from_directions(qs * ds, plane=plane)
88 |
89 | @classmethod
90 | def from_directions(cls, ds, plane='xz'):
91 | ys = ds[...,'xyz'.index(plane[0])]
92 | xs = ds[...,'xyz'.index(plane[1])]
93 | return Pivots(np.arctan2(ys, xs))
94 |
95 |
--------------------------------------------------------------------------------
/utils/Pivots_torch.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | import torch
5 |
6 | sys.path.append(os.getcwd())
7 | from utils.Quaternions_torch import Quaternions_torch
8 |
9 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10 |
11 | class Pivots_torch:
12 | """
13 | Pivots is an ndarray of angular rotations
14 |
15 | This wrapper provides some functions for
16 | working with pivots.
17 |
18 | These are particularly useful as a number
19 | of atomic operations (such as adding or
20 | subtracting) cannot be achieved using
21 | the standard arithmatic and need to be
22 | defined differently to work correctly
23 | """
24 |
25 | def __init__(self, ps): self.ps = torch.tensor(ps).to(device)
26 | def __str__(self): return "Pivots("+ str(self.ps) + ")"
27 | def __repr__(self): return "Pivots("+ repr(self.ps) + ")"
28 |
29 | def __add__(self, other): return Pivots_torch(torch.atan2(torch.sin(self.ps + other.ps), torch.cos(self.ps + other.ps)))
30 | def __sub__(self, other): return Pivots_torch(torch.atan2(torch.sin(self.ps - other.ps), torch.cos(self.ps - other.ps)))
31 | def __mul__(self, other): return Pivots_torch(self.ps * other.ps)
32 | def __div__(self, other): return Pivots_torch(self.ps / other.ps)
33 | def __mod__(self, other): return Pivots_torch(self.ps % other.ps)
34 | def __pow__(self, other): return Pivots_torch(self.ps ** other.ps)
35 |
36 | def __lt__(self, other): return self.ps < other.ps
37 | def __le__(self, other): return self.ps <= other.ps
38 | def __eq__(self, other): return self.ps == other.ps
39 | def __ne__(self, other): return self.ps != other.ps
40 | def __ge__(self, other): return self.ps >= other.ps
41 | def __gt__(self, other): return self.ps > other.ps
42 |
43 | def __abs__(self): return Pivots_torch(torch.abs(self.ps))
44 | def __neg__(self): return Pivots_torch(-self.ps)
45 |
46 | def __iter__(self): return iter(self.ps)
47 | def __len__(self): return len(self.ps)
48 |
49 | def __getitem__(self, k): return Pivots_torch(self.ps[k])
50 | def __setitem__(self, k, v): self.ps[k] = v.ps
51 |
52 | def _ellipsis(self): return tuple(map(lambda x: slice(None), self.shape))
53 |
54 | def quaternions(self, plane='xz'):
55 | fa = self._ellipsis()
56 | axises = torch.ones(self.ps.shape + (3,)).to(device)
57 | axises[fa + ("xyz".index(plane[0]),)] = 0.0
58 | axises[fa + ("xyz".index(plane[1]),)] = 0.0
59 | return Quaternions_torch.from_angle_axis(self.ps, axises)
60 |
61 | def directions(self, plane='xz'):
62 | dirs = torch.zeros((len(self.ps), 3)).to(device)
63 | dirs["xyz".index(plane[0])] = torch.sin(self.ps)
64 | dirs["xyz".index(plane[1])] = torch.cos(self.ps)
65 | return dirs
66 |
67 | def normalized(self):
68 | xs = self.ps.clone()
69 | while torch.any(xs > torch.pi): xs[xs > torch.pi] = xs[xs > torch.pi] - 2 * torch.pi
70 | while torch.any(xs < -torch.pi): xs[xs < -torch.pi] = xs[xs < -torch.pi] + 2 * torch.pi
71 | return Pivots_torch(xs)
72 |
73 | # def interpolate(self, ws):
74 | # dir = np.average(self.directions, weights=ws, axis=0)
75 | # return torch.atan2(dir[2], dir[0])
76 |
77 | def clone(self):
78 | return Pivots_torch((self.ps).clone())
79 |
80 | @property
81 | def shape(self):
82 | return self.ps.shape
83 |
84 | @classmethod
85 | def from_quaternions(cls, qs, forward='z', plane='xz'):
86 | ds = torch.zeros(qs.shape + (3,)).to(device)
87 | ds[...,'xyz'.index(forward)] = 1.0
88 | return Pivots_torch.from_directions(qs * ds, plane=plane)
89 |
90 | @classmethod
91 | def from_directions(cls, ds, plane='xz'):
92 | ys = ds[...,'xyz'.index(plane[0])]
93 | xs = ds[...,'xyz'.index(plane[1])]
94 | return Pivots_torch(torch.atan2(ys, xs))
95 |
96 |
--------------------------------------------------------------------------------
/utils/Quaternions.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | class Quaternions:
4 | """
5 | Quaternions is a wrapper around a numpy ndarray
6 | that allows it to act as if it were an narray of
7 | a quaternion data type.
8 |
9 | Therefore addition, subtraction, multiplication,
10 | division, negation, absolute, are all defined
11 | in terms of quaternion operations such as quaternion
12 | multiplication.
13 |
14 | This allows for much neater code and many routines
15 | which conceptually do the same thing to be written
16 | in the same way for point data and for rotation data.
17 |
18 | The Quaternions class has been desgined such that it
19 | should support broadcasting and slicing in all of the
20 | usual ways.
21 | """
22 |
23 | def __init__(self, qs):
24 | if isinstance(qs, np.ndarray):
25 |
26 | if len(qs.shape) == 1: qs = np.array([qs])
27 | self.qs = qs
28 | return
29 |
30 | if isinstance(qs, Quaternions):
31 | self.qs = qs.qs
32 | return
33 |
34 | raise TypeError('Quaternions must be constructed from iterable, numpy array, or Quaternions, not %s' % type(qs))
35 |
36 | def __str__(self): return "Quaternions("+ str(self.qs) + ")"
37 | def __repr__(self): return "Quaternions("+ repr(self.qs) + ")"
38 |
39 | """ Helper Methods for Broadcasting and Data extraction """
40 |
41 | @classmethod
42 | def _broadcast(cls, sqs, oqs, scalar=False):
43 |
44 | if isinstance(oqs, float): return sqs, oqs * np.ones(sqs.shape[:-1])
45 |
46 | ss = np.array(sqs.shape) if not scalar else np.array(sqs.shape[:-1])
47 | os = np.array(oqs.shape)
48 |
49 | if len(ss) != len(os):
50 | raise TypeError('Quaternions cannot broadcast together shapes %s and %s' % (sqs.shape, oqs.shape))
51 |
52 | if np.all(ss == os): return sqs, oqs
53 |
54 | if not np.all((ss == os) | (os == np.ones(len(os))) | (ss == np.ones(len(ss)))):
55 | raise TypeError('Quaternions cannot broadcast together shapes %s and %s' % (sqs.shape, oqs.shape))
56 |
57 | sqsn, oqsn = sqs.copy(), oqs.copy()
58 |
59 | for a in np.where(ss == 1)[0]: sqsn = sqsn.repeat(os[a], axis=a)
60 | for a in np.where(os == 1)[0]: oqsn = oqsn.repeat(ss[a], axis=a)
61 |
62 | return sqsn, oqsn
63 |
64 | """ Adding Quaterions is just Defined as Multiplication """
65 |
66 | def __add__(self, other): return self * other
67 | def __sub__(self, other): return self / other
68 |
69 | """ Quaterion Multiplication """
70 |
71 | def __mul__(self, other):
72 | """
73 | Quaternion multiplication has three main methods.
74 |
75 | When multiplying a Quaternions array by Quaternions
76 | normal quaternion multiplication is performed.
77 |
78 | When multiplying a Quaternions array by a vector
79 | array of the same shape, where the last axis is 3,
80 | it is assumed to be a Quaternion by 3D-Vector
81 | multiplication and the 3D-Vectors are rotated
82 | in space by the Quaternions.
83 |
84 | When multipplying a Quaternions array by a scalar
85 | or vector of different shape it is assumed to be
86 | a Quaternions by Scalars multiplication and the
87 | Quaternions are scaled using Slerp and the identity
88 | quaternions.
89 | """
90 |
91 | """ If Quaternions type do Quaternions * Quaternions """
92 | if isinstance(other, Quaternions):
93 |
94 | sqs, oqs = Quaternions._broadcast(self.qs, other.qs)
95 |
96 | q0 = sqs[...,0]; q1 = sqs[...,1];
97 | q2 = sqs[...,2]; q3 = sqs[...,3];
98 | r0 = oqs[...,0]; r1 = oqs[...,1];
99 | r2 = oqs[...,2]; r3 = oqs[...,3];
100 |
101 | qs = np.empty(sqs.shape)
102 | qs[...,0] = r0 * q0 - r1 * q1 - r2 * q2 - r3 * q3
103 | qs[...,1] = r0 * q1 + r1 * q0 - r2 * q3 + r3 * q2
104 | qs[...,2] = r0 * q2 + r1 * q3 + r2 * q0 - r3 * q1
105 | qs[...,3] = r0 * q3 - r1 * q2 + r2 * q1 + r3 * q0
106 |
107 | return Quaternions(qs)
108 |
109 | """ If array type do Quaternions * Vectors """
110 | if isinstance(other, np.ndarray) and other.shape[-1] == 3:
111 | vs = Quaternions(np.concatenate([np.zeros(other.shape[:-1] + (1,)), other], axis=-1))
112 | return (self * (vs * -self)).imaginaries
113 |
114 | """ If float do Quaternions * Scalars """
115 | if isinstance(other, np.ndarray) or isinstance(other, float):
116 | return Quaternions.slerp(Quaternions.id_like(self), self, other)
117 |
118 | raise TypeError('Cannot multiply/add Quaternions with type %s' % str(type(other)))
119 |
120 | def __div__(self, other):
121 | """
122 | When a Quaternion type is supplied, division is defined
123 | as multiplication by the inverse of that Quaternion.
124 |
125 | When a scalar or vector is supplied it is defined
126 | as multiplicaion of one over the supplied value.
127 | Essentially a scaling.
128 | """
129 |
130 | if isinstance(other, Quaternions): return self * (-other)
131 | if isinstance(other, np.ndarray): return self * (1.0 / other)
132 | if isinstance(other, float): return self * (1.0 / other)
133 | raise TypeError('Cannot divide/subtract Quaternions with type %s' + str(type(other)))
134 |
135 | def __eq__(self, other): return self.qs == other.qs
136 | def __ne__(self, other): return self.qs != other.qs
137 |
138 | def __neg__(self):
139 | """ Invert Quaternions """
140 | return Quaternions(self.qs * np.array([[1, -1, -1, -1]]))
141 |
142 | def __abs__(self):
143 | """ Unify Quaternions To Single Pole """
144 | qabs = self.normalized().copy()
145 | top = np.sum(( qabs.qs) * np.array([1,0,0,0]), axis=-1)
146 | bot = np.sum((-qabs.qs) * np.array([1,0,0,0]), axis=-1)
147 | qabs.qs[top < bot] = -qabs.qs[top < bot]
148 | return qabs
149 |
150 | def __iter__(self): return iter(self.qs)
151 | def __len__(self): return len(self.qs)
152 |
153 | def __getitem__(self, k): return Quaternions(self.qs[k])
154 | def __setitem__(self, k, v): self.qs[k] = v.qs
155 |
156 | @property
157 | def lengths(self):
158 | return np.sum(self.qs**2.0, axis=-1)**0.5
159 |
160 | @property
161 | def reals(self):
162 | return self.qs[...,0]
163 |
164 | @property
165 | def imaginaries(self):
166 | return self.qs[...,1:4]
167 |
168 | @property
169 | def shape(self): return self.qs.shape[:-1]
170 |
171 | def repeat(self, n, **kwargs):
172 | return Quaternions(self.qs.repeat(n, **kwargs))
173 |
174 | def normalized(self):
175 | return Quaternions(self.qs / self.lengths[...,np.newaxis])
176 |
177 | def log(self):
178 | norm = abs(self.normalized())
179 | imgs = norm.imaginaries
180 | lens = np.sqrt(np.sum(imgs**2, axis=-1))
181 | lens = np.arctan2(lens, norm.reals) / (lens + 1e-10)
182 | return imgs * lens[...,np.newaxis]
183 |
184 | def constrained(self, axis):
185 |
186 | rl = self.reals
187 | im = np.sum(axis * self.imaginaries, axis=-1)
188 |
189 | t1 = -2 * np.arctan2(rl, im) + np.pi
190 | t2 = -2 * np.arctan2(rl, im) - np.pi
191 |
192 | top = Quaternions.exp(axis[np.newaxis] * (t1[:,np.newaxis] / 2.0))
193 | bot = Quaternions.exp(axis[np.newaxis] * (t2[:,np.newaxis] / 2.0))
194 | img = self.dot(top) > self.dot(bot)
195 |
196 | ret = top.copy()
197 | ret[ img] = top[ img]
198 | ret[~img] = bot[~img]
199 | return ret
200 |
201 | def constrained_x(self): return self.constrained(np.array([1,0,0]))
202 | def constrained_y(self): return self.constrained(np.array([0,1,0]))
203 | def constrained_z(self): return self.constrained(np.array([0,0,1]))
204 |
205 | def dot(self, q): return np.sum(self.qs * q.qs, axis=-1)
206 |
207 | def copy(self): return Quaternions(np.copy(self.qs))
208 |
209 | def reshape(self, s):
210 | self.qs.reshape(s)
211 | return self
212 |
213 | def interpolate(self, ws):
214 | return Quaternions.exp(np.average(abs(self).log, axis=0, weights=ws))
215 |
216 | def euler(self, order='xyz'):
217 |
218 | q = self.normalized().qs
219 | q0 = q[...,0]
220 | q1 = q[...,1]
221 | q2 = q[...,2]
222 | q3 = q[...,3]
223 | es = np.zeros(self.shape + (3,))
224 |
225 | if order == 'xyz':
226 | es[...,0] = np.arctan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
227 | es[...,1] = np.arcsin((2 * (q0 * q2 - q3 * q1)).clip(-1,1))
228 | es[...,2] = np.arctan2(2 * (q0 * q3 + q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3))
229 | elif order == 'yzx':
230 | es[...,0] = np.arctan2(2 * (q1 * q0 - q2 * q3), -q1 * q1 + q2 * q2 - q3 * q3 + q0 * q0)
231 | es[...,1] = np.arctan2(2 * (q2 * q0 - q1 * q3), q1 * q1 - q2 * q2 - q3 * q3 + q0 * q0)
232 | es[...,2] = np.arcsin((2 * (q1 * q2 + q3 * q0)).clip(-1,1))
233 | else:
234 | raise NotImplementedError('Cannot convert from ordering %s' % order)
235 |
236 | """
237 |
238 | # These conversion don't appear to work correctly for Maya.
239 | # http://bediyap.com/programming/convert-quaternion-to-euler-rotations/
240 |
241 | if order == 'xyz':
242 | es[fa + (0,)] = np.arctan2(2 * (q0 * q3 - q1 * q2), q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3)
243 | es[fa + (1,)] = np.arcsin((2 * (q1 * q3 + q0 * q2)).clip(-1,1))
244 | es[fa + (2,)] = np.arctan2(2 * (q0 * q1 - q2 * q3), q0 * q0 - q1 * q1 - q2 * q2 + q3 * q3)
245 | elif order == 'yzx':
246 | es[fa + (0,)] = np.arctan2(2 * (q0 * q1 - q2 * q3), q0 * q0 - q1 * q1 + q2 * q2 - q3 * q3)
247 | es[fa + (1,)] = np.arcsin((2 * (q1 * q2 + q0 * q3)).clip(-1,1))
248 | es[fa + (2,)] = np.arctan2(2 * (q0 * q2 - q1 * q3), q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3)
249 | elif order == 'zxy':
250 | es[fa + (0,)] = np.arctan2(2 * (q0 * q2 - q1 * q3), q0 * q0 - q1 * q1 - q2 * q2 + q3 * q3)
251 | es[fa + (1,)] = np.arcsin((2 * (q0 * q1 + q2 * q3)).clip(-1,1))
252 | es[fa + (2,)] = np.arctan2(2 * (q0 * q3 - q1 * q2), q0 * q0 - q1 * q1 + q2 * q2 - q3 * q3)
253 | elif order == 'xzy':
254 | es[fa + (0,)] = np.arctan2(2 * (q0 * q2 + q1 * q3), q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3)
255 | es[fa + (1,)] = np.arcsin((2 * (q0 * q3 - q1 * q2)).clip(-1,1))
256 | es[fa + (2,)] = np.arctan2(2 * (q0 * q1 + q2 * q3), q0 * q0 - q1 * q1 + q2 * q2 - q3 * q3)
257 | elif order == 'yxz':
258 | es[fa + (0,)] = np.arctan2(2 * (q1 * q2 + q0 * q3), q0 * q0 - q1 * q1 + q2 * q2 - q3 * q3)
259 | es[fa + (1,)] = np.arcsin((2 * (q0 * q1 - q2 * q3)).clip(-1,1))
260 | es[fa + (2,)] = np.arctan2(2 * (q1 * q3 + q0 * q2), q0 * q0 - q1 * q1 - q2 * q2 + q3 * q3)
261 | elif order == 'zyx':
262 | es[fa + (0,)] = np.arctan2(2 * (q0 * q1 + q2 * q3), q0 * q0 - q1 * q1 - q2 * q2 + q3 * q3)
263 | es[fa + (1,)] = np.arcsin((2 * (q0 * q2 - q1 * q3)).clip(-1,1))
264 | es[fa + (2,)] = np.arctan2(2 * (q0 * q3 + q1 * q2), q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3)
265 | else:
266 | raise KeyError('Unknown ordering %s' % order)
267 |
268 | """
269 |
270 | # https://github.com/ehsan/ogre/blob/master/OgreMain/src/OgreMatrix3.cpp
271 | # Use this class and convert from matrix
272 |
273 | return es
274 |
275 |
276 | def average(self):
277 |
278 | if len(self.shape) == 1:
279 |
280 | import numpy.core.umath_tests as ut
281 | system = ut.matrix_multiply(self.qs[:,:,np.newaxis], self.qs[:,np.newaxis,:]).sum(axis=0)
282 | w, v = np.linalg.eigh(system)
283 | qiT_dot_qref = (self.qs[:,:,np.newaxis] * v[np.newaxis,:,:]).sum(axis=1)
284 | return Quaternions(v[:,np.argmin((1.-qiT_dot_qref**2).sum(axis=0))])
285 |
286 | else:
287 |
288 | raise NotImplementedError('Cannot average multi-dimensionsal Quaternions')
289 |
290 | def angle_axis(self):
291 |
292 | norm = self.normalized()
293 | s = np.sqrt(1 - (norm.reals**2.0))
294 | s[s == 0] = 0.001
295 |
296 | angles = 2.0 * np.arccos(norm.reals)
297 | axis = norm.imaginaries / s[...,np.newaxis]
298 |
299 | return angles, axis
300 |
301 |
302 | def transforms(self):
303 |
304 | qw = self.qs[...,0]
305 | qx = self.qs[...,1]
306 | qy = self.qs[...,2]
307 | qz = self.qs[...,3]
308 |
309 | x2 = qx + qx; y2 = qy + qy; z2 = qz + qz;
310 | xx = qx * x2; yy = qy * y2; wx = qw * x2;
311 | xy = qx * y2; yz = qy * z2; wy = qw * y2;
312 | xz = qx * z2; zz = qz * z2; wz = qw * z2;
313 |
314 | m = np.empty(self.shape + (3,3))
315 | m[...,0,0] = 1.0 - (yy + zz)
316 | m[...,0,1] = xy - wz
317 | m[...,0,2] = xz + wy
318 | m[...,1,0] = xy + wz
319 | m[...,1,1] = 1.0 - (xx + zz)
320 | m[...,1,2] = yz - wx
321 | m[...,2,0] = xz - wy
322 | m[...,2,1] = yz + wx
323 | m[...,2,2] = 1.0 - (xx + yy)
324 |
325 | return m
326 |
327 | def ravel(self):
328 | return self.qs.ravel()
329 |
330 | @classmethod
331 | def id(cls, n):
332 |
333 | if isinstance(n, tuple):
334 | qs = np.zeros(n + (4,))
335 | qs[...,0] = 1.0
336 | return Quaternions(qs)
337 |
338 | if isinstance(n, int) or isinstance(n, long):
339 | qs = np.zeros((n,4))
340 | qs[:,0] = 1.0
341 | return Quaternions(qs)
342 |
343 | raise TypeError('Cannot Construct Quaternion from %s type' % str(type(n)))
344 |
345 | @classmethod
346 | def id_like(cls, a):
347 | qs = np.zeros(a.shape + (4,))
348 | qs[...,0] = 1.0
349 | return Quaternions(qs)
350 |
351 | @classmethod
352 | def exp(cls, ws):
353 |
354 | ts = np.sum(ws**2.0, axis=-1)**0.5
355 | ts[ts == 0] = 0.001
356 | ls = np.sin(ts) / ts
357 |
358 | qs = np.empty(ws.shape[:-1] + (4,))
359 | qs[...,0] = np.cos(ts)
360 | qs[...,1] = ws[...,0] * ls
361 | qs[...,2] = ws[...,1] * ls
362 | qs[...,3] = ws[...,2] * ls
363 |
364 | return Quaternions(qs).normalized()
365 |
366 | @classmethod
367 | def slerp(cls, q0s, q1s, a):
368 |
369 | fst, snd = cls._broadcast(q0s.qs, q1s.qs)
370 | fst, a = cls._broadcast(fst, a, scalar=True)
371 | snd, a = cls._broadcast(snd, a, scalar=True)
372 |
373 | len = np.sum(fst * snd, axis=-1)
374 |
375 | neg = len < 0.0
376 | len[neg] = -len[neg]
377 | snd[neg] = -snd[neg]
378 |
379 | amount0 = np.zeros(a.shape)
380 | amount1 = np.zeros(a.shape)
381 |
382 | linear = (1.0 - len) < 0.01
383 | omegas = np.arccos(len[~linear])
384 | sinoms = np.sin(omegas)
385 |
386 | amount0[ linear] = 1.0 - a[linear]
387 | amount1[ linear] = a[linear]
388 | amount0[~linear] = np.sin((1.0 - a[~linear]) * omegas) / sinoms
389 | amount1[~linear] = np.sin( a[~linear] * omegas) / sinoms
390 |
391 | return Quaternions(
392 | amount0[...,np.newaxis] * fst +
393 | amount1[...,np.newaxis] * snd)
394 |
395 | @classmethod
396 | def between(cls, v0s, v1s):
397 | a = np.cross(v0s, v1s)
398 | w = np.sqrt((v0s**2).sum(axis=-1) * (v1s**2).sum(axis=-1)) + (v0s * v1s).sum(axis=-1)
399 | return Quaternions(np.concatenate([w[...,np.newaxis], a], axis=-1)).normalized()
400 |
401 | @classmethod
402 | def from_angle_axis(cls, angles, axis):
403 | axis = axis / (np.sqrt(np.sum(axis**2, axis=-1)) + 1e-10)[...,np.newaxis]
404 | sines = np.sin(angles / 2.0)[...,np.newaxis]
405 | cosines = np.cos(angles / 2.0)[...,np.newaxis]
406 | return Quaternions(np.concatenate([cosines, axis * sines], axis=-1))
407 |
408 | @classmethod
409 | def from_euler(cls, es, order='xyz', world=False):
410 |
411 | axis = {
412 | 'x' : np.array([1,0,0]),
413 | 'y' : np.array([0,1,0]),
414 | 'z' : np.array([0,0,1]),
415 | }
416 |
417 | q0s = Quaternions.from_angle_axis(es[...,0], axis[order[0]])
418 | q1s = Quaternions.from_angle_axis(es[...,1], axis[order[1]])
419 | q2s = Quaternions.from_angle_axis(es[...,2], axis[order[2]])
420 |
421 | return (q2s * (q1s * q0s)) if world else (q0s * (q1s * q2s))
422 |
423 | @classmethod
424 | def from_transforms(cls, ts):
425 |
426 | d0, d1, d2 = ts[...,0,0], ts[...,1,1], ts[...,2,2]
427 |
428 | q0 = ( d0 + d1 + d2 + 1.0) / 4.0
429 | q1 = ( d0 - d1 - d2 + 1.0) / 4.0
430 | q2 = (-d0 + d1 - d2 + 1.0) / 4.0
431 | q3 = (-d0 - d1 + d2 + 1.0) / 4.0
432 |
433 | q0 = np.sqrt(q0.clip(0,None))
434 | q1 = np.sqrt(q1.clip(0,None))
435 | q2 = np.sqrt(q2.clip(0,None))
436 | q3 = np.sqrt(q3.clip(0,None))
437 |
438 | c0 = (q0 >= q1) & (q0 >= q2) & (q0 >= q3)
439 | c1 = (q1 >= q0) & (q1 >= q2) & (q1 >= q3)
440 | c2 = (q2 >= q0) & (q2 >= q1) & (q2 >= q3)
441 | c3 = (q3 >= q0) & (q3 >= q1) & (q3 >= q2)
442 |
443 | q1[c0] *= np.sign(ts[c0,2,1] - ts[c0,1,2])
444 | q2[c0] *= np.sign(ts[c0,0,2] - ts[c0,2,0])
445 | q3[c0] *= np.sign(ts[c0,1,0] - ts[c0,0,1])
446 |
447 | q0[c1] *= np.sign(ts[c1,2,1] - ts[c1,1,2])
448 | q2[c1] *= np.sign(ts[c1,1,0] + ts[c1,0,1])
449 | q3[c1] *= np.sign(ts[c1,0,2] + ts[c1,2,0])
450 |
451 | q0[c2] *= np.sign(ts[c2,0,2] - ts[c2,2,0])
452 | q1[c2] *= np.sign(ts[c2,1,0] + ts[c2,0,1])
453 | q3[c2] *= np.sign(ts[c2,2,1] + ts[c2,1,2])
454 |
455 | q0[c3] *= np.sign(ts[c3,1,0] - ts[c3,0,1])
456 | q1[c3] *= np.sign(ts[c3,2,0] + ts[c3,0,2])
457 | q2[c3] *= np.sign(ts[c3,2,1] + ts[c3,1,2])
458 |
459 | qs = np.empty(ts.shape[:-2] + (4,))
460 | qs[...,0] = q0
461 | qs[...,1] = q1
462 | qs[...,2] = q2
463 | qs[...,3] = q3
464 |
465 | return cls(qs)
466 |
467 |
468 |
--------------------------------------------------------------------------------
/utils/Quaternions_torch.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
4 |
5 | class Quaternions_torch:
6 | """
7 | Quaternions is a wrapper around a numpy ndarray
8 | that allows it to act as if it were an narray of
9 | a quaternion data type.
10 |
11 | Therefore addition, subtraction, multiplication,
12 | division, negation, absolute, are all defined
13 | in terms of quaternion operations such as quaternion
14 | multiplication.
15 |
16 | This allows for much neater code and many routines
17 | which conceptually do the same thing to be written
18 | in the same way for point data and for rotation data.
19 |
20 | The Quaternions class has been desgined such that it
21 | should support broadcasting and slicing in all of the
22 | usual ways.
23 | """
24 |
25 | def __init__(self, qs):
26 | if isinstance(qs, torch.Tensor):
27 |
28 | if len(qs.shape) == 1: qs = torch.tensor([qs])
29 | self.qs = qs
30 | return
31 |
32 | if isinstance(qs, Quaternions_torch):
33 | self.qs = qs.qs
34 | return
35 |
36 | raise TypeError('Quaternions must be constructed from iterable, numpy array, or Quaternions, not %s' % type(qs))
37 |
38 | def __str__(self): return "Quaternions("+ str(self.qs) + ")"
39 | def __repr__(self): return "Quaternions("+ repr(self.qs) + ")"
40 |
41 | """ Helper Methods for Broadcasting and Data extraction """
42 |
43 | @classmethod
44 | def _broadcast(cls, sqs, oqs, scalar=False):
45 |
46 | if isinstance(oqs, float): return sqs, oqs * torch.ones(sqs.shape[:-1])
47 |
48 | ss = torch.tensor(sqs.shape).to(device) if not scalar else torch.tensor(sqs.shape[:-1]).to(device)
49 | os = torch.tensor(oqs.shape).to(device)
50 |
51 | if len(ss) != len(os):
52 | raise TypeError('Quaternions cannot broadcast together shapes %s and %s' % (sqs.shape, oqs.shape))
53 |
54 | if torch.all(ss == os): return sqs, oqs # TODO: check torch.all
55 | # ipdb.set_trace()
56 | if not torch.all((ss.to(device) == os.to(device)) | (os == torch.ones(len(os)).to(device)) | (ss == torch.ones(len(ss)).to(device))):
57 | raise TypeError('Quaternions cannot broadcast together shapes %s and %s' % (sqs.shape, oqs.shape))
58 |
59 | sqsn, oqsn = sqs.clone(), oqs.clone()
60 |
61 | for a in torch.where(ss == 1)[0]: sqsn = sqsn.repeat_interleave(os[a], dim=a)
62 | for a in torch.where(os == 1)[0]: oqsn = oqsn.repeat_interleave(ss[a], dim=a)
63 |
64 | return sqsn, oqsn
65 |
66 | """ Adding Quaterions is just Defined as Multiplication """
67 |
68 | def __add__(self, other): return self * other
69 | def __sub__(self, other): return self / other
70 |
71 | """ Quaterion Multiplication """
72 |
73 | def __mul__(self, other):
74 | """
75 | Quaternion multiplication has three main methods.
76 |
77 | When multiplying a Quaternions array by Quaternions
78 | normal quaternion multiplication is performed.
79 |
80 | When multiplying a Quaternions array by a vector
81 | array of the same shape, where the last axis is 3,
82 | it is assumed to be a Quaternion by 3D-Vector
83 | multiplication and the 3D-Vectors are rotated
84 | in space by the Quaternions.
85 |
86 | When multipplying a Quaternions array by a scalar
87 | or vector of different shape it is assumed to be
88 | a Quaternions by Scalars multiplication and the
89 | Quaternions are scaled using Slerp and the identity
90 | quaternions.
91 | """
92 |
93 | """ If Quaternions type do Quaternions * Quaternions """
94 | if isinstance(other, Quaternions_torch):
95 |
96 | sqs, oqs = Quaternions_torch._broadcast(self.qs, other.qs)
97 |
98 | q0 = sqs[...,0]; q1 = sqs[...,1];
99 | q2 = sqs[...,2]; q3 = sqs[...,3];
100 | r0 = oqs[...,0]; r1 = oqs[...,1];
101 | r2 = oqs[...,2]; r3 = oqs[...,3];
102 |
103 | qs = torch.empty(sqs.shape).to(device)
104 | qs[...,0] = r0 * q0 - r1 * q1 - r2 * q2 - r3 * q3
105 | qs[...,1] = r0 * q1 + r1 * q0 - r2 * q3 + r3 * q2
106 | qs[...,2] = r0 * q2 + r1 * q3 + r2 * q0 - r3 * q1
107 | qs[...,3] = r0 * q3 - r1 * q2 + r2 * q1 + r3 * q0
108 |
109 | return Quaternions_torch(qs)
110 |
111 | """ If array type do Quaternions * Vectors """
112 | if isinstance(other, torch.Tensor) and other.shape[-1] == 3:
113 | vs = Quaternions_torch(torch.cat([torch.zeros(other.shape[:-1] + (1,)).to(device), other], dim=-1))
114 | # ipdb.set_trace()
115 | return (self * (vs * -self)).imaginaries
116 |
117 | """ If float do Quaternions * Scalars """
118 | if isinstance(other,torch.Tensor) or isinstance(other, float):
119 | return Quaternions_torch.slerp(Quaternions_torch.id_like(self), self, other)
120 |
121 | raise TypeError('Cannot multiply/add Quaternions with type %s' % str(type(other)))
122 |
123 | def __div__(self, other):
124 | """
125 | When a Quaternion type is supplied, division is defined
126 | as multiplication by the inverse of that Quaternion.
127 |
128 | When a scalar or vector is supplied it is defined
129 | as multiplicaion of one over the supplied value.
130 | Essentially a scaling.
131 | """
132 |
133 | if isinstance(other, Quaternions_torch): return self * (-other)
134 | if isinstance(other, torch.Tensor): return self * (1.0 / other)
135 | if isinstance(other, float): return self * (1.0 / other)
136 | raise TypeError('Cannot divide/subtract Quaternions with type %s' + str(type(other)))
137 |
138 | def __eq__(self, other): return self.qs == other.qs
139 | def __ne__(self, other): return self.qs != other.qs
140 |
141 | def __neg__(self):
142 | """ Invert Quaternions """
143 | return Quaternions_torch(self.qs * torch.tensor([[1, -1, -1, -1]]).to(device))
144 |
145 | def __abs__(self):
146 | """ Unify Quaternions To Single Pole """
147 | qabs = self.normalized().copy()
148 | top = torch.sum(( qabs.qs) * torch.tensor([1,0,0,0]).to(device), dim=-1)
149 | bot = torch.sum((-qabs.qs) * torch.tensor([1,0,0,0]).to(device), dim=-1)
150 | qabs.qs[top < bot] = -qabs.qs[top < bot]
151 | return qabs
152 |
153 | def __iter__(self): return iter(self.qs)
154 | def __len__(self): return len(self.qs)
155 |
156 | def __getitem__(self, k): return Quaternions_torch(self.qs[k])
157 | def __setitem__(self, k, v): self.qs[k] = v.qs
158 |
159 | @property
160 | def lengths(self):
161 | return torch.sum(self.qs**2.0, axis=-1)**0.5
162 |
163 | @property
164 | def reals(self):
165 | return self.qs[...,0]
166 |
167 | @property
168 | def imaginaries(self):
169 | return self.qs[...,1:4]
170 |
171 | @property
172 | def shape(self): return self.qs.shape[:-1]
173 |
174 | def repeat(self, n, **kwargs):
175 | return Quaternions_torch(self.qs.repeat(n, **kwargs))
176 |
177 | def normalized(self):
178 | return Quaternions_torch(self.qs / self.lengths.unsqueeze(-1))
179 |
180 | def unsqueeze(self, dim):
181 | return Quaternions_torch(self.qs.unsqueeze(dim))
182 |
183 | def log(self):
184 | norm = torch.abs(self.normalized())
185 | imgs = norm.imaginaries
186 | lens = torch.sqrt(torch.sum(imgs**2, dim=-1))
187 | lens = torch.arctan2(lens, norm.reals) / (lens + 1e-10)
188 | return imgs * lens.unsqueeze(-1)
189 |
190 | def constrained(self, axis):
191 |
192 | rl = self.reals
193 | im = torch.sum(axis * self.imaginaries, dim=-1)
194 |
195 | t1 = -2 * torch.arctan2(rl, im) + torch.pi
196 | t2 = -2 * torch.arctan2(rl, im) - torch.pi
197 |
198 | top = Quaternions_torch.exp(axis.unsqueeze(-1) * (t1.unsqueeze(-1) / 2.0))
199 | bot = Quaternions_torch.exp(axis.unsqueeze(-1) * (t2.unsqueeze(-1) / 2.0))
200 | img = self.dot(top) > self.dot(bot)
201 |
202 | ret = top.detach().clone() #.copy()
203 | ret[ img] = top[ img]
204 | ret[~img] = bot[~img]
205 | return ret
206 |
207 | def constrained_x(self): return self.constrained(torch.tensor([1,0,0]).to(device))
208 | def constrained_y(self): return self.constrained(torch.tensor([0,1,0]).to(device))
209 | def constrained_z(self): return self.constrained(torch.tensor([0,0,1]).to(device))
210 |
211 | def dot(self, q): return torch.sum(self.qs * q.qs, axis=-1)
212 |
213 | def copy(self): return Quaternions_torch(self.qs.detach().clone())
214 |
215 | def reshape(self, s):
216 | self.qs.reshape(s)
217 | return self
218 |
219 | @classmethod
220 | def between(cls, v0s, v1s):
221 | a = torch.cross(v0s, v1s)
222 | w = torch.sqrt((v0s**2).sum(dim=-1) * (v1s**2).sum(dim=-1)) + (v0s * v1s).sum(dim=-1)
223 | return Quaternions_torch(torch.cat([w.unsqueeze(-1), a], dim=-1)).normalized()
224 |
225 |
--------------------------------------------------------------------------------
/utils/__pycache__/Pivots.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/utils/__pycache__/Pivots.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/Pivots_torch.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/utils/__pycache__/Pivots_torch.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/Quaternions.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/utils/__pycache__/Quaternions.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/Quaternions_torch.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/utils/__pycache__/Quaternions_torch.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/cfg_parser.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/utils/__pycache__/cfg_parser.cpython-310.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/cfg_parser.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/utils/__pycache__/cfg_parser.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/train_helper.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/utils/__pycache__/train_helper.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/train_tools.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/utils/__pycache__/train_tools.cpython-310.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/train_tools.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/utils/__pycache__/train_tools.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/utils.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/utils/__pycache__/utils.cpython-310.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/utils/__pycache__/utils.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/utils_body.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/utils/__pycache__/utils_body.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/cfg_parser.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import yaml
4 |
5 |
6 | class Config(dict):
7 |
8 | def __init__(self,default_cfg_path=None,**kwargs):
9 |
10 | default_cfg = {}
11 | if default_cfg_path is not None and os.path.exists(default_cfg_path):
12 | default_cfg = self.load_cfg(default_cfg_path)
13 |
14 | super(Config,self).__init__(**kwargs)
15 |
16 | default_cfg.update(self)
17 | self.update(default_cfg)
18 | self.default_cfg = default_cfg
19 |
20 | def load_cfg(self,load_path):
21 | with open(load_path, 'r') as infile:
22 | cfg = yaml.safe_load(infile)
23 | return cfg if cfg is not None else {}
24 |
25 | def write_cfg(self,write_path=None):
26 |
27 | if write_path is None:
28 | write_path = 'yaml_config.yaml'
29 |
30 | dump_dict = {k:v for k,v in self.items() if k!='default_cfg'}
31 | with open(write_path, 'w') as outfile:
32 | yaml.safe_dump(dump_dict, outfile, default_flow_style=False)
33 |
34 | def __getattr__(self, key):
35 | try:
36 | return self[key]
37 | except KeyError:
38 | raise AttributeError(key)
39 |
40 | __setattr__ = dict.__setitem__
41 | __delattr__ = dict.__delitem__
42 |
43 |
44 | if __name__ == '__main__':
45 |
46 | cfg = {
47 | 'intent': 'all',
48 | 'only_contact': True,
49 | 'save_body_verts': False,
50 | 'save_object_verts': False,
51 | 'save_contact': False,
52 | }
53 |
54 | cfg = Config(**cfg)
55 | cfg.write_cfg()
56 |
--------------------------------------------------------------------------------
/utils/como/SSM2.json:
--------------------------------------------------------------------------------
1 | {
2 | "gender": "unknown",
3 | "markersets": [
4 | {
5 | "distance_from_skin": 0.0095,
6 | "indices": {
7 | "C7": 3832,
8 | "CLAV": 5533,
9 | "LANK": 5882,
10 | "LFWT": 3486,
11 | "LBAK": 3336,
12 | "LBCEP": 4029,
13 | "LBSH": 4137,
14 | "LBUM": 5694,
15 | "LBUST": 3228,
16 | "LCHEECK": 2081,
17 | "LELB": 4302,
18 | "LELBIN": 4363,
19 | "LFIN": 4788,
20 | "LFRM2": 4379,
21 | "LFTHI": 3504,
22 | "LFTHIIN": 3998,
23 | "LHEE": 8846,
24 | "LIWR": 4726,
25 | "LKNE": 3682,
26 | "LKNI": 3688,
27 | "LMT1": 5890,
28 | "LMT5": 5901,
29 | "LNWST": 3260,
30 | "LOWR": 4722,
31 | "LBWT": 5697,
32 | "LRSTBEEF": 5838,
33 | "LSHO": 4481,
34 | "LTHI": 4088,
35 | "LTHMB": 4839,
36 | "LTIB": 3745,
37 | "LTOE": 5787,
38 | "MBLLY": 5942,
39 | "RANK": 8576,
40 | "RFWT": 6248,
41 | "RBAK": 6127,
42 | "RBCEP": 6776,
43 | "RBSH": 7192,
44 | "RBUM": 8388,
45 | "RBUSTLO": 8157,
46 | "RCHEECK": 8786,
47 | "RELB": 7040,
48 | "RELBIN": 7099,
49 | "RFIN": 7524,
50 | "RFRM2": 7115,
51 | "RFRM2IN": 7303,
52 | "RFTHI": 6265,
53 | "RFTHIIN": 6746,
54 | "RHEE": 8634,
55 | "RKNE": 6443,
56 | "RKNI": 6449,
57 | "RMT1": 8584,
58 | "RMT5": 8595,
59 | "RNWST": 6023,
60 | "ROWR": 7458,
61 | "RBWT": 8391,
62 | "RRSTBEEF": 8532,
63 | "RSHO": 6627,
64 | "RTHI": 6832,
65 | "RTHMB": 7575,
66 | "RTIB": 6503,
67 | "RTOE": 8481,
68 | "STRN": 5531,
69 | "T8": 5487,
70 | "LFHD": 707,
71 | "LBHD": 2026,
72 | "RFHD": 2198,
73 | "RBHD": 3066
74 | },
75 | "marker_radius": 0.0095,
76 | "type": "body"
77 | }
78 | ]
79 | }
--------------------------------------------------------------------------------
/utils/como/__pycache__/como_smooth.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/utils/como/__pycache__/como_smooth.cpython-310.pyc
--------------------------------------------------------------------------------
/utils/como/__pycache__/como_smooth.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/utils/como/__pycache__/como_smooth.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/como/__pycache__/como_utils.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/utils/como/__pycache__/como_utils.cpython-310.pyc
--------------------------------------------------------------------------------
/utils/como/__pycache__/como_utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/utils/como/__pycache__/como_utils.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/como/como_smooth.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torchvision
6 | from torch.autograd import Variable
7 |
8 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9 |
10 |
11 | class EncBlock(nn.Module):
12 | def __init__(self, nin, nout, downsample=True):
13 | super(EncBlock, self).__init__()
14 | self.downsample = downsample
15 |
16 | self.main = nn.Sequential(
17 | nn.Conv2d(in_channels=nin, out_channels=nout, kernel_size=3, stride=1, padding=1),
18 | nn.LeakyReLU(0.2),
19 | nn.Conv2d(in_channels=nout, out_channels=nout, kernel_size=3, stride=1, padding=1),
20 | nn.LeakyReLU(0.2),
21 | )
22 | self.pooling = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
23 |
24 | def forward(self, input):
25 | output = self.main(input)
26 | if not self.downsample:
27 | return output
28 | else:
29 | output = self.pooling(output)
30 | return output
31 |
32 |
33 |
34 | class DecBlock(nn.Module):
35 | def __init__(self, nin, nout, upsample=True):
36 | super(DecBlock, self).__init__()
37 | self.upsample = upsample
38 | if upsample:
39 | deconv_stride = 2
40 | else:
41 | deconv_stride = 1
42 |
43 | self.deconv1 = nn.ConvTranspose2d(in_channels=nin, out_channels=nout, kernel_size=3, stride=deconv_stride, padding=1)
44 | self.deconv2 = nn.ConvTranspose2d(in_channels=nout, out_channels=nout, kernel_size=3, stride=1, padding=1)
45 | self.leaky_relu = nn.LeakyReLU(0.2)
46 |
47 | def forward(self, input, out_size):
48 | output = self.deconv1(input, output_size=out_size)
49 | output = self.leaky_relu(output)
50 | output = self.leaky_relu(self.deconv2(output))
51 | return output
52 |
53 |
54 | class DecBlock_output(nn.Module):
55 | def __init__(self, nin, nout, upsample=True):
56 | super(DecBlock_output, self).__init__()
57 | self.upsample = upsample
58 | if upsample:
59 | deconv_stride = 2
60 | else:
61 | deconv_stride = 1
62 |
63 | self.deconv1 = nn.ConvTranspose2d(in_channels=nin, out_channels=nout, kernel_size=3, stride=deconv_stride, padding=1)
64 | self.deconv2 = nn.ConvTranspose2d(in_channels=nout, out_channels=nout, kernel_size=3, stride=1, padding=1)
65 | self.leaky_relu = nn.LeakyReLU(0.2)
66 |
67 |
68 | def forward(self, input, out_size):
69 | output = self.deconv1(input, output_size=out_size)
70 | output = self.leaky_relu(output)
71 | output = self.deconv2(output)
72 | return output
73 |
74 |
75 |
76 |
77 | class Enc(nn.Module):
78 | def __init__(self, downsample=True, z_channel=64):
79 | super(Enc, self).__init__()
80 | if z_channel == 256:
81 | channel_2, channel_3 = 128, 256
82 | elif z_channel == 64:
83 | channel_2, channel_3 = 64, 64
84 | self.enc_blc1 = EncBlock(nin=1, nout=32, downsample=downsample)
85 | self.enc_blc2 = EncBlock(nin=32, nout=64, downsample=downsample)
86 | self.enc_blc3 = EncBlock(nin=64, nout=channel_2, downsample=downsample)
87 | self.enc_blc4 = EncBlock(nin=channel_2, nout=channel_3, downsample=downsample)
88 | self.enc_blc5 = EncBlock(nin=channel_3, nout=channel_3, downsample=downsample)
89 |
90 |
91 | def forward(self, input):
92 | # input: [bs, 1, d, T]
93 | # [bs, 1, 51/145/75, 120] (smplx_params, no hand, vposer/6d_rot)/(joints, no hand)
94 | x_down1 = self.enc_blc1(input) # [bs, 32, 26, 60]
95 | x_down2 = self.enc_blc2(x_down1) # [bs, 64, 13, 30]
96 | x_down3 = self.enc_blc3(x_down2) # [bs, 128, 7, 15]
97 | x_down4 = self.enc_blc4(x_down3) # [bs, 256, 4, 8]
98 | z = self.enc_blc5(x_down4) # [bs, 256, 2/5/3, 4] (smplx_params, no hand, vposer/6d_rot)/(joints, no hand)
99 | return z, input.size(), x_down1.size(), x_down2.size(), x_down3.size(), x_down4.size()
100 |
101 |
102 | class Dec(nn.Module):
103 | def __init__(self, downsample=True, z_channel=64):
104 | super(Dec, self).__init__()
105 | if z_channel == 256:
106 | channel_2, channel_3 = 128, 256
107 | elif z_channel == 64:
108 | channel_2, channel_3 = 64, 64
109 |
110 | self.dec_blc1 = DecBlock(nin=channel_3, nout=channel_3, upsample=downsample)
111 | self.dec_blc2 = DecBlock(nin=channel_3, nout=channel_2, upsample=downsample)
112 | self.dec_blc3 = DecBlock(nin=channel_2, nout=64, upsample=downsample)
113 | self.dec_blc4 = DecBlock(nin=64, nout=32, upsample=downsample)
114 | self.dec_blc5 = DecBlock_output(nin=32, nout=1, upsample=downsample)
115 |
116 |
117 | def forward(self, z, input_size, x_down1_size, x_down2_size, x_down3_size, x_down4_size):
118 | x_up4 = self.dec_blc1(z, x_down4_size) # [bs, 256, 4, 8]
119 | x_up3 = self.dec_blc2(x_up4, x_down3_size) # [bs, 128, 7, 15]
120 | x_up2 = self.dec_blc3(x_up3, x_down2_size) # [bs, 64, 13, 30]
121 | x_up1 = self.dec_blc4(x_up2, x_down1_size) # [bs, 32, 26, 60]
122 | output = self.dec_blc5(x_up1, input_size)
123 | return output
124 |
--------------------------------------------------------------------------------
/utils/como/como_smooth_model.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/utils/como/como_smooth_model.pkl
--------------------------------------------------------------------------------
/utils/como/my_SSM2.json:
--------------------------------------------------------------------------------
1 | {
2 | "gender": "unknown",
3 | "markersets": [
4 | {
5 | "distance_from_skin": 0.0095,
6 | "indices": {
7 | "C7": 3832,
8 | "CLAV": 5533,
9 | "LANK": 5882,
10 | "LFWT": 3486,
11 | "LBAK": 3336,
12 | "LBCEP": 4029,
13 | "LBSH": 4137,
14 | "LBUM": 5694,
15 | "LBUST": 3228,
16 | "LCHEECK": 2081,
17 | "LELB": 4302,
18 | "LELBIN": 4363,
19 | "LFIN": 4788,
20 | "LFRM2": 4379,
21 | "LFTHI": 3504,
22 | "LFTHIIN": 3998,
23 | "LHEE": 8846,
24 | "LIWR": 4726,
25 | "LKNE": 3682,
26 | "LKNI": 3688,
27 | "LMT1": 5890,
28 | "LMT5": 5901,
29 | "LNWST": 3260,
30 | "LOWR": 4722,
31 | "LBWT": 5697,
32 | "LRSTBEEF": 5838,
33 | "LSHO": 4481,
34 | "LTHI": 4088,
35 | "LTHMB": 4839,
36 | "LTIB": 3745,
37 | "LTOE": 5787,
38 | "MBLLY": 5942,
39 | "RANK": 8576,
40 | "RFWT": 6248,
41 | "RBAK": 6127,
42 | "RBCEP": 6776,
43 | "RBSH": 7192,
44 | "RBUM": 8388,
45 | "RBUSTLO": 8157,
46 | "RCHEECK": 8786,
47 | "RELB": 7040,
48 | "RELBIN": 7099,
49 | "RFIN": 7524,
50 | "RFRM2": 7115,
51 | "RFRM2IN": 7303,
52 | "RFTHI": 6265,
53 | "RFTHIIN": 6746,
54 | "RHEE": 8634,
55 | "RKNE": 6443,
56 | "RKNI": 6449,
57 | "RMT1": 8584,
58 | "RMT5": 8595,
59 | "RNWST": 6023,
60 | "ROWR": 7458,
61 | "RBWT": 8391,
62 | "RRSTBEEF": 8532,
63 | "RSHO": 6627,
64 | "RTHI": 6832,
65 | "RTHMB": 7575,
66 | "RTIB": 6503,
67 | "RTOE": 8481,
68 | "STRN": 5531,
69 | "T8": 5487,
70 | "LFHD": 707,
71 | "LBHD": 2026,
72 | "RFHD": 2198,
73 | "RBHD": 3066,
74 | "CHN1": 8757,
75 | "CHN2": 9066,
76 | "MTH3": 8985,
77 | "MTH7": 8947,
78 | "LIDX3": 4931,
79 | "LMID3": 5045,
80 | "LPNK3": 5268,
81 | "LRNG3": 5149,
82 | "LTHM4": 5346,
83 | "RIDX3": 7667,
84 | "RMID3": 7781,
85 | "RPNK3": 8001,
86 | "RRNG3": 7884,
87 | "RTHM4": 8082
88 | },
89 | "marker_radius": 0.0095,
90 | "type": "body"
91 | }
92 | ]
93 | }
--------------------------------------------------------------------------------
/utils/como/preprocess_stats_global_markers/Xmean.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/utils/como/preprocess_stats_global_markers/Xmean.npy
--------------------------------------------------------------------------------
/utils/como/preprocess_stats_global_markers/Xstd.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/utils/como/preprocess_stats_global_markers/Xstd.npy
--------------------------------------------------------------------------------
/utils/train_helper.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | import chamfer_distance as chd
5 | import numpy as np
6 | import scipy.ndimage.filters as filters
7 | import torch
8 |
9 | from utils.Pivots import Pivots
10 | from utils.Pivots_torch import Pivots_torch
11 | from utils.Quaternions import Quaternions
12 | from utils.Quaternions_torch import Quaternions_torch
13 |
14 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15 |
16 |
17 | def point2point_signed(
18 | x,
19 | y,
20 | x_normals=None,
21 | y_normals=None,
22 | return_vector=False,
23 | ):
24 | """
25 | signed distance between two pointclouds
26 |
27 | Args:
28 | x: FloatTensor of shape (N, P1, D) representing a batch of point clouds
29 | with P1 points in each batch element, batch size N and feature
30 | dimension D.
31 | y: FloatTensor of shape (N, P2, D) representing a batch of point clouds
32 | with P2 points in each batch element, batch size N and feature
33 | dimension D.
34 | x_normals: Optional FloatTensor of shape (N, P1, D).
35 | y_normals: Optional FloatTensor of shape (N, P2, D).
36 |
37 | Returns:
38 |
39 | - y2x_signed: Torch.Tensor
40 | the sign distance from y to x
41 | - y2x_signed: Torch.Tensor
42 | the sign distance from y to x
43 | - yidx_near: Torch.tensor
44 | the indices of x vertices closest to y
45 |
46 | """
47 |
48 |
49 | N, P1, D = x.shape
50 | P2 = y.shape[1]
51 |
52 | if y.shape[0] != N or y.shape[2] != D:
53 | raise ValueError("y does not have the correct shape.")
54 |
55 | # ch_dist = chd.ChamferDistance()
56 |
57 | x_near, y_near, xidx_near, yidx_near = chd.ChamferDistance(x,y)
58 |
59 | xidx_near_expanded = xidx_near.view(N, P1, 1).expand(N, P1, D).to(torch.long)
60 | x_near = y.gather(1, xidx_near_expanded)
61 |
62 | yidx_near_expanded = yidx_near.view(N, P2, 1).expand(N, P2, D).to(torch.long)
63 | y_near = x.gather(1, yidx_near_expanded)
64 |
65 | x2y = x - x_near # y point to x
66 | y2x = y - y_near # x point to y
67 |
68 | if x_normals is not None:
69 | y_nn = x_normals.gather(1, yidx_near_expanded)
70 | in_out = torch.bmm(y_nn.view(-1, 1, 3), y2x.view(-1, 3, 1)).view(N, -1).sign()
71 | y2x_signed = y2x.norm(dim=2) * in_out
72 |
73 | else:
74 | y2x_signed = y2x.norm(dim=2)
75 |
76 | if y_normals is not None:
77 | x_nn = y_normals.gather(1, xidx_near_expanded)
78 | in_out_x = torch.bmm(x_nn.view(-1, 1, 3), x2y.view(-1, 3, 1)).view(N, -1).sign()
79 | x2y_signed = x2y.norm(dim=2) * in_out_x
80 | else:
81 | x2y_signed = x2y.norm(dim=2)
82 |
83 | if not return_vector:
84 | return y2x_signed, x2y_signed, yidx_near, xidx_near
85 | else:
86 | return y2x_signed, x2y_signed, yidx_near, xidx_near, y2x, x2y
87 |
88 |
89 |
90 |
91 | class EarlyStopping:
92 | """Early stops the training if validation loss doesn't improve after a given patience."""
93 | def __init__(self, patience=7, verbose=False, delta=0, trace_func=None):
94 | """
95 | Args:
96 | patience (int): How long to wait after last time validation loss improved.
97 | Default: 7
98 | verbose (bool): If True, prints a message for each validation loss improvement.
99 | Default: False
100 | delta (float): Minimum change in the monitored quantity to qualify as an improvement.
101 | Default: 0
102 | trace_func (function): trace print function.
103 | Default: print
104 | """
105 | self.patience = patience
106 | self.verbose = verbose
107 | self.counter = 0
108 | self.best_score = None
109 | self.early_stop = False
110 | self.val_loss_min = np.Inf
111 | self.delta = delta
112 | self.trace_func = trace_func
113 | def __call__(self, val_loss):
114 |
115 | score = -val_loss
116 |
117 | if self.best_score is None:
118 | self.best_score = score
119 | elif score < self.best_score + self.delta:
120 | self.counter += 1
121 | if self.trace_func is not None:
122 | self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
123 | if self.counter >= self.patience:
124 | self.early_stop = True
125 | else:
126 | self.best_score = score
127 | self.counter = 0
128 | return self.early_stop
129 |
130 | def save_ckp(state, checkpoint_dir):
131 | f_path = os.path.join(checkpoint_dir, 'checkpoint.pt')
132 | torch.save(state, f_path)
133 |
134 | def load_ckp(checkpoint_fpath, model, optimizer):
135 | checkpoint = torch.load(checkpoint_fpath)
136 | model.load_state_dict(checkpoint['state_dict'])
137 | optimizer.load_state_dict(checkpoint['optimizer'])
138 | return model, optimizer, checkpoint['epoch']
139 |
140 |
141 | def get_forward_joint(joint_start):
142 | """ Joint_start: [B, N, 3] in xyz """
143 | x_axis = joint_start[:, 2, :] - joint_start[:, 1, :]
144 | x_axis[:, -1] = 0
145 | x_axis = x_axis / torch.norm(x_axis, dim=-1).unsqueeze(1)
146 | z_axis = torch.tensor([0, 0, 1]).float().unsqueeze(0).repeat(len(x_axis), 1).to(device)
147 | y_axis = torch.cross(z_axis, x_axis)
148 | y_axis = y_axis / torch.norm(y_axis, dim=-1).unsqueeze(1)
149 | transf_rotmat = torch.stack([x_axis, y_axis, z_axis], dim=1)
150 | return y_axis, transf_rotmat
151 |
152 | def prepare_traj_input(joint_start, joint_end, traj_Xmean, traj_Xstd):
153 | """ Joints: [B, N, 3] in xyz """
154 | B, N, _ = joint_start.shape
155 | T = 62
156 | joint_sr_input_unnormed = torch.ones(B, 4, T) # [B, xyr, T]
157 | y_axis, transf_rotmat = get_forward_joint(joint_start)
158 | joint_start_new = joint_start.clone()
159 | joint_end_new = joint_end.clone() # to check whether original joints change or not
160 | joint_start_new = torch.matmul(joint_start - joint_start[:, 0:1], transf_rotmat)
161 | joint_end_new = torch.matmul(joint_end - joint_start[:, 0:1], transf_rotmat)
162 |
163 | # start_forward, _ = get_forward_joint(joint_start_new)
164 | start_forward = torch.tensor([0, 1, 0]).unsqueeze(0)
165 | end_forward, _ = get_forward_joint(joint_end_new)
166 |
167 | joint_sr_input_unnormed[:, :2, 0] = joint_start_new[:, 0, :2] # xy
168 | joint_sr_input_unnormed[:, :2, -2] = joint_end_new[:, 0, :2] # xy
169 | joint_sr_input_unnormed[:, 2:, 0] = start_forward[:, :2] # r
170 | joint_sr_input_unnormed[:, 2:, -2] = end_forward[:, :2] # r
171 |
172 | # normalize
173 | traj_mean = traj_Xmean.unsqueeze(2).cpu()
174 | traj_std = traj_Xstd.unsqueeze(2).cpu()
175 |
176 | # linear interpolation
177 | joint_sr_input_normed = (joint_sr_input_unnormed - traj_mean) / traj_std
178 | for t in range(joint_sr_input_normed.size(-1)):
179 | joint_sr_input_normed[:, :, t] = joint_sr_input_normed[:, :, 0] + (joint_sr_input_normed[:, :, -2] - joint_sr_input_normed[:, :, 0])*t/(joint_sr_input_normed.size(-1)-2)
180 | joint_sr_input_normed[:, -2:, t] = joint_sr_input_normed[:, -2:, t] / torch.norm(joint_sr_input_normed[:, -2:, t], dim=1).unsqueeze(1)
181 |
182 | for t in range(joint_sr_input_unnormed.size(-1)):
183 | joint_sr_input_unnormed[:, :, t] = joint_sr_input_unnormed[:, :, 0] + (joint_sr_input_unnormed[:, :, -2] - joint_sr_input_unnormed[:, :, 0])*t/(joint_sr_input_unnormed.size(-1)-2)
184 | joint_sr_input_unnormed[:, -2:, t] = joint_sr_input_unnormed[:, -2:, t] / torch.norm(joint_sr_input_unnormed[:, -2:, t], dim=1).unsqueeze(1)
185 |
186 | return joint_sr_input_normed.float().to(device), joint_sr_input_unnormed.float().to(device), transf_rotmat, joint_start_new, joint_end_new
187 |
188 | def prepare_clip_img_input(clip_img, marker_start, marker_end, joint_start, joint_end, joint_start_new, joint_end_new, transf_rotmat, traj_pred_unnormed, traj_sr_input_unnormed, traj_smoothed, markers_stats):
189 | traj_pred_unnormed = traj_pred_unnormed.detach().cpu().numpy()
190 |
191 | traj_pred_unnormed[:, :, 0] = traj_sr_input_unnormed[:, :, 0].detach().cpu().numpy()
192 | traj_pred_unnormed[:, :, -2] = traj_sr_input_unnormed[:, :, -2].detach().cpu().numpy()
193 |
194 | B, n_markers, _ = marker_start.shape
195 | _, n_joints, _ = joint_start.shape
196 | markers = torch.rand(B, 61, n_markers, 3) # [B, T, N ,3]
197 | joints = torch.rand(B, 61, n_joints, 3) # [B, T, N ,3]
198 |
199 | marker_start_new = torch.matmul(marker_start - joint_start[:, 0:1], transf_rotmat)
200 | marker_end_new = torch.matmul(marker_end - joint_start[:, 0:1], transf_rotmat)
201 |
202 | z_transl_to_floor_start = torch.min(marker_start_new[:, :, -1], dim=-1)[0]# - 0.03
203 | z_transl_to_floor_end = torch.min(marker_end_new[:, :, -1], dim=-1)[0]# - 0.03
204 |
205 | marker_start_new[:, :, -1] -= z_transl_to_floor_start.unsqueeze(1)
206 | marker_end_new[:, :, -1] -= z_transl_to_floor_end.unsqueeze(1)
207 | joint_start_new[:, :, -1] -= z_transl_to_floor_start.unsqueeze(1)
208 | joint_end_new[:, :, -1] -= z_transl_to_floor_end.unsqueeze(1)
209 |
210 | markers[:, 0] = marker_start_new
211 | markers[:, -1] = marker_end_new
212 | joints[:, 0] = joint_start_new
213 | joints[:, -1] = joint_end_new
214 |
215 | cur_body = torch.cat([joints[:, :, 0:1], markers], dim=2)
216 | cur_body[:, :, :, [1, 2]] = cur_body[:, :, :, [2, 1]] # => xyz -> xzy
217 | reference = cur_body[:, :, 0] * torch.tensor([1, 0, 1]) # => the xy of pelvis joint?
218 | cur_body = torch.cat([reference.unsqueeze(2), cur_body], dim=2) # [B, T, 1(reference)+1(pelvis)+N, 3]
219 |
220 | # position to local frame
221 | cur_body[:, :, :, 0] = cur_body[:, :, :, 0] - cur_body[:, :, 0:1, 0]
222 | cur_body[:, :, :, -1] = cur_body[:, :, :, -1] - cur_body[:, :, 0:1, -1]
223 |
224 | forward = np.zeros((B, 62, 3))
225 | forward[:, :, :2] = traj_pred_unnormed[:, 2:].transpose(0, 2, 1)
226 | forward = forward / np.sqrt((forward ** 2).sum(axis=-1))[..., np.newaxis]
227 | forward[:, :, [1, 2]] = forward[:, :, [2, 1]]
228 |
229 | if traj_smoothed:
230 | forward_saved = forward.copy()
231 | direction_filterwidth = 20
232 | forward = filters.gaussian_filter1d(forward, direction_filterwidth, axis=1, mode='nearest')
233 | traj_pred_unnormed[:, 2] = forward[:, :, 0]
234 | traj_pred_unnormed[:, 3] = forward[:, :, -1]
235 |
236 | target = np.array([[0, 0, 1]])
237 | rotation = Quaternions.between(forward, target)[:, :, np.newaxis] # [B, T, 1, 4]
238 |
239 | cur_body = rotation[:, :-1] * cur_body.detach().cpu().numpy() # [B, T, 1+1+N, xzy]
240 | cur_body[:, 1:-1] = 0
241 | cur_body[:, :, :, [1, 2]] = cur_body[:, :, :, [2, 1]] # xzy => xyz
242 | cur_body = cur_body[:, :, 1:, :]
243 | cur_body = cur_body.reshape(cur_body.shape[0], cur_body.shape[1], -1) # [B, T, N*3]
244 |
245 | velocity = np.zeros((B, 3, 61))
246 | velocity[:, 0, :] = traj_pred_unnormed[:, 0, 1:] - traj_pred_unnormed[:, 0, 0:-1] # [B, 2, 61] on Joint frame
247 | velocity[:, -1, :] = traj_pred_unnormed[:, 1, 1:] - traj_pred_unnormed[:, 1, 0:-1] # [B, 2, 61] on Joint frame
248 |
249 | velocity = rotation[:, 1:] * velocity.transpose(0, 2, 1).reshape(B, 61, 1, 3)
250 | rvelocity = Pivots.from_quaternions(rotation[:, 1:] * -rotation[:, :-1]).ps # [B, T-1, 1]
251 | rot_0_pivot = Pivots.from_quaternions(rotation[:, 0]).ps
252 |
253 | global_x = velocity[:, :, 0, 0]
254 | global_y = velocity[:, :, 0, 2]
255 | contact_lbls = np.zeros((B, 61, 4))
256 |
257 | channel_local = np.concatenate([cur_body, contact_lbls], axis=-1)[:, np.newaxis, :, :] # [B, 1, T-1, d=N*3+4]
258 | T, d = channel_local.shape[-2], channel_local.shape[-1]
259 | channel_global_x = np.repeat(global_x, d).reshape(-1, 1, T, d) # [B, 1, T-1, d]
260 | channel_global_y = np.repeat(global_y, d).reshape(-1, 1, T, d) # [B, 1, T-1, d]
261 | channel_global_r = np.repeat(rvelocity, d).reshape(-1, 1, T, d) # [B, 1, T-1, d]
262 |
263 | cur_body = np.concatenate([clip_img[:, 0:1].detach().permute(0, 1, 3, 2).cpu().numpy(), channel_global_x, channel_global_y, channel_global_r], axis=1) # [B, 4, T-1, d]
264 |
265 | # cur_body[:, 0] = (cur_body[:, 0] - markers_stats['Xmean_local']) / markers_stats['Xstd_local']
266 | cur_body[:, 1:3] = (cur_body[:, 1:3] - markers_stats['Xmean_global_xy']) / markers_stats['Xstd_global_xy']
267 | cur_body[:, 3] = (cur_body[:, 3] - markers_stats['Xmean_global_r']) / markers_stats['Xstd_global_r']
268 |
269 | # mask cur_body
270 | cur_body = cur_body.transpose(0, 1, 3, 2) # [B, 4, D, T-1]
271 | mask_t_1 = [0, 60]
272 | mask_t_0 = list(set(range(60+1)) - set(mask_t_1))
273 | cur_body[:, 0, 2:, mask_t_0] = 0.
274 | cur_body[:, 0, -4:, :] = 0.
275 | # print('Mask the markers in the following frames: ', mask_t_0)
276 |
277 | return torch.from_numpy(cur_body).float().to(device), rot_0_pivot, marker_start_new, marker_end_new, traj_pred_unnormed
278 |
279 |
280 | def prepare_clip_img_input_torch(clip_img, marker_start, marker_end, joint_start, joint_end,
281 | joint_start_new, joint_end_new, transf_rotmat,
282 | traj_pred_unnormed, traj_sr_input_unnormed,
283 | traj_smoothed, markers_stats):
284 |
285 | traj_pred_unnormed[:, :, 0] = traj_sr_input_unnormed[:, :, 0]#.detach().cpu().numpy()
286 | traj_pred_unnormed[:, :, -2] = traj_sr_input_unnormed[:, :, -2]#.detach().cpu().numpy()
287 |
288 | B, n_markers, _ = marker_start.shape
289 | _, n_joints, _ = joint_start.shape
290 | markers = torch.rand(B, 61, n_markers, 3).to(device) # [B, T, N ,3]
291 | joints = torch.rand(B, 61, n_joints, 3).to(device) # [B, T, N ,3]
292 |
293 | marker_start_new = torch.matmul(marker_start - joint_start[:, 0:1], transf_rotmat)
294 | marker_end_new = torch.matmul(marker_end - joint_start[:, 0:1], transf_rotmat)
295 |
296 | z_transl_to_floor_start = torch.min(marker_start_new[:, :, -1], dim=-1)[0]# - 0.03
297 | z_transl_to_floor_end = torch.min(marker_end_new[:, :, -1], dim=-1)[0]# - 0.03
298 |
299 | marker_start_new[:, :, -1] -= z_transl_to_floor_start.unsqueeze(1)
300 | marker_end_new[:, :, -1] -= z_transl_to_floor_end.unsqueeze(1)
301 | joint_start_new[:, :, -1] -= z_transl_to_floor_start.unsqueeze(1)
302 | joint_end_new[:, :, -1] -= z_transl_to_floor_end.unsqueeze(1)
303 |
304 | markers[:, 0] = marker_start_new
305 | markers[:, -1] = marker_end_new
306 | joints[:, 0] = joint_start_new
307 | joints[:, -1] = joint_end_new
308 |
309 | cur_body = torch.cat([joints[:, :, 0:1], markers], dim=2)
310 | cur_body[:, :, :, [1, 2]] = cur_body[:, :, :, [2, 1]] # => xyz -> xzy
311 | reference = cur_body[:, :, 0] * torch.tensor([1, 0, 1]).to(device) # => the xy of pelvis joint?
312 | cur_body = torch.cat([reference.unsqueeze(2), cur_body], dim=2) # [B, T, 1(reference)+1(pelvis)+N, 3]
313 |
314 | # position to local frame
315 | cur_body[:, :, :, 0] = cur_body[:, :, :, 0] - cur_body[:, :, 0:1, 0]
316 | cur_body[:, :, :, -1] = cur_body[:, :, :, -1] - cur_body[:, :, 0:1, -1]
317 |
318 | forward = torch.zeros((B, 62, 3)).to(device)
319 | forward[:, :, :2] = traj_pred_unnormed[:, 2:].permute(0, 2, 1)
320 | forward = forward / torch.sqrt((forward ** 2).sum(dim=-1)).unsqueeze(-1)
321 | forward[:, :, [1, 2]] = forward[:, :, [2, 1]]
322 |
323 | if traj_smoothed:
324 | # forward_saved = forward.copy()
325 | direction_filterwidth = 20
326 | forward = filters.gaussian_filter1d(forward, direction_filterwidth, axis=1, mode='nearest')
327 | traj_pred_unnormed[:, 2] = forward[:, :, 0]
328 | traj_pred_unnormed[:, 3] = forward[:, :, -1]
329 |
330 | target = torch.tensor([[[0, 0, 1]]]).float().to(device).repeat(forward.size(0), forward.size(1), 1) #.repeat(len(forward), axis=0)
331 | # rotation = Quaternions.between(forward, target)[:, :, np.newaxis] # [B, T, 1, 4]
332 | rotation = Quaternions_torch.between(forward, target).unsqueeze(2)
333 |
334 | cur_body = rotation[:, :-1] * cur_body#.detach().cpu().numpy() # [B, T, 1+1+N, xzy]
335 | cur_body[:, 1:-1] = 0
336 | cur_body[:, :, :, [1, 2]] = cur_body[:, :, :, [2, 1]] # xzy => xyz
337 | cur_body = cur_body[:, :, 1:, :]
338 | cur_body = cur_body.reshape(cur_body.shape[0], cur_body.shape[1], -1) # [B, T, N*3]
339 |
340 | velocity = torch.zeros((B, 3, 61)).to(device)
341 | velocity[:, 0, :] = traj_pred_unnormed[:, 0, 1:] - traj_pred_unnormed[:, 0, 0:-1] # [B, 2, 61] on Joint frame
342 | velocity[:, -1, :] = traj_pred_unnormed[:, 1, 1:] - traj_pred_unnormed[:, 1, 0:-1] # [B, 2, 61] on Joint frame
343 |
344 | velocity = rotation[:, 1:] * velocity.permute(0, 2, 1).reshape(B, 61, 1, 3)
345 | rvelocity = Pivots_torch.from_quaternions(rotation[:, 1:] * -rotation[:, :-1]).ps # [B, T-1, 1]
346 | rot_0_pivot = Pivots_torch.from_quaternions(rotation[:, 0]).ps
347 |
348 | global_x = velocity[:, :, 0, 0]
349 | global_y = velocity[:, :, 0, 2]
350 |
351 | T, d = clip_img.shape[-1], clip_img.shape[-2]
352 | channel_global_x = torch.repeat_interleave(global_x, d).reshape(-1, 1, T, d) # [B, 1, T-1, d]
353 | channel_global_y = torch.repeat_interleave(global_y, d).reshape(-1, 1, T, d) # [B, 1, T-1, d]
354 | channel_global_r = torch.repeat_interleave(rvelocity, d).reshape(-1, 1, T, d) # [B, 1, T-1, d]
355 |
356 | cur_body = torch.cat([clip_img[:, 0:1].permute(0, 1, 3, 2), channel_global_x, channel_global_y, channel_global_r], dim=1) # [B, 4, T-1, d]
357 |
358 | # cur_body[:, 0] = (cur_body[:, 0] - markers_stats['Xmean_local']) / markers_stats['Xstd_local']
359 | cur_body[:, 1:3] = (cur_body[:, 1:3] - torch.from_numpy(markers_stats['Xmean_global_xy']).float().to(device)) / torch.from_numpy(markers_stats['Xstd_global_xy']).float().to(device)
360 | cur_body[:, 3] = (cur_body[:, 3] - torch.from_numpy(markers_stats['Xmean_global_r']).float().to(device)) / torch.from_numpy(markers_stats['Xstd_global_r']).float().to(device)
361 |
362 | # mask cur_body
363 | cur_body = cur_body.permute(0, 1, 3, 2) # [B, 4, D, T-1]
364 | mask_t_1 = [0, 60]
365 | mask_t_0 = list(set(range(60+1)) - set(mask_t_1))
366 | cur_body[:, 0, 2:, mask_t_0] = 0.
367 | cur_body[:, 0, -4:, :] = 0.
368 | # print('Mask the markers in the following frames: ', mask_t_0)
369 | return cur_body, rot_0_pivot, marker_start_new, marker_end_new, traj_pred_unnormed
370 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
9 | to_cpu = lambda tensor: tensor.detach().cpu().numpy()
10 |
11 | class Struct(object):
12 | def __init__(self, **kwargs):
13 | for key, val in kwargs.items():
14 | setattr(self, key, val)
15 |
16 |
17 | def to_tensor(array, dtype=torch.float32):
18 | if not torch.is_tensor(array):
19 | array = torch.tensor(array)
20 | return array.to(dtype)
21 |
22 |
23 | def to_np(array, dtype=np.float32):
24 | if 'scipy.sparse' in str(type(array)):
25 | array = np.array(array.todense(), dtype=dtype)
26 | elif torch.is_tensor(array):
27 | array = array.detach().cpu().numpy()
28 | return array.astype(dtype)
29 |
30 |
31 | def makepath(desired_path, isfile = False):
32 | '''
33 | if the path does not exist make it
34 | :param desired_path: can be path to a file or a folder name
35 | :return:
36 | '''
37 | import os
38 | if isfile:
39 | if not os.path.exists(os.path.dirname(desired_path)):os.makedirs(os.path.dirname(desired_path))
40 | else:
41 | if not os.path.exists(desired_path): os.makedirs(desired_path)
42 | return desired_path
43 |
44 | def makelogger(log_dir,mode='w'):
45 |
46 |
47 | logger = logging.getLogger()
48 | logger.setLevel(logging.INFO)
49 |
50 | ch = logging.StreamHandler()
51 | ch.setLevel(logging.INFO)
52 |
53 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
54 |
55 | ch.setFormatter(formatter)
56 |
57 | logger.addHandler(ch)
58 |
59 | fh = logging.FileHandler('%s'%log_dir, mode=mode)
60 | fh.setFormatter(formatter)
61 | logger.addHandler(fh)
62 |
63 | return logger
64 |
65 | def CRot2rotmat(pose):
66 |
67 | reshaped_input = pose.view(-1, 3, 2)
68 |
69 | b1 = F.normalize(reshaped_input[:, :, 0], dim=1)
70 |
71 | dot_prod = torch.sum(b1 * reshaped_input[:, :, 1], dim=1, keepdim=True)
72 | b2 = F.normalize(reshaped_input[:, :, 1] - dot_prod * b1, dim=-1)
73 | b3 = torch.cross(b1, b2, dim=1)
74 |
75 | return torch.stack([b1, b2, b3], dim=-1)
76 |
77 |
78 | def euler(rots, order='xyz', units='deg'):
79 |
80 | rots = np.asarray(rots)
81 | single_val = False if len(rots.shape)>1 else True
82 | rots = rots.reshape(-1,3)
83 | rotmats = []
84 |
85 | for xyz in rots:
86 | if units == 'deg':
87 | xyz = np.radians(xyz)
88 | r = np.eye(3)
89 | for theta, axis in zip(xyz,order):
90 | c = np.cos(theta)
91 | s = np.sin(theta)
92 | if axis=='x':
93 | r = np.dot(np.array([[1,0,0],[0,c,-s],[0,s,c]]), r)
94 | if axis=='y':
95 | r = np.dot(np.array([[c,0,s],[0,1,0],[-s,0,c]]), r)
96 | if axis=='z':
97 | r = np.dot(np.array([[c,-s,0],[s,c,0],[0,0,1]]), r)
98 | rotmats.append(r)
99 | rotmats = np.stack(rotmats).astype(np.float32)
100 | if single_val:
101 | return rotmats[0]
102 | else:
103 | return rotmats
104 |
105 | def batch_euler(bxyz,order='xyz', units='deg'):
106 |
107 | br = []
108 | for frame in range(bxyz.shape[0]):
109 | br.append(euler(bxyz[frame], order, units))
110 | return np.stack(br).astype(np.float32)
111 |
112 | def rotate(points,R):
113 | shape = points.shape
114 | if len(shape)>3:
115 | points = points.squeeze()
116 | if len(shape)<3:
117 | points = points[:,np.newaxis]
118 | r_points = torch.matmul(torch.from_numpy(points).to(device), torch.from_numpy(R).to(device).transpose(1,2))
119 | return r_points.cpu().numpy().reshape(shape)
120 |
121 | def rotmul(rotmat,R):
122 |
123 | shape = rotmat.shape
124 | rotmat = rotmat.squeeze()
125 | R = R.squeeze()
126 | rot = torch.matmul(torch.from_numpy(R).to(device),torch.from_numpy(rotmat).to(device))
127 | return rot.cpu().numpy().reshape(shape)
128 |
129 | # import torchgeometry as tgm
130 | # borrowed from the torchgeometry package
131 | def rotmat2aa(rotmat):
132 | '''
133 | :param rotmat: Nx1xnum_jointsx9
134 | :return: Nx1xnum_jointsx3
135 | '''
136 | batch_size = rotmat.size(0)
137 | homogen_matrot = F.pad(rotmat.view(-1, 3, 3), [0,1])
138 | pose = rotation_matrix_to_angle_axis(homogen_matrot).view(batch_size, 1, -1, 3).contiguous()
139 | return pose
140 |
141 | def aa2rotmat(axis_angle):
142 | '''
143 | :param Nx1xnum_jointsx3
144 | :return: pose_matrot: Nx1xnum_jointsx9
145 | '''
146 | batch_size = axis_angle.size(0)
147 | pose_body_matrot = angle_axis_to_rotation_matrix(axis_angle.reshape(-1, 3))[:, :3, :3].contiguous().view(batch_size, 1, -1, 9)
148 | return pose_body_matrot
149 |
150 | def angle_axis_to_rotation_matrix(angle_axis):
151 | """Convert 3d vector of axis-angle rotation to 4x4 rotation matrix
152 |
153 | Args:
154 | angle_axis (Tensor): tensor of 3d vector of axis-angle rotations.
155 |
156 | Returns:
157 | Tensor: tensor of 4x4 rotation matrices.
158 |
159 | Shape:
160 | - Input: :math:`(N, 3)`
161 | - Output: :math:`(N, 4, 4)`
162 |
163 | Example:
164 | >>> input = torch.rand(1, 3) # Nx3
165 | >>> output = angle_axis_to_rotation_matrix(input) # Nx4x4
166 | """
167 | def _compute_rotation_matrix(angle_axis, theta2, eps=1e-6):
168 | # We want to be careful to only evaluate the square root if the
169 | # norm of the angle_axis vector is greater than zero. Otherwise
170 | # we get a division by zero.
171 | k_one = 1.0
172 | theta = torch.sqrt(theta2)
173 | wxyz = angle_axis / (theta + eps)
174 | wx, wy, wz = torch.chunk(wxyz, 3, dim=1)
175 | cos_theta = torch.cos(theta)
176 | sin_theta = torch.sin(theta)
177 |
178 | r00 = cos_theta + wx * wx * (k_one - cos_theta)
179 | r10 = wz * sin_theta + wx * wy * (k_one - cos_theta)
180 | r20 = -wy * sin_theta + wx * wz * (k_one - cos_theta)
181 | r01 = wx * wy * (k_one - cos_theta) - wz * sin_theta
182 | r11 = cos_theta + wy * wy * (k_one - cos_theta)
183 | r21 = wx * sin_theta + wy * wz * (k_one - cos_theta)
184 | r02 = wy * sin_theta + wx * wz * (k_one - cos_theta)
185 | r12 = -wx * sin_theta + wy * wz * (k_one - cos_theta)
186 | r22 = cos_theta + wz * wz * (k_one - cos_theta)
187 | rotation_matrix = torch.cat(
188 | [r00, r01, r02, r10, r11, r12, r20, r21, r22], dim=1)
189 | return rotation_matrix.view(-1, 3, 3)
190 |
191 | def _compute_rotation_matrix_taylor(angle_axis):
192 | rx, ry, rz = torch.chunk(angle_axis, 3, dim=1)
193 | k_one = torch.ones_like(rx)
194 | rotation_matrix = torch.cat(
195 | [k_one, -rz, ry, rz, k_one, -rx, -ry, rx, k_one], dim=1)
196 | return rotation_matrix.view(-1, 3, 3)
197 |
198 | # stolen from ceres/rotation.h
199 |
200 | _angle_axis = torch.unsqueeze(angle_axis, dim=1)
201 | theta2 = torch.matmul(_angle_axis, _angle_axis.transpose(1, 2))
202 | theta2 = torch.squeeze(theta2, dim=1)
203 |
204 | # compute rotation matrices
205 | rotation_matrix_normal = _compute_rotation_matrix(angle_axis, theta2)
206 | rotation_matrix_taylor = _compute_rotation_matrix_taylor(angle_axis)
207 |
208 | # create mask to handle both cases
209 | eps = 1e-6
210 | mask = (theta2 > eps).view(-1, 1, 1).to(theta2.device)
211 | mask_pos = (mask).type_as(theta2)
212 | mask_neg = (mask == False).type_as(theta2) # noqa
213 |
214 | # create output pose matrix
215 | batch_size = angle_axis.shape[0]
216 | rotation_matrix = torch.eye(4).to(angle_axis.device).type_as(angle_axis)
217 | rotation_matrix = rotation_matrix.view(1, 4, 4).repeat(batch_size, 1, 1)
218 | # fill output matrix with masked values
219 | rotation_matrix[..., :3, :3] = \
220 | mask_pos * rotation_matrix_normal + mask_neg * rotation_matrix_taylor
221 | return rotation_matrix # Nx4x4
222 |
223 | def rotation_matrix_to_angle_axis(rotation_matrix):
224 | """Convert 3x4 rotation matrix to Rodrigues vector
225 |
226 | Args:
227 | rotation_matrix (Tensor): rotation matrix.
228 |
229 | Returns:
230 | Tensor: Rodrigues vector transformation.
231 |
232 | Shape:
233 | - Input: :math:`(N, 3, 4)`
234 | - Output: :math:`(N, 3)`
235 |
236 | Example:
237 | >>> input = torch.rand(2, 3, 4) # Nx4x4
238 | >>> output = rotation_matrix_to_angle_axis(input) # Nx3
239 | """
240 | # todo add check that matrix is a valid rotation matrix
241 | quaternion = rotation_matrix_to_quaternion(rotation_matrix)
242 | return quaternion_to_angle_axis(quaternion)
243 |
244 | def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6):
245 | """Convert 3x4 rotation matrix to 4d quaternion vector
246 |
247 | This algorithm is based on algorithm described in
248 | https://github.com/KieranWynn/pyquaternion/blob/master/pyquaternion/quaternion.py#L201
249 |
250 | Args:
251 | rotation_matrix (Tensor): the rotation matrix to convert.
252 |
253 | Return:
254 | Tensor: the rotation in quaternion
255 |
256 | Shape:
257 | - Input: :math:`(N, 3, 4)`
258 | - Output: :math:`(N, 4)`
259 |
260 | Example:
261 | >>> input = torch.rand(4, 3, 4) # Nx3x4
262 | >>> output = rotation_matrix_to_quaternion(input) # Nx4
263 | """
264 | if not torch.is_tensor(rotation_matrix):
265 | raise TypeError("Input type is not a torch.Tensor. Got {}".format(
266 | type(rotation_matrix)))
267 |
268 | if len(rotation_matrix.shape) > 3:
269 | raise ValueError(
270 | "Input size must be a three dimensional tensor. Got {}".format(
271 | rotation_matrix.shape))
272 | if not rotation_matrix.shape[-2:] == (3, 4):
273 | raise ValueError(
274 | "Input size must be a N x 3 x 4 tensor. Got {}".format(
275 | rotation_matrix.shape))
276 |
277 | rmat_t = torch.transpose(rotation_matrix, 1, 2)
278 |
279 | mask_d2 = rmat_t[:, 2, 2] < eps
280 |
281 | mask_d0_d1 = rmat_t[:, 0, 0] > rmat_t[:, 1, 1]
282 | mask_d0_nd1 = rmat_t[:, 0, 0] < -rmat_t[:, 1, 1]
283 |
284 | t0 = 1 + rmat_t[:, 0, 0] - rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
285 | q0 = torch.stack([rmat_t[:, 1, 2] - rmat_t[:, 2, 1],
286 | t0, rmat_t[:, 0, 1] + rmat_t[:, 1, 0],
287 | rmat_t[:, 2, 0] + rmat_t[:, 0, 2]], -1)
288 | t0_rep = t0.repeat(4, 1).t()
289 |
290 | t1 = 1 - rmat_t[:, 0, 0] + rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
291 | q1 = torch.stack([rmat_t[:, 2, 0] - rmat_t[:, 0, 2],
292 | rmat_t[:, 0, 1] + rmat_t[:, 1, 0],
293 | t1, rmat_t[:, 1, 2] + rmat_t[:, 2, 1]], -1)
294 | t1_rep = t1.repeat(4, 1).t()
295 |
296 | t2 = 1 - rmat_t[:, 0, 0] - rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
297 | q2 = torch.stack([rmat_t[:, 0, 1] - rmat_t[:, 1, 0],
298 | rmat_t[:, 2, 0] + rmat_t[:, 0, 2],
299 | rmat_t[:, 1, 2] + rmat_t[:, 2, 1], t2], -1)
300 | t2_rep = t2.repeat(4, 1).t()
301 |
302 | t3 = 1 + rmat_t[:, 0, 0] + rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
303 | q3 = torch.stack([t3, rmat_t[:, 1, 2] - rmat_t[:, 2, 1],
304 | rmat_t[:, 2, 0] - rmat_t[:, 0, 2],
305 | rmat_t[:, 0, 1] - rmat_t[:, 1, 0]], -1)
306 | t3_rep = t3.repeat(4, 1).t()
307 |
308 | mask_c0 = mask_d2 * mask_d0_d1
309 | mask_c1 = mask_d2 * (~mask_d0_d1)
310 | mask_c2 = (~mask_d2) * mask_d0_nd1
311 | mask_c3 = (~mask_d2) * (~mask_d0_nd1)
312 | mask_c0 = mask_c0.view(-1, 1).type_as(q0)
313 | mask_c1 = mask_c1.view(-1, 1).type_as(q1)
314 | mask_c2 = mask_c2.view(-1, 1).type_as(q2)
315 | mask_c3 = mask_c3.view(-1, 1).type_as(q3)
316 |
317 | q = q0 * mask_c0 + q1 * mask_c1 + q2 * mask_c2 + q3 * mask_c3
318 | q /= torch.sqrt(t0_rep * mask_c0 + t1_rep * mask_c1 + # noqa
319 | t2_rep * mask_c2 + t3_rep * mask_c3) # noqa
320 | q *= 0.5
321 | return q
322 |
323 | def quaternion_to_angle_axis(quaternion: torch.Tensor) -> torch.Tensor:
324 | """Convert quaternion vector to angle axis of rotation.
325 |
326 | Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h
327 |
328 | Args:
329 | quaternion (torch.Tensor): tensor with quaternions.
330 |
331 | Return:
332 | torch.Tensor: tensor with angle axis of rotation.
333 |
334 | Shape:
335 | - Input: :math:`(*, 4)` where `*` means, any number of dimensions
336 | - Output: :math:`(*, 3)`
337 |
338 | Example:
339 | >>> quaternion = torch.rand(2, 4) # Nx4
340 | >>> angle_axis = quaternion_to_angle_axis(quaternion) # Nx3
341 | """
342 | if not torch.is_tensor(quaternion):
343 | raise TypeError("Input type is not a torch.Tensor. Got {}".format(
344 | type(quaternion)))
345 |
346 | if not quaternion.shape[-1] == 4:
347 | raise ValueError("Input must be a tensor of shape Nx4 or 4. Got {}"
348 | .format(quaternion.shape))
349 | # unpack input and compute conversion
350 | q1: torch.Tensor = quaternion[..., 1]
351 | q2: torch.Tensor = quaternion[..., 2]
352 | q3: torch.Tensor = quaternion[..., 3]
353 | sin_squared_theta: torch.Tensor = q1 * q1 + q2 * q2 + q3 * q3
354 |
355 | sin_theta: torch.Tensor = torch.sqrt(sin_squared_theta)
356 | cos_theta: torch.Tensor = quaternion[..., 0]
357 | two_theta: torch.Tensor = 2.0 * torch.where(
358 | cos_theta < 0.0,
359 | torch.atan2(-sin_theta, -cos_theta),
360 | torch.atan2(sin_theta, cos_theta))
361 |
362 | k_pos: torch.Tensor = two_theta / sin_theta
363 | k_neg: torch.Tensor = 2.0 * torch.ones_like(sin_theta)
364 | k: torch.Tensor = torch.where(sin_squared_theta > 0.0, k_pos, k_neg)
365 |
366 | angle_axis: torch.Tensor = torch.zeros_like(quaternion)[..., :3]
367 | angle_axis[..., 0] += q1 * k
368 | angle_axis[..., 1] += q2 * k
369 | angle_axis[..., 2] += q3 * k
370 | return angle_axis
371 |
372 |
373 | class RotConverter(nn.Module):
374 | '''
375 | this class is from smplx/vposer
376 | '''
377 | def __init__(self):
378 | super(RotConverter, self).__init__()
379 |
380 | def forward(self,module_input):
381 | pass
382 |
383 |
384 | @staticmethod
385 | def cont2rotmat(module_input):
386 | reshaped_input = module_input.view(-1, 3, 2)
387 | b1 = F.normalize(reshaped_input[:, :, 0], dim=1)
388 | dot_prod = torch.sum(b1 * reshaped_input[:, :, 1], dim=1, keepdim=True)
389 | b2 = F.normalize(reshaped_input[:, :, 1] - dot_prod * b1, dim=-1)
390 | b3 = torch.cross(b1, b2, dim=1)
391 |
392 | return torch.stack([b1, b2, b3], dim=-1)
393 |
394 |
395 | @staticmethod
396 | def aa2cont(module_input):
397 | '''
398 | :param NxTxnum_jointsx3
399 | :return: pose_matrot: NxTxnum_jointsx6
400 | '''
401 | batch_size = module_input.shape[0]
402 | n_frames = module_input.shape[1]
403 | pose_body_6d = angle_axis_to_rotation_matrix(module_input.reshape(-1, 3))[:, :3, :2].contiguous().view(batch_size, n_frames, -1, 6)
404 |
405 | return pose_body_6d
406 |
407 |
408 | @staticmethod
409 | def rotmat2aa(pose_matrot):
410 | '''
411 | :param pose_matrot: Nx1xnum_jointsx9
412 | :return: Nx1xnum_jointsx3
413 | '''
414 | homogen_matrot = F.pad(pose_matrot.view(-1, 3, 3), [0,1])
415 | pose = rotation_matrix_to_angle_axis(homogen_matrot).view(-1, 3).contiguous()
416 |
417 | return pose
418 |
419 |
420 | @staticmethod
421 | def aa2rotmat(pose):
422 | '''
423 | :param Nx1xnum_jointsx3
424 | :return: pose_matrot: Nx1xnum_jointsx9
425 | '''
426 | batch_size = pose.shape[0]
427 | n_frames = pose.shape[1]
428 | pose_body_matrot = angle_axis_to_rotation_matrix(pose.reshape(-1, 3))[:, :3, :3].contiguous().view(batch_size, n_frames, -1, 9)
429 |
430 | return pose_body_matrot
431 |
432 |
433 |
--------------------------------------------------------------------------------
/visualization/__pycache__/visualization_utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/visualization/__pycache__/visualization_utils.cpython-38.pyc
--------------------------------------------------------------------------------
/visualization/vis_motion.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | sys.path.append(os.getcwd())
5 | import argparse
6 | import pickle
7 |
8 | import matplotlib.pyplot as plt
9 | import numpy as np
10 | import open3d as o3d
11 | import torch
12 | import torch.nn.functional as F
13 | from tqdm import tqdm
14 |
15 | from visualization_utils import (color_hex2rgb, create_lineset, get_body_mesh,
16 | get_object_mesh, update_cam)
17 |
18 |
19 | def vis_graspmotion_third_view(body_meshes, object_mesh, object_transl, sample_index):
20 | mesh_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(
21 | size=0.5, origin=[0, 0, 0])
22 |
23 | vis = o3d.visualization.Visualizer()
24 | vis.create_window()
25 | # vis.add_geometry(mesh_frame)
26 | render_opt=vis.get_render_option()
27 | render_opt.mesh_show_back_face=True
28 | render_opt.line_width=10
29 | render_opt.point_size=5
30 | render_opt.background_color = color_hex2rgb('#1c2434')
31 |
32 | vis.add_geometry(object_mesh)
33 |
34 | x_range = np.arange(-200, 200, 0.75)
35 | y_range = np.arange(-200, 200, 0.75)
36 | z_range = np.arange(0, 1, 1)
37 | gp_lines, gp_pcd = create_lineset(x_range, y_range, z_range)
38 | gp_lines.paint_uniform_color(color_hex2rgb('#7ea4ab'))
39 | gp_pcd.paint_uniform_color(color_hex2rgb('#7ea4ab'))
40 | vis.add_geometry(gp_lines)
41 | vis.poll_events()
42 | vis.update_renderer()
43 | vis.add_geometry(gp_pcd)
44 | vis.poll_events()
45 | vis.update_renderer()
46 |
47 | for t in range(len(body_meshes)):
48 | vis.add_geometry(body_meshes[t])
49 |
50 | ### get cam R
51 | ### update render cam
52 | ctr = vis.get_view_control()
53 | cam_param = ctr.convert_to_pinhole_camera_parameters()
54 | trans = np.eye(4)
55 | trans[:3, :3] = np.array([[0, 0, -1], [-1, 0, 0], [0, -1, 0]])
56 | trans[:3, -1] = np.array([4, 1, 1])
57 | cam_param = update_cam(cam_param, trans)
58 | ctr.convert_from_pinhole_camera_parameters(cam_param)
59 | vis.poll_events()
60 | vis.update_renderer()
61 |
62 | vis.capture_screen_image(
63 | vis_save_path_third_view+"/clip_%04d_%04d.jpg" % (sample_index, t), True)
64 | vis.remove_geometry(body_meshes[t])
65 |
66 |
67 | def vis_graspmotion_first_view(body_meshes, object_mesh, object_transl, sample_index):
68 | mesh_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(
69 | size=0.5, origin=[0, 0, 0])
70 |
71 | vis = o3d.visualization.Visualizer()
72 | vis.create_window()
73 | render_opt=vis.get_render_option()
74 | render_opt.mesh_show_back_face=True
75 | render_opt.line_width=10
76 | render_opt.point_size=5
77 |
78 | render_opt.background_color = color_hex2rgb('#545454')
79 |
80 | vis.add_geometry(object_mesh)
81 |
82 | for t in range(len(body_meshes)):
83 | vis.add_geometry(body_meshes[t])
84 |
85 | ### get cam R
86 | cam_o = np.array(body_meshes[t].vertices)[8999]
87 | cam_z = object_transl - cam_o
88 | cam_z = cam_z / np.linalg.norm(cam_z)
89 | cam_x = np.array([cam_z[1], -cam_z[0], 0.0])
90 | cam_x = cam_x / np.linalg.norm(cam_x)
91 | cam_y = np.array([cam_z[0], cam_z[1], -(cam_z[0]**2 + cam_z[1]**2)/cam_z[2] ])
92 | cam_y = cam_y / np.linalg.norm(cam_y)
93 | cam_r = np.stack([cam_x, -cam_y, cam_z], axis=1)
94 | ### update render cam
95 | ctr = vis.get_view_control()
96 | cam_param = ctr.convert_to_pinhole_camera_parameters()
97 | transf = np.eye(4)
98 | transf[:3,:3]=cam_r
99 | transf[:3,-1] = cam_o
100 | cam_param = update_cam(cam_param, transf)
101 | ctr.convert_from_pinhole_camera_parameters(cam_param)
102 | vis.poll_events()
103 | vis.update_renderer()
104 |
105 | vis.capture_screen_image(
106 | vis_save_path_first_view+"/clip_%04d_%04d.jpg" % (sample_index, t), True)
107 | vis.remove_geometry(body_meshes[t])
108 |
109 |
110 | def vis_graspmotion_top_view(body_meshes, object_mesh, object_transl, sample_index):
111 | mesh_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(
112 | size=0.5, origin=[0, 0, 0])
113 |
114 | vis = o3d.visualization.Visualizer()
115 | vis.create_window()
116 | render_opt=vis.get_render_option()
117 | render_opt.mesh_show_back_face=True
118 | render_opt.line_width=10
119 | render_opt.point_size=5
120 | render_opt.background_color = color_hex2rgb('#1c2434')
121 |
122 | vis.add_geometry(object_mesh)
123 |
124 | x_range = np.arange(-200, 200, 0.75)
125 | y_range = np.arange(-200, 200, 0.75)
126 | z_range = np.arange(0, 1, 1)
127 | gp_lines, gp_pcd = create_lineset(x_range, y_range, z_range)
128 | gp_lines.paint_uniform_color(color_hex2rgb('#7ea4ab'))
129 | gp_pcd.paint_uniform_color(color_hex2rgb('#7ea4ab'))
130 | vis.add_geometry(gp_lines)
131 | vis.poll_events()
132 | vis.update_renderer()
133 | vis.add_geometry(gp_pcd)
134 | vis.poll_events()
135 | vis.update_renderer()
136 |
137 | for t in range(len(body_meshes)):
138 | vis.add_geometry(body_meshes[t])
139 |
140 | ctr = vis.get_view_control()
141 | cam_param = ctr.convert_to_pinhole_camera_parameters()
142 |
143 | cam_o = np.array([0, 0, 4])
144 | reference_point = np.zeros(3)
145 | reference_point[:2] = object_transl[:2]/2
146 | cam_z = reference_point - cam_o
147 | cam_z = cam_z / np.linalg.norm(cam_z)
148 | cam_x = np.array([cam_z[1], -cam_z[0], 0.0])
149 | cam_x = cam_x / np.linalg.norm(cam_x)
150 | cam_y = np.array([cam_z[0], cam_z[1], -(cam_z[0]**2 + cam_z[1]**2)/cam_z[2] ])
151 | cam_y = cam_y / np.linalg.norm(cam_y)
152 | cam_r = np.stack([cam_x, -cam_y, cam_z], axis=1)
153 | ### update render cam
154 | ctr = vis.get_view_control()
155 | cam_param = ctr.convert_to_pinhole_camera_parameters()
156 | transf = np.eye(4)
157 | transf[:3,:3]=cam_r
158 | transf[:3,-1] = cam_o
159 |
160 |
161 | cam_param = update_cam(cam_param, transf)
162 | ctr.convert_from_pinhole_camera_parameters(cam_param)
163 | vis.poll_events()
164 | vis.update_renderer()
165 |
166 | vis.capture_screen_image(
167 | vis_save_path_top_view+"/clip_%04d_%04d.jpg" % (sample_index, t), True)
168 | vis.remove_geometry(body_meshes[t])
169 |
170 | if __name__ == '__main__':
171 |
172 | parser = argparse.ArgumentParser(description='visualization from saved')
173 |
174 | parser.add_argument('--GraspPose_exp_name', type=str, help='exp name')
175 | parser.add_argument('--dataset', default='GRAB', type=str, help='exp name')
176 | parser.add_argument('--object', default='camera', type=str, help='object name')
177 | parser.add_argument('--gender', type=str, help='object name')
178 |
179 | args = parser.parse_args()
180 |
181 | result_path = '../results/{}/GraspMotion/{}'.format(args.GraspPose_exp_name, args.object)
182 | opt_results = np.load('{}/fitting_results.npy'.format(result_path), allow_pickle=True)[()]
183 |
184 | print('Saving visualization results to {}').format(result_path)
185 |
186 | vis_save_path_first_view = os.path.join(result_path, 'visualization/first_view')
187 | vis_save_path_third_view = os.path.join(result_path, 'visualization/third_view')
188 | vis_save_path_top_view = os.path.join(result_path, 'visualization/top_view')
189 | if not os.path.exists(vis_save_path_first_view):
190 | os.makedirs(vis_save_path_first_view)
191 | if not os.path.exists(vis_save_path_third_view):
192 | os.makedirs(vis_save_path_third_view)
193 | if not os.path.exists(vis_save_path_top_view):
194 | os.makedirs(vis_save_path_top_view)
195 |
196 | object_params = opt_results['object'] # Tensor
197 | object_name = str(opt_results['object_name'])
198 | body_orig = opt_results['body_orig'] # Numpy
199 | body_opt = opt_results['body_opt']
200 |
201 | for key in body_orig:
202 | body_orig[key] = torch.from_numpy(body_orig[key])
203 | body_opt[key] = torch.from_numpy(body_opt[key])
204 |
205 | B, T, _ = body_orig['transl'].shape
206 |
207 | object_mesh = get_object_mesh(object_name, args.dataset, object_params['transl'][:B], object_params['global_orient'][:B], B, rotmat=True)
208 |
209 | # frame by frame visualization
210 | coord = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.25)
211 |
212 | x_range = np.arange(-200, 200, 0.75)
213 | y_range = np.arange(-200, 200, 0.75)
214 | z_range = np.arange(0, 1, 1)
215 | gp_lines, gp_pcd = create_lineset(x_range, y_range, z_range)
216 |
217 | collision_eval_list = []
218 |
219 | camera_point_index = 8999
220 |
221 | pos_list, vel_list, acc_list = [], [], []
222 |
223 | for i in tqdm(range(0, B)):
224 | collision_eval_T = {}
225 | collision_eval_T['vol'] = []
226 | collision_eval_T['depth'] = []
227 | collision_eval_T['contact'] = []
228 | # get body mesh
229 | orig_smplxparams = {}
230 | opt_smplxparams = {}
231 | for k in body_orig.keys():
232 | orig_smplxparams[k] = body_orig[k][i]
233 | opt_smplxparams[k] = body_opt[k][i]
234 | body_meshes_orig, _ = get_body_mesh(orig_smplxparams, args.gender, n_samples=T, device='cpu', color='D4BEA3')
235 | body_meshes_opt, _ = get_body_mesh(opt_smplxparams, args.gender, n_samples=T, device='cpu', color='D4BEA3')
236 |
237 |
238 | body_meshes = body_meshes_opt
239 |
240 | vis_graspmotion_first_view(body_meshes, object_mesh[i], object_params['transl'][i], i)
241 | vis_graspmotion_top_view(body_meshes, object_mesh[i], object_params['transl'][i], i)
242 | vis_graspmotion_third_view(body_meshes, object_mesh[i], object_params['transl'][i], i)
243 |
--------------------------------------------------------------------------------
/visualization/vis_pose.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import sys
4 |
5 | import numpy as np
6 | import open3d as o3d
7 |
8 | from visualization_utils import *
9 |
10 | if __name__ == '__main__':
11 |
12 | parser = argparse.ArgumentParser(description='grabpose-Testing')
13 |
14 | parser.add_argument('--exp_name', default = None, type=str,
15 | help='experiment name')
16 |
17 | parser.add_argument('--gender', default = None, type=str,
18 | help='gender')
19 |
20 | parser.add_argument('--object', default = None, type=str,
21 | help='object name')
22 |
23 | parser.add_argument('--object_format', default = 'mesh', type=str,
24 | help='pcd or mesh')
25 |
26 | args = parser.parse_args()
27 |
28 |
29 | cwd = os.getcwd()
30 |
31 | load_path = '../results/{}/GraspPose/{}/fitting_results.npz'.format(args.exp_name, args.object)
32 |
33 | data = np.load(load_path, allow_pickle=True)
34 | gender = args.gender
35 | object_name = args.object
36 |
37 | n_samples = len(data['markers'])
38 |
39 | # Prepare mesh and pcd
40 | object_pcd = get_pcd(data['object'][()]['verts_object'][:n_samples])
41 | object_mesh = get_object_mesh(object_name, 'GRAB', data['object'][()]['transl'][:n_samples], data['object'][()]['global_orient'][:n_samples], n_samples)
42 | body_mesh, _ = get_body_mesh(data['body'][()], gender, n_samples)
43 |
44 |
45 | # ground
46 | x_range = np.arange(-5, 50, 1)
47 | y_range = np.arange(-5, 50, 1)
48 | z_range = np.arange(0, 1, 1)
49 | gp_lines, gp_pcd = create_lineset(x_range, y_range, z_range)
50 | gp_lines.paint_uniform_color(color_hex2rgb('#bdbfbe')) # grey
51 | gp_pcd.paint_uniform_color(color_hex2rgb('#bdbfbe')) # grey
52 | coord = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.25)
53 |
54 | for i in range(n_samples):
55 | print(body_mesh[i])
56 | visualization_list = [body_mesh[i], object_mesh[i], coord, gp_lines, gp_pcd]
57 | o3d.visualization.draw_geometries(visualization_list)
58 |
59 |
--------------------------------------------------------------------------------
/visualization/visualization_utils.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | sys.path.append('..')
4 | import os
5 |
6 | import numpy as np
7 | import open3d as o3d
8 | import smplx
9 | import torch
10 | from WholeGraspPose.models.objectmodel import ObjectModel
11 |
12 | def update_cam(cam_param, trans):
13 | cam_R = np.transpose(trans[:-1, :-1])
14 | cam_T = -trans[:-1, -1:]
15 | cam_T = np.matmul(cam_R, cam_T) # !!!!!! T is applied in the rotated coord
16 | cam_aux = np.array([[0, 0, 0, 1]])
17 | mat = np.concatenate([cam_R, cam_T], axis=-1)
18 | mat = np.concatenate([mat, cam_aux], axis=0)
19 | cam_param.extrinsic = mat
20 | return cam_param
21 |
22 | def color_hex2rgb(hex):
23 | h = hex.lstrip('#')
24 | return np.array( tuple(int(h[i:i+2], 16) for i in (0, 2, 4)) )/255
25 | def create_lineset(x_range, y_range, z_range):
26 | gp_lines = o3d.geometry.LineSet()
27 | gp_pcd = o3d.geometry.PointCloud()
28 | points = np.stack(np.meshgrid(x_range, y_range, z_range), axis=-1)
29 |
30 | lines = []
31 | for ii in range( x_range.shape[0]-1):
32 | for jj in range(y_range.shape[0]-1):
33 | lines.append(np.array([ii*x_range.shape[0]+jj, ii*x_range.shape[0]+jj+1]))
34 | lines.append(np.array([ii*x_range.shape[0]+jj, ii*x_range.shape[0]+jj+y_range.shape[0]]))
35 |
36 | points = np.reshape(points, [-1,3])
37 | colors = np.random.rand(len(lines), 3)*0.5+0.5
38 |
39 | gp_lines.points = o3d.utility.Vector3dVector(points)
40 | gp_lines.colors = o3d.utility.Vector3dVector(colors)
41 | gp_lines.lines = o3d.utility.Vector2iVector(np.stack(lines,axis=0))
42 | gp_pcd.points = o3d.utility.Vector3dVector(points)
43 |
44 | return gp_lines, gp_pcd
45 |
46 | def get_body_model(type, gender, batch_size,device='cuda',v_template=None):
47 | '''
48 | type: smpl, smplx smplh and others. Refer to smplx tutorial
49 | gender: male, female, neutral
50 | batch_size: an positive integar
51 | '''
52 | body_model_path = '../body_utils/body_models'
53 | body_model = smplx.create(body_model_path, model_type=type,
54 | gender=gender, ext='npz',
55 | num_pca_comps=24,
56 | create_global_orient=True,
57 | create_body_pose=True,
58 | create_betas=True,
59 | create_left_hand_pose=True,
60 | create_right_hand_pose=True,
61 | create_expression=True,
62 | create_jaw_pose=True,
63 | create_leye_pose=True,
64 | create_reye_pose=True,
65 | create_transl=True,
66 | batch_size=batch_size,
67 | v_template=v_template
68 | )
69 | if device == 'cuda':
70 | return body_model.cuda()
71 | else:
72 | return body_model
73 |
74 | def get_body_mesh(smplxparams, gender, n_samples, device='cpu', color=None):
75 | body_mesh_list = []
76 |
77 | for key in smplxparams.keys():
78 | # print(key, smplxparams[key].shape)
79 | smplxparams[key] = torch.tensor(smplxparams[key][:n_samples]).to(device)
80 |
81 |
82 | bm = get_body_model('smplx', str(gender), n_samples, device=device)
83 | smplx_results = bm(return_verts=True, **smplxparams)
84 | verts = smplx_results.vertices.detach().cpu().numpy()
85 | face = bm.faces
86 |
87 | for i in range(n_samples):
88 | mesh = o3d.geometry.TriangleMesh()
89 | mesh.vertices = o3d.utility.Vector3dVector(verts[i])
90 | mesh.triangles = o3d.utility.Vector3iVector(face)
91 | mesh.compute_vertex_normals()
92 | if color is not None:
93 | mesh.paint_uniform_color(color_hex2rgb(color)) # orange
94 | body_mesh_list.append(mesh)
95 |
96 | return body_mesh_list, smplx_results
97 |
98 | def get_object_mesh(obj, dataset, transl, global_orient, n_samples, device='cpu', rotmat=False):
99 | object_mesh_list = []
100 | global_orient_dim = 9 if rotmat else 3
101 |
102 | global_orient = torch.FloatTensor(global_orient).to(device)
103 | transl = torch.FloatTensor(transl).to(device)
104 |
105 | if dataset == 'GRAB':
106 | mesh_base = '../dataset/contact_meshes'
107 | # mesh_base = '/home/dalco/wuyan/data/GRAB/tools/object_meshes/contact_meshes'
108 | obj_mesh_base = o3d.io.read_triangle_mesh(os.path.join(mesh_base, obj + '.ply'))
109 | elif dataset == 'FHB':
110 | mesh_base = '/home/dalco/wuyan/data/FHB/Object_models'
111 | obj_mesh_base = o3d.io.read_triangle_mesh(os.path.join(mesh_base, '{}_model/{}_model.ply'.format(obj, obj)))
112 | elif dataset == 'HO3D':
113 | mesh_base = '/home/dalco/wuyan/data/HO3D/YCB_Video_Models/models'
114 | obj_mesh_base = o3d.io.read_triangle_mesh(os.path.join(mesh_base, '{}/textured_simple.obj'.format(obj)))
115 | elif dataset == 'ShapeNet':
116 | mesh_base = '/home/dalco/wuyan/data/ShapeNet/ShapeNet_selected'
117 | obj_mesh_base = o3d.io.read_triangle_mesh(os.path.join(mesh_base, '{}.obj'.format(obj)))
118 | obj_mesh_base.scale(0.15, center=np.zeros((3, 1)))
119 | else:
120 | raise NotImplementedError
121 |
122 | obj_mesh_base.compute_vertex_normals()
123 | v_temp = torch.FloatTensor(obj_mesh_base.vertices).to(device).view(1, -1, 3).repeat(n_samples, 1, 1)
124 | normal_temp = torch.FloatTensor(obj_mesh_base.vertex_normals).to(device).view(1, -1, 3).repeat(n_samples, 1, 1)
125 | obj_model = ObjectModel(v_temp, normal_temp, n_samples)
126 | # global_orient_dim = 9 if rotmat else 3
127 | object_output = obj_model(global_orient.view(n_samples, global_orient_dim), transl.view(n_samples, 3), v_temp.to(device), normal_temp.to(device), rotmat)
128 |
129 | object_verts = object_output[0].detach().squeeze().view(n_samples, -1, 3).cpu().numpy()
130 | # object_vertex_normal = object_output[1].detach().squeeze().cpu().numpy()
131 |
132 | for i in range(n_samples):
133 | mesh = o3d.geometry.TriangleMesh()
134 | mesh.vertices = o3d.utility.Vector3dVector(object_verts[i])
135 | mesh.triangles = obj_mesh_base.triangles
136 | mesh.compute_vertex_normals()
137 | mesh.paint_uniform_color(color_hex2rgb('#f59002')) # orange
138 | object_mesh_list.append(mesh)
139 |
140 | return object_mesh_list
141 |
142 |
143 | # def get_object_mesh(obj, transl, global_orient, n_samples, device='cpu'):
144 | # object_mesh_list = []
145 |
146 | # global_orient = torch.FloatTensor(global_orient).to(device)
147 | # transl = torch.FloatTensor(transl).to(device)
148 |
149 | # mesh_base = '../dataset/contact_meshes'
150 | # obj_mesh_base = o3d.io.read_triangle_mesh(os.path.join(mesh_base, obj + '.ply'))
151 | # obj_mesh_base.compute_vertex_normals()
152 | # v_temp = torch.FloatTensor(obj_mesh_base.vertices).to(device).view(1, -1, 3).repeat(n_samples, 1, 1)
153 | # normal_temp = torch.FloatTensor(obj_mesh_base.vertex_normals).to(device).view(1, -1, 3).repeat(n_samples, 1, 1)
154 | # obj_model = ObjectModel(v_temp, normal_temp, n_samples)
155 | # object_output = obj_model(global_orient.view(n_samples, 3), transl.view(n_samples, 3), v_temp.to(device), normal_temp.to(device))
156 |
157 | # object_verts = object_output[0].detach().squeeze().cpu().numpy()
158 | # # object_vertex_normal = object_output[1].detach().squeeze().cpu().numpy()
159 |
160 | # for i in range(n_samples):
161 | # mesh = o3d.geometry.TriangleMesh()
162 | # print('debug:', object_verts[i].shape)
163 | # mesh.vertices = o3d.utility.Vector3dVector(object_verts[i])
164 | # mesh.triangles = obj_mesh_base.triangles
165 | # mesh.compute_vertex_normals()
166 | # mesh.paint_uniform_color(color_hex2rgb('#f59002')) # orange
167 | # object_mesh_list.append(mesh)
168 |
169 | # return object_mesh_list
170 |
171 | def get_pcd(points, contact=None):
172 | print(points.shape)
173 | pcd_list = []
174 | n_samples = points.shape[0]
175 | for i in range(n_samples):
176 | pcd = o3d.geometry.PointCloud()
177 | pcd.points = o3d.utility.Vector3dVector(points[i])
178 | if contact is not None:
179 | colors = np.zeros((points.shape[1], 3))
180 | colors[:, 0] = contact[i].squeeze()
181 | pcd.colors = o3d.utility.Vector3dVector(colors)
182 | pcd_list.append(pcd)
183 | return pcd_list
184 |
185 |
--------------------------------------------------------------------------------