├── .gitignore ├── LICENSE ├── MotionFill ├── data │ ├── AMASS_localmotionfill_dataloader.py │ ├── GRAB_end2end_dataloader.py │ └── __pycache__ │ │ └── GRAB_end2end_dataloader.cpython-38.pyc └── models │ ├── LocalMotionFill.py │ ├── TrajFill.py │ ├── __pycache__ │ ├── LocalMotionFill.cpython-310.pyc │ ├── LocalMotionFill.cpython-38.pyc │ ├── TrajFill.cpython-310.pyc │ ├── TrajFill.cpython-38.pyc │ ├── fittingop.cpython-38.pyc │ └── motionfill_cvae.cpython-310.pyc │ └── fittingop.py ├── README.md ├── WholeGraspPose ├── __pycache__ │ ├── trainer.cpython-310.pyc │ └── trainer.cpython-38.pyc ├── configs │ ├── WholeGraspPose.yaml │ ├── rhand_weight.npy │ └── verts_per_edge.npy ├── data │ ├── __pycache__ │ │ ├── dataloader.cpython-310.pyc │ │ └── dataloader.cpython-38.pyc │ └── dataloader.py ├── models │ ├── __pycache__ │ │ ├── fittingop.cpython-38.pyc │ │ ├── models.cpython-310.pyc │ │ ├── models.cpython-38.pyc │ │ ├── objectmodel.cpython-38.pyc │ │ ├── pointnet.cpython-38.pyc │ │ ├── pointnet_util.cpython-310.pyc │ │ └── pointnet_util.cpython-38.pyc │ ├── fittingop.py │ ├── models.py │ ├── objectmodel.py │ └── pointnet.py └── trainer.py ├── body_utils ├── body_models │ └── VPoser │ │ ├── .DS_Store │ │ ├── vposerDecoderWeights.npz │ │ ├── vposerEncoderWeights.npz │ │ └── vposerMeanPose.npz ├── body_segments │ ├── L_Hand.json │ ├── L_Leg.json │ ├── R_Hand.json │ ├── R_Leg.json │ ├── back.json │ ├── body_mask.json │ ├── butt.json │ ├── gluteus.json │ └── thighs.json ├── left_heel_verts_id.npy ├── left_toe_verts_id.npy ├── left_whole_foot_verts_id.npy ├── right_heel_verts_id.npy ├── right_toe_verts_id.npy ├── right_whole_foot_verts_id.npy ├── smplx_mano_flame_correspondences │ ├── MANO_SMPLX_vertex_ids.pkl │ └── SMPL-X__FLAME_vertex_ids.npy └── smplx_markerset.json ├── images ├── binoculars-0.jpg ├── binoculars-60-first-view.jpg ├── binoculars-60.jpg ├── binoculars-movie-first-view.gif ├── binoculars-video.gif ├── teaser.png ├── two-stage-pipeline.png ├── wineglass-0.jpg ├── wineglass-60-first-view.jpg ├── wineglass-60.jpg ├── wineglass-movie-first-view.gif └── wineglass-video.gif ├── opt_graspmotion.py ├── opt_grasppose.py ├── train_graspmotion.py ├── train_grasppose.py ├── utils ├── Pivots.py ├── Pivots_torch.py ├── Quaternions.py ├── Quaternions_torch.py ├── __pycache__ │ ├── Pivots.cpython-38.pyc │ ├── Pivots_torch.cpython-38.pyc │ ├── Quaternions.cpython-38.pyc │ ├── Quaternions_torch.cpython-38.pyc │ ├── cfg_parser.cpython-310.pyc │ ├── cfg_parser.cpython-38.pyc │ ├── train_helper.cpython-38.pyc │ ├── train_tools.cpython-310.pyc │ ├── train_tools.cpython-38.pyc │ ├── utils.cpython-310.pyc │ ├── utils.cpython-38.pyc │ └── utils_body.cpython-38.pyc ├── cfg_parser.py ├── como │ ├── SSM2.json │ ├── __pycache__ │ │ ├── como_smooth.cpython-310.pyc │ │ ├── como_smooth.cpython-38.pyc │ │ ├── como_utils.cpython-310.pyc │ │ └── como_utils.cpython-38.pyc │ ├── como_smooth.py │ ├── como_smooth_model.pkl │ ├── como_utils.py │ ├── my_SSM2.json │ └── preprocess_stats_global_markers │ │ ├── Xmean.npy │ │ └── Xstd.npy ├── train_helper.py ├── utils.py └── utils_body.py └── visualization ├── __pycache__ └── visualization_utils.cpython-38.pyc ├── vis_motion.py ├── vis_pose.py └── visualization_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | dataset/ 2 | results/ 3 | smplx/ 4 | mano_v1_2/ 5 | vposer_v1_0/ 6 | pretrained_model/ 7 | logs/ 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Jiahao 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MotionFill/data/__pycache__/GRAB_end2end_dataloader.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/MotionFill/data/__pycache__/GRAB_end2end_dataloader.cpython-38.pyc -------------------------------------------------------------------------------- /MotionFill/models/LocalMotionFill.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torchvision 6 | from torch.autograd import Variable 7 | 8 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 9 | 10 | class EncBlock(nn.Module): 11 | def __init__(self, nin, nout, downsample=True, kernel=3): 12 | super(EncBlock, self).__init__() 13 | self.downsample = downsample 14 | padding = kernel // 2 15 | 16 | self.main = nn.Sequential( 17 | nn.Conv2d(in_channels=nin, out_channels=nout, kernel_size=kernel, stride=1, padding=padding, padding_mode='replicate'), 18 | nn.LeakyReLU(0.2), 19 | nn.Conv2d(in_channels=nout, out_channels=nout, kernel_size=kernel, stride=1, padding=padding, padding_mode='replicate'), 20 | nn.LeakyReLU(0.2), 21 | ) 22 | 23 | if self.downsample: 24 | self.pooling = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 25 | else: 26 | self.pooling = nn.MaxPool2d(kernel_size=(3,3), stride=(2, 1), padding=1) 27 | 28 | def forward(self, input): 29 | output = self.main(input) 30 | output = self.pooling(output) 31 | return output 32 | 33 | class DecBlock(nn.Module): 34 | def __init__(self, nin, nout, upsample=True, kernel=3): 35 | super(DecBlock, self).__init__() 36 | self.upsample = upsample 37 | 38 | padding = kernel // 2 39 | if upsample: 40 | self.deconv1 = nn.ConvTranspose2d(in_channels=nin, out_channels=nout, kernel_size=kernel, stride=2, padding=padding) 41 | else: 42 | self.deconv1 = nn.ConvTranspose2d(in_channels=nin, out_channels=nout, kernel_size=kernel, stride=(2, 1), padding=padding) 43 | self.deconv2 = nn.ConvTranspose2d(in_channels=nout, out_channels=nout, kernel_size=kernel, stride=1, padding=padding) 44 | self.leaky_relu = nn.LeakyReLU(0.2) 45 | 46 | def forward(self, input, out_size): 47 | output = self.deconv1(input, output_size=out_size) 48 | output = self.leaky_relu(output) 49 | output = self.leaky_relu(self.deconv2(output)) 50 | return output 51 | 52 | 53 | class DecBlock_output(nn.Module): 54 | def __init__(self, nin, nout, upsample=True, kernel=3): 55 | super(DecBlock_output, self).__init__() 56 | self.upsample = upsample 57 | padding = kernel // 2 58 | 59 | if upsample: 60 | self.deconv1 = nn.ConvTranspose2d(in_channels=nin, out_channels=nout, kernel_size=kernel, stride=2, padding=padding) 61 | else: 62 | self.deconv1 = nn.ConvTranspose2d(in_channels=nin, out_channels=nout, kernel_size=kernel, stride=(2, 1), padding=padding) 63 | self.deconv2 = nn.ConvTranspose2d(in_channels=nout, out_channels=nout, kernel_size=kernel, stride=1, padding=padding) 64 | self.leaky_relu = nn.LeakyReLU(0.2) 65 | 66 | 67 | def forward(self, input, out_size): 68 | output = self.deconv1(input, output_size=out_size) 69 | output = self.leaky_relu(output) 70 | output = self.deconv2(output) 71 | return output 72 | 73 | 74 | class AE(nn.Module): 75 | def __init__(self, downsample=True, in_channel=1, kernel=3): 76 | super(AE, self).__init__() 77 | self.enc_blc1 = EncBlock(nin=in_channel, nout=32, downsample=downsample, kernel=kernel) 78 | self.enc_blc2 = EncBlock(nin=32, nout=64, downsample=downsample, kernel=kernel) 79 | self.enc_blc3 = EncBlock(nin=64, nout=128, downsample=downsample, kernel=kernel) 80 | self.enc_blc4 = EncBlock(nin=128, nout=256, downsample=downsample, kernel=kernel) 81 | self.enc_blc5 = EncBlock(nin=256, nout=256, downsample=downsample, kernel=kernel) 82 | 83 | self.dec_blc1 = DecBlock(nin=256, nout=256, upsample=downsample, kernel=kernel) 84 | self.dec_blc2 = DecBlock(nin=256, nout=128, upsample=downsample, kernel=kernel) 85 | self.dec_blc3 = DecBlock(nin=128, nout=64, upsample=downsample, kernel=kernel) 86 | self.dec_blc4 = DecBlock(nin=64, nout=32, upsample=downsample, kernel=kernel) 87 | self.dec_blc5 = DecBlock_output(nin=32, nout=1, upsample=downsample, kernel=kernel) 88 | 89 | def forward(self, input): 90 | # input: [bs, c, d, T] 91 | x_down1 = self.enc_blc1(input) 92 | x_down2 = self.enc_blc2(x_down1) 93 | x_down3 = self.enc_blc3(x_down2) 94 | x_down4 = self.enc_blc4(x_down3) 95 | z = self.enc_blc5(x_down4) 96 | 97 | x_up4 = self.dec_blc1(z, x_down4.size()) 98 | x_up3 = self.dec_blc2(x_up4, x_down3.size()) 99 | x_up2 = self.dec_blc3(x_up3, x_down2.size()) 100 | x_up1 = self.dec_blc4(x_up2, x_down1.size()) 101 | output = self.dec_blc5(x_up1, input.size()) 102 | 103 | return output, z 104 | 105 | 106 | class Flatten(nn.Module): 107 | def forward(self, input): 108 | return input.view(input.size(0), -1) 109 | 110 | 111 | class CNN_Encoder(nn.Module): 112 | def __init__(self, downsample=True, in_channel=1, kernel=3): 113 | super(CNN_Encoder, self).__init__() 114 | self.enc_blc1 = EncBlock(nin=in_channel, nout=32, downsample=downsample, kernel=kernel) 115 | self.enc_blc2 = EncBlock(nin=32, nout=64, downsample=downsample, kernel=kernel) 116 | self.enc_blc3 = EncBlock(nin=64, nout=128, downsample=downsample, kernel=kernel) 117 | self.enc_blc4 = EncBlock(nin=128, nout=256, downsample=downsample, kernel=kernel) 118 | self.enc_blc5 = EncBlock(nin=256, nout=256, downsample=downsample, kernel=kernel) 119 | 120 | def forward(self, input): 121 | x_down1 = self.enc_blc1(input) 122 | x_down2 = self.enc_blc2(x_down1) 123 | x_down3 = self.enc_blc3(x_down2) 124 | x_down4 = self.enc_blc4(x_down3) 125 | z = self.enc_blc5(x_down4) 126 | size_list = [x_down4.size(), x_down3.size(), x_down2.size(), x_down1.size(), input.size()] 127 | return z, size_list 128 | 129 | 130 | class CNN_Decoder(nn.Module): 131 | def __init__(self, downsample=True, kernel=3): 132 | super(CNN_Decoder, self).__init__() 133 | self.dec_blc1 = DecBlock(nin=512, nout=256, upsample=downsample, kernel=kernel) 134 | self.dec_blc2 = DecBlock(nin=256, nout=128, upsample=downsample, kernel=kernel) 135 | self.dec_blc3 = DecBlock(nin=128, nout=64, upsample=downsample, kernel=kernel) 136 | self.dec_blc4 = DecBlock(nin=64, nout=32, upsample=downsample, kernel=kernel) 137 | self.dec_blc5 = DecBlock_output(nin=32, nout=1, upsample=downsample, kernel=kernel) 138 | 139 | def forward(self, z, size_list): 140 | x_up4 = self.dec_blc1(z, size_list[0]) 141 | x_up3 = self.dec_blc2(x_up4, size_list[1]) 142 | x_up2 = self.dec_blc3(x_up3, size_list[2]) 143 | x_up1 = self.dec_blc4(x_up2, size_list[3]) 144 | output = self.dec_blc5(x_up1, size_list[4]) 145 | return output 146 | 147 | 148 | class Motion_CNN_CVAE(nn.Module): 149 | def __init__(self, nz, downsample=True, in_channel=1, kernel=3, clip_seconds=2): 150 | super(Motion_CNN_CVAE, self).__init__() 151 | self.nz = nz # dim of latent variables 152 | self.enc_conv_input = CNN_Encoder(downsample, in_channel, kernel) 153 | self.enc_conv_gt = CNN_Encoder(downsample, in_channel, kernel) 154 | self.enc_conv_cat = nn.Sequential( 155 | nn.Conv2d(in_channels=512, out_channels=512, kernel_size=kernel, stride=1, padding=kernel//2, padding_mode='replicate'), 156 | nn.LeakyReLU(0.2), 157 | Flatten(), 158 | ) 159 | self.enc_mu = nn.Linear(512*8*clip_seconds, self.nz) 160 | self.enc_logvar = nn.Linear(512*8*clip_seconds, self.nz) 161 | self.dec_dense = nn.Linear(self.nz, 256*8*clip_seconds) 162 | self.dec_conv = CNN_Decoder(downsample, kernel) 163 | 164 | 165 | def encode(self, x, y): 166 | # x: [bs, c, d, T] (input) 167 | # y: [bs, c, d, T] (gt) 168 | e_x, _ = self.enc_conv_input(x) 169 | e_y, _ = self.enc_conv_gt(y) 170 | e_xy = torch.cat((e_x, e_y), dim=1) 171 | z = self.enc_conv_cat(e_xy) 172 | z_mu = self.enc_mu(z) 173 | z_logvar = self.enc_logvar(z) 174 | return z_mu, z_logvar 175 | 176 | def reparameterize(self, mu, logvar, eps): 177 | std = torch.exp(0.5*logvar) 178 | return mu + eps * std 179 | 180 | def decode(self, x, z): 181 | e_x, size_list = self.enc_conv_input(x) 182 | d_z_dense = self.dec_dense(z) 183 | d_z = d_z_dense.view(e_x.size(0), e_x.size(1), e_x.size(2), e_x.size(3)) 184 | d_xz = torch.cat((e_x, d_z), dim=1) 185 | y_hat = self.dec_conv(d_xz, size_list) 186 | return y_hat 187 | 188 | def forward(self, input, gt=None, is_train=True, z=None, is_twice=None): 189 | # input: [bs, c, d, T] 190 | self.bs = len(input) 191 | if is_train: 192 | mu, logvar = self.encode(input, gt) 193 | eps = torch.randn_like(logvar) 194 | z = self.reparameterize(mu, logvar, eps) #if self.training else mu 195 | else: 196 | if is_twice: 197 | mu, logvar = self.encode(input, gt) 198 | z = mu 199 | else: 200 | if z is None: 201 | z = torch.randn((self.bs, self.nz), device=input.device) 202 | mu = 0 203 | logvar = 1 204 | 205 | pred = self.decode(input, z) 206 | return pred, mu, logvar 207 | 208 | def sample_prior(self, x): 209 | z = torch.randn((x.shape[1], self.nz), device=x.device) 210 | return self.decode(x, z) 211 | -------------------------------------------------------------------------------- /MotionFill/models/TrajFill.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torchvision 6 | from torch.autograd import Variable 7 | 8 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 9 | 10 | class ResBlock(nn.Module): 11 | 12 | def __init__(self, 13 | Fin, 14 | Fout, 15 | n_neurons=256): 16 | 17 | super(ResBlock, self).__init__() 18 | self.Fin = Fin 19 | self.Fout = Fout 20 | 21 | self.fc1 = nn.Linear(Fin, n_neurons) 22 | self.bn1 = nn.BatchNorm1d(n_neurons) 23 | 24 | self.fc2 = nn.Linear(n_neurons, Fout) 25 | self.bn2 = nn.BatchNorm1d(Fout) 26 | 27 | if Fin != Fout: 28 | self.fc3 = nn.Linear(Fin, Fout) 29 | 30 | self.ll = nn.LeakyReLU(negative_slope=0.2) 31 | 32 | def forward(self, x, final_nl=True): 33 | Xin = x if self.Fin == self.Fout else self.ll(self.fc3(x)) 34 | 35 | Xout = self.fc1(x) # n_neurons 36 | Xout = self.bn1(Xout) 37 | Xout = self.ll(Xout) 38 | 39 | Xout = self.fc2(Xout) 40 | Xout = self.bn2(Xout) 41 | Xout = Xin + Xout 42 | 43 | if final_nl: 44 | return self.ll(Xout) 45 | return Xout 46 | 47 | 48 | class Traj_MLP_CVAE(nn.Module): # T*4 => T*8 => ResBlock -> z => ResBlock 49 | def __init__(self, nz, feature_dim, T, residual=False, load_path=None): 50 | super(Traj_MLP_CVAE, self).__init__() 51 | self.T = T 52 | self.feature_dim = feature_dim 53 | self.nz = nz 54 | self.residual = residual 55 | self.load_path = load_path 56 | 57 | """MLP""" 58 | self.enc1 = ResBlock(Fin=2*feature_dim*T, Fout=2*feature_dim*T, n_neurons=2*feature_dim*T) 59 | self.enc2 = ResBlock(Fin=2*feature_dim*T, Fout=2*feature_dim*T, n_neurons=2*feature_dim*T) 60 | self.enc_mu = nn.Linear(2*feature_dim*T, nz) 61 | self.enc_var = nn.Linear(2*feature_dim*T, nz) 62 | 63 | self.dec1 = ResBlock(Fin=nz+feature_dim*T, Fout=2*feature_dim*T, n_neurons=2*feature_dim*T) 64 | self.dec2 = ResBlock(Fin=2*feature_dim*T + feature_dim*T, Fout=feature_dim*T, n_neurons=feature_dim*T) 65 | 66 | if self.load_path is not None: 67 | self._load_model() 68 | 69 | def encode(self, x, y): 70 | """ x: [bs, T*feature_dim] """ 71 | bs = x.shape[0] 72 | x = torch.cat([x, y], dim=-1) 73 | h = self.enc1(x) 74 | h = self.enc2(h) 75 | z_mu = self.enc_mu(h) 76 | z_logvar = self.enc_var(h) 77 | 78 | return z_mu, z_logvar 79 | 80 | def decode(self, z, y): 81 | """z: [bs, nz]; y: [bs, 2*feature_dim]""" 82 | bs = y.shape[0] 83 | x = torch.cat([z, y], dim=-1) 84 | x = self.dec1(x) 85 | x = torch.cat([x, y], dim=-1) 86 | x = self.dec2(x) 87 | 88 | if self.residual: 89 | x = x + y 90 | 91 | return x.reshape(bs, self.feature_dim, -1) 92 | 93 | def reparameterize(self, mu, logvar, eps): 94 | std = torch.exp(0.5*logvar) 95 | return mu + eps * std 96 | 97 | def forward(self, x, y): 98 | bs = x.shape[0] 99 | # print(x.shape, y.shape) 100 | mu, logvar = self.encode(x, y) 101 | eps = torch.randn_like(logvar) 102 | z = self.reparameterize(mu, logvar, eps) #if self.training else mu 103 | pred = self.decode(z, y) 104 | 105 | return pred, mu, logvar 106 | 107 | def sample(self, y, z=None): 108 | if z is None: 109 | z = torch.randn((y.shape[0], self.nz), device=y.device) 110 | return self.decode(z, y) 111 | 112 | def _load_model(self): 113 | print('Loading Traj_CVAE from {} ...'.format(self.load_path)) 114 | assert self.load_path is not None 115 | model_cp = torch.load(self.load_path) 116 | self.load_state_dict(model_cp['model_dict']) 117 | -------------------------------------------------------------------------------- /MotionFill/models/__pycache__/LocalMotionFill.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/MotionFill/models/__pycache__/LocalMotionFill.cpython-310.pyc -------------------------------------------------------------------------------- /MotionFill/models/__pycache__/LocalMotionFill.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/MotionFill/models/__pycache__/LocalMotionFill.cpython-38.pyc -------------------------------------------------------------------------------- /MotionFill/models/__pycache__/TrajFill.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/MotionFill/models/__pycache__/TrajFill.cpython-310.pyc -------------------------------------------------------------------------------- /MotionFill/models/__pycache__/TrajFill.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/MotionFill/models/__pycache__/TrajFill.cpython-38.pyc -------------------------------------------------------------------------------- /MotionFill/models/__pycache__/fittingop.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/MotionFill/models/__pycache__/fittingop.cpython-38.pyc -------------------------------------------------------------------------------- /MotionFill/models/__pycache__/motionfill_cvae.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/MotionFill/models/__pycache__/motionfill_cvae.cpython-310.pyc -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | SAGA: Stochastic Whole-Body Grasping with Contact 3 |

