├── README.md ├── refu_tailornet ├── .DS_Store ├── dataset │ ├── __init__.py │ ├── canonical_pose_dataset.py │ └── static_pose_shape_final.py ├── global_var.py ├── models │ ├── __init__.py │ ├── geometric_loss.py │ ├── networks.py │ ├── ops.py │ ├── sdf_collision_response_model.py │ ├── skirt_model.md │ ├── smpl4garment.py │ ├── soft_collision_loss.py │ ├── tailornet_model.py │ └── torch_smpl4garment.py ├── smpl_lib │ ├── __init__.py │ ├── ch.py │ ├── ch_smpl.py │ ├── lbs.py │ ├── posemapper.py │ ├── serialization.py │ └── verts.py ├── tnutils │ ├── __init__.py │ ├── diffusion_smoothing.py │ ├── eval.py │ ├── eval_col_info.py │ ├── geometry.py │ ├── interpenetration.py │ ├── io.py │ ├── logger.py │ ├── renderer.py │ ├── rotation.py │ ├── sio.py │ └── smpl_paths.py └── trainer │ ├── __init__.py │ ├── base_trainer.py │ ├── base_trainer_col_info.py │ ├── eg_trainer.py │ ├── hf_trainer.py │ ├── hf_trainer_col_info.py │ ├── lf_trainer.py │ ├── refu_trainer.py │ └── ss2g_trainer.py └── sdf ├── datasets ├── __init__.py └── smpldataset.py ├── global_var_fun.py ├── igrutils ├── __init__.py ├── general.py └── plots.py ├── model ├── __init__.py ├── network.py └── sample.py └── shapespace ├── __init__.py ├── smpl_setup.conf └── train_smpl_sdf_value.py /README.md: -------------------------------------------------------------------------------- 1 | # ReFU 2 | Code for paper "A Repulsive Force Unit for Garment Collision Handling in Neural Networks" 3 | 4 | ## Related Work 5 | Based on code from 6 | 7 | 1. TailorNet: Predicting Clothing in 3D as a Function of Human Pose, Shape and Garment Style [Code](https://github.com/chaitanya100100/TailorNet) 8 | 9 | 2. IGR: Implicit Geometric Regularization for Learning Shapes [Code](https://github.com/amosgropp/IGR) 10 | 11 | -------------------------------------------------------------------------------- /refu_tailornet/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aldehydecho/ReFU/9476bb1da1b8adc1001d21758e884aefb91ffd60/refu_tailornet/.DS_Store -------------------------------------------------------------------------------- /refu_tailornet/dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aldehydecho/ReFU/9476bb1da1b8adc1001d21758e884aefb91ffd60/refu_tailornet/dataset/__init__.py -------------------------------------------------------------------------------- /refu_tailornet/dataset/canonical_pose_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | from torch.utils.data import Dataset 5 | 6 | import global_var 7 | from tnutils.rotation import get_Apose 8 | 9 | 10 | def get_style(style_idx, gender, garment_class): 11 | gammas = np.load(os.path.join( 12 | global_var.DATA_DIR, 13 | '{}_{}/style/gamma_{}.npy'.format(garment_class, gender, style_idx) 14 | )).astype(np.float32) 15 | return gammas 16 | 17 | 18 | def get_shape(shape_idx, gender, garment_class): 19 | betas = np.load(os.path.join( 20 | global_var.DATA_DIR, 21 | '{}_{}/shape/beta_{}.npy'.format(garment_class, gender, shape_idx) 22 | )).astype(np.float32) 23 | return betas[:10] 24 | 25 | 26 | class ShapeStyleCanonPose(Dataset): 27 | """Dataset for garments in canonical pose. 28 | 29 | This dataset is used to train ss2g(shape-style to garment) model which is used 30 | is weighing of pivot high frequency outputs. 31 | """ 32 | def __init__(self, garment_class, gender, shape_style_list_path='avail.txt', split=None): 33 | super(ShapeStyleCanonPose, self).__init__() 34 | self.garment_class = garment_class 35 | self.gender = gender 36 | root_dir = os.path.join(global_var.DATA_DIR, '{}_{}'.format(garment_class, gender)) 37 | 38 | if garment_class == 'old-t-shirt': 39 | betas = np.stack([np.load(os.path.join(root_dir, 'shape/beta_{:03d}.npy'.format(i))) for i in range(9)]).astype(np.float32)[:, :10] 40 | gammas = np.stack([np.load(os.path.join(root_dir, 'style/gamma_{:03d}.npy'.format(i))) for i in range(26)]).astype(np.float32) 41 | else: 42 | betas = np.load(os.path.join(root_dir, 'shape/betas.npy'))[:, :10] 43 | gammas = np.load(os.path.join(root_dir, 'style/gammas.npy')) 44 | 45 | with open(os.path.join(root_dir, shape_style_list_path), "r") as f: 46 | ss_list = [l.strip().split('_') for l in f.readlines()] 47 | 48 | assert(split in [None, 'train', 'test']) 49 | with open(os.path.join(root_dir, "test.txt"), "r") as f: 50 | test_ss = [l.strip().split('_') for l in f.readlines()] 51 | if split == 'train': 52 | ss_list = [ss for ss in ss_list if ss not in test_ss] 53 | elif split == 'test': 54 | ss_list = [ss for ss in ss_list if ss in test_ss] 55 | 56 | unpose_v = [] 57 | for shape_idx, style_idx in ss_list: 58 | fpath = os.path.join( 59 | root_dir, 'style_shape/beta{}_gamma{}.npy'.format(shape_idx, style_idx)) 60 | if not os.path.exists(fpath): 61 | print("shape {} and style {} not available".format(shape_idx, style_idx)) 62 | unpose_v.append(np.load(fpath)) 63 | unpose_v = np.stack(unpose_v) 64 | 65 | self.ss_list = ss_list 66 | # self.betas = torch.nn.parameter.Parameter(torch.from_numpy(betas.astype(np.float32)), requires_grad=False) 67 | # self.gammas = torch.nn.parameter.Parameter(torch.from_numpy(gammas.astype(np.float32)), requires_grad=False) 68 | # self.unpose_v = torch.nn.parameter.Parameter(torch.from_numpy(unpose_v.astype(np.float32)), requires_grad=False) 69 | # self.apose = torch.nn.parameter.Parameter(torch.from_numpy(get_Apose().astype(np.float32)), requires_grad=False) 70 | 71 | self.betas = torch.from_numpy(betas.astype(np.float32)) 72 | self.gammas = torch.from_numpy(gammas.astype(np.float32)) 73 | self.unpose_v = torch.from_numpy(unpose_v.astype(np.float32)) 74 | self.apose = torch.from_numpy(get_Apose().astype(np.float32)) 75 | def __len__(self): 76 | return self.unpose_v.shape[0] 77 | 78 | def __getitem__(self, item): 79 | bi, gi = self.ss_list[item] 80 | bi, gi = int(bi), int(gi) 81 | return self.unpose_v[item], self.apose, self.betas[bi], self.gammas[gi], item 82 | 83 | 84 | if __name__ == '__main__': 85 | gender = 'male' 86 | garment_class = 't-shirt' 87 | ds = ShapeStyleCanonPose(gender=gender, garment_class=garment_class, split='train') 88 | print(len(ds)) 89 | ds = ShapeStyleCanonPose(gender=gender, garment_class=garment_class, split='test') 90 | print(len(ds)) 91 | -------------------------------------------------------------------------------- /refu_tailornet/global_var.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) 4 | 5 | # Dataset root directory. Change it to point to downloaded data root directory. 6 | DATA_DIR = '/home/qtan/TailorNet_dataset' 7 | 8 | # Set the paths to SMPL model 9 | SMPL_PATH_NEUTRAL = '/home/qtan/TailorNet_dataset/smpl/basicmodel_neutral_lbs_10_207_0_v1.1.0.pkl' 10 | SMPL_PATH_MALE = '/home/qtan/TailorNet_dataset/smpl/basicmodel_m_lbs_10_207_0_v1.1.0.pkl' 11 | SMPL_PATH_FEMALE = '/home/qtan/TailorNet_dataset/smpl/basicmodel_f_lbs_10_207_0_v1.1.0.pkl' 12 | 13 | # Log directory where training logs, checkpoints and visualizations will be stored 14 | # LOG_DIR = '/mnt/session_space/TailorNet/log' 15 | 16 | LOG_DIR = '/home/qtan/TailorNet/log' 17 | 18 | # Downloaded TailorNet trained models' path 19 | MODEL_WEIGHTS_PATH = "/home/code-base/user_space/TailorNet_Models" 20 | 21 | # -------------------------------------------------------------------- 22 | # Variables below hardly need to change 23 | # -------------------------------------------------------------------- 24 | 25 | # Available genders 26 | GENDERS = ['neutral', 'male', 'female'] 27 | 28 | # This file in DATA_DIR contains pose indices (out of all SMPL poses) of 29 | # train/test splits as a dict {'train': , 'test': } 30 | POSE_SPLIT_FILE = 'split_static_pose_shape.npz' 31 | 32 | # This file in DATA_DIR contains garment template information in format 33 | # { : {'vert_indices': , 'f': } } 34 | # where refers to the indices of high_resolution SMPL 35 | # template which make garment 36 | GAR_INFO_FILE = 'garment_class_info.pkl' 37 | 38 | # Root dir for smooth data. Groundtruth smooth data is stored in the same 39 | # data hierarchy as simulation data under this directory. 40 | SMOOTH_DATA_DIR = DATA_DIR 41 | 42 | # Indicates that smooth groundtruth data is available or not. If False, smoothing 43 | # will be performed during the training which might slow down the training significantly. 44 | SMOOTH_STORED = True 45 | 46 | # Using smoothing in posed space for skirt 47 | POSED_SMOOTH_SKIRT = True 48 | 49 | """ 50 | ## SMPL joint 51 | ID parent name 52 | 0 -1 pelvis 53 | 1 0 L hip 54 | 2 0 R hip 55 | 3 0 stomach 56 | 4 1 L knee 57 | 5 2 R knee 58 | 6 3 Lower chest 59 | 7 4 L ankle 60 | 8 5 R ankle 61 | 9 6 Upper chest 62 | 10 7 L toe 63 | 11 8 R toe 64 | 12 9 throat 65 | 13 9 L Breast 66 | 14 9 R Breast 67 | 15 12 jaw 68 | 16 13 L shoulder 69 | 17 14 R shoulder 70 | 18 16 L elbow 71 | 19 17 R elbow 72 | 20 18 L wrist 73 | 21 19 R wrist 74 | 22 20 L hand 75 | 23 21 R hand 76 | """ 77 | 78 | # Lists the indices of joints which affect the deformations of particular garment 79 | VALID_THETA = { 80 | 't-shirt': [0, 1, 2, 3, 6, 9, 12, 13, 14, 16, 17, 18, 19], 81 | 'old-t-shirt': [0, 1, 2, 3, 6, 9, 12, 13, 14, 16, 17, 18, 19], 82 | 'shirt': [0, 1, 2, 3, 6, 9, 12, 13, 14, 16, 17, 18, 19, 20, 21], 83 | 'pant': [0, 1, 2, 4, 5, 7, 8], 84 | 'short-pant': [0, 1, 2, 4, 5], 85 | 'skirt': [0, 1, 2, 4, 5], 86 | } 87 | -------------------------------------------------------------------------------- /refu_tailornet/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aldehydecho/ReFU/9476bb1da1b8adc1001d21758e884aefb91ffd60/refu_tailornet/models/__init__.py -------------------------------------------------------------------------------- /refu_tailornet/models/geometric_loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | class Laplace_Mean_Euclidean_Loss(nn.Module): 5 | def __init__(self, neighbour, degrees, max_degree, point_num): 6 | super(Laplace_Mean_Euclidean_Loss, self).__init__() 7 | 8 | if type(neighbour) == torch.nn.parameter.Parameter: 9 | self.neighbour = neighbour 10 | else: 11 | self.neighbour = torch.nn.Parameter(torch.from_numpy(neighbour).long(), requires_grad=False) 12 | if type(degrees) == torch.nn.parameter.Parameter: 13 | self.degrees = degrees 14 | else: 15 | self.degrees = torch.nn.Parameter(torch.from_numpy(degrees).int(), requires_grad=False) 16 | 17 | self.max_degree = max_degree 18 | 19 | self.point_num = point_num 20 | 21 | def forward(self, predict_points, gt_points): 22 | 23 | batch = predict_points.shape[0] 24 | 25 | zeros = torch.zeros(batch, 1, 3).to(predict_points.device) 26 | 27 | padded_predict_points = torch.cat([predict_points, zeros], dim=1) 28 | padded_gt_points = torch.cat([gt_points, zeros], dim=1) 29 | 30 | gt_laplace = gt_points*(self.degrees.view(1, self.point_num, 1).repeat(batch, 1, 3)) - padded_gt_points[:, self.neighbour, :].sum(dim=2) 31 | 32 | predict_laplace = predict_points*(self.degrees.view(1, self.point_num, 1).repeat(batch, 1, 3)) - padded_predict_points[:, self.neighbour, :].sum(dim=2) 33 | 34 | loss = torch.sqrt(torch.pow(predict_laplace - gt_laplace, 2).sum(2)).sum(1).mean() 35 | 36 | return loss 37 | 38 | class Geometric_Mean_Euclidean_Loss(nn.Module): 39 | def __init__(self): 40 | super(Geometric_Mean_Euclidean_Loss, self).__init__() 41 | 42 | def forward(self, predict_points, gt_points): 43 | loss = (predict_points - gt_points).pow(2).sum(2).sqrt().mean() 44 | 45 | return loss 46 | 47 | class Per_Layer_Mean_Euclidean_Loss(nn.Module): 48 | def __init__(self, decoder_layers, vertex_num): 49 | super(Per_Layer_Mean_Euclidean_Loss, self).__init__() 50 | 51 | self.vertex_num = vertex_num 52 | 53 | temp_mapping_list = [torch.nn.Parameter(torch.from_numpy(layer[-1]).long(), requires_grad=False) for layer in decoder_layers] 54 | temp_mapping_list.pop(0) 55 | 56 | temp_mapping_list.append(temp_mapping_list[-1]) 57 | 58 | self.center_mapping_list = torch.nn.ParameterList(temp_mapping_list) 59 | 60 | def forward(self, per_layer_predict, gt_points): 61 | 62 | loss = [] 63 | 64 | for i, predict in enumerate(per_layer_predict): 65 | loss.append(torch.sqrt(torch.pow(predict-gt_points[:,self.center_mapping_list[i],:], 2).sum(2)).mean()*self.vertex_num) 66 | 67 | loss = torch.stack(loss).mean() 68 | 69 | return loss 70 | 71 | 72 | 73 | 74 | 75 | -------------------------------------------------------------------------------- /refu_tailornet/models/networks.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | 5 | class FullyConnected(nn.Module): 6 | def __init__(self, input_size, output_size, hidden_size=1024, num_layers=None): 7 | super(FullyConnected, self).__init__() 8 | net = [ 9 | nn.Linear(input_size, hidden_size), 10 | nn.ReLU(inplace=True), 11 | nn.Dropout(p=0.2), 12 | ] 13 | for i in range(num_layers - 2): 14 | net.extend([ 15 | nn.Linear(hidden_size, hidden_size), 16 | nn.ReLU(inplace=True), 17 | ]) 18 | net.extend([ 19 | nn.Linear(hidden_size, output_size), 20 | ]) 21 | self.net = nn.Sequential(*net) 22 | 23 | def forward(self, x): 24 | return self.net(x) 25 | 26 | 27 | class FullyConnected_Reshape(nn.Module): 28 | def __init__(self, input_size, output_size, output_shape, hidden_size=1024, num_layers=None, drop_prob = 0): 29 | super(FullyConnected_Reshape, self).__init__() 30 | net = [ 31 | nn.Linear(input_size, hidden_size), 32 | nn.ReLU(inplace=True), 33 | nn.Dropout(p=drop_prob), 34 | ] 35 | for i in range(num_layers - 2): 36 | net.extend([ 37 | nn.Linear(hidden_size, hidden_size), 38 | nn.ReLU(inplace=True), 39 | ]) 40 | net.extend([ 41 | nn.Linear(hidden_size, output_size), 42 | ]) 43 | self.net = nn.Sequential(*net) 44 | 45 | self.output_shape = output_shape 46 | 47 | def forward(self, x): 48 | x = self.net(x) 49 | 50 | batch_shape = x.shape[0] 51 | 52 | x = torch.reshape(x, [batch_shape, self.output_shape[0], self.output_shape[1]]) 53 | 54 | return x 55 | 56 | class FullyConnected_SDF_Hybrid_Weight(nn.Module): 57 | def __init__(self, input_size, middle_output_size, middle_output_shape, hidden_size=1024, num_layers=None, drop_prob = 0): 58 | super(FullyConnected_SDF_Hybrid_Weight, self).__init__() 59 | net = [ 60 | nn.Linear(input_size, hidden_size), 61 | nn.ReLU(inplace=True), 62 | nn.Dropout(p=drop_prob), 63 | ] 64 | for i in range(num_layers - 2): 65 | net.extend([ 66 | nn.Linear(hidden_size, hidden_size), 67 | nn.ReLU(inplace=True), 68 | ]) 69 | net.extend([ 70 | nn.Linear(hidden_size, middle_output_size), 71 | ]) 72 | self.net = nn.Sequential(*net) 73 | 74 | self.middle_output_shape = middle_output_shape 75 | 76 | weight_net = [nn.Linear(self.middle_output_shape[1]+1, 10), 77 | nn.ReLU(inplace=True), 78 | nn.Linear(10, 1) 79 | ] 80 | 81 | self.weight_net = nn.Sequential(*weight_net) 82 | 83 | def forward(self, x, sdf_value): 84 | x = self.net(x) 85 | 86 | batch_shape = x.shape[0] 87 | 88 | x = torch.reshape(x, [batch_shape, self.middle_output_shape[0], self.middle_output_shape[1]]) 89 | 90 | x = torch.cat((x, sdf_value), -1) 91 | 92 | x = torch.abs(self.weight_net(x)) 93 | 94 | 95 | return x 96 | -------------------------------------------------------------------------------- /refu_tailornet/models/ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from global_var import VALID_THETA 3 | 4 | 5 | def verts_dist(v1, v2, dim=None): 6 | """ 7 | distance between two point sets 8 | v1 and v2 shape: NxVx3 9 | """ 10 | x = torch.pow(v2 - v1, 2) 11 | x = torch.sum(x, -1) 12 | x = torch.sqrt(x) 13 | if dim == -1: 14 | return x 15 | elif dim is None: 16 | return torch.mean(x) 17 | else: 18 | return torch.mean(x, dim=dim) 19 | 20 | 21 | def mask_thetas(thetas, garment_class): 22 | """ 23 | thetas: shape [N, 72] 24 | garment_class: e.g. t-shirt 25 | """ 26 | valid_theta = VALID_THETA[garment_class] 27 | mask = torch.zeros_like(thetas).view(-1, 24, 3) 28 | mask[:, valid_theta, :] = 1. 29 | mask = mask.view(-1, 72) 30 | return thetas * mask 31 | 32 | 33 | def mask_betas(betas, garment_class): 34 | """ 35 | betas: shape [N, 10] 36 | garment_class: e.g. t-shirt 37 | """ 38 | valid_beta = [0, 1] 39 | mask = torch.zeros_like(betas) 40 | mask[:, valid_beta] = 1. 41 | return betas * mask 42 | 43 | 44 | def mask_gammas(gammas, garment_class): 45 | """ 46 | gammas: shape [N, 4] 47 | garment_class: e.g. t-shirt 48 | """ 49 | valid_gamma = [0, 1] 50 | mask = torch.zeros_like(gammas) 51 | mask[:, valid_gamma] = 1. 52 | gammas = gammas * mask 53 | if garment_class == 'old-t-shirt': 54 | gammas = gammas + torch.tensor( 55 | [[0., 0., 1.5, 0.]], dtype=torch.float32, device=gammas.device) 56 | return gammas 57 | 58 | 59 | def mask_inputs(thetas, betas, gammas, garment_class): 60 | if thetas is not None: 61 | thetas = mask_thetas(thetas, garment_class) 62 | if betas is not None: 63 | betas = mask_betas(betas, garment_class) 64 | if gammas is not None: 65 | gammas = mask_gammas(gammas, garment_class) 66 | return thetas, betas, gammas 67 | 68 | 69 | def pairwise_distances(x, y=None): 70 | """ 71 | Input: x is a Nxd matrix 72 | y is an optional Mxd matirx 73 | Output: dist is a NxM matrix where dist[i,j] is the square norm between x[i,:] and y[j,:] 74 | if y is not given then use 'y=x'. 75 | i.e. dist[i,j] = ||x[i,:]-y[j,:]||^2 76 | """ 77 | x_norm = (x**2).sum(1).view(-1, 1) 78 | if y is not None: 79 | y_norm = (y**2).sum(1).view(1, -1) 80 | else: 81 | y = x 82 | y_norm = x_norm.view(1, -1) 83 | 84 | dist = x_norm + y_norm - 2.0 * torch.mm(x, torch.transpose(y, 0, 1)) 85 | return dist 86 | -------------------------------------------------------------------------------- /refu_tailornet/models/sdf_collision_response_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import sys 4 | import numpy as np 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import pickle 8 | from models import ops 9 | 10 | import time 11 | 12 | from torch.autograd import grad 13 | 14 | class SDF_Collsion_Response_Hybrid(nn.Module): 15 | def __init__(self, sdf_network, hybrid_weight_model, garment_class): 16 | super(SDF_Collsion_Response_Hybrid, self).__init__() 17 | 18 | self.sdf_network = sdf_network 19 | 20 | self.hybrid_weight_model = hybrid_weight_model 21 | 22 | self.garment_class = garment_class 23 | 24 | def add_latent(self, thetas, betas, points): 25 | batch_size, num_of_points, dim = points.shape 26 | points = points.reshape(batch_size * num_of_points, dim) 27 | latent_inputs = torch.zeros(0).cuda() 28 | 29 | for ind in range(0, batch_size): 30 | latent_ind = torch.cat([betas[ind], thetas[ind]], 0) 31 | latent_repeat = latent_ind.expand(num_of_points, -1) 32 | latent_inputs = torch.cat([latent_inputs, latent_repeat], 0) 33 | points = torch.cat([latent_inputs, points], 1) 34 | return points 35 | 36 | def forward(self, thetas, betas, gammas, verts, eval = False): 37 | 38 | mlp_thetas, mlp_betas, mlp_gammas = ops.mask_inputs(thetas, betas, gammas, self.garment_class) 39 | 40 | batch_size, num_of_points, _ = verts.shape 41 | 42 | verts.requires_grad_() 43 | 44 | 45 | latent_verts = self.add_latent(thetas, betas, verts) 46 | 47 | 48 | sdf_value = self.sdf_network(latent_verts).reshape(batch_size, num_of_points, 1) 49 | 50 | 51 | d_points = torch.ones_like(sdf_value, requires_grad=False, device=sdf_value.device) 52 | 53 | sdf_gradient = grad(outputs = sdf_value, inputs = verts, grad_outputs=d_points, create_graph=False, retain_graph=True, only_inputs=True)[0] 54 | 55 | normalized_sdf_gradient = F.normalize(sdf_gradient, dim=2) 56 | 57 | 58 | hybrid_weight = self.hybrid_weight_model(torch.cat((mlp_thetas, mlp_betas, mlp_gammas), dim=1), sdf_value) 59 | 60 | hybrid_verts = torch.where(sdf_value<0, verts-sdf_value * hybrid_weight * normalized_sdf_gradient, verts) 61 | 62 | 63 | 64 | hybrid_latent_verts = self.add_latent(thetas, betas, hybrid_verts) 65 | 66 | hybrid_sdf_value = self.sdf_network(hybrid_latent_verts).reshape(batch_size, num_of_points, 1) 67 | 68 | return hybrid_verts, hybrid_sdf_value, sdf_value, hybrid_weight 69 | 70 | 71 | -------------------------------------------------------------------------------- /refu_tailornet/models/skirt_model.md: -------------------------------------------------------------------------------- 1 | TailorNet garment model assumes that the garment template is a submesh of SMPL 2 | body template, thereby it can take skinning weights from SMPL model. 3 | But that doesn't hold up for garments like skirt. 4 | First, we tried attaching a skirt template to the root joint of SMPL where the skirt 5 | is only subjected to the global rotation of SMPL pose. 6 | All other deformations are learned as displacements in a reasonable manner as shown 7 | in our paper. 8 | However, this design doesn't use the articulation of underlined body parts like legs and 9 | hence it is limiting. 10 | 11 | Hence, we propose a simple modification to our garment model which works well for skirt. 12 | We drape a simple skirt template on canonical body and calculate skinning weights of 13 | each skirt vertex as a weighted sum of SMPL skinning weights of K=100 nearest body vertices 14 | where weights of sum are inversely proportional to distances. 15 | This creates a primitive smooth skirt base which is articulated by SMPL pose and shape, and 16 | over which the displacements can be added to predict realistic skirt. 17 | Associating each skirt vertex with K=100 nearest body vertices instead of just one, reduces 18 | the sudden discontinuity of association between two legs. 19 | -------------------------------------------------------------------------------- /refu_tailornet/models/smpl4garment.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import chumpy as ch 4 | import numpy as np 5 | import cv2 6 | from psbody.mesh import Mesh 7 | from smpl_lib.ch_smpl import Smpl 8 | from tnutils.smpl_paths import SmplPaths 9 | 10 | import global_var 11 | 12 | 13 | class SMPL4Garment(object): 14 | """SMPL class for garments.""" 15 | def __init__(self, gender): 16 | self.gender = gender 17 | smpl_model = SmplPaths(gender=gender).get_hres_smpl_model_data() 18 | self.smpl_base = Smpl(smpl_model) 19 | with open(os.path.join(global_var.DATA_DIR, global_var.GAR_INFO_FILE), 'rb') as f: 20 | self.class_info = pickle.load(f) 21 | 22 | # skirt_weight: n_skirt x n_body 23 | # skirt_skinning: n_skirt x 24 24 | self.skirt_weight = ch.array(np.load(os.path.join( 25 | global_var.DATA_DIR, 'skirt_weight.npz'))['w']) 26 | self.skirt_skinning = self.skirt_weight.dot(self.smpl_base.weights) 27 | 28 | def run(self, beta=None, theta=None, garment_d=None, garment_class=None): 29 | """Outputs body and garment of specified garment class given theta, beta and displacements.""" 30 | if beta is not None: 31 | self.smpl_base.betas[:beta.shape[0]] = beta 32 | else: 33 | self.smpl_base.betas[:] = 0 34 | if theta is not None: 35 | self.smpl_base.pose[:] = theta 36 | else: 37 | self.smpl_base.pose[:] = 0 38 | self.smpl_base.v_personal[:] = 0 39 | if garment_d is not None and garment_class is not None: 40 | if 'skirt' not in garment_class: 41 | vert_indices = self.class_info[garment_class]['vert_indices'] 42 | f = self.class_info[garment_class]['f'] 43 | self.smpl_base.v_personal[vert_indices] = garment_d 44 | garment_m = Mesh(v=self.smpl_base.r[vert_indices], f=f) 45 | else: 46 | # vert_indices = self.class_info[garment_class]['vert_indices'] 47 | f = self.class_info[garment_class]['f'] 48 | 49 | A = self.smpl_base.A.reshape((16, 24)).T 50 | skirt_V = self.skirt_skinning.dot(A).reshape((-1, 4, 4)) 51 | 52 | verts = self.skirt_weight.dot(self.smpl_base.v_poseshaped) 53 | verts = verts + garment_d 54 | verts_h = ch.hstack((verts, ch.ones((verts.shape[0], 1)))) 55 | verts = ch.sum(skirt_V * verts_h.reshape(-1, 1, 4), axis=-1)[:, :3] 56 | garment_m = Mesh(v=verts, f=f) 57 | else: 58 | garment_m = None 59 | self.smpl_base.v_personal[:] = 0 60 | body_m = Mesh(v=self.smpl_base.r, f=self.smpl_base.f) 61 | return body_m, garment_m 62 | 63 | 64 | if __name__ == '__main__': 65 | gender = 'female' 66 | garment_class = 'skirt' 67 | shape_idx = '005' 68 | style_idx = '020' 69 | split = 'train' 70 | smpl = SMPL4Garment(gender) 71 | 72 | from dataset.static_pose_shape_final import OneStyleShape 73 | ds = OneStyleShape(garment_class, shape_idx, style_idx, split) 74 | K = 87 75 | verts_d, theta, beta, gamma, idx = ds[K] 76 | body_m, gar_m = smpl.run(theta=theta, beta=beta, garment_class=garment_class, garment_d=verts_d) 77 | gar_m.write_ply('/BS/cpatel/work/gar.ply') 78 | body_m.write_ply('/BS/cpatel/work/body.ply') 79 | -------------------------------------------------------------------------------- /refu_tailornet/models/soft_collision_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import sys 4 | import numpy as np 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import pickle 8 | from models import ops 9 | 10 | import time 11 | 12 | from torch.autograd import grad 13 | 14 | from psbody.mesh import Mesh 15 | from psbody.mesh.geometry.vert_normals import VertNormals 16 | from psbody.mesh.geometry.tri_normals import TriNormals 17 | from psbody.mesh.search import AabbTree 18 | 19 | def get_nearest_points_and_normals(vert, base_verts, base_faces): 20 | 21 | fn = TriNormals(v=base_verts, f=base_faces).reshape((-1, 3)) 22 | vn = VertNormals(v=base_verts, f=base_faces).reshape((-1, 3)) 23 | 24 | tree = AabbTree(Mesh(v=base_verts, f=base_faces)) 25 | nearest_tri, nearest_part, nearest_point = tree.nearest(vert, nearest_part=True) 26 | nearest_tri = nearest_tri.ravel().astype(np.long) 27 | nearest_part = nearest_part.ravel().astype(np.long) 28 | 29 | nearest_normals = np.zeros_like(vert) 30 | 31 | #nearest_part tells you whether the closest point in triangle abc is in the interior (0), on an edge (ab:1,bc:2,ca:3), or a vertex (a:4,b:5,c:6) 32 | cl_tri_idxs = np.nonzero(nearest_part == 0)[0].astype(np.int) 33 | cl_vrt_idxs = np.nonzero(nearest_part > 3)[0].astype(np.int) 34 | cl_edg_idxs = np.nonzero((nearest_part <= 3) & (nearest_part > 0))[0].astype(np.int) 35 | 36 | nt = nearest_tri[cl_tri_idxs] 37 | nearest_normals[cl_tri_idxs] = fn[nt] 38 | 39 | nt = nearest_tri[cl_vrt_idxs] 40 | npp = nearest_part[cl_vrt_idxs] - 4 41 | nearest_normals[cl_vrt_idxs] = vn[base_faces[nt, npp]] 42 | 43 | nt = nearest_tri[cl_edg_idxs] 44 | npp = nearest_part[cl_edg_idxs] - 1 45 | nearest_normals[cl_edg_idxs] += vn[base_faces[nt, npp]] 46 | npp = np.mod(nearest_part[cl_edg_idxs], 3) 47 | nearest_normals[cl_edg_idxs] += vn[base_faces[nt, npp]] 48 | 49 | nearest_normals = nearest_normals / (np.linalg.norm(nearest_normals, axis=-1, keepdims=True) + 1.e-10) 50 | 51 | return nearest_point, nearest_normals 52 | 53 | def get_collision_loss(garment_verts, body_verts, body_faces, output_collision_percentage=False): 54 | 55 | device = garment_verts.get_device() 56 | 57 | nearest_points, nearest_normals = get_nearest_points_and_normals(garment_verts.cpu().detach().numpy(), body_verts.cpu().detach().numpy(), body_faces) 58 | 59 | torch_nearest_points = torch.from_numpy(nearest_points).to(device) 60 | torch_nearest_normals = torch.from_numpy(nearest_normals).to(device) 61 | 62 | distance = torch.nn.functional.relu(torch.sum(- (garment_verts - torch_nearest_points) * torch_nearest_normals, axis=1)) 63 | 64 | if output_collision_percentage == False: 65 | return distance.sum() 66 | else: 67 | collision_vertices = distance>0 68 | return distance.sum(), torch.Tensor.float(collision_vertices).mean() 69 | 70 | def get_SDF(garment_verts, body_verts, body_faces): 71 | 72 | device = garment_verts.get_device() 73 | 74 | nearest_points, nearest_normals = get_nearest_points_and_normals(garment_verts.cpu().detach().numpy(), body_verts.cpu().detach().numpy(), body_faces) 75 | 76 | torch_nearest_points = torch.from_numpy(nearest_points).to(device) 77 | torch_nearest_normals = torch.from_numpy(nearest_normals).to(device) 78 | 79 | distance = torch.sum( (garment_verts - torch_nearest_points) * torch_nearest_normals, axis=1) 80 | 81 | return distance 82 | 83 | 84 | class Accurate_SDF(nn.Module): 85 | def __init__(self, body_faces): 86 | super(Accurate_SDF, self).__init__() 87 | 88 | self.body_faces = body_faces 89 | 90 | def forward(self, batch_garment_verts, batch_body_verts): 91 | batch_size = batch_garment_verts.shape[0] 92 | 93 | per_model_sdf = [] 94 | 95 | for i in range(0, batch_size): 96 | g_verts = batch_garment_verts[i,:,:] 97 | b_verts = batch_body_verts[i,:,:] 98 | 99 | per_model_sdf.append(get_SDF(g_verts, b_verts, self.body_faces)) 100 | 101 | return torch.stack(per_model_sdf) 102 | 103 | 104 | 105 | class Soft_Collision_Loss(nn.Module): 106 | def __init__(self, body_faces): 107 | super(Soft_Collision_Loss, self).__init__() 108 | 109 | self.body_faces = body_faces 110 | 111 | 112 | def forward(self, batch_garment_verts, batch_body_verts, output_collision_percentage=False): 113 | batch_size = batch_garment_verts.shape[0] 114 | 115 | per_model_collision_loss = [] 116 | 117 | if output_collision_percentage == True: 118 | per_model_collision_percentage = [] 119 | 120 | for i in range(0, batch_size): 121 | g_verts = batch_garment_verts[i,:,:] 122 | b_verts = batch_body_verts[i,:,:] 123 | 124 | if output_collision_percentage == False: 125 | per_model_collision_loss.append(get_collision_loss(g_verts, b_verts, self.body_faces)) 126 | else: 127 | collision_loss, collision_percentage = get_collision_loss(g_verts, b_verts, self.body_faces, output_collision_percentage=True) 128 | per_model_collision_loss.append(collision_loss) 129 | per_model_collision_percentage.append(collision_percentage) 130 | 131 | if output_collision_percentage == False: 132 | return torch.stack(per_model_collision_loss) 133 | else: 134 | return torch.stack(per_model_collision_loss), torch.stack(per_model_collision_percentage) 135 | 136 | class Estimated_Soft_Collision_Loss(nn.Module): 137 | def __init__(self, sdf_network): 138 | super(Estimated_Soft_Collision_Loss, self).__init__() 139 | 140 | self.sdf_network = sdf_network 141 | 142 | def add_latent(self, thetas, betas, points): 143 | batch_size, num_of_points, dim = points.shape 144 | points = points.reshape(batch_size * num_of_points, dim) 145 | latent_inputs = torch.zeros(0).cuda() 146 | 147 | for ind in range(0, batch_size): 148 | latent_ind = torch.cat([betas[ind], thetas[ind]], 0) 149 | latent_repeat = latent_ind.expand(num_of_points, -1) 150 | latent_inputs = torch.cat([latent_inputs, latent_repeat], 0) 151 | points = torch.cat([latent_inputs, points], 1) 152 | return points 153 | 154 | def forward(self, thetas, betas, verts): 155 | batch_size, num_of_points, _ = verts.shape 156 | 157 | latent_verts = self.add_latent(thetas, betas, verts) 158 | 159 | sdf_value = self.sdf_network(latent_verts).reshape(batch_size, num_of_points) 160 | 161 | collision_loss = torch.nn.functional.relu(-sdf_value).sum(dim=1) 162 | 163 | return collision_loss -------------------------------------------------------------------------------- /refu_tailornet/models/tailornet_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | 5 | import global_var 6 | from trainer.lf_trainer import get_best_runner as lf_runner 7 | from trainer.hf_trainer import get_best_runner as hf_runner 8 | from trainer.ss2g_trainer import get_best_runner as ss2g_runner 9 | from dataset.canonical_pose_dataset import ShapeStyleCanonPose 10 | 11 | from trainer.lf_trainer import get_best_model as lf_model 12 | from trainer.hf_trainer import get_best_model as hf_model 13 | from trainer.ss2g_trainer import get_best_model as ss2g_model 14 | 15 | 16 | class TailorNetTestModel(object): 17 | """Main TailorNet model class. 18 | This class provides API for TailorNet prediction given trained models of low 19 | frequency predictor, pivot high frequency predictors and ss2g predictor. 20 | """ 21 | def __init__(self, lf_logdir, hf_logdir, ss2g_logdir, garment_class, gender): 22 | self.gender = gender 23 | self.garment_class = garment_class 24 | self.lf_logdir = lf_logdir 25 | self.hf_logdir = hf_logdir 26 | self.ss2g_logdir = ss2g_logdir 27 | print("USING LF LOG DIR: ", lf_logdir) 28 | print("USING HF LOG DIR: ", hf_logdir) 29 | print("USING SS2G LOG DIR: ", ss2g_logdir) 30 | 31 | pivots_ds = ShapeStyleCanonPose(garment_class=garment_class, gender=gender, 32 | shape_style_list_path='pivots.txt') 33 | 34 | self.train_betas = pivots_ds.betas.cuda() 35 | self.train_gammas = pivots_ds.gammas.cuda() 36 | self.basis = pivots_ds.unpose_v.cuda() 37 | self.train_pivots = pivots_ds.ss_list 38 | 39 | self.hf_runners = [ 40 | hf_runner("{}/{}_{}".format(hf_logdir, shape_idx, style_idx)) 41 | for shape_idx, style_idx in self.train_pivots 42 | ] 43 | self.lf_runner = lf_runner(lf_logdir) 44 | self.ss2g_runner = ss2g_runner(ss2g_logdir) 45 | 46 | def forward(self, thetas, betas, gammas, ret_separate=False): 47 | inp_type = type(thetas) 48 | inp_device = None if inp_type == np.ndarray else thetas.device 49 | bs = thetas.shape[0] 50 | 51 | if isinstance(thetas, np.ndarray): 52 | thetas = torch.from_numpy(thetas.astype(np.float32)) 53 | betas = torch.from_numpy(betas.astype(np.float32)) 54 | gammas = torch.from_numpy(gammas.astype(np.float32)) 55 | 56 | with torch.no_grad(): 57 | pred_disp_hf_pivot = torch.stack([ 58 | rr.forward(thetas.cuda(), betas.cuda(), 59 | gammas.cuda()).view(bs, -1, 3) 60 | for rr in self.hf_runners 61 | ]).transpose(0, 1) 62 | 63 | pred_disp_hf = self.interp4(thetas, betas, gammas, pred_disp_hf_pivot, sigma=0.01) 64 | pred_disp_lf = self.lf_runner.forward(thetas, betas, gammas).view(bs, -1, 3) 65 | 66 | if inp_type == np.ndarray: 67 | pred_disp_hf = pred_disp_hf.cpu().numpy() 68 | pred_disp_lf = pred_disp_lf.cpu().numpy() 69 | else: 70 | pred_disp_hf = pred_disp_hf.to(inp_device) 71 | pred_disp_lf = pred_disp_lf.to(inp_device) 72 | if ret_separate: 73 | return pred_disp_lf, pred_disp_hf 74 | else: 75 | return pred_disp_lf + pred_disp_hf 76 | 77 | def interp4(self, thetas, betas, gammas, pred_disp_pivot, sigma=0.5): 78 | """RBF interpolation with distance by SS2G.""" 79 | # disp for given shape-style in canon pose 80 | bs = pred_disp_pivot.shape[0] 81 | rest_verts = self.ss2g_runner.forward(betas=betas, gammas=gammas).view(bs, -1, 3) 82 | # distance of given shape-style from pivots in terms of displacement 83 | # difference in canon pose 84 | dist = rest_verts.unsqueeze(1) - self.basis.unsqueeze(0) 85 | dist = (dist ** 2).sum(-1).mean(-1) * 1000. 86 | 87 | # compute normalized RBF distance 88 | weight = torch.exp(-dist/sigma) 89 | weight = weight / weight.sum(1, keepdim=True) 90 | 91 | # interpolate using weights 92 | pred_disp = (pred_disp_pivot * weight.unsqueeze(-1).unsqueeze(-1)).sum(1) 93 | 94 | return pred_disp 95 | 96 | 97 | class TailorNetModel(torch.nn.Module): 98 | def __init__(self, lf_logdir, hf_logdir, ss2g_logdir, garment_class, gender): 99 | super(TailorNetModel, self).__init__() 100 | self.gender = gender 101 | self.garment_class = garment_class 102 | 103 | lf_logdir = os.path.join(lf_logdir, "{}_{}".format(garment_class, gender)) 104 | hf_logdir = os.path.join(hf_logdir, "{}_{}".format(garment_class, gender)) 105 | ss2g_logdir = os.path.join(ss2g_logdir, "{}_{}".format(garment_class, gender)) 106 | 107 | self.lf_logdir = lf_logdir 108 | self.hf_logdir = hf_logdir 109 | self.ss2g_logdir = ss2g_logdir 110 | print("USING LF LOG DIR: ", lf_logdir) 111 | print("USING HF LOG DIR: ", hf_logdir) 112 | print("USING SS2G LOG DIR: ", ss2g_logdir) 113 | 114 | pivots_ds = ShapeStyleCanonPose(garment_class=garment_class, gender=gender, 115 | shape_style_list_path='pivots.txt') 116 | 117 | self.train_betas = pivots_ds.betas.cuda() 118 | self.train_gammas = pivots_ds.gammas.cuda() 119 | self.basis = pivots_ds.unpose_v.cuda() 120 | self.train_pivots = pivots_ds.ss_list 121 | 122 | self.hf_models = torch.nn.ModuleList([ 123 | hf_model("{}/{}_{}".format(hf_logdir, shape_idx, style_idx)) 124 | for shape_idx, style_idx in self.train_pivots 125 | ]) 126 | self.lf_model = lf_model(lf_logdir) 127 | self.ss2g_model = ss2g_model(ss2g_logdir) 128 | 129 | def forward(self, thetas, betas, gammas): 130 | # inp_type = type(thetas) 131 | # inp_device = None if inp_type == np.ndarray else thetas.device 132 | bs = thetas.shape[0] 133 | 134 | # if isinstance(thetas, np.ndarray): 135 | # thetas = torch.from_numpy(thetas.astype(np.float32)) 136 | # betas = torch.from_numpy(betas.astype(np.float32)) 137 | # gammas = torch.from_numpy(gammas.astype(np.float32)) 138 | 139 | # with torch.no_grad(): 140 | pred_disp_hf_pivot = torch.stack([ 141 | model(thetas.cuda(), betas.cuda(), 142 | gammas.cuda()).view(bs, -1, 3) 143 | for model in self.hf_models 144 | ]).transpose(0, 1) 145 | 146 | # def interp4(thetas, betas, gammas, pred_disp_pivot, sigma=0.5): 147 | # """RBF interpolation with distance by SS2G.""" 148 | # # disp for given shape-style in canon pose 149 | # bs = pred_disp_pivot.shape[0] 150 | # rest_verts = self.ss2g_model(betas=betas, gammas=gammas).view(bs, -1, 3) 151 | # # distance of given shape-style from pivots in terms of displacement 152 | # # difference in canon pose 153 | 154 | 155 | 156 | # dist = rest_verts.unsqueeze(1) - self.basis.unsqueeze(0) 157 | # dist = (dist ** 2).sum(-1).mean(-1) * 1000. 158 | 159 | # # compute normalized RBF distance 160 | # weight = torch.exp(-dist/sigma) 161 | # weight = weight / weight.sum(1, keepdim=True) 162 | 163 | # # interpolate using weights 164 | # pred_disp = (pred_disp_pivot * weight.unsqueeze(-1).unsqueeze(-1)).sum(1) 165 | 166 | # return pred_disp 167 | 168 | # pred_disp_hf = interp4(thetas, betas, gammas, pred_disp_hf_pivot, sigma=0.01) 169 | pred_disp_hf = self.interp4(thetas, betas, gammas, pred_disp_hf_pivot, sigma=0.01) 170 | pred_disp_lf = self.lf_model.forward(thetas, betas, gammas).view(bs, -1, 3) 171 | 172 | return pred_disp_lf + pred_disp_hf 173 | 174 | def interp4(self, thetas, betas, gammas, pred_disp_pivot, sigma=0.5): 175 | """RBF interpolation with distance by SS2G.""" 176 | # disp for given shape-style in canon pose 177 | bs = pred_disp_pivot.shape[0] 178 | rest_verts = self.ss2g_model(betas=betas, gammas=gammas).view(bs, -1, 3) 179 | # distance of given shape-style from pivots in terms of displacement 180 | # difference in canon pose 181 | dist = rest_verts.unsqueeze(1).to(self.basis.device) - self.basis.unsqueeze(0) 182 | dist = (dist ** 2).sum(-1).mean(-1) * 1000. 183 | 184 | dist = dist.to(pred_disp_pivot.device) 185 | 186 | # compute normalized RBF distance 187 | weight = torch.exp(-dist/sigma) 188 | weight = weight / weight.sum(1, keepdim=True) 189 | 190 | # interpolate using weights 191 | pred_disp = (pred_disp_pivot * weight.unsqueeze(-1).unsqueeze(-1)).sum(1) 192 | 193 | return pred_disp 194 | 195 | 196 | def get_best_runner(garment_class='t-shirt', gender='female', lf_logdir=None, hf_logdir=None, ss2g_logdir=None): 197 | """Helper function to get TailorNet runner.""" 198 | if lf_logdir is None: 199 | lf_logdir = os.path.join(global_var.MODEL_WEIGHTS_PATH, "{}_{}_weights/tn_orig_lf".format(garment_class, gender)) 200 | if hf_logdir is None: 201 | hf_logdir = os.path.join(global_var.MODEL_WEIGHTS_PATH, "{}_{}_weights/tn_orig_hf".format(garment_class, gender)) 202 | if ss2g_logdir is None: 203 | ss2g_logdir = os.path.join(global_var.MODEL_WEIGHTS_PATH, "{}_{}_weights/tn_orig_ss2g".format(garment_class, gender)) 204 | 205 | lf_logdir = os.path.join(lf_logdir, "{}_{}".format(garment_class, gender)) 206 | hf_logdir = os.path.join(hf_logdir, "{}_{}".format(garment_class, gender)) 207 | ss2g_logdir = os.path.join(ss2g_logdir, "{}_{}".format(garment_class, gender)) 208 | runner = TailorNetTestModel(lf_logdir, hf_logdir, ss2g_logdir, garment_class, gender) 209 | return runner 210 | 211 | 212 | if __name__ == '__main__': 213 | # gender = 'male' 214 | # garment_class = 't-shirt' 215 | # runner = get_best_runner(garment_class, gender) 216 | pass -------------------------------------------------------------------------------- /refu_tailornet/models/torch_smpl4garment.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import sys 4 | import numpy as np 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import pickle 8 | 9 | import global_var 10 | from tnutils.smpl_paths import SmplPaths 11 | 12 | 13 | class TorchSMPL4Garment(nn.Module): 14 | """Pytorch version of models.smpl4garment.SMPL4Garment class.""" 15 | def __init__(self, gender): 16 | super(TorchSMPL4Garment, self).__init__() 17 | 18 | # with open(model_path, 'rb') as reader: 19 | # model = pickle.load(reader, encoding='iso-8859-1') 20 | model = SmplPaths(gender=gender).get_hres_smpl_model_data() 21 | with open(os.path.join(global_var.DATA_DIR, global_var.GAR_INFO_FILE), 'rb') as f: 22 | class_info = pickle.load(f) 23 | for k in class_info.keys(): 24 | if isinstance(class_info[k]['vert_indices'], np.ndarray): 25 | class_info[k]['vert_indices'] = torch.tensor( 26 | class_info[k]['vert_indices'].astype(np.int64)) 27 | if isinstance(class_info[k]['f'], np.ndarray): 28 | class_info[k]['f'] = torch.tensor(class_info[k]['f'].astype(np.int64)) 29 | 30 | self.class_info = class_info 31 | self.gender = gender 32 | 33 | self.faces = model['f'] 34 | 35 | np_v_template = np.array(model['v_template'], dtype=np.float) 36 | 37 | self.register_buffer('v_template', torch.from_numpy(np_v_template).float()) 38 | self.size = [np_v_template.shape[0], 3] 39 | 40 | np_shapedirs = np.array(model['shapedirs'], dtype=np.float)[:, :, :10] 41 | self.num_betas = np_shapedirs.shape[-1] 42 | np_shapedirs = np.reshape(np_shapedirs, [-1, self.num_betas]).T 43 | self.register_buffer('shapedirs', torch.from_numpy(np_shapedirs).float()) 44 | 45 | np_J_regressor = np.array(model['J_regressor'].todense(), dtype=np.float).T 46 | self.register_buffer('J_regressor', torch.from_numpy(np_J_regressor).float()) 47 | 48 | np_posedirs = np.array(model['posedirs'], dtype=np.float) 49 | num_pose_basis = np_posedirs.shape[-1] 50 | np_posedirs = np.reshape(np_posedirs, [-1, num_pose_basis]).T 51 | self.register_buffer('posedirs', torch.from_numpy(np_posedirs).float()) 52 | 53 | self.parents = np.array(model['kintree_table'])[0].astype(np.int32) 54 | 55 | np_joint_regressor = np.array(model['J_regressor'].todense(), dtype=np.float) 56 | self.register_buffer('joint_regressor', torch.from_numpy(np_joint_regressor).float()) 57 | 58 | np_weights = np.array(model['weights'], dtype=np.float) 59 | 60 | vertex_count = np_weights.shape[0] 61 | vertex_component = np_weights.shape[1] 62 | 63 | self.register_buffer( 64 | 'weight', 65 | torch.from_numpy(np_weights).float().reshape(1, vertex_count, vertex_component)) 66 | 67 | self.register_buffer('e3', torch.eye(3).float()) 68 | self.cur_device = None 69 | self.num_verts = 27554 70 | 71 | skirt_weight = np.load(os.path.join(global_var.DATA_DIR, 'skirt_weight.npz'))['w'] 72 | self.register_buffer('skirt_weight', torch.from_numpy(skirt_weight).float()) 73 | skirt_skinning = skirt_weight.dot(np_weights) 74 | self.register_buffer('skirt_skinning', torch.from_numpy(skirt_skinning).float()) 75 | 76 | def save_obj(self, verts, obj_mesh_name): 77 | if self.faces is None: 78 | msg = 'obj not saveable!' 79 | sys.exit(msg) 80 | 81 | with open(obj_mesh_name, 'w') as fp: 82 | for v in verts: 83 | fp.write('v %f %f %f\n' % (v[0], v[1], v[2])) 84 | 85 | for f in self.faces: # Faces are 1-based, not 0-based in obj files 86 | fp.write('f %d %d %d\n' % (f[0] + 1, f[1] + 1, f[2] + 1)) 87 | 88 | def forward(self, theta, beta=None, garment_d=None, 89 | garment_class=None, rotate_base=False, ret_skirt_skinning=False): 90 | if not self.cur_device: 91 | device = theta.device 92 | self.cur_device = torch.device(device.type, device.index) 93 | 94 | num_batch = theta.shape[0] 95 | 96 | if beta is not None: 97 | v_shaped = torch.matmul( 98 | beta, self.shapedirs).view(-1, self.size[0], self.size[1]) + self.v_template 99 | else: 100 | v_shaped = self.v_template.unsqueeze(0).expand(num_batch, -1, -1) 101 | Jx = torch.matmul(v_shaped[:, :, 0], self.J_regressor) 102 | Jy = torch.matmul(v_shaped[:, :, 1], self.J_regressor) 103 | Jz = torch.matmul(v_shaped[:, :, 2], self.J_regressor) 104 | J = torch.stack([Jx, Jy, Jz], dim=2) 105 | 106 | Rs = batch_rodrigues(theta.contiguous().view(-1, 3)).view(-1, 24, 3, 3) 107 | pose_feature = (Rs[:, 1:, :, :]).sub(1.0, self.e3).view(-1, 207) 108 | v_posed = torch.matmul( 109 | pose_feature, self.posedirs).view(-1, self.size[0], self.size[1]) + v_shaped 110 | 111 | # garment deformation 112 | if garment_d is not None and garment_class is not None: 113 | v_deformed = v_posed.clone() 114 | v_deformed[:, self.class_info[garment_class]['vert_indices']] += garment_d 115 | 116 | self.J_transformed, A = batch_global_rigid_transformation( 117 | Rs, J, self.parents, self.cur_device, rotate_base=rotate_base) 118 | 119 | W = self.weight.view(1, self.num_verts, 24).repeat(num_batch, 1, 1) 120 | T = torch.matmul(W, A.view(num_batch, 24, 16)).view(num_batch, -1, 4, 4) 121 | 122 | v_posed_homo = torch.cat( 123 | [v_posed, torch.ones(num_batch, v_posed.shape[1], 1, device=self.cur_device)], dim=2) 124 | v_homo = torch.matmul(T, torch.unsqueeze(v_posed_homo, -1)) 125 | v_body = v_homo[:, :, :3, 0] 126 | 127 | if garment_class is not None: 128 | if garment_class == 'skirt': 129 | skirt_W = self.skirt_skinning.repeat(num_batch, 1, 1) 130 | skirt_T = torch.matmul(skirt_W, A.view(num_batch, 24, 16)).view(num_batch, -1, 4, 4) 131 | v_deformed = v_posed.clone() 132 | v_skirt_base = torch.einsum('sb,nbt->nst', self.skirt_weight, v_deformed) 133 | v_skirt = v_skirt_base + garment_d 134 | v_skirt_homo = torch.cat([ 135 | v_skirt, torch.ones(num_batch, v_skirt.shape[1], 1, device=self.cur_device) 136 | ], dim=2) 137 | v_skirt = torch.matmul(skirt_T, torch.unsqueeze(v_skirt_homo, -1)) 138 | v_skirt = v_skirt[:, :, :3, 0] 139 | if ret_skirt_skinning: 140 | return v_body, v_skirt, skirt_T, v_skirt_base 141 | return v_body, v_skirt 142 | 143 | v_posed_homo = torch.cat([ 144 | v_deformed, 145 | torch.ones(num_batch, v_deformed.shape[1], 1, device=self.cur_device)], dim=2) 146 | v_homo = torch.matmul(T, torch.unsqueeze(v_posed_homo, -1)) 147 | v_garment = v_homo[:, :, :3, 0] 148 | return v_body, v_garment[:, self.class_info[garment_class]['vert_indices']] 149 | else: 150 | return v_body 151 | 152 | def forward_poseshaped(self, theta, beta=None, garment_class=None): 153 | if not self.cur_device: 154 | device = theta.device 155 | self.cur_device = torch.device(device.type, device.index) 156 | 157 | num_batch = theta.shape[0] 158 | 159 | if beta is not None: 160 | v_shaped = torch.matmul( 161 | beta, self.shapedirs).view(-1, self.size[0], self.size[1]) + self.v_template 162 | else: 163 | v_shaped = self.v_template.unsqueeze(0).expand(num_batch, -1, -1) 164 | Jx = torch.matmul(v_shaped[:, :, 0], self.J_regressor) 165 | Jy = torch.matmul(v_shaped[:, :, 1], self.J_regressor) 166 | Jz = torch.matmul(v_shaped[:, :, 2], self.J_regressor) 167 | J = torch.stack([Jx, Jy, Jz], dim=2) 168 | 169 | Rs = batch_rodrigues(theta.contiguous().view(-1, 3)).view(-1, 24, 3, 3) 170 | pose_feature = (Rs[:, 1:, :, :]).sub(1.0, self.e3).view(-1, 207) 171 | v_posed = torch.matmul( 172 | pose_feature, self.posedirs).view(-1, self.size[0], self.size[1]) + v_shaped 173 | 174 | if garment_class is not None: 175 | if garment_class == 'skirt': 176 | v_posed = torch.einsum('sb,nbt->nst', self.skirt_weight, v_posed.clone()) 177 | return v_posed 178 | v_posed = v_posed[:, self.class_info[garment_class]['vert_indices']] 179 | return v_posed 180 | 181 | 182 | def batch_rodrigues(theta): 183 | # theta N x 3 184 | l1norm = torch.norm(theta + 1e-8, p=2, dim=1) 185 | angle = torch.unsqueeze(l1norm, -1) 186 | normalized = torch.div(theta, angle) 187 | angle = angle * 0.5 188 | v_cos = torch.cos(angle) 189 | v_sin = torch.sin(angle) 190 | quat = torch.cat([v_cos, v_sin * normalized], dim=1) 191 | 192 | return quat2mat(quat) 193 | 194 | 195 | def quat2mat(quat): 196 | """Convert quaternion coefficients to rotation matrix. 197 | Args: 198 | quat: size = [B, 4] 4 <===>(w, x, y, z) 199 | Returns: 200 | Rotation matrix corresponding to the quaternion -- size = [B, 3, 3] 201 | """ 202 | norm_quat = quat 203 | norm_quat = norm_quat / norm_quat.norm(p=2, dim=1, keepdim=True) 204 | w, x, y, z = norm_quat[:, 0], norm_quat[:, 1], norm_quat[:, 2], norm_quat[:, 3] 205 | 206 | B = quat.size(0) 207 | 208 | w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2) 209 | wx, wy, wz = w * x, w * y, w * z 210 | xy, xz, yz = x * y, x * z, y * z 211 | 212 | rotMat = torch.stack([w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz, 213 | 2 * wz + 2 * xy, w2 - x2 + y2 - z2, 2 * yz - 2 * wx, 214 | 2 * xz - 2 * wy, 2 * wx + 2 * yz, w2 - x2 - y2 + z2], dim=1).view(B, 3, 3) 215 | return rotMat 216 | 217 | 218 | def batch_global_rigid_transformation(Rs, Js, parent, device, rotate_base=False): 219 | N = Rs.shape[0] 220 | if rotate_base: 221 | np_rot_x = np.array([[1, 0, 0], [0, -1, 0], [0, 0, -1]], dtype=np.float) 222 | np_rot_x = np.reshape(np.tile(np_rot_x, [N, 1]), [N, 3, 3]) 223 | rot_x = torch.from_numpy(np_rot_x).float().to(device) 224 | root_rotation = torch.matmul(Rs[:, 0, :, :], rot_x) 225 | else: 226 | root_rotation = Rs[:, 0, :, :] 227 | Js = torch.unsqueeze(Js, -1) 228 | 229 | def make_A(R, t): 230 | R_homo = F.pad(R, [0, 0, 0, 1, 0, 0]) 231 | t_homo = torch.cat([t, torch.ones(N, 1, 1).to(device)], dim=1) 232 | return torch.cat([R_homo, t_homo], 2) 233 | 234 | A0 = make_A(root_rotation, Js[:, 0]) 235 | results = [A0] 236 | 237 | for i in range(1, parent.shape[0]): 238 | j_here = Js[:, i] - Js[:, parent[i]] 239 | A_here = make_A(Rs[:, i], j_here) 240 | res_here = torch.matmul(results[parent[i]], A_here) 241 | results.append(res_here) 242 | 243 | results = torch.stack(results, dim=1) 244 | 245 | new_J = results[:, :, :3, 3] 246 | Js_w0 = torch.cat([Js, torch.zeros(N, 24, 1, 1).to(device)], dim=2) 247 | init_bone = torch.matmul(results, Js_w0) 248 | init_bone = F.pad(init_bone, [3, 0, 0, 0, 0, 0, 0, 0]) 249 | A = results - init_bone 250 | 251 | return new_J, A 252 | 253 | 254 | def batch_lrotmin(theta): 255 | theta = theta[:, 3:].contiguous() 256 | Rs = batch_rodrigues(theta.view(-1, 3)) 257 | print(Rs.shape) 258 | e = torch.eye(3).float() 259 | Rs = Rs.sub(1.0, e) 260 | 261 | return Rs.view(-1, 23 * 9) 262 | 263 | 264 | def batch_orth_proj(X, camera): 265 | ''' 266 | X is N x num_points x 3 267 | ''' 268 | camera = camera.view(-1, 1, 3) 269 | X_trans = X[:, :, :2] + camera[:, :, 1:] 270 | shape = X_trans.shape 271 | return (camera[:, :, 0] * X_trans.view(shape[0], -1)).view(shape) 272 | 273 | 274 | if __name__ == "__main__": 275 | device = torch.device("cpu") 276 | gender = 'female' 277 | smpl = TorchSMPL4Garment(gender=gender).to(device) 278 | -------------------------------------------------------------------------------- /refu_tailornet/smpl_lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aldehydecho/ReFU/9476bb1da1b8adc1001d21758e884aefb91ffd60/refu_tailornet/smpl_lib/__init__.py -------------------------------------------------------------------------------- /refu_tailornet/smpl_lib/ch.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | # -*- coding: utf-8 -*- 3 | 4 | import numpy as np 5 | import chumpy as ch 6 | import scipy.sparse as sp 7 | 8 | from chumpy.utils import col 9 | 10 | 11 | class sp_dot(ch.Ch): 12 | terms = 'a', 13 | dterms = 'b', 14 | 15 | def on_changed(self, which): 16 | if 'a' in which: 17 | a_csr = sp.csr_matrix(self.a) 18 | # To stay consistent with numpy, we must upgrade 1D arrays to 2D 19 | self.ar = sp.csr_matrix((a_csr.data, a_csr.indices, a_csr.indptr), 20 | shape=(max(np.sum(a_csr.shape[:-1]), 1), a_csr.shape[-1])) 21 | 22 | if 'b' in which: 23 | self.br = col(self.b.r) if len(self.b.r.shape) < 2 else self.b.r.reshape((self.b.r.shape[0], -1)) 24 | 25 | if 'a' in which or 'b' in which: 26 | self.k = sp.kron(self.ar, sp.eye(self.br.shape[1], self.br.shape[1])) 27 | 28 | def compute_r(self): 29 | return self.a.dot(self.b.r) 30 | 31 | def compute(self): 32 | if self.br.ndim <= 1: 33 | return self.ar 34 | elif self.br.ndim <= 2: 35 | return self.k 36 | else: 37 | raise NotImplementedError 38 | 39 | def compute_dr_wrt(self, wrt): 40 | if wrt is self.b: 41 | return self.compute() 42 | 43 | class PReLU(ch.Ch): 44 | terms = 'p' 45 | dterms = 'x' 46 | 47 | def compute_r(self): 48 | r = self.x.r.copy() 49 | r[r < 0] *= self.p 50 | 51 | return r 52 | 53 | def compute_dr_wrt(self, wrt): 54 | if wrt is not self.x: 55 | return None 56 | 57 | dr = np.zeros(self.x.r.shape) 58 | dr[self.x.r > 0] = 1 59 | dr[self.x.r < 0] = self.p 60 | 61 | return sp.diags([dr.ravel()], [0]) 62 | 63 | ## RelU is PRelU with p=0 64 | class Clamp(ch.Ch): 65 | dterms = 'x' 66 | terms = 'c' 67 | 68 | def compute_r(self): 69 | r = self.x.r.copy() 70 | r[r > self.c] = self.c 71 | 72 | return r 73 | 74 | def compute_dr_wrt(self, wrt): 75 | if wrt is not self.x: 76 | return None 77 | 78 | dr = np.zeros(self.x.r.shape) 79 | dr[self.x.r < self.c] = 1 80 | 81 | return sp.diags([dr.ravel()], [0]) -------------------------------------------------------------------------------- /refu_tailornet/smpl_lib/ch_smpl.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | # -*- coding: utf-8 -*- 3 | 4 | import numpy as np 5 | import chumpy as ch 6 | import pickle as pkl 7 | import scipy.sparse as sp 8 | from chumpy.ch import Ch 9 | from .posemapper import posemap, Rodrigues 10 | from .serialization import backwards_compatibility_replacements 11 | 12 | from smpl_lib.ch import sp_dot 13 | 14 | 15 | class Smpl(Ch): 16 | """ 17 | Class to store SMPL object with slightly improved code and access to more matrices 18 | """ 19 | terms = 'model', 20 | dterms = 'trans', 'betas', 'pose', 'v_personal', 'v_template' 21 | 22 | def __init__(self, *args, **kwargs): 23 | self.on_changed(self._dirty_vars) 24 | 25 | def on_changed(self, which): 26 | if 'model' in which: 27 | if not isinstance(self.model, dict): 28 | dd = pkl.load(open(self.model, 'rb'), encoding='latin1') 29 | else: 30 | dd = self.model 31 | 32 | backwards_compatibility_replacements(dd) 33 | 34 | # for s in ['v_template', 'weights', 'posedirs', 'pose', 'trans', 'shapedirs', 'betas', 'J']: 35 | for s in ['posedirs', 'shapedirs']: 36 | if (s in dd) and not hasattr(dd[s], 'dterms'): 37 | dd[s] = ch.array(dd[s]) 38 | 39 | self.f = dd['f'] 40 | self.shapedirs = dd['shapedirs'] 41 | self.J_regressor = dd['J_regressor'] 42 | if 'J_regressor_prior' in dd: 43 | self.J_regressor_prior = dd['J_regressor_prior'] 44 | self.bs_type = dd['bs_type'] 45 | self.bs_style = dd['bs_style'] 46 | self.weights = ch.array(dd['weights']) 47 | if 'vert_sym_idxs' in dd: 48 | self.vert_sym_idxs = dd['vert_sym_idxs'] 49 | if 'weights_prior' in dd: 50 | self.weights_prior = dd['weights_prior'] 51 | self.kintree_table = dd['kintree_table'] 52 | self.posedirs = dd['posedirs'] 53 | 54 | if not hasattr(self, 'betas'): 55 | self.betas = ch.zeros(self.shapedirs.shape[-1]) 56 | 57 | if not hasattr(self, 'trans'): 58 | self.trans = ch.zeros(3) 59 | 60 | if not hasattr(self, 'pose'): 61 | self.pose = ch.zeros(72) 62 | 63 | if not hasattr(self, 'v_template'): 64 | self.v_template = ch.array(dd['v_template']) 65 | 66 | if not hasattr(self, 'v_personal'): 67 | self.v_personal = ch.zeros_like(self.v_template) 68 | 69 | self._set_up() 70 | 71 | def _set_up(self): 72 | self.v_shaped = self.shapedirs.dot(self.betas) + self.v_template 73 | 74 | self.v_shaped_personal = self.v_shaped + self.v_personal 75 | if sp.issparse(self.J_regressor): 76 | self.J = sp_dot(self.J_regressor, self.v_shaped) 77 | else: 78 | self.J = ch.sum(self.J_regressor.T.reshape(-1, 1, 24) * self.v_shaped.reshape(-1, 3, 1), axis=0).T 79 | self.v_posevariation = self.posedirs.dot(posemap(self.bs_type)(self.pose)) 80 | self.v_poseshaped = self.v_shaped_personal + self.v_posevariation 81 | 82 | self.A, A_global = self._global_rigid_transformation() 83 | self.Jtr = ch.vstack([g[:3, 3] for g in A_global]) 84 | self.J_transformed = self.Jtr + self.trans.reshape((1, 3)) 85 | 86 | self.V = self.A.dot(self.weights.T) 87 | 88 | rest_shape_h = ch.hstack((self.v_poseshaped, ch.ones((self.v_poseshaped.shape[0], 1)))) 89 | self.v_posed = ch.sum(self.V.T * rest_shape_h.reshape(-1, 4, 1), axis=1)[:, :3] 90 | self.v = self.v_posed + self.trans 91 | 92 | def _global_rigid_transformation(self): 93 | results = {} 94 | pose = self.pose.reshape((-1, 3)) 95 | parent = {i: self.kintree_table[0, i] for i in range(1, self.kintree_table.shape[1])} 96 | 97 | with_zeros = lambda x: ch.vstack((x, ch.array([[0.0, 0.0, 0.0, 1.0]]))) 98 | pack = lambda x: ch.hstack([ch.zeros((4, 3)), x.reshape((4, 1))]) 99 | 100 | results[0] = with_zeros(ch.hstack((Rodrigues(pose[0, :]), self.J[0, :].reshape((3, 1))))) 101 | 102 | for i in range(1, self.kintree_table.shape[1]): 103 | results[i] = results[parent[i]].dot(with_zeros(ch.hstack(( 104 | Rodrigues(pose[i, :]), # rotation around bone endpoint 105 | (self.J[i, :] - self.J[parent[i], :]).reshape((3, 1)) # bone 106 | )))) 107 | 108 | results = [results[i] for i in sorted(results.keys())] 109 | results_global = results 110 | 111 | # subtract rotated J position 112 | results2 = [results[i] - (pack( 113 | results[i].dot(ch.concatenate((self.J[i, :], [0])))) 114 | ) for i in range(len(results))] 115 | result = ch.dstack(results2) 116 | 117 | return result, results_global 118 | 119 | def compute_r(self): 120 | return self.v.r 121 | 122 | def compute_dr_wrt(self, wrt): 123 | if wrt is not self.trans and wrt is not self.betas and wrt is not self.pose and wrt is not self.v_personal and wrt is not self.v_template: 124 | return None 125 | 126 | return self.v.dr_wrt(wrt) 127 | 128 | 129 | if __name__ == '__main__': 130 | from tnutils.smpl_paths import SmplPaths 131 | 132 | dp = SmplPaths(gender='male') 133 | 134 | smpl = Smpl(dp.get_smpl_file()) 135 | 136 | from psbody.mesh.meshviewer import MeshViewer 137 | from psbody.mesh import Mesh 138 | import IPython 139 | IPython.embed() 140 | #mv = MeshViewer() 141 | #mv.set_static_meshes([Mesh(smpl.r, smpl.f)]) 142 | 143 | input("Press Enter to continue...") 144 | -------------------------------------------------------------------------------- /refu_tailornet/smpl_lib/lbs.py: -------------------------------------------------------------------------------- 1 | ## This function is copied from https://github.com/Rubikplayer/flame-fitting 2 | 3 | ''' 4 | Copyright 2015 Matthew Loper, Naureen Mahmood and the Max Planck Gesellschaft. All rights reserved. 5 | This software is provided for research purposes only. 6 | By using this software you agree to the terms of the SMPL Model license here http://smpl.is.tue.mpg.de/license 7 | More information about SMPL is available here http://smpl.is.tue.mpg. 8 | For comments or questions, please email us at: smpl@tuebingen.mpg.de 9 | About this file: 10 | ================ 11 | This file defines linear blend skinning for the SMPL loader which 12 | defines the effect of bones and blendshapes on the vertices of the template mesh. 13 | Modules included: 14 | - global_rigid_transformation: 15 | computes global rotation & translation of the model 16 | - verts_core: [overloaded function inherited from verts.verts_core] 17 | computes the blending of joint-influences for each vertex based on type of skinning 18 | ''' 19 | 20 | from .posemapper import posemap 21 | import chumpy 22 | import numpy as np 23 | 24 | 25 | def global_rigid_transformation(pose, J, kintree_table, xp): 26 | results = {} 27 | pose = pose.reshape((-1, 3)) 28 | id_to_col = {kintree_table[1, i]: i for i in range(kintree_table.shape[1])} 29 | parent = {i: id_to_col[kintree_table[0, i]] for i in range(1, kintree_table.shape[1])} 30 | 31 | if xp == chumpy: 32 | from posemapper import Rodrigues 33 | rodrigues = lambda x: Rodrigues(x) 34 | else: 35 | import cv2 36 | rodrigues = lambda x: cv2.Rodrigues(x)[0] 37 | 38 | with_zeros = lambda x: xp.vstack((x, xp.array([[0.0, 0.0, 0.0, 1.0]]))) 39 | results[0] = with_zeros(xp.hstack((rodrigues(pose[0, :]), J[0, :].reshape((3, 1))))) 40 | 41 | for i in range(1, kintree_table.shape[1]): 42 | results[i] = results[parent[i]].dot(with_zeros(xp.hstack(( 43 | rodrigues(pose[i, :]), 44 | ((J[i, :] - J[parent[i], :]).reshape((3, 1))) 45 | )))) 46 | 47 | pack = lambda x: xp.hstack([np.zeros((4, 3)), x.reshape((4, 1))]) 48 | 49 | results = [results[i] for i in sorted(results.keys())] 50 | results_global = results 51 | 52 | if True: 53 | results2 = [results[i] - (pack( 54 | results[i].dot(xp.concatenate(((J[i, :]), 0)))) 55 | ) for i in range(len(results))] 56 | results = results2 57 | result = xp.dstack(results) 58 | return result, results_global 59 | 60 | 61 | def verts_core(pose, v, J, weights, kintree_table, want_Jtr=False, xp=chumpy): 62 | A, A_global = global_rigid_transformation(pose, J, kintree_table, xp) 63 | T = A.dot(weights.T) 64 | 65 | rest_shape_h = xp.vstack((v.T, np.ones((1, v.shape[0])))) 66 | 67 | v = (T[:, 0, :] * rest_shape_h[0, :].reshape((1, -1)) + 68 | T[:, 1, :] * rest_shape_h[1, :].reshape((1, -1)) + 69 | T[:, 2, :] * rest_shape_h[2, :].reshape((1, -1)) + 70 | T[:, 3, :] * rest_shape_h[3, :].reshape((1, -1))).T 71 | 72 | v = v[:, :3] 73 | 74 | if not want_Jtr: 75 | return v 76 | Jtr = xp.vstack([g[:3, 3] for g in A_global]) 77 | return (v, Jtr) 78 | -------------------------------------------------------------------------------- /refu_tailornet/smpl_lib/posemapper.py: -------------------------------------------------------------------------------- 1 | ## This function is copied from https://github.com/Rubikplayer/flame-fitting 2 | 3 | ''' 4 | Copyright 2015 Matthew Loper, Naureen Mahmood and the Max Planck Gesellschaft. All rights reserved. 5 | This software is provided for research purposes only. 6 | By using this software you agree to the terms of the SMPL Model license here http://smpl.is.tue.mpg.de/license 7 | More information about SMPL is available here http://smpl.is.tue.mpg. 8 | For comments or questions, please email us at: smpl@tuebingen.mpg.de 9 | About this file: 10 | ================ 11 | This module defines the mapping of joint-angles to pose-blendshapes. 12 | Modules included: 13 | - posemap: 14 | computes the joint-to-pose blend shape mapping given a mapping type as input 15 | ''' 16 | 17 | import chumpy as ch 18 | import numpy as np 19 | import cv2 20 | 21 | 22 | class Rodrigues(ch.Ch): 23 | dterms = 'rt' 24 | 25 | def compute_r(self): 26 | return cv2.Rodrigues(self.rt.r)[0] 27 | 28 | def compute_dr_wrt(self, wrt): 29 | if wrt is self.rt: 30 | return cv2.Rodrigues(self.rt.r)[1].T 31 | 32 | 33 | def lrotmin(p): 34 | if isinstance(p, np.ndarray): 35 | p = p.ravel()[3:] 36 | return np.concatenate( 37 | [(cv2.Rodrigues(np.array(pp))[0] - np.eye(3)).ravel() for pp in p.reshape((-1, 3))]).ravel() 38 | if p.ndim != 2 or p.shape[1] != 3: 39 | p = p.reshape((-1, 3)) 40 | p = p[1:] 41 | return ch.concatenate([(Rodrigues(pp) - ch.eye(3)).ravel() for pp in p]).ravel() 42 | 43 | 44 | def posemap(s): 45 | if s == 'lrotmin': 46 | return lrotmin 47 | else: 48 | raise Exception('Unknown posemapping: %s' % (str(s),)) -------------------------------------------------------------------------------- /refu_tailornet/smpl_lib/serialization.py: -------------------------------------------------------------------------------- 1 | ## This function is copied from https://github.com/Rubikplayer/flame-fitting 2 | 3 | ''' 4 | Copyright 2015 Matthew Loper, Naureen Mahmood and the Max Planck Gesellschaft. All rights reserved. 5 | This software is provided for research purposes only. 6 | By using this software you agree to the terms of the SMPL Model license here http://smpl.is.tue.mpg.de/license 7 | More information about SMPL is available here http://smpl.is.tue.mpg. 8 | For comments or questions, please email us at: smpl@tuebingen.mpg.de 9 | About this file: 10 | ================ 11 | This file defines the serialization functions of the SMPL model. 12 | Modules included: 13 | - save_model: 14 | saves the SMPL model to a given file location as a .pkl file 15 | - load_model: 16 | loads the SMPL model from a given file location (i.e. a .pkl file location), 17 | or a dictionary object. 18 | ''' 19 | import pickle 20 | import numpy as np 21 | import chumpy as ch 22 | from chumpy.ch import MatVecMult 23 | from .verts import verts_core 24 | from .posemapper import posemap 25 | 26 | def backwards_compatibility_replacements(dd): 27 | # replacements 28 | if 'default_v' in dd: 29 | dd['v_template'] = dd['default_v'] 30 | del dd['default_v'] 31 | if 'template_v' in dd: 32 | dd['v_template'] = dd['template_v'] 33 | del dd['template_v'] 34 | if 'joint_regressor' in dd: 35 | dd['J_regressor'] = dd['joint_regressor'] 36 | del dd['joint_regressor'] 37 | if 'blendshapes' in dd: 38 | dd['posedirs'] = dd['blendshapes'] 39 | del dd['blendshapes'] 40 | if 'J' not in dd: 41 | dd['J'] = dd['joints'] 42 | del dd['joints'] 43 | 44 | # defaults 45 | if 'bs_style' not in dd: 46 | dd['bs_style'] = 'lbs' 47 | 48 | 49 | def ready_arguments(fname_or_dict): 50 | if not isinstance(fname_or_dict, dict): 51 | dd = pickle.load(open(fname_or_dict)) 52 | else: 53 | dd = fname_or_dict 54 | 55 | backwards_compatibility_replacements(dd) 56 | 57 | want_shapemodel = 'shapedirs' in dd 58 | nposeparms = dd['kintree_table'].shape[1] * 3 59 | 60 | if 'trans' not in dd: 61 | dd['trans'] = np.zeros(3) 62 | if 'pose' not in dd: 63 | dd['pose'] = np.zeros(nposeparms) 64 | if 'shapedirs' in dd and 'betas' not in dd: 65 | dd['betas'] = np.zeros(dd['shapedirs'].shape[-1]) 66 | 67 | for s in ['v_template', 'weights', 'posedirs', 'pose', 'trans', 'shapedirs', 'betas', 'J']: 68 | if (s in dd) and not hasattr(dd[s], 'dterms'): 69 | dd[s] = ch.array(dd[s]) 70 | 71 | if want_shapemodel: 72 | dd['v_shaped'] = dd['shapedirs'].dot(dd['betas']) + dd['v_template'] 73 | v_shaped = dd['v_shaped'] 74 | J_tmpx = MatVecMult(dd['J_regressor'], v_shaped[:, 0]) 75 | J_tmpy = MatVecMult(dd['J_regressor'], v_shaped[:, 1]) 76 | J_tmpz = MatVecMult(dd['J_regressor'], v_shaped[:, 2]) 77 | dd['J'] = ch.vstack((J_tmpx, J_tmpy, J_tmpz)).T 78 | dd['v_posed'] = v_shaped + dd['posedirs'].dot(posemap(dd['bs_type'])(dd['pose'])) 79 | else: 80 | dd['v_posed'] = dd['v_template'] + dd['posedirs'].dot(posemap(dd['bs_type'])(dd['pose'])) 81 | 82 | return dd 83 | 84 | def load_model(fname_or_dict): 85 | dd = ready_arguments(fname_or_dict) 86 | 87 | args = { 88 | 'pose': dd['pose'], 89 | 'v': dd['v_posed'], 90 | 'J': dd['J'], 91 | 'weights': dd['weights'], 92 | 'kintree_table': dd['kintree_table'], 93 | 'xp': ch, 94 | 'want_Jtr': True, 95 | 'bs_style': dd['bs_style'] 96 | } 97 | 98 | result, Jtr = verts_core(**args) 99 | result = result + dd['trans'].reshape((1, 3)) 100 | result.J_transformed = Jtr + dd['trans'].reshape((1, 3)) 101 | 102 | for k, v in dd.items(): 103 | setattr(result, k, v) 104 | 105 | return result 106 | -------------------------------------------------------------------------------- /refu_tailornet/smpl_lib/verts.py: -------------------------------------------------------------------------------- 1 | ## This function is copied from https://github.com/Rubikplayer/flame-fitting 2 | 3 | ''' 4 | Copyright 2015 Matthew Loper, Naureen Mahmood and the Max Planck Gesellschaft. All rights reserved. 5 | This software is provided for research purposes only. 6 | By using this software you agree to the terms of the SMPL Model license here http://smpl.is.tue.mpg.de/license 7 | More information about SMPL is available here http://smpl.is.tue.mpg. 8 | For comments or questions, please email us at: smpl@tuebingen.mpg.de 9 | About this file: 10 | ================ 11 | This file defines the basic skinning modules for the SMPL loader which 12 | defines the effect of bones and blendshapes on the vertices of the template mesh. 13 | Modules included: 14 | - verts_decorated: 15 | creates an instance of the SMPL model which inherits model attributes from another 16 | SMPL model. 17 | - verts_core: [overloaded function inherited by lbs.verts_core] 18 | computes the blending of joint-influences for each vertex based on type of skinning 19 | ''' 20 | 21 | import chumpy 22 | from . import lbs 23 | from .posemapper import posemap 24 | import scipy.sparse as sp 25 | from chumpy.ch import MatVecMult 26 | 27 | 28 | def ischumpy(x): return hasattr(x, 'dterms') 29 | 30 | 31 | def verts_decorated(trans, pose, 32 | v_template, J, weights, kintree_table, bs_style, f, 33 | bs_type=None, posedirs=None, betas=None, shapedirs=None, want_Jtr=False): 34 | for which in [trans, pose, v_template, weights, posedirs, betas, shapedirs]: 35 | if which is not None: 36 | assert ischumpy(which) 37 | 38 | v = v_template 39 | 40 | if shapedirs is not None: 41 | if betas is None: 42 | betas = chumpy.zeros(shapedirs.shape[-1]) 43 | v_shaped = v + shapedirs.dot(betas) 44 | else: 45 | v_shaped = v 46 | 47 | if posedirs is not None: 48 | v_posed = v_shaped + posedirs.dot(posemap(bs_type)(pose)) 49 | else: 50 | v_posed = v_shaped 51 | 52 | v = v_posed 53 | 54 | if sp.issparse(J): 55 | regressor = J 56 | J_tmpx = MatVecMult(regressor, v_shaped[:, 0]) 57 | J_tmpy = MatVecMult(regressor, v_shaped[:, 1]) 58 | J_tmpz = MatVecMult(regressor, v_shaped[:, 2]) 59 | J = chumpy.vstack((J_tmpx, J_tmpy, J_tmpz)).T 60 | else: 61 | assert (ischumpy(J)) 62 | 63 | assert (bs_style == 'lbs') 64 | result, Jtr = lbs.verts_core(pose, v, J, weights, kintree_table, want_Jtr=True, xp=chumpy) 65 | 66 | tr = trans.reshape((1, 3)) 67 | result = result + tr 68 | Jtr = Jtr + tr 69 | 70 | result.trans = trans 71 | result.f = f 72 | result.pose = pose 73 | result.v_template = v_template 74 | result.J = J 75 | result.weights = weights 76 | result.kintree_table = kintree_table 77 | result.bs_style = bs_style 78 | result.bs_type = bs_type 79 | if posedirs is not None: 80 | result.posedirs = posedirs 81 | result.v_posed = v_posed 82 | if shapedirs is not None: 83 | result.shapedirs = shapedirs 84 | result.betas = betas 85 | result.v_shaped = v_shaped 86 | if want_Jtr: 87 | result.J_transformed = Jtr 88 | return result 89 | 90 | 91 | def verts_core(pose, v, J, weights, kintree_table, bs_style, want_Jtr=False, xp=chumpy): 92 | if xp == chumpy: 93 | assert (hasattr(pose, 'dterms')) 94 | assert (hasattr(v, 'dterms')) 95 | assert (hasattr(J, 'dterms')) 96 | assert (hasattr(weights, 'dterms')) 97 | 98 | assert (bs_style == 'lbs') 99 | result = lbs.verts_core(pose, v, J, weights, kintree_table, want_Jtr, xp) 100 | 101 | return result 102 | -------------------------------------------------------------------------------- /refu_tailornet/tnutils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aldehydecho/ReFU/9476bb1da1b8adc1001d21758e884aefb91ffd60/refu_tailornet/tnutils/__init__.py -------------------------------------------------------------------------------- /refu_tailornet/tnutils/diffusion_smoothing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy 3 | import scipy.sparse as sp 4 | from psbody.mesh import Mesh 5 | # , MeshViewer 6 | 7 | 8 | def numpy_laplacian_uniform(v, f): 9 | """Computes uniform laplacian operator on mesh.""" 10 | import scipy.sparse as sp 11 | from sklearn.preprocessing import normalize 12 | from psbody.mesh.topology.connectivity import get_vert_connectivity 13 | 14 | connectivity = get_vert_connectivity(Mesh(v=v, f=f)) 15 | # connectivity is a sparse matrix, and np.clip can not applied directly on 16 | # a sparse matrix. 17 | connectivity.data = np.clip(connectivity.data, 0, 1) 18 | lap = normalize(connectivity, norm='l1', axis=1) 19 | lap = lap - sp.eye(connectivity.shape[0]) 20 | 21 | return lap 22 | 23 | 24 | def numpy_laplacian_cot(v, f): 25 | """Computes cotangent laplacian operator on mesh.""" 26 | n = len(v) 27 | 28 | v_a = f[:, 0] 29 | v_b = f[:, 1] 30 | v_c = f[:, 2] 31 | 32 | ab = v[v_a] - v[v_b] 33 | bc = v[v_b] - v[v_c] 34 | ca = v[v_c] - v[v_a] 35 | 36 | cot_a = -1 * (ab * ca).sum(axis=1) / (np.sqrt(np.sum(np.cross(ab, ca) ** 2, axis=-1)) + 1.e-10) 37 | cot_b = -1 * (bc * ab).sum(axis=1) / (np.sqrt(np.sum(np.cross(bc, ab) ** 2, axis=-1)) + 1.e-10) 38 | cot_c = -1 * (ca * bc).sum(axis=1) / (np.sqrt(np.sum(np.cross(ca, bc) ** 2, axis=-1)) + 1.e-10) 39 | 40 | I = np.concatenate((v_a, v_c, v_a, v_b, v_b, v_c)) 41 | J = np.concatenate((v_c, v_a, v_b, v_a, v_c, v_b)) 42 | W = 0.5 * np.concatenate((cot_b, cot_b, cot_c, cot_c, cot_a, cot_a)) 43 | 44 | L = sp.csr_matrix((W, (I, J)), shape=(n, n)) 45 | L = L - sp.spdiags(L * np.ones(n), 0, n, n) 46 | 47 | return L 48 | 49 | 50 | def direct_smoothing(v, f, smoothness=0.1, Ltype='cotangent'): 51 | """Apply direct smoothing on mesh.""" 52 | if Ltype == 'cotangent': 53 | L = numpy_laplacian_cot(v, f) 54 | elif Ltype == 'uniform': 55 | L = numpy_laplacian_uniform(v, f) 56 | else: 57 | raise AttributeError 58 | new_v = v + smoothness * L.dot(v) 59 | return new_v 60 | 61 | 62 | class DiffusionSmoothing(object): 63 | """A class useful to apply smoothing repeatedly in efficient manner on the same-topology meshes.""" 64 | 65 | def __init__(self, v, f): 66 | """Computes and stores necessary variables. 67 | 68 | v is only used for getting total number of vertices. f defines the topology. 69 | """ 70 | self.num_v = v.shape[0] 71 | self.v = v 72 | self.f = f 73 | self.set_boundary_ids_and_mats(v, f) 74 | self.uniL = None 75 | 76 | def get_uniform_lap_smoothing(self): 77 | """Computes uniform laplacian for smoothing. 78 | 79 | Boundary vertices are smoothed not by all neighbors but only neighboring 80 | boundary vertices in order to prevent boundary shrinking. 81 | """ 82 | L = numpy_laplacian_uniform(self.v, self.f) 83 | 84 | # remove rows corresponding to boundary vertices 85 | for row in self.b_ids: 86 | L.data[L.indptr[row]:L.indptr[row + 1]] = 0 87 | L.eliminate_zeros() 88 | 89 | num_b = self.b_ids.shape[0] 90 | I = np.tile(self.b_ids, 3) 91 | J = np.hstack(( 92 | self.b_ids, 93 | self.b_ids[self.l_ids], 94 | self.b_ids[self.r_ids], 95 | )) 96 | W = np.hstack(( 97 | -1 * np.ones(num_b), 98 | 0.5 * np.ones(num_b), 99 | 0.5 * np.ones(num_b), 100 | )) 101 | mat = sp.csr_matrix((W, (I, J)), shape=(self.num_v, self.num_v)) 102 | L = L + mat 103 | return L 104 | 105 | def set_boundary_ids_and_mats(self, v, f): 106 | from .geometry import get_boundary_verts 107 | _, b_rings = get_boundary_verts(v, f) 108 | 109 | def shift_left(ls, k): 110 | return ls[k:] + ls[:k] 111 | 112 | b_ids = [] 113 | l_ids = [] 114 | r_ids = [] 115 | for rg in b_rings: 116 | tmp = list(range(len(b_ids), len(b_ids) + len(rg))) 117 | ltmp = shift_left(tmp, 1) 118 | rtmp = shift_left(tmp, -1) 119 | l_ids.extend(ltmp) 120 | r_ids.extend(rtmp) 121 | 122 | b_ids.extend(rg) 123 | 124 | b_ids = np.asarray(b_ids, dtype=np.int64) 125 | num_b = b_ids.shape[0] 126 | m_ids = np.arange(num_b, dtype=np.int64) 127 | l_ids = np.asarray(l_ids, dtype=np.int64) 128 | r_ids = np.asarray(r_ids, dtype=np.int64) 129 | 130 | self.right_edge_mat = sp.csr_matrix(( 131 | np.hstack((-1*np.ones(num_b), np.ones(num_b))), 132 | (np.hstack((m_ids, m_ids)), np.hstack((m_ids, r_ids))) 133 | ), shape=(num_b, num_b) 134 | ) 135 | 136 | self.left_edge_mat = sp.csr_matrix(( 137 | np.hstack((-1 * np.ones(num_b), np.ones(num_b))), 138 | (np.hstack((m_ids, m_ids)), np.hstack((m_ids, l_ids))) 139 | ), shape=(num_b, num_b) 140 | ) 141 | 142 | # boundary vertex ids 143 | self.b_ids = b_ids 144 | # left and right boundary neighbour vertex ids 145 | self.l_ids = l_ids 146 | self.r_ids = r_ids 147 | 148 | def smooth_cotlap(self, verts, smoothness=0.03): 149 | """Smooth using cotangent laplacian. 150 | 151 | Boundary vertices are smoothed only by neighboring boundary vertices 152 | in order to prevent boundary shrinking. 153 | """ 154 | L = numpy_laplacian_cot(verts, self.f) 155 | new_verts = verts + smoothness * L.dot(verts) 156 | 157 | b_verts = verts[self.b_ids] 158 | le = 1. / (np.linalg.norm(self.left_edge_mat.dot(b_verts), axis=-1) + 1.0e-10) 159 | ri = 1. / (np.linalg.norm(self.right_edge_mat.dot(b_verts), axis=-1) + 1.0e-10) 160 | 161 | num_b = b_verts.shape[0] 162 | I = np.tile(np.arange(num_b), 3) 163 | J = np.hstack(( 164 | np.arange(num_b), 165 | self.l_ids, 166 | self.r_ids, 167 | )) 168 | W = np.hstack(( 169 | -1*np.ones(num_b), 170 | le / (le + ri), 171 | ri / (le + ri), 172 | )) 173 | mat = sp.csr_matrix((W, (I, J)), shape=(num_b, num_b)) 174 | new_verts[self.b_ids] = verts[self.b_ids] + smoothness * mat.dot(verts[self.b_ids]) 175 | return new_verts 176 | 177 | def smooth_uniform(self, verts, smoothness=0.03): 178 | """Smooth using uniform laplacian. 179 | 180 | Boundary vertices are smoothed only by neighboring boundary vertices 181 | in order to prevent boundary shrinking. 182 | """ 183 | if self.uniL is None: 184 | self.uniL = self.get_uniform_lap_smoothing() 185 | new_verts = verts + smoothness * self.uniL.dot(verts) 186 | return new_verts 187 | 188 | def smooth(self, verts, smoothness=0.03, n=1, Ltype="cotangent"): 189 | assert(Ltype in ["cotangent", "uniform"]) 190 | for i in range(n): 191 | if Ltype == 'uniform': 192 | verts = self.smooth_uniform(verts, smoothness) 193 | else: 194 | verts = self.smooth_cotlap(verts, smoothness) 195 | return verts 196 | 197 | 198 | if __name__ == "__main__": 199 | IS_SMPL = True 200 | fpath = "/BS/cpatel/work/data/learn_anim/mixture_exp31/000_0/smooth_TShirtNoCoat/0990/pred_0.ply" 201 | 202 | if not IS_SMPL: 203 | ms = Mesh(filename=fpath) 204 | else: 205 | from tnutils.smpl_paths import SmplPaths 206 | 207 | dp = SmplPaths(gender='female') 208 | smpl = dp.get_smpl() 209 | ms = Mesh(v=smpl.r, f=smpl.f) 210 | 211 | smoothing = DiffusionSmoothing(ms.v, ms.f) 212 | 213 | verts_smooth = ms.v.copy() 214 | for i in range(20): 215 | verts_smooth = smoothing.smooth(verts_smooth, smoothness=0.05) 216 | ms_smooth = Mesh(v=verts_smooth, f=ms.f) 217 | 218 | # from psbody.mesh import MeshViewers 219 | # mvs = MeshViewers((1,3)) 220 | # mvs[0][0].set_static_meshes([ms]) 221 | # mvs[0][1].set_static_meshes([ms_smooth]) 222 | # mvs[0][2].set_static_meshes([ms_smooth2]) 223 | # import ipdb 224 | # ipdb.set_trace() 225 | 226 | ms.write_ply("/BS/cpatel/work/orig.ply") 227 | ms_smooth.write_ply("/BS/cpatel/work/smooth.ply") 228 | -------------------------------------------------------------------------------- /refu_tailornet/tnutils/eval.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | from numpy import save 5 | 6 | currentdir = os.path.dirname(os.path.realpath(__file__)) 7 | parentdir = os.path.dirname(currentdir) 8 | sys.path.append(parentdir) 9 | 10 | class AverageMeter(object): 11 | """Computes and stores the average and current value""" 12 | def __init__(self): 13 | self.val = 0 14 | self.avg = 0 15 | self.sum = 0 16 | self.count = 0 17 | 18 | def reset(self): 19 | self.val = 0 20 | self.avg = 0 21 | self.sum = 0 22 | self.count = 0 23 | 24 | def update(self, val, n=1): 25 | self.val = val 26 | self.sum += val * n 27 | self.count += n 28 | self.avg = self.sum / self.count 29 | 30 | def evaluate(): 31 | """Evaluate TailorNet (or any model for that matter) on test set.""" 32 | from dataset.static_pose_shape_final import MultiStyleShape 33 | import torch 34 | from torch.utils.data import DataLoader 35 | from tnutils.eval import AverageMeter 36 | from models import ops 37 | 38 | gender = 'female' 39 | garment_class = 'skirt' 40 | 41 | dataset = MultiStyleShape(garment_class=garment_class, gender=gender, split='test') 42 | dataloader = DataLoader(dataset, batch_size=32, num_workers=0, shuffle=False, drop_last=False) 43 | print(len(dataset)) 44 | 45 | val_dist = AverageMeter() 46 | from models.tailornet_model import get_best_runner as tn_runner 47 | runner = tn_runner(garment_class, gender) 48 | # from trainer.base_trainer import get_best_runner as baseline_runner 49 | # runner = baseline_runner("/BS/cpatel/work/data/learn_anim/{}_{}_weights/tn_orig_baseline/{}_{}".format(garment_class, gender, garment_class, gender)) 50 | 51 | device = torch.device('cuda:0') 52 | with torch.no_grad(): 53 | for i, inputs in enumerate(dataloader): 54 | gt_verts, thetas, betas, gammas, _ = inputs 55 | 56 | thetas, betas, gammas = ops.mask_inputs(thetas, betas, gammas, garment_class) 57 | gt_verts = gt_verts.to(device) 58 | thetas = thetas.to(device) 59 | betas = betas.to(device) 60 | gammas = gammas.to(device) 61 | pred_verts = runner.forward(thetas=thetas, betas=betas, gammas=gammas).view(gt_verts.shape) 62 | 63 | dist = ops.verts_dist(gt_verts, pred_verts) * 1000. 64 | val_dist.update(dist.item(), gt_verts.shape[0]) 65 | print(i, len(dataloader)) 66 | print(val_dist.avg) 67 | 68 | 69 | def evaluate_save(): 70 | """Evaluate TailorNet (or any model for that matter) on test set.""" 71 | from dataset.static_pose_shape_final import MultiStyleShape 72 | import torch 73 | from torch.utils.data import DataLoader 74 | from tnutils.eval import AverageMeter 75 | from models import ops 76 | from models.smpl4garment import SMPL4Garment 77 | import os 78 | 79 | gender = 'female' 80 | garment_class = 'skirt' 81 | smpl = SMPL4Garment(gender) 82 | vis_freq = 512 83 | log_dir = "/BS/cpatel/work/code_test2/try" 84 | 85 | dataset = MultiStyleShape(garment_class=garment_class, gender=gender, split='test') 86 | dataloader = DataLoader(dataset, batch_size=32, num_workers=0, shuffle=False, drop_last=False) 87 | print(len(dataset)) 88 | 89 | val_dist = AverageMeter() 90 | from models.tailornet_model import get_best_runner as tn_runner 91 | runner = tn_runner(garment_class, gender) 92 | # from trainer.base_trainer import get_best_runner as baseline_runner 93 | # runner = baseline_runner("/BS/cpatel/work/data/learn_anim/{}_{}_weights/tn_orig_baseline/{}_{}".format(garment_class, gender, garment_class, gender)) 94 | 95 | device = torch.device('cuda:0') 96 | with torch.no_grad(): 97 | for i, inputs in enumerate(dataloader): 98 | gt_verts, thetas, betas, gammas, idxs = inputs 99 | 100 | thetas, betas, gammas = ops.mask_inputs(thetas, betas, gammas, garment_class) 101 | gt_verts = gt_verts.to(device) 102 | thetas = thetas.to(device) 103 | betas = betas.to(device) 104 | gammas = gammas.to(device) 105 | pred_verts = runner.forward(thetas=thetas, betas=betas, gammas=gammas).view(gt_verts.shape) 106 | 107 | for lidx, idx in enumerate(idxs): 108 | if idx % vis_freq != 0: 109 | continue 110 | theta = thetas[lidx].cpu().numpy() 111 | beta = betas[lidx].cpu().numpy() 112 | pred_vert = pred_verts[lidx].cpu().numpy() 113 | gt_vert = gt_verts[lidx].cpu().numpy() 114 | 115 | body_m, pred_m = smpl.run(theta=theta, garment_d=pred_vert, 116 | beta=beta, 117 | garment_class=garment_class) 118 | _, gt_m = smpl.run(theta=theta, garment_d=gt_vert, 119 | beta=beta, 120 | garment_class=garment_class) 121 | 122 | save_dir = log_dir 123 | pred_m.write_ply( 124 | os.path.join(save_dir, "pred_{}.ply".format(idx))) 125 | gt_m.write_ply(os.path.join(save_dir, "gt_{}.ply".format(idx))) 126 | body_m.write_ply( 127 | os.path.join(save_dir, "body_{}.ply".format(idx))) 128 | 129 | print(val_dist.avg) 130 | 131 | def new_evaluate_save(): 132 | """Evaluate TailorNet (or any model for that matter) on test set.""" 133 | from dataset.static_pose_shape_final import MultiStyleShape 134 | import torch 135 | from torch.utils.data import DataLoader 136 | from tnutils.eval import AverageMeter 137 | from models import ops 138 | from models.smpl4garment import SMPL4Garment 139 | import os 140 | import os.path as osp 141 | 142 | gender = 'male' 143 | garment_class = 'shirt' 144 | smpl = SMPL4Garment(gender) 145 | # vis_freq = 512 146 | save_dir = osp.join("/home/code-base/user_space/TailorNet_eval_results", garment_class+'_'+gender) 147 | 148 | body_save_dir = osp.join(save_dir, 'body') 149 | garment_save_dir = osp.join(save_dir, 'garment') 150 | 151 | predict_garment_save_dir = osp.join(garment_save_dir, 'predict') 152 | gt_garment_save_dir = osp.join(garment_save_dir, 'gt') 153 | 154 | os.makedirs(body_save_dir, exist_ok=True) 155 | os.makedirs(predict_garment_save_dir, exist_ok=True) 156 | os.makedirs(gt_garment_save_dir, exist_ok=True) 157 | 158 | dataset = MultiStyleShape(garment_class=garment_class, gender=gender, split='test') 159 | dataloader = DataLoader(dataset, batch_size=32, num_workers=0, shuffle=False, drop_last=False) 160 | print(len(dataset)) 161 | 162 | # val_dist = AverageMeter() 163 | from models.tailornet_model import get_best_runner as tn_runner 164 | runner = tn_runner(garment_class, gender) 165 | # from trainer.base_trainer import get_best_runner as baseline_runner 166 | # runner = baseline_runner("/BS/cpatel/work/data/learn_anim/{}_{}_weights/tn_orig_baseline/{}_{}".format(garment_class, gender, garment_class, gender)) 167 | 168 | device = torch.device('cuda:0') 169 | with torch.no_grad(): 170 | for i, inputs in enumerate(dataloader): 171 | gt_verts, thetas, betas, gammas, idxs = inputs 172 | 173 | thetas, betas, gammas = ops.mask_inputs(thetas, betas, gammas, garment_class) 174 | gt_verts = gt_verts.to(device) 175 | thetas = thetas.to(device) 176 | betas = betas.to(device) 177 | gammas = gammas.to(device) 178 | pred_verts = runner.forward(thetas=thetas, betas=betas, gammas=gammas).view(gt_verts.shape) 179 | 180 | for lidx, idx in enumerate(idxs): 181 | # if idx % vis_freq != 0: 182 | # continue 183 | theta = thetas[lidx].cpu().numpy() 184 | beta = betas[lidx].cpu().numpy() 185 | pred_vert = pred_verts[lidx].cpu().numpy() 186 | gt_vert = gt_verts[lidx].cpu().numpy() 187 | 188 | body_m, pred_m = smpl.run(theta=theta, garment_d=pred_vert, 189 | beta=beta, 190 | garment_class=garment_class) 191 | _, gt_m = smpl.run(theta=theta, garment_d=gt_vert, 192 | beta=beta, 193 | garment_class=garment_class) 194 | 195 | # pred_m.write_ply( 196 | # os.path.join(save_dir, "pred_{}.ply".format(idx))) 197 | # gt_m.write_ply(os.path.join(save_dir, "gt_{}.ply".format(idx))) 198 | # body_m.write_ply( 199 | # os.path.join(save_dir, "body_{}.ply".format(idx))) 200 | 201 | pred_m.write_obj( 202 | os.path.join(predict_garment_save_dir, "{}.obj".format(idx))) 203 | gt_m.write_obj(os.path.join(gt_garment_save_dir, "{}.obj".format(idx))) 204 | body_m.write_obj( 205 | os.path.join(body_save_dir, "{}.obj".format(idx))) 206 | 207 | # print(val_dist.avg) 208 | 209 | 210 | if __name__ == '__main__': 211 | # evaluate() 212 | new_evaluate_save() 213 | -------------------------------------------------------------------------------- /refu_tailornet/tnutils/geometry.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import chumpy as ch 3 | from psbody.mesh import Mesh 4 | import torch 5 | import scipy.sparse as sp 6 | from chumpy.utils import row, col 7 | 8 | 9 | def get_face_normals(verts, faces): 10 | num_batch = verts.size(0) 11 | num_faces = faces.size(0) 12 | 13 | # faces by vertices 14 | fbv = torch.index_select(verts, 1, faces.view(-1)).view(num_batch, num_faces, 3, 3) 15 | normals = torch.cross(fbv[:, :, 1] - fbv[:, :, 0], fbv[:, :, 2] - fbv[:, :, 0], dim=2) 16 | normals = normals / (torch.norm(normals, dim=-1, keepdim=True) + 1.e-10) 17 | return normals 18 | 19 | 20 | def get_vertex_normals(verts, faces, ret_face_normals=False): 21 | num_faces = faces.size(0) 22 | num_verts = verts.size(1) 23 | face_normals = get_face_normals(verts, faces) 24 | 25 | FID = torch.arange(num_faces).unsqueeze(1).repeat(1, 3).view(-1) 26 | VID = faces.view(-1) 27 | data = torch.ones_like(FID, dtype=torch.float32) 28 | 29 | mat = torch.sparse_coo_tensor( 30 | indices=torch.stack((VID, FID)), 31 | values=data, 32 | size=(num_verts, num_faces) 33 | ) 34 | degree = torch.sparse.sum(mat, dim=1).to_dense() 35 | vertex_normals = torch.stack(( 36 | torch.sparse.mm(mat, face_normals[:, :, 0].t()), 37 | torch.sparse.mm(mat, face_normals[:, :, 1].t()), 38 | torch.sparse.mm(mat, face_normals[:, :, 2].t()), 39 | ), dim=-1) 40 | vertex_normals = vertex_normals.transpose(1, 0) / degree.unsqueeze(0).unsqueeze(-1) 41 | vertex_normals = vertex_normals / (torch.norm(vertex_normals, dim=-1, keepdim=True) + 1.e-10) 42 | 43 | if ret_face_normals: 44 | return vertex_normals, face_normals 45 | else: 46 | return vertex_normals 47 | 48 | 49 | def unpose_garment(smpl, v_free, vert_indices=None): 50 | smpl.v_personal[:] = 0 51 | c = smpl[vert_indices] 52 | E = { 53 | 'v_personal_high': c - v_free 54 | } 55 | ch.minimize(E, x0=[smpl.v_personal], options={'e_3': .00001}) 56 | smpl.pose[:] = 0 57 | smpl.trans[:] = 0 58 | 59 | return Mesh(smpl.r, smpl.f).keep_vertices(vert_indices), np.copy(np.array(smpl.v_personal)) 60 | 61 | 62 | def merge_mesh(vs, fs, vcs): 63 | v_num = 0 64 | new_fs = [fs[0]] 65 | new_vcs = [] 66 | for i in range(len(vs)): 67 | if i >= 1: 68 | v_num += vs[i-1].shape[0] 69 | new_fs.append(fs[i]+v_num) 70 | if vcs is not None: 71 | if vcs[i].ndim == 1: 72 | new_vcs.append(np.tile(np.expand_dims(vcs[i], 0), [vs[i].shape[0], 1])) 73 | else: 74 | new_vcs.append(vcs) 75 | vs = np.concatenate(vs, 0) 76 | new_fs = np.concatenate(new_fs, 0) 77 | if vcs is not None: 78 | vcs = np.concatenate(new_vcs, 0) 79 | return vs, new_fs, vcs 80 | 81 | 82 | def get_edges2face(faces): 83 | from itertools import combinations 84 | from collections import OrderedDict 85 | # Returns a structure that contains the faces corresponding to every edge 86 | edges = OrderedDict() 87 | for iface, f in enumerate(faces): 88 | sorted_face_edges = tuple(combinations(sorted(f), 2)) 89 | for sorted_face_edge in sorted_face_edges: 90 | if sorted_face_edge in edges: 91 | edges[sorted_face_edge].faces.add(iface) 92 | else: 93 | edges[sorted_face_edge] = lambda:0 94 | edges[sorted_face_edge].faces = set([iface]) 95 | return edges 96 | 97 | 98 | def get_boundary_verts(verts, faces, connected_boundaries=True, connected_faces=False): 99 | """ 100 | Given a mesh returns boundary vertices 101 | if connected_boundaries is True it returs a list of lists 102 | OUTPUT: 103 | boundary_verts: list of verts 104 | cnct_bound_verts: list of list containing the N ordered rings of the mesh 105 | """ 106 | MIN_NUM_VERTS_RING = 10 107 | # Ordred dictionary 108 | edge_dict = get_edges2face(faces) 109 | 110 | boundary_verts = [] 111 | boundary_edges = [] 112 | boundary_faces = [] 113 | for edge, (key, val) in enumerate(edge_dict.items()): 114 | if len(val.faces) == 1: 115 | boundary_verts += list(key) 116 | boundary_edges.append(edge) 117 | for face_id in val.faces: 118 | boundary_faces.append(face_id) 119 | boundary_verts = list(set(boundary_verts)) 120 | if not connected_boundaries: 121 | return boundary_verts 122 | n_removed_verts = 0 123 | if connected_boundaries: 124 | edge_mat = np.array(list(edge_dict.keys())) 125 | # Edges on the boundary 126 | edge_mat = edge_mat[np.array(boundary_edges, dtype=np.int64)] 127 | 128 | # check that every vertex is shared by only two edges 129 | for v in boundary_verts: 130 | if np.sum(edge_mat == v) != 2: 131 | import ipdb; ipdb.set_trace(); 132 | raise ValueError('The boundary edges are not closed loops!') 133 | 134 | cnct_bound_verts = [] 135 | while len(edge_mat > 0): 136 | # boundary verts, indices of conected boundary verts in order 137 | bverts = [] 138 | orig_vert = edge_mat[0, 0] 139 | bverts.append(orig_vert) 140 | vert = edge_mat[0, 1] 141 | edge = 0 142 | while orig_vert != vert: 143 | bverts.append(vert) 144 | # remove edge from queue 145 | edge_mask = np.ones(edge_mat.shape[0], dtype=bool) 146 | edge_mask[edge] = False 147 | edge_mat = edge_mat[edge_mask] 148 | edge = np.where(np.sum(edge_mat == vert, axis=1) > 0)[0] 149 | tmp = edge_mat[edge] 150 | vert = tmp[tmp != vert][0] 151 | # remove the last edge 152 | edge_mask = np.ones(edge_mat.shape[0], dtype=bool) 153 | edge_mask[edge] = False 154 | edge_mat = edge_mat[edge_mask] 155 | if len(bverts) > MIN_NUM_VERTS_RING: 156 | # add ring to the list 157 | cnct_bound_verts.append(bverts) 158 | else: 159 | n_removed_verts += len(bverts) 160 | count = 0 161 | for ring in cnct_bound_verts: count += len(ring) 162 | assert(len(boundary_verts) - n_removed_verts == count), "Error computing boundary rings !!" 163 | 164 | if connected_faces: 165 | return (boundary_verts, boundary_faces, cnct_bound_verts) 166 | else: 167 | return (boundary_verts, cnct_bound_verts) 168 | 169 | 170 | def loop_subdivider(mesh_v, mesh_f): 171 | """Copied from opendr and modified to work in python3.""" 172 | 173 | IS = [] 174 | JS = [] 175 | data = [] 176 | 177 | vc = get_vert_connectivity(mesh_v, mesh_f) 178 | ve = get_vertices_per_edge(mesh_v, mesh_f) 179 | vo = get_vert_opposites_per_edge(mesh_v, mesh_f) 180 | 181 | if True: 182 | # New values for each vertex 183 | for idx in range(len(mesh_v)): 184 | 185 | # find neighboring vertices 186 | nbrs = np.nonzero(vc[:,idx])[0] 187 | 188 | nn = len(nbrs) 189 | 190 | if nn < 3: 191 | wt = 0. 192 | elif nn == 3: 193 | wt = 3./16. 194 | elif nn > 3: 195 | wt = 3. / (8. * nn) 196 | else: 197 | raise Exception('nn should be 3 or more') 198 | if wt > 0.: 199 | for nbr in nbrs: 200 | IS.append(idx) 201 | JS.append(nbr) 202 | data.append(wt) 203 | 204 | JS.append(idx) 205 | IS.append(idx) 206 | data.append(1. - (wt * nn)) 207 | 208 | start = len(mesh_v) 209 | edge_to_midpoint = {} 210 | 211 | if True: 212 | # New values for each edge: 213 | # new edge verts depend on the verts they span 214 | for idx, vs in enumerate(ve): 215 | 216 | vsl = list(vs) 217 | vsl.sort() 218 | IS.append(start + idx) 219 | IS.append(start + idx) 220 | JS.append(vsl[0]) 221 | JS.append(vsl[1]) 222 | data.append(3./8) 223 | data.append(3./8) 224 | 225 | opposites = vo[(vsl[0], vsl[1])] 226 | for opp in opposites: 227 | IS.append(start + idx) 228 | JS.append(opp) 229 | data.append(2./8./len(opposites)) 230 | 231 | edge_to_midpoint[(vsl[0], vsl[1])] = start + idx 232 | edge_to_midpoint[(vsl[1], vsl[0])] = start + idx 233 | 234 | f = [] 235 | 236 | for f_i, old_f in enumerate(mesh_f): 237 | ff = np.concatenate((old_f, old_f)) 238 | 239 | for i in range(3): 240 | v0 = edge_to_midpoint[(ff[i], ff[i+1])] 241 | v1 = ff[i+1] 242 | v2 = edge_to_midpoint[(ff[i+1], ff[i+2])] 243 | f.append(row(np.array([v0,v1,v2]))) 244 | 245 | v0 = edge_to_midpoint[(ff[0], ff[1])] 246 | v1 = edge_to_midpoint[(ff[1], ff[2])] 247 | v2 = edge_to_midpoint[(ff[2], ff[3])] 248 | f.append(row(np.array([v0,v1,v2]))) 249 | 250 | f = np.vstack(f) 251 | 252 | IS = np.array(IS, dtype=np.uint32) 253 | JS = np.array(JS, dtype=np.uint32) 254 | 255 | if True: # for x,y,z coords 256 | IS = np.concatenate((IS*3, IS*3+1, IS*3+2)) 257 | JS = np.concatenate((JS*3, JS*3+1, JS*3+2)) 258 | data = np.concatenate ((data,data,data)) 259 | 260 | ij = np.vstack((IS.flatten(), JS.flatten())) 261 | mtx = sp.csc_matrix((data, ij)) 262 | 263 | return mtx, f 264 | 265 | 266 | def get_vert_connectivity(mesh_v, mesh_f): 267 | """Returns a sparse matrix (of size #verts x #verts) where each nonzero 268 | element indicates a neighborhood relation. For example, if there is a 269 | nonzero element in position (15,12), that means vertex 15 is connected 270 | by an edge to vertex 12. 271 | 272 | Copied from opendr library. 273 | """ 274 | 275 | vpv = sp.csc_matrix((len(mesh_v),len(mesh_v))) 276 | 277 | # for each column in the faces... 278 | for i in range(3): 279 | IS = mesh_f[:,i] 280 | JS = mesh_f[:,(i+1)%3] 281 | data = np.ones(len(IS)) 282 | ij = np.vstack((row(IS.flatten()), row(JS.flatten()))) 283 | mtx = sp.csc_matrix((data, ij), shape=vpv.shape) 284 | vpv = vpv + mtx + mtx.T 285 | 286 | return vpv 287 | 288 | 289 | def get_vertices_per_edge(mesh_v, mesh_f): 290 | """Returns an Ex2 array of adjacencies between vertices, where 291 | each element in the array is a vertex index. Each edge is included 292 | only once. If output of get_faces_per_edge is provided, this is used to 293 | avoid call to get_vert_connectivity() 294 | 295 | Copied from opendr library. 296 | """ 297 | 298 | vc = sp.coo_matrix(get_vert_connectivity(mesh_v, mesh_f)) 299 | result = np.hstack((col(vc.row), col(vc.col))) 300 | result = result[result[:,0] < result[:,1]] # for uniqueness 301 | 302 | return result 303 | 304 | 305 | def get_faces_per_edge(mesh_v, mesh_f, verts_per_edge=None): 306 | """Copied from opendr library.""" 307 | if verts_per_edge is None: 308 | verts_per_edge = get_vertices_per_edge(mesh_v, mesh_f) 309 | 310 | v2f = {i: set([]) for i in range(len(mesh_v))} 311 | # TODO: cythonize? 312 | for idx, f in enumerate(mesh_f): 313 | v2f[f[0]].add(idx) 314 | v2f[f[1]].add(idx) 315 | v2f[f[2]].add(idx) 316 | 317 | fpe = -np.ones_like(verts_per_edge) 318 | for idx, edge in enumerate(verts_per_edge): 319 | faces = v2f[edge[0]].intersection(v2f[edge[1]]) 320 | faces = list(faces)[:2] 321 | for i, f in enumerate(faces): 322 | fpe[idx,i] = f 323 | 324 | return fpe 325 | 326 | 327 | def get_vert_opposites_per_edge(mesh_v, mesh_f): 328 | """Returns a dictionary from vertidx-pairs to opposites. 329 | For example, a key consist of [4,5)] meaning the edge between 330 | vertices 4 and 5, and a value might be [10,11] which are the indices 331 | of the vertices opposing this edge. 332 | 333 | Copied from opendr library. 334 | """ 335 | result = {} 336 | for f in mesh_f: 337 | for i in range(3): 338 | key = [f[i], f[(i+1)%3]] 339 | key.sort() 340 | key = tuple(key) 341 | val = f[(i+2)%3] 342 | 343 | if key in result: 344 | result[key].append(val) 345 | else: 346 | result[key] = [val] 347 | return result 348 | 349 | 350 | if __name__ == '__main__': 351 | ms = Mesh(filename="/BS/cpatel/work/data/learn_anim/test_py3/t-shirt_male/0000/gt_0.ply") 352 | b = get_boundary_verts(ms.v, ms.f) -------------------------------------------------------------------------------- /refu_tailornet/tnutils/interpenetration.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import scipy.sparse as sp 4 | from scipy.sparse import vstack, csr_matrix 5 | from scipy.sparse.linalg import spsolve 6 | from psbody.mesh import Mesh 7 | 8 | from psbody.mesh import Mesh 9 | from psbody.mesh.geometry.vert_normals import VertNormals 10 | from psbody.mesh.geometry.tri_normals import TriNormals 11 | from psbody.mesh.search import AabbTree 12 | from tnutils.diffusion_smoothing import numpy_laplacian_uniform as laplacian 13 | 14 | 15 | def get_nearest_points_and_normals(vert, base_verts, base_faces): 16 | """For each vertex of `vert`, find nearest surface points on 17 | base mesh (`base_verts`, `base_faces`). 18 | """ 19 | fn = TriNormals(v=base_verts, f=base_faces).reshape((-1, 3)) 20 | vn = VertNormals(v=base_verts, f=base_faces).reshape((-1, 3)) 21 | 22 | tree = AabbTree(Mesh(v=base_verts, f=base_faces)) 23 | nearest_tri, nearest_part, nearest_point = tree.nearest(vert, nearest_part=True) 24 | nearest_tri = nearest_tri.ravel().astype(np.long) 25 | nearest_part = nearest_part.ravel().astype(np.long) 26 | 27 | nearest_normals = np.zeros_like(vert) 28 | 29 | # nearest_part tells you whether the closest point in triangle abc is 30 | # in the interior (0), on an edge (ab:1,bc:2,ca:3), or a vertex (a:4,b:5,c:6) 31 | cl_tri_idxs = np.nonzero(nearest_part == 0)[0].astype(np.int) 32 | cl_vrt_idxs = np.nonzero(nearest_part > 3)[0].astype(np.int) 33 | cl_edg_idxs = np.nonzero((nearest_part <= 3) & (nearest_part > 0))[0].astype(np.int) 34 | 35 | nt = nearest_tri[cl_tri_idxs] 36 | nearest_normals[cl_tri_idxs] = fn[nt] 37 | 38 | nt = nearest_tri[cl_vrt_idxs] 39 | npp = nearest_part[cl_vrt_idxs] - 4 40 | nearest_normals[cl_vrt_idxs] = vn[base_faces[nt, npp]] 41 | 42 | nt = nearest_tri[cl_edg_idxs] 43 | npp = nearest_part[cl_edg_idxs] - 1 44 | nearest_normals[cl_edg_idxs] += vn[base_faces[nt, npp]] 45 | npp = np.mod(nearest_part[cl_edg_idxs], 3) 46 | nearest_normals[cl_edg_idxs] += vn[base_faces[nt, npp]] 47 | 48 | nearest_normals = nearest_normals / (np.linalg.norm(nearest_normals, axis=-1, keepdims=True) + 1.e-10) 49 | 50 | return nearest_point, nearest_normals 51 | 52 | 53 | def remove_interpenetration_fast(mesh, base, L=None): 54 | """Deforms `mesh` to remove its interpenetration from `base`. 55 | This is posed as least square optimization problem which can be solved 56 | faster with sparse solver. 57 | """ 58 | 59 | eps = 0.001 60 | ww = 2.0 61 | nverts = mesh.v.shape[0] 62 | 63 | if L is None: 64 | L = laplacian(mesh.v, mesh.f) 65 | 66 | nearest_points, nearest_normals = get_nearest_points_and_normals(mesh.v, base.v, base.f) 67 | direction = np.sign( np.sum((mesh.v - nearest_points) * nearest_normals, axis=-1) ) 68 | 69 | indices = np.where(direction < 0)[0] 70 | 71 | pentgt_points = nearest_points[indices] - mesh.v[indices] 72 | pentgt_points = nearest_points[indices] \ 73 | + eps * pentgt_points / np.expand_dims(0.0001 + np.linalg.norm(pentgt_points, axis=1), 1) 74 | tgt_points = mesh.v.copy() 75 | tgt_points[indices] = ww * pentgt_points 76 | 77 | rc = np.arange(nverts) 78 | data = np.ones(nverts) 79 | data[indices] *= ww 80 | I = csr_matrix((data, (rc, rc)), shape=(nverts, nverts)) 81 | 82 | A = vstack([L, I]) 83 | b = np.vstack(( 84 | L.dot(mesh.v), 85 | tgt_points 86 | )) 87 | 88 | res = spsolve(A.T.dot(A), A.T.dot(b)) 89 | mres = Mesh(v=res, f=mesh.f) 90 | return mres 91 | 92 | 93 | if __name__ == '__main__': 94 | import os 95 | ROOT = "/BS/cpatel/work/data/learn_anim/mixture_exp31/000_0/smooth_TShirtNoCoat/0990/" 96 | body = Mesh(filename=os.path.join(ROOT, "body_160.ply")) 97 | mesh = Mesh(filename=os.path.join(ROOT, "pred_160.ply")) 98 | 99 | mesh1 = remove_interpenetration_fast(mesh, body) 100 | mesh1.write_ply("/BS/cpatel/work/proccessed.ply") 101 | mesh.write_ply("/BS/cpatel/work/orig.ply") 102 | body.write_ply("/BS/cpatel/work/body.ply") 103 | 104 | # from psbody.mesh import MeshViewers 105 | # mvs = MeshViewers((1, 2)) 106 | # mesh1.set_vertex_colors_from_weights(np.linalg.norm(mesh.v - mesh1.v, axis=1)) 107 | # mesh.set_vertex_colors_from_weights(np.linalg.norm(mesh.v - mesh1.v, axis=1)) 108 | # # mesh1.set_vertex_colors_from_weights(np.zeros(mesh.v.shape[0])) 109 | # # mesh.set_vertex_colors_from_weights(np.zeros(mesh.v.shape[0])) 110 | # mvs[0][0].set_static_meshes([mesh, body]) 111 | # mvs[0][1].set_static_meshes([mesh1, body]) 112 | # mesh1.show() 113 | 114 | import ipdb 115 | ipdb.set_trace() -------------------------------------------------------------------------------- /refu_tailornet/tnutils/io.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join as opj 3 | import shutil 4 | from datetime import datetime 5 | import numpy as np 6 | # from plyfile import PlyData 7 | import struct 8 | import global_var 9 | 10 | def write_obj(verts, faces, path, color_idx=None): 11 | faces = faces + 1 12 | with open(path, 'w') as f: 13 | for vidx, v in enumerate(verts): 14 | if color_idx is not None and color_idx[vidx]: 15 | f.write("v {:.5f} {:.5f} {:.5f} 1 0 0\n".format(v[0], v[1], v[2])) 16 | else: 17 | f.write("v {:.5f} {:.5f} {:.5f}\n".format(v[0], v[1], v[2])) 18 | for fa in faces: 19 | f.write("f {:d} {:d} {:d}\n".format(fa[0], fa[1], fa[2])) 20 | 21 | def read_obj(path): 22 | with open(path, 'r') as obj: 23 | datas = obj.read() 24 | 25 | lines = datas.splitlines() 26 | 27 | vertices = [] 28 | faces = [] 29 | 30 | for line in lines: 31 | elem = line.split() 32 | if elem: 33 | if elem[0] == 'v': 34 | vertices.append([float(elem[1]), float(elem[2]), float(elem[3])]) 35 | elif elem[0] == 'f': 36 | face = [] 37 | for i in range(1, len(elem)): 38 | face.append(int(elem[i].split('/')[0])) 39 | faces.append(face) 40 | else: 41 | pass 42 | 43 | vertices = np.array(vertices) 44 | faces = np.array(faces)-1 45 | 46 | return vertices, faces -------------------------------------------------------------------------------- /refu_tailornet/tnutils/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | from datetime import datetime 4 | import global_var 5 | 6 | 7 | class BaseLogger(object): 8 | def __init__(self, log_name, fields): 9 | self.fpath = os.path.join(global_var.LOG_DIR, log_name) 10 | self.fields = fields 11 | 12 | def add_item(self, **kwargs): 13 | kwargs = kwargs.copy() 14 | if 'time' not in kwargs: 15 | kwargs['time'] = datetime.now().strftime('%Y-%m-%d %H:%M:%S') 16 | for k in list(kwargs): 17 | if k not in self.fields: 18 | kwargs.pop(k) 19 | else: 20 | kwargs[k] = str(kwargs[k]) 21 | if os.path.exists(self.fpath): 22 | with open(self.fpath, 'a') as f: 23 | writer = csv.DictWriter(f, fieldnames=self.fields) 24 | writer.writerow(kwargs) 25 | else: 26 | with open(self.fpath, 'w') as f: 27 | writer = csv.DictWriter(f, fieldnames=self.fields) 28 | writer.writeheader() 29 | writer.writerow(kwargs) 30 | 31 | 32 | class TailorNetLogger(BaseLogger): 33 | def __init__(self, log_name='tailornet.csv'): 34 | super(TailorNetLogger, self).__init__(log_name, 35 | ['garment_class', 'gender', 'smooth_level', 'best_error', 'best_epoch', 'time', 'batch_size', 36 | 'lr', 'weight_decay', 'note', 'log_name', 'shape_style', 'checkpoint']) 37 | -------------------------------------------------------------------------------- /refu_tailornet/tnutils/renderer.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | if 'SSH_CONNECTION' in os.environ: 7 | from tnutils.renderer_software import Renderer 8 | print("Warning: You're logged via SSH. Thus only software renderer is available, which is much slower") 9 | else: 10 | import pyrender 11 | import trimesh 12 | import numpy as np 13 | 14 | class Renderer(object): 15 | """ 16 | This is a wrapper of pyrender 17 | see documentation of __call__ for detailed usage 18 | """ 19 | 20 | def __init__(self, img_size, bg_color=None): 21 | if bg_color is None: 22 | bg_color = np.array([0.1, 0.1, 0.1, 1.]) 23 | self.scene = pyrender.Scene(bg_color=bg_color) 24 | self.focal_len = 5. 25 | camera = pyrender.PerspectiveCamera(yfov=np.tan(1 / self.focal_len) * 2, aspectRatio=1.0) 26 | camera_pose = np.eye(4, dtype=np.float32) 27 | self.scene.add(camera, pose=camera_pose) 28 | light = pyrender.DirectionalLight(color=np.ones(3), intensity=10.0, 29 | ) 30 | self.scene.add(light, pose=camera_pose) 31 | if not hasattr(img_size, '__iter__'): 32 | img_size = [img_size, img_size] 33 | self.r = pyrender.OffscreenRenderer(*img_size) 34 | 35 | def __call__(self, vs, fs, vcs=None, trans=(1., 0., 0.), euler=(0., 0., 0.), center=True): 36 | """ 37 | This function will put the center of objects at origin point. 38 | vs, fs, vcs: 39 | vertices, faces, colors of vertices. 40 | They are numpy array or list of numpy array (multiple meshes) 41 | trans: 42 | It is a 3 element tuple. The first is scale factor. The last two is x,y translation 43 | euler: 44 | euler angle of objects (degree not radian). It follows the order of YXZ, 45 | which means Y-axis, X-axis, Z-axis are yaw, pitch, roll respectively. 46 | """ 47 | if isinstance(vs, np.ndarray): 48 | vs = [vs] 49 | fs = [fs] 50 | vcs = [vcs] 51 | ms = [] 52 | mnodes = [] 53 | vss = np.concatenate(vs, 0) 54 | cen = (np.max(vss, 0, keepdims=True) + np.min(vss, 0, keepdims=True)) / 2. 55 | rotmat = self.euler2rotmat(euler) 56 | for v, f, vs in zip(vs, fs, vcs): 57 | trans_v = v - cen if center else v 58 | trans_v = np.einsum('pq,nq->np', rotmat, trans_v) 59 | trans_v[:, :2] += np.expand_dims(np.array(trans[1:]), 0) 60 | trans_v[:, 2] -= self.focal_len / trans[0] 61 | ms.append(trimesh.Trimesh(vertices=trans_v, faces=f, vertex_colors=vs)) 62 | for m in ms: 63 | mnode = self.scene.add(pyrender.Mesh.from_trimesh(m)) 64 | mnodes.append(mnode) 65 | img, depth = self.r.render(self.scene) 66 | for mnode in mnodes: 67 | self.scene.remove_node(mnode) 68 | return img 69 | 70 | @staticmethod 71 | def euler2rotmat(euler): 72 | euler = np.array(euler)*np.pi/180. 73 | se, ce = np.sin(euler), np.cos(euler) 74 | s1, c1 = se[0], ce[0] 75 | s2, c2 = se[1], ce[1] 76 | s3, c3 = se[2], ce[2] 77 | return np.array([[c1*c3+s1*s2*s3, c3*s1*s2-c1*s3, c2*s1], 78 | [c2*s3, c2*c3, -s2], 79 | [c1*s2*s3-c3*s1, c1*c3*s2+s1*s3, c1*c2]]) 80 | 81 | if __name__ == '__main__': 82 | from psbody.mesh import Mesh 83 | import cv2 84 | import numpy as np 85 | import global_var 86 | 87 | m1 = Mesh(filename='/home/zliao/cloth-anim/work/data/md/cloth_test/121611457711203/apose_avatar.obj') 88 | m2 = Mesh(filename='/home/zliao/cloth-anim/work/data/md/cloth_test/121611457711203/result_Pants.obj') 89 | vs = [m1.v, m2.v] 90 | fs = [m1.f, m2.f] 91 | colors = [np.array([0.6, 0.6, 0.9]), np.array([0.8, 0.5, 0.3])] 92 | renderer = Renderer(800) 93 | img = renderer(vs, fs, colors) 94 | trans_img = renderer(vs, fs, colors, trans=(0.8, 0.3, 0.3)) 95 | euler_img = renderer(vs, fs, colors, euler=(45, 20, 0)) 96 | img = np.concatenate([img, trans_img, euler_img], 1) 97 | cv2.imwrite(os.path.join(global_var.DATA_DIR, 'img.png'), img) 98 | -------------------------------------------------------------------------------- /refu_tailornet/tnutils/rotation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import cv2 4 | import pickle 5 | import global_var 6 | 7 | 8 | def flip_theta(theta, batch=False): 9 | """ 10 | flip SMPL theta along y-z plane 11 | if batch is True, theta shape is Nx72, otherwise 72 12 | """ 13 | exg_idx = [0, 2, 1, 3, 5, 4, 6, 8, 7, 9, 11, 10, 12, 14, 13, 15, 17, 16, 19, 18, 21, 20, 23, 22] 14 | if batch: 15 | new_theta = np.reshape(theta, [-1, 24, 3]) 16 | new_theta = new_theta[:, exg_idx] 17 | new_theta[:, :, 1:3] *= -1 18 | else: 19 | new_theta = np.reshape(theta, [24, 3]) 20 | new_theta = new_theta[exg_idx] 21 | new_theta[:, 1:3] *= -1 22 | new_theta = new_theta.reshape(theta.shape) 23 | return new_theta 24 | 25 | 26 | def get_Apose(): 27 | """Return thetas for A-pose.""" 28 | with open(os.path.join(global_var.DATA_DIR, 'apose.pkl'), 'rb') as f: 29 | APOSE = np.array(pickle.load(f, encoding='latin1')['pose']).astype(np.float32) 30 | flip_pose = flip_theta(APOSE) 31 | APOSE[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 15]] = 0 32 | APOSE[[14, 17, 19, 21, 23]] = flip_pose[[14, 17, 19, 21, 23]] 33 | APOSE = APOSE.reshape([72]) 34 | return APOSE 35 | 36 | 37 | def normalize_y_rotation(raw_theta): 38 | """Rotate along y axis so that root rotation can always face the camera. 39 | Theta should be a [3] or [72] numpy array. 40 | """ 41 | only_global = True 42 | if raw_theta.shape == (72,): 43 | theta = raw_theta[:3] 44 | only_global = False 45 | else: 46 | theta = raw_theta[:] 47 | raw_rot = cv2.Rodrigues(theta)[0] 48 | rot_z = raw_rot[:, 2] 49 | # we should rotate along y axis counter-clockwise for t rads to make the object face the camera 50 | if rot_z[2] == 0: 51 | t = (rot_z[0] / np.abs(rot_z[0])) * np.pi / 2 52 | elif rot_z[2] > 0: 53 | t = np.arctan(rot_z[0]/rot_z[2]) 54 | else: 55 | t = np.arctan(rot_z[0]/rot_z[2]) + np.pi 56 | cost, sint = np.cos(t), np.sin(t) 57 | norm_rot = np.array([[cost, 0, -sint],[0, 1, 0],[sint, 0, cost]]) 58 | final_rot = np.matmul(norm_rot, raw_rot) 59 | final_theta = cv2.Rodrigues(final_rot)[0][:, 0] 60 | if not only_global: 61 | return np.concatenate([final_theta, raw_theta[3:]], 0) 62 | else: 63 | return final_theta 64 | -------------------------------------------------------------------------------- /refu_tailornet/tnutils/sio.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import shutil 4 | import json 5 | from datetime import datetime 6 | from os.path import join as opj 7 | import cv2 8 | import struct 9 | import numpy as np 10 | import global_var 11 | 12 | 13 | def backup_file(src, dst): 14 | if os.path.exists(opj(src, 'nobackup')): 15 | return 16 | if len([k for k in os.listdir(src) if k.endswith('.py') or k.endswith('.sh')]) == 0: 17 | return 18 | if not os.path.isdir(dst): 19 | os.makedirs(dst) 20 | all_files = os.listdir(src) 21 | for fname in all_files: 22 | fname_full = opj(src, fname) 23 | fname_dst = opj(dst, fname) 24 | if os.path.isdir(fname_full): 25 | backup_file(fname_full, fname_dst) 26 | elif fname.endswith('.py') or fname.endswith('.sh'): 27 | shutil.copy(fname_full, fname_dst) 28 | 29 | 30 | def prepare_log_dir(log_name): 31 | if len(log_name) == 0: 32 | log_name = datetime.now().strftime("%b%d_%H%M%S") 33 | 34 | log_dir = os.path.join(global_var.LOG_DIR, log_name) 35 | if not os.path.exists(log_dir): 36 | print('making %s' % log_dir) 37 | os.makedirs(log_dir) 38 | else: 39 | warning_info = 'Warning: log_dir({}) already exists\n' \ 40 | 'Are you sure continuing? y/n'.format(log_dir) 41 | if sys.version_info.major == 3: 42 | a = input(warning_info) 43 | else: 44 | a = input(warning_info) 45 | if a != 'y': 46 | exit() 47 | 48 | backup_dir = opj(log_dir, 'code') 49 | if not os.path.exists(backup_dir): 50 | os.makedirs(backup_dir) 51 | backup_file(global_var.ROOT_DIR, backup_dir) 52 | print("Backup code in {}".format(backup_dir)) 53 | return log_dir 54 | 55 | 56 | def save_params(log_dir, params, save_name="params"): 57 | same_num = 1 58 | while os.path.exists(save_name): 59 | save_name = save_name + "({})".format(same_num) 60 | with open(os.path.join(log_dir, save_name+".json"), 'w') as f: 61 | json.dump(params, f) 62 | 63 | 64 | def save_pc2(vertices, path): 65 | # vertices: (N, V, 3), N is the number of frames, V is the number of vertices 66 | # path: a .pc2 file 67 | nframes, nverts, _ = vertices.shape 68 | with open(path, 'wb') as f: 69 | headerStr = struct.pack('<12siiffi', b'POINTCACHE2\0', 70 | 1, nverts, 1, 1, nframes) 71 | f.write(headerStr) 72 | v = vertices.reshape(-1, 3).astype(np.float32) 73 | for v_ in v: 74 | f.write(struct.pack(' params['batch_size']: 94 | drop_last = True 95 | else: 96 | drop_last = False 97 | dataloader = DataLoader(dataset, batch_size=self.bs, num_workers=0, shuffle=shuffle, 98 | drop_last=drop_last) 99 | return dataset, dataloader 100 | 101 | def build_model(self): 102 | params = self.params 103 | model = getattr(networks, self.model_name)( 104 | input_size=72+10+4, output_size=self.vert_indices.shape[0] * 3, 105 | num_layers=params['num_layers'], 106 | hidden_size=params['hidden_size']) 107 | return model 108 | 109 | def get_logger(self): 110 | return TailorNetLogger() 111 | 112 | def one_step(self, inputs): 113 | """One forward pass. 114 | Takes `inputs` tuple. Returns output(s) and loss. 115 | """ 116 | gt_verts, thetas, betas, gammas, _ = inputs 117 | 118 | thetas, betas, gammas = ops.mask_inputs(thetas, betas, gammas, self.garment_class) 119 | gt_verts = gt_verts.to(device) 120 | thetas = thetas.to(device) 121 | betas = betas.to(device) 122 | gammas = gammas.to(device) 123 | pred_verts = self.model( 124 | torch.cat((thetas, betas, gammas), dim=1)).view(gt_verts.shape) 125 | 126 | # L1 loss 127 | data_loss = (pred_verts - gt_verts).abs().sum(-1).mean() 128 | return pred_verts, data_loss 129 | 130 | def train(self, epoch): 131 | """Train for one epoch.""" 132 | epoch_loss = AverageMeter() 133 | self.model.train() 134 | for i, inputs in enumerate(self.train_loader): 135 | self.optimizer.zero_grad() 136 | outputs, loss = self.one_step(inputs) 137 | loss.backward() 138 | self.optimizer.step() 139 | 140 | self.logger.add_scalar("train/loss", loss.item(), self.iter_nums) 141 | print("Iter {}, loss: {:.8f}".format(self.iter_nums, loss.item())) 142 | epoch_loss.update(loss, inputs[0].shape[0]) 143 | self.iter_nums += 1 144 | 145 | self.logger.add_scalar("train_epoch/loss", epoch_loss.avg, epoch) 146 | 147 | def update_metrics(self, metrics, inputs, outputs): 148 | """Update metrics from inputs and predicted outputs.""" 149 | gt_verts = inputs[0] 150 | pred_verts = outputs 151 | dist = ops.verts_dist(gt_verts, pred_verts.cpu()) * 1000. 152 | metrics['val_dist'].update(dist.item(), gt_verts.shape[0]) 153 | 154 | def visualize_batch(self, inputs, outputs, epoch): 155 | """Save visualizations of some samples of the batch.""" 156 | gt_verts, thetas, betas, gammas, idxs = inputs 157 | pred_verts = outputs 158 | idxs = idxs.numpy() 159 | for lidx, idx in enumerate(idxs): 160 | if idx % self.vis_freq != 0: 161 | continue 162 | theta = thetas[lidx].cpu().numpy() 163 | beta = betas[lidx].cpu().numpy() 164 | pred_vert = pred_verts[lidx].cpu().numpy() 165 | gt_vert = gt_verts[lidx].cpu().numpy() 166 | 167 | body_m, pred_m = self.smpl.run(theta=theta, garment_d=pred_vert, beta=beta, 168 | garment_class=self.garment_class) 169 | _, gt_m = self.smpl.run(theta=theta, garment_d=gt_vert, beta=beta, 170 | garment_class=self.garment_class) 171 | 172 | save_dir = os.path.join(self.log_dir, "{:04d}".format(epoch)) 173 | pred_m.write_ply(os.path.join(save_dir, "pred_{}.ply".format(idx))) 174 | gt_m.write_ply(os.path.join(save_dir, "gt_{}.ply".format(idx))) 175 | body_m.write_ply(os.path.join(save_dir, "body_{}.ply".format(idx))) 176 | 177 | def validate(self, epoch): 178 | """Evaluate on test dataset.""" 179 | val_loss = AverageMeter() 180 | metrics = { 181 | 'val_dist': AverageMeter(), # per vertex distance in mm 182 | } 183 | self.model.eval() 184 | with torch.no_grad(): 185 | for i, inputs in enumerate(self.test_loader): 186 | outputs, loss = self.one_step(inputs) 187 | val_loss.update(loss.item(), inputs[0].shape[0]) 188 | 189 | self.update_metrics(metrics, inputs, outputs) 190 | self.visualize_batch(inputs, outputs, epoch) 191 | 192 | val_dist_avg = metrics['val_dist'].avg 193 | self.logger.add_scalar("val/loss", val_loss.avg, epoch) 194 | self.logger.add_scalar("val/dist", val_dist_avg, epoch) 195 | print("VALIDATION") 196 | print("Epoch {}, loss: {:.4f}, dist: {:.4f} mm".format( 197 | epoch, val_loss.avg, val_dist_avg)) 198 | 199 | if val_dist_avg < self.best_error: 200 | self.best_error = val_dist_avg 201 | self.best_epoch = epoch 202 | self.save_ckpt_best() 203 | with open(os.path.join(self.log_dir, 'best_epoch'), 'w') as f: 204 | f.write("{:04d}".format(epoch)) 205 | 206 | def write_log(self): 207 | """Log training info once training is done.""" 208 | if self.best_epoch >= 0: 209 | self.csv_logger.add_item( 210 | best_error=self.best_error, best_epoch=self.best_epoch, **self.params) 211 | 212 | def save_ckpt(self, epoch): 213 | """Save checkpoint in given epoch's directory.""" 214 | save_dir = os.path.join(self.log_dir, "{:04d}".format(epoch)) 215 | if not os.path.exists(save_dir): 216 | os.makedirs(save_dir) 217 | torch.save(self.model.state_dict(), os.path.join(save_dir, 'lin.pth.tar')) 218 | torch.save(self.optimizer.state_dict(), os.path.join(save_dir, "optimizer.pth.tar")) 219 | 220 | def save_ckpt_best(self): 221 | """Save checkpoint in log directory.""" 222 | save_dir = self.log_dir 223 | if not os.path.exists(save_dir): 224 | os.makedirs(save_dir) 225 | torch.save(self.model.state_dict(), os.path.join(save_dir, 'lin.pth.tar')) 226 | torch.save(self.optimizer.state_dict(), os.path.join(save_dir, "optimizer.pth.tar")) 227 | 228 | 229 | class Runner(object): 230 | """A helper class to load a trained model.""" 231 | def __init__(self, ckpt, params): 232 | model_name = params['model_name'] 233 | garment_class = params['garment_class'] 234 | 235 | with open(os.path.join(global_var.DATA_DIR, global_var.GAR_INFO_FILE), 'rb') as f: 236 | class_info = pickle.load(f) 237 | output_size = len(class_info[garment_class]['vert_indices']) * 3 238 | 239 | self.model = getattr(networks, model_name)( 240 | input_size=72+10+4, output_size=output_size, 241 | hidden_size=params['hidden_size'] if 'hidden_size' in params else 1024, 242 | num_layers=params['num_layers'] if 'num_layers' in params else 3 243 | ) 244 | self.garment_class = params['garment_class'] 245 | 246 | print("loading {}".format(ckpt)) 247 | if torch.cuda.is_available(): 248 | self.model.cuda() 249 | state_dict = torch.load(ckpt) 250 | else: 251 | state_dict = torch.load(ckpt,map_location='cpu') 252 | self.model.load_state_dict(state_dict) 253 | self.model.eval() 254 | 255 | def forward(self, thetas, betas, gammas): 256 | thetas, betas, gammas = ops.mask_inputs( 257 | thetas, betas, gammas, garment_class=self.garment_class) 258 | pred_verts = self.model(torch.cat((thetas, betas, gammas), dim=1)) 259 | return pred_verts.view(thetas.shape[0], -1, 3) 260 | 261 | def cuda(self): 262 | self.model.cuda() 263 | 264 | def to(self, device): 265 | self.model.to(device) 266 | 267 | class Model(torch.nn.Module): 268 | """A helper class to load a trained model.""" 269 | def __init__(self, params, ckpt = None): 270 | super(Model, self).__init__() 271 | model_name = params['model_name'] 272 | garment_class = params['garment_class'] 273 | 274 | with open(os.path.join(global_var.DATA_DIR, global_var.GAR_INFO_FILE), 'rb') as f: 275 | class_info = pickle.load(f) 276 | output_size = len(class_info[garment_class]['vert_indices']) * 3 277 | 278 | self.model = getattr(networks, model_name)( 279 | input_size=72+10+4, output_size=output_size, 280 | hidden_size=params['hidden_size'] if 'hidden_size' in params else 1024, 281 | num_layers=params['num_layers'] if 'num_layers' in params else 3 282 | ) 283 | self.garment_class = params['garment_class'] 284 | 285 | if ckpt is not None: 286 | print("loading {}".format(ckpt)) 287 | if torch.cuda.is_available(): 288 | self.model.cuda() 289 | state_dict = torch.load(ckpt) 290 | else: 291 | state_dict = torch.load(ckpt,map_location='cpu') 292 | self.model.load_state_dict(state_dict) 293 | 294 | def forward(self, thetas, betas, gammas): 295 | thetas, betas, gammas = ops.mask_inputs( 296 | thetas, betas, gammas, garment_class=self.garment_class) 297 | pred_verts = self.model(torch.cat((thetas, betas, gammas), dim=1)) 298 | return pred_verts.view(thetas.shape[0], -1, 3) 299 | 300 | def get_best_runner(log_dir, epoch_num=None): 301 | """Returns a trained model runner given the log_dir.""" 302 | ckpt_dir = log_dir 303 | with open(os.path.join(ckpt_dir, 'params.json')) as jf: 304 | params = json.load(jf) 305 | 306 | # if epoch_num is not given then pick up the best epoch 307 | if epoch_num is None: 308 | ckpt_path = os.path.join(ckpt_dir, 'lin.pth.tar') 309 | else: 310 | # with open(os.path.join(ckpt_dir, 'best_epoch')) as f: 311 | # best_epoch = int(f.read().strip()) 312 | best_epoch = epoch_num 313 | ckpt_path = os.path.join(ckpt_dir, "{:04d}".format(best_epoch), 'lin.pth.tar') 314 | 315 | runner = Runner(ckpt_path, params) 316 | return runner 317 | 318 | def get_best_model(log_dir, epoch_num=None): 319 | ckpt_dir = log_dir 320 | with open(os.path.join(ckpt_dir, 'params.json')) as jf: 321 | params = json.load(jf) 322 | 323 | # if epoch_num is not given then pick up the best epoch 324 | if epoch_num is None: 325 | ckpt_path = os.path.join(ckpt_dir, 'lin.pth.tar') 326 | else: 327 | ckpt_path = os.path.join(ckpt_dir, "{:04d}".format(epoch_num), 'lin.pth.tar') 328 | 329 | model = Model(params, ckpt_path) 330 | return model 331 | 332 | 333 | 334 | def parse_argument(): 335 | parser = argparse.ArgumentParser() 336 | parser.add_argument('--local_config', default='') 337 | 338 | parser.add_argument('--garment_class', default="shirt") 339 | parser.add_argument('--gender', default="male") 340 | parser.add_argument('--shape_style', default="") 341 | 342 | # some training hyper parameters 343 | parser.add_argument('--vis_freq', default=512, type=int) 344 | parser.add_argument('--batch_size', default=32, type=int) 345 | parser.add_argument('--lr', default=1e-4, type=float) 346 | parser.add_argument('--weight_decay', default=1e-6, type=float) 347 | parser.add_argument('--max_epoch', default=100, type=int) 348 | parser.add_argument('--start_epoch', default=0, type=int) 349 | parser.add_argument('--checkpoint', default="") 350 | 351 | # name under which experiment will be logged 352 | parser.add_argument('--log_name', default="tn_baseline") 353 | 354 | # smooth_level=0 will train TailorNet MLP baseline 355 | parser.add_argument('--smooth_level', default=0, type=int) 356 | 357 | # model specification. 358 | parser.add_argument('--model_name', default="FullyConnected") 359 | parser.add_argument('--num_layers', default=3) 360 | parser.add_argument('--hidden_size', default=1048) 361 | 362 | # small experiment description 363 | parser.add_argument('--note', default="MLP Baseline") 364 | 365 | args = parser.parse_args() 366 | params = args.__dict__ 367 | 368 | # load params from local config if provided 369 | if os.path.exists(params['local_config']): 370 | print("loading config from {}".format(params['local_config'])) 371 | with open(params['local_config']) as f: 372 | lc = json.load(f) 373 | for k, v in lc.items(): 374 | params[k] = v 375 | return params 376 | 377 | 378 | def main(): 379 | params = parse_argument() 380 | 381 | print("start training {}".format(params['garment_class'])) 382 | trainer = Trainer(params) 383 | 384 | for i in range(params['start_epoch'], params['max_epoch']): 385 | print("epoch: {}".format(i)) 386 | trainer.train(i) 387 | trainer.validate(i) 388 | # trainer.save_ckpt(i) 389 | 390 | trainer.write_log() 391 | print("safely quit!") 392 | 393 | 394 | if __name__ == '__main__': 395 | main() 396 | -------------------------------------------------------------------------------- /refu_tailornet/trainer/base_trainer_col_info.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tensorboardX 3 | import argparse 4 | import torch 5 | from torch.utils.data import DataLoader 6 | import numpy as np 7 | import json 8 | import pickle 9 | 10 | import sys 11 | 12 | currentdir = os.path.dirname(os.path.realpath(__file__)) 13 | parentdir = os.path.dirname(currentdir) 14 | sys.path.append(parentdir) 15 | 16 | from models import networks 17 | from models import ops 18 | from models.smpl4garment import SMPL4Garment 19 | from dataset.static_pose_shape_final import MultiStyleShape 20 | from tnutils.eval import AverageMeter 21 | from tnutils.logger import TailorNetLogger 22 | from tnutils import sio 23 | import global_var 24 | 25 | device = torch.device("cuda:0") 26 | # device = torch.device("cpu") 27 | 28 | 29 | class Trainer(object): 30 | """Implements trainer class for TailorNet MLP baseline. 31 | It is also a base class for TailorNet LF, HF and SS2G trainers. 32 | """ 33 | 34 | def __init__(self, params): 35 | self.params = params 36 | self.gender = params['gender'] 37 | self.garment_class = params['garment_class'] 38 | 39 | self.bs = params['batch_size'] 40 | self.vis_freq = params['vis_freq'] 41 | self.model_name = params['model_name'] 42 | self.note = params['note'] 43 | 44 | # log and backup 45 | log_name = os.path.join(params['log_name'], 46 | '{}_{}'.format(self.garment_class, self.gender)) 47 | if params['shape_style'] != '': 48 | log_name = os.path.join(log_name, params['shape_style']) 49 | self.log_dir = sio.prepare_log_dir(log_name) 50 | sio.save_params(self.log_dir, params, save_name='params') 51 | 52 | self.iter_nums = 0 if 'iter_nums' not in params else params['iter_nums'] 53 | 54 | # smpl for garment 55 | self.smpl = SMPL4Garment(gender=self.gender) 56 | 57 | # garment specific things 58 | with open(os.path.join(global_var.DATA_DIR, global_var.GAR_INFO_FILE), 'rb') as f: 59 | class_info = pickle.load(f) 60 | self.body_f_np = self.smpl.smpl_base.f.astype(np.long) 61 | self.garment_f_np = class_info[self.garment_class]['f'] 62 | self.garment_f_torch = torch.tensor(self.garment_f_np.astype(np.long)).long().to(device) 63 | self.vert_indices = np.array( 64 | class_info[self.garment_class]['vert_indices']) 65 | 66 | # get dataset and dataloader 67 | self.train_dataset, self.train_loader = self.load_dataset('train', collision_info=True) 68 | self.test_dataset, self.test_loader = self.load_dataset('test', collision_info=True) 69 | print("Train dataset size", len(self.train_dataset)) 70 | print("Test dataset size", len(self.test_dataset)) 71 | 72 | # model and optimizer 73 | self.model = self.build_model() 74 | self.model.to(device) 75 | self.optimizer = torch.optim.Adam( 76 | self.model.parameters(), lr=params['lr'], weight_decay=params['weight_decay']) 77 | 78 | # continue training from checkpoint if provided 79 | if params['checkpoint']: 80 | ckpt_path = params['checkpoint'] 81 | print('loading ckpt from {}'.format(ckpt_path)) 82 | state_dict = torch.load(os.path.join(ckpt_path, 'lin.pth.tar')) 83 | self.model.load_state_dict(state_dict) 84 | state_dict = torch.load(os.path.join(ckpt_path, 'optimizer.pth.tar')) 85 | self.optimizer.load_state_dict(state_dict) 86 | 87 | self.best_error = np.inf 88 | self.best_epoch = -1 89 | 90 | # logger 91 | self.logger = tensorboardX.SummaryWriter(os.path.join(self.log_dir)) 92 | self.csv_logger = self.get_logger() 93 | 94 | def load_dataset(self, split, collision_info = False): 95 | params = self.params 96 | dataset = MultiStyleShape(self.garment_class, split=split, gender=self.gender, 97 | smooth_level=params['smooth_level'], collision_info=collision_info) 98 | shuffle = True if split == 'train' else False 99 | if split == 'train' and len(dataset) > params['batch_size']: 100 | drop_last = True 101 | else: 102 | drop_last = False 103 | dataloader = DataLoader(dataset, batch_size=self.bs, num_workers=0, shuffle=shuffle, 104 | drop_last=drop_last) 105 | return dataset, dataloader 106 | 107 | def build_model(self): 108 | params = self.params 109 | model = getattr(networks, self.model_name)( 110 | input_size=72+10+4, output_size=self.vert_indices.shape[0] * 3, 111 | num_layers=params['num_layers'], 112 | hidden_size=params['hidden_size']) 113 | return model 114 | 115 | def get_logger(self): 116 | return TailorNetLogger() 117 | 118 | def one_step(self, inputs): 119 | """One forward pass. 120 | Takes `inputs` tuple. Returns output(s) and loss. 121 | """ 122 | gt_verts, thetas, betas, gammas, _ = inputs 123 | 124 | thetas, betas, gammas = ops.mask_inputs(thetas, betas, gammas, self.garment_class) 125 | gt_verts = gt_verts.to(device) 126 | thetas = thetas.to(device) 127 | betas = betas.to(device) 128 | gammas = gammas.to(device) 129 | pred_verts = self.model( 130 | torch.cat((thetas, betas, gammas), dim=1)).view(gt_verts.shape) 131 | 132 | # L1 loss 133 | data_loss = (pred_verts - gt_verts).abs().sum(-1).mean() 134 | return pred_verts, data_loss 135 | 136 | def train(self, epoch): 137 | """Train for one epoch.""" 138 | epoch_loss = AverageMeter() 139 | self.model.train() 140 | for i, inputs in enumerate(self.train_loader): 141 | self.optimizer.zero_grad() 142 | outputs, loss = self.one_step(inputs) 143 | loss.backward() 144 | self.optimizer.step() 145 | 146 | self.logger.add_scalar("train/loss", loss.item(), self.iter_nums) 147 | print("Iter {}, loss: {:.8f}".format(self.iter_nums, loss.item())) 148 | epoch_loss.update(loss, inputs[0].shape[0]) 149 | self.iter_nums += 1 150 | 151 | self.logger.add_scalar("train_epoch/loss", epoch_loss.avg, epoch) 152 | 153 | def update_metrics(self, metrics, inputs, outputs): 154 | """Update metrics from inputs and predicted outputs.""" 155 | gt_verts = inputs[0] 156 | pred_verts = outputs 157 | dist = ops.verts_dist(gt_verts, pred_verts.cpu()) * 1000. 158 | metrics['val_dist'].update(dist.item(), gt_verts.shape[0]) 159 | 160 | def visualize_batch(self, inputs, outputs, epoch): 161 | """Save visualizations of some samples of the batch.""" 162 | gt_verts, thetas, betas, gammas, idxs = inputs 163 | pred_verts = outputs 164 | idxs = idxs.numpy() 165 | for lidx, idx in enumerate(idxs): 166 | if idx % self.vis_freq != 0: 167 | continue 168 | theta = thetas[lidx].cpu().numpy() 169 | beta = betas[lidx].cpu().numpy() 170 | pred_vert = pred_verts[lidx].cpu().numpy() 171 | gt_vert = gt_verts[lidx].cpu().numpy() 172 | 173 | body_m, pred_m = self.smpl.run(theta=theta, garment_d=pred_vert, beta=beta, 174 | garment_class=self.garment_class) 175 | _, gt_m = self.smpl.run(theta=theta, garment_d=gt_vert, beta=beta, 176 | garment_class=self.garment_class) 177 | 178 | save_dir = os.path.join(self.log_dir, "{:04d}".format(epoch)) 179 | pred_m.write_ply(os.path.join(save_dir, "pred_{}.ply".format(idx))) 180 | gt_m.write_ply(os.path.join(save_dir, "gt_{}.ply".format(idx))) 181 | body_m.write_ply(os.path.join(save_dir, "body_{}.ply".format(idx))) 182 | 183 | def validate(self, epoch): 184 | """Evaluate on test dataset.""" 185 | val_loss = AverageMeter() 186 | metrics = { 187 | 'val_dist': AverageMeter(), # per vertex distance in mm 188 | } 189 | self.model.eval() 190 | with torch.no_grad(): 191 | for i, inputs in enumerate(self.test_loader): 192 | outputs, loss = self.one_step(inputs) 193 | val_loss.update(loss.item(), inputs[0].shape[0]) 194 | 195 | self.update_metrics(metrics, inputs, outputs) 196 | self.visualize_batch(inputs, outputs, epoch) 197 | 198 | val_dist_avg = metrics['val_dist'].avg 199 | self.logger.add_scalar("val/loss", val_loss.avg, epoch) 200 | self.logger.add_scalar("val/dist", val_dist_avg, epoch) 201 | print("VALIDATION") 202 | print("Epoch {}, loss: {:.4f}, dist: {:.4f} mm".format( 203 | epoch, val_loss.avg, val_dist_avg)) 204 | 205 | if val_dist_avg < self.best_error: 206 | self.best_error = val_dist_avg 207 | self.best_epoch = epoch 208 | self.save_ckpt_best() 209 | with open(os.path.join(self.log_dir, 'best_epoch'), 'w') as f: 210 | f.write("{:04d}".format(epoch)) 211 | 212 | def write_log(self): 213 | """Log training info once training is done.""" 214 | if self.best_epoch >= 0: 215 | self.csv_logger.add_item( 216 | best_error=self.best_error, best_epoch=self.best_epoch, **self.params) 217 | 218 | def save_ckpt(self, epoch): 219 | """Save checkpoint in given epoch's directory.""" 220 | save_dir = os.path.join(self.log_dir, "{:04d}".format(epoch)) 221 | if not os.path.exists(save_dir): 222 | os.makedirs(save_dir) 223 | torch.save(self.model.state_dict(), os.path.join(save_dir, 'lin.pth.tar')) 224 | torch.save(self.optimizer.state_dict(), os.path.join(save_dir, "optimizer.pth.tar")) 225 | 226 | def save_ckpt_best(self): 227 | """Save checkpoint in log directory.""" 228 | save_dir = self.log_dir 229 | if not os.path.exists(save_dir): 230 | os.makedirs(save_dir) 231 | torch.save(self.model.state_dict(), os.path.join(save_dir, 'lin.pth.tar')) 232 | torch.save(self.optimizer.state_dict(), os.path.join(save_dir, "optimizer.pth.tar")) 233 | 234 | 235 | class Runner(object): 236 | """A helper class to load a trained model.""" 237 | def __init__(self, ckpt, params): 238 | model_name = params['model_name'] 239 | garment_class = params['garment_class'] 240 | 241 | with open(os.path.join(global_var.DATA_DIR, global_var.GAR_INFO_FILE), 'rb') as f: 242 | class_info = pickle.load(f) 243 | output_size = len(class_info[garment_class]['vert_indices']) * 3 244 | 245 | self.model = getattr(networks, model_name)( 246 | input_size=72+10+4, output_size=output_size, 247 | hidden_size=params['hidden_size'] if 'hidden_size' in params else 1024, 248 | num_layers=params['num_layers'] if 'num_layers' in params else 3 249 | ) 250 | self.garment_class = params['garment_class'] 251 | 252 | print("loading {}".format(ckpt)) 253 | if torch.cuda.is_available(): 254 | self.model.cuda() 255 | state_dict = torch.load(ckpt) 256 | else: 257 | state_dict = torch.load(ckpt,map_location='cpu') 258 | self.model.load_state_dict(state_dict) 259 | self.model.eval() 260 | 261 | def forward(self, thetas, betas, gammas): 262 | thetas, betas, gammas = ops.mask_inputs( 263 | thetas, betas, gammas, garment_class=self.garment_class) 264 | pred_verts = self.model(torch.cat((thetas, betas, gammas), dim=1)) 265 | return pred_verts.view(thetas.shape[0], -1, 3) 266 | 267 | def cuda(self): 268 | self.model.cuda() 269 | 270 | def to(self, device): 271 | self.model.to(device) 272 | 273 | 274 | def get_best_runner(log_dir, epoch_num=None): 275 | """Returns a trained model runner given the log_dir.""" 276 | ckpt_dir = log_dir 277 | with open(os.path.join(ckpt_dir, 'params.json')) as jf: 278 | params = json.load(jf) 279 | 280 | # if epoch_num is not given then pick up the best epoch 281 | if epoch_num is None: 282 | ckpt_path = os.path.join(ckpt_dir, 'lin.pth.tar') 283 | else: 284 | # with open(os.path.join(ckpt_dir, 'best_epoch')) as f: 285 | # best_epoch = int(f.read().strip()) 286 | best_epoch = epoch_num 287 | ckpt_path = os.path.join(ckpt_dir, "{:04d}".format(best_epoch), 'lin.pth.tar') 288 | 289 | runner = Runner(ckpt_path, params) 290 | return runner 291 | 292 | 293 | def parse_argument(): 294 | parser = argparse.ArgumentParser() 295 | parser.add_argument('--local_config', default='') 296 | 297 | parser.add_argument('--garment_class', default="shirt") 298 | parser.add_argument('--gender', default="male") 299 | parser.add_argument('--shape_style', default="") 300 | 301 | # some training hyper parameters 302 | parser.add_argument('--vis_freq', default=512, type=int) 303 | parser.add_argument('--batch_size', default=32, type=int) 304 | parser.add_argument('--lr', default=1e-4, type=float) 305 | parser.add_argument('--weight_decay', default=1e-6, type=float) 306 | parser.add_argument('--max_epoch', default=100, type=int) 307 | parser.add_argument('--start_epoch', default=0, type=int) 308 | parser.add_argument('--checkpoint', default="") 309 | 310 | # name under which experiment will be logged 311 | # parser.add_argument('--log_name', default="tn_baseline") 312 | parser.add_argument('--log_name', default="tn_lf_100_1") 313 | 314 | # smooth_level=0 will train TailorNet MLP baseline 315 | # parser.add_argument('--smooth_level', default=0, type=int) 316 | 317 | # smooth_level=1 will train TailorNet low frequency predictor 318 | parser.add_argument('--smooth_level', default=1, type=int) 319 | 320 | # model specification. 321 | parser.add_argument('--model_name', default="FullyConnected") 322 | parser.add_argument('--num_layers', default=3) 323 | parser.add_argument('--hidden_size', default=1024) 324 | 325 | # small experiment description 326 | # parser.add_argument('--note', default="MLP Baseline") 327 | parser.add_argument('--note', default="TailorNet low frequency prediction") 328 | 329 | args = parser.parse_args() 330 | params = args.__dict__ 331 | 332 | # load params from local config if provided 333 | if os.path.exists(params['local_config']): 334 | print("loading config from {}".format(params['local_config'])) 335 | with open(params['local_config']) as f: 336 | lc = json.load(f) 337 | for k, v in lc.items(): 338 | params[k] = v 339 | return params 340 | 341 | 342 | def main(): 343 | params = parse_argument() 344 | 345 | print("start training {}".format(params['garment_class'])) 346 | trainer = Trainer(params) 347 | 348 | for i in range(params['start_epoch'], params['max_epoch']): 349 | print("epoch: {}".format(i)) 350 | trainer.train(i) 351 | trainer.validate(i) 352 | # trainer.save_ckpt(i) 353 | 354 | trainer.write_log() 355 | print("safely quit!") 356 | 357 | 358 | if __name__ == '__main__': 359 | main() 360 | -------------------------------------------------------------------------------- /refu_tailornet/trainer/eg_trainer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implements trainer class for TailorNet low frequency predictor. 3 | 4 | LF predictor training is exactly same as MLP baseline training 5 | but one change: smooth_level is set to 1 for LF training. 6 | """ 7 | import argparse 8 | import json 9 | import os 10 | import torch 11 | import base_trainer 12 | from ss2g_trainer import get_best_runner as ss2g_runner 13 | from models import ops 14 | 15 | device = torch.device("cuda:0") 16 | # device = torch.device("cpu") 17 | 18 | 19 | class EGTrainer(base_trainer.Trainer): 20 | def __init__(self, params): 21 | super(EGTrainer, self).__init__(params) 22 | ss2g_logdir = "/BS/cpatel/work/data/learn_anim/test_ss2g/{}_{}".format( 23 | self.garment_class, self.gender) 24 | self.ss2g_runner = ss2g_runner(ss2g_logdir) 25 | 26 | def one_step(self, inputs): 27 | gt_verts, thetas, betas, gammas, _ = inputs 28 | 29 | thetas = ops.mask_thetas(thetas, self.garment_class) 30 | gt_verts = gt_verts.to(device) 31 | thetas = thetas.to(device) 32 | betas = betas.to(device) 33 | gammas = gammas.to(device) 34 | 35 | ss2g_verts = self.ss2g_runner.forward(betas=betas, gammas=gammas).view(gt_verts.shape) 36 | pred_verts = ss2g_verts + self.model( 37 | torch.cat((thetas, betas, gammas), dim=1)).view(gt_verts.shape) 38 | 39 | # L1 loss 40 | data_loss = (pred_verts - gt_verts).abs().sum(-1).mean() 41 | return pred_verts, data_loss 42 | 43 | 44 | class Runner(base_trainer.Runner): 45 | def __init__(self, ckpt, params): 46 | super(Runner, self).__init__(ckpt, params) 47 | ss2g_logdir = "/BS/cpatel/work/data/learn_anim/test_ss2g/{}_{}".format( 48 | params['garment_class'], params['gender']) 49 | self.ss2g_runner = ss2g_runner(ss2g_logdir) 50 | 51 | def forward(self, thetas, betas, gammas): 52 | pred_verts = super(Runner, self).forward(thetas=thetas, betas=betas, gammas=gammas) 53 | pred_verts = pred_verts + self.ss2g_runner(betas=betas, gammas=gammas) 54 | return pred_verts 55 | 56 | 57 | def get_best_runner(log_dir, epoch_num=None): 58 | """Returns a trained model runner given the log_dir.""" 59 | ckpt_dir = log_dir 60 | with open(os.path.join(ckpt_dir, 'params.json')) as jf: 61 | params = json.load(jf) 62 | 63 | # if epoch_num is not given then pick up the best epoch 64 | if epoch_num is None: 65 | with open(os.path.join(ckpt_dir, 'best_epoch')) as f: 66 | best_epoch = int(f.read().strip()) 67 | else: 68 | best_epoch = epoch_num 69 | ckpt_path = os.path.join(ckpt_dir, "{:04d}".format(best_epoch), 'lin.pth.tar') 70 | ckpt_path = os.path.join(ckpt_dir, 'lin.pth.tar') 71 | 72 | runner = Runner(ckpt_path, params) 73 | return runner 74 | 75 | 76 | def parse_argument(): 77 | parser = argparse.ArgumentParser() 78 | parser.add_argument('--local_config', default='') 79 | 80 | parser.add_argument('--garment_class', default="t-shirt") 81 | parser.add_argument('--gender', default="male") 82 | parser.add_argument('--shape_style', default="") 83 | 84 | # some training hyper parameters 85 | parser.add_argument('--vis_freq', default=512, type=int) 86 | parser.add_argument('--batch_size', default=32, type=int) 87 | parser.add_argument('--lr', default=1e-4, type=float) 88 | parser.add_argument('--weight_decay', default=1e-6, type=float) 89 | parser.add_argument('--max_epoch', default=100, type=int) 90 | parser.add_argument('--start_epoch', default=0, type=int) 91 | parser.add_argument('--checkpoint', default="") 92 | 93 | # name under which experiment will be logged 94 | parser.add_argument('--log_name', default="test_eg_baseline") 95 | 96 | # smooth_level=1 will train TailorNet low frequency predictor 97 | parser.add_argument('--smooth_level', default=0, type=int) 98 | 99 | # model specification. 100 | parser.add_argument('--model_name', default="FcModified") 101 | parser.add_argument('--num_layers', default=3) 102 | parser.add_argument('--hidden_size', default=1024) 103 | 104 | # small experiment description 105 | parser.add_argument('--note', default="EG baseline") 106 | 107 | args = parser.parse_args() 108 | params = args.__dict__ 109 | 110 | # load params from local config if provided 111 | if os.path.exists(params['local_config']): 112 | print("loading config from {}".format(params['local_config'])) 113 | with open(params['local_config']) as f: 114 | lc = json.load(f) 115 | for k, v in lc.items(): 116 | params[k] = v 117 | return params 118 | 119 | 120 | def main(): 121 | params = parse_argument() 122 | 123 | print("start training {}".format(params['garment_class'])) 124 | trainer = EGTrainer(params) 125 | 126 | # try: 127 | if True: 128 | for i in range(params['start_epoch'], params['max_epoch']): 129 | print("epoch: {}".format(i)) 130 | trainer.train(i) 131 | trainer.validate(i) 132 | # trainer.save_ckpt(i) 133 | 134 | # except Exception as e: 135 | # print(str(e)) 136 | # finally: 137 | trainer.write_log() 138 | print("safely quit!") 139 | 140 | 141 | if __name__ == '__main__': 142 | main() -------------------------------------------------------------------------------- /refu_tailornet/trainer/hf_trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | from torch.utils.data import DataLoader 5 | import numpy as np 6 | import json 7 | import pickle 8 | 9 | from models import networks 10 | from models import ops 11 | from dataset.static_pose_shape_final import OneStyleShapeHF 12 | import global_var 13 | from trainer import base_trainer 14 | 15 | device = torch.device("cuda:0") 16 | # device = torch.device("cpu") 17 | 18 | 19 | class HFTrainer(base_trainer.Trainer): 20 | """Implements trainer class for TailorNet high frequency predictor. 21 | 22 | It overloads some functions of base_trainer.Trainer class. 23 | """ 24 | 25 | def load_dataset(self, split): 26 | params = self.params 27 | shape_idx, style_idx = params['shape_style'].split('_') 28 | 29 | dataset = OneStyleShapeHF(self.garment_class, shape_idx=shape_idx, style_idx=style_idx, split=split, 30 | gender=self.gender, smooth_level=params['smooth_level']) 31 | shuffle = True if split == 'train' else False 32 | if split == 'train' and len(dataset) > params['batch_size']: 33 | drop_last = True 34 | else: 35 | drop_last = False 36 | dataloader = DataLoader(dataset, batch_size=self.bs, num_workers=0, shuffle=shuffle, 37 | drop_last=drop_last) 38 | return dataset, dataloader 39 | 40 | def build_model(self): 41 | params = self.params 42 | model = getattr(networks, self.model_name)( 43 | input_size=72, output_size=self.vert_indices.shape[0] * 3, 44 | num_layers=params['num_layers'], 45 | hidden_size=params['hidden_size']) 46 | return model 47 | 48 | def one_step(self, inputs): 49 | gt_verts, smooth_verts, thetas, _, _, _ = inputs 50 | 51 | thetas = ops.mask_thetas(thetas, self.garment_class) 52 | gt_verts = gt_verts.to(device) 53 | smooth_verts = smooth_verts.to(device) 54 | thetas = thetas.to(device) 55 | 56 | # predicts residual over smooth groundtruth. 57 | pred_verts = self.model(thetas).view(gt_verts.shape) + smooth_verts 58 | 59 | # L1 loss 60 | data_loss = (pred_verts - gt_verts).abs().sum(-1).mean() 61 | return pred_verts, data_loss 62 | 63 | def update_metrics(self, metrics, inputs, outputs): 64 | gt_verts = inputs[0] 65 | pred_verts = outputs 66 | dist = ops.verts_dist(gt_verts, pred_verts.cpu()) * 1000. 67 | metrics['val_dist'].update(dist.item(), gt_verts.shape[0]) 68 | 69 | def visualize_batch(self, inputs, outputs, epoch): 70 | gt_verts, smooth_verts, thetas, betas, gammas, idxs = inputs 71 | new_inputs = (gt_verts, thetas, betas, gammas, idxs) 72 | super(HFTrainer, self).visualize_batch(new_inputs, outputs, epoch) 73 | 74 | 75 | class Runner(object): 76 | """A helper class to load a trained model.""" 77 | def __init__(self, ckpt, params): 78 | model_name = params['model_name'] 79 | garment_class = params['garment_class'] 80 | 81 | with open(os.path.join(global_var.DATA_DIR, global_var.GAR_INFO_FILE), 'rb') as f: 82 | class_info = pickle.load(f) 83 | output_size = len(class_info[garment_class]['vert_indices']) * 3 84 | 85 | self.model = getattr(networks, model_name)( 86 | input_size=72, output_size=output_size, 87 | hidden_size=params['hidden_size'] if 'hidden_size' in params else 1024, 88 | num_layers=params['num_layers'] if 'num_layers' in params else 3 89 | ) 90 | self.garment_class = params['garment_class'] 91 | 92 | print("loading {}".format(ckpt)) 93 | if torch.cuda.is_available(): 94 | self.model.cuda() 95 | state_dict = torch.load(ckpt) 96 | else: 97 | state_dict = torch.load(ckpt,map_location='cpu') 98 | self.model.load_state_dict(state_dict) 99 | self.model.eval() 100 | 101 | def forward(self, thetas, betas=None, gammas=None): 102 | thetas = ops.mask_thetas(thetas=thetas, garment_class=self.garment_class) 103 | pred_verts = self.model(thetas) 104 | return pred_verts 105 | 106 | def cuda(self): 107 | self.model.cuda() 108 | 109 | def to(self, device): 110 | self.model.to(device) 111 | 112 | 113 | class HFModel(torch.nn.Module): 114 | def __init__(self, params, ckpt = None): 115 | super(HFModel, self).__init__() 116 | model_name = params['model_name'] 117 | garment_class = params['garment_class'] 118 | 119 | with open(os.path.join(global_var.DATA_DIR, global_var.GAR_INFO_FILE), 'rb') as f: 120 | class_info = pickle.load(f) 121 | output_size = len(class_info[garment_class]['vert_indices']) * 3 122 | 123 | self.model = getattr(networks, model_name)( 124 | input_size=72, output_size=output_size, 125 | hidden_size=params['hidden_size'] if 'hidden_size' in params else 1024, 126 | num_layers=params['num_layers'] if 'num_layers' in params else 3 127 | ) 128 | self.garment_class = params['garment_class'] 129 | 130 | if ckpt is not None: 131 | print("loading {}".format(ckpt)) 132 | if torch.cuda.is_available(): 133 | self.model.cuda() 134 | state_dict = torch.load(ckpt) 135 | else: 136 | state_dict = torch.load(ckpt,map_location='cpu') 137 | self.model.load_state_dict(state_dict) 138 | 139 | def forward(self, thetas, betas=None, gammas=None): 140 | thetas = ops.mask_thetas(thetas=thetas, garment_class=self.garment_class) 141 | pred_verts = self.model(thetas) 142 | return pred_verts 143 | 144 | 145 | def get_best_runner(log_dir, epoch_num=None): 146 | """Returns a trained model runner given the log_dir.""" 147 | ckpt_dir = log_dir 148 | with open(os.path.join(ckpt_dir, 'params.json')) as jf: 149 | params = json.load(jf) 150 | 151 | # if epoch_num is not given then pick up the best epoch 152 | if epoch_num is None: 153 | ckpt_path = os.path.join(ckpt_dir, 'lin.pth.tar') 154 | else: 155 | # with open(os.path.join(ckpt_dir, 'best_epoch')) as f: 156 | # best_epoch = int(f.read().strip()) 157 | best_epoch = epoch_num 158 | ckpt_path = os.path.join(ckpt_dir, "{:04d}".format(best_epoch), 'lin.pth.tar') 159 | 160 | runner = Runner(ckpt_path, params) 161 | return runner 162 | 163 | def get_best_model(log_dir, epoch_num = None): 164 | ckpt_dir = log_dir 165 | with open(os.path.join(ckpt_dir, 'params.json')) as jf: 166 | params = json.load(jf) 167 | 168 | # if epoch_num is not given then pick up the best epoch 169 | if epoch_num is None: 170 | ckpt_path = os.path.join(ckpt_dir, 'lin.pth.tar') 171 | else: 172 | ckpt_path = os.path.join(ckpt_dir, "{:04d}".format(epoch_num), 'lin.pth.tar') 173 | 174 | model = HFModel(params, ckpt_path) 175 | return model 176 | 177 | 178 | 179 | def parse_argument(): 180 | parser = argparse.ArgumentParser() 181 | parser.add_argument('--local_config', default='') 182 | 183 | parser.add_argument('--garment_class', default="shirt") 184 | parser.add_argument('--gender', default="male") 185 | parser.add_argument('--shape_style', default = ['006_011', '007_002', '004_018', '001_007', '005_021', '007_004', '006_013', '001_009', '004_021', '007_018', '002_020', '002_024', '001_005', '003_024', '003_020', '002_001', '001_001', '000_011', '001_024'], nargs='+') 186 | 187 | # some training hyper parameters 188 | parser.add_argument('--vis_freq', default=16, type=int) 189 | parser.add_argument('--batch_size', default=32, type=int) 190 | parser.add_argument('--lr', default=1e-4, type=float) 191 | parser.add_argument('--weight_decay', default=1e-6, type=float) 192 | parser.add_argument('--max_epoch', default=400, type=int) 193 | parser.add_argument('--start_epoch', default=0, type=int) 194 | parser.add_argument('--checkpoint', default="") 195 | 196 | # name under which experiment will be logged 197 | parser.add_argument('--log_name', default="tn_hf_new") 198 | 199 | # smooth_level=1 will train HF for that smoothness level 200 | parser.add_argument('--smooth_level', default=1, type=int) 201 | 202 | # model specification. 203 | parser.add_argument('--model_name', default="FullyConnected") 204 | parser.add_argument('--num_layers', default=3) 205 | parser.add_argument('--hidden_size', default=1024) 206 | 207 | # small experiment description 208 | parser.add_argument('--note', default="TailorNet high frequency prediction") 209 | 210 | args = parser.parse_args() 211 | params = args.__dict__ 212 | 213 | # load params from local config if provided 214 | if os.path.exists(params['local_config']): 215 | print("loading config from {}".format(params['local_config'])) 216 | with open(params['local_config']) as f: 217 | lc = json.load(f) 218 | for k, v in lc.items(): 219 | params[k] = v 220 | return params 221 | 222 | 223 | def main(): 224 | params = parse_argument() 225 | shape_styles = params['shape_style'] 226 | 227 | for ss in shape_styles: 228 | params['shape_style'] = ss 229 | print("start training {} on {}".format(params['garment_class'], ss)) 230 | trainer = HFTrainer(params) 231 | 232 | for i in range(params['start_epoch'], params['max_epoch']): 233 | print("epoch: {}".format(i)) 234 | trainer.train(i) 235 | if i % 20 == 0: 236 | trainer.validate(i) 237 | # if i % 40 == 0: 238 | # trainer.save_ckpt(i) 239 | 240 | trainer.save_ckpt(params['max_epoch']-1) 241 | trainer.write_log() 242 | print("safely quit!") 243 | 244 | 245 | if __name__ == '__main__': 246 | main() -------------------------------------------------------------------------------- /refu_tailornet/trainer/hf_trainer_col_info.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | from torch.utils.data import DataLoader 5 | import numpy as np 6 | import json 7 | import pickle 8 | 9 | import sys 10 | 11 | currentdir = os.path.dirname(os.path.realpath(__file__)) 12 | parentdir = os.path.dirname(currentdir) 13 | sys.path.append(parentdir) 14 | 15 | from models import networks 16 | from models import ops 17 | from dataset.static_pose_shape_final import OneStyleShapeHF 18 | import global_var 19 | from trainer import base_trainer 20 | 21 | device = torch.device("cuda:0") 22 | # device = torch.device("cpu") 23 | 24 | 25 | class HFTrainer(base_trainer.Trainer): 26 | """Implements trainer class for TailorNet high frequency predictor. 27 | 28 | It overloads some functions of base_trainer.Trainer class. 29 | """ 30 | 31 | def load_dataset(self, split): 32 | params = self.params 33 | shape_idx, style_idx = params['shape_style'].split('_') 34 | 35 | dataset = OneStyleShapeHF(self.garment_class, shape_idx=shape_idx, style_idx=style_idx, split=split, 36 | gender=self.gender, smooth_level=params['smooth_level'], collision_info=True) 37 | shuffle = True if split == 'train' else False 38 | if split == 'train' and len(dataset) > params['batch_size']: 39 | drop_last = True 40 | else: 41 | drop_last = False 42 | dataloader = DataLoader(dataset, batch_size=self.bs, num_workers=0, shuffle=shuffle, 43 | drop_last=drop_last) 44 | return dataset, dataloader 45 | 46 | def build_model(self): 47 | params = self.params 48 | model = getattr(networks, self.model_name)( 49 | input_size=72, output_size=self.vert_indices.shape[0] * 3, 50 | num_layers=params['num_layers'], 51 | hidden_size=params['hidden_size']) 52 | return model 53 | 54 | def one_step(self, inputs): 55 | gt_verts, smooth_verts, thetas, _, _, _ = inputs 56 | 57 | thetas = ops.mask_thetas(thetas, self.garment_class) 58 | gt_verts = gt_verts.to(device) 59 | smooth_verts = smooth_verts.to(device) 60 | thetas = thetas.to(device) 61 | 62 | # predicts residual over smooth groundtruth. 63 | pred_verts = self.model(thetas).view(gt_verts.shape) + smooth_verts 64 | 65 | # L1 loss 66 | data_loss = (pred_verts - gt_verts).abs().sum(-1).mean() 67 | return pred_verts, data_loss 68 | 69 | def update_metrics(self, metrics, inputs, outputs): 70 | gt_verts = inputs[0] 71 | pred_verts = outputs 72 | dist = ops.verts_dist(gt_verts, pred_verts.cpu()) * 1000. 73 | metrics['val_dist'].update(dist.item(), gt_verts.shape[0]) 74 | 75 | def visualize_batch(self, inputs, outputs, epoch): 76 | gt_verts, smooth_verts, thetas, betas, gammas, idxs = inputs 77 | new_inputs = (gt_verts, thetas, betas, gammas, idxs) 78 | super(HFTrainer, self).visualize_batch(new_inputs, outputs, epoch) 79 | 80 | 81 | class Runner(object): 82 | """A helper class to load a trained model.""" 83 | def __init__(self, ckpt, params): 84 | model_name = params['model_name'] 85 | garment_class = params['garment_class'] 86 | 87 | with open(os.path.join(global_var.DATA_DIR, global_var.GAR_INFO_FILE), 'rb') as f: 88 | class_info = pickle.load(f) 89 | output_size = len(class_info[garment_class]['vert_indices']) * 3 90 | 91 | self.model = getattr(networks, model_name)( 92 | input_size=72, output_size=output_size, 93 | hidden_size=params['hidden_size'] if 'hidden_size' in params else 1024, 94 | num_layers=params['num_layers'] if 'num_layers' in params else 3 95 | ) 96 | self.garment_class = params['garment_class'] 97 | 98 | print("loading {}".format(ckpt)) 99 | if torch.cuda.is_available(): 100 | self.model.cuda() 101 | state_dict = torch.load(ckpt) 102 | else: 103 | state_dict = torch.load(ckpt,map_location='cpu') 104 | self.model.load_state_dict(state_dict) 105 | self.model.eval() 106 | 107 | def forward(self, thetas, betas=None, gammas=None): 108 | thetas = ops.mask_thetas(thetas=thetas, garment_class=self.garment_class) 109 | pred_verts = self.model(thetas) 110 | return pred_verts 111 | 112 | def cuda(self): 113 | self.model.cuda() 114 | 115 | def to(self, device): 116 | self.model.to(device) 117 | 118 | 119 | def get_best_runner(log_dir, epoch_num=None): 120 | """Returns a trained model runner given the log_dir.""" 121 | ckpt_dir = log_dir 122 | with open(os.path.join(ckpt_dir, 'params.json')) as jf: 123 | params = json.load(jf) 124 | 125 | # if epoch_num is not given then pick up the best epoch 126 | if epoch_num is None: 127 | ckpt_path = os.path.join(ckpt_dir, 'lin.pth.tar') 128 | else: 129 | # with open(os.path.join(ckpt_dir, 'best_epoch')) as f: 130 | # best_epoch = int(f.read().strip()) 131 | best_epoch = epoch_num 132 | ckpt_path = os.path.join(ckpt_dir, "{:04d}".format(best_epoch), 'lin.pth.tar') 133 | 134 | runner = Runner(ckpt_path, params) 135 | return runner 136 | 137 | 138 | def parse_argument(): 139 | parser = argparse.ArgumentParser() 140 | parser.add_argument('--local_config', default='') 141 | 142 | parser.add_argument('--garment_class', default="shirt") 143 | parser.add_argument('--gender', default="male") 144 | # t-shirt male 145 | # parser.add_argument('--shape_style', default = ['008_000', '006_023', '008_002', '006_016', '007_018', '008_009', '008_007', '008_018', '004_012', '006_006', '002_022', '003_019', '002_005', '005_005', '003_020', '004_024', '002_011', '001_006', '005_004', '002_024'], nargs='+') 146 | # shirt male 147 | parser.add_argument('--shape_style', default = ['006_011', '007_002', '004_018', '001_007', '005_021', '007_004', '006_013', '001_009', '004_021', '007_018', '002_020', '002_024', '001_005', '003_024', '003_020', '002_001', '001_001', '000_011', '001_024'], nargs='+') 148 | # short-pant male 149 | # parser.add_argument('--shape_style', default = ['008_023', '008_002', '008_009', '008_013', '008_004', '005_021', '006_018', '005_016', '005_012', '005_007', '000_005', '002_024', '002_005', '000_002', '001_024', '000_024', '003_001', '003_020', '003_014', '000_022'], nargs='+') 150 | # skirt female 151 | # parser.add_argument('--shape_style', default = ['005_020', '004_000', '005_000', '000_009', '002_005', '006_016', '007_012', '006_008', '004_001', '004_019', '002_003', '001_018', '000_003', '003_018', '002_020', '001_003', '001_020', '001_002', '000_001', '003_003'], nargs='+') 152 | 153 | 154 | 155 | 156 | # some training hyper parameters 157 | parser.add_argument('--vis_freq', default=16, type=int) 158 | parser.add_argument('--batch_size', default=32, type=int) 159 | parser.add_argument('--lr', default=1e-4, type=float) 160 | parser.add_argument('--weight_decay', default=1e-6, type=float) 161 | parser.add_argument('--max_epoch', default=800, type=int) 162 | parser.add_argument('--start_epoch', default=0, type=int) 163 | parser.add_argument('--checkpoint', default="") 164 | 165 | # name under which experiment will be logged 166 | parser.add_argument('--log_name', default="tn_hf_800_1") 167 | 168 | # smooth_level=1 will train HF for that smoothness level 169 | parser.add_argument('--smooth_level', default=1, type=int) 170 | 171 | # model specification. 172 | parser.add_argument('--model_name', default="FullyConnected") 173 | parser.add_argument('--num_layers', default=3) 174 | parser.add_argument('--hidden_size', default=1024) 175 | 176 | # small experiment description 177 | parser.add_argument('--note', default="TailorNet high frequency prediction") 178 | 179 | args = parser.parse_args() 180 | params = args.__dict__ 181 | 182 | # load params from local config if provided 183 | if os.path.exists(params['local_config']): 184 | print("loading config from {}".format(params['local_config'])) 185 | with open(params['local_config']) as f: 186 | lc = json.load(f) 187 | for k, v in lc.items(): 188 | params[k] = v 189 | return params 190 | 191 | 192 | def main(): 193 | params = parse_argument() 194 | shape_styles = params['shape_style'] 195 | 196 | for ss in shape_styles: 197 | params['shape_style'] = ss 198 | print("start training {} on {}".format(params['garment_class'], ss)) 199 | trainer = HFTrainer(params) 200 | 201 | for i in range(params['start_epoch'], params['max_epoch']): 202 | print("epoch: {}".format(i)) 203 | trainer.train(i) 204 | if i % 20 == 0: 205 | trainer.validate(i) 206 | # if i % 40 == 0: 207 | # trainer.save_ckpt(i) 208 | 209 | trainer.save_ckpt(params['max_epoch']-1) 210 | trainer.write_log() 211 | print("safely quit!") 212 | 213 | 214 | if __name__ == '__main__': 215 | main() -------------------------------------------------------------------------------- /refu_tailornet/trainer/lf_trainer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import torch 5 | from trainer import base_trainer 6 | 7 | device = torch.device("cuda:0") 8 | # device = torch.device("cpu") 9 | 10 | 11 | class LFTrainer(base_trainer.Trainer): 12 | """ 13 | Implements trainer class for TailorNet low frequency predictor. 14 | 15 | LF predictor training is exactly same as MLP baseline training 16 | but one change: smooth_level is set to 1 in config parameters 17 | to get low frequency GT data. 18 | """ 19 | pass 20 | 21 | 22 | def get_best_runner(log_dir, epoch_num=None): 23 | return base_trainer.get_best_runner(log_dir, epoch_num) 24 | 25 | def get_best_model(log_dir, epoch_num=None): 26 | return base_trainer.get_best_model(log_dir, epoch_num) 27 | 28 | 29 | def parse_argument(): 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument('--local_config', default='') 32 | 33 | parser.add_argument('--garment_class', default="t-shirt") 34 | parser.add_argument('--gender', default="male") 35 | parser.add_argument('--shape_style', default="") 36 | 37 | # some training hyper parameters 38 | parser.add_argument('--vis_freq', default=512, type=int) 39 | parser.add_argument('--batch_size', default=32, type=int) 40 | parser.add_argument('--lr', default=1e-4, type=float) 41 | parser.add_argument('--weight_decay', default=1e-6, type=float) 42 | parser.add_argument('--max_epoch', default=100, type=int) 43 | parser.add_argument('--start_epoch', default=0, type=int) 44 | parser.add_argument('--checkpoint', default="") 45 | 46 | # name under which experiment will be logged 47 | parser.add_argument('--log_name', default="tn_lf") 48 | 49 | # smooth_level=1 will train TailorNet low frequency predictor 50 | parser.add_argument('--smooth_level', default=1, type=int) 51 | 52 | # model specification. 53 | parser.add_argument('--model_name', default="FullyConnected") 54 | parser.add_argument('--num_layers', default=3) 55 | parser.add_argument('--hidden_size', default=1024) 56 | 57 | # small experiment description 58 | parser.add_argument('--note', default="TailorNet low frequency prediction") 59 | 60 | args = parser.parse_args() 61 | params = args.__dict__ 62 | 63 | # load params from local config if provided 64 | if os.path.exists(params['local_config']): 65 | print("loading config from {}".format(params['local_config'])) 66 | with open(params['local_config']) as f: 67 | lc = json.load(f) 68 | for k, v in lc.items(): 69 | params[k] = v 70 | return params 71 | 72 | 73 | def main(): 74 | params = parse_argument() 75 | 76 | print("start training {}".format(params['garment_class'])) 77 | trainer = LFTrainer(params) 78 | 79 | for i in range(params['start_epoch'], params['max_epoch']): 80 | print("epoch: {}".format(i)) 81 | trainer.train(i) 82 | trainer.validate(i) 83 | # trainer.save_ckpt(i) 84 | 85 | trainer.write_log() 86 | print("safely quit!") 87 | 88 | 89 | if __name__ == '__main__': 90 | main() 91 | -------------------------------------------------------------------------------- /refu_tailornet/trainer/ss2g_trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | from torch.utils.data import DataLoader 5 | import json 6 | import pickle 7 | 8 | import sys 9 | 10 | currentdir = os.path.dirname(os.path.realpath(__file__)) 11 | parentdir = os.path.dirname(currentdir) 12 | sys.path.append(parentdir) 13 | 14 | from models import networks 15 | from dataset.canonical_pose_dataset import ShapeStyleCanonPose 16 | import global_var 17 | from trainer import base_trainer 18 | from models import ops 19 | 20 | device = torch.device("cuda:0") 21 | # device = torch.device("cpu") 22 | 23 | 24 | class SS2GTrainer(base_trainer.Trainer): 25 | """ 26 | Implements trainer class to predict deformations in canonical pose. 27 | 28 | It overloads base_trainer.Trainer class. This predictor is used in 29 | TailorNet to get the weights of pivot high frequency outputs. 30 | """ 31 | def load_dataset(self, split): 32 | params = self.params 33 | dataset = ShapeStyleCanonPose(self.garment_class, split=split, gender=self.gender) 34 | shuffle = True if split == 'train' else False 35 | if split == 'train' and len(dataset) > params['batch_size']: 36 | drop_last = True 37 | else: 38 | drop_last = False 39 | dataloader = DataLoader(dataset, batch_size=self.bs, num_workers=0, shuffle=shuffle, 40 | drop_last=drop_last) 41 | return dataset, dataloader 42 | 43 | def build_model(self): 44 | params = self.params 45 | model = getattr(networks, self.model_name)( 46 | input_size=10+4, output_size=self.vert_indices.shape[0] * 3, 47 | num_layers=params['num_layers'], 48 | hidden_size=params['hidden_size']) 49 | return model 50 | 51 | def one_step(self, inputs): 52 | gt_verts, _, betas, gammas, _ = inputs 53 | _, betas, gammas = ops.mask_inputs(None, betas, gammas, self.garment_class) 54 | 55 | gt_verts = gt_verts.to(device) 56 | betas = betas.to(device) 57 | gammas = gammas.to(device) 58 | pred_verts = self.model( 59 | torch.cat((betas, gammas), dim=1)).view(gt_verts.shape) 60 | 61 | # L1 loss 62 | data_loss = (pred_verts - gt_verts).abs().sum(-1).mean() 63 | return pred_verts, data_loss 64 | 65 | def visualize_batch(self, inputs, outputs, epoch): 66 | # This is easy training so no need to visualize here 67 | return 68 | 69 | 70 | 71 | class Runner(object): 72 | """A helper class to load a trained model.""" 73 | def __init__(self, ckpt, params): 74 | model_name = params['model_name'] 75 | garment_class = params['garment_class'] 76 | 77 | with open(os.path.join(global_var.DATA_DIR, global_var.GAR_INFO_FILE), 'rb') as f: 78 | class_info = pickle.load(f) 79 | output_size = len(class_info[garment_class]['vert_indices']) * 3 80 | 81 | self.model = getattr(networks, model_name)( 82 | input_size=10+4, output_size=output_size, 83 | hidden_size=params['hidden_size'] if 'hidden_size' in params else 1024, 84 | num_layers=params['num_layers'] if 'num_layers' in params else 3 85 | ) 86 | self.garment_class = params['garment_class'] 87 | 88 | print("loading {}".format(ckpt)) 89 | if torch.cuda.is_available(): 90 | self.model.cuda() 91 | state_dict = torch.load(ckpt) 92 | else: 93 | state_dict = torch.load(ckpt,map_location='cpu') 94 | self.model.load_state_dict(state_dict) 95 | self.model.eval() 96 | 97 | def forward(self, thetas=None, betas=None, gammas=None): 98 | _, betas, gammas = ops.mask_inputs(None, betas, gammas, self.garment_class) 99 | pred_verts = self.model(torch.cat((betas, gammas), dim=1)) 100 | return pred_verts 101 | 102 | def cuda(self): 103 | self.model.cuda() 104 | 105 | def to(self, device): 106 | self.model.to(device) 107 | 108 | class SS2GModel(torch.nn.Module): 109 | """A helper class to load a trained model.""" 110 | def __init__(self, params, ckpt = None): 111 | super(SS2GModel, self).__init__() 112 | model_name = params['model_name'] 113 | garment_class = params['garment_class'] 114 | 115 | with open(os.path.join(global_var.DATA_DIR, global_var.GAR_INFO_FILE), 'rb') as f: 116 | class_info = pickle.load(f) 117 | output_size = len(class_info[garment_class]['vert_indices']) * 3 118 | 119 | self.model = getattr(networks, model_name)( 120 | input_size=10+4, output_size=output_size, 121 | hidden_size=params['hidden_size'] if 'hidden_size' in params else 1024, 122 | num_layers=params['num_layers'] if 'num_layers' in params else 3 123 | ) 124 | self.garment_class = params['garment_class'] 125 | 126 | if ckpt is not None: 127 | print("loading {}".format(ckpt)) 128 | if torch.cuda.is_available(): 129 | self.model.cuda() 130 | state_dict = torch.load(ckpt) 131 | else: 132 | state_dict = torch.load(ckpt,map_location='cpu') 133 | self.model.load_state_dict(state_dict) 134 | 135 | def forward(self, thetas=None, betas=None, gammas=None): 136 | _, betas, gammas = ops.mask_inputs(None, betas, gammas, self.garment_class) 137 | pred_verts = self.model(torch.cat((betas, gammas), dim=1)) 138 | return pred_verts 139 | 140 | 141 | def get_best_runner(log_dir, epoch_num=None): 142 | """Returns a trained model runner given the log_dir.""" 143 | ckpt_dir = log_dir 144 | with open(os.path.join(ckpt_dir, 'params.json')) as jf: 145 | params = json.load(jf) 146 | 147 | # if epoch_num is not given then pick up the best epoch 148 | if epoch_num is None: 149 | ckpt_path = os.path.join(ckpt_dir, 'lin.pth.tar') 150 | else: 151 | # with open(os.path.join(ckpt_dir, 'best_epoch')) as f: 152 | # best_epoch = int(f.read().strip()) 153 | best_epoch = epoch_num 154 | ckpt_path = os.path.join(ckpt_dir, "{:04d}".format(best_epoch), 'lin.pth.tar') 155 | 156 | runner = Runner(ckpt_path, params) 157 | return runner 158 | 159 | 160 | def get_best_model(log_dir, epoch_num = None): 161 | ckpt_dir = log_dir 162 | with open(os.path.join(ckpt_dir, 'params.json')) as jf: 163 | params = json.load(jf) 164 | 165 | # if epoch_num is not given then pick up the best epoch 166 | if epoch_num is None: 167 | ckpt_path = os.path.join(ckpt_dir, 'lin.pth.tar') 168 | else: 169 | ckpt_path = os.path.join(ckpt_dir, "{:04d}".format(epoch_num), 'lin.pth.tar') 170 | 171 | model = SS2GModel(params, ckpt_path) 172 | return model 173 | 174 | 175 | def parse_argument(): 176 | parser = argparse.ArgumentParser() 177 | parser.add_argument('--local_config', default='') 178 | 179 | parser.add_argument('--garment_class', default="shirt") 180 | parser.add_argument('--gender', default="male") 181 | parser.add_argument('--shape_style', default="") 182 | 183 | # some training hyper parameters 184 | parser.add_argument('--vis_freq', default=512, type=int) 185 | parser.add_argument('--batch_size', default=32, type=int) 186 | parser.add_argument('--lr', default=1e-3, type=float) 187 | parser.add_argument('--weight_decay', default=1e-5, type=float) 188 | parser.add_argument('--max_epoch', default=400, type=int) 189 | parser.add_argument('--start_epoch', default=0, type=int) 190 | parser.add_argument('--checkpoint', default="") 191 | 192 | # name under which experiment will be logged 193 | parser.add_argument('--log_name', default="tn_ss2g_400_1") 194 | 195 | # smooth_level=0 will train TailorNet MLP baseline 196 | parser.add_argument('--smooth_level', default=0, type=int) 197 | 198 | # model specification. 199 | parser.add_argument('--model_name', default="FullyConnected") 200 | parser.add_argument('--num_layers', default=3) 201 | parser.add_argument('--hidden_size', default=128) 202 | 203 | # small experiment description 204 | parser.add_argument('--note', default="SS2G training") 205 | 206 | args = parser.parse_args() 207 | params = args.__dict__ 208 | 209 | # load params from local config if provided 210 | if os.path.exists(params['local_config']): 211 | print("loading config from {}".format(params['local_config'])) 212 | with open(params['local_config']) as f: 213 | lc = json.load(f) 214 | for k, v in lc.items(): 215 | params[k] = v 216 | return params 217 | 218 | 219 | def main(): 220 | params = parse_argument() 221 | 222 | print("start training ss2g {}".format(params['garment_class'])) 223 | trainer = SS2GTrainer(params) 224 | 225 | # try: 226 | if True: 227 | for i in range(params['start_epoch'], params['max_epoch']): 228 | print("epoch: {}".format(i)) 229 | trainer.train(i) 230 | trainer.validate(i) 231 | 232 | # except Exception as e: 233 | # print(str(e)) 234 | # finally: 235 | trainer.write_log() 236 | print("safely quit!") 237 | 238 | 239 | if __name__ == '__main__': 240 | main() 241 | -------------------------------------------------------------------------------- /sdf/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aldehydecho/ReFU/9476bb1da1b8adc1001d21758e884aefb91ffd60/sdf/datasets/__init__.py -------------------------------------------------------------------------------- /sdf/datasets/smpldataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as data 3 | import numpy as np 4 | import os 5 | 6 | import pickle 7 | import trimesh 8 | from tqdm import tqdm, trange 9 | import random 10 | 11 | import sys 12 | 13 | currentdir = os.path.dirname(os.path.realpath(__file__)) 14 | parentdir = os.path.dirname(currentdir) 15 | sys.path.append(parentdir) 16 | 17 | import igrutils.general as utils 18 | 19 | class SMPLDataSet(data.Dataset): 20 | 21 | def __init__(self, dataset_path, file_name_list_file, points_batch=8000, with_normals=False, return_index=False): 22 | 23 | self.dataset_path = dataset_path 24 | 25 | with open(file_name_list_file, 'rb') as inputfile: 26 | self.file_name_list = pickle.load(inputfile) 27 | 28 | self.points_batch = points_batch 29 | 30 | self.with_normals = with_normals 31 | 32 | self.return_index = return_index 33 | 34 | def load_data(self, index): 35 | info_str, id = self.file_name_list[index] 36 | 37 | shape_str, style_str, seq_str = info_str.split('_') 38 | 39 | beta = torch.from_numpy(np.load(os.path.join(self.dataset_path, 'shape/beta_{}.npy'.format(shape_str)))[0:10]).float() 40 | theta = torch.from_numpy(np.load(os.path.join(self.dataset_path, 'pose/{}_{}/poses_{}.npz'.format(shape_str, style_str, seq_str)))['thetas'][id, :]).float() 41 | 42 | obj_dir_path = os.path.join(self.dataset_path, 'pose_obj_total', info_str, 'body') 43 | 44 | cloth_obj_dir_path = os.path.join(self.dataset_path, 'pose_obj_total', info_str, 'garment') 45 | 46 | mnlfold = np.load(os.path.join(obj_dir_path, 'manifold_'+str(id)+'.npy')).astype(np.float32) 47 | 48 | point_set_mnlfold = torch.from_numpy(mnlfold[:, :3]) 49 | 50 | sdf_pair = np.load(os.path.join(obj_dir_path, 'sdf_'+str(id)+'.npy')).astype(np.float32) 51 | 52 | cloth_sdf_pair = np.load(os.path.join(cloth_obj_dir_path, 'garment_sdf_'+str(id)+'.npy')).astype(np.float32) 53 | 54 | sdf_pair = np.concatenate((sdf_pair, cloth_sdf_pair), axis=0) 55 | 56 | sdf_point = torch.from_numpy(sdf_pair[:, 0:3]) 57 | 58 | sdf_value = torch.from_numpy(sdf_pair[:, 3]) 59 | 60 | if self.with_normals == True: 61 | normals = torch.from_numpy(mnlfold[:, -3:]) 62 | 63 | return beta, theta, point_set_mnlfold, normals, sdf_point, sdf_value 64 | 65 | else: 66 | return beta, theta, point_set_mnlfold, sdf_point, sdf_value 67 | 68 | 69 | def __getitem__(self, index): 70 | 71 | if self.with_normals == True: 72 | beta, theta, point_set_mnlfold, normals, sdf_point, sdf_value = self.load_data(index) 73 | else: 74 | beta, theta, point_set_mnlfold, sdf_point, sdf_value = self.load_data(index) 75 | normals = torch.empty(0) 76 | 77 | random_idx = torch.randperm(point_set_mnlfold.shape[0])[:self.points_batch] 78 | 79 | point_set_mnlfold = torch.index_select(point_set_mnlfold, 0, random_idx) 80 | 81 | if self.with_normals: 82 | normals = torch.index_select(normals, 0, random_idx) 83 | 84 | random_idx = torch.randperm(sdf_point.shape[0])[:(self.points_batch+self.points_batch//8)] 85 | 86 | sdf_point = torch.index_select(sdf_point, 0, random_idx) 87 | sdf_value = torch.index_select(sdf_value, 0, random_idx) 88 | 89 | if self.return_index == True: 90 | return beta, theta, point_set_mnlfold, normals, sdf_point, sdf_value, self.file_name_list[index][0] + '_'+ str(self.file_name_list[index][1]), index 91 | else: 92 | return beta, theta, point_set_mnlfold, normals, sdf_point, sdf_value, self.file_name_list[index][0] + '_'+ str(self.file_name_list[index][1]) 93 | 94 | def __len__(self): 95 | return len(self.file_name_list) 96 | 97 | 98 | 99 | def __init__(self, dataset_path, file_name_list_file, points_batch=8000, with_normals=False): 100 | 101 | self.dataset_path = dataset_path 102 | 103 | with open(file_name_list_file, 'rb') as inputfile: 104 | self.file_name_list = pickle.load(inputfile) 105 | 106 | self.points_batch = points_batch 107 | 108 | self.with_normals = with_normals 109 | 110 | def load_data(self, index): 111 | info_str, id = self.file_name_list[index] 112 | 113 | shape_str, style_str, seq_str = info_str.split('_') 114 | 115 | beta = torch.from_numpy(np.load(os.path.join(self.dataset_path, 'shape/beta_{}.npy'.format(shape_str)))[0:10]).float() 116 | theta = torch.from_numpy(np.load(os.path.join(self.dataset_path, 'pose/{}_{}/poses_{}.npz'.format(shape_str, style_str, seq_str)))['thetas'][id, :]).float() 117 | 118 | obj_dir_path = os.path.join(self.dataset_path, 'pose_obj_total', info_str, 'body') 119 | 120 | vertices = torch.from_numpy(np.array(trimesh.load(os.path.join(obj_dir_path, str(id)+'.obj')).vertices)).float() 121 | 122 | if self.with_normals == True: 123 | with open(os.path.join(obj_dir_path, 'vn_'+str(id)+'.pkl'), 'rb') as normal_file: 124 | normals = torch.from_numpy(pickle.load(normal_file)).float() 125 | 126 | return beta, theta, vertices, normals 127 | 128 | else: 129 | return beta, theta, vertices 130 | 131 | 132 | def __getitem__(self, index): 133 | 134 | if self.with_normals == True: 135 | beta, theta, vertices, normals = self.load_data(index) 136 | else: 137 | beta, theta, vertices = self.load_data(index) 138 | normals = torch.empty(0) 139 | 140 | if self.points_batch is not None: 141 | 142 | random_idx = torch.randperm(vertices.shape[0])[:self.points_batch] 143 | 144 | vertices = torch.index_select(vertices, 0, random_idx) 145 | 146 | if self.with_normals: 147 | normals = torch.index_select(normals, 0, random_idx) 148 | 149 | return beta, theta, vertices, normals, self.file_name_list[index][0] + '_'+ str(self.file_name_list[index][1]) 150 | 151 | def __len__(self): 152 | return len(self.file_name_list) -------------------------------------------------------------------------------- /sdf/global_var_fun.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join as opj 3 | import shutil 4 | 5 | ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) 6 | 7 | def backup_file(src, dst): 8 | if os.path.exists(opj(src, 'nobackup')): 9 | return 10 | if len([k for k in os.listdir(src) if k.endswith('.py') or k.endswith('.sh')]) == 0: 11 | return 12 | if not os.path.isdir(dst): 13 | os.makedirs(dst) 14 | all_files = os.listdir(src) 15 | for fname in all_files: 16 | fname_full = opj(src, fname) 17 | fname_dst = opj(dst, fname) 18 | if os.path.isdir(fname_full): 19 | backup_file(fname_full, fname_dst) 20 | elif fname.endswith('.py') or fname.endswith('.sh'): 21 | shutil.copy(fname_full, fname_dst) -------------------------------------------------------------------------------- /sdf/igrutils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aldehydecho/ReFU/9476bb1da1b8adc1001d21758e884aefb91ffd60/sdf/igrutils/__init__.py -------------------------------------------------------------------------------- /sdf/igrutils/general.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | import trimesh 6 | 7 | 8 | def mkdir_ifnotexists(directory): 9 | if not os.path.exists(directory): 10 | os.mkdir(directory) 11 | 12 | 13 | def as_mesh(scene_or_mesh): 14 | """ 15 | Convert a possible scene to a mesh. 16 | 17 | If conversion occurs, the returned mesh has only vertex and face data. 18 | """ 19 | if isinstance(scene_or_mesh, trimesh.Scene): 20 | if len(scene_or_mesh.geometry) == 0: 21 | mesh = None # empty scene 22 | else: 23 | # we lose texture information here 24 | mesh = trimesh.util.concatenate( 25 | tuple(trimesh.Trimesh(vertices=g.vertices, faces=g.faces) 26 | for g in scene_or_mesh.geometry.values())) 27 | else: 28 | assert(isinstance(scene_or_mesh, trimesh.Trimesh)) 29 | mesh = scene_or_mesh 30 | return mesh 31 | 32 | 33 | def concat_home_dir(path): 34 | return os.path.join(os.environ['HOME'],'data',path) 35 | 36 | 37 | def get_class(kls): 38 | parts = kls.split('.') 39 | module = ".".join(parts[:-1]) 40 | m = __import__(module) 41 | for comp in parts[1:]: 42 | m = getattr(m, comp) 43 | return m 44 | 45 | 46 | def to_cuda(torch_obj): 47 | if torch.cuda.is_available(): 48 | return torch_obj.cuda() 49 | else: 50 | return torch_obj 51 | 52 | 53 | def load_point_cloud_by_file_extension(file_name): 54 | 55 | ext = file_name.split('.')[-1] 56 | 57 | if ext == "npz" or ext == "npy": 58 | point_set = torch.tensor(np.load(file_name)).float() 59 | else: 60 | point_set = torch.tensor(trimesh.load(file_name, ext).vertices).float() 61 | 62 | return point_set 63 | 64 | 65 | class LearningRateSchedule: 66 | def get_learning_rate(self, epoch): 67 | pass 68 | 69 | 70 | class StepLearningRateSchedule(LearningRateSchedule): 71 | def __init__(self, initial, interval, factor): 72 | self.initial = initial 73 | self.interval = interval 74 | self.factor = factor 75 | 76 | def get_learning_rate(self, epoch): 77 | return np.maximum(self.initial * (self.factor ** (epoch // self.interval)), 5.0e-6) -------------------------------------------------------------------------------- /sdf/igrutils/plots.py: -------------------------------------------------------------------------------- 1 | import plotly.graph_objs as go 2 | import plotly.offline as offline 3 | import torch 4 | import numpy as np 5 | from skimage import measure 6 | import os 7 | import igrutils.general as utils 8 | 9 | 10 | def get_threed_scatter_trace(points,caption = None,colorscale = None,color = None): 11 | 12 | if (type(points) == list): 13 | trace = [go.Scatter3d( 14 | x=p[0][:, 0], 15 | y=p[0][:, 1], 16 | z=p[0][:, 2], 17 | mode='markers', 18 | name=p[1], 19 | marker=dict( 20 | size=3, 21 | line=dict( 22 | width=2, 23 | ), 24 | opacity=0.9, 25 | colorscale=colorscale, 26 | showscale=True, 27 | color=color, 28 | ), text=caption) for p in points] 29 | 30 | else: 31 | 32 | trace = [go.Scatter3d( 33 | x=points[:,0], 34 | y=points[:,1], 35 | z=points[:,2], 36 | mode='markers', 37 | name='projection', 38 | marker=dict( 39 | size=3, 40 | line=dict( 41 | width=2, 42 | ), 43 | opacity=0.9, 44 | colorscale=colorscale, 45 | showscale=True, 46 | color=color, 47 | ), text=caption)] 48 | 49 | return trace 50 | 51 | 52 | def plot_threed_scatter(points,path,epoch,in_epoch): 53 | trace = get_threed_scatter_trace(points) 54 | layout = go.Layout(width=1200, height=1200, scene=dict(xaxis=dict(range=[-2, 2], autorange=False), 55 | yaxis=dict(range=[-2, 2], autorange=False), 56 | zaxis=dict(range=[-2, 2], autorange=False), 57 | aspectratio=dict(x=1, y=1, z=1))) 58 | 59 | fig1 = go.Figure(data=trace, layout=layout) 60 | 61 | filename = '{0}/scatter_iteration_{1}_{2}.html'.format(path, epoch, in_epoch) 62 | offline.plot(fig1, filename=filename, auto_open=False) 63 | 64 | 65 | 66 | 67 | def plot_surface(decoder,path,epoch, shapename,resolution,mc_value,is_uniform_grid,verbose,save_html,save_ply,overwrite, points=None, with_points=False, latent=None, connected=False, filename_format_simple=False): 68 | 69 | if filename_format_simple == True: 70 | filename = '{0}/{1}'.format(path, shapename) 71 | else: 72 | filename = '{0}/igr_{1}_{2}'.format(path, epoch, shapename) 73 | 74 | if (not os.path.exists(filename) or overwrite): 75 | 76 | if with_points: 77 | pnts_val = decoder(points) 78 | pnts_val = pnts_val.cpu() 79 | points = points.cpu() 80 | caption = ["decoder : {0}".format(val.item()) for val in pnts_val.squeeze()] 81 | trace_pnts = get_threed_scatter_trace(points[:,-3:],caption=caption) 82 | 83 | surface = get_surface_trace(points,decoder,latent,resolution,mc_value,is_uniform_grid,verbose,save_ply, connected) 84 | trace_surface = surface["mesh_trace"] 85 | 86 | layout = go.Layout(title= go.layout.Title(text=shapename), width=1200, height=1200, scene=dict(xaxis=dict(range=[-2, 2], autorange=False), 87 | yaxis=dict(range=[-2, 2], autorange=False), 88 | zaxis=dict(range=[-2, 2], autorange=False), 89 | aspectratio=dict(x=1, y=1, z=1))) 90 | if (with_points): 91 | fig1 = go.Figure(data=trace_pnts + trace_surface, layout=layout) 92 | else: 93 | fig1 = go.Figure(data=trace_surface, layout=layout) 94 | 95 | 96 | if (save_html): 97 | offline.plot(fig1, filename=filename + '.html', auto_open=False) 98 | if (not surface['mesh_export'] is None): 99 | surface['mesh_export'].export(filename + '.ply', 'ply') 100 | return surface['mesh_export'] 101 | 102 | def get_surface_verts(points,decoder,latent,resolution,mc_value,is_uniform_grid,verbose,save_html,save_ply,overwrite): 103 | 104 | if (is_uniform_grid): 105 | grid = get_grid_uniform(resolution) 106 | else: 107 | if not points is None: 108 | grid = get_grid(points[:,-3:],resolution) 109 | else: 110 | grid = get_grid(None, resolution) 111 | 112 | z = [] 113 | 114 | for i,pnts in enumerate(torch.split(grid['grid_points'],100000,dim=0)): 115 | if (verbose): 116 | print ('{0}'.format(i/(grid['grid_points'].shape[0] // 100000) * 100)) 117 | 118 | if (not latent is None): 119 | pnts = torch.cat([latent.expand(pnts.shape[0], -1), pnts], dim=1) 120 | z.append(decoder(pnts).detach().cpu().numpy()) 121 | z = np.concatenate(z,axis=0) 122 | 123 | if (not (np.min(z) > mc_value or np.max(z) < mc_value)): 124 | 125 | z = z.astype(np.float64) 126 | 127 | verts, faces, normals, values = measure.marching_cubes_lewiner( 128 | volume=z.reshape(grid['xyz'][1].shape[0], grid['xyz'][0].shape[0], 129 | grid['xyz'][2].shape[0]).transpose([1, 0, 2]), 130 | level=mc_value, 131 | spacing=(grid['xyz'][0][2] - grid['xyz'][0][1], 132 | grid['xyz'][0][2] - grid['xyz'][0][1], 133 | grid['xyz'][0][2] - grid['xyz'][0][1])) 134 | 135 | verts = verts + np.array([grid['xyz'][0][0],grid['xyz'][1][0],grid['xyz'][2][0]]) 136 | 137 | return verts 138 | 139 | 140 | def get_surface_trace(points,decoder,latent,resolution,mc_value,is_uniform,verbose,save_ply, connected=False): 141 | 142 | trace = [] 143 | meshexport = None 144 | 145 | if (is_uniform): 146 | grid = get_grid_uniform(resolution) 147 | else: 148 | if not points is None: 149 | grid = get_grid(points[:,-3:],resolution) 150 | else: 151 | grid = get_grid(None, resolution) 152 | 153 | z = [] 154 | 155 | for i,pnts in enumerate(torch.split(grid['grid_points'],100000,dim=0)): 156 | if (verbose): 157 | print ('{0}'.format(i/(grid['grid_points'].shape[0] // 100000) * 100)) 158 | 159 | if (not latent is None): 160 | pnts = torch.cat([latent.expand(pnts.shape[0], -1), pnts], dim=1) 161 | z.append(decoder(pnts).detach().cpu().numpy()) 162 | z = np.concatenate(z,axis=0) 163 | 164 | if (not (np.min(z) > mc_value or np.max(z) < mc_value)): 165 | 166 | import trimesh 167 | z = z.astype(np.float64) 168 | 169 | verts, faces, normals, values = measure.marching_cubes_lewiner( 170 | volume=z.reshape(grid['xyz'][1].shape[0], grid['xyz'][0].shape[0], 171 | grid['xyz'][2].shape[0]).transpose([1, 0, 2]), 172 | level=mc_value, 173 | spacing=(grid['xyz'][0][2] - grid['xyz'][0][1], 174 | grid['xyz'][0][2] - grid['xyz'][0][1], 175 | grid['xyz'][0][2] - grid['xyz'][0][1])) 176 | 177 | verts = verts + np.array([grid['xyz'][0][0],grid['xyz'][1][0],grid['xyz'][2][0]]) 178 | if (save_ply): 179 | meshexport = trimesh.Trimesh(verts, faces, normals, vertex_colors=values) 180 | if connected: 181 | connected_comp = meshexport.split(only_watertight=False) 182 | max_area = 0 183 | max_comp = None 184 | for comp in connected_comp: 185 | if comp.area > max_area: 186 | max_area = comp.area 187 | max_comp = comp 188 | meshexport = max_comp 189 | 190 | def tri_indices(simplices): 191 | return ([triplet[c] for triplet in simplices] for c in range(3)) 192 | 193 | I, J, K = tri_indices(faces) 194 | 195 | trace.append(go.Mesh3d(x=verts[:, 0], y=verts[:, 1], z=verts[:, 2], 196 | i=I, j=J, k=K, name='', 197 | color='orange', opacity=0.5)) 198 | 199 | 200 | 201 | return {"mesh_trace":trace, 202 | "mesh_export":meshexport} 203 | 204 | 205 | def plot_cuts_axis(points,decoder,latent,path,epoch,near_zero,axis,file_name_sep='/'): 206 | onedim_cut = np.linspace(-1.0, 1.0, 200) 207 | xx, yy = np.meshgrid(onedim_cut, onedim_cut) 208 | xx = xx.ravel() 209 | yy = yy.ravel() 210 | min_axis = points[:,axis].min(dim=0)[0].item() 211 | max_axis = points[:,axis].max(dim=0)[0].item() 212 | mask = np.zeros(3) 213 | mask[axis] = 1.0 214 | if (axis == 0): 215 | position_cut = np.vstack(([np.zeros(xx.shape[0]), xx, yy])) 216 | elif (axis == 1): 217 | position_cut = np.vstack(([xx,np.zeros(xx.shape[0]), yy])) 218 | elif (axis == 2): 219 | position_cut = np.vstack(([xx, yy, np.zeros(xx.shape[0])])) 220 | position_cut = [position_cut + i*mask.reshape(-1, 1) for i in np.linspace(min_axis - 0.1, max_axis + 0.1, 50)] 221 | for index, pos in enumerate(position_cut): 222 | #fig = tools.make_subplots(rows=1, cols=1) 223 | 224 | field_input = utils.to_cuda(torch.tensor(pos.T, dtype=torch.float)) 225 | z = [] 226 | for i, pnts in enumerate(torch.split(field_input, 10000, dim=0)): 227 | if (not latent is None): 228 | pnts = torch.cat([latent.expand(pnts.shape[0], -1), pnts], dim=1) 229 | z.append(decoder(pnts).detach().cpu().numpy()) 230 | z = np.concatenate(z, axis=0) 231 | 232 | if (near_zero): 233 | if (np.min(z) < -1.0e-5): 234 | start = -0.1 235 | else: 236 | start = 0.0 237 | trace1 = go.Contour(x=onedim_cut, 238 | y=onedim_cut, 239 | z=z.reshape(onedim_cut.shape[0], onedim_cut.shape[0]), 240 | name='axis {0} = {1}'.format(axis,pos[axis, 0]), # colorbar=dict(len=0.4, y=0.8), 241 | autocontour=False, 242 | contours=dict( 243 | start=start, 244 | end=0.1, 245 | size=0.01 246 | ) 247 | # ),colorbar = {'dtick': 0.05} 248 | ) 249 | else: 250 | trace1 = go.Contour(x=onedim_cut, 251 | y=onedim_cut, 252 | z=z.reshape(onedim_cut.shape[0], onedim_cut.shape[0]), 253 | name='axis {0} = {1}'.format(axis,pos[axis, 0]), # colorbar=dict(len=0.4, y=0.8), 254 | autocontour=True, 255 | ncontours=70 256 | # contours=dict( 257 | # start=-0.001, 258 | # end=0.001, 259 | # size=0.00001 260 | # ) 261 | # ),colorbar = {'dtick': 0.05} 262 | ) 263 | 264 | layout = go.Layout(width=1200, height=1200, scene=dict(xaxis=dict(range=[-1, 1], autorange=False), 265 | yaxis=dict(range=[-1, 1], autorange=False), 266 | aspectratio=dict(x=1, y=1)), 267 | title=dict(text='axis {0} = {1}'.format(axis,pos[axis, 0]))) 268 | # fig['layout']['xaxis2'].update(range=[-1, 1]) 269 | # fig['layout']['yaxis2'].update(range=[-1, 1], scaleanchor="x2", scaleratio=1) 270 | 271 | filename = '{0}{1}cutsaxis_{2}_{3}_{4}.html'.format(path,file_name_sep,axis, epoch, index) 272 | fig1 = go.Figure(data=[trace1], layout=layout) 273 | offline.plot(fig1, filename=filename, auto_open=False) 274 | 275 | 276 | def plot_cuts(points,decoder,path,epoch,near_zero,latent=None): 277 | onedim_cut = np.linspace(-1, 1, 200) 278 | xx, yy = np.meshgrid(onedim_cut, onedim_cut) 279 | xx = xx.ravel() 280 | yy = yy.ravel() 281 | min_y = points[:,-2].min(dim=0)[0].item() 282 | max_y = points[:,-2].max(dim=0)[0].item() 283 | position_cut = np.vstack(([xx, np.zeros(xx.shape[0]), yy])) 284 | position_cut = [position_cut + np.array([0., i, 0.]).reshape(-1, 1) for i in np.linspace(min_y - 0.1, max_y + 0.1, 10)] 285 | for index, pos in enumerate(position_cut): 286 | #fig = tools.make_subplots(rows=1, cols=1) 287 | 288 | field_input = torch.tensor(pos.T, dtype=torch.float).cuda() 289 | z = [] 290 | for i, pnts in enumerate(torch.split(field_input, 1000, dim=-1)): 291 | input_=pnts 292 | if (not latent is None): 293 | input_ = torch.cat([latent.expand(pnts.shape[0],-1) ,pnts],dim=1) 294 | z.append(decoder(input_).detach().cpu().numpy()) 295 | z = np.concatenate(z, axis=0) 296 | 297 | if (near_zero): 298 | trace1 = go.Contour(x=onedim_cut, 299 | y=onedim_cut, 300 | z=z.reshape(onedim_cut.shape[0], onedim_cut.shape[0]), 301 | name='y = {0}'.format(pos[1, 0]), # colorbar=dict(len=0.4, y=0.8), 302 | autocontour=False, 303 | contours=dict( 304 | start=-0.001, 305 | end=0.001, 306 | size=0.00001 307 | ) 308 | # ),colorbar = {'dtick': 0.05} 309 | ) 310 | else: 311 | trace1 = go.Contour(x=onedim_cut, 312 | y=onedim_cut, 313 | z=z.reshape(onedim_cut.shape[0], onedim_cut.shape[0]), 314 | name='y = {0}'.format(pos[1, 0]), # colorbar=dict(len=0.4, y=0.8), 315 | autocontour=True, 316 | # contours=dict( 317 | # start=-0.001, 318 | # end=0.001, 319 | # size=0.00001 320 | # ) 321 | # ),colorbar = {'dtick': 0.05} 322 | ) 323 | 324 | layout = go.Layout(width=1200, height=1200, scene=dict(xaxis=dict(range=[-1, 1], autorange=False), 325 | yaxis=dict(range=[-1, 1], autorange=False), 326 | aspectratio=dict(x=1, y=1)), 327 | title=dict(text='y = {0}'.format(pos[1, 0]))) 328 | # fig['layout']['xaxis2'].update(range=[-1, 1]) 329 | # fig['layout']['yaxis2'].update(range=[-1, 1], scaleanchor="x2", scaleratio=1) 330 | 331 | filename = '{0}/cuts{1}_{2}.html'.format(path, epoch, index) 332 | fig1 = go.Figure(data=[trace1], layout=layout) 333 | offline.plot(fig1, filename=filename, auto_open=False) 334 | 335 | 336 | def get_grid(points,resolution): 337 | eps = 0.1 338 | input_min = torch.min(points, dim=0)[0].squeeze().cpu().numpy() 339 | input_max = torch.max(points, dim=0)[0].squeeze().cpu().numpy() 340 | bounding_box = input_max - input_min 341 | shortest_axis = np.argmin(bounding_box) 342 | if (shortest_axis == 0): 343 | x = np.linspace(input_min[shortest_axis] - eps, 344 | input_max[shortest_axis] + eps, resolution) 345 | length = np.max(x) - np.min(x) 346 | y = np.arange(input_min[1] - eps, input_max[1] + length / (x.shape[0] - 1) + eps, length / (x.shape[0] - 1)) 347 | z = np.arange(input_min[2] - eps, input_max[2] + length / (x.shape[0] - 1) + eps, length / (x.shape[0] - 1)) 348 | elif (shortest_axis == 1): 349 | y = np.linspace(input_min[shortest_axis] - eps, 350 | input_max[shortest_axis] + eps, resolution) 351 | length = np.max(y) - np.min(y) 352 | x = np.arange(input_min[0] - eps, input_max[0] + length / (y.shape[0] - 1) + eps, length / (y.shape[0] - 1)) 353 | z = np.arange(input_min[2] - eps, input_max[2] + length / (y.shape[0] - 1) + eps, length / (y.shape[0] - 1)) 354 | elif (shortest_axis == 2): 355 | z = np.linspace(input_min[shortest_axis] - eps, 356 | input_max[shortest_axis] + eps, resolution) 357 | length = np.max(z) - np.min(z) 358 | x = np.arange(input_min[0] - eps, input_max[0] + length / (z.shape[0] - 1) + eps, length / (z.shape[0] - 1)) 359 | y = np.arange(input_min[1] - eps, input_max[1] + length / (z.shape[0] - 1) + eps, length / (z.shape[0] - 1)) 360 | 361 | xx, yy, zz = np.meshgrid(x, y, z) 362 | grid_points = torch.tensor(np.vstack([xx.ravel(), yy.ravel(), zz.ravel()]).T, dtype=torch.float).cuda() 363 | return {"grid_points":grid_points, 364 | "shortest_axis_length":length, 365 | "xyz":[x,y,z], 366 | "shortest_axis_index":shortest_axis} 367 | 368 | 369 | def get_grid_uniform(resolution): 370 | x = np.linspace(-1.2,1.2, resolution) 371 | y = x 372 | z = x 373 | 374 | xx, yy, zz = np.meshgrid(x, y, z) 375 | grid_points = utils.to_cuda(torch.tensor(np.vstack([xx.ravel(), yy.ravel(), zz.ravel()]).T, dtype=torch.float)) 376 | 377 | return {"grid_points": grid_points, 378 | "shortest_axis_length": 2.4, 379 | "xyz": [x, y, z], 380 | "shortest_axis_index": 0} -------------------------------------------------------------------------------- /sdf/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aldehydecho/ReFU/9476bb1da1b8adc1001d21758e884aefb91ffd60/sdf/model/__init__.py -------------------------------------------------------------------------------- /sdf/model/network.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.nn as nn 3 | import torch 4 | from torch.autograd import grad 5 | 6 | 7 | def gradient(inputs, outputs): 8 | d_points = torch.ones_like(outputs, requires_grad=False, device=outputs.device) 9 | points_grad = grad( 10 | outputs=outputs, 11 | inputs=inputs, 12 | grad_outputs=d_points, 13 | create_graph=True, 14 | retain_graph=True, 15 | only_inputs=True)[0][:, -3:] 16 | return points_grad 17 | 18 | 19 | class ImplicitNet(nn.Module): 20 | def __init__( 21 | self, 22 | d_in, 23 | dims, 24 | skip_in=(), 25 | geometric_init=True, 26 | radius_init=1, 27 | beta=100 28 | ): 29 | super().__init__() 30 | 31 | dims = [d_in] + dims + [1] 32 | 33 | self.num_layers = len(dims) 34 | self.skip_in = skip_in 35 | 36 | for layer in range(0, self.num_layers - 1): 37 | 38 | if layer + 1 in skip_in: 39 | out_dim = dims[layer + 1] - d_in 40 | else: 41 | out_dim = dims[layer + 1] 42 | 43 | lin = nn.Linear(dims[layer], out_dim) 44 | 45 | # if true preform preform geometric initialization 46 | if geometric_init: 47 | 48 | if layer == self.num_layers - 2: 49 | 50 | torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims[layer]), std=0.00001) 51 | torch.nn.init.constant_(lin.bias, -radius_init) 52 | else: 53 | torch.nn.init.constant_(lin.bias, 0.0) 54 | 55 | torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) 56 | 57 | setattr(self, "lin" + str(layer), lin) 58 | 59 | if beta > 0: 60 | self.activation = nn.Softplus(beta=beta) 61 | 62 | # vanilla relu 63 | else: 64 | self.activation = nn.ReLU() 65 | 66 | def forward(self, input): 67 | 68 | x = input 69 | 70 | for layer in range(0, self.num_layers - 1): 71 | 72 | lin = getattr(self, "lin" + str(layer)) 73 | 74 | if layer in self.skip_in: 75 | x = torch.cat([x, input], -1) / np.sqrt(2) 76 | 77 | x = lin(x) 78 | 79 | if layer < self.num_layers - 2: 80 | x = self.activation(x) 81 | 82 | return x 83 | -------------------------------------------------------------------------------- /sdf/model/sample.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import igrutils.general as utils 3 | import abc 4 | 5 | 6 | class Sampler(metaclass=abc.ABCMeta): 7 | 8 | @abc.abstractmethod 9 | def get_points(self,pc_input): 10 | pass 11 | 12 | @staticmethod 13 | def get_sampler(sampler_type): 14 | 15 | return utils.get_class("model.sample.{0}".format(sampler_type)) 16 | 17 | 18 | class NormalPerPoint(Sampler): 19 | 20 | def __init__(self, global_sigma, local_sigma=0.01): 21 | self.global_sigma = global_sigma 22 | self.local_sigma = local_sigma 23 | 24 | def get_points(self, pc_input, local_sigma=None): 25 | batch_size, sample_size, dim = pc_input.shape 26 | 27 | if local_sigma is not None: 28 | sample_local = pc_input + (torch.randn_like(pc_input) * local_sigma.unsqueeze(-1)) 29 | else: 30 | sample_local = pc_input + (torch.randn_like(pc_input) * self.local_sigma) 31 | 32 | sample_global = (torch.rand(batch_size, sample_size // 8, dim, device=pc_input.device) * (self.global_sigma * 2)) - self.global_sigma 33 | # // Floor division 34 | sample = torch.cat([sample_local, sample_global], dim=1) 35 | 36 | return sample 37 | -------------------------------------------------------------------------------- /sdf/shapespace/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aldehydecho/ReFU/9476bb1da1b8adc1001d21758e884aefb91ffd60/sdf/shapespace/__init__.py -------------------------------------------------------------------------------- /sdf/shapespace/smpl_setup.conf: -------------------------------------------------------------------------------- 1 | 2 | train{ 3 | d_in = 3 4 | plot_frequency = 1 5 | checkpoint_frequency = 1 6 | status_frequency = 20 7 | preprocess = True 8 | latent_size = 82 9 | dataset_path = /mnt/session_space/TailorNet_dataset/shirt_male 10 | dataset = datasets.smpl.shirt_male 11 | train_file_list = shirt_male_train_file_name_list_new.pkl 12 | test_file_list = shirt_male_test_file_name_list_new.pkl 13 | weight_decay = 0 14 | learning_rate_schedule = [{ 15 | "Type" : "Step", 16 | "Initial" : 0.005, 17 | "Interval" : 500, 18 | "Factor" : 0.5 19 | }] 20 | network_class = model.network.ImplicitNet 21 | } 22 | 23 | plot{ 24 | resolution = 100 25 | mc_value = 0.0 26 | is_uniform_grid = False 27 | verbose = False 28 | save_html = True 29 | save_ply = True 30 | overwrite = True 31 | } 32 | network{ 33 | inputs{ 34 | dims = [1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024] 35 | skip_in = [4] 36 | geometric_init= True 37 | radius_init = 1 38 | beta=100 39 | } 40 | sampler{ 41 | sampler_type = NormalPerPoint 42 | properties{ 43 | global_sigma = 2.0 44 | local_sigma = 0.01 45 | } 46 | } 47 | loss{ 48 | lambda = 0.1 49 | sdf_lambda = 2.0 50 | normals_lambda = 1.0 51 | latent_lambda = 1e-3 52 | penalty_lambda = 0.1 53 | } 54 | } 55 | --------------------------------------------------------------------------------