4 | 5 | > [**SAGA: Stochastic Whole-Body Grasping with Contact**](https://jiahaoplus.github.io/SAGA/saga.html) 6 | > **ECCV 2022** 7 | > Yan Wu*, Jiahao Wang*, Yan Zhang, Siwei Zhang, Otmar Hilliges, Fisher Yu, Siyu Tang 8 | 9 | ![alt text](https://github.com/JiahaoPlus/SAGA/blob/main/images/teaser.png) 10 | This repository is the official implementation for the ECCV 2022 paper: [SAGA: Stochastic Whole-Body Grasping with Contact](https://jiahaoplus.github.io/SAGA/saga.html). 11 | 12 | \[[Project Page](https://jiahaoplus.github.io/SAGA/saga.html) | [Paper](https://arxiv.org/abs/2112.10103)\] 13 | 14 | ## Introduction 15 | Given an object in 3D space and a human initial pose, we aim to generate diverse human motion sequences to approach and grasp the given object. We propose a two-stage pipeline to address this problem by generating grasping ending pose first and then infilling the in-between motion. 16 | 17 | ![alt text](https://github.com/JiahaoPlus/SAGA/blob/main/images/two-stage-pipeline.png) 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 |
InputFirst-stage resultSecond-stage result
44 | 45 | ## Contents 46 | - [Installation](https://github.com/JiahaoPlus/SAGA#installation) 47 | - [Dataset Preparation](https://github.com/JiahaoPlus/SAGA#Dataset) 48 | - [Pretrained models](https://github.com/JiahaoPlus/SAGA#pretrained-models) 49 | - [Train](https://github.com/JiahaoPlus/SAGA#train) 50 | - [Grasping poses and motions generation for given object](https://github.com/JiahaoPlus/SAGA#inference) (object position and orientation can be customized) 51 | - [Visualization](https://github.com/JiahaoPlus/SAGA#visualization) 52 | 53 | ## Installation 54 | - Packages 55 | - python>=3.8 56 | - pytorch==1.12.1 57 | - [human-body-prior](https://pypi.org/project/human-body-prior/) 58 | - [SMPLX](https://github.com/vchoutas/smplx) 59 | - [Chamfer Distance](https://github.com/otaheri/chamfer_distance) 60 | - Open3D 61 | 62 | - Body Models 63 | Download [SMPL-X body model and vposer v1.0 model](https://smpl-x.is.tue.mpg.de/index.html) and put them under /body_utils/body_models folder as below: 64 | ``` 65 | SAGA 66 | │ 67 | └───body_utils 68 | │ 69 | └───body_models 70 | │ 71 | └───smplx 72 | │ └───SMPLX_FEMALE.npz 73 | │ └───... 74 | │ 75 | └───vposer_v1_0 76 | │ └───snapshots 77 | │ └───TR00_E096.pt 78 | │ └───... 79 | │ 80 | └───VPoser 81 | │ └───vposerDecoderWeights.npz 82 | │ └───vposerEnccoderWeights.npz 83 | │ └───vposerMeanPose.npz 84 | │ 85 | └───... 86 | │ 87 | └───... 88 | ``` 89 | 90 | ## Dataset 91 | ### 92 | Download [GRAB](https://grab.is.tue.mpg.de/) object mesh 93 | 94 | Download dataset for the first stage (GraspPose) from [[Google Drive]](https://drive.google.com/uc?export=download&id=1OfSGa3Y1QwkbeXUmAhrfeXtF89qvZj54) 95 | 96 | Download dataset for the second stage (GraspMotion) from [[Google Drive]](https://drive.google.com/uc?export=download&id=1QiouaqunhxKuv0D0QHv1JHlwVU-F6dWm) 97 | 98 | Put them under /dataset as below, 99 | ``` 100 | SAGA 101 | │ 102 | └───dataset 103 | │ 104 | └───GraspPose 105 | │ └───train 106 | │ └───s1 107 | │ └───... 108 | │ └───eval 109 | │ └───s1 110 | │ └───... 111 | │ └───test 112 | │ └───s1 113 | │ └───... 114 | │ 115 | └───GraspMotion 116 | │ └───Processed_Traj 117 | │ └───s1 118 | │ └───... 119 | │ 120 | └───contact_meshes 121 | │ └───airplane.ply 122 | │ └───... 123 | │ 124 | └───... 125 | ``` 126 | 127 | ## Pretrained models 128 | Download pretrained models from [[Google Drive]](https://drive.google.com/uc?export=download&id=1dxzBBOUbRuUAlNQGxnbmWLmtP7bmEE_9), and the pretrained models include: 129 | - Stage 1: pretrained WholeGrasp-VAE for male and female respectively 130 | - Stage 2: pretrained TrajFill-VAE and LocalMotionFill-VAE (end to end) 131 | 132 | ## Train 133 | ### First Stage: WholeGrasp-VAE training 134 | ``` 135 | python train_grasppose.py --data_path ./dataset/GraspPose --gender male --exp_name male 136 | ``` 137 | 138 | ### Second Stage: MotionFill-VAE training 139 | Can train TrajFill-VAE and LocalMotionFill-VAE separately first (download separately trained models from [[Google Drive]](https://drive.google.com/uc?export=download&id=1eyUW7YLmnAj-CHwIMe9qsAs6W63Aw7Ce)), and then train them end-to-end: 140 | ``` 141 | python train_graspmotion.py --pretrained_path_traj $PRETRAINED_MODEL_PATH/TrajFill_model_separate_trained.pkl --pretrained_path_motion $PRETRAINED_MODEL_PATH/LocalMotionFill_model_separate_trained.pkl 142 | ``` 143 | 144 | ## Inference 145 | ### First Stage: WholeGrasp-VAE sampling + GraspPose-Opt 146 | At the first stage, we generate grasping poses for the given object. 147 | The example command below generates 10 male pose samples to grasp camera, where the object's height and orientation are randomly set within a reasonable range. You can also easily customize your own setting accordingly. 148 | ``` 149 | python opt_grasppose.py --exp_name pretrained_male --gender male --pose_ckpt_path $PRETRAINED_MODEL_PATH/male_grasppose_model.pt --object camera --n_object_samples 10 150 | ``` 151 | ### Second Stage: MotionFill-VAE sampling + GraspMotion-Opt 152 | At the second stage, with generated ending pose from the first stage and a customizable human initial pose, we generate in-between motions. 153 | The example command below generates male motion samples to grasp camera, where the human initial pose and the initial distance away from the given object are randomly set within a reasonable range. You can also easily customize your own setting accordingly. 154 | ``` 155 | python opt_graspmotion.py --GraspPose_exp_name pretrained_male --object camera --gender male --traj_ckpt_path $PRETRAINED_MODEL_PATH/TrajFill_model.pkl --motion_ckpt_path $PRETRAINED_MODEL_PATH/LocalMotionFill_model.pkl 156 | ``` 157 | 158 | ## Visualization 159 | We provide visualization script to visualize the generated grasping ending pose results which is saved at (by default) _/results/$EXP_NAME/GraspPose/$OBJECT/fitting_results.npz_. 160 | ``` 161 | cd visualization 162 | python vis_pose.py --exp_name pretrained_male --gender male --object camera 163 | ``` 164 | 165 | We provide visualization script to visualize the generated grasping motion result which is saved at (by default) _/results/$EXP_NAME/GraspMotion/$OBJECT/fitting_results.npy_, from 3 view points, the first-person view, third-person view and the bird-eye view. 166 | ``` 167 | cd visualization 168 | python vis_motion.py --GraspPose_exp_name pretrained_male --gender male --object camera 169 | ``` 170 | 171 | ### Contact 172 | If you have any questions, feel free to contact us: 173 | - Yan Wu: yan.wu@vision.ee.ethz.ch 174 | - Jiahao Wang: jiwang@mpi-inf.mpg.de 175 | ### Citation 176 | ```bash 177 | @inproceedings{wu2022saga, 178 | title = {SAGA: Stochastic Whole-Body Grasping with Contact}, 179 | author = {Wu, Yan and Wang, Jiahao and Zhang, Yan and Zhang, Siwei and Hilliges, Otmar and Yu, Fisher and Tang, Siyu}, 180 | booktitle = {Proceedings of the European Conference on Computer Vision (ECCV)}, 181 | year = {2022} 182 | } 183 | ``` 184 | -------------------------------------------------------------------------------- /WholeGraspPose/__pycache__/trainer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/WholeGraspPose/__pycache__/trainer.cpython-310.pyc -------------------------------------------------------------------------------- /WholeGraspPose/__pycache__/trainer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/WholeGraspPose/__pycache__/trainer.cpython-38.pyc -------------------------------------------------------------------------------- /WholeGraspPose/configs/WholeGraspPose.yaml: -------------------------------------------------------------------------------- 1 | base_lr: 0.0005 2 | batch_size: 128 3 | best_net: null 4 | bps_size: 4096 5 | c_weights_path: null 6 | cuda_id: 0 7 | dataset_dir: null 8 | kl_coef: 0.005 9 | latentD: 16 10 | log_every_epoch: 10 11 | n_epochs: 100 12 | n_markers: 512 13 | n_neurons: 512 14 | n_workers: 8 15 | reg_coef: 0.0005 16 | # rhm_path: null 17 | seed: 4815 18 | try_num: 0 19 | use_multigpu: false 20 | vpe_path: null 21 | work_dir: null 22 | load_on_ram: False 23 | cond_object_height: True 24 | # gender: male # 'male' / 'female' / 25 | motion_intent: False # 'use'/'offhand'/'pass'/'lift'/ False(all intent) 26 | object_class: ['all'] # ['', ''] / ['all'] 27 | robustkl: False 28 | kl_annealing: True 29 | kl_annealing_epoch: 100 30 | marker_weight: 1 31 | foot_weight: 0 32 | collision_weight: 0 33 | consistency_weight: 1 34 | dropout: 0.1 35 | obj_feature: 12 36 | pointnet_hc: 64 37 | continue_train: False 38 | data_representation: 'markers_143' # 'joints' / 'markers_143' / 'markers_214' / 'markers_593' -------------------------------------------------------------------------------- /WholeGraspPose/configs/rhand_weight.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/WholeGraspPose/configs/rhand_weight.npy -------------------------------------------------------------------------------- /WholeGraspPose/configs/verts_per_edge.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/WholeGraspPose/configs/verts_per_edge.npy -------------------------------------------------------------------------------- /WholeGraspPose/data/__pycache__/dataloader.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/WholeGraspPose/data/__pycache__/dataloader.cpython-310.pyc -------------------------------------------------------------------------------- /WholeGraspPose/data/__pycache__/dataloader.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/WholeGraspPose/data/__pycache__/dataloader.cpython-38.pyc -------------------------------------------------------------------------------- /WholeGraspPose/data/dataloader.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import json 3 | import os 4 | import time 5 | 6 | import numpy as np 7 | import torch 8 | from smplx.lbs import batch_rodrigues 9 | from torch.utils import data 10 | 11 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 12 | to_cpu = lambda tensor: tensor.detach().cpu().numpy() 13 | 14 | 15 | class LoadData(data.Dataset): 16 | def __init__(self, 17 | dataset_dir, 18 | ds_name='train', 19 | gender=None, 20 | motion_intent=None, 21 | object_class=['all'], 22 | dtype=torch.float32, 23 | data_type = 'markers_143'): 24 | 25 | super().__init__() 26 | 27 | print('Preparing {} data...'.format(ds_name.upper())) 28 | self.sbj_idxs = [] 29 | self.objs_frames = {} 30 | self.ds_path = os.path.join(dataset_dir, ds_name) 31 | self.gender = gender 32 | self.motion_intent = motion_intent 33 | self.object_class = object_class 34 | self.data_type = data_type 35 | 36 | with open('body_utils/smplx_markerset.json') as f: 37 | markerset = json.load(f)['markersets'] 38 | self.markers_idx = [] 39 | for marker in markerset: 40 | if marker['type'] not in ['palm_5']: # 'palm_5' contains selected 5 markers per palm, but for training we use 'palm' set where there are 22 markers per palm. 41 | self.markers_idx += list(marker['indices'].values()) 42 | print(len(self.markers_idx)) 43 | self.ds = self.load_full_data(self.ds_path) 44 | 45 | def load_full_data(self, path): 46 | rec_list = [] 47 | output = {} 48 | 49 | markers_list = [] 50 | transf_transl_list = [] 51 | verts_object_list = [] 52 | contacts_object_list = [] 53 | normal_object_list = [] 54 | transl_object_list = [] 55 | global_orient_object_list = [] 56 | rotmat_list = [] 57 | contacts_markers_list = [] 58 | body_list = {} 59 | for key in ['transl', 'global_orient', 'body_pose', 'jaw_pose', 'leye_pose', 'reye_pose', 'left_hand_pose', 'right_hand_pose', 'expression']: 60 | body_list[key] = [] 61 | 62 | subsets_dict = {'male':['s1', 's2', 's8', 's9', 's10'], 63 | 'female': ['s3', 's4', 's5', 's6', 's7']} 64 | subsets = subsets_dict[self.gender] 65 | 66 | print('loading {} dataset: {}'.format(self.gender, subsets)) 67 | for subset in subsets: 68 | subset_path = os.path.join(path, subset) 69 | rec_list += [os.path.join(subset_path, i) for i in os.listdir(subset_path)] 70 | 71 | index = 0 72 | 73 | for rec in rec_list: 74 | data = np.load(rec, allow_pickle=True) 75 | 76 | ## select object 77 | obj_name = rec.split('/')[-1].split('_')[0] 78 | if 'all' not in self.object_class: 79 | if obj_name not in self.object_class: 80 | continue 81 | 82 | verts_object_list.append(data['verts_object']) 83 | markers_list.append(data[self.data_type]) 84 | transf_transl_list.append(data['transf_transl']) 85 | normal_object_list.append(data['normal_object']) 86 | global_orient_object_list.append(data['global_orient_object']) 87 | 88 | orient = torch.tensor(data['global_orient_object']) 89 | rot_mats = batch_rodrigues(orient.view(-1, 3)).view([orient.shape[0], 9]).numpy() 90 | rotmat_list.append(rot_mats) 91 | 92 | object_contact = data['contact_object'] 93 | markers_contact = data['contact_body'][:, self.markers_idx] 94 | object_contact_binary = (object_contact>0).astype(int) 95 | contacts_object_list.append(object_contact_binary) 96 | markers_contact_binary = (markers_contact>0).astype(int) 97 | contacts_markers_list.append(markers_contact_binary) 98 | 99 | # SMPLX parameters (optional) 100 | for key in data['body'][()].keys(): 101 | body_list[key].append(data['body'][()][key]) 102 | 103 | sbj_id = rec.split('/')[-2] 104 | self.sbj_idxs += [sbj_id]*data['verts_object'].shape[0] 105 | if obj_name in self.objs_frames.keys(): 106 | self.objs_frames[obj_name] += list(range(index, index+data['verts_object'].shape[0])) 107 | else: 108 | self.objs_frames[obj_name] = list(range(index, index+data['verts_object'].shape[0])) 109 | index += data['verts_object'].shape[0] 110 | output['transf_transl'] = torch.tensor(np.concatenate(transf_transl_list, axis=0)) 111 | output['markers'] = torch.tensor(np.concatenate(markers_list, axis=0)) # (B, 99, 3) 112 | output['verts_object'] = torch.tensor(np.concatenate(verts_object_list, axis=0)) # (B, 2048, 3) 113 | output['contacts_object'] = torch.tensor(np.concatenate(contacts_object_list, axis=0)) # (B, 2048, 3) 114 | output['contacts_markers'] = torch.tensor(np.concatenate(contacts_markers_list, axis=0)) # (B, 2048, 3) 115 | output['normal_object'] = torch.tensor(np.concatenate(normal_object_list, axis=0)) # (B, 2048, 3) 116 | output['global_orient_object'] = torch.tensor(np.concatenate(global_orient_object_list, axis=0)) # (B, 2048, 3) 117 | output['rotmat'] = torch.tensor(np.concatenate(rotmat_list, axis=0)) # (B, 2048, 3) 118 | 119 | # SMPLX parameters 120 | output['smplxparams'] = {} 121 | for key in ['transl', 'global_orient', 'body_pose', 'jaw_pose', 'leye_pose', 'reye_pose', 'left_hand_pose', 'right_hand_pose', 'expression']: 122 | output['smplxparams'][key] = torch.tensor(np.concatenate(body_list[key], axis=0)) 123 | 124 | return output 125 | 126 | def __len__(self): 127 | k = list(self.ds.keys())[0] 128 | return self.ds[k].shape[0] 129 | 130 | def __getitem__(self, idx): 131 | 132 | data_out = {} 133 | 134 | data_out['markers'] = self.ds['markers'][idx] 135 | data_out['contacts_markers'] = self.ds['contacts_markers'][idx] 136 | data_out['verts_object'] = self.ds['verts_object'][idx] 137 | data_out['normal_object'] = self.ds['normal_object'][idx] 138 | data_out['global_orient_object'] = self.ds['global_orient_object'][idx] 139 | data_out['transf_transl'] = self.ds['transf_transl'][idx] 140 | data_out['contacts_object'] = self.ds['contacts_object'][idx] 141 | if len(data_out['verts_object'].shape) == 2: 142 | data_out['feat_object'] = torch.cat([self.ds['normal_object'][idx], self.ds['rotmat'][idx, :6].view(1, 6).repeat(2048, 1)], -1) 143 | else: 144 | data_out['feat_object'] = torch.cat([self.ds['normal_object'][idx], self.ds['rotmat'][idx, :6].view(-1, 1, 6).repeat(1, 2048, 1)], -1) 145 | 146 | """You may want to uncomment it when you need smplxparams!!!""" 147 | data_out['smplxparams'] = {} 148 | for key in ['transl', 'global_orient', 'body_pose', 'jaw_pose', 'leye_pose', 'reye_pose', 'left_hand_pose', 'right_hand_pose', 'expression']: 149 | data_out['smplxparams'][key] = self.ds['smplxparams'][key][idx] 150 | 151 | ## random rotation augmentation 152 | bsz = 1 153 | theta = torch.FloatTensor(np.random.uniform(-np.pi/6, np.pi/6, bsz)) 154 | orient = torch.zeros((bsz, 3)) 155 | orient[:, -1] = theta 156 | rot_mats = batch_rodrigues(orient.view(-1, 3)).view([bsz, 3, 3]) 157 | if len(data_out['verts_object'].shape) == 3: 158 | data_out['markers'] = torch.matmul(data_out['markers'][:, :, :3], rot_mats.squeeze()) 159 | data_out['verts_object'] = torch.matmul(data_out['verts_object'][:, :, :3], rot_mats.squeeze()) 160 | data_out['normal_object'][:, :, :3] = torch.matmul(data_out['normal_object'][:, :, :3], rot_mats.squeeze()) 161 | else: 162 | data_out['markers'] = torch.matmul(data_out['markers'][:, :3], rot_mats.squeeze()) 163 | data_out['verts_object'] = torch.matmul(data_out['verts_object'][:, :3], rot_mats.squeeze()) 164 | data_out['normal_object'][:, :3] = torch.matmul(data_out['normal_object'][:, :3], rot_mats.squeeze()) 165 | 166 | return data_out 167 | -------------------------------------------------------------------------------- /WholeGraspPose/models/__pycache__/fittingop.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/WholeGraspPose/models/__pycache__/fittingop.cpython-38.pyc -------------------------------------------------------------------------------- /WholeGraspPose/models/__pycache__/models.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/WholeGraspPose/models/__pycache__/models.cpython-310.pyc -------------------------------------------------------------------------------- /WholeGraspPose/models/__pycache__/models.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/WholeGraspPose/models/__pycache__/models.cpython-38.pyc -------------------------------------------------------------------------------- /WholeGraspPose/models/__pycache__/objectmodel.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/WholeGraspPose/models/__pycache__/objectmodel.cpython-38.pyc -------------------------------------------------------------------------------- /WholeGraspPose/models/__pycache__/pointnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/WholeGraspPose/models/__pycache__/pointnet.cpython-38.pyc -------------------------------------------------------------------------------- /WholeGraspPose/models/__pycache__/pointnet_util.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/WholeGraspPose/models/__pycache__/pointnet_util.cpython-310.pyc -------------------------------------------------------------------------------- /WholeGraspPose/models/__pycache__/pointnet_util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/WholeGraspPose/models/__pycache__/pointnet_util.cpython-38.pyc -------------------------------------------------------------------------------- /WholeGraspPose/models/models.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append('.') 4 | sys.path.append('..') 5 | import numpy as np 6 | import torch 7 | from torch import nn 8 | from torch.nn import functional as F 9 | from WholeGraspPose.models.pointnet import (PointNetFeaturePropagation, 10 | PointNetSetAbstraction) 11 | 12 | 13 | class ResBlock(nn.Module): 14 | def __init__(self, 15 | Fin, 16 | Fout, 17 | n_neurons=256): 18 | 19 | super(ResBlock, self).__init__() 20 | self.Fin = Fin 21 | self.Fout = Fout 22 | 23 | self.fc1 = nn.Linear(Fin, n_neurons) 24 | self.bn1 = nn.BatchNorm1d(n_neurons) 25 | 26 | self.fc2 = nn.Linear(n_neurons, Fout) 27 | self.bn2 = nn.BatchNorm1d(Fout) 28 | 29 | if Fin != Fout: 30 | self.fc3 = nn.Linear(Fin, Fout) 31 | 32 | self.ll = nn.LeakyReLU(negative_slope=0.2) 33 | 34 | def forward(self, x, final_nl=True): 35 | Xin = x if self.Fin == self.Fout else self.ll(self.fc3(x)) 36 | 37 | Xout = self.fc1(x) # n_neurons 38 | Xout = self.bn1(Xout) 39 | Xout = self.ll(Xout) 40 | 41 | Xout = self.fc2(Xout) 42 | Xout = self.bn2(Xout) 43 | Xout = Xin + Xout 44 | 45 | if final_nl: 46 | return self.ll(Xout) 47 | return Xout 48 | 49 | 50 | class PointNetEncoder(nn.Module): 51 | 52 | def __init__(self, 53 | hc, 54 | in_feature): 55 | 56 | super(PointNetEncoder, self).__init__() 57 | self.hc = hc 58 | self.in_feature = in_feature 59 | 60 | self.enc_sa1 = PointNetSetAbstraction(npoint=256, radius=0.2, nsample=32, in_channel=self.in_feature, mlp=[self.hc, self.hc*2], group_all=False) 61 | self.enc_sa2 = PointNetSetAbstraction(npoint=128, radius=0.25, nsample=64, in_channel=self.hc*2 + 3, mlp=[self.hc*2, self.hc*4], group_all=False) 62 | self.enc_sa3 = PointNetSetAbstraction(npoint=None, radius=None, nsample=None, in_channel=self.hc*4 + 3, mlp=[self.hc*4, self.hc*8], group_all=True) 63 | 64 | def forward(self, l0_xyz, l0_points): 65 | 66 | l1_xyz, l1_points = self.enc_sa1(l0_xyz, l0_points) 67 | l2_xyz, l2_points = self.enc_sa2(l1_xyz, l1_points) 68 | l3_xyz, l3_points = self.enc_sa3(l2_xyz, l2_points) 69 | x = l3_points.view(-1, self.hc*8) 70 | 71 | return l1_xyz, l1_points, l2_xyz, l2_points, l3_xyz, x 72 | 73 | class MarkerNet(nn.Module): 74 | def __init__(self, cfg, n_neurons=1024, in_cond=1024, latentD=16, in_feature=143*3, **kwargs): 75 | super(MarkerNet, self).__init__() 76 | 77 | self.cfg = cfg 78 | ## condition features 79 | self.obj_cond_feature = 1 if cfg.cond_object_height else 0 80 | 81 | self.enc_bn1 = nn.BatchNorm1d(in_feature + self.obj_cond_feature) 82 | self.enc_rb1 = ResBlock(in_cond + in_feature + int(in_feature/3) + self.obj_cond_feature, n_neurons) 83 | self.enc_rb2 = ResBlock(n_neurons + in_cond + in_feature + int(in_feature/3) + self.obj_cond_feature, n_neurons) 84 | 85 | self.dec_rb1 = ResBlock(latentD + in_cond, n_neurons) 86 | self.dec_rb2_xyz = ResBlock(n_neurons + latentD + in_cond + self.obj_cond_feature, n_neurons) 87 | self.dec_rb2_p = ResBlock(n_neurons + latentD + in_cond, n_neurons) 88 | 89 | self.dec_output_xyz = nn.Linear(n_neurons, 143*3) 90 | self.dec_output_p = nn.Linear(n_neurons, 143) 91 | self.p_output = nn.Sigmoid() 92 | 93 | def enc(self, cond_object, markers, contacts_markers, transf_transl): 94 | _, _, _, _, _, object_cond = cond_object 95 | 96 | X = markers.view(markers.shape[0], -1) 97 | 98 | if self.obj_cond_feature == 1: 99 | X = torch.cat([X, transf_transl[:, -1, None]], dim=-1).float() 100 | 101 | X0 = self.enc_bn1(X) 102 | 103 | X0 = torch.cat([X0, contacts_markers.view(-1, 143), object_cond], dim=-1) 104 | 105 | 106 | X = self.enc_rb1(X0, True) 107 | X = self.enc_rb2(torch.cat([X0, X], dim=1), True) 108 | 109 | return X 110 | 111 | def dec(self, Z, cond_object, transf_transl): 112 | 113 | _, _, _, _, _, object_cond = cond_object 114 | X0 = torch.cat([Z, object_cond], dim=1).float() 115 | 116 | X = self.dec_rb1(X0, True) 117 | X_xyz = self.dec_rb2_xyz(torch.cat([X0, X, transf_transl[:, -1, None]], dim=1).float(), True) 118 | X_p = self.dec_rb2_p(torch.cat([X0, X], dim=1).float(), True) 119 | 120 | xyz_pred = self.dec_output_xyz(X_xyz) 121 | p_pred = self.p_output(self.dec_output_p(X_p)) 122 | 123 | return xyz_pred, p_pred 124 | 125 | class ContactNet(nn.Module): 126 | def __init__(self, cfg, latentD=16, hc=64, object_feature=6, **kwargs): 127 | super(ContactNet, self).__init__() 128 | self.latentD = latentD 129 | self.hc = hc 130 | self.object_feature = object_feature 131 | 132 | self.enc_pointnet = PointNetEncoder(self.hc, self.object_feature+1) 133 | 134 | self.dec_fc1 = nn.Linear(self.latentD, self.hc*2) 135 | self.dec_bn1 = nn.BatchNorm1d(self.hc*2) 136 | self.dec_drop1 = nn.Dropout(0.1) 137 | self.dec_fc2 = nn.Linear(self.hc*2, self.hc*4) 138 | self.dec_bn2 = nn.BatchNorm1d(self.hc*4) 139 | self.dec_drop2 = nn.Dropout(0.1) 140 | self.dec_fc3 = nn.Linear(self.hc*4, self.hc*8) 141 | self.dec_bn3 = nn.BatchNorm1d(self.hc*8) 142 | self.dec_drop3 = nn.Dropout(0.1) 143 | 144 | self.dec_fc4 = nn.Linear(self.hc*8+self.latentD, self.hc*8) 145 | self.dec_bn4 = nn.BatchNorm1d(self.hc*8) 146 | self.dec_drop4 = nn.Dropout(0.1) 147 | 148 | self.dec_fp3 = PointNetFeaturePropagation(in_channel=self.hc*8+self.hc*4, mlp=[self.hc*8, self.hc*4]) 149 | self.dec_fp2 = PointNetFeaturePropagation(in_channel=self.hc*4+self.hc*2, mlp=[self.hc*4, self.hc*2]) 150 | self.dec_fp1 = PointNetFeaturePropagation(in_channel=self.hc*2+self.object_feature, mlp=[self.hc*2, self.hc*2]) 151 | 152 | self.dec_conv1 = nn.Conv1d(self.hc*2, self.hc*2, 1) 153 | self.dec_conv_bn1 = nn.BatchNorm1d(self.hc*2) 154 | self.dec_conv_drop1 = nn.Dropout(0.1) 155 | self.dec_conv2 = nn.Conv1d(self.hc*2, 1, 1) 156 | 157 | self.dec_output = nn.Sigmoid() 158 | 159 | def enc(self, contacts_object, verts_object, feat_object): 160 | l0_xyz = verts_object[:, :3, :] 161 | l0_points = torch.cat([feat_object, contacts_object], 1) if feat_object is not None else contacts_object 162 | _, _, _, _, _, x = self.enc_pointnet(l0_xyz, l0_points) 163 | 164 | return x 165 | 166 | def dec(self, z, cond_object, verts_object, feat_object): 167 | l0_xyz = verts_object[:, :3, :] 168 | l0_points = feat_object 169 | 170 | l1_xyz, l1_points, l2_xyz, l2_points, l3_xyz, l3_points = cond_object 171 | 172 | l3_points = torch.cat([l3_points, z], 1) 173 | l3_points = self.dec_drop4(F.relu(self.dec_bn4(self.dec_fc4(l3_points)), inplace=True)) 174 | l3_points = l3_points.view(l3_points.size()[0], l3_points.size()[1], 1) 175 | 176 | l2_points = self.dec_fp3(l2_xyz, l3_xyz, l2_points, l3_points) 177 | l1_points = self.dec_fp2(l1_xyz, l2_xyz, l1_points, l2_points) 178 | if l0_points is None: 179 | l0_points = self.dec_fp1(l0_xyz, l1_xyz, l0_xyz, l1_points) 180 | else: 181 | l0_points = self.dec_fp1(l0_xyz, l1_xyz, torch.cat([l0_xyz,l0_points],1), l1_points) 182 | feat = F.relu(self.dec_conv_bn1(self.dec_conv1(l0_points)), inplace=True) 183 | x = self.dec_conv_drop1(feat) 184 | x = self.dec_conv2(x) 185 | x = self.dec_output(x) 186 | 187 | return x 188 | 189 | 190 | class FullBodyGraspNet(nn.Module): 191 | def __init__(self, cfg, **kwargs): 192 | super(FullBodyGraspNet, self).__init__() 193 | 194 | self.cfg = cfg 195 | self.latentD = cfg.latentD 196 | self.in_feature_list = {} 197 | self.in_feature_list['joints'] = 127*3 198 | self.in_feature_list['markers_143'] = 143*3 199 | self.in_feature_list['markers_214'] = 214*3 200 | self.in_feature_list['markers_593'] = 593*3 201 | 202 | self.in_feature = self.in_feature_list[cfg.data_representation] 203 | 204 | self.marker_net = MarkerNet(cfg, n_neurons=cfg.n_markers, in_cond=cfg.pointnet_hc*8, latentD=cfg.latentD, in_feature=self.in_feature) 205 | self.contact_net = ContactNet(cfg, latentD=cfg.latentD, hc=cfg.pointnet_hc, object_feature=cfg.obj_feature) 206 | 207 | self.pointnet = PointNetEncoder(hc=cfg.pointnet_hc, in_feature=cfg.obj_feature) 208 | # encoder fusion 209 | self.enc_fusion = ResBlock(cfg.n_markers+self.cfg.pointnet_hc*8, cfg.n_neurons) 210 | 211 | self.enc_mu = nn.Linear(cfg.n_neurons, cfg.latentD) 212 | self.enc_var = nn.Linear(cfg.n_neurons, cfg.latentD) 213 | 214 | 215 | def encode(self, object_cond, verts_object, feat_object, contacts_object, markers, contacts_markers, transf_transl): 216 | # marker branch 217 | marker_feat = self.marker_net.enc(object_cond, markers, contacts_markers, transf_transl) # [B, n_neurons=1024] 218 | 219 | # contact branch 220 | contact_feat = self.contact_net.enc(contacts_object, verts_object, feat_object) # [B, hc*8] 221 | 222 | # fusion 223 | X = torch.cat([marker_feat, contact_feat], dim=-1) 224 | X = self.enc_fusion(X, True) 225 | 226 | return torch.distributions.normal.Normal(self.enc_mu(X), F.softplus(self.enc_var(X))) 227 | 228 | 229 | def decode(self, Z, object_cond, verts_object, feat_object, transf_transl): 230 | 231 | bs = Z.shape[0] 232 | # marker_branch 233 | markers_xyz_pred, markers_p_pred = self.marker_net.dec(Z, object_cond, transf_transl) 234 | 235 | # contact branch 236 | contact_pred = self.contact_net.dec(Z, object_cond, verts_object, feat_object) 237 | 238 | return markers_xyz_pred.view(bs, -1, 3), markers_p_pred, contact_pred 239 | 240 | def forward(self, verts_object, feat_object, contacts_object, markers, contacts_markers, transf_transl, **kwargs): 241 | object_cond = self.pointnet(l0_xyz=verts_object, l0_points=feat_object) 242 | z = self.encode(object_cond, verts_object, feat_object, contacts_object, markers, contacts_markers, transf_transl) 243 | z_s = z.rsample() 244 | 245 | markers_xyz_pred, markers_p_pred, object_p_pred = self.decode(z_s, object_cond, verts_object, feat_object, transf_transl) 246 | 247 | results = {'markers': markers_xyz_pred, 'contacts_markers': markers_p_pred, 'contacts_object': object_p_pred, 'object_code': object_cond[-1], 'mean': z.mean, 'std': z.scale} 248 | 249 | return results 250 | 251 | 252 | def sample(self, verts_object, feat_object, transf_transl, seed=None): 253 | bs = verts_object.shape[0] 254 | if seed is not None: 255 | np.random.seed(seed) 256 | dtype = verts_object.dtype 257 | device = verts_object.device 258 | self.eval() 259 | with torch.no_grad(): 260 | Zgen = np.random.normal(0., 1., size=(bs, self.latentD)) 261 | Zgen = torch.tensor(Zgen,dtype=dtype).to(device) 262 | 263 | object_cond = self.pointnet(l0_xyz=verts_object, l0_points=feat_object) 264 | 265 | return self.decode(Zgen, object_cond, verts_object, feat_object, transf_transl) 266 | -------------------------------------------------------------------------------- /WholeGraspPose/models/objectmodel.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | from smplx.lbs import batch_rodrigues 7 | 8 | model_output = namedtuple('output', ['vertices', 'vertex_normals', 'global_orient', 'transl']) 9 | 10 | class ObjectModel(nn.Module): 11 | 12 | def __init__(self, 13 | v_template, 14 | normal_template, 15 | batch_size=1, 16 | dtype=torch.float32): 17 | 18 | super(ObjectModel, self).__init__() 19 | 20 | 21 | self.dtype = dtype 22 | self.batch_size = batch_size 23 | 24 | 25 | def forward(self, global_orient=None, transl=None, v_template=None, n_template=None, rotmat=False, **kwargs): 26 | 27 | 28 | if global_orient is None: 29 | global_orient = self.global_orient 30 | if transl is None: 31 | transl = self.transl 32 | if v_template is None: 33 | v_template = self.v_template 34 | if n_template is None: 35 | n_template = self.n_template 36 | 37 | if not rotmat: 38 | rot_mats = batch_rodrigues(global_orient.view(-1, 3)).view([self.batch_size, 3, 3]) 39 | else: 40 | rot_mats = global_orient.view([self.batch_size, 3, 3]) 41 | 42 | vertices = torch.matmul(v_template, rot_mats) + transl.unsqueeze(dim=1) 43 | 44 | vertex_normals = torch.matmul(n_template, rot_mats) 45 | 46 | output = model_output(vertices=vertices, 47 | vertex_normals = vertex_normals, 48 | global_orient=global_orient, 49 | transl=transl) 50 | 51 | return output 52 | 53 | -------------------------------------------------------------------------------- /WholeGraspPose/models/pointnet.py: -------------------------------------------------------------------------------- 1 | from time import time 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | def timeit(tag, t): 10 | print("{}: {}s".format(tag, time() - t)) 11 | return time() 12 | 13 | def pc_normalize(pc): 14 | l = pc.shape[0] 15 | centroid = np.mean(pc, axis=0) 16 | pc = pc - centroid 17 | m = np.max(np.sqrt(np.sum(pc**2, axis=1))) 18 | pc = pc / m 19 | return pc 20 | 21 | def square_distance(src, dst): 22 | """ 23 | Calculate Euclid distance between each two points. 24 | 25 | src^T * dst = xn * xm + yn * ym + zn * zm; 26 | sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn; 27 | sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm; 28 | dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2 29 | = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst 30 | 31 | Input: 32 | src: source points, [B, N, C] 33 | dst: target points, [B, M, C] 34 | Output: 35 | dist: per-point square distance, [B, N, M] 36 | """ 37 | B, N, _ = src.shape 38 | _, M, _ = dst.shape 39 | dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) 40 | dist += torch.sum(src ** 2, -1).view(B, N, 1) 41 | dist += torch.sum(dst ** 2, -1).view(B, 1, M) 42 | return dist 43 | 44 | 45 | def index_points(points, idx): 46 | """ 47 | 48 | Input: 49 | points: input points data, [B, N, C] 50 | idx: sample index data, [B, S] 51 | Return: 52 | new_points:, indexed points data, [B, S, C] 53 | """ 54 | device = points.device 55 | B = points.shape[0] 56 | view_shape = list(idx.shape) 57 | view_shape[1:] = [1] * (len(view_shape) - 1) 58 | repeat_shape = list(idx.shape) 59 | repeat_shape[0] = 1 60 | batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape) 61 | new_points = points[batch_indices, idx, :] 62 | return new_points 63 | 64 | 65 | def farthest_point_sample(xyz, npoint): 66 | """ 67 | Input: 68 | xyz: pointcloud data, [B, N, 3] 69 | npoint: number of samples 70 | Return: 71 | centroids: sampled pointcloud index, [B, npoint] 72 | """ 73 | device = xyz.device 74 | B, N, C = xyz.shape 75 | centroids = torch.zeros(B, npoint, dtype=torch.long).to(device) 76 | distance = torch.ones(B, N).to(device) * 1e10 77 | farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) 78 | batch_indices = torch.arange(B, dtype=torch.long).to(device) 79 | for i in range(npoint): 80 | centroids[:, i] = farthest 81 | centroid = xyz[batch_indices, farthest, :].view(B, 1, 3) 82 | dist = torch.sum((xyz - centroid) ** 2, -1) 83 | mask = dist < distance 84 | distance[mask] = dist[mask] 85 | farthest = torch.max(distance, -1)[1] 86 | return centroids 87 | 88 | 89 | def query_ball_point(radius, nsample, xyz, new_xyz): 90 | """ 91 | Input: 92 | radius: local region radius 93 | nsample: max sample number in local region 94 | xyz: all points, [B, N, 3] 95 | new_xyz: query points, [B, S, 3] 96 | Return: 97 | group_idx: grouped points index, [B, S, nsample] 98 | """ 99 | device = xyz.device 100 | B, N, C = xyz.shape 101 | _, S, _ = new_xyz.shape 102 | group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1]) 103 | sqrdists = square_distance(new_xyz, xyz) 104 | group_idx[sqrdists > radius ** 2] = N 105 | group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample] 106 | group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample]) 107 | mask = group_idx == N 108 | group_idx[mask] = group_first[mask] 109 | return group_idx 110 | 111 | 112 | def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False): 113 | """ 114 | Input: 115 | npoint: 116 | radius: 117 | nsample: 118 | xyz: input points position data, [B, N, 3] 119 | points: input points data, [B, N, D] 120 | Return: 121 | new_xyz: sampled points position data, [B, npoint, nsample, 3] 122 | new_points: sampled points data, [B, npoint, nsample, 3+D] 123 | """ 124 | B, N, C = xyz.shape 125 | S = npoint 126 | fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C] 127 | torch.cuda.empty_cache() 128 | new_xyz = index_points(xyz, fps_idx) 129 | torch.cuda.empty_cache() 130 | idx = query_ball_point(radius, nsample, xyz, new_xyz) 131 | torch.cuda.empty_cache() 132 | grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C] 133 | torch.cuda.empty_cache() 134 | grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C) 135 | torch.cuda.empty_cache() 136 | 137 | if points is not None: 138 | grouped_points = index_points(points, idx) 139 | new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D] 140 | else: 141 | new_points = grouped_xyz_norm 142 | if returnfps: 143 | return new_xyz, new_points, grouped_xyz, fps_idx 144 | else: 145 | return new_xyz, new_points 146 | 147 | 148 | def sample_and_group_all(xyz, points): 149 | """ 150 | Input: 151 | xyz: input points position data, [B, N, 3] 152 | points: input points data, [B, N, D] 153 | Return: 154 | new_xyz: sampled points position data, [B, 1, 3] 155 | new_points: sampled points data, [B, 1, N, 3+D] 156 | """ 157 | device = xyz.device 158 | B, N, C = xyz.shape 159 | new_xyz = torch.zeros(B, 1, C).to(device) 160 | grouped_xyz = xyz.view(B, 1, N, C) 161 | if points is not None: 162 | new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1) 163 | else: 164 | new_points = grouped_xyz 165 | return new_xyz, new_points 166 | 167 | 168 | class PointNetSetAbstraction(nn.Module): 169 | def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all): 170 | super(PointNetSetAbstraction, self).__init__() 171 | self.npoint = npoint 172 | self.radius = radius 173 | self.nsample = nsample 174 | self.mlp_convs = nn.ModuleList() 175 | self.mlp_bns = nn.ModuleList() 176 | last_channel = in_channel 177 | for out_channel in mlp: 178 | self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1)) 179 | self.mlp_bns.append(nn.BatchNorm2d(out_channel)) 180 | last_channel = out_channel 181 | self.group_all = group_all 182 | 183 | def forward(self, xyz, points): 184 | """ 185 | Input: 186 | xyz: input points position data, [B, C, N] 187 | points: input points data, [B, D, N] 188 | Return: 189 | new_xyz: sampled points position data, [B, C, S] 190 | new_points_concat: sample points feature data, [B, D', S] 191 | """ 192 | xyz = xyz.permute(0, 2, 1) 193 | if points is not None: 194 | points = points.permute(0, 2, 1) 195 | # print('before sample:', points.shape) 196 | 197 | if self.group_all: 198 | new_xyz, new_points = sample_and_group_all(xyz, points) 199 | else: 200 | new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points) 201 | # new_xyz: sampled points position data, [B, npoint, C] 202 | # new_points: sampled points data, [B, npoint, nsample, C+D] 203 | # print('after sample:', new_points.shape) 204 | new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint] 205 | for i, conv in enumerate(self.mlp_convs): 206 | bn = self.mlp_bns[i] 207 | new_points = F.relu(bn(conv(new_points)), inplace=True) 208 | # print('after conv:', new_points.shape) 209 | new_points = torch.max(new_points, 2)[0] 210 | new_xyz = new_xyz.permute(0, 2, 1) 211 | return new_xyz, new_points 212 | 213 | 214 | class PointNetSetAbstractionMsg(nn.Module): 215 | def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list): 216 | super(PointNetSetAbstractionMsg, self).__init__() 217 | self.npoint = npoint 218 | self.radius_list = radius_list 219 | self.nsample_list = nsample_list 220 | self.conv_blocks = nn.ModuleList() 221 | self.bn_blocks = nn.ModuleList() 222 | for i in range(len(mlp_list)): 223 | convs = nn.ModuleList() 224 | bns = nn.ModuleList() 225 | last_channel = in_channel + 3 226 | for out_channel in mlp_list[i]: 227 | convs.append(nn.Conv2d(last_channel, out_channel, 1)) 228 | bns.append(nn.BatchNorm2d(out_channel)) 229 | last_channel = out_channel 230 | self.conv_blocks.append(convs) 231 | self.bn_blocks.append(bns) 232 | 233 | def forward(self, xyz, points): 234 | """ 235 | Input: 236 | xyz: input points position data, [B, C, N] 237 | points: input points data, [B, D, N] 238 | Return: 239 | new_xyz: sampled points position data, [B, C, S] 240 | new_points_concat: sample points feature data, [B, D', S] 241 | """ 242 | xyz = xyz.permute(0, 2, 1) 243 | if points is not None: 244 | points = points.permute(0, 2, 1) 245 | 246 | B, N, C = xyz.shape 247 | S = self.npoint 248 | new_xyz = index_points(xyz, farthest_point_sample(xyz, S)) 249 | new_points_list = [] 250 | for i, radius in enumerate(self.radius_list): 251 | K = self.nsample_list[i] 252 | group_idx = query_ball_point(radius, K, xyz, new_xyz) 253 | grouped_xyz = index_points(xyz, group_idx) 254 | grouped_xyz -= new_xyz.view(B, S, 1, C) 255 | if points is not None: 256 | grouped_points = index_points(points, group_idx) 257 | grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1) 258 | else: 259 | grouped_points = grouped_xyz 260 | 261 | grouped_points = grouped_points.permute(0, 3, 2, 1) # [B, D, K, S] 262 | for j in range(len(self.conv_blocks[i])): 263 | conv = self.conv_blocks[i][j] 264 | bn = self.bn_blocks[i][j] 265 | grouped_points = F.relu(bn(conv(grouped_points)), inplace=True) 266 | new_points = torch.max(grouped_points, 2)[0] # [B, D', S] 267 | new_points_list.append(new_points) 268 | 269 | new_xyz = new_xyz.permute(0, 2, 1) 270 | new_points_concat = torch.cat(new_points_list, dim=1) 271 | return new_xyz, new_points_concat 272 | 273 | 274 | class PointNetFeaturePropagation(nn.Module): 275 | def __init__(self, in_channel, mlp): 276 | super(PointNetFeaturePropagation, self).__init__() 277 | self.mlp_convs = nn.ModuleList() 278 | self.mlp_bns = nn.ModuleList() 279 | last_channel = in_channel 280 | for out_channel in mlp: 281 | self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1)) 282 | self.mlp_bns.append(nn.BatchNorm1d(out_channel)) 283 | last_channel = out_channel 284 | 285 | def forward(self, xyz1, xyz2, points1, points2): 286 | """ 287 | Input: 288 | xyz1: input points position data, [B, C, N] 289 | xyz2: sampled input points position data, [B, C, S] 290 | points1: input points data, [B, D, N] 291 | points2: input points data, [B, D, S] 292 | Return: 293 | new_points: upsampled points data, [B, D', N] 294 | """ 295 | xyz1 = xyz1.permute(0, 2, 1) 296 | xyz2 = xyz2.permute(0, 2, 1) 297 | 298 | points2 = points2.permute(0, 2, 1) 299 | B, N, C = xyz1.shape 300 | _, S, _ = xyz2.shape 301 | 302 | if S == 1: 303 | interpolated_points = points2.repeat(1, N, 1) 304 | else: 305 | dists = square_distance(xyz1, xyz2) 306 | dists, idx = dists.sort(dim=-1) 307 | dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3] 308 | 309 | dist_recip = 1.0 / (dists + 1e-8) 310 | norm = torch.sum(dist_recip, dim=2, keepdim=True) 311 | weight = dist_recip / norm 312 | interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2) 313 | 314 | if points1 is not None: 315 | points1 = points1.permute(0, 2, 1) 316 | new_points = torch.cat([points1, interpolated_points], dim=-1) 317 | else: 318 | new_points = interpolated_points 319 | 320 | new_points = new_points.permute(0, 2, 1) 321 | for i, conv in enumerate(self.mlp_convs): 322 | bn = self.mlp_bns[i] 323 | new_points = F.relu(bn(conv(new_points)), inplace=True) 324 | return new_points 325 | 326 | -------------------------------------------------------------------------------- /body_utils/body_models/VPoser/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/body_utils/body_models/VPoser/.DS_Store -------------------------------------------------------------------------------- /body_utils/body_models/VPoser/vposerDecoderWeights.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/body_utils/body_models/VPoser/vposerDecoderWeights.npz -------------------------------------------------------------------------------- /body_utils/body_models/VPoser/vposerEncoderWeights.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/body_utils/body_models/VPoser/vposerEncoderWeights.npz -------------------------------------------------------------------------------- /body_utils/body_models/VPoser/vposerMeanPose.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/body_utils/body_models/VPoser/vposerMeanPose.npz -------------------------------------------------------------------------------- /body_utils/body_segments/L_Hand.json: -------------------------------------------------------------------------------- 1 | {"faces_ind": [3320, 3333, 3335, 3347, 3352, 3357, 3369, 3376, 3389, 3392, 3427, 3445, 3446, 3450, 3453, 3454, 3456, 3458, 3459, 3463, 3468, 3471, 3477, 3478, 3484, 3488, 3498, 3504, 3514, 3526, 3527, 3532, 3535, 3546, 3549, 3551, 3560, 3561, 3566, 3576, 3578, 3591, 3593, 3594, 3595, 3596, 3597, 3601, 3613, 3614, 3615, 3620, 3625, 3626, 3627, 3628, 3632, 3634, 3638, 3640, 3646, 3647, 3650, 3652, 3653, 3654, 3655, 3659, 3664, 3668, 3681, 3685, 3686, 3687, 3690, 3695, 3698, 3708, 3709, 3714, 3715, 3741, 3753, 3756, 3771, 3788, 3823, 3825, 3826, 3831, 3833, 3990, 3991, 3996, 3998, 3999, 4004, 4006, 4011, 4083, 4093, 4130, 4131, 4136, 4145, 4152, 4155, 4166, 4273, 4342, 4359, 4360, 4366, 4376, 4378, 4384, 4406, 4649, 4653, 4696, 4700, 4707, 4708, 4716, 4722, 4773, 4777, 4794, 4804, 4825, 4837, 4948, 4963, 5019, 5022, 5026, 5033, 5042, 5046, 5091, 5133, 5150, 5156, 5168, 5178, 5179, 5180, 5193, 5196, 5224, 5247, 5252, 5258, 5304, 5306, 5365, 5550, 5558, 5564, 5576, 5589, 5616, 5646, 5647, 5648, 5661, 5668, 5670, 5677, 5716, 5721, 5767, 5774, 5789, 5802, 5803, 5809, 5878, 5911, 13854, 13867, 13869, 13871, 13881, 13886, 13891, 13903, 13909, 13922, 13925, 13954, 13960, 13977, 13978, 13982, 13985, 13986, 13988, 13990, 13991, 13995, 14000, 14003, 14009, 14010, 14016, 14020, 14030, 14036, 14046, 14058, 14059, 14064, 14067, 14078, 14081, 14083, 14092, 14093, 14098, 14108, 14110, 14123, 14125, 14126, 14127, 14128, 14129, 14133, 14145, 14146, 14147, 14152, 14157, 14158, 14159, 14160, 14164, 14166, 14170, 14172, 14178, 14182, 14184, 14185, 14186, 14187, 14189, 14191, 14196, 14198, 14200, 14213, 14217, 14218, 14219, 14222, 14223, 14224, 14227, 14230, 14240, 14241, 14246, 14247, 14273, 14285, 14288, 14303, 14320, 14355, 14357, 14358, 14363, 14365, 14522, 14523, 14528, 14531, 14536, 14538, 14543, 14615, 14625, 14662, 14663, 14668, 14684, 14687, 14698, 14804, 14873, 14890, 14891, 14897, 14907, 14915, 14937, 15180, 15184, 15202, 15227, 15231, 15238, 15239, 15247, 15253, 15308, 15324, 15325, 15335, 15356, 15368, 15494, 15550, 15553, 15557, 15573, 15620, 15662, 15679, 15685, 15697, 15707, 15708, 15709, 15722, 15725, 15753, 15772, 15776, 15781, 15787, 15831, 15833, 15835, 15840, 15894, 16079, 16087, 16093, 16105, 16118, 16145, 16175, 16176, 16177, 16190, 16197, 16199, 16206, 16249, 16295, 16302, 16317, 16330, 16331, 16337, 16406, 16439], "verts_ind": [4595, 4596, 4599, 4600, 4601, 4602, 4603, 4604, 4605, 4606, 4607, 4627, 4628, 4629, 4630, 4631, 4634, 4637, 4639, 4640, 4641, 4642, 4643, 4644, 4656, 4657, 4658, 4659, 4660, 4669, 4670, 4671, 4672, 4688, 4691, 4692, 4693, 4694, 4695, 4696, 4697, 4698, 4702, 4704, 4707, 4708, 4709, 4710, 4711, 4727, 4728, 4729, 4730, 4731, 4732, 4734, 4735, 4736, 4741, 4743, 4744, 4745, 4746, 4753, 4754, 4755, 4756, 4758, 4770, 4773, 4774, 4775, 4776, 4777, 4778, 4779, 4800, 4801, 4802, 4803, 4804, 4805, 4836, 4846, 4847, 4850, 4853, 4854, 4859, 4860, 4861, 4862, 4863, 4870, 4877, 4878, 4880, 4881, 4885, 4892, 4894, 4895, 4898, 4903, 4913, 4916, 4940, 4941, 4943, 4944, 4945, 4955, 4956, 4957, 4958, 4959, 4962, 4963, 4965, 4971, 4972, 4973, 4974, 4985, 4987, 4988, 4989, 4992, 4993, 4994, 4995, 5000, 5001, 5002, 5011, 5012, 5013, 5017, 5018, 5025, 5028, 5049, 5050, 5051, 5052, 5053, 5054, 5055, 5056, 5057, 5067, 5068, 5069, 5070, 5071, 5074, 5075, 5077, 5083, 5084, 5085, 5086, 5099, 5102, 5103, 5104, 5105, 5111, 5112, 5115, 5116, 5121, 5122, 5128, 5129, 5138, 5139, 5164, 5165, 5166, 5168, 5178, 5179, 5180, 5181, 5182, 5185, 5188, 5192, 5194, 5195, 5196, 5197, 5198, 5209, 5210, 5211, 5212, 5215, 5216, 5217, 5218, 5219, 5220, 5221, 5222, 5229, 5230, 5233, 5234, 5235, 5236, 5237, 5239, 5240, 5245, 5246, 5255, 5256, 5282, 5283, 5295, 5296, 5297, 5298, 5299, 5300, 5301, 5302, 5303, 5305, 5314, 5315, 5316, 5318, 5325, 5326, 5347, 5348, 5351, 5353, 5354, 5355, 5356, 5357, 5358, 5359, 5360, 5370, 5371, 5372, 5373, 5380, 5386, 5387, 5388, 5389, 5390, 5391, 5392, 5393]} -------------------------------------------------------------------------------- /body_utils/body_segments/L_Leg.json: -------------------------------------------------------------------------------- 1 | {"verts_ind": [5774, 5781, 5789, 5790, 5791, 5792, 5793, 5794, 5797, 5805, 5806, 5807, 5808, 5813, 5814, 5815, 5816, 5817, 5818, 5824, 5827, 5830, 5831, 5832, 5839, 5840, 5842, 5843, 5844, 5847, 5850, 5851, 5854, 5855, 5859, 5861, 5862, 5864, 5865, 5869, 5902, 5906, 5907, 5908, 5909, 5910, 5911, 5912, 5913, 5914, 5915, 5916, 5917, 8866, 8867, 8868, 8879, 8880, 8881, 8882, 8883, 8884, 8888, 8889, 8890, 8891, 8897, 8898, 8899, 8900, 8901, 8902, 8903, 8904, 8905, 8906, 8907, 8908, 8909, 8910, 8911, 8912, 8913, 8914, 8915, 8916, 8917, 8919, 8920, 8921, 8922, 8923, 8924, 8925, 8929, 8930, 8934], "faces_ind": [4257, 4259, 4274, 4283, 4284, 4286, 4288, 4307, 4309, 4311, 4314, 4320, 4340, 4354, 4371, 4422, 4462, 4463, 4464, 4467, 4469, 4470, 4471, 4472, 4478, 4530, 4545, 4546, 4547, 4549, 4550, 4551, 4552, 4553, 4554, 4568, 4581, 4616, 4618, 4874, 4875, 4876, 4978, 5004, 5072, 5074, 5075, 5076, 5077, 5204, 5263, 5264, 5265, 5268, 5324, 5353, 5375, 5376, 5386, 5390, 5627, 5684, 5694, 5712, 5783, 5785, 5860, 5867, 14790, 14805, 14808, 14814, 14815, 14817, 14819, 14834, 14840, 14845, 14858, 14859, 14871, 14922, 14953, 14993, 14994, 14995, 14998, 15000, 15001, 15002, 15003, 15009, 15061, 15076, 15077, 15078, 15080, 15081, 15082, 15083, 15084, 15085, 15099, 15112, 15147, 15149, 15405, 15406, 15407, 15434, 15509, 15535, 15602, 15604, 15605, 15606, 15607, 15733, 15793, 15794, 15797, 15853, 15904, 15905, 15915, 15919, 16156, 16213, 16240, 16311, 16313, 16388, 16394, 16395, 16415]} -------------------------------------------------------------------------------- /body_utils/body_segments/R_Hand.json: -------------------------------------------------------------------------------- 1 | {"faces_ind": [6679, 6687, 6692, 6700, 6712, 6713, 6717, 6718, 6719, 6720, 6730, 6732, 6733, 6735, 6736, 6739, 6749, 6750, 6759, 6762, 6790, 6798, 6801, 6803, 6818, 6820, 6821, 6823, 6824, 6830, 6831, 6847, 6854, 6857, 6864, 6874, 6878, 6879, 6882, 6887, 6889, 6894, 6898, 6910, 6942, 6943, 6944, 6945, 6946, 6948, 6951, 6952, 6953, 6954, 6958, 6962, 6963, 6964, 6971, 6972, 6975, 6980, 6987, 6991, 6992, 6999, 7005, 7007, 7019, 7059, 7060, 7061, 7062, 7063, 7070, 7071, 7074, 7079, 7080, 7081, 7082, 7088, 7091, 7100, 7106, 7110, 7117, 7123, 7125, 7137, 7143, 7177, 7178, 7179, 7180, 7183, 7186, 7187, 7188, 7189, 7197, 7198, 7199, 7206, 7209, 7214, 7218, 7220, 7225, 7226, 7229, 7233, 7239, 7241, 7244, 7252, 7253, 7260, 7280, 7293, 7294, 7295, 7296, 7297, 7298, 7299, 7302, 7303, 7304, 7305, 7308, 7309, 7313, 7314, 7315, 7316, 7321, 7322, 7325, 7330, 7331, 7333, 7337, 7364, 7365, 7366, 7379, 7380, 7381, 7382, 7383, 7384, 7385, 7391, 7392, 7400, 7401, 7402, 7403, 7408, 7411, 7412, 7414, 7415, 7416, 7417, 7418, 7419, 7420, 7421, 7422, 7423, 7424, 7426, 7428, 7429, 7437, 7440, 7453, 7462, 7463, 7468, 7479, 7480, 7481, 17207, 17215, 17219, 17220, 17228, 17240, 17241, 17245, 17246, 17247, 17248, 17258, 17260, 17263, 17264, 17277, 17278, 17287, 17290, 17317, 17325, 17328, 17330, 17345, 17347, 17348, 17350, 17351, 17357, 17373, 17380, 17383, 17385, 17390, 17400, 17404, 17405, 17408, 17413, 17415, 17420, 17424, 17435, 17467, 17468, 17469, 17470, 17471, 17473, 17476, 17477, 17478, 17479, 17482, 17483, 17487, 17488, 17489, 17496, 17497, 17500, 17502, 17505, 17512, 17516, 17517, 17524, 17530, 17532, 17544, 17584, 17585, 17586, 17587, 17588, 17595, 17596, 17599, 17604, 17605, 17606, 17613, 17616, 17625, 17631, 17635, 17642, 17648, 17650, 17662, 17669, 17702, 17703, 17704, 17705, 17706, 17708, 17711, 17713, 17714, 17722, 17723, 17724, 17731, 17734, 17739, 17743, 17745, 17746, 17750, 17751, 17754, 17758, 17764, 17766, 17769, 17777, 17778, 17785, 17818, 17819, 17820, 17821, 17822, 17824, 17827, 17828, 17829, 17830, 17833, 17834, 17838, 17839, 17840, 17841, 17846, 17847, 17850, 17855, 17856, 17858, 17862, 17889, 17890, 17891, 17904, 17905, 17906, 17907, 17908, 17909, 17910, 17916, 17917, 17925, 17926, 17927, 17933, 17936, 17937, 17939, 17940, 17941, 17942, 17943, 17944, 17945, 17946, 17947, 17948, 17951, 17954, 17964, 17969, 17977, 17986, 17987, 17992, 18003, 18004], "verts_ind": [7335, 7336, 7337, 7338, 7342, 7360, 7363, 7364, 7365, 7366, 7367, 7370, 7373, 7375, 7377, 7378, 7379, 7380, 7394, 7395, 7405, 7406, 7407, 7408, 7424, 7427, 7428, 7429, 7430, 7431, 7432, 7433, 7434, 7438, 7440, 7441, 7442, 7443, 7444, 7445, 7446, 7447, 7463, 7464, 7465, 7466, 7470, 7471, 7472, 7477, 7478, 7479, 7480, 7481, 7482, 7489, 7491, 7492, 7494, 7506, 7507, 7508, 7509, 7510, 7511, 7512, 7513, 7536, 7537, 7540, 7541, 7565, 7566, 7572, 7582, 7583, 7586, 7589, 7590, 7598, 7599, 7604, 7605, 7606, 7612, 7613, 7614, 7615, 7616, 7617, 7621, 7625, 7628, 7630, 7631, 7634, 7639, 7651, 7652, 7679, 7680, 7681, 7683, 7691, 7692, 7693, 7694, 7695, 7696, 7697, 7698, 7699, 7700, 7701, 7702, 7704, 7705, 7706, 7707, 7708, 7709, 7710, 7721, 7725, 7728, 7729, 7730, 7737, 7738, 7747, 7748, 7753, 7754, 7764, 7789, 7790, 7791, 7793, 7803, 7804, 7805, 7806, 7807, 7810, 7811, 7813, 7814, 7816, 7817, 7818, 7819, 7820, 7821, 7822, 7835, 7838, 7839, 7840, 7847, 7848, 7852, 7857, 7858, 7864, 7865, 7873, 7874, 7875, 7900, 7901, 7902, 7904, 7914, 7915, 7916, 7917, 7918, 7919, 7920, 7921, 7922, 7924, 7927, 7928, 7930, 7931, 7932, 7933, 7934, 7945, 7946, 7947, 7948, 7951, 7952, 7953, 7954, 7955, 7956, 7957, 7958, 7965, 7966, 7969, 7970, 7971, 7972, 7975, 7976, 7981, 7982, 7989, 7991, 7992, 8016, 8017, 8018, 8019, 8020, 8021, 8031, 8032, 8033, 8034, 8035, 8036, 8037, 8038, 8039, 8040, 8041, 8042, 8044, 8045, 8046, 8050, 8051, 8052, 8054, 8055, 8062, 8065, 8087, 8088, 8089, 8090, 8091, 8092, 8093, 8094, 8104, 8105, 8106, 8107, 8108, 8111, 8112, 8114, 8117, 8118, 8119, 8120, 8121, 8122, 8123, 8124, 8125, 8126, 8127]} -------------------------------------------------------------------------------- /body_utils/body_segments/R_Leg.json: -------------------------------------------------------------------------------- 1 | {"verts_ind": [8468, 8484, 8485, 8487, 8488, 8499, 8500, 8501, 8502, 8507, 8508, 8509, 8510, 8511, 8512, 8521, 8522, 8523, 8524, 8525, 8526, 8527, 8530, 8533, 8534, 8536, 8537, 8538, 8541, 8543, 8544, 8545, 8546, 8548, 8549, 8555, 8556, 8558, 8559, 8563, 8600, 8601, 8602, 8603, 8604, 8605, 8606, 8607, 8608, 8609, 8610, 8611, 8654, 8655, 8656, 8667, 8668, 8669, 8670, 8671, 8672, 8676, 8677, 8678, 8679, 8685, 8686, 8687, 8688, 8689, 8690, 8691, 8692, 8693, 8694, 8695, 8696, 8697, 8698, 8699, 8700, 8701, 8702, 8703, 8704, 8705, 8706, 8707, 8708, 8709, 8710, 8711, 8712, 8713, 8714, 8715, 8716], "faces_ind": [8100, 8105, 8107, 8108, 8112, 8113, 8117, 8122, 8123, 8132, 8133, 8139, 8149, 8157, 8171, 8182, 8188, 8192, 8227, 8229, 8230, 8231, 8232, 8295, 8296, 8297, 8298, 8299, 8300, 8301, 8302, 8303, 8304, 8305, 8306, 8307, 8308, 8309, 8310, 8311, 8312, 8313, 8314, 8315, 8316, 8317, 8318, 8319, 8320, 8323, 8324, 8325, 8326, 8327, 8328, 8329, 8330, 8331, 8333, 8334, 8335, 8336, 8337, 8558, 8559, 8560, 8561, 8562, 18623, 18628, 18630, 18631, 18635, 18636, 18640, 18645, 18646, 18648, 18656, 18662, 18666, 18672, 18680, 18694, 18703, 18705, 18711, 18750, 18752, 18753, 18754, 18755, 18818, 18819, 18820, 18821, 18822, 18823, 18824, 18825, 18826, 18827, 18828, 18829, 18830, 18831, 18832, 18833, 18834, 18835, 18836, 18837, 18838, 18839, 18840, 18841, 18842, 18843, 18847, 18848, 18849, 18850, 18851, 18852, 18853, 18854, 18856, 18857, 18858, 18859, 18860, 19081, 19082, 19083, 19084]} -------------------------------------------------------------------------------- /body_utils/body_segments/back.json: -------------------------------------------------------------------------------- 1 | {"faces_ind": [1517, 1535, 1593, 1598, 3126, 3195, 3264, 3269, 3272, 3507, 3552, 3739, 3846, 3861, 3866, 3868, 3869, 3871, 3872, 3881, 3882, 3891, 3892, 3893, 3894, 3895, 3896, 3918, 3921, 3922, 3977, 3978, 3981, 4010, 4043, 4099, 4100, 4191, 4224, 4229, 4302, 4573, 4656, 4665, 4675, 4677, 4737, 4748, 4756, 4779, 4787, 4791, 4796, 4800, 4805, 4807, 4811, 4818, 4819, 4820, 4832, 4838, 4844, 4862, 4949, 5009, 5010, 5038, 5055, 5066, 5085, 5121, 5253, 5388, 5599, 5630, 5687, 5691, 5698, 5763, 5799, 5833, 5859, 5862, 5888, 5896, 5965, 5972, 5973, 6024, 6025, 6026, 6159, 6169, 6174, 6175, 6343, 6543, 6552, 6553, 7518, 7526, 7530, 7534, 7648, 7655, 7657, 7681, 7682, 7683, 7684, 7685, 7686, 7687, 7703, 7704, 7705, 7712, 7723, 7724, 7725, 7726, 7727, 7737, 7738, 7739, 7740, 7741, 7742, 7743, 7744, 7745, 7771, 7772, 7773, 7783, 7784, 7786, 7787, 7853, 7854, 7855, 7856, 7931, 7932, 7933, 8038, 8047, 8049, 8447, 8450, 8451, 8465, 8466, 8477, 8492, 8503, 8504, 8505, 8509, 8510, 8511, 8527, 8541, 8591, 8617, 8623, 8662, 8671, 8672, 8673, 8676, 8687, 8749, 8755, 12066, 12084, 12142, 12147, 13660, 13729, 13798, 13803, 13806, 14084, 14271, 14378, 14393, 14398, 14400, 14404, 14413, 14414, 14417, 14423, 14424, 14425, 14426, 14428, 14453, 14454, 14509, 14510, 14513, 14542, 14575, 14631, 14632, 14723, 14755, 14760, 14833, 15104, 15187, 15196, 15206, 15208, 15268, 15273, 15279, 15287, 15310, 15318, 15322, 15327, 15331, 15336, 15338, 15342, 15349, 15350, 15351, 15363, 15369, 15375, 15393, 15480, 15540, 15541, 15569, 15585, 15596, 15614, 15650, 15782, 15917, 15938, 16128, 16159, 16216, 16220, 16227, 16233, 16291, 16327, 16361, 16362, 16387, 16390, 16416, 16493, 16500, 16501, 16552, 16553, 16554, 16687, 16697, 16702, 16703, 16815, 16871, 17071, 17080, 17081, 18042, 18050, 18054, 18058, 18172, 18179, 18181, 18182, 18205, 18206, 18207, 18208, 18209, 18210, 18211, 18227, 18228, 18229, 18236, 18248, 18249, 18250, 18251, 18262, 18263, 18264, 18265, 18266, 18267, 18268, 18269, 18295, 18296, 18297, 18307, 18308, 18310, 18311, 18377, 18378, 18379, 18380, 18455, 18456, 18457, 18561, 18570, 18572, 18969, 18972, 18973, 18988, 18999, 19014, 19025, 19026, 19027, 19031, 19032, 19033, 19049, 19063, 19139, 19184, 19194, 19195, 19198, 19209, 19271, 19277], "verts_ind": [3336, 3337, 3338, 3339, 3352, 3357, 3361, 3362, 3363, 3364, 3365, 3366, 3367, 3381, 3384, 3398, 3399, 3426, 3428, 3431, 3516, 3518, 3519, 3520, 3521, 3522, 3523, 3524, 3525, 3526, 3823, 3844, 3845, 3846, 3848, 3872, 3873, 3883, 3886, 3887, 3888, 3891, 3892, 3893, 4068, 4069, 4070, 4071, 4126, 4127, 4128, 4129, 4430, 4431, 4432, 4433, 4443, 4444, 4445, 4452, 4453, 4454, 4455, 4456, 4457, 5402, 5403, 5404, 5405, 5411, 5412, 5413, 5414, 5420, 5421, 5422, 5427, 5428, 5460, 5461, 5462, 5483, 5485, 5498, 5516, 5521, 5523, 5524, 5525, 5526, 5530, 5535, 5536, 5537, 5538, 5545, 5546, 5560, 5561, 5562, 5563, 5564, 5565, 5566, 5567, 5568, 5569, 5598, 5610, 5611, 5631, 5699, 6099, 6100, 6101, 6102, 6115, 6121, 6122, 6123, 6124, 6125, 6126, 6127, 6128, 6142, 6145, 6159, 6160, 6187, 6189, 6192, 6194, 6277, 6279, 6280, 6281, 6282, 6283, 6284, 6285, 6286, 6287, 6580, 6599, 6600, 6601, 6603, 6604, 6623, 6624, 6633, 6636, 6637, 6638, 6639, 6640, 6641, 6812, 6813, 6814, 6870, 6871, 6872, 6873, 7166, 7167, 7168, 7169, 7179, 7180, 7188, 7189, 7190, 7192, 8136, 8137, 8138, 8139, 8145, 8146, 8147, 8148, 8154, 8155, 8156, 8161, 8162, 8192, 8194, 8195, 8196, 8197, 8218, 8223, 8238, 8241, 8242, 8243, 8244, 8245, 8246, 8247, 8248, 8249, 8258, 8259, 8260, 8272, 8274, 8275, 8276, 8277, 8278, 8279, 8280, 8281, 8307, 8308, 8316, 8325, 8393]} -------------------------------------------------------------------------------- /body_utils/body_segments/butt.json: -------------------------------------------------------------------------------- 1 | {"faces_ind": [1573, 3081, 3125, 3483, 4047, 4049, 4086, 4189, 4190, 4194, 4198, 4199, 4213, 4215, 4237, 4293, 4402, 4511, 4592, 4615, 4745, 4851, 4856, 4859, 4912, 4929, 5124, 5276, 5313, 5339, 5343, 5354, 5414, 5428, 5600, 5605, 5625, 5685, 5775, 5781, 5843, 5874, 6007, 6008, 6009, 6022, 6173, 7899, 7902, 7903, 7908, 7909, 7912, 7914, 7915, 7916, 7917, 7918, 7919, 7920, 7921, 7922, 7923, 7924, 7925, 7926, 7927, 7928, 8012, 8013, 8014, 8015, 8016, 8050, 8061, 8430, 8431, 8441, 8442, 8448, 8521, 8522, 8523, 8524, 8525, 8526, 8539, 8564, 8606, 8732, 12122, 13615, 13659, 13916, 14015, 14579, 14581, 14618, 14721, 14722, 14726, 14730, 14731, 14744, 14746, 14768, 14824, 14933, 15042, 15123, 15146, 15276, 15382, 15390, 15443, 15460, 15653, 15805, 15842, 15868, 15872, 15883, 15943, 15957, 16129, 16134, 16154, 16214, 16303, 16309, 16371, 16402, 16535, 16537, 16550, 16701, 18426, 18427, 18432, 18433, 18436, 18438, 18439, 18440, 18441, 18442, 18443, 18444, 18445, 18446, 18447, 18448, 18449, 18450, 18451, 18452, 18536, 18537, 18538, 18539, 18540, 18573, 18584, 18585, 18952, 18953, 18963, 18964, 18970, 19043, 19044, 19045, 19046, 19047, 19048, 19061, 19086, 19128, 19254], "verts_ind": [3462, 3463, 3464, 3465, 3466, 3468, 3469, 3470, 3471, 3472, 3473, 3483, 3501, 3512, 3513, 3514, 3515, 3867, 3884, 3885, 5574, 5575, 5596, 5613, 5614, 5659, 5661, 5665, 5666, 5667, 5673, 5674, 5675, 5676, 5678, 5680, 5681, 5682, 5683, 5684, 5685, 5686, 5687, 5688, 5689, 5690, 5691, 5692, 5693, 5694, 5695, 5696, 5712, 5713, 5714, 5715, 5934, 6223, 6224, 6225, 6226, 6227, 6229, 6230, 6231, 6232, 6233, 6234, 6273, 6274, 6275, 6276, 6634, 6635, 6651, 7145, 8353, 8354, 8355, 8359, 8360, 8361, 8362, 8363, 8365, 8366, 8367, 8368, 8369, 8370, 8372, 8374, 8375, 8376, 8377, 8378, 8379, 8380, 8381, 8382, 8383, 8384, 8385, 8386, 8387, 8388, 8389, 8390, 8405, 8406, 8407, 8408, 8409]} -------------------------------------------------------------------------------- /body_utils/body_segments/gluteus.json: -------------------------------------------------------------------------------- 1 | {"faces_ind": [1573, 3081, 3125, 3483, 4047, 4049, 4086, 4189, 4190, 4194, 4198, 4199, 4213, 4215, 4237, 4293, 4402, 4511, 4592, 4615, 4745, 4851, 4856, 4859, 4912, 4929, 5124, 5276, 5313, 5339, 5343, 5354, 5414, 5428, 5600, 5605, 5625, 5685, 5775, 5781, 5843, 5874, 6007, 6008, 6009, 6022, 6173, 7899, 7902, 7903, 7908, 7909, 7912, 7914, 7915, 7916, 7917, 7918, 7919, 7920, 7921, 7922, 7923, 7924, 7925, 7926, 7927, 7928, 8012, 8013, 8014, 8015, 8016, 8050, 8061, 8430, 8431, 8441, 8442, 8448, 8521, 8522, 8523, 8524, 8525, 8526, 8539, 8564, 8606, 8732, 12122, 13615, 13659, 13916, 14015, 14579, 14581, 14618, 14721, 14722, 14726, 14730, 14731, 14744, 14746, 14768, 14824, 14933, 15042, 15123, 15146, 15276, 15382, 15390, 15443, 15460, 15653, 15805, 15842, 15868, 15872, 15883, 15943, 15957, 16129, 16134, 16154, 16214, 16303, 16309, 16371, 16402, 16535, 16537, 16550, 16701, 18426, 18427, 18432, 18433, 18436, 18438, 18439, 18440, 18441, 18442, 18443, 18444, 18445, 18446, 18447, 18448, 18449, 18450, 18451, 18452, 18536, 18537, 18538, 18539, 18540, 18573, 18584, 18585, 18952, 18953, 18963, 18964, 18970, 19043, 19044, 19045, 19046, 19047, 19048, 19061, 19086, 19128, 19254], "verts_ind": [3462, 3463, 3464, 3465, 3466, 3468, 3469, 3470, 3471, 3472, 3473, 3483, 3501, 3512, 3513, 3514, 3515, 3867, 3884, 3885, 5574, 5575, 5596, 5613, 5614, 5659, 5661, 5665, 5666, 5667, 5673, 5674, 5675, 5676, 5678, 5680, 5681, 5682, 5683, 5684, 5685, 5686, 5687, 5688, 5689, 5690, 5691, 5692, 5693, 5694, 5695, 5696, 5712, 5713, 5714, 5715, 5934, 6223, 6224, 6225, 6226, 6227, 6229, 6230, 6231, 6232, 6233, 6234, 6273, 6274, 6275, 6276, 6634, 6635, 6651, 7145, 8353, 8354, 8355, 8359, 8360, 8361, 8362, 8363, 8365, 8366, 8367, 8368, 8369, 8370, 8372, 8374, 8375, 8376, 8377, 8378, 8379, 8380, 8381, 8382, 8383, 8384, 8385, 8386, 8387, 8388, 8389, 8390, 8405, 8406, 8407, 8408, 8409]} -------------------------------------------------------------------------------- /body_utils/body_segments/thighs.json: -------------------------------------------------------------------------------- 1 | {"faces_ind": [1615, 1620, 1627, 3172, 3173, 3174, 3179, 3286, 3386, 4056, 4058, 4070, 4088, 4847, 4886, 4954, 5096, 5296, 5307, 5735, 5820, 5869, 5908, 6055, 6063, 6064, 6075, 6301, 6302, 6304, 6310, 6311, 6312, 6338, 7941, 7952, 7953, 7954, 7958, 7959, 8453, 8531, 12164, 12169, 12176, 13706, 13707, 13708, 13713, 13820, 13919, 14588, 14602, 14620, 15200, 15378, 15417, 15485, 15625, 15825, 15836, 16348, 16397, 16436, 16583, 16591, 16592, 16603, 16829, 16830, 16832, 16838, 16839, 16840, 16866, 18464, 18465, 18476, 18477, 18478, 18482, 18483, 18975, 19053], "verts_ind": [3530, 3531, 3533, 3592, 3596, 3599, 3600, 3601, 3602, 3611, 3612, 3613, 3614, 3615, 3616, 3617, 3618, 3619, 3620, 3622, 3623, 3653, 3654, 3656, 3801, 3802, 3803, 4091, 4092, 4093, 4094, 4095, 4096, 6290, 6291, 6327, 6353, 6356, 6357, 6358, 6359, 6376, 6377, 6378, 6379, 6380, 6381, 6382, 6383, 6384, 6385, 6416, 6417, 6559, 6560, 6561, 6835, 6836, 6837, 6838, 6839, 6840]} -------------------------------------------------------------------------------- /body_utils/left_heel_verts_id.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/body_utils/left_heel_verts_id.npy -------------------------------------------------------------------------------- /body_utils/left_toe_verts_id.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/body_utils/left_toe_verts_id.npy -------------------------------------------------------------------------------- /body_utils/left_whole_foot_verts_id.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/body_utils/left_whole_foot_verts_id.npy -------------------------------------------------------------------------------- /body_utils/right_heel_verts_id.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/body_utils/right_heel_verts_id.npy -------------------------------------------------------------------------------- /body_utils/right_toe_verts_id.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/body_utils/right_toe_verts_id.npy -------------------------------------------------------------------------------- /body_utils/right_whole_foot_verts_id.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/body_utils/right_whole_foot_verts_id.npy -------------------------------------------------------------------------------- /body_utils/smplx_mano_flame_correspondences/MANO_SMPLX_vertex_ids.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/body_utils/smplx_mano_flame_correspondences/MANO_SMPLX_vertex_ids.pkl -------------------------------------------------------------------------------- /body_utils/smplx_mano_flame_correspondences/SMPL-X__FLAME_vertex_ids.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/body_utils/smplx_mano_flame_correspondences/SMPL-X__FLAME_vertex_ids.npy -------------------------------------------------------------------------------- /body_utils/smplx_markerset.json: -------------------------------------------------------------------------------- 1 | {"markersets": [{"distance_from_skin": 0.0095, "indices": {"C7": 4391, "CLAV": 5533, "LANK": 5761, "LBSH": 4509, "LBWT": 5678, "LELB": 4245, "LFRM": 4379, "LFSH": 4515, "LFWT": 5726, "LHEL": 8852, "LIEL": 4258, "LKNE": 3638, "LKNI": 3781, "LSHN": 3705, "LTHI": 3479, "LUPA": 4039, "MBWT": 4297, "MFWT": 5615, "RANK": 8455, "RBSH": 7179, "RBWT": 7145, "RELB": 7028, "RFRM": 7115, "RFSH": 7251, "RFWT": 8421, "RHEL": 8634, "RIEL": 7036, "RKNE": 6401, "RKNI": 6539, "RSHN": 6466, "RTHI": 6352, "RUPA": 6778, "STRN": 5532, "ROWR": 7293, "RIWR": 7274, "LOWR": 4557, "LIWR": 4538, "T10": 5944}, "type": "body"}, {"distance_from_skin": 0.0095, "indices": {"LMT1": 5893, "LMT5": 5899, "LTOE": 5857, "RMT1": 8587, "RMT5": 8593, "RTOE": 8551}, "type": "foot"}, {"distance_from_skin": 0.0095, "indices": {"RFHD": 3035, "LFHD": 2148, "LBHD": 2041, "RBHD": 3076, "ARIEL": 9002}, "type": "head"}, {"distance_from_skin": 0.0002, "indices": {"LIDX1": 4875, "LIDX2": 4897, "LIDX3": 4931, "LMID1": 5014, "LMID2": 5020, "LMID3": 5045, "LPNK1": 5242, "LPNK2": 5250, "LPNK3": 5268, "LRNG1": 5124, "LRNG2": 5131, "LRNG3": 5149, "LTHM2": 4683, "LTHM3": 5321, "LTHM4": 5346, "RIDX1": 7611, "RIDX2": 7633, "RIDX3": 7667, "RMID1": 7750, "RMID2": 7756, "RMID3": 7781, "RPNK1": 7978, "RPNK2": 7984, "RPNK3": 8001, "RRNG1": 7860, "RRNG2": 7867, "RRNG3": 7884, "RTHM2": 7419, "RTHM3": 7602, "RTHM4": 8082}, "type": "finger"}, {"distance_from_skin": 0.0039, "indices": {"LTHM1": 4686, "RTHM1": 7423, "LIHAND": 4748, "LOHAND": 4615, "RIHAND": 7500, "ROHAND": 7351}, "type": "hand"}, {"distance_from_skin": 0.0002, "indices": {"MTH1": 2819, "MTH2": 2813, "MTH3": 8985, "MTH4": 1696, "MTH5": 1703, "MTH6": 1795, "MTH7": 8947, "MTH8": 2898, "CHN1": 8757, "CHN2": 9066}, "type": "face"}, {"distance_from_skin": 0.0002, "indices": {"REYE1": 2383, "REYE2": 2311, "LEYE1": 1043, "LEYE2": 919}, "type": "eyelids"}, {"distance_from_skin": 0, "indices": {"0": 4628, "1": 4641, "2": 4660, "3": 4690, "4": 4691, "5": 4710, "6": 4750, "7": 4885, "8": 4957, "9": 4970, "10": 5001, "11": 5012, "12": 5082, "13": 5111, "14": 5179, "15": 5193, "16": 5229, "17": 5296, "18": 5306, "19": 5315, "20": 5353, "21": 5387, "22": 7357, "23": 7396, "24": 7443, "25": 7446, "26": 7536, "27": 7589, "28": 7618, "29": 7625, "30": 7692, "31": 7706, "32": 7730, "33": 7748, "34": 7789, "35": 7847, "36": 7858, "37": 7924, "38": 7931, "39": 7976, "40": 8039, "41": 8050, "42": 8087, "43": 8122}, "type": "palm"}, {"distance_from_skin": 0, "indices": {"0": 4970, "1": 5082, "2": 5193, "3": 5306, "4": 5353, "5": 7706, "6": 7789, "7": 7924, "8": 8039, "9": 8087}, "type": "palm_5"}]} -------------------------------------------------------------------------------- /images/binoculars-0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/images/binoculars-0.jpg -------------------------------------------------------------------------------- /images/binoculars-60-first-view.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/images/binoculars-60-first-view.jpg -------------------------------------------------------------------------------- /images/binoculars-60.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/images/binoculars-60.jpg -------------------------------------------------------------------------------- /images/binoculars-movie-first-view.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/images/binoculars-movie-first-view.gif -------------------------------------------------------------------------------- /images/binoculars-video.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/images/binoculars-video.gif -------------------------------------------------------------------------------- /images/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/images/teaser.png -------------------------------------------------------------------------------- /images/two-stage-pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/images/two-stage-pipeline.png -------------------------------------------------------------------------------- /images/wineglass-0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/images/wineglass-0.jpg -------------------------------------------------------------------------------- /images/wineglass-60-first-view.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/images/wineglass-60-first-view.jpg -------------------------------------------------------------------------------- /images/wineglass-60.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/images/wineglass-60.jpg -------------------------------------------------------------------------------- /images/wineglass-movie-first-view.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/images/wineglass-movie-first-view.gif -------------------------------------------------------------------------------- /images/wineglass-video.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/images/wineglass-video.gif -------------------------------------------------------------------------------- /opt_grasppose.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | from collections import defaultdict 5 | 6 | import numpy as np 7 | # import smplx 8 | import open3d as o3d 9 | import torch 10 | from smplx.lbs import batch_rodrigues 11 | from tqdm import tqdm 12 | 13 | from utils.cfg_parser import Config 14 | from utils.utils import makelogger, makepath 15 | from WholeGraspPose.models.fittingop import FittingOP 16 | from WholeGraspPose.models.objectmodel import ObjectModel 17 | from WholeGraspPose.trainer import Trainer 18 | 19 | 20 | #### inference 21 | def load_object_data_random(object_name, n_samples): 22 | mesh_base = './dataset/contact_meshes' 23 | obj_mesh_base = o3d.io.read_triangle_mesh(os.path.join(mesh_base, object_name + '.ply')) 24 | obj_mesh_base.compute_vertex_normals() 25 | v_temp = torch.FloatTensor(obj_mesh_base.vertices).to(grabpose.device).view(1, -1, 3).repeat(n_samples, 1, 1) 26 | normal_temp = torch.FloatTensor(obj_mesh_base.vertex_normals).to(grabpose.device).view(1, -1, 3).repeat(n_samples, 1, 1) 27 | obj_model = ObjectModel(v_temp, normal_temp, n_samples) 28 | 29 | """Prepare transl/global_orient data""" 30 | """Example: randomly sample object height and orientation""" 31 | transf_transl_list = torch.rand(n_samples) + 0.6 #### can be customized 32 | global_orient_list = (np.pi)*torch.rand(n_samples) - np.pi/2 #### can be customized 33 | transl = torch.zeros(n_samples, 3) # for object model which is centered at object 34 | transf_transl = torch.zeros(n_samples, 3) 35 | transf_transl[:, -1] = transf_transl_list 36 | global_orient = torch.zeros(n_samples, 3) 37 | global_orient[:, -1] = global_orient_list 38 | global_orient_rotmat = batch_rodrigues(global_orient.view(-1, 3)).to(grabpose.device) # [N, 3, 3] 39 | 40 | object_output = obj_model(global_orient_rotmat, transl.to(grabpose.device), v_temp.to(grabpose.device), normal_temp.to(grabpose.device), rotmat=True) 41 | object_verts = object_output[0].detach().squeeze().cpu().numpy() if n_samples != 1 else object_output[0].detach().cpu().numpy() 42 | object_normal = object_output[1].detach().squeeze().cpu().numpy() if n_samples != 1 else object_output[1].detach().cpu().numpy() 43 | 44 | index = np.linspace(0, object_verts.shape[1], num=2048, endpoint=False,retstep=True,dtype=int)[0] 45 | 46 | verts_object = object_verts[:, index] 47 | normal_object = object_normal[:, index] 48 | global_orient_rotmat_6d = global_orient_rotmat.view(-1, 1, 9)[:, :, :6].detach().cpu().numpy() 49 | feat_object = np.concatenate([normal_object, global_orient_rotmat_6d.repeat(2048, axis=1)], axis=-1) 50 | 51 | verts_object = torch.from_numpy(verts_object).to(grabpose.device) 52 | feat_object = torch.from_numpy(feat_object).to(grabpose.device) 53 | transf_transl = transf_transl.to(grabpose.device) 54 | return {'verts_object':verts_object, 'normal_object': normal_object, 'global_orient':global_orient, 'global_orient_rotmat':global_orient_rotmat, 'feat_object':feat_object, 'transf_transl':transf_transl} 55 | 56 | 57 | def load_object_data_uniform_sample(object_name, n_samples): 58 | mesh_base = './dataset/contact_meshes' 59 | obj_mesh_base = o3d.io.read_triangle_mesh(os.path.join(mesh_base, object_name + '.ply')) 60 | obj_mesh_base.compute_vertex_normals() 61 | v_temp = torch.FloatTensor(obj_mesh_base.vertices).to(grabpose.device).view(1, -1, 3).repeat(n_samples, 1, 1) 62 | normal_temp = torch.FloatTensor(obj_mesh_base.vertex_normals).to(grabpose.device).view(1, -1, 3).repeat(n_samples, 1, 1) 63 | obj_model = ObjectModel(v_temp, normal_temp, n_samples) 64 | 65 | """Prepare transl/global_orient data""" 66 | """Example: uniformly sample object height and orientation (can be customized)""" 67 | transf_transl_list = torch.arange(n_samples)*1.0/(n_samples-1) + 0.5 68 | global_orient_list = (2*np.pi)*torch.arange(n_samples)/n_samples 69 | n_samples = transf_transl_list.shape[0] * global_orient_list.shape[0] 70 | transl = torch.zeros(n_samples, 3) # for object model which is centered at object 71 | transf_transl = torch.zeros(n_samples, 3) 72 | transf_transl[:, -1] = transf_transl_list.repeat_interleave(global_orient_list.shape[0]) 73 | global_orient = torch.zeros(n_samples, 3) 74 | global_orient[:, -1] = global_orient_list.repeat(transf_transl_list.shape[0]) # [6+6+6.....] 75 | global_orient_rotmat = batch_rodrigues(global_orient.view(-1, 3)).to(grabpose.device) # [N, 3, 3] 76 | 77 | object_output = obj_model(global_orient_rotmat, transl.to(grabpose.device), v_temp.to(grabpose.device), normal_temp.to(grabpose.device), rotmat=True) 78 | object_verts = object_output[0].detach().squeeze().cpu().numpy() if n_samples != 1 else object_output[0].detach().cpu().numpy() 79 | object_normal = object_output[1].detach().squeeze().cpu().numpy() if n_samples != 1 else object_output[1].detach().cpu().numpy() 80 | 81 | index = np.linspace(0, object_verts.shape[1], num=2048, endpoint=False,retstep=True,dtype=int)[0] 82 | 83 | verts_object = object_verts[:, index] 84 | normal_object = object_normal[:, index] 85 | global_orient_rotmat_6d = global_orient_rotmat.view(-1, 1, 9)[:, :, :6].detach().cpu().numpy() 86 | feat_object = np.concatenate([normal_object, global_orient_rotmat_6d.repeat(2048, axis=1)], axis=-1) 87 | 88 | verts_object = torch.from_numpy(verts_object).to(grabpose.device) 89 | feat_object = torch.from_numpy(feat_object).to(grabpose.device) 90 | transf_transl = transf_transl.to(grabpose.device) 91 | return {'verts_object':verts_object, 'normal_object': normal_object, 'global_orient':global_orient, 'global_orient_rotmat':global_orient_rotmat, 'feat_object':feat_object, 'transf_transl':transf_transl} 92 | 93 | def inference(grabpose, obj, n_samples, n_rand_samples, object_type, save_dir): 94 | """ prepare test object data: [verts_object, feat_object(normal + rotmat), transf_transl] """ 95 | ### object centered 96 | # for obj in grabpose.cfg.object_class: 97 | if object_type == 'uniform': 98 | obj_data = load_object_data_uniform_sample(obj, n_samples) 99 | elif object_type == 'random': 100 | obj_data = load_object_data_random(obj, n_samples) 101 | obj_data['feat_object'] = obj_data['feat_object'].permute(0,2,1) 102 | obj_data['verts_object'] = obj_data['verts_object'].permute(0,2,1) 103 | 104 | n_samples_total = obj_data['feat_object'].shape[0] 105 | 106 | markers_gen = [] 107 | object_contact_gen = [] 108 | markers_contact_gen = [] 109 | for i in range(n_samples_total): 110 | sample_results = grabpose.full_grasp_net.sample(obj_data['verts_object'][None, i].repeat(n_rand_samples,1,1), obj_data['feat_object'][None, i].repeat(n_rand_samples,1,1), obj_data['transf_transl'][None, i].repeat(n_rand_samples,1)) 111 | markers_gen.append((sample_results[0]+obj_data['transf_transl'][None, i])) 112 | markers_contact_gen.append(sample_results[1]) 113 | object_contact_gen.append(sample_results[2]) 114 | 115 | markers_gen = torch.cat(markers_gen, dim=0) # [B, N, 3] 116 | object_contact_gen = torch.cat(object_contact_gen, dim=0).squeeze() # [B, 2048] 117 | markers_contact_gen = torch.cat(markers_contact_gen, dim=0) # [B, N] 118 | 119 | output = {} 120 | output['markers_gen'] = markers_gen.detach().cpu().numpy() 121 | output['markers_contact_gen'] = markers_contact_gen.detach().cpu().numpy() 122 | output['object_contact_gen'] = object_contact_gen.detach().cpu().numpy() 123 | output['normal_object'] = obj_data['normal_object']#.repeat(n_rand_samples, axis=0) 124 | output['transf_transl'] = obj_data['transf_transl'].detach().cpu().numpy()#.repeat(n_rand_samples, axis=0) 125 | output['global_orient_object'] = obj_data['global_orient'].detach().cpu().numpy()#.repeat(n_rand_samples, axis=0) 126 | output['global_orient_object_rotmat'] = obj_data['global_orient_rotmat'].detach().cpu().numpy()#.repeat(n_rand_samples, axis=0) 127 | output['verts_object'] = (obj_data['verts_object']+obj_data['transf_transl'].view(-1,3,1).repeat(1,1,2048)).permute(0, 2, 1).detach().cpu().numpy()#.repeat(n_rand_samples, axis=0) 128 | 129 | save_path = os.path.join(save_dir, 'markers_results.npy') 130 | np.save(save_path, output) 131 | print('Saving to {}'.format(save_path)) 132 | 133 | return output 134 | 135 | def fitting_data_save(save_data, 136 | markers, 137 | markers_fit, 138 | smplxparams, 139 | gender, 140 | object_contact, body_contact, 141 | object_name, verts_object, global_orient_object, transf_transl_object): 142 | # markers & markers_fit 143 | save_data['markers'].append(markers) 144 | save_data['markers_fit'].append(markers_fit) 145 | # print('markers:', markers.shape) 146 | 147 | # body params 148 | for key in save_data['body'].keys(): 149 | # print(key, smplxparams[key].shape) 150 | save_data['body'][key].append(smplxparams[key].detach().cpu().numpy()) 151 | # object name & object params 152 | save_data['object_name'].append(object_name) 153 | save_data['gender'].append(gender) 154 | save_data['object']['transl'].append(transf_transl_object) 155 | save_data['object']['global_orient'].append(global_orient_object) 156 | save_data['object']['verts_object'].append(verts_object) 157 | 158 | # contact 159 | save_data['contact']['body'].append(body_contact) 160 | save_data['contact']['object'].append(object_contact) 161 | 162 | #### fitting 163 | 164 | def pose_opt(grabpose, samples_results, n_random_samples, obj, gender, save_dir, logger, device): 165 | # prepare objects 166 | n_samples = len(samples_results['verts_object']) 167 | verts_object = torch.tensor(samples_results['verts_object'])[:n_samples].to(device) # (n, 2048, 3) 168 | normals_object = torch.tensor(samples_results['normal_object'])[:n_samples].to(device) # (n, 2048, 3) 169 | global_orients_object = torch.tensor(samples_results['global_orient_object'])[:n_samples].to(device) # (n, 2048, 3) 170 | transf_transl_object = torch.tensor(samples_results['transf_transl'])[:n_samples].to(device) # (n, 2048, 3) 171 | 172 | # prepare body markers 173 | markers_gen = torch.tensor(samples_results['markers_gen']).to(device) # (n*k, 143, 3) 174 | object_contacts_gen = torch.tensor(samples_results['object_contact_gen']).view(markers_gen.shape[0], -1, 1).to(device) # (n, 2048, 1) 175 | markers_contacts_gen = torch.tensor(samples_results['markers_contact_gen']).view(markers_gen.shape[0], -1, 1).to(device) # (n, 143, 1) 176 | 177 | print('Fitting {} {} samples for {}...'.format(n_samples, cfg.gender, obj.upper())) 178 | 179 | fittingconfig={ 'init_lr_h': 0.008, 180 | 'num_iter': [300,400,500], 181 | 'batch_size': 1, 182 | 'num_markers': 143, 183 | 'device': device, 184 | 'cfg': cfg, 185 | 'verbose': False, 186 | 'hand_ncomps': 24, 187 | 'only_rec': False, # True / False 188 | 'contact_loss': 'contact', # contact / prior / False 189 | 'logger': logger, 190 | 'data_type': 'markers_143', 191 | } 192 | fittingop = FittingOP(fittingconfig) 193 | 194 | save_data_gen = {} 195 | for data in [save_data_gen]: 196 | data['markers'] = [] 197 | data['markers_fit'] = [] 198 | data['body'] = {} 199 | for key in ['betas', 'transl', 'global_orient', 'body_pose', 'leye_pose', 'reye_pose', 'left_hand_pose', 'right_hand_pose']: 200 | data['body'][key] = [] 201 | data['object'] = {} 202 | for key in ['transl', 'global_orient', 'verts_object']: 203 | data['object'][key] = [] 204 | data['contact'] = {} 205 | for key in ['body', 'object']: 206 | data['contact'][key] = [] 207 | data['gender'] = [] 208 | data['object_name'] = [] 209 | 210 | 211 | for i in tqdm(range(n_samples)): 212 | # prepare object 213 | vert_object = verts_object[None, i, :, :] 214 | normal_object = normals_object[None, i, :, :] 215 | 216 | marker_gen = markers_gen[i*n_random_samples:(i+1)*n_random_samples, :, :] 217 | object_contact_gen = object_contacts_gen[i*n_random_samples:(i+1)*n_random_samples, :, :] 218 | markers_contact_gen = markers_contacts_gen[i*n_random_samples:(i+1)*n_random_samples, :, :] 219 | 220 | for k in range(n_random_samples): 221 | print('Fitting for {}-th GEN...'.format(k+1)) 222 | markers_fit_gen, smplxparams_gen, loss_gen = fittingop.fitting(marker_gen[None, k, :], object_contact_gen[None, k, :], markers_contact_gen[None, k], vert_object, normal_object, gender) 223 | fitting_data_save(save_data_gen, 224 | marker_gen[k, :].detach().cpu().numpy().reshape(1, -1 ,3), 225 | markers_fit_gen[-1].squeeze().reshape(1, -1 ,3), 226 | smplxparams_gen[-1], 227 | gender, 228 | object_contact_gen[k].detach().cpu().numpy().reshape(1, -1), markers_contact_gen[k].detach().cpu().numpy().reshape(1, -1), 229 | obj, vert_object.detach().cpu().numpy(), global_orients_object[i].detach().cpu().numpy(), transf_transl_object[i].detach().cpu().numpy()) 230 | 231 | 232 | for data in [save_data_gen]: 233 | # for data in [save_data_gt, save_data_rec, save_data_gen]: 234 | data['markers'] = np.vstack(data['markers']) 235 | data['markers_fit'] = np.vstack(data['markers_fit']) 236 | for key in ['betas', 'transl', 'global_orient', 'body_pose', 'leye_pose', 'reye_pose', 'left_hand_pose', 'right_hand_pose']: 237 | data['body'][key] = np.vstack(data['body'][key]) 238 | for key in ['transl', 'global_orient', 'verts_object']: 239 | data['object'][key] = np.vstack(data['object'][key]) 240 | for key in ['body', 'object']: 241 | data['contact'][key] = np.vstack(data['contact'][key]) 242 | 243 | np.savez(os.path.join(save_dir, 'fitting_results.npz'), **save_data_gen) 244 | 245 | if __name__ == '__main__': 246 | 247 | parser = argparse.ArgumentParser(description='grabpose-Testing') 248 | 249 | parser.add_argument('--data_path', default = './dataset/GraspPose', type=str, 250 | help='The path to the folder that contains grabpose data') 251 | 252 | parser.add_argument('--object', default = None, type=str, 253 | help='object name') 254 | 255 | parser.add_argument('--gender', default=None, type=str, 256 | help='The gender of dataset') 257 | 258 | parser.add_argument('--config_path', default = None, type=str, 259 | help='The path to the confguration of the trained grabpose model') 260 | 261 | parser.add_argument('--exp_name', default = None, type=str, 262 | help='experiment name') 263 | 264 | parser.add_argument('--pose_ckpt_path', default = None, type=str, 265 | help='checkpoint path') 266 | 267 | parser.add_argument('--n_object_samples', default = 5, type=int, 268 | help='The number of object samples of this object') 269 | 270 | parser.add_argument('--type_object_samples', default = 'random', type=str, 271 | help='For the given object mesh, we provide two types of object heights and orientation sampling mode: random / uniform') 272 | 273 | parser.add_argument('--n_rand_samples_per_object', default = 1, type=int, 274 | help='The number of whole-body poses random samples generated per object') 275 | 276 | args = parser.parse_args() 277 | 278 | cwd = os.getcwd() 279 | 280 | best_net = os.path.join(cwd, args.pose_ckpt_path) 281 | 282 | vpe_path = '/configs/verts_per_edge.npy' 283 | c_weights_path = cwd + '/WholeGraspPose/configs/rhand_weight.npy' 284 | work_dir = cwd + '/results/{}/GraspPose'.format(args.exp_name) 285 | print(work_dir) 286 | config = { 287 | 'dataset_dir': args.data_path, 288 | 'work_dir':work_dir, 289 | 'vpe_path': vpe_path, 290 | 'c_weights_path': c_weights_path, 291 | 'exp_name': args.exp_name, 292 | 'gender': args.gender, 293 | 'best_net': best_net 294 | } 295 | 296 | cfg_path = 'WholeGraspPose/configs/WholeGraspPose.yaml' 297 | cfg = Config(default_cfg_path=cfg_path, **config) 298 | 299 | save_dir = os.path.join(work_dir, args.object) 300 | if not os.path.exists(save_dir): 301 | os.makedirs(save_dir) 302 | 303 | logger = makelogger(makepath(os.path.join(save_dir, '%s.log' % (args.object)), isfile=True)).info 304 | 305 | grabpose = Trainer(cfg=cfg, inference=True, logger=logger) 306 | 307 | samples_results = inference(grabpose, args.object, args.n_object_samples, args.n_rand_samples_per_object, args.type_object_samples, save_dir) 308 | fitting_results = pose_opt(grabpose, samples_results, args.n_rand_samples_per_object, args.object, cfg.gender, save_dir, logger, grabpose.device) 309 | 310 | -------------------------------------------------------------------------------- /train_grasppose.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | 5 | from utils.cfg_parser import Config 6 | from WholeGraspPose.trainer import Trainer 7 | 8 | if __name__ == '__main__': 9 | 10 | parser = argparse.ArgumentParser(description='GrabNet-Training') 11 | 12 | parser.add_argument('--work-dir', default='logs/GraspPose', type=str, 13 | help='The path to the downloaded grab data') 14 | 15 | parser.add_argument('--gender', default=None, type=str, 16 | help='The gender of dataset') 17 | 18 | parser.add_argument('--data_path', default = '/cluster/work/cvl/wuyan/data/GRAB-series/GrabPose_r_fullbody/data', type=str, 19 | help='The path to the folder that contains grabpose data') 20 | 21 | parser.add_argument('--batch-size', default=64, type=int, 22 | help='Training batch size') 23 | 24 | parser.add_argument('--n-workers', default=8, type=int, 25 | help='Number of PyTorch dataloader workers') 26 | 27 | parser.add_argument('--lr', default=5e-4, type=float, 28 | help='Training learning rate') 29 | 30 | parser.add_argument('--kl-coef', default=0.5, type=float, 31 | help='KL divergence coefficent for Coarsenet training') 32 | 33 | parser.add_argument('--use-multigpu', default=False, 34 | type=lambda arg: arg.lower() in ['true', '1'], 35 | help='If to use multiple GPUs for training') 36 | 37 | parser.add_argument('--exp_name', default = None, type=str, 38 | help='experiment name') 39 | 40 | 41 | args = parser.parse_args() 42 | 43 | work_dir = os.path.join(args.work_dir, args.exp_name) 44 | 45 | cwd = os.getcwd() 46 | default_cfg_path = 'WholeGraspPose/configs/WholeGraspPose.yaml' 47 | 48 | cfg = { 49 | 'batch_size': args.batch_size, 50 | 'n_workers': args.n_workers, 51 | 'use_multigpu': args.use_multigpu, 52 | 'kl_coef': args.kl_coef, 53 | 'dataset_dir': args.data_path, 54 | 'base_dir': cwd, 55 | 'work_dir': work_dir, 56 | 'base_lr': args.lr, 57 | 'best_net': None, 58 | 'gender': args.gender, 59 | 'exp_name': args.exp_name, 60 | } 61 | 62 | cfg = Config(default_cfg_path=default_cfg_path, **cfg) 63 | grabpose_trainer = Trainer(cfg=cfg) 64 | grabpose_trainer.fit() 65 | 66 | cfg = grabpose_trainer.cfg 67 | cfg.write_cfg(os.path.join(work_dir, 'TR%02d_%s' % (cfg.try_num, os.path.basename(default_cfg_path)))) 68 | -------------------------------------------------------------------------------- /utils/Pivots.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import numpy as np 5 | 6 | sys.path.append(os.getcwd()) 7 | from utils.Quaternions import Quaternions 8 | 9 | 10 | class Pivots: 11 | """ 12 | Pivots is an ndarray of angular rotations 13 | 14 | This wrapper provides some functions for 15 | working with pivots. 16 | 17 | These are particularly useful as a number 18 | of atomic operations (such as adding or 19 | subtracting) cannot be achieved using 20 | the standard arithmatic and need to be 21 | defined differently to work correctly 22 | """ 23 | 24 | def __init__(self, ps): self.ps = np.array(ps) 25 | def __str__(self): return "Pivots("+ str(self.ps) + ")" 26 | def __repr__(self): return "Pivots("+ repr(self.ps) + ")" 27 | 28 | def __add__(self, other): return Pivots(np.arctan2(np.sin(self.ps + other.ps), np.cos(self.ps + other.ps))) 29 | def __sub__(self, other): return Pivots(np.arctan2(np.sin(self.ps - other.ps), np.cos(self.ps - other.ps))) 30 | def __mul__(self, other): return Pivots(self.ps * other.ps) 31 | def __div__(self, other): return Pivots(self.ps / other.ps) 32 | def __mod__(self, other): return Pivots(self.ps % other.ps) 33 | def __pow__(self, other): return Pivots(self.ps ** other.ps) 34 | 35 | def __lt__(self, other): return self.ps < other.ps 36 | def __le__(self, other): return self.ps <= other.ps 37 | def __eq__(self, other): return self.ps == other.ps 38 | def __ne__(self, other): return self.ps != other.ps 39 | def __ge__(self, other): return self.ps >= other.ps 40 | def __gt__(self, other): return self.ps > other.ps 41 | 42 | def __abs__(self): return Pivots(abs(self.ps)) 43 | def __neg__(self): return Pivots(-self.ps) 44 | 45 | def __iter__(self): return iter(self.ps) 46 | def __len__(self): return len(self.ps) 47 | 48 | def __getitem__(self, k): return Pivots(self.ps[k]) 49 | def __setitem__(self, k, v): self.ps[k] = v.ps 50 | 51 | def _ellipsis(self): return tuple(map(lambda x: slice(None), self.shape)) 52 | 53 | def quaternions(self, plane='xz'): 54 | fa = self._ellipsis() 55 | axises = np.ones(self.ps.shape + (3,)) 56 | axises[fa + ("xyz".index(plane[0]),)] = 0.0 57 | axises[fa + ("xyz".index(plane[1]),)] = 0.0 58 | return Quaternions.from_angle_axis(self.ps, axises) 59 | 60 | def directions(self, plane='xz'): 61 | dirs = np.zeros((len(self.ps), 3)) 62 | dirs["xyz".index(plane[0])] = np.sin(self.ps) 63 | dirs["xyz".index(plane[1])] = np.cos(self.ps) 64 | return dirs 65 | 66 | def normalized(self): 67 | xs = np.copy(self.ps) 68 | while np.any(xs > np.pi): xs[xs > np.pi] = xs[xs > np.pi] - 2 * np.pi 69 | while np.any(xs < -np.pi): xs[xs < -np.pi] = xs[xs < -np.pi] + 2 * np.pi 70 | return Pivots(xs) 71 | 72 | def interpolate(self, ws): 73 | dir = np.average(self.directions, weights=ws, axis=0) 74 | return np.arctan2(dir[2], dir[0]) 75 | 76 | def copy(self): 77 | return Pivots(np.copy(self.ps)) 78 | 79 | @property 80 | def shape(self): 81 | return self.ps.shape 82 | 83 | @classmethod 84 | def from_quaternions(cls, qs, forward='z', plane='xz'): 85 | ds = np.zeros(qs.shape + (3,)) 86 | ds[...,'xyz'.index(forward)] = 1.0 87 | return Pivots.from_directions(qs * ds, plane=plane) 88 | 89 | @classmethod 90 | def from_directions(cls, ds, plane='xz'): 91 | ys = ds[...,'xyz'.index(plane[0])] 92 | xs = ds[...,'xyz'.index(plane[1])] 93 | return Pivots(np.arctan2(ys, xs)) 94 | 95 | -------------------------------------------------------------------------------- /utils/Pivots_torch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import torch 5 | 6 | sys.path.append(os.getcwd()) 7 | from utils.Quaternions_torch import Quaternions_torch 8 | 9 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 10 | 11 | class Pivots_torch: 12 | """ 13 | Pivots is an ndarray of angular rotations 14 | 15 | This wrapper provides some functions for 16 | working with pivots. 17 | 18 | These are particularly useful as a number 19 | of atomic operations (such as adding or 20 | subtracting) cannot be achieved using 21 | the standard arithmatic and need to be 22 | defined differently to work correctly 23 | """ 24 | 25 | def __init__(self, ps): self.ps = torch.tensor(ps).to(device) 26 | def __str__(self): return "Pivots("+ str(self.ps) + ")" 27 | def __repr__(self): return "Pivots("+ repr(self.ps) + ")" 28 | 29 | def __add__(self, other): return Pivots_torch(torch.atan2(torch.sin(self.ps + other.ps), torch.cos(self.ps + other.ps))) 30 | def __sub__(self, other): return Pivots_torch(torch.atan2(torch.sin(self.ps - other.ps), torch.cos(self.ps - other.ps))) 31 | def __mul__(self, other): return Pivots_torch(self.ps * other.ps) 32 | def __div__(self, other): return Pivots_torch(self.ps / other.ps) 33 | def __mod__(self, other): return Pivots_torch(self.ps % other.ps) 34 | def __pow__(self, other): return Pivots_torch(self.ps ** other.ps) 35 | 36 | def __lt__(self, other): return self.ps < other.ps 37 | def __le__(self, other): return self.ps <= other.ps 38 | def __eq__(self, other): return self.ps == other.ps 39 | def __ne__(self, other): return self.ps != other.ps 40 | def __ge__(self, other): return self.ps >= other.ps 41 | def __gt__(self, other): return self.ps > other.ps 42 | 43 | def __abs__(self): return Pivots_torch(torch.abs(self.ps)) 44 | def __neg__(self): return Pivots_torch(-self.ps) 45 | 46 | def __iter__(self): return iter(self.ps) 47 | def __len__(self): return len(self.ps) 48 | 49 | def __getitem__(self, k): return Pivots_torch(self.ps[k]) 50 | def __setitem__(self, k, v): self.ps[k] = v.ps 51 | 52 | def _ellipsis(self): return tuple(map(lambda x: slice(None), self.shape)) 53 | 54 | def quaternions(self, plane='xz'): 55 | fa = self._ellipsis() 56 | axises = torch.ones(self.ps.shape + (3,)).to(device) 57 | axises[fa + ("xyz".index(plane[0]),)] = 0.0 58 | axises[fa + ("xyz".index(plane[1]),)] = 0.0 59 | return Quaternions_torch.from_angle_axis(self.ps, axises) 60 | 61 | def directions(self, plane='xz'): 62 | dirs = torch.zeros((len(self.ps), 3)).to(device) 63 | dirs["xyz".index(plane[0])] = torch.sin(self.ps) 64 | dirs["xyz".index(plane[1])] = torch.cos(self.ps) 65 | return dirs 66 | 67 | def normalized(self): 68 | xs = self.ps.clone() 69 | while torch.any(xs > torch.pi): xs[xs > torch.pi] = xs[xs > torch.pi] - 2 * torch.pi 70 | while torch.any(xs < -torch.pi): xs[xs < -torch.pi] = xs[xs < -torch.pi] + 2 * torch.pi 71 | return Pivots_torch(xs) 72 | 73 | # def interpolate(self, ws): 74 | # dir = np.average(self.directions, weights=ws, axis=0) 75 | # return torch.atan2(dir[2], dir[0]) 76 | 77 | def clone(self): 78 | return Pivots_torch((self.ps).clone()) 79 | 80 | @property 81 | def shape(self): 82 | return self.ps.shape 83 | 84 | @classmethod 85 | def from_quaternions(cls, qs, forward='z', plane='xz'): 86 | ds = torch.zeros(qs.shape + (3,)).to(device) 87 | ds[...,'xyz'.index(forward)] = 1.0 88 | return Pivots_torch.from_directions(qs * ds, plane=plane) 89 | 90 | @classmethod 91 | def from_directions(cls, ds, plane='xz'): 92 | ys = ds[...,'xyz'.index(plane[0])] 93 | xs = ds[...,'xyz'.index(plane[1])] 94 | return Pivots_torch(torch.atan2(ys, xs)) 95 | 96 | -------------------------------------------------------------------------------- /utils/Quaternions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class Quaternions: 4 | """ 5 | Quaternions is a wrapper around a numpy ndarray 6 | that allows it to act as if it were an narray of 7 | a quaternion data type. 8 | 9 | Therefore addition, subtraction, multiplication, 10 | division, negation, absolute, are all defined 11 | in terms of quaternion operations such as quaternion 12 | multiplication. 13 | 14 | This allows for much neater code and many routines 15 | which conceptually do the same thing to be written 16 | in the same way for point data and for rotation data. 17 | 18 | The Quaternions class has been desgined such that it 19 | should support broadcasting and slicing in all of the 20 | usual ways. 21 | """ 22 | 23 | def __init__(self, qs): 24 | if isinstance(qs, np.ndarray): 25 | 26 | if len(qs.shape) == 1: qs = np.array([qs]) 27 | self.qs = qs 28 | return 29 | 30 | if isinstance(qs, Quaternions): 31 | self.qs = qs.qs 32 | return 33 | 34 | raise TypeError('Quaternions must be constructed from iterable, numpy array, or Quaternions, not %s' % type(qs)) 35 | 36 | def __str__(self): return "Quaternions("+ str(self.qs) + ")" 37 | def __repr__(self): return "Quaternions("+ repr(self.qs) + ")" 38 | 39 | """ Helper Methods for Broadcasting and Data extraction """ 40 | 41 | @classmethod 42 | def _broadcast(cls, sqs, oqs, scalar=False): 43 | 44 | if isinstance(oqs, float): return sqs, oqs * np.ones(sqs.shape[:-1]) 45 | 46 | ss = np.array(sqs.shape) if not scalar else np.array(sqs.shape[:-1]) 47 | os = np.array(oqs.shape) 48 | 49 | if len(ss) != len(os): 50 | raise TypeError('Quaternions cannot broadcast together shapes %s and %s' % (sqs.shape, oqs.shape)) 51 | 52 | if np.all(ss == os): return sqs, oqs 53 | 54 | if not np.all((ss == os) | (os == np.ones(len(os))) | (ss == np.ones(len(ss)))): 55 | raise TypeError('Quaternions cannot broadcast together shapes %s and %s' % (sqs.shape, oqs.shape)) 56 | 57 | sqsn, oqsn = sqs.copy(), oqs.copy() 58 | 59 | for a in np.where(ss == 1)[0]: sqsn = sqsn.repeat(os[a], axis=a) 60 | for a in np.where(os == 1)[0]: oqsn = oqsn.repeat(ss[a], axis=a) 61 | 62 | return sqsn, oqsn 63 | 64 | """ Adding Quaterions is just Defined as Multiplication """ 65 | 66 | def __add__(self, other): return self * other 67 | def __sub__(self, other): return self / other 68 | 69 | """ Quaterion Multiplication """ 70 | 71 | def __mul__(self, other): 72 | """ 73 | Quaternion multiplication has three main methods. 74 | 75 | When multiplying a Quaternions array by Quaternions 76 | normal quaternion multiplication is performed. 77 | 78 | When multiplying a Quaternions array by a vector 79 | array of the same shape, where the last axis is 3, 80 | it is assumed to be a Quaternion by 3D-Vector 81 | multiplication and the 3D-Vectors are rotated 82 | in space by the Quaternions. 83 | 84 | When multipplying a Quaternions array by a scalar 85 | or vector of different shape it is assumed to be 86 | a Quaternions by Scalars multiplication and the 87 | Quaternions are scaled using Slerp and the identity 88 | quaternions. 89 | """ 90 | 91 | """ If Quaternions type do Quaternions * Quaternions """ 92 | if isinstance(other, Quaternions): 93 | 94 | sqs, oqs = Quaternions._broadcast(self.qs, other.qs) 95 | 96 | q0 = sqs[...,0]; q1 = sqs[...,1]; 97 | q2 = sqs[...,2]; q3 = sqs[...,3]; 98 | r0 = oqs[...,0]; r1 = oqs[...,1]; 99 | r2 = oqs[...,2]; r3 = oqs[...,3]; 100 | 101 | qs = np.empty(sqs.shape) 102 | qs[...,0] = r0 * q0 - r1 * q1 - r2 * q2 - r3 * q3 103 | qs[...,1] = r0 * q1 + r1 * q0 - r2 * q3 + r3 * q2 104 | qs[...,2] = r0 * q2 + r1 * q3 + r2 * q0 - r3 * q1 105 | qs[...,3] = r0 * q3 - r1 * q2 + r2 * q1 + r3 * q0 106 | 107 | return Quaternions(qs) 108 | 109 | """ If array type do Quaternions * Vectors """ 110 | if isinstance(other, np.ndarray) and other.shape[-1] == 3: 111 | vs = Quaternions(np.concatenate([np.zeros(other.shape[:-1] + (1,)), other], axis=-1)) 112 | return (self * (vs * -self)).imaginaries 113 | 114 | """ If float do Quaternions * Scalars """ 115 | if isinstance(other, np.ndarray) or isinstance(other, float): 116 | return Quaternions.slerp(Quaternions.id_like(self), self, other) 117 | 118 | raise TypeError('Cannot multiply/add Quaternions with type %s' % str(type(other))) 119 | 120 | def __div__(self, other): 121 | """ 122 | When a Quaternion type is supplied, division is defined 123 | as multiplication by the inverse of that Quaternion. 124 | 125 | When a scalar or vector is supplied it is defined 126 | as multiplicaion of one over the supplied value. 127 | Essentially a scaling. 128 | """ 129 | 130 | if isinstance(other, Quaternions): return self * (-other) 131 | if isinstance(other, np.ndarray): return self * (1.0 / other) 132 | if isinstance(other, float): return self * (1.0 / other) 133 | raise TypeError('Cannot divide/subtract Quaternions with type %s' + str(type(other))) 134 | 135 | def __eq__(self, other): return self.qs == other.qs 136 | def __ne__(self, other): return self.qs != other.qs 137 | 138 | def __neg__(self): 139 | """ Invert Quaternions """ 140 | return Quaternions(self.qs * np.array([[1, -1, -1, -1]])) 141 | 142 | def __abs__(self): 143 | """ Unify Quaternions To Single Pole """ 144 | qabs = self.normalized().copy() 145 | top = np.sum(( qabs.qs) * np.array([1,0,0,0]), axis=-1) 146 | bot = np.sum((-qabs.qs) * np.array([1,0,0,0]), axis=-1) 147 | qabs.qs[top < bot] = -qabs.qs[top < bot] 148 | return qabs 149 | 150 | def __iter__(self): return iter(self.qs) 151 | def __len__(self): return len(self.qs) 152 | 153 | def __getitem__(self, k): return Quaternions(self.qs[k]) 154 | def __setitem__(self, k, v): self.qs[k] = v.qs 155 | 156 | @property 157 | def lengths(self): 158 | return np.sum(self.qs**2.0, axis=-1)**0.5 159 | 160 | @property 161 | def reals(self): 162 | return self.qs[...,0] 163 | 164 | @property 165 | def imaginaries(self): 166 | return self.qs[...,1:4] 167 | 168 | @property 169 | def shape(self): return self.qs.shape[:-1] 170 | 171 | def repeat(self, n, **kwargs): 172 | return Quaternions(self.qs.repeat(n, **kwargs)) 173 | 174 | def normalized(self): 175 | return Quaternions(self.qs / self.lengths[...,np.newaxis]) 176 | 177 | def log(self): 178 | norm = abs(self.normalized()) 179 | imgs = norm.imaginaries 180 | lens = np.sqrt(np.sum(imgs**2, axis=-1)) 181 | lens = np.arctan2(lens, norm.reals) / (lens + 1e-10) 182 | return imgs * lens[...,np.newaxis] 183 | 184 | def constrained(self, axis): 185 | 186 | rl = self.reals 187 | im = np.sum(axis * self.imaginaries, axis=-1) 188 | 189 | t1 = -2 * np.arctan2(rl, im) + np.pi 190 | t2 = -2 * np.arctan2(rl, im) - np.pi 191 | 192 | top = Quaternions.exp(axis[np.newaxis] * (t1[:,np.newaxis] / 2.0)) 193 | bot = Quaternions.exp(axis[np.newaxis] * (t2[:,np.newaxis] / 2.0)) 194 | img = self.dot(top) > self.dot(bot) 195 | 196 | ret = top.copy() 197 | ret[ img] = top[ img] 198 | ret[~img] = bot[~img] 199 | return ret 200 | 201 | def constrained_x(self): return self.constrained(np.array([1,0,0])) 202 | def constrained_y(self): return self.constrained(np.array([0,1,0])) 203 | def constrained_z(self): return self.constrained(np.array([0,0,1])) 204 | 205 | def dot(self, q): return np.sum(self.qs * q.qs, axis=-1) 206 | 207 | def copy(self): return Quaternions(np.copy(self.qs)) 208 | 209 | def reshape(self, s): 210 | self.qs.reshape(s) 211 | return self 212 | 213 | def interpolate(self, ws): 214 | return Quaternions.exp(np.average(abs(self).log, axis=0, weights=ws)) 215 | 216 | def euler(self, order='xyz'): 217 | 218 | q = self.normalized().qs 219 | q0 = q[...,0] 220 | q1 = q[...,1] 221 | q2 = q[...,2] 222 | q3 = q[...,3] 223 | es = np.zeros(self.shape + (3,)) 224 | 225 | if order == 'xyz': 226 | es[...,0] = np.arctan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2)) 227 | es[...,1] = np.arcsin((2 * (q0 * q2 - q3 * q1)).clip(-1,1)) 228 | es[...,2] = np.arctan2(2 * (q0 * q3 + q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3)) 229 | elif order == 'yzx': 230 | es[...,0] = np.arctan2(2 * (q1 * q0 - q2 * q3), -q1 * q1 + q2 * q2 - q3 * q3 + q0 * q0) 231 | es[...,1] = np.arctan2(2 * (q2 * q0 - q1 * q3), q1 * q1 - q2 * q2 - q3 * q3 + q0 * q0) 232 | es[...,2] = np.arcsin((2 * (q1 * q2 + q3 * q0)).clip(-1,1)) 233 | else: 234 | raise NotImplementedError('Cannot convert from ordering %s' % order) 235 | 236 | """ 237 | 238 | # These conversion don't appear to work correctly for Maya. 239 | # http://bediyap.com/programming/convert-quaternion-to-euler-rotations/ 240 | 241 | if order == 'xyz': 242 | es[fa + (0,)] = np.arctan2(2 * (q0 * q3 - q1 * q2), q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3) 243 | es[fa + (1,)] = np.arcsin((2 * (q1 * q3 + q0 * q2)).clip(-1,1)) 244 | es[fa + (2,)] = np.arctan2(2 * (q0 * q1 - q2 * q3), q0 * q0 - q1 * q1 - q2 * q2 + q3 * q3) 245 | elif order == 'yzx': 246 | es[fa + (0,)] = np.arctan2(2 * (q0 * q1 - q2 * q3), q0 * q0 - q1 * q1 + q2 * q2 - q3 * q3) 247 | es[fa + (1,)] = np.arcsin((2 * (q1 * q2 + q0 * q3)).clip(-1,1)) 248 | es[fa + (2,)] = np.arctan2(2 * (q0 * q2 - q1 * q3), q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3) 249 | elif order == 'zxy': 250 | es[fa + (0,)] = np.arctan2(2 * (q0 * q2 - q1 * q3), q0 * q0 - q1 * q1 - q2 * q2 + q3 * q3) 251 | es[fa + (1,)] = np.arcsin((2 * (q0 * q1 + q2 * q3)).clip(-1,1)) 252 | es[fa + (2,)] = np.arctan2(2 * (q0 * q3 - q1 * q2), q0 * q0 - q1 * q1 + q2 * q2 - q3 * q3) 253 | elif order == 'xzy': 254 | es[fa + (0,)] = np.arctan2(2 * (q0 * q2 + q1 * q3), q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3) 255 | es[fa + (1,)] = np.arcsin((2 * (q0 * q3 - q1 * q2)).clip(-1,1)) 256 | es[fa + (2,)] = np.arctan2(2 * (q0 * q1 + q2 * q3), q0 * q0 - q1 * q1 + q2 * q2 - q3 * q3) 257 | elif order == 'yxz': 258 | es[fa + (0,)] = np.arctan2(2 * (q1 * q2 + q0 * q3), q0 * q0 - q1 * q1 + q2 * q2 - q3 * q3) 259 | es[fa + (1,)] = np.arcsin((2 * (q0 * q1 - q2 * q3)).clip(-1,1)) 260 | es[fa + (2,)] = np.arctan2(2 * (q1 * q3 + q0 * q2), q0 * q0 - q1 * q1 - q2 * q2 + q3 * q3) 261 | elif order == 'zyx': 262 | es[fa + (0,)] = np.arctan2(2 * (q0 * q1 + q2 * q3), q0 * q0 - q1 * q1 - q2 * q2 + q3 * q3) 263 | es[fa + (1,)] = np.arcsin((2 * (q0 * q2 - q1 * q3)).clip(-1,1)) 264 | es[fa + (2,)] = np.arctan2(2 * (q0 * q3 + q1 * q2), q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3) 265 | else: 266 | raise KeyError('Unknown ordering %s' % order) 267 | 268 | """ 269 | 270 | # https://github.com/ehsan/ogre/blob/master/OgreMain/src/OgreMatrix3.cpp 271 | # Use this class and convert from matrix 272 | 273 | return es 274 | 275 | 276 | def average(self): 277 | 278 | if len(self.shape) == 1: 279 | 280 | import numpy.core.umath_tests as ut 281 | system = ut.matrix_multiply(self.qs[:,:,np.newaxis], self.qs[:,np.newaxis,:]).sum(axis=0) 282 | w, v = np.linalg.eigh(system) 283 | qiT_dot_qref = (self.qs[:,:,np.newaxis] * v[np.newaxis,:,:]).sum(axis=1) 284 | return Quaternions(v[:,np.argmin((1.-qiT_dot_qref**2).sum(axis=0))]) 285 | 286 | else: 287 | 288 | raise NotImplementedError('Cannot average multi-dimensionsal Quaternions') 289 | 290 | def angle_axis(self): 291 | 292 | norm = self.normalized() 293 | s = np.sqrt(1 - (norm.reals**2.0)) 294 | s[s == 0] = 0.001 295 | 296 | angles = 2.0 * np.arccos(norm.reals) 297 | axis = norm.imaginaries / s[...,np.newaxis] 298 | 299 | return angles, axis 300 | 301 | 302 | def transforms(self): 303 | 304 | qw = self.qs[...,0] 305 | qx = self.qs[...,1] 306 | qy = self.qs[...,2] 307 | qz = self.qs[...,3] 308 | 309 | x2 = qx + qx; y2 = qy + qy; z2 = qz + qz; 310 | xx = qx * x2; yy = qy * y2; wx = qw * x2; 311 | xy = qx * y2; yz = qy * z2; wy = qw * y2; 312 | xz = qx * z2; zz = qz * z2; wz = qw * z2; 313 | 314 | m = np.empty(self.shape + (3,3)) 315 | m[...,0,0] = 1.0 - (yy + zz) 316 | m[...,0,1] = xy - wz 317 | m[...,0,2] = xz + wy 318 | m[...,1,0] = xy + wz 319 | m[...,1,1] = 1.0 - (xx + zz) 320 | m[...,1,2] = yz - wx 321 | m[...,2,0] = xz - wy 322 | m[...,2,1] = yz + wx 323 | m[...,2,2] = 1.0 - (xx + yy) 324 | 325 | return m 326 | 327 | def ravel(self): 328 | return self.qs.ravel() 329 | 330 | @classmethod 331 | def id(cls, n): 332 | 333 | if isinstance(n, tuple): 334 | qs = np.zeros(n + (4,)) 335 | qs[...,0] = 1.0 336 | return Quaternions(qs) 337 | 338 | if isinstance(n, int) or isinstance(n, long): 339 | qs = np.zeros((n,4)) 340 | qs[:,0] = 1.0 341 | return Quaternions(qs) 342 | 343 | raise TypeError('Cannot Construct Quaternion from %s type' % str(type(n))) 344 | 345 | @classmethod 346 | def id_like(cls, a): 347 | qs = np.zeros(a.shape + (4,)) 348 | qs[...,0] = 1.0 349 | return Quaternions(qs) 350 | 351 | @classmethod 352 | def exp(cls, ws): 353 | 354 | ts = np.sum(ws**2.0, axis=-1)**0.5 355 | ts[ts == 0] = 0.001 356 | ls = np.sin(ts) / ts 357 | 358 | qs = np.empty(ws.shape[:-1] + (4,)) 359 | qs[...,0] = np.cos(ts) 360 | qs[...,1] = ws[...,0] * ls 361 | qs[...,2] = ws[...,1] * ls 362 | qs[...,3] = ws[...,2] * ls 363 | 364 | return Quaternions(qs).normalized() 365 | 366 | @classmethod 367 | def slerp(cls, q0s, q1s, a): 368 | 369 | fst, snd = cls._broadcast(q0s.qs, q1s.qs) 370 | fst, a = cls._broadcast(fst, a, scalar=True) 371 | snd, a = cls._broadcast(snd, a, scalar=True) 372 | 373 | len = np.sum(fst * snd, axis=-1) 374 | 375 | neg = len < 0.0 376 | len[neg] = -len[neg] 377 | snd[neg] = -snd[neg] 378 | 379 | amount0 = np.zeros(a.shape) 380 | amount1 = np.zeros(a.shape) 381 | 382 | linear = (1.0 - len) < 0.01 383 | omegas = np.arccos(len[~linear]) 384 | sinoms = np.sin(omegas) 385 | 386 | amount0[ linear] = 1.0 - a[linear] 387 | amount1[ linear] = a[linear] 388 | amount0[~linear] = np.sin((1.0 - a[~linear]) * omegas) / sinoms 389 | amount1[~linear] = np.sin( a[~linear] * omegas) / sinoms 390 | 391 | return Quaternions( 392 | amount0[...,np.newaxis] * fst + 393 | amount1[...,np.newaxis] * snd) 394 | 395 | @classmethod 396 | def between(cls, v0s, v1s): 397 | a = np.cross(v0s, v1s) 398 | w = np.sqrt((v0s**2).sum(axis=-1) * (v1s**2).sum(axis=-1)) + (v0s * v1s).sum(axis=-1) 399 | return Quaternions(np.concatenate([w[...,np.newaxis], a], axis=-1)).normalized() 400 | 401 | @classmethod 402 | def from_angle_axis(cls, angles, axis): 403 | axis = axis / (np.sqrt(np.sum(axis**2, axis=-1)) + 1e-10)[...,np.newaxis] 404 | sines = np.sin(angles / 2.0)[...,np.newaxis] 405 | cosines = np.cos(angles / 2.0)[...,np.newaxis] 406 | return Quaternions(np.concatenate([cosines, axis * sines], axis=-1)) 407 | 408 | @classmethod 409 | def from_euler(cls, es, order='xyz', world=False): 410 | 411 | axis = { 412 | 'x' : np.array([1,0,0]), 413 | 'y' : np.array([0,1,0]), 414 | 'z' : np.array([0,0,1]), 415 | } 416 | 417 | q0s = Quaternions.from_angle_axis(es[...,0], axis[order[0]]) 418 | q1s = Quaternions.from_angle_axis(es[...,1], axis[order[1]]) 419 | q2s = Quaternions.from_angle_axis(es[...,2], axis[order[2]]) 420 | 421 | return (q2s * (q1s * q0s)) if world else (q0s * (q1s * q2s)) 422 | 423 | @classmethod 424 | def from_transforms(cls, ts): 425 | 426 | d0, d1, d2 = ts[...,0,0], ts[...,1,1], ts[...,2,2] 427 | 428 | q0 = ( d0 + d1 + d2 + 1.0) / 4.0 429 | q1 = ( d0 - d1 - d2 + 1.0) / 4.0 430 | q2 = (-d0 + d1 - d2 + 1.0) / 4.0 431 | q3 = (-d0 - d1 + d2 + 1.0) / 4.0 432 | 433 | q0 = np.sqrt(q0.clip(0,None)) 434 | q1 = np.sqrt(q1.clip(0,None)) 435 | q2 = np.sqrt(q2.clip(0,None)) 436 | q3 = np.sqrt(q3.clip(0,None)) 437 | 438 | c0 = (q0 >= q1) & (q0 >= q2) & (q0 >= q3) 439 | c1 = (q1 >= q0) & (q1 >= q2) & (q1 >= q3) 440 | c2 = (q2 >= q0) & (q2 >= q1) & (q2 >= q3) 441 | c3 = (q3 >= q0) & (q3 >= q1) & (q3 >= q2) 442 | 443 | q1[c0] *= np.sign(ts[c0,2,1] - ts[c0,1,2]) 444 | q2[c0] *= np.sign(ts[c0,0,2] - ts[c0,2,0]) 445 | q3[c0] *= np.sign(ts[c0,1,0] - ts[c0,0,1]) 446 | 447 | q0[c1] *= np.sign(ts[c1,2,1] - ts[c1,1,2]) 448 | q2[c1] *= np.sign(ts[c1,1,0] + ts[c1,0,1]) 449 | q3[c1] *= np.sign(ts[c1,0,2] + ts[c1,2,0]) 450 | 451 | q0[c2] *= np.sign(ts[c2,0,2] - ts[c2,2,0]) 452 | q1[c2] *= np.sign(ts[c2,1,0] + ts[c2,0,1]) 453 | q3[c2] *= np.sign(ts[c2,2,1] + ts[c2,1,2]) 454 | 455 | q0[c3] *= np.sign(ts[c3,1,0] - ts[c3,0,1]) 456 | q1[c3] *= np.sign(ts[c3,2,0] + ts[c3,0,2]) 457 | q2[c3] *= np.sign(ts[c3,2,1] + ts[c3,1,2]) 458 | 459 | qs = np.empty(ts.shape[:-2] + (4,)) 460 | qs[...,0] = q0 461 | qs[...,1] = q1 462 | qs[...,2] = q2 463 | qs[...,3] = q3 464 | 465 | return cls(qs) 466 | 467 | 468 | -------------------------------------------------------------------------------- /utils/Quaternions_torch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 4 | 5 | class Quaternions_torch: 6 | """ 7 | Quaternions is a wrapper around a numpy ndarray 8 | that allows it to act as if it were an narray of 9 | a quaternion data type. 10 | 11 | Therefore addition, subtraction, multiplication, 12 | division, negation, absolute, are all defined 13 | in terms of quaternion operations such as quaternion 14 | multiplication. 15 | 16 | This allows for much neater code and many routines 17 | which conceptually do the same thing to be written 18 | in the same way for point data and for rotation data. 19 | 20 | The Quaternions class has been desgined such that it 21 | should support broadcasting and slicing in all of the 22 | usual ways. 23 | """ 24 | 25 | def __init__(self, qs): 26 | if isinstance(qs, torch.Tensor): 27 | 28 | if len(qs.shape) == 1: qs = torch.tensor([qs]) 29 | self.qs = qs 30 | return 31 | 32 | if isinstance(qs, Quaternions_torch): 33 | self.qs = qs.qs 34 | return 35 | 36 | raise TypeError('Quaternions must be constructed from iterable, numpy array, or Quaternions, not %s' % type(qs)) 37 | 38 | def __str__(self): return "Quaternions("+ str(self.qs) + ")" 39 | def __repr__(self): return "Quaternions("+ repr(self.qs) + ")" 40 | 41 | """ Helper Methods for Broadcasting and Data extraction """ 42 | 43 | @classmethod 44 | def _broadcast(cls, sqs, oqs, scalar=False): 45 | 46 | if isinstance(oqs, float): return sqs, oqs * torch.ones(sqs.shape[:-1]) 47 | 48 | ss = torch.tensor(sqs.shape).to(device) if not scalar else torch.tensor(sqs.shape[:-1]).to(device) 49 | os = torch.tensor(oqs.shape).to(device) 50 | 51 | if len(ss) != len(os): 52 | raise TypeError('Quaternions cannot broadcast together shapes %s and %s' % (sqs.shape, oqs.shape)) 53 | 54 | if torch.all(ss == os): return sqs, oqs # TODO: check torch.all 55 | # ipdb.set_trace() 56 | if not torch.all((ss.to(device) == os.to(device)) | (os == torch.ones(len(os)).to(device)) | (ss == torch.ones(len(ss)).to(device))): 57 | raise TypeError('Quaternions cannot broadcast together shapes %s and %s' % (sqs.shape, oqs.shape)) 58 | 59 | sqsn, oqsn = sqs.clone(), oqs.clone() 60 | 61 | for a in torch.where(ss == 1)[0]: sqsn = sqsn.repeat_interleave(os[a], dim=a) 62 | for a in torch.where(os == 1)[0]: oqsn = oqsn.repeat_interleave(ss[a], dim=a) 63 | 64 | return sqsn, oqsn 65 | 66 | """ Adding Quaterions is just Defined as Multiplication """ 67 | 68 | def __add__(self, other): return self * other 69 | def __sub__(self, other): return self / other 70 | 71 | """ Quaterion Multiplication """ 72 | 73 | def __mul__(self, other): 74 | """ 75 | Quaternion multiplication has three main methods. 76 | 77 | When multiplying a Quaternions array by Quaternions 78 | normal quaternion multiplication is performed. 79 | 80 | When multiplying a Quaternions array by a vector 81 | array of the same shape, where the last axis is 3, 82 | it is assumed to be a Quaternion by 3D-Vector 83 | multiplication and the 3D-Vectors are rotated 84 | in space by the Quaternions. 85 | 86 | When multipplying a Quaternions array by a scalar 87 | or vector of different shape it is assumed to be 88 | a Quaternions by Scalars multiplication and the 89 | Quaternions are scaled using Slerp and the identity 90 | quaternions. 91 | """ 92 | 93 | """ If Quaternions type do Quaternions * Quaternions """ 94 | if isinstance(other, Quaternions_torch): 95 | 96 | sqs, oqs = Quaternions_torch._broadcast(self.qs, other.qs) 97 | 98 | q0 = sqs[...,0]; q1 = sqs[...,1]; 99 | q2 = sqs[...,2]; q3 = sqs[...,3]; 100 | r0 = oqs[...,0]; r1 = oqs[...,1]; 101 | r2 = oqs[...,2]; r3 = oqs[...,3]; 102 | 103 | qs = torch.empty(sqs.shape).to(device) 104 | qs[...,0] = r0 * q0 - r1 * q1 - r2 * q2 - r3 * q3 105 | qs[...,1] = r0 * q1 + r1 * q0 - r2 * q3 + r3 * q2 106 | qs[...,2] = r0 * q2 + r1 * q3 + r2 * q0 - r3 * q1 107 | qs[...,3] = r0 * q3 - r1 * q2 + r2 * q1 + r3 * q0 108 | 109 | return Quaternions_torch(qs) 110 | 111 | """ If array type do Quaternions * Vectors """ 112 | if isinstance(other, torch.Tensor) and other.shape[-1] == 3: 113 | vs = Quaternions_torch(torch.cat([torch.zeros(other.shape[:-1] + (1,)).to(device), other], dim=-1)) 114 | # ipdb.set_trace() 115 | return (self * (vs * -self)).imaginaries 116 | 117 | """ If float do Quaternions * Scalars """ 118 | if isinstance(other,torch.Tensor) or isinstance(other, float): 119 | return Quaternions_torch.slerp(Quaternions_torch.id_like(self), self, other) 120 | 121 | raise TypeError('Cannot multiply/add Quaternions with type %s' % str(type(other))) 122 | 123 | def __div__(self, other): 124 | """ 125 | When a Quaternion type is supplied, division is defined 126 | as multiplication by the inverse of that Quaternion. 127 | 128 | When a scalar or vector is supplied it is defined 129 | as multiplicaion of one over the supplied value. 130 | Essentially a scaling. 131 | """ 132 | 133 | if isinstance(other, Quaternions_torch): return self * (-other) 134 | if isinstance(other, torch.Tensor): return self * (1.0 / other) 135 | if isinstance(other, float): return self * (1.0 / other) 136 | raise TypeError('Cannot divide/subtract Quaternions with type %s' + str(type(other))) 137 | 138 | def __eq__(self, other): return self.qs == other.qs 139 | def __ne__(self, other): return self.qs != other.qs 140 | 141 | def __neg__(self): 142 | """ Invert Quaternions """ 143 | return Quaternions_torch(self.qs * torch.tensor([[1, -1, -1, -1]]).to(device)) 144 | 145 | def __abs__(self): 146 | """ Unify Quaternions To Single Pole """ 147 | qabs = self.normalized().copy() 148 | top = torch.sum(( qabs.qs) * torch.tensor([1,0,0,0]).to(device), dim=-1) 149 | bot = torch.sum((-qabs.qs) * torch.tensor([1,0,0,0]).to(device), dim=-1) 150 | qabs.qs[top < bot] = -qabs.qs[top < bot] 151 | return qabs 152 | 153 | def __iter__(self): return iter(self.qs) 154 | def __len__(self): return len(self.qs) 155 | 156 | def __getitem__(self, k): return Quaternions_torch(self.qs[k]) 157 | def __setitem__(self, k, v): self.qs[k] = v.qs 158 | 159 | @property 160 | def lengths(self): 161 | return torch.sum(self.qs**2.0, axis=-1)**0.5 162 | 163 | @property 164 | def reals(self): 165 | return self.qs[...,0] 166 | 167 | @property 168 | def imaginaries(self): 169 | return self.qs[...,1:4] 170 | 171 | @property 172 | def shape(self): return self.qs.shape[:-1] 173 | 174 | def repeat(self, n, **kwargs): 175 | return Quaternions_torch(self.qs.repeat(n, **kwargs)) 176 | 177 | def normalized(self): 178 | return Quaternions_torch(self.qs / self.lengths.unsqueeze(-1)) 179 | 180 | def unsqueeze(self, dim): 181 | return Quaternions_torch(self.qs.unsqueeze(dim)) 182 | 183 | def log(self): 184 | norm = torch.abs(self.normalized()) 185 | imgs = norm.imaginaries 186 | lens = torch.sqrt(torch.sum(imgs**2, dim=-1)) 187 | lens = torch.arctan2(lens, norm.reals) / (lens + 1e-10) 188 | return imgs * lens.unsqueeze(-1) 189 | 190 | def constrained(self, axis): 191 | 192 | rl = self.reals 193 | im = torch.sum(axis * self.imaginaries, dim=-1) 194 | 195 | t1 = -2 * torch.arctan2(rl, im) + torch.pi 196 | t2 = -2 * torch.arctan2(rl, im) - torch.pi 197 | 198 | top = Quaternions_torch.exp(axis.unsqueeze(-1) * (t1.unsqueeze(-1) / 2.0)) 199 | bot = Quaternions_torch.exp(axis.unsqueeze(-1) * (t2.unsqueeze(-1) / 2.0)) 200 | img = self.dot(top) > self.dot(bot) 201 | 202 | ret = top.detach().clone() #.copy() 203 | ret[ img] = top[ img] 204 | ret[~img] = bot[~img] 205 | return ret 206 | 207 | def constrained_x(self): return self.constrained(torch.tensor([1,0,0]).to(device)) 208 | def constrained_y(self): return self.constrained(torch.tensor([0,1,0]).to(device)) 209 | def constrained_z(self): return self.constrained(torch.tensor([0,0,1]).to(device)) 210 | 211 | def dot(self, q): return torch.sum(self.qs * q.qs, axis=-1) 212 | 213 | def copy(self): return Quaternions_torch(self.qs.detach().clone()) 214 | 215 | def reshape(self, s): 216 | self.qs.reshape(s) 217 | return self 218 | 219 | @classmethod 220 | def between(cls, v0s, v1s): 221 | a = torch.cross(v0s, v1s) 222 | w = torch.sqrt((v0s**2).sum(dim=-1) * (v1s**2).sum(dim=-1)) + (v0s * v1s).sum(dim=-1) 223 | return Quaternions_torch(torch.cat([w.unsqueeze(-1), a], dim=-1)).normalized() 224 | 225 | -------------------------------------------------------------------------------- /utils/__pycache__/Pivots.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/utils/__pycache__/Pivots.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/Pivots_torch.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/utils/__pycache__/Pivots_torch.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/Quaternions.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/utils/__pycache__/Quaternions.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/Quaternions_torch.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/utils/__pycache__/Quaternions_torch.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/cfg_parser.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/utils/__pycache__/cfg_parser.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/cfg_parser.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/utils/__pycache__/cfg_parser.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/train_helper.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/utils/__pycache__/train_helper.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/train_tools.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/utils/__pycache__/train_tools.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/train_tools.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/utils/__pycache__/train_tools.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/utils/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/utils/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils_body.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/utils/__pycache__/utils_body.cpython-38.pyc -------------------------------------------------------------------------------- /utils/cfg_parser.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import yaml 4 | 5 | 6 | class Config(dict): 7 | 8 | def __init__(self,default_cfg_path=None,**kwargs): 9 | 10 | default_cfg = {} 11 | if default_cfg_path is not None and os.path.exists(default_cfg_path): 12 | default_cfg = self.load_cfg(default_cfg_path) 13 | 14 | super(Config,self).__init__(**kwargs) 15 | 16 | default_cfg.update(self) 17 | self.update(default_cfg) 18 | self.default_cfg = default_cfg 19 | 20 | def load_cfg(self,load_path): 21 | with open(load_path, 'r') as infile: 22 | cfg = yaml.safe_load(infile) 23 | return cfg if cfg is not None else {} 24 | 25 | def write_cfg(self,write_path=None): 26 | 27 | if write_path is None: 28 | write_path = 'yaml_config.yaml' 29 | 30 | dump_dict = {k:v for k,v in self.items() if k!='default_cfg'} 31 | with open(write_path, 'w') as outfile: 32 | yaml.safe_dump(dump_dict, outfile, default_flow_style=False) 33 | 34 | def __getattr__(self, key): 35 | try: 36 | return self[key] 37 | except KeyError: 38 | raise AttributeError(key) 39 | 40 | __setattr__ = dict.__setitem__ 41 | __delattr__ = dict.__delitem__ 42 | 43 | 44 | if __name__ == '__main__': 45 | 46 | cfg = { 47 | 'intent': 'all', 48 | 'only_contact': True, 49 | 'save_body_verts': False, 50 | 'save_object_verts': False, 51 | 'save_contact': False, 52 | } 53 | 54 | cfg = Config(**cfg) 55 | cfg.write_cfg() 56 | -------------------------------------------------------------------------------- /utils/como/SSM2.json: -------------------------------------------------------------------------------- 1 | { 2 | "gender": "unknown", 3 | "markersets": [ 4 | { 5 | "distance_from_skin": 0.0095, 6 | "indices": { 7 | "C7": 3832, 8 | "CLAV": 5533, 9 | "LANK": 5882, 10 | "LFWT": 3486, 11 | "LBAK": 3336, 12 | "LBCEP": 4029, 13 | "LBSH": 4137, 14 | "LBUM": 5694, 15 | "LBUST": 3228, 16 | "LCHEECK": 2081, 17 | "LELB": 4302, 18 | "LELBIN": 4363, 19 | "LFIN": 4788, 20 | "LFRM2": 4379, 21 | "LFTHI": 3504, 22 | "LFTHIIN": 3998, 23 | "LHEE": 8846, 24 | "LIWR": 4726, 25 | "LKNE": 3682, 26 | "LKNI": 3688, 27 | "LMT1": 5890, 28 | "LMT5": 5901, 29 | "LNWST": 3260, 30 | "LOWR": 4722, 31 | "LBWT": 5697, 32 | "LRSTBEEF": 5838, 33 | "LSHO": 4481, 34 | "LTHI": 4088, 35 | "LTHMB": 4839, 36 | "LTIB": 3745, 37 | "LTOE": 5787, 38 | "MBLLY": 5942, 39 | "RANK": 8576, 40 | "RFWT": 6248, 41 | "RBAK": 6127, 42 | "RBCEP": 6776, 43 | "RBSH": 7192, 44 | "RBUM": 8388, 45 | "RBUSTLO": 8157, 46 | "RCHEECK": 8786, 47 | "RELB": 7040, 48 | "RELBIN": 7099, 49 | "RFIN": 7524, 50 | "RFRM2": 7115, 51 | "RFRM2IN": 7303, 52 | "RFTHI": 6265, 53 | "RFTHIIN": 6746, 54 | "RHEE": 8634, 55 | "RKNE": 6443, 56 | "RKNI": 6449, 57 | "RMT1": 8584, 58 | "RMT5": 8595, 59 | "RNWST": 6023, 60 | "ROWR": 7458, 61 | "RBWT": 8391, 62 | "RRSTBEEF": 8532, 63 | "RSHO": 6627, 64 | "RTHI": 6832, 65 | "RTHMB": 7575, 66 | "RTIB": 6503, 67 | "RTOE": 8481, 68 | "STRN": 5531, 69 | "T8": 5487, 70 | "LFHD": 707, 71 | "LBHD": 2026, 72 | "RFHD": 2198, 73 | "RBHD": 3066 74 | }, 75 | "marker_radius": 0.0095, 76 | "type": "body" 77 | } 78 | ] 79 | } -------------------------------------------------------------------------------- /utils/como/__pycache__/como_smooth.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/utils/como/__pycache__/como_smooth.cpython-310.pyc -------------------------------------------------------------------------------- /utils/como/__pycache__/como_smooth.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/utils/como/__pycache__/como_smooth.cpython-38.pyc -------------------------------------------------------------------------------- /utils/como/__pycache__/como_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/utils/como/__pycache__/como_utils.cpython-310.pyc -------------------------------------------------------------------------------- /utils/como/__pycache__/como_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/utils/como/__pycache__/como_utils.cpython-38.pyc -------------------------------------------------------------------------------- /utils/como/como_smooth.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torchvision 6 | from torch.autograd import Variable 7 | 8 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 9 | 10 | 11 | class EncBlock(nn.Module): 12 | def __init__(self, nin, nout, downsample=True): 13 | super(EncBlock, self).__init__() 14 | self.downsample = downsample 15 | 16 | self.main = nn.Sequential( 17 | nn.Conv2d(in_channels=nin, out_channels=nout, kernel_size=3, stride=1, padding=1), 18 | nn.LeakyReLU(0.2), 19 | nn.Conv2d(in_channels=nout, out_channels=nout, kernel_size=3, stride=1, padding=1), 20 | nn.LeakyReLU(0.2), 21 | ) 22 | self.pooling = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 23 | 24 | def forward(self, input): 25 | output = self.main(input) 26 | if not self.downsample: 27 | return output 28 | else: 29 | output = self.pooling(output) 30 | return output 31 | 32 | 33 | 34 | class DecBlock(nn.Module): 35 | def __init__(self, nin, nout, upsample=True): 36 | super(DecBlock, self).__init__() 37 | self.upsample = upsample 38 | if upsample: 39 | deconv_stride = 2 40 | else: 41 | deconv_stride = 1 42 | 43 | self.deconv1 = nn.ConvTranspose2d(in_channels=nin, out_channels=nout, kernel_size=3, stride=deconv_stride, padding=1) 44 | self.deconv2 = nn.ConvTranspose2d(in_channels=nout, out_channels=nout, kernel_size=3, stride=1, padding=1) 45 | self.leaky_relu = nn.LeakyReLU(0.2) 46 | 47 | def forward(self, input, out_size): 48 | output = self.deconv1(input, output_size=out_size) 49 | output = self.leaky_relu(output) 50 | output = self.leaky_relu(self.deconv2(output)) 51 | return output 52 | 53 | 54 | class DecBlock_output(nn.Module): 55 | def __init__(self, nin, nout, upsample=True): 56 | super(DecBlock_output, self).__init__() 57 | self.upsample = upsample 58 | if upsample: 59 | deconv_stride = 2 60 | else: 61 | deconv_stride = 1 62 | 63 | self.deconv1 = nn.ConvTranspose2d(in_channels=nin, out_channels=nout, kernel_size=3, stride=deconv_stride, padding=1) 64 | self.deconv2 = nn.ConvTranspose2d(in_channels=nout, out_channels=nout, kernel_size=3, stride=1, padding=1) 65 | self.leaky_relu = nn.LeakyReLU(0.2) 66 | 67 | 68 | def forward(self, input, out_size): 69 | output = self.deconv1(input, output_size=out_size) 70 | output = self.leaky_relu(output) 71 | output = self.deconv2(output) 72 | return output 73 | 74 | 75 | 76 | 77 | class Enc(nn.Module): 78 | def __init__(self, downsample=True, z_channel=64): 79 | super(Enc, self).__init__() 80 | if z_channel == 256: 81 | channel_2, channel_3 = 128, 256 82 | elif z_channel == 64: 83 | channel_2, channel_3 = 64, 64 84 | self.enc_blc1 = EncBlock(nin=1, nout=32, downsample=downsample) 85 | self.enc_blc2 = EncBlock(nin=32, nout=64, downsample=downsample) 86 | self.enc_blc3 = EncBlock(nin=64, nout=channel_2, downsample=downsample) 87 | self.enc_blc4 = EncBlock(nin=channel_2, nout=channel_3, downsample=downsample) 88 | self.enc_blc5 = EncBlock(nin=channel_3, nout=channel_3, downsample=downsample) 89 | 90 | 91 | def forward(self, input): 92 | # input: [bs, 1, d, T] 93 | # [bs, 1, 51/145/75, 120] (smplx_params, no hand, vposer/6d_rot)/(joints, no hand) 94 | x_down1 = self.enc_blc1(input) # [bs, 32, 26, 60] 95 | x_down2 = self.enc_blc2(x_down1) # [bs, 64, 13, 30] 96 | x_down3 = self.enc_blc3(x_down2) # [bs, 128, 7, 15] 97 | x_down4 = self.enc_blc4(x_down3) # [bs, 256, 4, 8] 98 | z = self.enc_blc5(x_down4) # [bs, 256, 2/5/3, 4] (smplx_params, no hand, vposer/6d_rot)/(joints, no hand) 99 | return z, input.size(), x_down1.size(), x_down2.size(), x_down3.size(), x_down4.size() 100 | 101 | 102 | class Dec(nn.Module): 103 | def __init__(self, downsample=True, z_channel=64): 104 | super(Dec, self).__init__() 105 | if z_channel == 256: 106 | channel_2, channel_3 = 128, 256 107 | elif z_channel == 64: 108 | channel_2, channel_3 = 64, 64 109 | 110 | self.dec_blc1 = DecBlock(nin=channel_3, nout=channel_3, upsample=downsample) 111 | self.dec_blc2 = DecBlock(nin=channel_3, nout=channel_2, upsample=downsample) 112 | self.dec_blc3 = DecBlock(nin=channel_2, nout=64, upsample=downsample) 113 | self.dec_blc4 = DecBlock(nin=64, nout=32, upsample=downsample) 114 | self.dec_blc5 = DecBlock_output(nin=32, nout=1, upsample=downsample) 115 | 116 | 117 | def forward(self, z, input_size, x_down1_size, x_down2_size, x_down3_size, x_down4_size): 118 | x_up4 = self.dec_blc1(z, x_down4_size) # [bs, 256, 4, 8] 119 | x_up3 = self.dec_blc2(x_up4, x_down3_size) # [bs, 128, 7, 15] 120 | x_up2 = self.dec_blc3(x_up3, x_down2_size) # [bs, 64, 13, 30] 121 | x_up1 = self.dec_blc4(x_up2, x_down1_size) # [bs, 32, 26, 60] 122 | output = self.dec_blc5(x_up1, input_size) 123 | return output 124 | -------------------------------------------------------------------------------- /utils/como/como_smooth_model.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/utils/como/como_smooth_model.pkl -------------------------------------------------------------------------------- /utils/como/my_SSM2.json: -------------------------------------------------------------------------------- 1 | { 2 | "gender": "unknown", 3 | "markersets": [ 4 | { 5 | "distance_from_skin": 0.0095, 6 | "indices": { 7 | "C7": 3832, 8 | "CLAV": 5533, 9 | "LANK": 5882, 10 | "LFWT": 3486, 11 | "LBAK": 3336, 12 | "LBCEP": 4029, 13 | "LBSH": 4137, 14 | "LBUM": 5694, 15 | "LBUST": 3228, 16 | "LCHEECK": 2081, 17 | "LELB": 4302, 18 | "LELBIN": 4363, 19 | "LFIN": 4788, 20 | "LFRM2": 4379, 21 | "LFTHI": 3504, 22 | "LFTHIIN": 3998, 23 | "LHEE": 8846, 24 | "LIWR": 4726, 25 | "LKNE": 3682, 26 | "LKNI": 3688, 27 | "LMT1": 5890, 28 | "LMT5": 5901, 29 | "LNWST": 3260, 30 | "LOWR": 4722, 31 | "LBWT": 5697, 32 | "LRSTBEEF": 5838, 33 | "LSHO": 4481, 34 | "LTHI": 4088, 35 | "LTHMB": 4839, 36 | "LTIB": 3745, 37 | "LTOE": 5787, 38 | "MBLLY": 5942, 39 | "RANK": 8576, 40 | "RFWT": 6248, 41 | "RBAK": 6127, 42 | "RBCEP": 6776, 43 | "RBSH": 7192, 44 | "RBUM": 8388, 45 | "RBUSTLO": 8157, 46 | "RCHEECK": 8786, 47 | "RELB": 7040, 48 | "RELBIN": 7099, 49 | "RFIN": 7524, 50 | "RFRM2": 7115, 51 | "RFRM2IN": 7303, 52 | "RFTHI": 6265, 53 | "RFTHIIN": 6746, 54 | "RHEE": 8634, 55 | "RKNE": 6443, 56 | "RKNI": 6449, 57 | "RMT1": 8584, 58 | "RMT5": 8595, 59 | "RNWST": 6023, 60 | "ROWR": 7458, 61 | "RBWT": 8391, 62 | "RRSTBEEF": 8532, 63 | "RSHO": 6627, 64 | "RTHI": 6832, 65 | "RTHMB": 7575, 66 | "RTIB": 6503, 67 | "RTOE": 8481, 68 | "STRN": 5531, 69 | "T8": 5487, 70 | "LFHD": 707, 71 | "LBHD": 2026, 72 | "RFHD": 2198, 73 | "RBHD": 3066, 74 | "CHN1": 8757, 75 | "CHN2": 9066, 76 | "MTH3": 8985, 77 | "MTH7": 8947, 78 | "LIDX3": 4931, 79 | "LMID3": 5045, 80 | "LPNK3": 5268, 81 | "LRNG3": 5149, 82 | "LTHM4": 5346, 83 | "RIDX3": 7667, 84 | "RMID3": 7781, 85 | "RPNK3": 8001, 86 | "RRNG3": 7884, 87 | "RTHM4": 8082 88 | }, 89 | "marker_radius": 0.0095, 90 | "type": "body" 91 | } 92 | ] 93 | } -------------------------------------------------------------------------------- /utils/como/preprocess_stats_global_markers/Xmean.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/utils/como/preprocess_stats_global_markers/Xmean.npy -------------------------------------------------------------------------------- /utils/como/preprocess_stats_global_markers/Xstd.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/utils/como/preprocess_stats_global_markers/Xstd.npy -------------------------------------------------------------------------------- /utils/train_helper.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import chamfer_distance as chd 5 | import numpy as np 6 | import scipy.ndimage.filters as filters 7 | import torch 8 | 9 | from utils.Pivots import Pivots 10 | from utils.Pivots_torch import Pivots_torch 11 | from utils.Quaternions import Quaternions 12 | from utils.Quaternions_torch import Quaternions_torch 13 | 14 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 15 | 16 | 17 | def point2point_signed( 18 | x, 19 | y, 20 | x_normals=None, 21 | y_normals=None, 22 | return_vector=False, 23 | ): 24 | """ 25 | signed distance between two pointclouds 26 | 27 | Args: 28 | x: FloatTensor of shape (N, P1, D) representing a batch of point clouds 29 | with P1 points in each batch element, batch size N and feature 30 | dimension D. 31 | y: FloatTensor of shape (N, P2, D) representing a batch of point clouds 32 | with P2 points in each batch element, batch size N and feature 33 | dimension D. 34 | x_normals: Optional FloatTensor of shape (N, P1, D). 35 | y_normals: Optional FloatTensor of shape (N, P2, D). 36 | 37 | Returns: 38 | 39 | - y2x_signed: Torch.Tensor 40 | the sign distance from y to x 41 | - y2x_signed: Torch.Tensor 42 | the sign distance from y to x 43 | - yidx_near: Torch.tensor 44 | the indices of x vertices closest to y 45 | 46 | """ 47 | 48 | 49 | N, P1, D = x.shape 50 | P2 = y.shape[1] 51 | 52 | if y.shape[0] != N or y.shape[2] != D: 53 | raise ValueError("y does not have the correct shape.") 54 | 55 | # ch_dist = chd.ChamferDistance() 56 | 57 | x_near, y_near, xidx_near, yidx_near = chd.ChamferDistance(x,y) 58 | 59 | xidx_near_expanded = xidx_near.view(N, P1, 1).expand(N, P1, D).to(torch.long) 60 | x_near = y.gather(1, xidx_near_expanded) 61 | 62 | yidx_near_expanded = yidx_near.view(N, P2, 1).expand(N, P2, D).to(torch.long) 63 | y_near = x.gather(1, yidx_near_expanded) 64 | 65 | x2y = x - x_near # y point to x 66 | y2x = y - y_near # x point to y 67 | 68 | if x_normals is not None: 69 | y_nn = x_normals.gather(1, yidx_near_expanded) 70 | in_out = torch.bmm(y_nn.view(-1, 1, 3), y2x.view(-1, 3, 1)).view(N, -1).sign() 71 | y2x_signed = y2x.norm(dim=2) * in_out 72 | 73 | else: 74 | y2x_signed = y2x.norm(dim=2) 75 | 76 | if y_normals is not None: 77 | x_nn = y_normals.gather(1, xidx_near_expanded) 78 | in_out_x = torch.bmm(x_nn.view(-1, 1, 3), x2y.view(-1, 3, 1)).view(N, -1).sign() 79 | x2y_signed = x2y.norm(dim=2) * in_out_x 80 | else: 81 | x2y_signed = x2y.norm(dim=2) 82 | 83 | if not return_vector: 84 | return y2x_signed, x2y_signed, yidx_near, xidx_near 85 | else: 86 | return y2x_signed, x2y_signed, yidx_near, xidx_near, y2x, x2y 87 | 88 | 89 | 90 | 91 | class EarlyStopping: 92 | """Early stops the training if validation loss doesn't improve after a given patience.""" 93 | def __init__(self, patience=7, verbose=False, delta=0, trace_func=None): 94 | """ 95 | Args: 96 | patience (int): How long to wait after last time validation loss improved. 97 | Default: 7 98 | verbose (bool): If True, prints a message for each validation loss improvement. 99 | Default: False 100 | delta (float): Minimum change in the monitored quantity to qualify as an improvement. 101 | Default: 0 102 | trace_func (function): trace print function. 103 | Default: print 104 | """ 105 | self.patience = patience 106 | self.verbose = verbose 107 | self.counter = 0 108 | self.best_score = None 109 | self.early_stop = False 110 | self.val_loss_min = np.Inf 111 | self.delta = delta 112 | self.trace_func = trace_func 113 | def __call__(self, val_loss): 114 | 115 | score = -val_loss 116 | 117 | if self.best_score is None: 118 | self.best_score = score 119 | elif score < self.best_score + self.delta: 120 | self.counter += 1 121 | if self.trace_func is not None: 122 | self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}') 123 | if self.counter >= self.patience: 124 | self.early_stop = True 125 | else: 126 | self.best_score = score 127 | self.counter = 0 128 | return self.early_stop 129 | 130 | def save_ckp(state, checkpoint_dir): 131 | f_path = os.path.join(checkpoint_dir, 'checkpoint.pt') 132 | torch.save(state, f_path) 133 | 134 | def load_ckp(checkpoint_fpath, model, optimizer): 135 | checkpoint = torch.load(checkpoint_fpath) 136 | model.load_state_dict(checkpoint['state_dict']) 137 | optimizer.load_state_dict(checkpoint['optimizer']) 138 | return model, optimizer, checkpoint['epoch'] 139 | 140 | 141 | def get_forward_joint(joint_start): 142 | """ Joint_start: [B, N, 3] in xyz """ 143 | x_axis = joint_start[:, 2, :] - joint_start[:, 1, :] 144 | x_axis[:, -1] = 0 145 | x_axis = x_axis / torch.norm(x_axis, dim=-1).unsqueeze(1) 146 | z_axis = torch.tensor([0, 0, 1]).float().unsqueeze(0).repeat(len(x_axis), 1).to(device) 147 | y_axis = torch.cross(z_axis, x_axis) 148 | y_axis = y_axis / torch.norm(y_axis, dim=-1).unsqueeze(1) 149 | transf_rotmat = torch.stack([x_axis, y_axis, z_axis], dim=1) 150 | return y_axis, transf_rotmat 151 | 152 | def prepare_traj_input(joint_start, joint_end, traj_Xmean, traj_Xstd): 153 | """ Joints: [B, N, 3] in xyz """ 154 | B, N, _ = joint_start.shape 155 | T = 62 156 | joint_sr_input_unnormed = torch.ones(B, 4, T) # [B, xyr, T] 157 | y_axis, transf_rotmat = get_forward_joint(joint_start) 158 | joint_start_new = joint_start.clone() 159 | joint_end_new = joint_end.clone() # to check whether original joints change or not 160 | joint_start_new = torch.matmul(joint_start - joint_start[:, 0:1], transf_rotmat) 161 | joint_end_new = torch.matmul(joint_end - joint_start[:, 0:1], transf_rotmat) 162 | 163 | # start_forward, _ = get_forward_joint(joint_start_new) 164 | start_forward = torch.tensor([0, 1, 0]).unsqueeze(0) 165 | end_forward, _ = get_forward_joint(joint_end_new) 166 | 167 | joint_sr_input_unnormed[:, :2, 0] = joint_start_new[:, 0, :2] # xy 168 | joint_sr_input_unnormed[:, :2, -2] = joint_end_new[:, 0, :2] # xy 169 | joint_sr_input_unnormed[:, 2:, 0] = start_forward[:, :2] # r 170 | joint_sr_input_unnormed[:, 2:, -2] = end_forward[:, :2] # r 171 | 172 | # normalize 173 | traj_mean = traj_Xmean.unsqueeze(2).cpu() 174 | traj_std = traj_Xstd.unsqueeze(2).cpu() 175 | 176 | # linear interpolation 177 | joint_sr_input_normed = (joint_sr_input_unnormed - traj_mean) / traj_std 178 | for t in range(joint_sr_input_normed.size(-1)): 179 | joint_sr_input_normed[:, :, t] = joint_sr_input_normed[:, :, 0] + (joint_sr_input_normed[:, :, -2] - joint_sr_input_normed[:, :, 0])*t/(joint_sr_input_normed.size(-1)-2) 180 | joint_sr_input_normed[:, -2:, t] = joint_sr_input_normed[:, -2:, t] / torch.norm(joint_sr_input_normed[:, -2:, t], dim=1).unsqueeze(1) 181 | 182 | for t in range(joint_sr_input_unnormed.size(-1)): 183 | joint_sr_input_unnormed[:, :, t] = joint_sr_input_unnormed[:, :, 0] + (joint_sr_input_unnormed[:, :, -2] - joint_sr_input_unnormed[:, :, 0])*t/(joint_sr_input_unnormed.size(-1)-2) 184 | joint_sr_input_unnormed[:, -2:, t] = joint_sr_input_unnormed[:, -2:, t] / torch.norm(joint_sr_input_unnormed[:, -2:, t], dim=1).unsqueeze(1) 185 | 186 | return joint_sr_input_normed.float().to(device), joint_sr_input_unnormed.float().to(device), transf_rotmat, joint_start_new, joint_end_new 187 | 188 | def prepare_clip_img_input(clip_img, marker_start, marker_end, joint_start, joint_end, joint_start_new, joint_end_new, transf_rotmat, traj_pred_unnormed, traj_sr_input_unnormed, traj_smoothed, markers_stats): 189 | traj_pred_unnormed = traj_pred_unnormed.detach().cpu().numpy() 190 | 191 | traj_pred_unnormed[:, :, 0] = traj_sr_input_unnormed[:, :, 0].detach().cpu().numpy() 192 | traj_pred_unnormed[:, :, -2] = traj_sr_input_unnormed[:, :, -2].detach().cpu().numpy() 193 | 194 | B, n_markers, _ = marker_start.shape 195 | _, n_joints, _ = joint_start.shape 196 | markers = torch.rand(B, 61, n_markers, 3) # [B, T, N ,3] 197 | joints = torch.rand(B, 61, n_joints, 3) # [B, T, N ,3] 198 | 199 | marker_start_new = torch.matmul(marker_start - joint_start[:, 0:1], transf_rotmat) 200 | marker_end_new = torch.matmul(marker_end - joint_start[:, 0:1], transf_rotmat) 201 | 202 | z_transl_to_floor_start = torch.min(marker_start_new[:, :, -1], dim=-1)[0]# - 0.03 203 | z_transl_to_floor_end = torch.min(marker_end_new[:, :, -1], dim=-1)[0]# - 0.03 204 | 205 | marker_start_new[:, :, -1] -= z_transl_to_floor_start.unsqueeze(1) 206 | marker_end_new[:, :, -1] -= z_transl_to_floor_end.unsqueeze(1) 207 | joint_start_new[:, :, -1] -= z_transl_to_floor_start.unsqueeze(1) 208 | joint_end_new[:, :, -1] -= z_transl_to_floor_end.unsqueeze(1) 209 | 210 | markers[:, 0] = marker_start_new 211 | markers[:, -1] = marker_end_new 212 | joints[:, 0] = joint_start_new 213 | joints[:, -1] = joint_end_new 214 | 215 | cur_body = torch.cat([joints[:, :, 0:1], markers], dim=2) 216 | cur_body[:, :, :, [1, 2]] = cur_body[:, :, :, [2, 1]] # => xyz -> xzy 217 | reference = cur_body[:, :, 0] * torch.tensor([1, 0, 1]) # => the xy of pelvis joint? 218 | cur_body = torch.cat([reference.unsqueeze(2), cur_body], dim=2) # [B, T, 1(reference)+1(pelvis)+N, 3] 219 | 220 | # position to local frame 221 | cur_body[:, :, :, 0] = cur_body[:, :, :, 0] - cur_body[:, :, 0:1, 0] 222 | cur_body[:, :, :, -1] = cur_body[:, :, :, -1] - cur_body[:, :, 0:1, -1] 223 | 224 | forward = np.zeros((B, 62, 3)) 225 | forward[:, :, :2] = traj_pred_unnormed[:, 2:].transpose(0, 2, 1) 226 | forward = forward / np.sqrt((forward ** 2).sum(axis=-1))[..., np.newaxis] 227 | forward[:, :, [1, 2]] = forward[:, :, [2, 1]] 228 | 229 | if traj_smoothed: 230 | forward_saved = forward.copy() 231 | direction_filterwidth = 20 232 | forward = filters.gaussian_filter1d(forward, direction_filterwidth, axis=1, mode='nearest') 233 | traj_pred_unnormed[:, 2] = forward[:, :, 0] 234 | traj_pred_unnormed[:, 3] = forward[:, :, -1] 235 | 236 | target = np.array([[0, 0, 1]]) 237 | rotation = Quaternions.between(forward, target)[:, :, np.newaxis] # [B, T, 1, 4] 238 | 239 | cur_body = rotation[:, :-1] * cur_body.detach().cpu().numpy() # [B, T, 1+1+N, xzy] 240 | cur_body[:, 1:-1] = 0 241 | cur_body[:, :, :, [1, 2]] = cur_body[:, :, :, [2, 1]] # xzy => xyz 242 | cur_body = cur_body[:, :, 1:, :] 243 | cur_body = cur_body.reshape(cur_body.shape[0], cur_body.shape[1], -1) # [B, T, N*3] 244 | 245 | velocity = np.zeros((B, 3, 61)) 246 | velocity[:, 0, :] = traj_pred_unnormed[:, 0, 1:] - traj_pred_unnormed[:, 0, 0:-1] # [B, 2, 61] on Joint frame 247 | velocity[:, -1, :] = traj_pred_unnormed[:, 1, 1:] - traj_pred_unnormed[:, 1, 0:-1] # [B, 2, 61] on Joint frame 248 | 249 | velocity = rotation[:, 1:] * velocity.transpose(0, 2, 1).reshape(B, 61, 1, 3) 250 | rvelocity = Pivots.from_quaternions(rotation[:, 1:] * -rotation[:, :-1]).ps # [B, T-1, 1] 251 | rot_0_pivot = Pivots.from_quaternions(rotation[:, 0]).ps 252 | 253 | global_x = velocity[:, :, 0, 0] 254 | global_y = velocity[:, :, 0, 2] 255 | contact_lbls = np.zeros((B, 61, 4)) 256 | 257 | channel_local = np.concatenate([cur_body, contact_lbls], axis=-1)[:, np.newaxis, :, :] # [B, 1, T-1, d=N*3+4] 258 | T, d = channel_local.shape[-2], channel_local.shape[-1] 259 | channel_global_x = np.repeat(global_x, d).reshape(-1, 1, T, d) # [B, 1, T-1, d] 260 | channel_global_y = np.repeat(global_y, d).reshape(-1, 1, T, d) # [B, 1, T-1, d] 261 | channel_global_r = np.repeat(rvelocity, d).reshape(-1, 1, T, d) # [B, 1, T-1, d] 262 | 263 | cur_body = np.concatenate([clip_img[:, 0:1].detach().permute(0, 1, 3, 2).cpu().numpy(), channel_global_x, channel_global_y, channel_global_r], axis=1) # [B, 4, T-1, d] 264 | 265 | # cur_body[:, 0] = (cur_body[:, 0] - markers_stats['Xmean_local']) / markers_stats['Xstd_local'] 266 | cur_body[:, 1:3] = (cur_body[:, 1:3] - markers_stats['Xmean_global_xy']) / markers_stats['Xstd_global_xy'] 267 | cur_body[:, 3] = (cur_body[:, 3] - markers_stats['Xmean_global_r']) / markers_stats['Xstd_global_r'] 268 | 269 | # mask cur_body 270 | cur_body = cur_body.transpose(0, 1, 3, 2) # [B, 4, D, T-1] 271 | mask_t_1 = [0, 60] 272 | mask_t_0 = list(set(range(60+1)) - set(mask_t_1)) 273 | cur_body[:, 0, 2:, mask_t_0] = 0. 274 | cur_body[:, 0, -4:, :] = 0. 275 | # print('Mask the markers in the following frames: ', mask_t_0) 276 | 277 | return torch.from_numpy(cur_body).float().to(device), rot_0_pivot, marker_start_new, marker_end_new, traj_pred_unnormed 278 | 279 | 280 | def prepare_clip_img_input_torch(clip_img, marker_start, marker_end, joint_start, joint_end, 281 | joint_start_new, joint_end_new, transf_rotmat, 282 | traj_pred_unnormed, traj_sr_input_unnormed, 283 | traj_smoothed, markers_stats): 284 | 285 | traj_pred_unnormed[:, :, 0] = traj_sr_input_unnormed[:, :, 0]#.detach().cpu().numpy() 286 | traj_pred_unnormed[:, :, -2] = traj_sr_input_unnormed[:, :, -2]#.detach().cpu().numpy() 287 | 288 | B, n_markers, _ = marker_start.shape 289 | _, n_joints, _ = joint_start.shape 290 | markers = torch.rand(B, 61, n_markers, 3).to(device) # [B, T, N ,3] 291 | joints = torch.rand(B, 61, n_joints, 3).to(device) # [B, T, N ,3] 292 | 293 | marker_start_new = torch.matmul(marker_start - joint_start[:, 0:1], transf_rotmat) 294 | marker_end_new = torch.matmul(marker_end - joint_start[:, 0:1], transf_rotmat) 295 | 296 | z_transl_to_floor_start = torch.min(marker_start_new[:, :, -1], dim=-1)[0]# - 0.03 297 | z_transl_to_floor_end = torch.min(marker_end_new[:, :, -1], dim=-1)[0]# - 0.03 298 | 299 | marker_start_new[:, :, -1] -= z_transl_to_floor_start.unsqueeze(1) 300 | marker_end_new[:, :, -1] -= z_transl_to_floor_end.unsqueeze(1) 301 | joint_start_new[:, :, -1] -= z_transl_to_floor_start.unsqueeze(1) 302 | joint_end_new[:, :, -1] -= z_transl_to_floor_end.unsqueeze(1) 303 | 304 | markers[:, 0] = marker_start_new 305 | markers[:, -1] = marker_end_new 306 | joints[:, 0] = joint_start_new 307 | joints[:, -1] = joint_end_new 308 | 309 | cur_body = torch.cat([joints[:, :, 0:1], markers], dim=2) 310 | cur_body[:, :, :, [1, 2]] = cur_body[:, :, :, [2, 1]] # => xyz -> xzy 311 | reference = cur_body[:, :, 0] * torch.tensor([1, 0, 1]).to(device) # => the xy of pelvis joint? 312 | cur_body = torch.cat([reference.unsqueeze(2), cur_body], dim=2) # [B, T, 1(reference)+1(pelvis)+N, 3] 313 | 314 | # position to local frame 315 | cur_body[:, :, :, 0] = cur_body[:, :, :, 0] - cur_body[:, :, 0:1, 0] 316 | cur_body[:, :, :, -1] = cur_body[:, :, :, -1] - cur_body[:, :, 0:1, -1] 317 | 318 | forward = torch.zeros((B, 62, 3)).to(device) 319 | forward[:, :, :2] = traj_pred_unnormed[:, 2:].permute(0, 2, 1) 320 | forward = forward / torch.sqrt((forward ** 2).sum(dim=-1)).unsqueeze(-1) 321 | forward[:, :, [1, 2]] = forward[:, :, [2, 1]] 322 | 323 | if traj_smoothed: 324 | # forward_saved = forward.copy() 325 | direction_filterwidth = 20 326 | forward = filters.gaussian_filter1d(forward, direction_filterwidth, axis=1, mode='nearest') 327 | traj_pred_unnormed[:, 2] = forward[:, :, 0] 328 | traj_pred_unnormed[:, 3] = forward[:, :, -1] 329 | 330 | target = torch.tensor([[[0, 0, 1]]]).float().to(device).repeat(forward.size(0), forward.size(1), 1) #.repeat(len(forward), axis=0) 331 | # rotation = Quaternions.between(forward, target)[:, :, np.newaxis] # [B, T, 1, 4] 332 | rotation = Quaternions_torch.between(forward, target).unsqueeze(2) 333 | 334 | cur_body = rotation[:, :-1] * cur_body#.detach().cpu().numpy() # [B, T, 1+1+N, xzy] 335 | cur_body[:, 1:-1] = 0 336 | cur_body[:, :, :, [1, 2]] = cur_body[:, :, :, [2, 1]] # xzy => xyz 337 | cur_body = cur_body[:, :, 1:, :] 338 | cur_body = cur_body.reshape(cur_body.shape[0], cur_body.shape[1], -1) # [B, T, N*3] 339 | 340 | velocity = torch.zeros((B, 3, 61)).to(device) 341 | velocity[:, 0, :] = traj_pred_unnormed[:, 0, 1:] - traj_pred_unnormed[:, 0, 0:-1] # [B, 2, 61] on Joint frame 342 | velocity[:, -1, :] = traj_pred_unnormed[:, 1, 1:] - traj_pred_unnormed[:, 1, 0:-1] # [B, 2, 61] on Joint frame 343 | 344 | velocity = rotation[:, 1:] * velocity.permute(0, 2, 1).reshape(B, 61, 1, 3) 345 | rvelocity = Pivots_torch.from_quaternions(rotation[:, 1:] * -rotation[:, :-1]).ps # [B, T-1, 1] 346 | rot_0_pivot = Pivots_torch.from_quaternions(rotation[:, 0]).ps 347 | 348 | global_x = velocity[:, :, 0, 0] 349 | global_y = velocity[:, :, 0, 2] 350 | 351 | T, d = clip_img.shape[-1], clip_img.shape[-2] 352 | channel_global_x = torch.repeat_interleave(global_x, d).reshape(-1, 1, T, d) # [B, 1, T-1, d] 353 | channel_global_y = torch.repeat_interleave(global_y, d).reshape(-1, 1, T, d) # [B, 1, T-1, d] 354 | channel_global_r = torch.repeat_interleave(rvelocity, d).reshape(-1, 1, T, d) # [B, 1, T-1, d] 355 | 356 | cur_body = torch.cat([clip_img[:, 0:1].permute(0, 1, 3, 2), channel_global_x, channel_global_y, channel_global_r], dim=1) # [B, 4, T-1, d] 357 | 358 | # cur_body[:, 0] = (cur_body[:, 0] - markers_stats['Xmean_local']) / markers_stats['Xstd_local'] 359 | cur_body[:, 1:3] = (cur_body[:, 1:3] - torch.from_numpy(markers_stats['Xmean_global_xy']).float().to(device)) / torch.from_numpy(markers_stats['Xstd_global_xy']).float().to(device) 360 | cur_body[:, 3] = (cur_body[:, 3] - torch.from_numpy(markers_stats['Xmean_global_r']).float().to(device)) / torch.from_numpy(markers_stats['Xstd_global_r']).float().to(device) 361 | 362 | # mask cur_body 363 | cur_body = cur_body.permute(0, 1, 3, 2) # [B, 4, D, T-1] 364 | mask_t_1 = [0, 60] 365 | mask_t_0 = list(set(range(60+1)) - set(mask_t_1)) 366 | cur_body[:, 0, 2:, mask_t_0] = 0. 367 | cur_body[:, 0, -4:, :] = 0. 368 | # print('Mask the markers in the following frames: ', mask_t_0) 369 | return cur_body, rot_0_pivot, marker_start_new, marker_end_new, traj_pred_unnormed 370 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 9 | to_cpu = lambda tensor: tensor.detach().cpu().numpy() 10 | 11 | class Struct(object): 12 | def __init__(self, **kwargs): 13 | for key, val in kwargs.items(): 14 | setattr(self, key, val) 15 | 16 | 17 | def to_tensor(array, dtype=torch.float32): 18 | if not torch.is_tensor(array): 19 | array = torch.tensor(array) 20 | return array.to(dtype) 21 | 22 | 23 | def to_np(array, dtype=np.float32): 24 | if 'scipy.sparse' in str(type(array)): 25 | array = np.array(array.todense(), dtype=dtype) 26 | elif torch.is_tensor(array): 27 | array = array.detach().cpu().numpy() 28 | return array.astype(dtype) 29 | 30 | 31 | def makepath(desired_path, isfile = False): 32 | ''' 33 | if the path does not exist make it 34 | :param desired_path: can be path to a file or a folder name 35 | :return: 36 | ''' 37 | import os 38 | if isfile: 39 | if not os.path.exists(os.path.dirname(desired_path)):os.makedirs(os.path.dirname(desired_path)) 40 | else: 41 | if not os.path.exists(desired_path): os.makedirs(desired_path) 42 | return desired_path 43 | 44 | def makelogger(log_dir,mode='w'): 45 | 46 | 47 | logger = logging.getLogger() 48 | logger.setLevel(logging.INFO) 49 | 50 | ch = logging.StreamHandler() 51 | ch.setLevel(logging.INFO) 52 | 53 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 54 | 55 | ch.setFormatter(formatter) 56 | 57 | logger.addHandler(ch) 58 | 59 | fh = logging.FileHandler('%s'%log_dir, mode=mode) 60 | fh.setFormatter(formatter) 61 | logger.addHandler(fh) 62 | 63 | return logger 64 | 65 | def CRot2rotmat(pose): 66 | 67 | reshaped_input = pose.view(-1, 3, 2) 68 | 69 | b1 = F.normalize(reshaped_input[:, :, 0], dim=1) 70 | 71 | dot_prod = torch.sum(b1 * reshaped_input[:, :, 1], dim=1, keepdim=True) 72 | b2 = F.normalize(reshaped_input[:, :, 1] - dot_prod * b1, dim=-1) 73 | b3 = torch.cross(b1, b2, dim=1) 74 | 75 | return torch.stack([b1, b2, b3], dim=-1) 76 | 77 | 78 | def euler(rots, order='xyz', units='deg'): 79 | 80 | rots = np.asarray(rots) 81 | single_val = False if len(rots.shape)>1 else True 82 | rots = rots.reshape(-1,3) 83 | rotmats = [] 84 | 85 | for xyz in rots: 86 | if units == 'deg': 87 | xyz = np.radians(xyz) 88 | r = np.eye(3) 89 | for theta, axis in zip(xyz,order): 90 | c = np.cos(theta) 91 | s = np.sin(theta) 92 | if axis=='x': 93 | r = np.dot(np.array([[1,0,0],[0,c,-s],[0,s,c]]), r) 94 | if axis=='y': 95 | r = np.dot(np.array([[c,0,s],[0,1,0],[-s,0,c]]), r) 96 | if axis=='z': 97 | r = np.dot(np.array([[c,-s,0],[s,c,0],[0,0,1]]), r) 98 | rotmats.append(r) 99 | rotmats = np.stack(rotmats).astype(np.float32) 100 | if single_val: 101 | return rotmats[0] 102 | else: 103 | return rotmats 104 | 105 | def batch_euler(bxyz,order='xyz', units='deg'): 106 | 107 | br = [] 108 | for frame in range(bxyz.shape[0]): 109 | br.append(euler(bxyz[frame], order, units)) 110 | return np.stack(br).astype(np.float32) 111 | 112 | def rotate(points,R): 113 | shape = points.shape 114 | if len(shape)>3: 115 | points = points.squeeze() 116 | if len(shape)<3: 117 | points = points[:,np.newaxis] 118 | r_points = torch.matmul(torch.from_numpy(points).to(device), torch.from_numpy(R).to(device).transpose(1,2)) 119 | return r_points.cpu().numpy().reshape(shape) 120 | 121 | def rotmul(rotmat,R): 122 | 123 | shape = rotmat.shape 124 | rotmat = rotmat.squeeze() 125 | R = R.squeeze() 126 | rot = torch.matmul(torch.from_numpy(R).to(device),torch.from_numpy(rotmat).to(device)) 127 | return rot.cpu().numpy().reshape(shape) 128 | 129 | # import torchgeometry as tgm 130 | # borrowed from the torchgeometry package 131 | def rotmat2aa(rotmat): 132 | ''' 133 | :param rotmat: Nx1xnum_jointsx9 134 | :return: Nx1xnum_jointsx3 135 | ''' 136 | batch_size = rotmat.size(0) 137 | homogen_matrot = F.pad(rotmat.view(-1, 3, 3), [0,1]) 138 | pose = rotation_matrix_to_angle_axis(homogen_matrot).view(batch_size, 1, -1, 3).contiguous() 139 | return pose 140 | 141 | def aa2rotmat(axis_angle): 142 | ''' 143 | :param Nx1xnum_jointsx3 144 | :return: pose_matrot: Nx1xnum_jointsx9 145 | ''' 146 | batch_size = axis_angle.size(0) 147 | pose_body_matrot = angle_axis_to_rotation_matrix(axis_angle.reshape(-1, 3))[:, :3, :3].contiguous().view(batch_size, 1, -1, 9) 148 | return pose_body_matrot 149 | 150 | def angle_axis_to_rotation_matrix(angle_axis): 151 | """Convert 3d vector of axis-angle rotation to 4x4 rotation matrix 152 | 153 | Args: 154 | angle_axis (Tensor): tensor of 3d vector of axis-angle rotations. 155 | 156 | Returns: 157 | Tensor: tensor of 4x4 rotation matrices. 158 | 159 | Shape: 160 | - Input: :math:`(N, 3)` 161 | - Output: :math:`(N, 4, 4)` 162 | 163 | Example: 164 | >>> input = torch.rand(1, 3) # Nx3 165 | >>> output = angle_axis_to_rotation_matrix(input) # Nx4x4 166 | """ 167 | def _compute_rotation_matrix(angle_axis, theta2, eps=1e-6): 168 | # We want to be careful to only evaluate the square root if the 169 | # norm of the angle_axis vector is greater than zero. Otherwise 170 | # we get a division by zero. 171 | k_one = 1.0 172 | theta = torch.sqrt(theta2) 173 | wxyz = angle_axis / (theta + eps) 174 | wx, wy, wz = torch.chunk(wxyz, 3, dim=1) 175 | cos_theta = torch.cos(theta) 176 | sin_theta = torch.sin(theta) 177 | 178 | r00 = cos_theta + wx * wx * (k_one - cos_theta) 179 | r10 = wz * sin_theta + wx * wy * (k_one - cos_theta) 180 | r20 = -wy * sin_theta + wx * wz * (k_one - cos_theta) 181 | r01 = wx * wy * (k_one - cos_theta) - wz * sin_theta 182 | r11 = cos_theta + wy * wy * (k_one - cos_theta) 183 | r21 = wx * sin_theta + wy * wz * (k_one - cos_theta) 184 | r02 = wy * sin_theta + wx * wz * (k_one - cos_theta) 185 | r12 = -wx * sin_theta + wy * wz * (k_one - cos_theta) 186 | r22 = cos_theta + wz * wz * (k_one - cos_theta) 187 | rotation_matrix = torch.cat( 188 | [r00, r01, r02, r10, r11, r12, r20, r21, r22], dim=1) 189 | return rotation_matrix.view(-1, 3, 3) 190 | 191 | def _compute_rotation_matrix_taylor(angle_axis): 192 | rx, ry, rz = torch.chunk(angle_axis, 3, dim=1) 193 | k_one = torch.ones_like(rx) 194 | rotation_matrix = torch.cat( 195 | [k_one, -rz, ry, rz, k_one, -rx, -ry, rx, k_one], dim=1) 196 | return rotation_matrix.view(-1, 3, 3) 197 | 198 | # stolen from ceres/rotation.h 199 | 200 | _angle_axis = torch.unsqueeze(angle_axis, dim=1) 201 | theta2 = torch.matmul(_angle_axis, _angle_axis.transpose(1, 2)) 202 | theta2 = torch.squeeze(theta2, dim=1) 203 | 204 | # compute rotation matrices 205 | rotation_matrix_normal = _compute_rotation_matrix(angle_axis, theta2) 206 | rotation_matrix_taylor = _compute_rotation_matrix_taylor(angle_axis) 207 | 208 | # create mask to handle both cases 209 | eps = 1e-6 210 | mask = (theta2 > eps).view(-1, 1, 1).to(theta2.device) 211 | mask_pos = (mask).type_as(theta2) 212 | mask_neg = (mask == False).type_as(theta2) # noqa 213 | 214 | # create output pose matrix 215 | batch_size = angle_axis.shape[0] 216 | rotation_matrix = torch.eye(4).to(angle_axis.device).type_as(angle_axis) 217 | rotation_matrix = rotation_matrix.view(1, 4, 4).repeat(batch_size, 1, 1) 218 | # fill output matrix with masked values 219 | rotation_matrix[..., :3, :3] = \ 220 | mask_pos * rotation_matrix_normal + mask_neg * rotation_matrix_taylor 221 | return rotation_matrix # Nx4x4 222 | 223 | def rotation_matrix_to_angle_axis(rotation_matrix): 224 | """Convert 3x4 rotation matrix to Rodrigues vector 225 | 226 | Args: 227 | rotation_matrix (Tensor): rotation matrix. 228 | 229 | Returns: 230 | Tensor: Rodrigues vector transformation. 231 | 232 | Shape: 233 | - Input: :math:`(N, 3, 4)` 234 | - Output: :math:`(N, 3)` 235 | 236 | Example: 237 | >>> input = torch.rand(2, 3, 4) # Nx4x4 238 | >>> output = rotation_matrix_to_angle_axis(input) # Nx3 239 | """ 240 | # todo add check that matrix is a valid rotation matrix 241 | quaternion = rotation_matrix_to_quaternion(rotation_matrix) 242 | return quaternion_to_angle_axis(quaternion) 243 | 244 | def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6): 245 | """Convert 3x4 rotation matrix to 4d quaternion vector 246 | 247 | This algorithm is based on algorithm described in 248 | https://github.com/KieranWynn/pyquaternion/blob/master/pyquaternion/quaternion.py#L201 249 | 250 | Args: 251 | rotation_matrix (Tensor): the rotation matrix to convert. 252 | 253 | Return: 254 | Tensor: the rotation in quaternion 255 | 256 | Shape: 257 | - Input: :math:`(N, 3, 4)` 258 | - Output: :math:`(N, 4)` 259 | 260 | Example: 261 | >>> input = torch.rand(4, 3, 4) # Nx3x4 262 | >>> output = rotation_matrix_to_quaternion(input) # Nx4 263 | """ 264 | if not torch.is_tensor(rotation_matrix): 265 | raise TypeError("Input type is not a torch.Tensor. Got {}".format( 266 | type(rotation_matrix))) 267 | 268 | if len(rotation_matrix.shape) > 3: 269 | raise ValueError( 270 | "Input size must be a three dimensional tensor. Got {}".format( 271 | rotation_matrix.shape)) 272 | if not rotation_matrix.shape[-2:] == (3, 4): 273 | raise ValueError( 274 | "Input size must be a N x 3 x 4 tensor. Got {}".format( 275 | rotation_matrix.shape)) 276 | 277 | rmat_t = torch.transpose(rotation_matrix, 1, 2) 278 | 279 | mask_d2 = rmat_t[:, 2, 2] < eps 280 | 281 | mask_d0_d1 = rmat_t[:, 0, 0] > rmat_t[:, 1, 1] 282 | mask_d0_nd1 = rmat_t[:, 0, 0] < -rmat_t[:, 1, 1] 283 | 284 | t0 = 1 + rmat_t[:, 0, 0] - rmat_t[:, 1, 1] - rmat_t[:, 2, 2] 285 | q0 = torch.stack([rmat_t[:, 1, 2] - rmat_t[:, 2, 1], 286 | t0, rmat_t[:, 0, 1] + rmat_t[:, 1, 0], 287 | rmat_t[:, 2, 0] + rmat_t[:, 0, 2]], -1) 288 | t0_rep = t0.repeat(4, 1).t() 289 | 290 | t1 = 1 - rmat_t[:, 0, 0] + rmat_t[:, 1, 1] - rmat_t[:, 2, 2] 291 | q1 = torch.stack([rmat_t[:, 2, 0] - rmat_t[:, 0, 2], 292 | rmat_t[:, 0, 1] + rmat_t[:, 1, 0], 293 | t1, rmat_t[:, 1, 2] + rmat_t[:, 2, 1]], -1) 294 | t1_rep = t1.repeat(4, 1).t() 295 | 296 | t2 = 1 - rmat_t[:, 0, 0] - rmat_t[:, 1, 1] + rmat_t[:, 2, 2] 297 | q2 = torch.stack([rmat_t[:, 0, 1] - rmat_t[:, 1, 0], 298 | rmat_t[:, 2, 0] + rmat_t[:, 0, 2], 299 | rmat_t[:, 1, 2] + rmat_t[:, 2, 1], t2], -1) 300 | t2_rep = t2.repeat(4, 1).t() 301 | 302 | t3 = 1 + rmat_t[:, 0, 0] + rmat_t[:, 1, 1] + rmat_t[:, 2, 2] 303 | q3 = torch.stack([t3, rmat_t[:, 1, 2] - rmat_t[:, 2, 1], 304 | rmat_t[:, 2, 0] - rmat_t[:, 0, 2], 305 | rmat_t[:, 0, 1] - rmat_t[:, 1, 0]], -1) 306 | t3_rep = t3.repeat(4, 1).t() 307 | 308 | mask_c0 = mask_d2 * mask_d0_d1 309 | mask_c1 = mask_d2 * (~mask_d0_d1) 310 | mask_c2 = (~mask_d2) * mask_d0_nd1 311 | mask_c3 = (~mask_d2) * (~mask_d0_nd1) 312 | mask_c0 = mask_c0.view(-1, 1).type_as(q0) 313 | mask_c1 = mask_c1.view(-1, 1).type_as(q1) 314 | mask_c2 = mask_c2.view(-1, 1).type_as(q2) 315 | mask_c3 = mask_c3.view(-1, 1).type_as(q3) 316 | 317 | q = q0 * mask_c0 + q1 * mask_c1 + q2 * mask_c2 + q3 * mask_c3 318 | q /= torch.sqrt(t0_rep * mask_c0 + t1_rep * mask_c1 + # noqa 319 | t2_rep * mask_c2 + t3_rep * mask_c3) # noqa 320 | q *= 0.5 321 | return q 322 | 323 | def quaternion_to_angle_axis(quaternion: torch.Tensor) -> torch.Tensor: 324 | """Convert quaternion vector to angle axis of rotation. 325 | 326 | Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h 327 | 328 | Args: 329 | quaternion (torch.Tensor): tensor with quaternions. 330 | 331 | Return: 332 | torch.Tensor: tensor with angle axis of rotation. 333 | 334 | Shape: 335 | - Input: :math:`(*, 4)` where `*` means, any number of dimensions 336 | - Output: :math:`(*, 3)` 337 | 338 | Example: 339 | >>> quaternion = torch.rand(2, 4) # Nx4 340 | >>> angle_axis = quaternion_to_angle_axis(quaternion) # Nx3 341 | """ 342 | if not torch.is_tensor(quaternion): 343 | raise TypeError("Input type is not a torch.Tensor. Got {}".format( 344 | type(quaternion))) 345 | 346 | if not quaternion.shape[-1] == 4: 347 | raise ValueError("Input must be a tensor of shape Nx4 or 4. Got {}" 348 | .format(quaternion.shape)) 349 | # unpack input and compute conversion 350 | q1: torch.Tensor = quaternion[..., 1] 351 | q2: torch.Tensor = quaternion[..., 2] 352 | q3: torch.Tensor = quaternion[..., 3] 353 | sin_squared_theta: torch.Tensor = q1 * q1 + q2 * q2 + q3 * q3 354 | 355 | sin_theta: torch.Tensor = torch.sqrt(sin_squared_theta) 356 | cos_theta: torch.Tensor = quaternion[..., 0] 357 | two_theta: torch.Tensor = 2.0 * torch.where( 358 | cos_theta < 0.0, 359 | torch.atan2(-sin_theta, -cos_theta), 360 | torch.atan2(sin_theta, cos_theta)) 361 | 362 | k_pos: torch.Tensor = two_theta / sin_theta 363 | k_neg: torch.Tensor = 2.0 * torch.ones_like(sin_theta) 364 | k: torch.Tensor = torch.where(sin_squared_theta > 0.0, k_pos, k_neg) 365 | 366 | angle_axis: torch.Tensor = torch.zeros_like(quaternion)[..., :3] 367 | angle_axis[..., 0] += q1 * k 368 | angle_axis[..., 1] += q2 * k 369 | angle_axis[..., 2] += q3 * k 370 | return angle_axis 371 | 372 | 373 | class RotConverter(nn.Module): 374 | ''' 375 | this class is from smplx/vposer 376 | ''' 377 | def __init__(self): 378 | super(RotConverter, self).__init__() 379 | 380 | def forward(self,module_input): 381 | pass 382 | 383 | 384 | @staticmethod 385 | def cont2rotmat(module_input): 386 | reshaped_input = module_input.view(-1, 3, 2) 387 | b1 = F.normalize(reshaped_input[:, :, 0], dim=1) 388 | dot_prod = torch.sum(b1 * reshaped_input[:, :, 1], dim=1, keepdim=True) 389 | b2 = F.normalize(reshaped_input[:, :, 1] - dot_prod * b1, dim=-1) 390 | b3 = torch.cross(b1, b2, dim=1) 391 | 392 | return torch.stack([b1, b2, b3], dim=-1) 393 | 394 | 395 | @staticmethod 396 | def aa2cont(module_input): 397 | ''' 398 | :param NxTxnum_jointsx3 399 | :return: pose_matrot: NxTxnum_jointsx6 400 | ''' 401 | batch_size = module_input.shape[0] 402 | n_frames = module_input.shape[1] 403 | pose_body_6d = angle_axis_to_rotation_matrix(module_input.reshape(-1, 3))[:, :3, :2].contiguous().view(batch_size, n_frames, -1, 6) 404 | 405 | return pose_body_6d 406 | 407 | 408 | @staticmethod 409 | def rotmat2aa(pose_matrot): 410 | ''' 411 | :param pose_matrot: Nx1xnum_jointsx9 412 | :return: Nx1xnum_jointsx3 413 | ''' 414 | homogen_matrot = F.pad(pose_matrot.view(-1, 3, 3), [0,1]) 415 | pose = rotation_matrix_to_angle_axis(homogen_matrot).view(-1, 3).contiguous() 416 | 417 | return pose 418 | 419 | 420 | @staticmethod 421 | def aa2rotmat(pose): 422 | ''' 423 | :param Nx1xnum_jointsx3 424 | :return: pose_matrot: Nx1xnum_jointsx9 425 | ''' 426 | batch_size = pose.shape[0] 427 | n_frames = pose.shape[1] 428 | pose_body_matrot = angle_axis_to_rotation_matrix(pose.reshape(-1, 3))[:, :3, :3].contiguous().view(batch_size, n_frames, -1, 9) 429 | 430 | return pose_body_matrot 431 | 432 | 433 | -------------------------------------------------------------------------------- /visualization/__pycache__/visualization_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiahaoPlus/SAGA/ecaed33b580afbfa61a29a8b6bd4ad019de91108/visualization/__pycache__/visualization_utils.cpython-38.pyc -------------------------------------------------------------------------------- /visualization/vis_motion.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.append(os.getcwd()) 5 | import argparse 6 | import pickle 7 | 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | import open3d as o3d 11 | import torch 12 | import torch.nn.functional as F 13 | from tqdm import tqdm 14 | 15 | from visualization_utils import (color_hex2rgb, create_lineset, get_body_mesh, 16 | get_object_mesh, update_cam) 17 | 18 | 19 | def vis_graspmotion_third_view(body_meshes, object_mesh, object_transl, sample_index): 20 | mesh_frame = o3d.geometry.TriangleMesh.create_coordinate_frame( 21 | size=0.5, origin=[0, 0, 0]) 22 | 23 | vis = o3d.visualization.Visualizer() 24 | vis.create_window() 25 | # vis.add_geometry(mesh_frame) 26 | render_opt=vis.get_render_option() 27 | render_opt.mesh_show_back_face=True 28 | render_opt.line_width=10 29 | render_opt.point_size=5 30 | render_opt.background_color = color_hex2rgb('#1c2434') 31 | 32 | vis.add_geometry(object_mesh) 33 | 34 | x_range = np.arange(-200, 200, 0.75) 35 | y_range = np.arange(-200, 200, 0.75) 36 | z_range = np.arange(0, 1, 1) 37 | gp_lines, gp_pcd = create_lineset(x_range, y_range, z_range) 38 | gp_lines.paint_uniform_color(color_hex2rgb('#7ea4ab')) 39 | gp_pcd.paint_uniform_color(color_hex2rgb('#7ea4ab')) 40 | vis.add_geometry(gp_lines) 41 | vis.poll_events() 42 | vis.update_renderer() 43 | vis.add_geometry(gp_pcd) 44 | vis.poll_events() 45 | vis.update_renderer() 46 | 47 | for t in range(len(body_meshes)): 48 | vis.add_geometry(body_meshes[t]) 49 | 50 | ### get cam R 51 | ### update render cam 52 | ctr = vis.get_view_control() 53 | cam_param = ctr.convert_to_pinhole_camera_parameters() 54 | trans = np.eye(4) 55 | trans[:3, :3] = np.array([[0, 0, -1], [-1, 0, 0], [0, -1, 0]]) 56 | trans[:3, -1] = np.array([4, 1, 1]) 57 | cam_param = update_cam(cam_param, trans) 58 | ctr.convert_from_pinhole_camera_parameters(cam_param) 59 | vis.poll_events() 60 | vis.update_renderer() 61 | 62 | vis.capture_screen_image( 63 | vis_save_path_third_view+"/clip_%04d_%04d.jpg" % (sample_index, t), True) 64 | vis.remove_geometry(body_meshes[t]) 65 | 66 | 67 | def vis_graspmotion_first_view(body_meshes, object_mesh, object_transl, sample_index): 68 | mesh_frame = o3d.geometry.TriangleMesh.create_coordinate_frame( 69 | size=0.5, origin=[0, 0, 0]) 70 | 71 | vis = o3d.visualization.Visualizer() 72 | vis.create_window() 73 | render_opt=vis.get_render_option() 74 | render_opt.mesh_show_back_face=True 75 | render_opt.line_width=10 76 | render_opt.point_size=5 77 | 78 | render_opt.background_color = color_hex2rgb('#545454') 79 | 80 | vis.add_geometry(object_mesh) 81 | 82 | for t in range(len(body_meshes)): 83 | vis.add_geometry(body_meshes[t]) 84 | 85 | ### get cam R 86 | cam_o = np.array(body_meshes[t].vertices)[8999] 87 | cam_z = object_transl - cam_o 88 | cam_z = cam_z / np.linalg.norm(cam_z) 89 | cam_x = np.array([cam_z[1], -cam_z[0], 0.0]) 90 | cam_x = cam_x / np.linalg.norm(cam_x) 91 | cam_y = np.array([cam_z[0], cam_z[1], -(cam_z[0]**2 + cam_z[1]**2)/cam_z[2] ]) 92 | cam_y = cam_y / np.linalg.norm(cam_y) 93 | cam_r = np.stack([cam_x, -cam_y, cam_z], axis=1) 94 | ### update render cam 95 | ctr = vis.get_view_control() 96 | cam_param = ctr.convert_to_pinhole_camera_parameters() 97 | transf = np.eye(4) 98 | transf[:3,:3]=cam_r 99 | transf[:3,-1] = cam_o 100 | cam_param = update_cam(cam_param, transf) 101 | ctr.convert_from_pinhole_camera_parameters(cam_param) 102 | vis.poll_events() 103 | vis.update_renderer() 104 | 105 | vis.capture_screen_image( 106 | vis_save_path_first_view+"/clip_%04d_%04d.jpg" % (sample_index, t), True) 107 | vis.remove_geometry(body_meshes[t]) 108 | 109 | 110 | def vis_graspmotion_top_view(body_meshes, object_mesh, object_transl, sample_index): 111 | mesh_frame = o3d.geometry.TriangleMesh.create_coordinate_frame( 112 | size=0.5, origin=[0, 0, 0]) 113 | 114 | vis = o3d.visualization.Visualizer() 115 | vis.create_window() 116 | render_opt=vis.get_render_option() 117 | render_opt.mesh_show_back_face=True 118 | render_opt.line_width=10 119 | render_opt.point_size=5 120 | render_opt.background_color = color_hex2rgb('#1c2434') 121 | 122 | vis.add_geometry(object_mesh) 123 | 124 | x_range = np.arange(-200, 200, 0.75) 125 | y_range = np.arange(-200, 200, 0.75) 126 | z_range = np.arange(0, 1, 1) 127 | gp_lines, gp_pcd = create_lineset(x_range, y_range, z_range) 128 | gp_lines.paint_uniform_color(color_hex2rgb('#7ea4ab')) 129 | gp_pcd.paint_uniform_color(color_hex2rgb('#7ea4ab')) 130 | vis.add_geometry(gp_lines) 131 | vis.poll_events() 132 | vis.update_renderer() 133 | vis.add_geometry(gp_pcd) 134 | vis.poll_events() 135 | vis.update_renderer() 136 | 137 | for t in range(len(body_meshes)): 138 | vis.add_geometry(body_meshes[t]) 139 | 140 | ctr = vis.get_view_control() 141 | cam_param = ctr.convert_to_pinhole_camera_parameters() 142 | 143 | cam_o = np.array([0, 0, 4]) 144 | reference_point = np.zeros(3) 145 | reference_point[:2] = object_transl[:2]/2 146 | cam_z = reference_point - cam_o 147 | cam_z = cam_z / np.linalg.norm(cam_z) 148 | cam_x = np.array([cam_z[1], -cam_z[0], 0.0]) 149 | cam_x = cam_x / np.linalg.norm(cam_x) 150 | cam_y = np.array([cam_z[0], cam_z[1], -(cam_z[0]**2 + cam_z[1]**2)/cam_z[2] ]) 151 | cam_y = cam_y / np.linalg.norm(cam_y) 152 | cam_r = np.stack([cam_x, -cam_y, cam_z], axis=1) 153 | ### update render cam 154 | ctr = vis.get_view_control() 155 | cam_param = ctr.convert_to_pinhole_camera_parameters() 156 | transf = np.eye(4) 157 | transf[:3,:3]=cam_r 158 | transf[:3,-1] = cam_o 159 | 160 | 161 | cam_param = update_cam(cam_param, transf) 162 | ctr.convert_from_pinhole_camera_parameters(cam_param) 163 | vis.poll_events() 164 | vis.update_renderer() 165 | 166 | vis.capture_screen_image( 167 | vis_save_path_top_view+"/clip_%04d_%04d.jpg" % (sample_index, t), True) 168 | vis.remove_geometry(body_meshes[t]) 169 | 170 | if __name__ == '__main__': 171 | 172 | parser = argparse.ArgumentParser(description='visualization from saved') 173 | 174 | parser.add_argument('--GraspPose_exp_name', type=str, help='exp name') 175 | parser.add_argument('--dataset', default='GRAB', type=str, help='exp name') 176 | parser.add_argument('--object', default='camera', type=str, help='object name') 177 | parser.add_argument('--gender', type=str, help='object name') 178 | 179 | args = parser.parse_args() 180 | 181 | result_path = '../results/{}/GraspMotion/{}'.format(args.GraspPose_exp_name, args.object) 182 | opt_results = np.load('{}/fitting_results.npy'.format(result_path), allow_pickle=True)[()] 183 | 184 | print('Saving visualization results to {}').format(result_path) 185 | 186 | vis_save_path_first_view = os.path.join(result_path, 'visualization/first_view') 187 | vis_save_path_third_view = os.path.join(result_path, 'visualization/third_view') 188 | vis_save_path_top_view = os.path.join(result_path, 'visualization/top_view') 189 | if not os.path.exists(vis_save_path_first_view): 190 | os.makedirs(vis_save_path_first_view) 191 | if not os.path.exists(vis_save_path_third_view): 192 | os.makedirs(vis_save_path_third_view) 193 | if not os.path.exists(vis_save_path_top_view): 194 | os.makedirs(vis_save_path_top_view) 195 | 196 | object_params = opt_results['object'] # Tensor 197 | object_name = str(opt_results['object_name']) 198 | body_orig = opt_results['body_orig'] # Numpy 199 | body_opt = opt_results['body_opt'] 200 | 201 | for key in body_orig: 202 | body_orig[key] = torch.from_numpy(body_orig[key]) 203 | body_opt[key] = torch.from_numpy(body_opt[key]) 204 | 205 | B, T, _ = body_orig['transl'].shape 206 | 207 | object_mesh = get_object_mesh(object_name, args.dataset, object_params['transl'][:B], object_params['global_orient'][:B], B, rotmat=True) 208 | 209 | # frame by frame visualization 210 | coord = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.25) 211 | 212 | x_range = np.arange(-200, 200, 0.75) 213 | y_range = np.arange(-200, 200, 0.75) 214 | z_range = np.arange(0, 1, 1) 215 | gp_lines, gp_pcd = create_lineset(x_range, y_range, z_range) 216 | 217 | collision_eval_list = [] 218 | 219 | camera_point_index = 8999 220 | 221 | pos_list, vel_list, acc_list = [], [], [] 222 | 223 | for i in tqdm(range(0, B)): 224 | collision_eval_T = {} 225 | collision_eval_T['vol'] = [] 226 | collision_eval_T['depth'] = [] 227 | collision_eval_T['contact'] = [] 228 | # get body mesh 229 | orig_smplxparams = {} 230 | opt_smplxparams = {} 231 | for k in body_orig.keys(): 232 | orig_smplxparams[k] = body_orig[k][i] 233 | opt_smplxparams[k] = body_opt[k][i] 234 | body_meshes_orig, _ = get_body_mesh(orig_smplxparams, args.gender, n_samples=T, device='cpu', color='D4BEA3') 235 | body_meshes_opt, _ = get_body_mesh(opt_smplxparams, args.gender, n_samples=T, device='cpu', color='D4BEA3') 236 | 237 | 238 | body_meshes = body_meshes_opt 239 | 240 | vis_graspmotion_first_view(body_meshes, object_mesh[i], object_params['transl'][i], i) 241 | vis_graspmotion_top_view(body_meshes, object_mesh[i], object_params['transl'][i], i) 242 | vis_graspmotion_third_view(body_meshes, object_mesh[i], object_params['transl'][i], i) 243 | -------------------------------------------------------------------------------- /visualization/vis_pose.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | 5 | import numpy as np 6 | import open3d as o3d 7 | 8 | from visualization_utils import * 9 | 10 | if __name__ == '__main__': 11 | 12 | parser = argparse.ArgumentParser(description='grabpose-Testing') 13 | 14 | parser.add_argument('--exp_name', default = None, type=str, 15 | help='experiment name') 16 | 17 | parser.add_argument('--gender', default = None, type=str, 18 | help='gender') 19 | 20 | parser.add_argument('--object', default = None, type=str, 21 | help='object name') 22 | 23 | parser.add_argument('--object_format', default = 'mesh', type=str, 24 | help='pcd or mesh') 25 | 26 | args = parser.parse_args() 27 | 28 | 29 | cwd = os.getcwd() 30 | 31 | load_path = '../results/{}/GraspPose/{}/fitting_results.npz'.format(args.exp_name, args.object) 32 | 33 | data = np.load(load_path, allow_pickle=True) 34 | gender = args.gender 35 | object_name = args.object 36 | 37 | n_samples = len(data['markers']) 38 | 39 | # Prepare mesh and pcd 40 | object_pcd = get_pcd(data['object'][()]['verts_object'][:n_samples]) 41 | object_mesh = get_object_mesh(object_name, 'GRAB', data['object'][()]['transl'][:n_samples], data['object'][()]['global_orient'][:n_samples], n_samples) 42 | body_mesh, _ = get_body_mesh(data['body'][()], gender, n_samples) 43 | 44 | 45 | # ground 46 | x_range = np.arange(-5, 50, 1) 47 | y_range = np.arange(-5, 50, 1) 48 | z_range = np.arange(0, 1, 1) 49 | gp_lines, gp_pcd = create_lineset(x_range, y_range, z_range) 50 | gp_lines.paint_uniform_color(color_hex2rgb('#bdbfbe')) # grey 51 | gp_pcd.paint_uniform_color(color_hex2rgb('#bdbfbe')) # grey 52 | coord = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.25) 53 | 54 | for i in range(n_samples): 55 | print(body_mesh[i]) 56 | visualization_list = [body_mesh[i], object_mesh[i], coord, gp_lines, gp_pcd] 57 | o3d.visualization.draw_geometries(visualization_list) 58 | 59 | -------------------------------------------------------------------------------- /visualization/visualization_utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append('..') 4 | import os 5 | 6 | import numpy as np 7 | import open3d as o3d 8 | import smplx 9 | import torch 10 | from WholeGraspPose.models.objectmodel import ObjectModel 11 | 12 | def update_cam(cam_param, trans): 13 | cam_R = np.transpose(trans[:-1, :-1]) 14 | cam_T = -trans[:-1, -1:] 15 | cam_T = np.matmul(cam_R, cam_T) # !!!!!! T is applied in the rotated coord 16 | cam_aux = np.array([[0, 0, 0, 1]]) 17 | mat = np.concatenate([cam_R, cam_T], axis=-1) 18 | mat = np.concatenate([mat, cam_aux], axis=0) 19 | cam_param.extrinsic = mat 20 | return cam_param 21 | 22 | def color_hex2rgb(hex): 23 | h = hex.lstrip('#') 24 | return np.array( tuple(int(h[i:i+2], 16) for i in (0, 2, 4)) )/255 25 | def create_lineset(x_range, y_range, z_range): 26 | gp_lines = o3d.geometry.LineSet() 27 | gp_pcd = o3d.geometry.PointCloud() 28 | points = np.stack(np.meshgrid(x_range, y_range, z_range), axis=-1) 29 | 30 | lines = [] 31 | for ii in range( x_range.shape[0]-1): 32 | for jj in range(y_range.shape[0]-1): 33 | lines.append(np.array([ii*x_range.shape[0]+jj, ii*x_range.shape[0]+jj+1])) 34 | lines.append(np.array([ii*x_range.shape[0]+jj, ii*x_range.shape[0]+jj+y_range.shape[0]])) 35 | 36 | points = np.reshape(points, [-1,3]) 37 | colors = np.random.rand(len(lines), 3)*0.5+0.5 38 | 39 | gp_lines.points = o3d.utility.Vector3dVector(points) 40 | gp_lines.colors = o3d.utility.Vector3dVector(colors) 41 | gp_lines.lines = o3d.utility.Vector2iVector(np.stack(lines,axis=0)) 42 | gp_pcd.points = o3d.utility.Vector3dVector(points) 43 | 44 | return gp_lines, gp_pcd 45 | 46 | def get_body_model(type, gender, batch_size,device='cuda',v_template=None): 47 | ''' 48 | type: smpl, smplx smplh and others. Refer to smplx tutorial 49 | gender: male, female, neutral 50 | batch_size: an positive integar 51 | ''' 52 | body_model_path = '../body_utils/body_models' 53 | body_model = smplx.create(body_model_path, model_type=type, 54 | gender=gender, ext='npz', 55 | num_pca_comps=24, 56 | create_global_orient=True, 57 | create_body_pose=True, 58 | create_betas=True, 59 | create_left_hand_pose=True, 60 | create_right_hand_pose=True, 61 | create_expression=True, 62 | create_jaw_pose=True, 63 | create_leye_pose=True, 64 | create_reye_pose=True, 65 | create_transl=True, 66 | batch_size=batch_size, 67 | v_template=v_template 68 | ) 69 | if device == 'cuda': 70 | return body_model.cuda() 71 | else: 72 | return body_model 73 | 74 | def get_body_mesh(smplxparams, gender, n_samples, device='cpu', color=None): 75 | body_mesh_list = [] 76 | 77 | for key in smplxparams.keys(): 78 | # print(key, smplxparams[key].shape) 79 | smplxparams[key] = torch.tensor(smplxparams[key][:n_samples]).to(device) 80 | 81 | 82 | bm = get_body_model('smplx', str(gender), n_samples, device=device) 83 | smplx_results = bm(return_verts=True, **smplxparams) 84 | verts = smplx_results.vertices.detach().cpu().numpy() 85 | face = bm.faces 86 | 87 | for i in range(n_samples): 88 | mesh = o3d.geometry.TriangleMesh() 89 | mesh.vertices = o3d.utility.Vector3dVector(verts[i]) 90 | mesh.triangles = o3d.utility.Vector3iVector(face) 91 | mesh.compute_vertex_normals() 92 | if color is not None: 93 | mesh.paint_uniform_color(color_hex2rgb(color)) # orange 94 | body_mesh_list.append(mesh) 95 | 96 | return body_mesh_list, smplx_results 97 | 98 | def get_object_mesh(obj, dataset, transl, global_orient, n_samples, device='cpu', rotmat=False): 99 | object_mesh_list = [] 100 | global_orient_dim = 9 if rotmat else 3 101 | 102 | global_orient = torch.FloatTensor(global_orient).to(device) 103 | transl = torch.FloatTensor(transl).to(device) 104 | 105 | if dataset == 'GRAB': 106 | mesh_base = '../dataset/contact_meshes' 107 | # mesh_base = '/home/dalco/wuyan/data/GRAB/tools/object_meshes/contact_meshes' 108 | obj_mesh_base = o3d.io.read_triangle_mesh(os.path.join(mesh_base, obj + '.ply')) 109 | elif dataset == 'FHB': 110 | mesh_base = '/home/dalco/wuyan/data/FHB/Object_models' 111 | obj_mesh_base = o3d.io.read_triangle_mesh(os.path.join(mesh_base, '{}_model/{}_model.ply'.format(obj, obj))) 112 | elif dataset == 'HO3D': 113 | mesh_base = '/home/dalco/wuyan/data/HO3D/YCB_Video_Models/models' 114 | obj_mesh_base = o3d.io.read_triangle_mesh(os.path.join(mesh_base, '{}/textured_simple.obj'.format(obj))) 115 | elif dataset == 'ShapeNet': 116 | mesh_base = '/home/dalco/wuyan/data/ShapeNet/ShapeNet_selected' 117 | obj_mesh_base = o3d.io.read_triangle_mesh(os.path.join(mesh_base, '{}.obj'.format(obj))) 118 | obj_mesh_base.scale(0.15, center=np.zeros((3, 1))) 119 | else: 120 | raise NotImplementedError 121 | 122 | obj_mesh_base.compute_vertex_normals() 123 | v_temp = torch.FloatTensor(obj_mesh_base.vertices).to(device).view(1, -1, 3).repeat(n_samples, 1, 1) 124 | normal_temp = torch.FloatTensor(obj_mesh_base.vertex_normals).to(device).view(1, -1, 3).repeat(n_samples, 1, 1) 125 | obj_model = ObjectModel(v_temp, normal_temp, n_samples) 126 | # global_orient_dim = 9 if rotmat else 3 127 | object_output = obj_model(global_orient.view(n_samples, global_orient_dim), transl.view(n_samples, 3), v_temp.to(device), normal_temp.to(device), rotmat) 128 | 129 | object_verts = object_output[0].detach().squeeze().view(n_samples, -1, 3).cpu().numpy() 130 | # object_vertex_normal = object_output[1].detach().squeeze().cpu().numpy() 131 | 132 | for i in range(n_samples): 133 | mesh = o3d.geometry.TriangleMesh() 134 | mesh.vertices = o3d.utility.Vector3dVector(object_verts[i]) 135 | mesh.triangles = obj_mesh_base.triangles 136 | mesh.compute_vertex_normals() 137 | mesh.paint_uniform_color(color_hex2rgb('#f59002')) # orange 138 | object_mesh_list.append(mesh) 139 | 140 | return object_mesh_list 141 | 142 | 143 | # def get_object_mesh(obj, transl, global_orient, n_samples, device='cpu'): 144 | # object_mesh_list = [] 145 | 146 | # global_orient = torch.FloatTensor(global_orient).to(device) 147 | # transl = torch.FloatTensor(transl).to(device) 148 | 149 | # mesh_base = '../dataset/contact_meshes' 150 | # obj_mesh_base = o3d.io.read_triangle_mesh(os.path.join(mesh_base, obj + '.ply')) 151 | # obj_mesh_base.compute_vertex_normals() 152 | # v_temp = torch.FloatTensor(obj_mesh_base.vertices).to(device).view(1, -1, 3).repeat(n_samples, 1, 1) 153 | # normal_temp = torch.FloatTensor(obj_mesh_base.vertex_normals).to(device).view(1, -1, 3).repeat(n_samples, 1, 1) 154 | # obj_model = ObjectModel(v_temp, normal_temp, n_samples) 155 | # object_output = obj_model(global_orient.view(n_samples, 3), transl.view(n_samples, 3), v_temp.to(device), normal_temp.to(device)) 156 | 157 | # object_verts = object_output[0].detach().squeeze().cpu().numpy() 158 | # # object_vertex_normal = object_output[1].detach().squeeze().cpu().numpy() 159 | 160 | # for i in range(n_samples): 161 | # mesh = o3d.geometry.TriangleMesh() 162 | # print('debug:', object_verts[i].shape) 163 | # mesh.vertices = o3d.utility.Vector3dVector(object_verts[i]) 164 | # mesh.triangles = obj_mesh_base.triangles 165 | # mesh.compute_vertex_normals() 166 | # mesh.paint_uniform_color(color_hex2rgb('#f59002')) # orange 167 | # object_mesh_list.append(mesh) 168 | 169 | # return object_mesh_list 170 | 171 | def get_pcd(points, contact=None): 172 | print(points.shape) 173 | pcd_list = [] 174 | n_samples = points.shape[0] 175 | for i in range(n_samples): 176 | pcd = o3d.geometry.PointCloud() 177 | pcd.points = o3d.utility.Vector3dVector(points[i]) 178 | if contact is not None: 179 | colors = np.zeros((points.shape[1], 3)) 180 | colors[:, 0] = contact[i].squeeze() 181 | pcd.colors = o3d.utility.Vector3dVector(colors) 182 | pcd_list.append(pcd) 183 | return pcd_list 184 | 185 | --------------------------------------------------------------------------------