├── Engine ├── __init__.py └── th_utils │ ├── __init__.py │ ├── animation │ ├── __init__.py │ ├── util.py │ └── uv_generator.py │ ├── config.py │ ├── distributed │ ├── __init__.py │ ├── distributed.py │ └── sampler.py │ ├── dp_lookup_renderer.py │ ├── files.py │ ├── geometry.py │ ├── grid_sample_fix.py │ ├── io │ ├── __init__.py │ ├── prints.py │ └── visualizer.py │ ├── load_smpl_tmp.py │ ├── my_pytorch3d │ ├── __init__.py │ ├── mesh_io.py │ ├── smpl_util.py │ ├── textures.py │ ├── util │ │ └── sample_points_from_meshes.py │ └── vis.py │ ├── networks │ ├── __init__.py │ ├── attention.py │ ├── base_module.py │ ├── discriminator │ │ ├── __init__.py │ │ ├── model.py │ │ └── vqperceptual.py │ ├── embedder.py │ ├── load.py │ ├── loss.py │ ├── loss_vqgan │ │ ├── __init__.py │ │ ├── lpips.py │ │ └── taming_util.py │ ├── losses.py │ ├── nerf_net_utils.py │ ├── nerf_render.py │ ├── nerf_util │ │ ├── __init__.py │ │ ├── base_utils.py │ │ ├── config.py │ │ ├── nerf_data_util.py │ │ ├── nerf_net_utils.py │ │ └── yacs.py │ ├── net_utils.py │ ├── networks.py │ ├── stylegan.py │ └── util │ │ └── image_pool.py │ ├── num.py │ └── util.py ├── LICENSE ├── README.md ├── configs ├── __init__.py ├── config_util.py ├── datasets │ └── zju │ │ ├── base.yml │ │ ├── motion_313_fv.yml │ │ ├── motion_315_fv.yml │ │ ├── motion_377_fv.yml │ │ ├── motion_386_fv.yml │ │ ├── motion_387_fv.yml │ │ └── motion_394_fv.yml ├── defaults.py ├── defaults.yml ├── methods │ ├── motion.yml │ ├── vrnr.py │ └── vrnr.yml ├── projects │ └── uvm.yml └── vrnr_setup.py ├── docs └── figs │ ├── summary.jpg │ └── test_example.jpg ├── download_models.py ├── requirements.txt ├── scripts └── zju │ ├── 313_test.sh │ ├── 313_train.sh │ ├── 315_test.sh │ ├── 315_train.sh │ ├── 377_test.sh │ ├── 377_train.sh │ ├── 386_test.sh │ ├── 386_train.sh │ ├── 387_test.sh │ ├── 387_train.sh │ ├── 394_test.sh │ └── 394_train.sh ├── test.py ├── train_dist.py └── uvm_lib ├── __init__.py ├── base_options ├── __init__.py ├── base_options.py ├── evaluate_options.py ├── motion_setup.py ├── other_setup.py ├── test_options.py ├── train_options.py └── vrnr_setup.py ├── data ├── __init__.py ├── base_data_loader.py ├── base_dataset.py ├── custom_dataset_data_loader.py ├── data_loader.py └── dataset_zju.py ├── models ├── __init__.py ├── base_model.py ├── model_motion.py ├── models.py ├── nerf_render.py ├── net_PosFeature.py ├── net_nerf_uvMotion.py └── net_smooth.py ├── options ├── __init__.py ├── evaluation_option.py ├── project_option.py ├── test_option.py └── train_option.py └── util ├── html.py ├── smpl_renderer.py ├── util.py └── visualizer.py /Engine/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TaoHuUMD/SurMo/ef68beea0a4615a85cceecaa35472d7525e592fb/Engine/__init__.py -------------------------------------------------------------------------------- /Engine/th_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TaoHuUMD/SurMo/ef68beea0a4615a85cceecaa35472d7525e592fb/Engine/th_utils/__init__.py -------------------------------------------------------------------------------- /Engine/th_utils/animation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TaoHuUMD/SurMo/ef68beea0a4615a85cceecaa35472d7525e592fb/Engine/th_utils/animation/__init__.py -------------------------------------------------------------------------------- /Engine/th_utils/animation/util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | def get_can_pose_snapshot(init_pose_, is_dense = False): 5 | 6 | can_list=[] 7 | ds = [-1, 0, 1] 8 | if is_dense: 9 | step = 120 10 | ds = np.linspace(-1, 1, step) 11 | for arm_ele in ds: 12 | new_poses = np.zeros((24, 3)) 13 | init_pose = init_pose_.reshape(24, 3) 14 | 15 | new_poses[0] = init_pose[0] 16 | 17 | new_poses[13] = arm_ele * init_pose[13] 18 | new_poses[14] = arm_ele * init_pose[14] 19 | 20 | can_list.append(new_poses) 21 | 22 | return can_list 23 | 24 | def map_normalized_dp_to_tex_pytorch(img, norm_iuv_img, tex_res, fillconst=0): 25 | 26 | device = img.device 27 | tex = torch.ones((tex_res, tex_res, img.shape[2])).to(device) * fillconst 28 | tex_mask = torch.zeros((tex_res, tex_res)).to(device) 29 | 30 | valid_iuv = norm_iuv_img[norm_iuv_img[:, :, 0] > 0] 31 | valid_iuv = valid_iuv.cpu().numpy() 32 | 33 | if valid_iuv.size==0: 34 | return tex, tex_mask 35 | 36 | u_I = np.round(valid_iuv[:, 0] * (tex.shape[1] - 1)).astype(np.int32) 37 | v_I = np.round((1 - valid_iuv[:, 1]) * (tex.shape[0] - 1)).astype(np.int32) 38 | 39 | data = img[norm_iuv_img[:, :, 0] > 0] 40 | 41 | tex[v_I, u_I] = data 42 | tex_mask[v_I, u_I] = 1 43 | 44 | return tex, tex_mask 45 | 46 | 47 | def map_normalized_dp_to_tex(img, norm_iuv_img, tex_res, fillconst=128): 48 | tex = np.ones((tex_res, tex_res, img.shape[2])) * fillconst 49 | tex_mask = np.zeros((tex_res, tex_res)).astype(np.bool) 50 | 51 | # print('norm max, min', norm_iuv_img[:, :, 0].max(), norm_iuv_img[:, :, 0].min()) 52 | valid_iuv = norm_iuv_img[norm_iuv_img[:, :, 0] > 0] 53 | 54 | if valid_iuv.size==0: 55 | return tex, tex_mask 56 | 57 | if valid_iuv[:, 2].max() > 1: 58 | valid_iuv[:, 2] /= 255. 59 | valid_iuv[:, 1] /= 255. 60 | 61 | u_I = np.round(valid_iuv[:, 1] * (tex.shape[1] - 1)).astype(np.int32) 62 | v_I = np.round((1 - valid_iuv[:, 2]) * (tex.shape[0] - 1)).astype(np.int32) 63 | 64 | data = img[norm_iuv_img[:, :, 0] > 0] 65 | 66 | tex[v_I, u_I] = data 67 | tex_mask[v_I, u_I] = 1 68 | 69 | return tex, tex_mask -------------------------------------------------------------------------------- /Engine/th_utils/config.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | cur_path = str(pathlib.Path(__file__).parent.absolute()) 3 | 4 | cfg_smpl_paths = { 5 | "neutral": './data/asset/smpl_data/SMPL_NEUTRAL.pkl', 6 | "male": './data/asset/smpl_data/SMPL_MALE.pkl', 7 | "female": './data/asset/smpl_data/SMPL_FEMALE.pkl' 8 | } -------------------------------------------------------------------------------- /Engine/th_utils/distributed/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TaoHuUMD/SurMo/ef68beea0a4615a85cceecaa35472d7525e592fb/Engine/th_utils/distributed/__init__.py -------------------------------------------------------------------------------- /Engine/th_utils/distributed/distributed.py: -------------------------------------------------------------------------------- 1 | ##code borrowed from EVA3D. 2 | 3 | import math 4 | import pickle 5 | 6 | import torch 7 | from torch import distributed as dist 8 | from torch.utils.data.sampler import Sampler 9 | 10 | 11 | def get_rank(): 12 | if not dist.is_available(): 13 | return 0 14 | 15 | if not dist.is_initialized(): 16 | return 0 17 | 18 | return dist.get_rank() 19 | 20 | 21 | def synchronize(): 22 | if not dist.is_available(): 23 | return 24 | 25 | if not dist.is_initialized(): 26 | return 27 | 28 | world_size = dist.get_world_size() 29 | 30 | if world_size == 1: 31 | return 32 | 33 | dist.barrier() 34 | 35 | 36 | def get_world_size(): 37 | if not dist.is_available(): 38 | return 1 39 | 40 | if not dist.is_initialized(): 41 | return 1 42 | 43 | return dist.get_world_size() 44 | 45 | 46 | def reduce_sum(tensor): 47 | if not dist.is_available(): 48 | return tensor 49 | 50 | if not dist.is_initialized(): 51 | return tensor 52 | 53 | tensor = tensor.clone() 54 | dist.all_reduce(tensor, op=dist.ReduceOp.SUM) 55 | 56 | return tensor 57 | 58 | 59 | def gather_grad(params): 60 | world_size = get_world_size() 61 | 62 | if world_size == 1: 63 | return 64 | 65 | for param in params: 66 | if param.grad is not None: 67 | dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM) 68 | param.grad.data.div_(world_size) 69 | 70 | 71 | def all_gather(data): 72 | world_size = get_world_size() 73 | 74 | if world_size == 1: 75 | return [data] 76 | 77 | buffer = pickle.dumps(data) 78 | storage = torch.ByteStorage.from_buffer(buffer) 79 | tensor = torch.ByteTensor(storage).to('cuda') 80 | 81 | local_size = torch.IntTensor([tensor.numel()]).to('cuda') 82 | size_list = [torch.IntTensor([0]).to('cuda') for _ in range(world_size)] 83 | dist.all_gather(size_list, local_size) 84 | size_list = [int(size.item()) for size in size_list] 85 | max_size = max(size_list) 86 | 87 | tensor_list = [] 88 | for _ in size_list: 89 | tensor_list.append(torch.ByteTensor(size=(max_size,)).to('cuda')) 90 | 91 | if local_size != max_size: 92 | padding = torch.ByteTensor(size=(max_size - local_size,)).to('cuda') 93 | tensor = torch.cat((tensor, padding), 0) 94 | 95 | dist.all_gather(tensor_list, tensor) 96 | 97 | data_list = [] 98 | 99 | for size, tensor in zip(size_list, tensor_list): 100 | buffer = tensor.cpu().numpy().tobytes()[:size] 101 | data_list.append(pickle.loads(buffer)) 102 | 103 | return data_list 104 | 105 | 106 | def reduce_loss_dict(loss_dict): 107 | world_size = get_world_size() 108 | 109 | if world_size < 2: 110 | return loss_dict 111 | 112 | with torch.no_grad(): 113 | keys = [] 114 | losses = [] 115 | 116 | for k in sorted(loss_dict.keys()): 117 | keys.append(k) 118 | losses.append(loss_dict[k]) 119 | 120 | losses = torch.stack(losses, 0) 121 | dist.reduce(losses, dst=0) 122 | 123 | if dist.get_rank() == 0: 124 | losses /= world_size 125 | 126 | reduced_losses = {k: v for k, v in zip(keys, losses)} 127 | 128 | return reduced_losses 129 | -------------------------------------------------------------------------------- /Engine/th_utils/distributed/sampler.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data.distributed import DistributedSampler 2 | from torch.utils.data.sampler import WeightedRandomSampler 3 | from torch.utils.data import Dataset, Sampler 4 | from torch.utils import data 5 | from typing import Optional 6 | from operator import itemgetter 7 | 8 | def data_sampler(dataset, shuffle, distributed, world_size, rank): 9 | if distributed: 10 | return data.distributed.DistributedSampler(dataset, shuffle=shuffle, num_replicas=world_size, rank=rank) 11 | 12 | if shuffle: 13 | return data.RandomSampler(dataset) 14 | 15 | else: 16 | return data.SequentialSampler(dataset) 17 | 18 | class DatasetFromSampler(Dataset): 19 | """Dataset to create indexes from `Sampler`. 20 | Args: 21 | sampler: PyTorch sampler 22 | """ 23 | 24 | def __init__(self, sampler: Sampler): 25 | """Initialisation for DatasetFromSampler.""" 26 | self.sampler = sampler 27 | self.sampler_list = None 28 | 29 | def __getitem__(self, index: int): 30 | """Gets element of the dataset. 31 | Args: 32 | index: index of the element in the dataset 33 | Returns: 34 | Single element by index 35 | """ 36 | if self.sampler_list is None: 37 | self.sampler_list = list(self.sampler) 38 | return self.sampler_list[index] 39 | 40 | def __len__(self) -> int: 41 | """ 42 | Returns: 43 | int: length of the dataset 44 | """ 45 | return len(self.sampler) 46 | 47 | #borrowed from EVA3D 48 | class DistributedSamplerWrapper(DistributedSampler): 49 | """ 50 | Wrapper over `Sampler` for distributed training. 51 | Allows you to use any sampler in distributed mode. 52 | It is especially useful in conjunction with 53 | `torch.nn.parallel.DistributedDataParallel`. In such case, each 54 | process can pass a DistributedSamplerWrapper instance as a DataLoader 55 | sampler, and load a subset of subsampled data of the original dataset 56 | that is exclusive to it. 57 | .. note:: 58 | Sampler is assumed to be of constant size. 59 | """ 60 | 61 | def __init__( 62 | self, 63 | sampler, 64 | num_replicas: Optional[int] = None, 65 | rank: Optional[int] = None, 66 | shuffle: bool = True, 67 | ): 68 | """ 69 | Args: 70 | sampler: Sampler used for subsampling 71 | num_replicas (int, optional): Number of processes participating in 72 | distributed training 73 | rank (int, optional): Rank of the current process 74 | within ``num_replicas`` 75 | shuffle (bool, optional): If true (default), 76 | sampler will shuffle the indices 77 | """ 78 | super(DistributedSamplerWrapper, self).__init__( 79 | DatasetFromSampler(sampler), 80 | num_replicas=num_replicas, 81 | rank=rank, 82 | shuffle=shuffle, 83 | ) 84 | self.sampler = sampler 85 | 86 | def __iter__(self): 87 | """@TODO: Docs. Contribution is welcome.""" 88 | self.dataset = DatasetFromSampler(self.sampler) 89 | indexes_of_indexes = super().__iter__() 90 | subsampler_indexes = self.dataset 91 | return iter(itemgetter(*indexes_of_indexes)(subsampler_indexes)) 92 | -------------------------------------------------------------------------------- /Engine/th_utils/dp_lookup_renderer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | #from PIL import Image 5 | 6 | import matplotlib.pylab as plt 7 | import numpy as np 8 | import torch.nn.functional as F 9 | 10 | 11 | def getDPTexListTensor(combinedtex_b): 12 | r, c = 4, 6 13 | combinedtex_b = combinedtex_b.permute(0, 2, 3, 1) 14 | psize = int(combinedtex_b.shape[1] / r) 15 | 16 | dptexlist = torch.zeros(combinedtex_b.shape[0], 24, psize, psize, combinedtex_b.shape[-1]).cuda() 17 | count = 0 18 | for i in range(r): 19 | for j in range(c): 20 | dptexlist[:, count] = combinedtex_b[:, i * psize:i * psize + psize, j * psize:j * psize + psize, : ] 21 | count += 1 22 | return dptexlist 23 | 24 | 25 | class DPLookupRenderer(nn.Module): 26 | 27 | '''Given an IUV denspose image (iuvimage) and densepose texture(dptex), propogates the texture image by differentiable sampling (lookup_mode supports bilinear and nearest)''' 28 | def __init__(self, lookup_mode = 'bilinear'): 29 | super(DPLookupRenderer, self).__init__() 30 | self.lookup_mode = lookup_mode 31 | 32 | 33 | def forward(self, dp_comb, iuv_image_b): 34 | #input dp_tex, iuv_image 35 | 36 | dp_tex = getDPTexListTensor(dp_comb) 37 | 38 | iuv_image_b = iuv_image_b.permute(0, 2, 3, 1) 39 | 40 | nbatch = iuv_image_b.shape[0] 41 | rendered = torch.zeros(nbatch, dp_tex.shape[-1], iuv_image_b.shape[1], iuv_image_b.shape[2]).cuda() 42 | flowzero = torch.ones(nbatch, iuv_image_b.shape[1], iuv_image_b.shape[2], 2).cuda() * 5 #5 is random invalid number 43 | 44 | # for all 24 texmap 45 | for i in range(dp_tex.shape[1]): 46 | flow = torch.where(torch.unsqueeze(iuv_image_b[:, :, :, 0] == (i + 1), -1), iuv_image_b[:, :, :, 1:], 47 | flowzero) 48 | 49 | input_t = dp_tex[:,i, ...].permute(0, 3, 2, 1) 50 | out_t = grid_sample_fix.grid_sample(input_t, flow, mode=self.lookup_mode, align_corners=True) 51 | #print(out_t.max(), dp_tex[i].max(), flow.max(), flow.min(), input_t.shape, flow.shape, out_t.shape) 52 | rendered += out_t 53 | 54 | return rendered 55 | 56 | def forward_single(self, dp_tex, iuv_image_b): 57 | # input dp_tex, iuv_image 58 | 59 | iuv_image_b = iuv_image_b.permute(0, 2, 3, 1) 60 | 61 | nbatch = iuv_image_b.shape[0] 62 | rendered = torch.zeros(nbatch, dp_tex.shape[-1], iuv_image_b.shape[1], iuv_image_b.shape[2]).cuda() 63 | flowzero = torch.ones(nbatch, iuv_image_b.shape[1], iuv_image_b.shape[2], 64 | 2).cuda() * 5 # 5 is random invalid number 65 | 66 | # for all 24 texmap 67 | for i in range(dp_tex.shape[0]): 68 | flow = torch.where(torch.unsqueeze(iuv_image_b[:, :, :, 0] == (i + 1), -1), iuv_image_b[:, :, :, 1:], 69 | flowzero) 70 | 71 | input_t = dp_tex[i].unsqueeze(0).repeat(nbatch, 1, 1, 1).permute(0, 3, 2, 1) 72 | out_t = grid_sample_fix.grid_sample(input_t, flow, mode=self.lookup_mode, align_corners=True) 73 | # print(out_t.max(), dp_tex[i].max(), flow.max(), flow.min(), input_t.shape, flow.shape, out_t.shape) 74 | rendered += out_t 75 | 76 | return rendered 77 | 78 | class DPLookupRendererNormal(nn.Module): 79 | 80 | '''Given an IUV denspose image (iuvimage) and densepose texture(dptex), propogates the texture image by differentiable sampling (lookup_mode supports bilinear and nearest)''' 81 | def __init__(self, lookup_mode = 'bilinear'): 82 | super(DPLookupRendererNormal, self).__init__() 83 | self.lookup_mode = lookup_mode 84 | 85 | 86 | def forward(self, normal_tex, normal_flow_b): 87 | nbatch = normal_flow_b.shape[0] 88 | normal_flow_b = normal_flow_b.permute(0, 2, 3, 1) 89 | flowzero = torch.ones(nbatch, normal_flow_b.shape[1], normal_flow_b.shape[2], 2).float().cuda() * 5 90 | 91 | 92 | 93 | flow = torch.where(torch.unsqueeze(normal_flow_b[:, :, :, 0] == 1, -1), normal_flow_b[:, :, :, 1:], flowzero) 94 | input_t = normal_tex 95 | 96 | #print('flow,', flow.shape) 97 | out_t = grid_sample_fix.grid_sample(input_t, flow, mode=self.lookup_mode, align_corners=True) 98 | 99 | return out_t 100 | 101 | def forward_single(self, normal_tex, normal_flow_b): 102 | nbatch = normal_flow_b.shape[0] 103 | flowzero = torch.ones(nbatch, normal_flow_b.shape[1], normal_flow_b.shape[2], 2) * 5 104 | 105 | flow = torch.where(torch.unsqueeze(normal_flow_b[:, :, :, 0] == 1, -1), normal_flow_b[:, :, :, 1:], flowzero) 106 | input_t = normal_tex.unsqueeze(0).repeat(nbatch, 1, 1, 1).permute(0, 3, 2, 1) 107 | out_t = grid_sample_fix.grid_sample(input_t, flow, mode=self.lookup_mode, align_corners=True) 108 | 109 | return out_t 110 | 111 | 112 | if __name__ == '__main__': 113 | import sys 114 | 115 | 116 | sys.path.extend(['/HPS/impl_deep_volume/static00/detectron2/projects/DensePose/', 117 | '/HPS/impl_deep_volume/static00/detectron2/projects/DensePose/densepose', 118 | '/HPS/impl_deep_volume/static00/tex2shape/lib']) 119 | 120 | from structures import DensePoseResult 121 | # from maps import map_densepose_to_tex, normalize 122 | 123 | from detectron2.structures.boxes import BoxMode 124 | from util.densepose_utils import uvTransform, getDPImg, showDPtex, uvTransformDP, renderDP 125 | import pickle 126 | 127 | import cv2 128 | import six 129 | 130 | 131 | 132 | f = open('/HPS/impl_deep_volume3/static00/exp_data/marc_densepose/densepose.pkl', 'rb') 133 | dposer = pickle.load(f) 134 | dp = dposer[1] 135 | 136 | 137 | 138 | img = cv2.imread(dp['file_name']) 139 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255 140 | dp_image = getDPImg((256, 256, 3), dp) 141 | dp_tex = uvTransformDP(img, dp, 64) 142 | dp_tex = np.random.random((24, 64, 64, 128)) 143 | 144 | 145 | #plt.imshow(dp_image) 146 | #plt.show() 147 | 148 | #preprocessing for dataloader, get_i function 149 | iuv_image = dp_image.copy().astype(dtype=np.float32) 150 | iuv_image[:, :, 1:] = (iuv_image[:, :, 1:] / 255.0) 151 | iuv_image[:, :, 1:] = (iuv_image[:, :, 1:] - 0.5) * 2 152 | 153 | #batch formation and to tensor 154 | iuv_image_b = torch.from_numpy(iuv_image) 155 | iuv_image_b = torch.unsqueeze(iuv_image_b, dim=0) 156 | 157 | dp_tex = torch.from_numpy(dp_tex).float() 158 | 159 | 160 | rendered = model(dp_tex, iuv_image_b) 161 | rendered_viz = rendered.permute(0, 2, 3, 1).numpy() 162 | plt.imshow(rendered_viz[0].sum(axis = -1)) 163 | plt.show() 164 | 165 | 166 | -------------------------------------------------------------------------------- /Engine/th_utils/files.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import os 4 | import pickle 5 | 6 | def save_dict(d, p): 7 | os.makedirs(os.path.dirname(p), exist_ok=True) 8 | with open(p, 'wb') as f: 9 | pickle.dump(d, f, pickle.HIGHEST_PROTOCOL) 10 | print("keys saved ", len(d.keys())) 11 | 12 | def load_dict(path): 13 | with open(path, 'rb') as f: 14 | return pickle.load(f) 15 | 16 | def get_croped_size(opt, mode="all"): 17 | 18 | ratio = opt.gen_ratio if not opt.only_nerf else opt.nerf_ratio 19 | 20 | h, w = opt.img_H * ratio, opt.img_W * ratio 21 | 22 | img_gen_h, img_gen_w = int(h), int(w) 23 | 24 | if opt.is_crop: 25 | img_gen_h = img_gen_h - opt.y0_crop - opt.y1_crop 26 | img_gen_w = img_gen_w - opt.x0_crop - opt.x1_crop 27 | 28 | if mode == "all" and not opt.only_nerf: 29 | img_nerf_h, img_nerf_w = int(img_gen_h * opt.nerf_ratio / opt.gen_ratio), int(img_gen_w * opt.nerf_ratio / opt.gen_ratio) 30 | return (img_gen_h, img_gen_w, img_nerf_h, img_nerf_w) 31 | else: 32 | return (img_gen_h, img_gen_w) 33 | 34 | 35 | dir="/home/th/projects/neural_body/neuralbody/configs" 36 | #rename_files(dir) 37 | 38 | #dir1="/fs/vulcan-projects/egocentric_video/neural_body/data/result/nr/result/wxd/wxd_12345_12345/wxd_12345_12345/smplpix_dnr_720_noaug_wxd_12345_12345/test_3_100/pose_dense/d0_gt.mp4" 39 | #dir2="/fs/vulcan-projects/egocentric_video/neural_body/data/result/nr/result/wxd/wxd_12345_12345/wxd_12345_12345/gt3.mp4" 40 | 41 | dir1="/fs/vulcan-projects/egocentric_video/neural_body/data/result/nr/result/bw/bw_12345_12345/bw_12345_12345/smplpix_dnr_720_noaug_bw_12345_12345/test_5_102/pose_dense/d0_gt.mp4" 42 | dir2="/fs/vulcan-projects/egocentric_video/neural_body/data/result/nr/result/bw/gt5.mp4" 43 | 44 | trans_video_format(dir1, dir2) -------------------------------------------------------------------------------- /Engine/th_utils/geometry.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import math 4 | 5 | import Engine.th_utils.grid_sample_fix as grid_sample_fix 6 | 7 | def rotationX(angle): 8 | return [ 9 | [1, 0, 0], 10 | [0, math.cos(angle), -math.sin(angle)], 11 | [0, math.sin(angle), math.cos(angle)], 12 | ] 13 | 14 | 15 | def rotationY(angle): 16 | return [ 17 | [math.cos(angle), 0, math.sin(angle)], 18 | [0, 1, 0], 19 | [-math.sin(angle), 0, math.cos(angle)], 20 | ] 21 | 22 | 23 | def rotationZ(angle): 24 | return [ 25 | [math.cos(angle), -math.sin(angle), 0], 26 | [math.sin(angle), math.cos(angle), 0], 27 | [0, 0, 1], 28 | ] 29 | 30 | def batch_cross_3d(a, b): 31 | c = torch.zeros(a.shape[0], 3) 32 | c[:, 0], c[:, 1], c[:, 2] = a[:, 1]*b[:, 2]-a[:, 2]*b[:, 1], b[:, 0]*a[:, 2]-a[:, 0]*b[:, 2], a[:, 0]*b[:, 1]-b[:, 0]*a[:, 1] 33 | return c 34 | 35 | def cross_3d(a, b): 36 | return np.array([a[1]*b[2]-a[2]*b[1], b[0]*a[2]-a[0]*b[2], a[0]*b[1]-b[0]*a[1]]) 37 | 38 | def index(feat, uv, size=None): 39 | ''' 40 | :param feat: [B, C, H, W] image features 41 | :param uv: [B, 2, N] uv coordinates in the image plane, range [-1, 1] 42 | :return: [B, C, N] image features at the uv coordinates 43 | ''' 44 | uv = uv.transpose(1, 2) # [B, N, 2] 45 | uv = uv.unsqueeze(2) # [B, N, 1, 2] 46 | if size != None: 47 | uv = (uv - size / 2) / (size / 2) 48 | # NOTE: for newer PyTorch, it seems that training results are degraded due to implementation diff in grid_sample_fix.grid_sample 49 | # for old versions, simply remove the aligned_corners argument. 50 | samples = grid_sample_fix.grid_sample(feat, uv, align_corners=True) # [B, C, N, 1] 51 | return samples[:, :, :, 0] # [B, C, N] 52 | 53 | 54 | def orthogonal(points, calibrations, transforms=None): 55 | ''' 56 | Compute the orthogonal projections of 3D points into the image plane by given projection matrix 57 | :param points: [B, 3, N] Tensor of 3D points 58 | :param calibrations: [B, 4, 4] Tensor of projection matrix 59 | :param transforms: [B, 2, 3] Tensor of image transform matrix 60 | :return: xyz: [B, 3, N] Tensor of xyz coordinates in the image plane 61 | ''' 62 | rot = calibrations[:, :3, :3] 63 | trans = calibrations[:, :3, 3:4] 64 | pts = torch.baddbmm(trans, rot, points) # [B, 3, N] 65 | if transforms is not None: 66 | scale = transforms[:2, :2] 67 | shift = transforms[:2, 2:3] 68 | pts[:, :2, :] = torch.baddbmm(shift, scale, pts[:, :2, :]) 69 | return pts 70 | 71 | 72 | def perspective(points, calibrations, transforms=None): 73 | ''' 74 | Compute the perspective projections of 3D points into the image plane by given projection matrix 75 | :param points: [Bx3xN] Tensor of 3D points 76 | :param calibrations: [Bx4x4] Tensor of projection matrix 77 | :param transforms: [Bx2x3] Tensor of image transform matrix 78 | :return: xy: [Bx2xN] Tensor of xy coordinates in the image plane 79 | ''' 80 | B, _, N = points.shape 81 | device = points.device 82 | 83 | points = torch.cat([points, torch.ones((B, 1, N), device=device)], dim=1) 84 | 85 | #print(calibrations.shape, points.shape, calibrations.dtype, points.dtype) 86 | 87 | points = calibrations @ points 88 | points[:, :2, :] /= points[:, 2:, :] 89 | return points 90 | 91 | def rotationMatrixToAngles(R): 92 | """ 93 | R : (bs, 3, 3) 94 | """ 95 | # print(R.shape) 96 | sy = torch.sqrt(R[:, 0, 0] * R[:, 0, 0] + R[:, 1, 0] * R[:, 1, 0]) 97 | singular = sy < 1e-6 98 | mask = ~singular 99 | x = torch.zeros(R.shape[0]) 100 | y = torch.zeros(R.shape[0]) 101 | z = torch.zeros(R.shape[0]) 102 | if torch.sum(mask): 103 | x[mask] = torch.atan2(R[mask, 2, 1], R[mask, 2, 2]) 104 | y[mask] = torch.atan2(-R[mask, 2, 0], sy[mask]) 105 | z[mask] = torch.atan2(R[mask, 1, 0], R[mask, 0, 0]) 106 | if torch.sum(singular): 107 | x[singular] = math.atan2(-R[singular, 1, 2], R[singular, 1, 1]) 108 | y[singular] = torch.atan2(-R[singular, 2, 0], sy[singular]) 109 | z[singular] = 0 110 | return torch.cat([x.unsqueeze(1), y.unsqueeze(1), z.unsqueeze(1)], dim=1)# np.array([x, y, z]) -------------------------------------------------------------------------------- /Engine/th_utils/grid_sample_fix.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom replacement for `torch.nn.functional.grid_sample` that 10 | supports arbitrarily high order gradients between the input and output. 11 | Only works on 2D images and assumes 12 | `mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`.""" 13 | 14 | import torch 15 | from pkg_resources import parse_version 16 | 17 | # pylint: disable=redefined-builtin 18 | # pylint: disable=arguments-differ 19 | # pylint: disable=protected-access 20 | 21 | #---------------------------------------------------------------------------- 22 | 23 | enabled = True # Enable the custom op by setting this to true. 24 | _use_pytorch_1_11_api = parse_version(torch.__version__) >= parse_version('1.11.0a') # Allow prerelease builds of 1.11 25 | 26 | #---------------------------------------------------------------------------- 27 | 28 | def grid_sample(input, grid, mode='bilinear', padding_mode='zeros', align_corners=False): 29 | if _should_use_custom_op(): 30 | return _GridSample2dForward.apply(input, grid) 31 | return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | def _should_use_custom_op(): 36 | return enabled 37 | 38 | #---------------------------------------------------------------------------- 39 | 40 | class _GridSample2dForward(torch.autograd.Function): 41 | @staticmethod 42 | def forward(ctx, input, grid): 43 | assert input.ndim == 4 44 | assert grid.ndim == 4 45 | output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) 46 | ctx.save_for_backward(input, grid) 47 | return output 48 | 49 | @staticmethod 50 | def backward(ctx, grad_output): 51 | input, grid = ctx.saved_tensors 52 | grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid) 53 | return grad_input, grad_grid 54 | 55 | #---------------------------------------------------------------------------- 56 | 57 | class _GridSample2dBackward(torch.autograd.Function): 58 | @staticmethod 59 | def forward(ctx, grad_output, input, grid): 60 | op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward') 61 | if _use_pytorch_1_11_api: 62 | output_mask = (ctx.needs_input_grad[1], ctx.needs_input_grad[2]) 63 | grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False, output_mask) 64 | else: 65 | grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) 66 | ctx.save_for_backward(grid) 67 | return grad_input, grad_grid 68 | 69 | @staticmethod 70 | def backward(ctx, grad2_grad_input, grad2_grad_grid): 71 | _ = grad2_grad_grid # unused 72 | grid, = ctx.saved_tensors 73 | grad2_grad_output = None 74 | grad2_input = None 75 | grad2_grid = None 76 | 77 | if ctx.needs_input_grad[0]: 78 | grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid) 79 | 80 | #assert not ctx.needs_input_grad[2] 81 | return grad2_grad_output, grad2_input, grad2_grid 82 | 83 | #---------------------------------------------------------------------------- -------------------------------------------------------------------------------- /Engine/th_utils/io/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TaoHuUMD/SurMo/ef68beea0a4615a85cceecaa35472d7525e592fb/Engine/th_utils/io/__init__.py -------------------------------------------------------------------------------- /Engine/th_utils/io/prints.py: -------------------------------------------------------------------------------- 1 | from termcolor import colored 2 | 3 | def printy(*arg): 4 | print(colored(arg,'yellow')) 5 | 6 | def prints(a, s=""): 7 | print("************") 8 | if a is None: 9 | print("None " + s) 10 | return 11 | print(a.shape, s) 12 | print("************") 13 | 14 | def printd(*arg): 15 | print("************") 16 | printy(arg) 17 | print("************") 18 | 19 | def printb(*arg): 20 | print(colored(arg,'blue')) 21 | 22 | def printg(*arg): 23 | print(colored(arg,'green')) 24 | 25 | def printr(*arg): 26 | print(colored(arg,'red')) 27 | 28 | 29 | def print_data(data, s=""): 30 | 31 | if isinstance(data, tuple): 32 | printg(f"{s} tuple ", len(data)) 33 | for d in data: print_data(d) 34 | elif isinstance(data, list): 35 | printg(f"{s} list ", len(data)) 36 | for d in data: print_data(d) 37 | elif isinstance(data, dict): 38 | printg(f"{s} dict ", data.keys()) 39 | for d in data.keys(): print_data(data[d]) 40 | elif hasattr(data, '__len__') and (not isinstance(data, str)): #ndarray 41 | printy(f"{s} arr ", type(data), data.shape) 42 | else: 43 | printy(f"{s} scalar ", data) 44 | 45 | -------------------------------------------------------------------------------- /Engine/th_utils/io/visualizer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .. import html 3 | try: 4 | from StringIO import StringIO # Python 2.7 5 | except ImportError: 6 | from io import BytesIO # Python 3.x 7 | 8 | 9 | class Visualizer(): 10 | def __init__(self, dir, win_size = 3000): 11 | # self.opt = opt 12 | #self.tf_log = opt.tf_log 13 | #self.use_html = opt.isTrain and not opt.no_html 14 | self.use_html = True #not (opt.phase == "test") 15 | self.win_size = win_size 16 | 17 | if self.use_html: 18 | self.web_dir = os.path.join(dir, 'web') 19 | print('create web directory %s...' % self.web_dir) 20 | for d in [self.web_dir]: 21 | os.makedirs(d, exist_ok=True) 22 | 23 | self.webpage = html.HTML(self.web_dir, '') 24 | 25 | def save_images(self, img, image_name, text="train"): 26 | 27 | image_dir = self.webpage.get_image_dir() 28 | 29 | self.webpage.add_header(image_name) 30 | ims = [] 31 | txts = [] 32 | links = [] 33 | 34 | ims.append(image_name) 35 | txts.append(text) 36 | links.append(image_name) 37 | self.webpage.add_images(ims, txts, links, width=self.win_size) 38 | 39 | def save(self): 40 | self.webpage.save() 41 | -------------------------------------------------------------------------------- /Engine/th_utils/my_pytorch3d/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TaoHuUMD/SurMo/ef68beea0a4615a85cceecaa35472d7525e592fb/Engine/th_utils/my_pytorch3d/__init__.py -------------------------------------------------------------------------------- /Engine/th_utils/my_pytorch3d/mesh_io.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | 5 | from pytorch3d.renderer.blending import ( 6 | BlendParams, 7 | hard_rgb_blend, 8 | sigmoid_alpha_blend, 9 | softmax_rgb_blend, 10 | ) 11 | from pytorch3d.renderer.lighting import PointLights 12 | from pytorch3d.renderer.materials import Materials 13 | from pytorch3d.renderer.mesh.shading import flat_shading, gouraud_shading, phong_shading 14 | 15 | from pytorch3d.ops import interpolate_face_attributes 16 | 17 | from typing import List, Optional 18 | from pytorch3d.io import load_obj, save_obj 19 | 20 | import numpy as np 21 | 22 | from pytorch3d.renderer import ( 23 | TexturesUV, 24 | SfMPerspectiveCameras, 25 | DirectionalLights, 26 | RasterizationSettings, 27 | MeshRenderer, 28 | MeshRasterizer, 29 | ) 30 | from pytorch3d.structures import Meshes 31 | 32 | def load_obj_mesh_tex( 33 | files: list, 34 | device = None 35 | ): 36 | mesh_list = [] 37 | 38 | texture_map = np.zeros((256, 256, 3)).astype(np.uint8) 39 | texture_map = texture_map.astype(np.float32) / 255. 40 | 41 | 42 | for f_obj in files: 43 | verts, faces, aux = load_obj( 44 | f_obj 45 | ) 46 | tex = None 47 | 48 | verts_uvs = aux.verts_uvs.to(device) # (V, 2) 49 | faces_uvs = faces.textures_idx.to(device) # (F, 3) 50 | 51 | mesh_tex = TexturesUV(maps=[torch.from_numpy(texture_map).to(device=device)], faces_uvs=[faces_uvs], 52 | verts_uvs=[verts_uvs]) 53 | mesh = Meshes( 54 | verts=[verts.to(device)], faces=[faces.verts_idx.to(device)], textures=mesh_tex 55 | ) 56 | 57 | mesh_list.append(mesh) 58 | if len(mesh_list) == 1: 59 | return mesh_list[0] 60 | return join_meshes_as_batch(mesh_list) 61 | 62 | def smpl_pkl_to_mesh(): 63 | # Load SMPL and texture data 64 | with open(verts_filename, 'rb') as f: 65 | data = pickle.load(f, encoding='latin1') 66 | v_template = torch.Tensor(data['v_template']).to(device) # (6890, 3) 67 | ALP_UV = loadmat(data_filename) 68 | tex = torch.from_numpy(_read_image(file_name=tex_filename, format='RGB') / 255.).unsqueeze(0).to(device) 69 | 70 | verts = torch.from_numpy((ALP_UV["All_vertices"]).astype(int)).squeeze().to(device) # (7829, 1) 71 | U = torch.Tensor(ALP_UV['All_U_norm']).to(device) # (7829, 1) 72 | V = torch.Tensor(ALP_UV['All_V_norm']).to(device) # (7829, 1) 73 | faces = torch.from_numpy((ALP_UV['All_Faces'] - 1).astype(int)).to(device) # (13774, 3) 74 | face_indices = torch.Tensor(ALP_UV['All_FaceIndices']).squeeze() 75 | 76 | # Map each face to a (u, v) offset 77 | offset_per_part = {} 78 | already_offset = set() 79 | cols, rows = 4, 6 80 | for i, u in enumerate(np.linspace(0, 1, cols, endpoint=False)): 81 | for j, v in enumerate(np.linspace(0, 1, rows, endpoint=False)): 82 | part = rows * i + j + 1 # parts are 1-indexed in face_indices 83 | offset_per_part[part] = (u, v) 84 | 85 | # iterate over faces and offset the corresponding vertex u and v values 86 | for i in range(len(faces)): 87 | face_vert_idxs = faces[i] 88 | part = face_indices[i] 89 | offset_u, offset_v = offset_per_part[int(part.item())] 90 | 91 | for vert_idx in face_vert_idxs: 92 | # vertices are reused, but we don't want to offset multiple times 93 | if vert_idx.item() not in already_offset: 94 | # offset u value 95 | U[vert_idx] = U[vert_idx] / cols + offset_u 96 | # offset v value 97 | # this also flips each part locally, as each part is upside down 98 | V[vert_idx] = (1 - V[vert_idx]) / rows + offset_v 99 | # add vertex to our set tracking offsetted vertices 100 | already_offset.add(vert_idx.item()) 101 | 102 | # invert V values 103 | U_norm, V_norm = U, 1 - V 104 | 105 | 106 | # create our verts_uv values 107 | verts_uv = torch.cat([U_norm[None], V_norm[None]], dim=2) # (1, 7829, 2) 108 | 109 | # There are 6890 xyz vertex coordinates but 7829 vertex uv coordinates. 110 | # This is because the same vertex can be shared by multiple faces where each face may correspond to a different body part. 111 | # Therefore when initializing the Meshes class, 112 | # we need to map each of the vertices referenced by the DensePose faces (in verts, which is the "All_vertices" field) 113 | # to the correct xyz coordinate in the SMPL template mesh. 114 | v_template_extended = torch.stack(list(map(lambda vert: v_template[vert - 1], verts))).unsqueeze(0).to( 115 | device) # (1, 7829, 3) 116 | 117 | # add a batch dimension to faces 118 | faces = faces.unsqueeze(0) 119 | 120 | 121 | #load pkl files 122 | def load_smpl_pytorch3d( 123 | files: list, 124 | device = None 125 | ): 126 | 127 | mesh_list = [] 128 | 129 | 130 | 131 | 132 | 133 | 134 | texture_map = np.zeros((256, 256, 3)).astype(np.uint8) 135 | texture_map = texture_map.astype(np.float32) / 255. 136 | 137 | for f_obj in files: 138 | verts, faces, aux = load_obj( 139 | f_obj 140 | ) 141 | tex = None 142 | 143 | verts_uvs = aux.verts_uvs.to(device) # (V, 2) 144 | faces_uvs = faces.textures_idx.to(device) # (F, 3) 145 | 146 | # image = list(tex_maps.values())[0].to(device)[None] 147 | 148 | # tex = TexturesUV( 149 | # verts_uvs=[verts_uvs], faces_uvs=[faces_uvs], maps=image 150 | # ) 151 | 152 | mesh_tex = TexturesUV(maps=[torch.from_numpy(texture_map).to(device=device)], faces_uvs=[faces_uvs], 153 | verts_uvs=[verts_uvs]) 154 | mesh = Meshes( 155 | verts=[verts.to(device)], faces=[faces.verts_idx.to(device)], textures=mesh_tex 156 | ) 157 | 158 | mesh_list.append(mesh) 159 | if len(mesh_list) == 1: 160 | return mesh_list[0] 161 | return join_meshes_as_batch(mesh_list) 162 | 163 | 164 | def load_mesh( 165 | files: list, 166 | device=None, 167 | load_textures: bool = True, 168 | create_texture_atlas: bool = False, 169 | texture_atlas_size: int = 4, 170 | texture_wrap: Optional[str] = "repeat", 171 | ): 172 | """ 173 | Load meshes from a list of .obj files using the load_obj function, and 174 | return them as a Meshes object. This only works for meshes which have a 175 | single texture image for the whole mesh. See the load_obj function for more 176 | details. material_colors and normals are not stored. 177 | 178 | Args: 179 | f: A list of file-like objects (with methods read, readline, tell, 180 | and seek), pathlib paths or strings containing file names. 181 | device: Desired device of returned Meshes. Default: 182 | uses the current device for the default tensor type. 183 | load_textures: Boolean indicating whether material files are loaded 184 | 185 | Returns: 186 | New Meshes object. 187 | """ 188 | mesh_list = [] 189 | for f_obj in files: 190 | verts, faces, aux = load_obj( 191 | f_obj, 192 | load_textures=load_textures, 193 | create_texture_atlas=create_texture_atlas, 194 | texture_atlas_size=texture_atlas_size, 195 | texture_wrap=texture_wrap, 196 | ) 197 | tex = None 198 | if create_texture_atlas: 199 | # TexturesAtlas type 200 | tex = TexturesAtlas(atlas=[aux.texture_atlas.to(device)]) 201 | else: 202 | # TexturesUV type 203 | tex_maps = aux.texture_images 204 | 205 | print(tex_maps) 206 | print("tex") 207 | exit() 208 | 209 | if tex_maps is not None and len(tex_maps) > 0: 210 | verts_uvs = aux.verts_uvs.to(device) # (V, 2) 211 | faces_uvs = faces.textures_idx.to(device) # (F, 3) 212 | image = list(tex_maps.values())[0].to(device)[None] 213 | tex = TexturesUV( 214 | verts_uvs=[verts_uvs], faces_uvs=[faces_uvs], maps=image 215 | ) 216 | 217 | mesh = Meshes( 218 | verts=[verts.to(device)], faces=[faces.verts_idx.to(device)], textures=tex 219 | ) 220 | mesh_list.append(mesh) 221 | if len(mesh_list) == 1: 222 | return mesh_list[0] 223 | return join_meshes_as_batch(mesh_list) -------------------------------------------------------------------------------- /Engine/th_utils/my_pytorch3d/textures.py: -------------------------------------------------------------------------------- 1 | from pytorch3d.io import load_obj, save_obj 2 | import numpy as np 3 | from torch._C import device 4 | import torch 5 | 6 | import os 7 | 8 | from pytorch3d.renderer import ( 9 | TexturesUV, 10 | SfMPerspectiveCameras, 11 | DirectionalLights, 12 | RasterizationSettings, 13 | MeshRenderer, 14 | MeshRasterizer, 15 | ) 16 | 17 | def get_smpl_uv_vts(): 18 | 19 | smpl_model_path = None 20 | flist = ["./data/asset/uv_table.npy"] 21 | for f in flist: 22 | if os.path.isfile(f): 23 | smpl_model_path = f 24 | break 25 | if smpl_model_path is None: 26 | print("no smlp uv obj file!") 27 | exit() 28 | 29 | uv = np.load(smpl_model_path) #allow_pickle=True 30 | return torch.from_numpy(uv) 31 | 32 | def get_default_texture_maps(): 33 | 34 | smpl_model_path = None 35 | flist = ["./data/asset/smpl_uv.obj"] 36 | for f in flist: 37 | if os.path.isfile(f): 38 | smpl_model_path = f 39 | break 40 | if smpl_model_path is None: 41 | print("no smlp uv obj file!") 42 | exit() 43 | 44 | texture_map = np.random.rand(256, 256, 3) 45 | texture_map = texture_map.astype(np.float32) 46 | 47 | verts, faces, aux = load_obj( 48 | smpl_model_path 49 | ) 50 | tex = None 51 | 52 | device="cuda" 53 | 54 | verts_uvs = aux.verts_uvs.to(device) # (V, 2) 55 | faces_uvs = faces.textures_idx.to(device) # (F, 3) 56 | 57 | mesh_tex = TexturesUV(maps=[torch.from_numpy(texture_map).to(device=device)], \ 58 | faces_uvs=[faces_uvs], verts_uvs=[verts_uvs]) 59 | return mesh_tex, verts_uvs, faces_uvs -------------------------------------------------------------------------------- /Engine/th_utils/my_pytorch3d/util/sample_points_from_meshes.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | This module implements utility functions for sampling points from 4 | batches of meshes. 5 | """ 6 | import sys 7 | from typing import Tuple, Union 8 | 9 | import torch 10 | from pytorch3d.ops.mesh_face_areas_normals import mesh_face_areas_normals 11 | from pytorch3d.ops.packed_to_padded import packed_to_padded 12 | from pytorch3d.renderer.mesh.rasterizer import Fragments as MeshFragments 13 | 14 | from pytorch3d.ops import interpolate_face_attributes 15 | 16 | def sample_points_from_meshes( 17 | meshes, 18 | num_samples: int = 10000, 19 | return_normals: bool = False, 20 | return_textures: bool = False, 21 | return_uv: bool = False, verts_uvs = None, faces_uvs = None 22 | ) -> Union[ 23 | torch.Tensor, 24 | Tuple[torch.Tensor, torch.Tensor], 25 | Tuple[torch.Tensor, torch.Tensor, torch.Tensor], 26 | ]: 27 | """ 28 | Convert a batch of meshes to a batch of pointclouds by uniformly sampling 29 | points on the surface of the mesh with probability proportional to the 30 | face area. 31 | Args: 32 | meshes: A Meshes object with a batch of N meshes. 33 | num_samples: Integer giving the number of point samples per mesh. 34 | return_normals: If True, return normals for the sampled points. 35 | return_textures: If True, return textures for the sampled points. 36 | Returns: 37 | 3-element tuple containing 38 | - **samples**: FloatTensor of shape (N, num_samples, 3) giving the 39 | coordinates of sampled points for each mesh in the batch. For empty 40 | meshes the corresponding row in the samples array will be filled with 0. 41 | - **normals**: FloatTensor of shape (N, num_samples, 3) giving a normal vector 42 | to each sampled point. Only returned if return_normals is True. 43 | For empty meshes the corresponding row in the normals array will 44 | be filled with 0. 45 | - **textures**: FloatTensor of shape (N, num_samples, C) giving a C-dimensional 46 | texture vector to each sampled point. Only returned if return_textures is True. 47 | For empty meshes the corresponding row in the textures array will 48 | be filled with 0. 49 | Note that in a future releases, we will replace the 3-element tuple output 50 | with a `Pointclouds` datastructure, as follows 51 | .. code-block:: python 52 | Pointclouds(samples, normals=normals, features=textures) 53 | """ 54 | if meshes.isempty(): 55 | raise ValueError("Meshes are empty.") 56 | 57 | verts = meshes.verts_packed() 58 | if not torch.isfinite(verts).all(): 59 | raise ValueError("Meshes contain nan or inf.") 60 | 61 | if return_textures and meshes.textures is None: 62 | raise ValueError("Meshes do not contain textures.") 63 | 64 | faces = meshes.faces_packed() 65 | mesh_to_face = meshes.mesh_to_faces_packed_first_idx() 66 | num_meshes = len(meshes) 67 | num_valid_meshes = torch.sum(meshes.valid) # Non empty meshes. 68 | 69 | num_samples = int(num_samples) 70 | # Initialize samples tensor with fill value 0 for empty meshes. 71 | samples = torch.zeros((num_meshes, num_samples, 3), device=meshes.device) 72 | 73 | # Only compute samples for non empty meshes 74 | with torch.no_grad(): 75 | areas, _ = mesh_face_areas_normals(verts, faces) # Face areas can be zero. 76 | max_faces = meshes.num_faces_per_mesh().max().item() 77 | areas_padded = packed_to_padded( 78 | areas, mesh_to_face[meshes.valid], max_faces 79 | ) # (N, F) 80 | 81 | # TODO (gkioxari) Confirm multinomial bug is not present with real data. 82 | sample_face_idxs = areas_padded.multinomial( 83 | num_samples, replacement=True 84 | ) # (N, num_samples) 85 | sample_face_idxs += mesh_to_face[meshes.valid].view(num_valid_meshes, 1) 86 | 87 | # Get the vertex coordinates of the sampled faces. 88 | face_verts = verts[faces] 89 | v0, v1, v2 = face_verts[:, 0], face_verts[:, 1], face_verts[:, 2] 90 | 91 | # Randomly generate barycentric coords. 92 | w0, w1, w2 = _rand_barycentric_coords( 93 | num_valid_meshes, num_samples, verts.dtype, verts.device 94 | ) 95 | 96 | # Use the barycentric coords to get a point on each sampled face. 97 | a = v0[sample_face_idxs] # (N, num_samples, 3) 98 | b = v1[sample_face_idxs] 99 | c = v2[sample_face_idxs] 100 | samples[meshes.valid] = w0[:, :, None] * a + w1[:, :, None] * b + w2[:, :, None] * c 101 | 102 | if return_normals: 103 | # Initialize normals tensor with fill value 0 for empty meshes. 104 | # Normals for the sampled points are face normals computed from 105 | # the vertices of the face in which the sampled point lies. 106 | normals = torch.zeros((num_meshes, num_samples, 3), device=meshes.device) 107 | vert_normals = (v1 - v0).cross(v2 - v1, dim=1) 108 | vert_normals = vert_normals / vert_normals.norm(dim=1, p=2, keepdim=True).clamp( 109 | min=sys.float_info.epsilon 110 | ) 111 | vert_normals = vert_normals[sample_face_idxs] 112 | normals[meshes.valid] = vert_normals 113 | 114 | if return_textures: 115 | # fragment data are of shape NxHxWxK. Here H=S, W=1 & K=1. 116 | pix_to_face = sample_face_idxs.view(len(meshes), num_samples, 1, 1) # NxSx1x1 117 | bary = torch.stack((w0, w1, w2), dim=2).unsqueeze(2).unsqueeze(2) # NxSx1x1x3 118 | # zbuf and dists are not used in `sample_textures` so we initialize them with dummy 119 | dummy = torch.zeros( 120 | (len(meshes), num_samples, 1, 1), device=meshes.device, dtype=torch.float32 121 | ) # NxSx1x1 122 | fragments = MeshFragments( 123 | pix_to_face=pix_to_face, zbuf=dummy, bary_coords=bary, dists=dummy 124 | ) 125 | textures = meshes.sample_textures(fragments) # NxSx1x1xC 126 | textures = textures[:, :, 0, 0, :] # NxSxC 127 | 128 | if return_uv: 129 | 130 | # fragment data are of shape NxHxWxK. Here H=S, W=1 & K=1. 131 | pix_to_face = sample_face_idxs.view(len(meshes), num_samples, 1, 1) # NxSx1x1 132 | bary = torch.stack((w0, w1, w2), dim=2).unsqueeze(2).unsqueeze(2) # NxSx1x1x3 133 | 134 | # assert verts_uvs is not None 135 | # assert faces_uvs is not None 136 | 137 | # print(verts_uvs.shape, faces_uvs.shape) 138 | 139 | # packing_list = [ 140 | # i[j] for i, j in zip(verts_uvs, faces_uvs) 141 | # ] 142 | # faces_verts_uvs = torch.cat(packing_list) 143 | 144 | packing_list = [ 145 | i[j] for i, j in zip(meshes.textures.verts_uvs_list(), meshes.textures.faces_uvs_list()) 146 | ] 147 | faces_verts_uvs = torch.cat(packing_list) 148 | 149 | #batch = len(meshes) 150 | texture_maps = meshes.textures.maps_padded() 151 | 152 | bary = torch.stack((w0, w1, w2), dim=2).unsqueeze(2).unsqueeze(2) # NxSx1x1x3 153 | 154 | uvs = interpolate_face_attributes(pix_to_face, bary, faces_verts_uvs) 155 | 156 | N, H_out, W_out, K = pix_to_face.shape 157 | N, H_in, W_in, C = texture_maps.shape 158 | 159 | pts_uvs = uvs.permute(0,3,1,2,4).reshape(N*K, H_out, W_out, 2)[:,:,0,:] 160 | 161 | # return 162 | # TODO(gkioxari) consider returning a Pointclouds instance [breaking] 163 | 164 | if return_normals and return_uv: 165 | return samples, normals, pts_uvs 166 | if return_normals and return_textures: 167 | # pyre-fixme[61]: `normals` may not be initialized here. 168 | # pyre-fixme[61]: `textures` may not be initialized here. 169 | return samples, normals, textures 170 | if return_normals: # return_textures is False 171 | # pyre-fixme[61]: `normals` may not be initialized here. 172 | return samples, normals 173 | if return_textures: # return_normals is False 174 | # pyre-fixme[61]: `textures` may not be initialized here. 175 | return samples, textures 176 | if return_uv: 177 | return samples, pts_uvs 178 | return samples 179 | 180 | 181 | def _rand_barycentric_coords( 182 | size1, size2, dtype: torch.dtype, device: torch.device 183 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 184 | """ 185 | Helper function to generate random barycentric coordinates which are uniformly 186 | distributed over a triangle. 187 | Args: 188 | size1, size2: The number of coordinates generated will be size1*size2. 189 | Output tensors will each be of shape (size1, size2). 190 | dtype: Datatype to generate. 191 | device: A torch.device object on which the outputs will be allocated. 192 | Returns: 193 | w0, w1, w2: Tensors of shape (size1, size2) giving random barycentric 194 | coordinates 195 | """ 196 | uv = torch.rand(2, size1, size2, dtype=dtype, device=device) 197 | u, v = uv[0], uv[1] 198 | u_sqrt = u.sqrt() 199 | w0 = 1.0 - u_sqrt 200 | w1 = u_sqrt * (1.0 - v) 201 | w2 = u_sqrt * v 202 | # pyre-fixme[7]: Expected `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` but 203 | # got `Tuple[float, typing.Any, typing.Any]`. 204 | return w0, w1, w2 205 | 206 | -------------------------------------------------------------------------------- /Engine/th_utils/my_pytorch3d/vis.py: -------------------------------------------------------------------------------- 1 | from pytorch3d.renderer.cameras import PerspectiveCameras, look_at_view_transform 2 | from pytorch3d.renderer.mesh import shader, SoftPhongShader 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | from .textures import get_default_texture_maps 9 | 10 | import numpy as np 11 | 12 | from pytorch3d.renderer import ( 13 | TexturesUV, 14 | SfMPerspectiveCameras, 15 | DirectionalLights, 16 | RasterizationSettings, 17 | MeshRenderer, 18 | MeshRasterizer, 19 | ) 20 | from pytorch3d.structures import Meshes 21 | 22 | from pytorch3d.structures import Pointclouds 23 | from pytorch3d.renderer import ( 24 | look_at_view_transform, 25 | FoVOrthographicCameras, 26 | PointsRasterizationSettings, 27 | PointsRenderer, 28 | PulsarPointsRenderer, 29 | PointsRasterizer, 30 | AlphaCompositor, 31 | NormWeightedCompositor 32 | ) 33 | 34 | def get_camera(R, T, K, device): 35 | cameras = PerspectiveCameras(device = device, \ 36 | R=R, T=T, K=K) # focal_length=focal_length 37 | 38 | 39 | fx, fy = K[0][0][0], K[0][1][1] 40 | px, py = K[0][0][2], K[0][1][2] 41 | image_size = (1024,1024) 42 | print(fx, fy, px, py) 43 | 44 | off_cameras = PerspectiveCameras(device=device, R=R, T=T, focal_length=\ 45 | ((fx, fy),), \ 46 | principal_point=((px, py),), image_size=(image_size,)) 47 | 48 | return off_cameras 49 | 50 | def get_pointcloud(vertices): 51 | 52 | vertices = vertices.reshape(-1,3) 53 | rgb = np.random.rand(vertices.cpu().numpy().shape[0],3) 54 | rgb = torch.from_numpy(rgb).float().cuda() 55 | #torch.Tensor(pointcloud['rgb']).to(device) 56 | 57 | point_cloud = Pointclouds(points=[vertices], features=[rgb]) 58 | return point_cloud 59 | 60 | def get_mesh(vertices, faces): 61 | tex_maps = get_default_texture_maps() 62 | 63 | if not torch.is_tensor(vertices): 64 | vertices = torch.from_numpy(vertices).float().cuda() 65 | 66 | if faces is not None and (not torch.is_tensor(faces)): 67 | faces = torch.from_numpy(faces).float().cuda() 68 | 69 | batch_size = vertices.shape[0] 70 | device = vertices.device 71 | 72 | is_render_mesh = True 73 | if faces is None: 74 | is_render_mesh = False 75 | else: 76 | faces = faces.expand(batch_size, -1, -1) 77 | 78 | mesh = Meshes(verts = vertices, faces = faces, 79 | textures=tex_maps).cuda() 80 | return mesh 81 | 82 | def vis(mesh=None, pointcloud=None, camera=None): 83 | 84 | if pointcloud is not None: 85 | scenes = { 86 | "subplot_title": { 87 | "mesh_trace_title": mesh, 88 | "pointcloud_trace_title": pointcloud, 89 | "cameras_trace_title": camera 90 | } 91 | } 92 | else: 93 | scenes = { 94 | "subplot_title": { 95 | "mesh_trace_title": mesh, 96 | "cameras_trace_title": camera 97 | } 98 | } 99 | 100 | from pytorch3d.vis.plotly_vis import plot_scene 101 | fig = plot_scene(scenes, viewpoint_cameras=camera) 102 | #fig = plot_scene(scenes) 103 | fig.show() -------------------------------------------------------------------------------- /Engine/th_utils/networks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TaoHuUMD/SurMo/ef68beea0a4615a85cceecaa35472d7525e592fb/Engine/th_utils/networks/__init__.py -------------------------------------------------------------------------------- /Engine/th_utils/networks/base_module.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TaoHuUMD/SurMo/ef68beea0a4615a85cceecaa35472d7525e592fb/Engine/th_utils/networks/base_module.py -------------------------------------------------------------------------------- /Engine/th_utils/networks/discriminator/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TaoHuUMD/SurMo/ef68beea0a4615a85cceecaa35472d7525e592fb/Engine/th_utils/networks/discriminator/__init__.py -------------------------------------------------------------------------------- /Engine/th_utils/networks/discriminator/model.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import torch.nn as nn 3 | 4 | 5 | #from taming.modules.util import ActNorm 6 | 7 | def weights_init(m): 8 | classname = m.__class__.__name__ 9 | if classname.find('Conv') != -1: 10 | nn.init.normal_(m.weight.data, 0.0, 0.02) 11 | elif classname.find('BatchNorm') != -1: 12 | nn.init.normal_(m.weight.data, 1.0, 0.02) 13 | nn.init.constant_(m.bias.data, 0) 14 | 15 | 16 | class NLayerDiscriminator(nn.Module): 17 | """Defines a PatchGAN discriminator as in Pix2Pix 18 | --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py 19 | """ 20 | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): 21 | """Construct a PatchGAN discriminator 22 | Parameters: 23 | input_nc (int) -- the number of channels in input images 24 | ndf (int) -- the number of filters in the last conv layer 25 | n_layers (int) -- the number of conv layers in the discriminator 26 | norm_layer -- normalization layer 27 | """ 28 | super(NLayerDiscriminator, self).__init__() 29 | if not use_actnorm: 30 | norm_layer = nn.BatchNorm2d 31 | else: 32 | norm_layer = ActNorm 33 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters 34 | use_bias = norm_layer.func != nn.BatchNorm2d 35 | else: 36 | use_bias = norm_layer != nn.BatchNorm2d 37 | 38 | kw = 4 39 | padw = 1 40 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] 41 | nf_mult = 1 42 | nf_mult_prev = 1 43 | for n in range(1, n_layers): # gradually increase the number of filters 44 | nf_mult_prev = nf_mult 45 | nf_mult = min(2 ** n, 8) 46 | sequence += [ 47 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), 48 | norm_layer(ndf * nf_mult), 49 | nn.LeakyReLU(0.2, True) 50 | ] 51 | 52 | nf_mult_prev = nf_mult 53 | nf_mult = min(2 ** n_layers, 8) 54 | sequence += [ 55 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), 56 | norm_layer(ndf * nf_mult), 57 | nn.LeakyReLU(0.2, True) 58 | ] 59 | 60 | sequence += [ 61 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map 62 | self.main = nn.Sequential(*sequence) 63 | 64 | def forward(self, input): 65 | """Standard forward.""" 66 | return self.main(input) 67 | -------------------------------------------------------------------------------- /Engine/th_utils/networks/discriminator/vqperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | #from Eng.modules.losses.lpips import LPIPS 6 | #from taming.modules.discriminator.model import NLayerDiscriminator, weights_init 7 | 8 | from Engine.th_utils.networks.loss_vqgan.lpips import LPIPS 9 | from Engine.th_utils.networks.discriminator.model import NLayerDiscriminator, weights_init 10 | 11 | class DummyLoss(nn.Module): 12 | def __init__(self): 13 | super().__init__() 14 | 15 | 16 | def adopt_weight(weight, global_step, threshold=0, value=0.): 17 | if global_step < threshold: 18 | weight = value 19 | return weight 20 | 21 | 22 | def hinge_d_loss(logits_real, logits_fake): 23 | loss_real = torch.mean(F.relu(1. - logits_real)) 24 | loss_fake = torch.mean(F.relu(1. + logits_fake)) 25 | d_loss = 0.5 * (loss_real + loss_fake) 26 | return d_loss, loss_real, loss_fake 27 | 28 | 29 | def vanilla_d_loss(logits_real, logits_fake): 30 | d_loss = 0.5 * ( 31 | torch.mean(torch.nn.functional.softplus(-logits_real)) + 32 | torch.mean(torch.nn.functional.softplus(logits_fake))) 33 | return d_loss 34 | 35 | 36 | class VQLPIPSWithDiscriminator(nn.Module): 37 | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, 38 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 39 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 40 | disc_ndf=64, disc_loss="hinge"): 41 | super().__init__() 42 | assert disc_loss in ["hinge", "vanilla"] 43 | self.codebook_weight = codebook_weight 44 | self.pixel_weight = pixelloss_weight 45 | self.perceptual_loss = LPIPS().eval() 46 | self.perceptual_weight = perceptual_weight 47 | 48 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 49 | n_layers=disc_num_layers, 50 | use_actnorm=use_actnorm, 51 | ndf=disc_ndf 52 | ).apply(weights_init) 53 | self.discriminator_iter_start = disc_start 54 | if disc_loss == "hinge": 55 | self.disc_loss = hinge_d_loss 56 | elif disc_loss == "vanilla": 57 | self.disc_loss = vanilla_d_loss 58 | else: 59 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.") 60 | print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") 61 | self.disc_factor = disc_factor 62 | self.discriminator_weight = disc_weight 63 | self.disc_conditional = disc_conditional 64 | 65 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 66 | if last_layer is not None: 67 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 68 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 69 | else: 70 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 71 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 72 | 73 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 74 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 75 | d_weight = d_weight * self.discriminator_weight 76 | return d_weight 77 | 78 | def forward(self, inputs, reconstructions, optimizer_idx, 79 | global_step, last_layer=None, cond=0, split="train", img_mask = None): 80 | 81 | if cond != 0: 82 | inputs, _ = torch.split(inputs, [inputs.shape[1] - cond, cond], dim=1) 83 | reconstructions, cinput_ = torch.split(reconstructions, [reconstructions.shape[1] - cond, cond], dim=1) 84 | cond_input = cinput_ 85 | 86 | if img_mask is not None: 87 | img_mask_ = img_mask.repeat(1, int(inputs.shape[1] // img_mask.shape[1]), 1, 1) 88 | pixel_loss = torch.abs(inputs.contiguous() * img_mask_ - reconstructions.contiguous() * img_mask_) * self.pixel_weight 89 | else: 90 | pixel_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) * self.pixel_weight 91 | 92 | if self.perceptual_weight > 0: 93 | #percept_loss = [] 94 | percept_loss = torch.tensor([0.0]).to(inputs.device) 95 | i = 0 96 | while i < inputs.shape[1]: 97 | #percept_loss.append(self.perceptual_loss(inputs.contiguous()[:,i:i+3,...], reconstructions.contiguous()[:,i:i+3,...])) 98 | percept_loss += (self.perceptual_loss(inputs.contiguous()[:,i:i+3,...], reconstructions.contiguous()[:,i:i+3,...]))[0][0][0] 99 | i += 3 100 | percept_loss /= (inputs.shape[1] // 3) 101 | percept_loss = self.perceptual_weight * percept_loss #torch.mean(percept_loss) 102 | else: 103 | percept_loss = torch.tensor([0.0]) 104 | #rec_loss = pixel_loss 105 | 106 | nll_loss = pixel_loss.to(inputs.device) + percept_loss.to(inputs.device) 107 | #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 108 | nll_loss = torch.mean(nll_loss) 109 | 110 | # now the GAN part 111 | if optimizer_idx == 0: 112 | # generator update 113 | if cond is None or cond == 0: 114 | assert not self.disc_conditional 115 | logits_fake = self.discriminator(reconstructions.contiguous()) 116 | else: 117 | assert self.disc_conditional 118 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond_input), dim=1)) 119 | g_loss = -torch.mean(logits_fake) 120 | 121 | if last_layer is None: 122 | d_weight = 1.0 123 | else: 124 | try: 125 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 126 | except RuntimeError: 127 | assert not self.training 128 | d_weight = torch.tensor(0.0) 129 | 130 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 131 | g_gan_loss = d_weight * disc_factor * g_loss #+ self.codebook_weight * codebook_loss.mean() 132 | 133 | if False: 134 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), 135 | #"{}/quant_loss".format(split): codebook_loss.detach().mean(), 136 | "{}/nll_loss".format(split): nll_loss.detach().mean(), 137 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 138 | "{}/p_loss".format(split): p_loss.detach().mean(), 139 | "{}/d_weight".format(split): d_weight.detach(), 140 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 141 | "{}/g_loss".format(split): g_loss.detach().mean(), 142 | } 143 | return pixel_loss.mean(), percept_loss.mean(), g_gan_loss, d_weight 144 | return loss, log 145 | 146 | if optimizer_idx == 1: 147 | # second pass for discriminator update 148 | if cond is None or cond == 0: 149 | logits_real = self.discriminator(inputs.contiguous().detach()) 150 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 151 | else: 152 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond_input), dim=1)) 153 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond_input), dim=1)) 154 | 155 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 156 | d_loss, real_loss, fake_loss = self.disc_loss(logits_real, logits_fake) 157 | d_loss *= disc_factor 158 | 159 | return d_loss, real_loss, fake_loss 160 | 161 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 162 | "{}/logits_real".format(split): logits_real.detach().mean(), 163 | "{}/logits_fake".format(split): logits_fake.detach().mean() 164 | } 165 | return d_loss, log 166 | -------------------------------------------------------------------------------- /Engine/th_utils/networks/embedder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class Embedder: 4 | def __init__(self, **kwargs): 5 | self.kwargs = kwargs 6 | self.create_embedding_fn() 7 | 8 | def create_embedding_fn(self): 9 | embed_fns = [] 10 | d = self.kwargs['input_dims'] 11 | out_dim = 0 12 | if self.kwargs['include_input']: 13 | embed_fns.append(lambda x: x) 14 | out_dim += d 15 | 16 | max_freq = self.kwargs['max_freq_log2'] 17 | N_freqs = self.kwargs['num_freqs'] 18 | 19 | if self.kwargs['log_sampling']: 20 | freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs) 21 | else: 22 | freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs) 23 | 24 | for freq in freq_bands: 25 | for p_fn in self.kwargs['periodic_fns']: 26 | embed_fns.append( 27 | lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq)) 28 | out_dim += d 29 | 30 | self.embed_fns = embed_fns 31 | self.out_dim = out_dim 32 | 33 | def embed(self, inputs): 34 | return torch.cat([fn(inputs) for fn in self.embed_fns], -1) 35 | 36 | 37 | def get_embedder(multires, input_dims=3): 38 | embed_kwargs = { 39 | 'include_input': True, 40 | 'input_dims': input_dims, 41 | 'max_freq_log2': multires - 1, 42 | 'num_freqs': multires, 43 | 'log_sampling': True, 44 | 'periodic_fns': [torch.sin, torch.cos], 45 | } 46 | embedder_obj = Embedder(**embed_kwargs) 47 | embed = lambda x, eo=embedder_obj: eo.embed(x) 48 | return embed, embedder_obj.out_dim 49 | 50 | xyz_res = 10 51 | view_res = 4 52 | pose_res = 1 53 | 54 | xyz_embedder, xyz_dim = get_embedder(xyz_res) 55 | view_embedder, view_dim = get_embedder(view_res) 56 | pose_embedder, pose_dim = get_embedder(pose_res, input_dims=72) 57 | shape_embedder, shape_dim = get_embedder(2, input_dims=10) 58 | 59 | uvhn_embedder, uvhn_dim = get_embedder(6, input_dims=6) 60 | uvh_embedder, uvh_dim = get_embedder(6, input_dims=3) 61 | uv_embedder, uv_dim = get_embedder(6, input_dims=2) 62 | 63 | #pose_embedder_2, pose_dim_2 = get_embedder(2, input_dims=72) 64 | -------------------------------------------------------------------------------- /Engine/th_utils/networks/load.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from collections import defaultdict 4 | from typing import Any, cast, Dict, IO, Iterable, List, NamedTuple, Optional, Tuple 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | from iopath.common.file_io import HTTPURLHandler, PathManager 10 | from termcolor import colored 11 | from torch.nn.parallel import DataParallel, DistributedDataParallel 12 | 13 | 14 | TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2]) 15 | if TORCH_VERSION >= (1, 11): 16 | from torch.ao import quantization 17 | from torch.ao.quantization import FakeQuantizeBase, ObserverBase 18 | elif ( 19 | TORCH_VERSION >= (1, 8) 20 | and hasattr(torch.quantization, "FakeQuantizeBase") 21 | and hasattr(torch.quantization, "ObserverBase") 22 | ): 23 | from torch import quantization 24 | from torch.quantization import FakeQuantizeBase, ObserverBase 25 | 26 | __all__ = ["Checkpointer", "PeriodicCheckpointer"] 27 | 28 | 29 | TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2]) 30 | 31 | 32 | class _IncompatibleKeys( 33 | NamedTuple( 34 | "IncompatibleKeys", 35 | [ 36 | ("missing_keys", List[str]), 37 | ("unexpected_keys", List[str]), 38 | ("incorrect_shapes", List[Tuple[str, Tuple[int], Tuple[int]]]), 39 | ], 40 | ) 41 | ): 42 | pass 43 | 44 | 45 | def _load_model(net, checkpoint_state_dict) -> _IncompatibleKeys: 46 | """ 47 | Load weights from a checkpoint. 48 | 49 | Args: 50 | checkpoint (Any): checkpoint contains the weights. 51 | 52 | Returns: 53 | ``NamedTuple`` with ``missing_keys``, ``unexpected_keys``, 54 | and ``incorrect_shapes`` fields: 55 | * **missing_keys** is a list of str containing the missing keys 56 | * **unexpected_keys** is a list of str containing the unexpected keys 57 | * **incorrect_shapes** is a list of (key, shape in checkpoint, shape in model) 58 | 59 | This is just like the return value of 60 | :func:`torch.nn.Module.load_state_dict`, but with extra support 61 | for ``incorrect_shapes``. 62 | """ 63 | #checkpoint_state_dict = checkpoint.pop("model") 64 | #self._convert_ndarray_to_tensor(checkpoint_state_dict) 65 | 66 | # if the state_dict comes from a model that was wrapped in a 67 | # DataParallel or DistributedDataParallel during serialization, 68 | # remove the "module" prefix before performing the matching. 69 | #_strip_prefix_if_present(checkpoint_state_dict, "module.") 70 | 71 | # workaround https://github.com/pytorch/pytorch/issues/24139 72 | model_state_dict = net.state_dict() 73 | 74 | #checkpoint_state_dict = checkpoint #.pop("model") 75 | 76 | incorrect_shapes = [] 77 | for k in list(checkpoint_state_dict.keys()): 78 | if k in model_state_dict: 79 | model_param = model_state_dict[k] 80 | # Allow mismatch for uninitialized parameters 81 | if TORCH_VERSION >= (1, 8) and isinstance( 82 | model_param, nn.parameter.UninitializedParameter 83 | ): 84 | continue 85 | shape_model = tuple(model_param.shape) 86 | shape_checkpoint = tuple(checkpoint_state_dict[k].shape) 87 | if shape_model != shape_checkpoint: 88 | 89 | has_observer_base_classes = ( 90 | TORCH_VERSION >= (1, 8) 91 | and hasattr(quantization, "ObserverBase") 92 | and hasattr(quantization, "FakeQuantizeBase") 93 | ) 94 | if has_observer_base_classes: 95 | # Handle the special case of quantization per channel observers, 96 | # where buffer shape mismatches are expected. 97 | def _get_module_for_key( 98 | model: torch.nn.Module, key: str 99 | ) -> torch.nn.Module: 100 | # foo.bar.param_or_buffer_name -> [foo, bar] 101 | key_parts = key.split(".")[:-1] 102 | cur_module = model 103 | for key_part in key_parts: 104 | cur_module = getattr(cur_module, key_part) 105 | return cur_module 106 | 107 | cls_to_skip = ( 108 | ObserverBase, 109 | FakeQuantizeBase, 110 | ) 111 | target_module = _get_module_for_key(net, k) 112 | if isinstance(target_module, cls_to_skip): 113 | # Do not remove modules with expected shape mismatches 114 | # them from the state_dict loading. They have special logic 115 | # in _load_from_state_dict to handle the mismatches. 116 | continue 117 | 118 | incorrect_shapes.append((k, shape_checkpoint, shape_model)) 119 | checkpoint_state_dict.pop(k) 120 | incompatible = net.load_state_dict(checkpoint_state_dict, strict=False) 121 | print("!!!! incorrect shapes ", incorrect_shapes) 122 | return len(incorrect_shapes) 123 | 124 | return _IncompatibleKeys( 125 | missing_keys=incompatible.missing_keys, 126 | unexpected_keys=incompatible.unexpected_keys, 127 | incorrect_shapes=incorrect_shapes, 128 | ) 129 | 130 | 131 | def update_network(net, prefix="criterion"): 132 | from collections import OrderedDict 133 | net_ = OrderedDict() 134 | for k in net.keys(): 135 | if k.startswith(prefix): 136 | continue 137 | 138 | net_[k] = net[k] 139 | return net_ 140 | """ 141 | Log information about the incompatible keys returned by ``_load_model``. 142 | """ 143 | for k, shape_checkpoint, shape_model in incompatible.incorrect_shapes: 144 | self.logger.warning( 145 | "Skip loading parameter '{}' to the model due to incompatible " 146 | "shapes: {} in the checkpoint but {} in the " 147 | "model! You might want to double check if this is expected.".format( 148 | k, shape_checkpoint, shape_model 149 | ) 150 | ) 151 | if incompatible.missing_keys: 152 | missing_keys = _filter_reused_missing_keys( 153 | self.model, incompatible.missing_keys 154 | ) 155 | if missing_keys: 156 | self.logger.warning(get_missing_parameters_message(missing_keys)) 157 | if incompatible.unexpected_keys: 158 | self.logger.warning( 159 | get_unexpected_parameters_message(incompatible.unexpected_keys) 160 | ) -------------------------------------------------------------------------------- /Engine/th_utils/networks/loss_vqgan/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TaoHuUMD/SurMo/ef68beea0a4615a85cceecaa35472d7525e592fb/Engine/th_utils/networks/loss_vqgan/__init__.py -------------------------------------------------------------------------------- /Engine/th_utils/networks/loss_vqgan/lpips.py: -------------------------------------------------------------------------------- 1 | """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torchvision import models 6 | from collections import namedtuple 7 | 8 | #from .taming_util import get_ckpt_path 9 | from Engine.th_utils.networks.loss_vqgan.taming_util import get_ckpt_path 10 | 11 | 12 | 13 | class LPIPS(nn.Module): 14 | # Learned perceptual metric 15 | def __init__(self, use_dropout=True): 16 | super().__init__() 17 | self.scaling_layer = ScalingLayer() 18 | self.chns = [64, 128, 256, 512, 512] # vg16 features 19 | self.net = vgg16(pretrained=True, requires_grad=False) 20 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 21 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 22 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 23 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 24 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 25 | self.load_from_pretrained() 26 | for param in self.parameters(): 27 | param.requires_grad = False 28 | 29 | def load_from_pretrained(self, name="vgg_lpips"): 30 | ckpt = get_ckpt_path(name, "taming/modules/autoencoder/lpips") 31 | self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) 32 | print("loaded pretrained LPIPS loss from {}".format(ckpt)) 33 | 34 | @classmethod 35 | def from_pretrained(cls, name="vgg_lpips"): 36 | if name != "vgg_lpips": 37 | raise NotImplementedError 38 | model = cls() 39 | ckpt = get_ckpt_path(name) 40 | model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) 41 | return model 42 | 43 | def forward(self, input, target): 44 | in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) 45 | outs0, outs1 = self.net(in0_input), self.net(in1_input) 46 | feats0, feats1, diffs = {}, {}, {} 47 | lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] 48 | for kk in range(len(self.chns)): 49 | feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) 50 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 51 | 52 | res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))] 53 | val = res[0] 54 | for l in range(1, len(self.chns)): 55 | val += res[l] 56 | return val 57 | 58 | 59 | class ScalingLayer(nn.Module): 60 | def __init__(self): 61 | super(ScalingLayer, self).__init__() 62 | self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 63 | self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None]) 64 | 65 | def forward(self, inp): 66 | return (inp - self.shift) / self.scale 67 | 68 | 69 | class NetLinLayer(nn.Module): 70 | """ A single linear layer which does a 1x1 conv """ 71 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 72 | super(NetLinLayer, self).__init__() 73 | layers = [nn.Dropout(), ] if (use_dropout) else [] 74 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] 75 | self.model = nn.Sequential(*layers) 76 | 77 | 78 | class vgg16(torch.nn.Module): 79 | def __init__(self, requires_grad=False, pretrained=True): 80 | super(vgg16, self).__init__() 81 | vgg_pretrained_features = models.vgg16(pretrained=pretrained).features 82 | self.slice1 = torch.nn.Sequential() 83 | self.slice2 = torch.nn.Sequential() 84 | self.slice3 = torch.nn.Sequential() 85 | self.slice4 = torch.nn.Sequential() 86 | self.slice5 = torch.nn.Sequential() 87 | self.N_slices = 5 88 | for x in range(4): 89 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 90 | for x in range(4, 9): 91 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 92 | for x in range(9, 16): 93 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 94 | for x in range(16, 23): 95 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 96 | for x in range(23, 30): 97 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 98 | if not requires_grad: 99 | for param in self.parameters(): 100 | param.requires_grad = False 101 | 102 | def forward(self, X): 103 | h = self.slice1(X) 104 | h_relu1_2 = h 105 | h = self.slice2(h) 106 | h_relu2_2 = h 107 | h = self.slice3(h) 108 | h_relu3_3 = h 109 | h = self.slice4(h) 110 | h_relu4_3 = h 111 | h = self.slice5(h) 112 | h_relu5_3 = h 113 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) 114 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 115 | return out 116 | 117 | 118 | def normalize_tensor(x,eps=1e-10): 119 | norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True)) 120 | return x/(norm_factor+eps) 121 | 122 | 123 | def spatial_average(x, keepdim=True): 124 | return x.mean([2,3],keepdim=keepdim) 125 | 126 | -------------------------------------------------------------------------------- /Engine/th_utils/networks/loss_vqgan/taming_util.py: -------------------------------------------------------------------------------- 1 | import os, hashlib 2 | import requests 3 | from tqdm import tqdm 4 | 5 | URL_MAP = { 6 | "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1" 7 | } 8 | 9 | CKPT_MAP = { 10 | "vgg_lpips": "vgg.pth" 11 | } 12 | 13 | MD5_MAP = { 14 | "vgg_lpips": "d507d7349b931f0638a25a48a722f98a" 15 | } 16 | 17 | 18 | def download(url, local_path, chunk_size=1024): 19 | os.makedirs(os.path.split(local_path)[0], exist_ok=True) 20 | with requests.get(url, stream=True) as r: 21 | total_size = int(r.headers.get("content-length", 0)) 22 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: 23 | with open(local_path, "wb") as f: 24 | for data in r.iter_content(chunk_size=chunk_size): 25 | if data: 26 | f.write(data) 27 | pbar.update(chunk_size) 28 | 29 | 30 | def md5_hash(path): 31 | with open(path, "rb") as f: 32 | content = f.read() 33 | return hashlib.md5(content).hexdigest() 34 | 35 | 36 | def get_ckpt_path(name, root, check=False): 37 | assert name in URL_MAP 38 | path = os.path.join(root, CKPT_MAP[name]) 39 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): 40 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) 41 | download(URL_MAP[name], path) 42 | md5 = md5_hash(path) 43 | assert md5 == MD5_MAP[name], md5 44 | return path 45 | 46 | 47 | class KeyNotFoundError(Exception): 48 | def __init__(self, cause, keys=None, visited=None): 49 | self.cause = cause 50 | self.keys = keys 51 | self.visited = visited 52 | messages = list() 53 | if keys is not None: 54 | messages.append("Key not found: {}".format(keys)) 55 | if visited is not None: 56 | messages.append("Visited: {}".format(visited)) 57 | messages.append("Cause:\n{}".format(cause)) 58 | message = "\n".join(messages) 59 | super().__init__(message) 60 | 61 | 62 | def retrieve( 63 | list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False 64 | ): 65 | """Given a nested list or dict return the desired value at key expanding 66 | callable nodes if necessary and :attr:`expand` is ``True``. The expansion 67 | is done in-place. 68 | 69 | Parameters 70 | ---------- 71 | list_or_dict : list or dict 72 | Possibly nested list or dictionary. 73 | key : str 74 | key/to/value, path like string describing all keys necessary to 75 | consider to get to the desired value. List indices can also be 76 | passed here. 77 | splitval : str 78 | String that defines the delimiter between keys of the 79 | different depth levels in `key`. 80 | default : obj 81 | Value returned if :attr:`key` is not found. 82 | expand : bool 83 | Whether to expand callable nodes on the path or not. 84 | 85 | Returns 86 | ------- 87 | The desired value or if :attr:`default` is not ``None`` and the 88 | :attr:`key` is not found returns ``default``. 89 | 90 | Raises 91 | ------ 92 | Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is 93 | ``None``. 94 | """ 95 | 96 | keys = key.split(splitval) 97 | 98 | success = True 99 | try: 100 | visited = [] 101 | parent = None 102 | last_key = None 103 | for key in keys: 104 | if callable(list_or_dict): 105 | if not expand: 106 | raise KeyNotFoundError( 107 | ValueError( 108 | "Trying to get past callable node with expand=False." 109 | ), 110 | keys=keys, 111 | visited=visited, 112 | ) 113 | list_or_dict = list_or_dict() 114 | parent[last_key] = list_or_dict 115 | 116 | last_key = key 117 | parent = list_or_dict 118 | 119 | try: 120 | if isinstance(list_or_dict, dict): 121 | list_or_dict = list_or_dict[key] 122 | else: 123 | list_or_dict = list_or_dict[int(key)] 124 | except (KeyError, IndexError, ValueError) as e: 125 | raise KeyNotFoundError(e, keys=keys, visited=visited) 126 | 127 | visited += [key] 128 | # final expansion of retrieved value 129 | if expand and callable(list_or_dict): 130 | list_or_dict = list_or_dict() 131 | parent[last_key] = list_or_dict 132 | except KeyNotFoundError as e: 133 | if default is None: 134 | raise e 135 | else: 136 | list_or_dict = default 137 | success = False 138 | 139 | if not pass_success: 140 | return list_or_dict 141 | else: 142 | return list_or_dict, success 143 | 144 | 145 | if __name__ == "__main__": 146 | config = {"keya": "a", 147 | "keyb": "b", 148 | "keyc": 149 | {"cc1": 1, 150 | "cc2": 2, 151 | } 152 | } 153 | from omegaconf import OmegaConf 154 | config = OmegaConf.create(config) 155 | print(config) 156 | retrieve(config, "keya") 157 | 158 | -------------------------------------------------------------------------------- /Engine/th_utils/networks/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import functools 4 | from torch.autograd import Variable 5 | import numpy as np 6 | import torch.nn.functional as F 7 | 8 | class ContrastiveLoss(nn.Module): 9 | """ 10 | Contrastive loss 11 | Takes embeddings of two samples and a target label == 1 if samples are from the same class and label == 0 otherwise 12 | """ 13 | 14 | def __init__(self, margin = 0.1): 15 | super(ContrastiveLoss, self).__init__() 16 | self.margin = margin 17 | self.eps = 1e-9 18 | 19 | #by default it treats as if it is from the same identity 20 | def forward(self, output1, output2, target = 1, size_average=True): 21 | distances = (output2 - output1).pow(2).sum(1) # squared distances 22 | losses = 0.5 * (target * distances + 23 | (1 + -1 * target) * F.relu(self.margin - (distances + self.eps).sqrt()).pow(2)) 24 | return losses.mean() if size_average else losses.sum() 25 | 26 | 27 | class TripletLoss(nn.Module): 28 | 29 | """ 30 | Triplet loss 31 | Takes embeddings of an anchor sample, a positive sample and a negative sample 32 | """ 33 | 34 | def __init__(self, margin): 35 | super(TripletLoss, self).__init__() 36 | self.margin = margin 37 | 38 | def forward(self, anchor, positive, negative, size_average=True): 39 | distance_positive = (anchor - positive).pow(2).sum(1) # .pow(.5) 40 | distance_negative = (anchor - negative).pow(2).sum(1) # .pow(.5) 41 | losses = F.relu(distance_positive - distance_negative + self.margin) 42 | return losses.mean() if size_average else losses.sum() 43 | 44 | class TotalVariation(nn.Module): 45 | r"""Computes the Total Variation according to [1]. 46 | 47 | Shape: 48 | - Input: :math:`(N, C, H, W)` or :math:`(C, H, W)`. 49 | - Output: :math:`(N,)` or scalar. 50 | 51 | Examples: 52 | >>> tv = TotalVariation() 53 | >>> output = tv(torch.ones((2, 3, 4, 4), requires_grad=True)) 54 | >>> output.data 55 | tensor([0., 0.]) 56 | >>> output.sum().backward() # grad can be implicitly created only for scalar outputs 57 | 58 | Reference: 59 | [1] https://en.wikipedia.org/wiki/Total_variation 60 | """ 61 | 62 | def __init__(self, direction="hw"): 63 | super(TotalVariation, self).__init__() 64 | self.direction = direction 65 | 66 | def forward(self, img) -> torch.Tensor: 67 | if self.direction == "hw": 68 | return self.total_variation_hw(img) 69 | elif self.direction == "h": 70 | return self.total_variation_h(img) 71 | elif self.direction == "w": 72 | return self.total_variation_w(img) 73 | 74 | def total_variation_h(self, img: torch.Tensor) -> torch.Tensor: 75 | r"""Function that computes Total Variation according to [1]. 76 | 77 | Args: 78 | img: the input image with shape :math:`(N, C, H, W)` or :math:`(C, H, W)`. 79 | 80 | Return: 81 | a scalar with the computer loss. 82 | 83 | Examples: 84 | >>> total_variation(torch.ones(3, 4, 4)) 85 | tensor(0.) 86 | 87 | Reference: 88 | [1] https://en.wikipedia.org/wiki/Total_variation 89 | """ 90 | if not isinstance(img, torch.Tensor): 91 | raise TypeError(f"Input type is not a torch.Tensor. Got {type(img)}") 92 | 93 | if len(img.shape) < 3 or len(img.shape) > 4: 94 | raise ValueError(f"Expected input tensor to be of ndim 3 or 4, but got {len(img.shape)}.") 95 | 96 | pixel_dif1 = img[..., 1:, :] - img[..., :-1, :] 97 | 98 | reduce_axes = (-3, -2, -1) 99 | res1 = pixel_dif1.abs().sum(dim=reduce_axes) 100 | 101 | return res1 102 | 103 | def total_variation_w(self, img: torch.Tensor) -> torch.Tensor: 104 | r"""Function that computes Total Variation according to [1]. 105 | 106 | Args: 107 | img: the input image with shape :math:`(N, C, H, W)` or :math:`(C, H, W)`. 108 | 109 | Return: 110 | a scalar with the computer loss. 111 | 112 | Examples: 113 | >>> total_variation(torch.ones(3, 4, 4)) 114 | tensor(0.) 115 | 116 | Reference: 117 | [1] https://en.wikipedia.org/wiki/Total_variation 118 | """ 119 | if not isinstance(img, torch.Tensor): 120 | raise TypeError(f"Input type is not a torch.Tensor. Got {type(img)}") 121 | 122 | if len(img.shape) < 3 or len(img.shape) > 4: 123 | raise ValueError(f"Expected input tensor to be of ndim 3 or 4, but got {len(img.shape)}.") 124 | 125 | pixel_dif2 = img[..., :, 1:] - img[..., :, :-1] 126 | 127 | reduce_axes = (-3, -2, -1) 128 | res2 = pixel_dif2.abs().sum(dim=reduce_axes) 129 | 130 | return res2 131 | 132 | def total_variation_hw(self, img: torch.Tensor) -> torch.Tensor: 133 | r"""Function that computes Total Variation according to [1]. 134 | 135 | Args: 136 | img: the input image with shape :math:`(N, C, H, W)` or :math:`(C, H, W)`. 137 | 138 | Return: 139 | a scalar with the computer loss. 140 | 141 | Examples: 142 | >>> total_variation(torch.ones(3, 4, 4)) 143 | tensor(0.) 144 | 145 | Reference: 146 | [1] https://en.wikipedia.org/wiki/Total_variation 147 | """ 148 | if not isinstance(img, torch.Tensor): 149 | raise TypeError(f"Input type is not a torch.Tensor. Got {type(img)}") 150 | 151 | if len(img.shape) < 3 or len(img.shape) > 4: 152 | raise ValueError(f"Expected input tensor to be of ndim 3 or 4, but got {len(img.shape)}.") 153 | 154 | pixel_dif1 = img[..., 1:, :] - img[..., :-1, :] 155 | pixel_dif2 = img[..., :, 1:] - img[..., :, :-1] 156 | 157 | reduce_axes = (-3, -2, -1) 158 | res1 = pixel_dif1.abs().sum(dim=reduce_axes) 159 | res2 = pixel_dif2.abs().sum(dim=reduce_axes) 160 | 161 | return res1 + res2 162 | 163 | def Total_variation_loss(y): 164 | tv = ( 165 | torch.sum(torch.abs(y[:, :, :, :-1] - y[:, :, :, 1:])) + 166 | torch.sum(torch.abs(y[:, :, :-1, :] - y[:, :, 1:, :])) 167 | ) 168 | 169 | tv /= (2*y.nelement()) 170 | 171 | return tv 172 | 173 | class VGGLoss(nn.Module): 174 | def __init__(self, gpu_ids=0): 175 | super(VGGLoss, self).__init__() 176 | self.vgg = Vgg19().cuda(gpu_ids) 177 | self.criterion = nn.L1Loss() 178 | self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0] 179 | 180 | def forward(self, x, y): 181 | x_vgg, y_vgg = self.vgg(x), self.vgg(y) 182 | loss = 0 183 | for i in range(len(x_vgg)): 184 | loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach()) 185 | return loss 186 | 187 | from torchvision import models 188 | class Vgg19(torch.nn.Module): 189 | def __init__(self, requires_grad=False): 190 | super(Vgg19, self).__init__() 191 | vgg_pretrained_features = models.vgg19(pretrained=True).features 192 | self.slice1 = torch.nn.Sequential() 193 | self.slice2 = torch.nn.Sequential() 194 | self.slice3 = torch.nn.Sequential() 195 | self.slice4 = torch.nn.Sequential() 196 | self.slice5 = torch.nn.Sequential() 197 | for x in range(2): 198 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 199 | for x in range(2, 7): 200 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 201 | for x in range(7, 12): 202 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 203 | for x in range(12, 21): 204 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 205 | for x in range(21, 30): 206 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 207 | if not requires_grad: 208 | for param in self.parameters(): 209 | param.requires_grad = False 210 | 211 | def forward(self, X): 212 | h_relu1 = self.slice1(X) 213 | h_relu2 = self.slice2(h_relu1) 214 | h_relu3 = self.slice3(h_relu2) 215 | h_relu4 = self.slice4(h_relu3) 216 | h_relu5 = self.slice5(h_relu4) 217 | out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] 218 | return out 219 | -------------------------------------------------------------------------------- /Engine/th_utils/networks/nerf_net_utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import torch 3 | 4 | from lib.config import cfg 5 | 6 | 7 | def raw2outputs(raw, z_vals, rays_d, raw_noise_std=0, white_bkgd=False, cfg = None): 8 | """Transforms model's predictions to semantically meaningful values. 9 | Args: 10 | raw: [num_rays, num_samples along ray, 4]. Prediction from model. 11 | z_vals: [num_rays, num_samples along ray]. Integration time. 12 | rays_d: [num_rays, 3]. Direction of each ray. 13 | Returns: 14 | rgb_map: [num_rays, 3]. Estimated RGB color of a ray. 15 | disp_map: [num_rays]. Disparity map. Inverse of depth map. 16 | acc_map: [num_rays]. Sum of weights along each ray. 17 | weights: [num_rays, num_samples]. Weights assigned to each sampled color. 18 | depth_map: [num_rays]. Estimated distance to object. 19 | """ 20 | raw2alpha = lambda raw, dists, act_fn=F.relu: 1. - torch.exp(-act_fn(raw) * 21 | dists) 22 | 23 | dists = z_vals[..., 1:] - z_vals[..., :-1] 24 | dists = torch.cat( 25 | [dists, 26 | torch.Tensor([1e10]).expand(dists[..., :1].shape).to(dists)], 27 | -1) # [N_rays, N_samples] 28 | 29 | dists = dists * torch.norm(rays_d[..., None, :], dim=-1) 30 | 31 | rgb = torch.sigmoid(raw[..., 1:]) # [N_rays, N_samples, 3] 32 | #rgb = torch.sigmoid(raw[..., :3]) # [N_rays, N_samples, 3] 33 | noise = 0. 34 | if raw_noise_std > 0.: 35 | noise = torch.randn(raw[..., 0].shape) * raw_noise_std 36 | 37 | alpha_raw = raw2alpha(raw[..., 0] + noise, dists) # [N_rays, N_samples] 38 | # weights = alpha * tf.math.cumprod(1.-alpha + 1e-10, -1, exclusive=True) 39 | 40 | if (not cfg.no_local_nerf) and cfg.use_density_th: 41 | alpha = torch.sigmoid(F.relu(alpha_raw - cfg.density_th) * cfg.par_density_norm) 42 | else: 43 | alpha = alpha_raw 44 | 45 | cum = torch.cumprod( 46 | torch.cat( 47 | [torch.ones((alpha.shape[0], 1)).to(alpha), 1. - alpha + 1e-10], 48 | -1), -1)[:, :-1] 49 | 50 | weights = alpha * torch.cumprod( 51 | torch.cat( 52 | [torch.ones((alpha.shape[0], 1)).to(alpha), 1. - alpha + 1e-10], 53 | -1), -1)[:, :-1] 54 | 55 | rgb_map = torch.sum(weights[..., None] * rgb, -2) # [N_rays, 3] 56 | 57 | depth_map = torch.sum(weights * z_vals, -1) 58 | 59 | disp_map = 1. / torch.max(1e-10 * torch.ones_like(depth_map).to(depth_map), 60 | depth_map / torch.sum(weights, -1)) 61 | acc_map = torch.sum(weights, -1) 62 | 63 | if white_bkgd: 64 | rgb_map = rgb_map + (1. - acc_map[..., None]) 65 | 66 | return rgb_map, disp_map, acc_map, weights, depth_map, cum[...,-1] 67 | 68 | 69 | # Hierarchical sampling (section 5.2) 70 | def sample_pdf(bins, weights, N_samples, det=False): 71 | from torchsearchsorted import searchsorted 72 | 73 | # Get pdf 74 | weights = weights + 1e-5 # prevent nans 75 | pdf = weights / torch.sum(weights, -1, keepdim=True) 76 | cdf = torch.cumsum(pdf, -1) 77 | cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], 78 | -1) # (batch, len(bins)) 79 | 80 | # Take uniform samples 81 | if det: 82 | u = torch.linspace(0., 1., steps=N_samples).to(cdf) 83 | u = u.expand(list(cdf.shape[:-1]) + [N_samples]) 84 | else: 85 | u = torch.rand(list(cdf.shape[:-1]) + [N_samples]).to(cdf) 86 | 87 | # Invert CDF 88 | u = u.contiguous() 89 | inds = searchsorted(cdf, u, side='right') 90 | below = torch.max(torch.zeros_like(inds - 1), inds - 1) 91 | above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds) 92 | inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2) 93 | 94 | # cdf_g = tf.gather(cdf, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2) 95 | # bins_g = tf.gather(bins, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2) 96 | matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] 97 | cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) 98 | bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) 99 | 100 | denom = (cdf_g[..., 1] - cdf_g[..., 0]) 101 | denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom) 102 | t = (u - cdf_g[..., 0]) / denom 103 | samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0]) 104 | 105 | return samples 106 | -------------------------------------------------------------------------------- /Engine/th_utils/networks/nerf_util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TaoHuUMD/SurMo/ef68beea0a4615a85cceecaa35472d7525e592fb/Engine/th_utils/networks/nerf_util/__init__.py -------------------------------------------------------------------------------- /Engine/th_utils/networks/nerf_util/base_utils.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import os 3 | import numpy as np 4 | import torch 5 | 6 | def read_pickle(pkl_path): 7 | with open(pkl_path, 'rb') as f: 8 | return pickle.load(f) 9 | 10 | 11 | def save_pickle(data, pkl_path): 12 | os.system('mkdir -p {}'.format(os.path.dirname(pkl_path))) 13 | with open(pkl_path, 'wb') as f: 14 | pickle.dump(data, f) 15 | 16 | 17 | def project(xyz, K, RT): 18 | """ 19 | xyz: [N, 3] 20 | K: [3, 3] 21 | RT: [3, 4] 22 | """ 23 | xyz = np.dot(xyz, RT[:, :3].T) + RT[:, 3:].T 24 | xyz = np.dot(xyz, K.T) 25 | xy = xyz[:, :2] / xyz[:, 2:] 26 | return xy 27 | 28 | def project_torch(xyz, K, RT): 29 | """ 30 | xyz: [N, 3] 31 | K: [3, 3] 32 | RT: [3, 4] 33 | """ 34 | #xyz = np.dot(xyz, RT[:, :3].T) + RT[:, 3:].T 35 | #xyz = np.dot(xyz, K.T) 36 | 37 | xyz = torch.mm(xyz, RT[:, :3].T) + RT[:, 3:].T 38 | xyz = torch.mm(xyz, K.T) 39 | xy = xyz[:, :2] / xyz[:, 2:] 40 | 41 | 42 | return xy 43 | 44 | 45 | 46 | def write_K_pose_inf(K, poses, img_root): 47 | K = K.copy() 48 | K[:2] = K[:2] * 8 49 | K_inf = os.path.join(img_root, 'Intrinsic.inf') 50 | os.system('mkdir -p {}'.format(os.path.dirname(K_inf))) 51 | with open(K_inf, 'w') as f: 52 | for i in range(len(poses)): 53 | f.write('%d\n'%i) 54 | f.write('%f %f %f\n %f %f %f\n %f %f %f\n' % tuple(K.reshape(9).tolist())) 55 | f.write('\n') 56 | 57 | pose_inf = os.path.join(img_root, 'CamPose.inf') 58 | with open(pose_inf, 'w') as f: 59 | for pose in poses: 60 | pose = np.linalg.inv(pose) 61 | A = pose[0:3,:] 62 | tmp = np.concatenate([A[0:3,2].T, A[0:3,0].T,A[0:3,1].T,A[0:3,3].T]) 63 | f.write('%f %f %f %f %f %f %f %f %f %f %f %f\n' % tuple(tmp.tolist())) 64 | -------------------------------------------------------------------------------- /Engine/th_utils/networks/nerf_util/config.py: -------------------------------------------------------------------------------- 1 | #import open3d as o3d 2 | from .yacs import CfgNode as CN 3 | import numpy as np 4 | 5 | cfg = CN() 6 | 7 | # data 8 | cfg.body_sample_ratio = 0.5 9 | cfg.face_sample_ratio = 0. 10 | cfg.rot_ratio = 0. 11 | cfg.rot_range = np.pi / 32 12 | -------------------------------------------------------------------------------- /Engine/th_utils/networks/nerf_util/nerf_net_utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import torch 3 | 4 | from Engine.th_utils.io.prints import * 5 | 6 | def raw2outputs(raw, z_vals, rays_d, raw_noise_std=0, white_bkgd=False, cfg = None): 7 | """Transforms model's predictions to semantically meaningful values. 8 | Args: 9 | raw: [num_rays, num_samples along ray, 4]. Prediction from model. 10 | z_vals: [num_rays, num_samples along ray]. Integration time. 11 | rays_d: [num_rays, 3]. Direction of each ray. 12 | Returns: 13 | rgb_map: [num_rays, 3]. Estimated RGB color of a ray. 14 | disp_map: [num_rays]. Disparity map. Inverse of depth map. 15 | acc_map: [num_rays]. Sum of weights along each ray. 16 | weights: [num_rays, num_samples]. Weights assigned to each sampled color. 17 | depth_map: [num_rays]. Estimated distance to object. 18 | """ 19 | raw2alpha = lambda raw, dists, act_fn=F.relu: 1. - torch.exp(-act_fn(raw) * 20 | dists) 21 | 22 | 23 | dists = z_vals[..., 1:] - z_vals[..., :-1] 24 | 25 | dists = torch.cat( 26 | [dists, 27 | torch.Tensor([1e10]).expand(dists[..., :1].shape).to(dists)], 28 | -1) # [N_rays, N_samples] 29 | 30 | dists = dists * torch.norm(rays_d[..., None, :], dim=-1) 31 | 32 | rgb = torch.sigmoid(raw[..., 1:]) # [N_rays, N_samples, 3] 33 | #rgb = raw[..., 1:] # [N_rays, N_samples, 3] 34 | 35 | #rgb = torch.sigmoid(raw[..., :3]) # [N_rays, N_samples, 3] 36 | noise = 0. 37 | if raw_noise_std > 0.: 38 | noise = torch.randn(raw[..., 0].shape) * raw_noise_std 39 | 40 | alpha_raw = raw2alpha(raw[..., 0] + noise, dists) # [N_rays, N_samples] 41 | # weights = alpha * tf.math.cumprod(1.-alpha + 1e-10, -1, exclusive=True) 42 | 43 | if (not cfg.no_local_nerf) and cfg.use_density_th: 44 | alpha = torch.sigmoid(F.relu(alpha_raw - cfg.density_th) * cfg.par_density_norm) 45 | else: 46 | alpha = alpha_raw 47 | 48 | cum = torch.cumprod( 49 | torch.cat( 50 | [torch.ones((alpha.shape[0], 1)).to(alpha), 1. - alpha + 1e-10], 51 | -1), -1)[:, :-1] 52 | 53 | weights = alpha * torch.cumprod( 54 | torch.cat( 55 | [torch.ones((alpha.shape[0], 1)).to(alpha), 1. - alpha + 1e-10], 56 | -1), -1)[:, :-1] 57 | 58 | rgb_map = torch.sum(weights[..., None] * rgb, -2) # [N_rays, 3] 59 | 60 | depth_map = torch.sum(weights * z_vals, -1) 61 | 62 | disp_map = 1. / torch.max(1e-10 * torch.ones_like(depth_map).to(depth_map), 63 | depth_map / torch.sum(weights, -1)) 64 | acc_map = torch.sum(weights, -1) 65 | 66 | if white_bkgd: 67 | rgb_map = rgb_map + (1. - acc_map[..., None]) 68 | 69 | return rgb_map, disp_map, acc_map, weights, depth_map, cum[...,-1] 70 | 71 | 72 | # Hierarchical sampling (section 5.2) 73 | def sample_pdf(bins, weights, N_samples, det=False): 74 | from torchsearchsorted import searchsorted 75 | 76 | # Get pdf 77 | weights = weights + 1e-5 # prevent nans 78 | pdf = weights / torch.sum(weights, -1, keepdim=True) 79 | cdf = torch.cumsum(pdf, -1) 80 | cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], 81 | -1) # (batch, len(bins)) 82 | 83 | # Take uniform samples 84 | if det: 85 | u = torch.linspace(0., 1., steps=N_samples).to(cdf) 86 | u = u.expand(list(cdf.shape[:-1]) + [N_samples]) 87 | else: 88 | u = torch.rand(list(cdf.shape[:-1]) + [N_samples]).to(cdf) 89 | 90 | # Invert CDF 91 | u = u.contiguous() 92 | inds = searchsorted(cdf, u, side='right') 93 | below = torch.max(torch.zeros_like(inds - 1), inds - 1) 94 | above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds) 95 | inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2) 96 | 97 | # cdf_g = tf.gather(cdf, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2) 98 | # bins_g = tf.gather(bins, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2) 99 | matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] 100 | cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) 101 | bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) 102 | 103 | denom = (cdf_g[..., 1] - cdf_g[..., 0]) 104 | denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom) 105 | t = (u - cdf_g[..., 0]) / denom 106 | samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0]) 107 | 108 | return samples 109 | -------------------------------------------------------------------------------- /Engine/th_utils/networks/net_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from torch import nn 4 | import numpy as np 5 | import torch.nn.functional 6 | from collections import OrderedDict 7 | from termcolor import colored 8 | 9 | from collections import OrderedDict 10 | def divide_network_params(net, net1_prefix="uv_embedding_list"): 11 | net_1 = OrderedDict() 12 | net_2 = OrderedDict() 13 | for k in net.keys(): 14 | if k.startswith(net1_prefix): 15 | net_1[k] = net[k] 16 | continue 17 | net_2[k] = net[k] 18 | return net_1, net_2 19 | 20 | def print_key(net): 21 | par1 = dict(net.named_parameters()) 22 | for k in par1.keys(): 23 | print(k) 24 | 25 | def check_int(s): 26 | if s[0] in ('-', '+'): 27 | return s[1:].isdigit() 28 | return s.isdigit() 29 | 30 | def get_epoch(model_dir, epoch= ""): 31 | 32 | if not os.path.exists(model_dir): 33 | return -1 34 | 35 | tgt_model_name = 'latest' 36 | pth_file = "%s.pth" % tgt_model_name 37 | 38 | pths = [ 39 | int(pth.split('.')[0]) for pth in os.listdir(model_dir) 40 | if pth != 'latest.pth' and pth.find('.pth')!=-1 and check_int(pth.split('.')[0]) 41 | ] 42 | if len(pths) == 0 and pth_file not in os.listdir(model_dir): 43 | return -1 44 | 45 | if epoch == "" or epoch == "-1": 46 | if pth_file in os.listdir(model_dir): 47 | pth = tgt_model_name 48 | else: 49 | pth = max(pths) 50 | else: 51 | pth = epoch 52 | if not os.path.isfile(os.path.join(model_dir, '{}.pth'.format(pth))): 53 | return -1 54 | print('load model: {}'.format(os.path.join(model_dir, '{}.pth'.format(pth)))) 55 | pretrained_model = torch.load( 56 | os.path.join(model_dir, '{}.pth'.format(pth)), 'cuda') 57 | 58 | return pretrained_model['epoch'] 59 | 60 | def load_model(model_dir, 61 | net, 62 | opt_G = None, opt_D = None, 63 | scheduler = None, 64 | recorder = None, 65 | resume = True, 66 | epoch= ""): 67 | 68 | 69 | print(os.path.exists(model_dir), model_dir) 70 | if not os.path.exists(model_dir): 71 | return -1, 0, 0 72 | 73 | tgt_model_name = 'latest' 74 | pth_file = "%s.pth" % tgt_model_name 75 | 76 | is_estimate = False 77 | 78 | if epoch.startswith("ema"): #test 79 | is_estimate = True 80 | pths = [ 81 | int(pth.split('.')[0].split('_')[1]) for pth in os.listdir(model_dir) 82 | if pth != 'ema_latest.pth' and pth.find('.pth')!=-1 and check_int(pth.split('.')[0].split('_')[1]) 83 | ] 84 | else: 85 | pths = [ 86 | int(pth.split('.')[0]) for pth in os.listdir(model_dir) 87 | if pth != 'latest.pth' and pth.find('.pth')!=-1 and check_int(pth.split('.')[0]) 88 | ] 89 | 90 | if len(pths) == 0 and pth_file not in os.listdir(model_dir): 91 | return -1, 0, 0 92 | 93 | if epoch == "" or epoch == "-1" or epoch.find("latest")!=-1 : 94 | if pth_file in os.listdir(model_dir): 95 | pth = tgt_model_name 96 | else: 97 | pth = max(pths) 98 | else: 99 | pth = epoch 100 | 101 | if is_estimate: pth = "ema_%s" % pth 102 | if not os.path.isfile(os.path.join(model_dir,'{}.pth'.format(pth))): 103 | return -1, 0, 0 104 | 105 | print('load model: {}'.format(os.path.join(model_dir, 106 | '{}.pth'.format(pth)))) 107 | pretrained_model = torch.load( 108 | os.path.join(model_dir, '{}.pth'.format(pth)), 'cuda') 109 | 110 | net.load_state_dict(pretrained_model['net']) 111 | 112 | if opt_G is not None: 113 | opt_G.load_state_dict(pretrained_model['optimG']) 114 | 115 | if opt_D is not None: 116 | opt_D.load_state_dict(pretrained_model['optimD']) 117 | 118 | if scheduler is not None: 119 | scheduler.load_state_dict(pretrained_model['scheduler']) 120 | 121 | if recorder is not None: 122 | recorder.load_state_dict(pretrained_model['recorder']) 123 | 124 | lr = pretrained_model['lr'] 125 | 126 | return pretrained_model['epoch'], pretrained_model['iter']+1, lr 127 | 128 | 129 | def save_model(model_dir, net, opt_G, opt_D, label, epoch, iter, lr): 130 | 131 | os.system('mkdir -p {}'.format(model_dir)) 132 | 133 | pth_file = "%s.pth" % label 134 | 135 | #scheduler, recorder, 136 | model = { 137 | 'net': remove_net_prefix(net.state_dict()), 138 | #'scheduler': scheduler.state_dict(), 139 | #'recorder': recorder.state_dict(), 140 | 'epoch': epoch, 141 | 'label': label, 142 | 'iter': iter, 143 | 'lr': lr 144 | } 145 | if opt_G is not None: 146 | model.update({'optimG': opt_G.state_dict()}) 147 | if opt_D is not None: 148 | model.update({'optimD': opt_D.state_dict()}) 149 | 150 | torch.save(model, os.path.join(model_dir, pth_file)) 151 | 152 | pths = [ 153 | int(pth.split('.')[0]) for pth in os.listdir(model_dir) 154 | if pth != 'latest.pth' and pth.find('.pth')!=-1 and check_int(pth.split('.')[0]) 155 | ] 156 | if len(pths)<3: 157 | return 158 | 159 | 160 | def load_network(net, model_dir, resume=True, epoch=-1, strict=True): 161 | if not resume: 162 | return 0 163 | 164 | if not os.path.exists(model_dir): 165 | print(colored('pretrained model does not exist', 'red')) 166 | return 0 167 | 168 | if os.path.isdir(model_dir): 169 | 170 | pths = [ 171 | int(pth.split('.')[0]) for pth in os.listdir(model_dir) 172 | if pth != 'latest.pth' and pth != 'D_latest.pth' and (pth[0] !='D' and check_int(pth.split('.')[0])) 173 | ] 174 | 175 | if len(pths) == 0 and 'latest.pth' not in os.listdir(model_dir): 176 | return 0 177 | 178 | if epoch == -1: 179 | if 'latest.pth' in os.listdir(model_dir): 180 | pth = 'latest' 181 | else: 182 | pth = max(pths) 183 | else: 184 | pth = epoch 185 | 186 | model_path = os.path.join(model_dir, '{}.pth'.format(pth)) 187 | else: 188 | model_path = model_dir 189 | 190 | print('load model: {}'.format(model_path)) 191 | pretrained_model = torch.load(model_path) 192 | net.load_state_dict(pretrained_model['net'], strict=strict) 193 | return pretrained_model['epoch'] + 1 194 | 195 | def remove_untrainable_net(net): 196 | net_ = OrderedDict() 197 | for k in net.keys(): 198 | print(k, net[k].requires_grad) 199 | if not net[k].requires_grad: 200 | print(k) 201 | continue 202 | net_[k] = net[k] 203 | return net_ 204 | 205 | def remove_net_prefix(net, prefix="criterion"): 206 | net_ = OrderedDict() 207 | for k in net.keys(): 208 | if k.startswith(prefix): 209 | continue 210 | 211 | net_[k] = net[k] 212 | return net_ 213 | 214 | def remove_net_prefix_why(net, prefix="criterion"): 215 | net_ = OrderedDict() 216 | for k in net.keys(): 217 | if k.startswith(prefix): 218 | net_[k[len(prefix):]] = net[k] 219 | else: 220 | net_[k] = net[k] 221 | return net_ 222 | 223 | 224 | def add_net_prefix(net, prefix): 225 | net_ = OrderedDict() 226 | for k in net.keys(): 227 | net_[prefix + k] = net[k] 228 | return net_ 229 | 230 | 231 | def replace_net_prefix(net, orig_prefix, prefix): 232 | net_ = OrderedDict() 233 | for k in net.keys(): 234 | if k.startswith(orig_prefix): 235 | net_[prefix + k[len(orig_prefix):]] = net[k] 236 | else: 237 | net_[k] = net[k] 238 | return net_ 239 | 240 | 241 | def remove_net_layer(net, layers): 242 | keys = list(net.keys()) 243 | for k in keys: 244 | for layer in layers: 245 | if k.startswith(layer): 246 | del net[k] 247 | return net 248 | -------------------------------------------------------------------------------- /Engine/th_utils/networks/util/image_pool.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | from torch.autograd import Variable 4 | class ImagePool(): 5 | def __init__(self, pool_size): 6 | self.pool_size = pool_size 7 | if self.pool_size > 0: 8 | self.num_imgs = 0 9 | self.images = [] 10 | 11 | def query(self, images): 12 | if self.pool_size == 0: 13 | return images 14 | return_images = [] 15 | for image in images.data: 16 | image = torch.unsqueeze(image, 0) 17 | if self.num_imgs < self.pool_size: 18 | self.num_imgs = self.num_imgs + 1 19 | self.images.append(image) 20 | return_images.append(image) 21 | else: 22 | p = random.uniform(0, 1) 23 | if p > 0.5: 24 | random_id = random.randint(0, self.pool_size-1) 25 | tmp = self.images[random_id].clone() 26 | self.images[random_id] = image 27 | return_images.append(tmp) 28 | else: 29 | return_images.append(image) 30 | return_images = Variable(torch.cat(return_images, 0)) 31 | return return_images 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | S-Lab License 1.0 2 | 3 | Copyright 2023 S-Lab 4 | 5 | Redistribution and use for non-commercial purpose in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 6 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 7 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 8 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 9 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 10 | 4. In the event that redistribution and/or use for commercial purpose in source or binary forms, with or without modification is required, please contact the contributor(s) of the work. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Surface-Based Motion Modeling for Dynamic Human Rendering 2 | 3 | 4 | 5 | We propose SurMo, a new paradigm for learning dynamic human rendering from videos by jointly modeling the temporal motion dynamics and human appearances in a unified framework based on a novel surface-based triplane. We extend the existing well-adopted paradigm of "Pose Encoding → Appearance Decoding" to "Motion Encoding → Physical Motion Decoding, Appearance Decoding". 6 | 7 | This repository contains the code of SurMo that is built upon HVTR and HVTR++. 8 | 9 | **SurMo: Surface-based 4D Motion Modeling for Dynamic Human Rendering** 10 | Tao Hu, Fangzhou Hong, Ziwei Liu 11 | CVPR 2024 12 | [[Project Page]](https://taohuumd.github.io/projects/SurMo) [[Video]](https://www.youtube.com/watch?v=m_rP5HwL53I) [[Paper]](https://arxiv.org/pdf/2404.01225.pdf) 13 | 14 | 15 | **HVTR++: Image and Pose Driven Human Avatars using Hybrid Volumetric-Textural Rendering** 16 | Tao Hu, Hongyi Xu, Linjie Luo, Tao Yu, Zerong Zheng, He Zhang, Yebin Liu, Matthias Zwicker 17 | TVCG 2023 18 | [[Project Page]](https://TaoHuUMD.github.io/projects/hvtrpp/) [[Video]](https://youtu.be/RdKLfRYtg3I) [[Paper]](https://ieeexplore.ieee.org/document/10190111) 19 | 20 | **HVTR: Hybrid Volumetric-Textural Rendering for Human Avatars** 21 | Tao Hu, Tao Yu, Zerong Zheng, He Zhang, Yebin Liu, Matthias Zwicker 22 | 3DV 2022 23 | [[Project Page]](https://TaoHuUMD.github.io/projects/hvtr/) [[Video]](https://youtu.be/LE0-YpbLlkY?si=DfXp4vLKUVGCJlKG) [[Paper]](https://arxiv.org/pdf/2112.10203.pdf) 24 | 25 | 26 | # Instructions 27 | 28 | ## Test Results 29 | To facilitate comparisons with our model in subsequent work, we have saved our rendering results of ZJU-MoCap on [OneDrive](https://1drv.ms/f/c/cd958c29ffd57ddb/Ett91f8pjJUggM1-DQAAAAAB892JXlTtzxmciQIh0MC3bg?e=r1HjCH) 30 | 31 | 32 | ## Installation 33 | NVIDIA GPUs are required for this project. We have trained and tested code on NVIDIA V100. We recommend using anaconda to manage the python environments. 34 | 35 | ```bash 36 | conda create --name surmo python=3.9 37 | conda install pytorch==1.10.1 torchvision==0.11.2 cudatoolkit=11.1 -c pytorch 38 | conda install -c fvcore -c iopath -c conda-forge fvcore iopath 39 | conda install pytorch3d -c pytorch3d 40 | pip install -r requirements.txt 41 | ``` 42 | 43 | ## Test 44 | 45 | ### Download Models & Assets & Datasets 46 | 47 | Download the pretrained models and necessary assets from [OneDrive](https://1drv.ms/f/c/cd958c29ffd57ddb/EsqvoFUGhCpIpuT10AaBDkMBb_ACQRf-dgjiC1FviCCFsA?e=IK007Z). Put them in *DATA_DIR/result/trained_model* and *DATA_DIR/asset* respectively. *DATA_DIR* is specified as *./data* in default. 48 | 49 | Download [ZJU-MoCap](https://github.com/zju3dv/neuralbody/blob/master/INSTALL.md#zju-mocap-dataset) dataset and put it in the folder *zju_mocap* (e.g., *DATA_DIR/zju_mocap/CoreView_3XX*) 50 | 51 | Register and download SMPL models [here](https://smpl.is.tue.mpg.de/). Put them in the folder *smpl_data*. 52 | 53 | The folder structure should look like 54 | 55 | ``` 56 | DATA_DIR 57 | └── asset/ 58 | ├── smpl_data/ 59 | └── SMPL_NEUTRAL.pkl 60 | ├── uv_sampler/ 61 | ├── uv_table.npy 62 | ├── smpl_uv.obj 63 | ├── sample_data.pkl 64 | ├── dataset 65 | ├──zju_mocap/ 66 | ├── result/ 67 | ├── trained_model/modelname/ 68 | └──xx.pth 69 | ├── test_output 70 | 71 | ``` 72 | 73 | ### Commands 74 | 75 | The test script for models (313, 315, 377, 386, 387, 394) trained on ZJU-MoCap. 76 | ```bash 77 | bash scripts/zju/3XX_test.sh [gpu_ids] 78 | ``` 79 | i.e., bash scripts/zju/313_test.sh or bash scripts/zju/313_test.sh 0 80 | 81 | The test results will be found in *DATA_DIR/result/*. An example rendering result is shown in *docs/figs/test_example.jpg*, which includes a generated image, ground truth image, positional map, predicted normal map, and low resolution NeRF renderings. 82 | 83 | ## Training 84 | 85 | ### Commands 86 | The training script for subjects (313, 315, 377, 386, 387, 394) on ZJU-MoCap. 87 | ```bash 88 | bash scripts/zju/3XX_train.sh [gpuids] 89 | ``` 90 | i.e., bash scripts/zju/313_train.sh or bash scripts/zju/313_train.sh 0,1,2,3 91 | 92 | The trained models will be saved in *DATA_DIR/result/trained_model/*. 93 | ## License 94 | 95 | Distributed under the S-Lab License. See `LICENSE` for more information. 96 | 97 | ## Citation 98 | ```bibtex 99 | @misc{hu2024surmo, 100 | title={SurMo: Surface-based 4D Motion Modeling for Dynamic Human Rendering}, 101 | author={Tao Hu and Fangzhou Hong and Ziwei Liu}, 102 | year={2024}, 103 | eprint={2404.01225}, 104 | archivePrefix={arXiv}, 105 | primaryClass={cs.CV} 106 | } 107 | 108 | @ARTICLE{hu2023hvtrpp, 109 | author={Hu, Tao and Xu, Hongyi and Luo, Linjie and Yu, Tao and Zheng, Zerong and Zhang, He and Liu, Yebin and Zwicker, Matthias}, 110 | journal={IEEE Transactions on Visualization and Computer Graphics}, 111 | title={HVTR++: Image and Pose Driven Human Avatars using Hybrid Volumetric-Textural Rendering}, 112 | year={2023} 113 | } 114 | 115 | @inproceedings{hu2022hvtr, 116 | title={HVTR: Hybrid Volumetric-Textural Rendering for Human Avatars}, 117 | author={Hu, Tao and Yu, Tao and Zheng, Zerong and Zhang, He and Liu, Yebin and Zwicker, Matthias}, 118 | booktitle = {2022 International Conference on 3D Vision (3DV)}, 119 | year = {2022} 120 | } 121 | ``` -------------------------------------------------------------------------------- /configs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TaoHuUMD/SurMo/ef68beea0a4615a85cceecaa35472d7525e592fb/configs/__init__.py -------------------------------------------------------------------------------- /configs/config_util.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import subprocess 4 | import sys 5 | 6 | import torch 7 | import yaml 8 | from easydict import EasyDict as edict 9 | 10 | 11 | def record_setting(out): 12 | """Record scripts and commandline arguments""" 13 | # out = out.split()[0].strip() 14 | source = out + "/source" 15 | if not os.path.exists(source): 16 | os.system('mkdir -p %s' % source) 17 | # os.mkdir(out) 18 | 19 | # subprocess.call("cp *.py %s" % source, shell=True) 20 | # subprocess.call("cp configs/*.yml %s" % out, shell=True) 21 | 22 | subprocess.call("find . -type d -name result -prune -o -name '*.py' -print0" 23 | "| xargs -0 cp --parents -p -t %s" % source, shell=True) 24 | subprocess.call("find . -type d -name result -prune -o -name '*.yml' -print0|" 25 | " xargs -0 cp --parents -p -t %s" % source, shell=True) 26 | 27 | with open(out + "/command.txt", "w") as f: 28 | f.write(" ".join(sys.argv) + "\n") 29 | 30 | def get_config_dataset(config_file): 31 | data_config="configs/datasets/%s" % config_file 32 | conf = edict(yaml.load(open(data_config), Loader=yaml.SafeLoader)) 33 | return conf 34 | 35 | def get_config(config_file): 36 | conf = edict(yaml.load(open(config), Loader=yaml.SafeLoader)) 37 | return conf 38 | 39 | def merge_config_opt(config_file, opt): 40 | conf = edict(yaml.load(open(config_file), Loader=yaml.SafeLoader)) 41 | for k in conf: 42 | setattr(opt, k, conf[k]) 43 | 44 | def yaml_config(config, default_cofig, resume_latest=False, num_workers=1): 45 | default = edict(yaml.load(open(default_cofig), Loader=yaml.SafeLoader)) 46 | 47 | #data_config="configs/datasets/%s" % config 48 | data_config = config 49 | conf = edict(yaml.load(open(data_config), Loader=yaml.SafeLoader)) 50 | 51 | def copy(conf, default): 52 | for key in conf: 53 | if key in default and isinstance(default[key], edict): 54 | copy(conf[key], default[key]) 55 | else: 56 | default[key] = conf[key] 57 | 58 | copy(conf, default) 59 | 60 | #default.resume_latest = resume_latest 61 | #default.dataset.num_workers = num_workers 62 | return default 63 | 64 | 65 | def read_config(): 66 | parser = argparse.ArgumentParser() 67 | parser.add_argument('--config', type=str, default="configs/hinge.yml") 68 | args = parser.parse_args() 69 | 70 | config = yaml_config(args.config) 71 | return config 72 | 73 | 74 | def write(iter, loss, name, writer): 75 | writer.add_scalar("metrics/" + name, loss, iter) 76 | return loss 77 | 78 | 79 | def ddp_data_sampler(dataset, rank, world_size, shuffle, drop_last): 80 | dist_sampler = torch.utils.data.distributed.DistributedSampler( 81 | dataset, 82 | rank=rank, 83 | num_replicas=world_size, 84 | shuffle=shuffle, 85 | drop_last=drop_last 86 | ) 87 | return dist_sampler 88 | -------------------------------------------------------------------------------- /configs/datasets/zju/base.yml: -------------------------------------------------------------------------------- 1 | 2 | train_rgbd: False 3 | 4 | gender: "neutral" 5 | org_img_reso: 1024 6 | dataset : "zju" 7 | dataset_basedir: "zju_mocap" 8 | 9 | white_bg: False 10 | display_winsize: 3000 11 | 12 | fps: 30 13 | 14 | # image_dir : 'image' 15 | # mask_dir : 'refined_mask' 16 | # sub_flag : "Camera_B" 17 | # depth_dir : 'depth' 18 | # annots: "annots.npy" 19 | 20 | 21 | img_H: 1024 22 | img_W: 1024 -------------------------------------------------------------------------------- /configs/datasets/zju/motion_313_fv.yml: -------------------------------------------------------------------------------- 1 | 2 | 3 | img_H: 1024 4 | img_W: 1024 5 | 6 | iters_each_sample : 300 7 | demo_frame : "1" 8 | 9 | dataset: 10 | gender: "neutral" 11 | dirname: "CoreView_313" 12 | resname: "motion_313" 13 | id : 0 14 | valid_begin: 0 15 | valid_end : 652 16 | 17 | train_begin : [1] 18 | train_end : [410] 19 | test_begin: [1] 20 | test_end : [410] 21 | 22 | novel_pose_begin: [410] 23 | novel_pose_end: [710] 24 | 25 | 26 | train_view : "0 6 12 18" # 10 camera 27 | test_view : "1 2 3 4 5 7 8 9 10 11 13 14 15 16 17 19 20" 28 | dataset_step: 1 29 | 30 | demo_view: "1 7 13 19" 31 | demo_begin: [1] 32 | demo_end : [410] 33 | 34 | novel_demo_begin: [10] 35 | novel_demo_end : [710] -------------------------------------------------------------------------------- /configs/datasets/zju/motion_315_fv.yml: -------------------------------------------------------------------------------- 1 | img_H: 1024 2 | img_W: 1024 3 | 4 | iters_each_sample : 300 5 | demo_frame : "1" 6 | 7 | dataset: 8 | gender: "neutral" 9 | dirname: "CoreView_315" 10 | resname: "motion_315" 11 | id : 0 12 | valid_begin: 0 13 | valid_end : 2184 14 | train_begin : [1] 15 | train_end : [700] 16 | test_begin: [1] 17 | test_end : [700] 18 | train_view : "0 6 12 18" 19 | test_view : "1 2 3 4 5 7 8 9 10 11 13 14 15 16 17 19 20" 20 | dataset_step: 1 21 | 22 | novel_pose_begin: [700] 23 | novel_pose_end: [1000] 24 | 25 | demo_view: "1 7 13 19" 26 | demo_begin: [1] 27 | demo_end : [700] 28 | 29 | novel_demo_begin: [700] 30 | novel_demo_end : [1000] -------------------------------------------------------------------------------- /configs/datasets/zju/motion_377_fv.yml: -------------------------------------------------------------------------------- 1 | 2 | img_H: 1024 3 | img_W: 1024 4 | 5 | iters_each_sample : 300 6 | demo_frame : "1" 7 | 8 | dataset: 9 | gender: "neutral" 10 | dirname: "CoreView_377" 11 | resname: "motion_377" 12 | id : 0 13 | valid_begin: 0 14 | valid_end : 652 15 | train_begin : [1] 16 | train_end : [300] 17 | test_begin: [1] 18 | test_end : [300] 19 | 20 | novel_pose_begin: [300] 21 | novel_pose_end: [618] 22 | 23 | train_view : "0 6 12 18" # 10 camera 24 | test_view : "1 2 3 4 5 7 8 9 10 11 13 14 15 16 17 19 20" 25 | dataset_step: 1 26 | 27 | demo_view: "1 7 13 19" 28 | demo_begin: [1] 29 | demo_end : [300] 30 | 31 | novel_demo_begin: [300] 32 | novel_demo_end : [618] -------------------------------------------------------------------------------- /configs/datasets/zju/motion_386_fv.yml: -------------------------------------------------------------------------------- 1 | img_H: 1024 2 | img_W: 1024 3 | 4 | iters_each_sample : 300 5 | demo_frame : "1" 6 | 7 | dataset: 8 | gender: "neutral" 9 | dirname: "CoreView_386" 10 | resname: "motion_386" 11 | id : 0 12 | valid_begin: 0 13 | valid_end : 620 14 | train_begin : [1] 15 | train_end : [300] 16 | test_begin: [1] 17 | test_end : [300] 18 | 19 | novel_pose_begin: [300] 20 | novel_pose_end: [647] 21 | 22 | train_view : "0 6 12 18" # 10 camera 23 | test_view : "1 2 3 4 5 7 8 9 10 11 13 14 15 16 17 19 20" 24 | dataset_step: 1 25 | 26 | demo_view: "1 7 13 19" 27 | demo_begin: [1] 28 | demo_end : [300] 29 | 30 | novel_demo_begin: [300] 31 | novel_demo_end : [647] -------------------------------------------------------------------------------- /configs/datasets/zju/motion_387_fv.yml: -------------------------------------------------------------------------------- 1 | img_H: 1024 2 | img_W: 1024 3 | 4 | iters_each_sample : 300 5 | demo_frame : "1" 6 | 7 | dataset: 8 | gender: "neutral" 9 | dirname: "CoreView_387" 10 | resname: "motion_387" 11 | id : 0 12 | valid_begin: 0 13 | valid_end : 652 14 | train_begin : [1] 15 | train_end : [600] 16 | test_begin: [1] 17 | test_end : [600] 18 | train_view : "0 6 12 18" 19 | test_view : "1 2 3 4 5 7 8 9 10 11 13 14 15 16 17 19 20" 20 | dataset_step: 1 21 | 22 | novel_pose_begin: [600] 23 | novel_pose_end: [900] 24 | 25 | demo_view: "1 7 13 19" 26 | demo_begin: [1] 27 | demo_end : [600] 28 | 29 | novel_demo_begin: [600] 30 | novel_demo_end : [900] -------------------------------------------------------------------------------- /configs/datasets/zju/motion_394_fv.yml: -------------------------------------------------------------------------------- 1 | img_H: 1024 2 | img_W: 1024 3 | 4 | iters_each_sample : 300 5 | demo_frame : "1" 6 | 7 | dataset: 8 | gender: "neutral" 9 | dirname: "CoreView_394" 10 | resname: "motion_394" 11 | id : 0 12 | valid_begin: 0 13 | valid_end : 858 14 | train_begin : [1] 15 | train_end : [600] 16 | test_begin: [1] 17 | test_end : [600] 18 | train_view : "0 6 12 18" 19 | test_view : "1 2 3 4 5 7 8 9 10 11 13 14 15 16 17 19 20" 20 | dataset_step: 1 21 | 22 | novel_pose_begin: [600] 23 | novel_pose_end: [900] 24 | 25 | demo_view: "1 7 13 19" 26 | demo_begin: [1] 27 | demo_end : [600] 28 | 29 | novel_demo_begin: [600] 30 | novel_demo_end : [900] -------------------------------------------------------------------------------- /configs/defaults.py: -------------------------------------------------------------------------------- 1 | aug_nr: True -------------------------------------------------------------------------------- /configs/defaults.yml: -------------------------------------------------------------------------------- 1 | multi_id: True -------------------------------------------------------------------------------- /configs/methods/motion.yml: -------------------------------------------------------------------------------- 1 | use_nerf : True 2 | use_gen : True 3 | use_face : True 4 | #use_face : False 5 | 6 | model : P_Motion 7 | is_expand_nerf_channel : True 8 | is_insert_nerf_depth : True 9 | is_skip_nerf : True 10 | uv_reso: 256 -------------------------------------------------------------------------------- /configs/methods/vrnr.py: -------------------------------------------------------------------------------- 1 | use_nerf_time_latent: True -------------------------------------------------------------------------------- /configs/methods/vrnr.yml: -------------------------------------------------------------------------------- 1 | use_nerf : True 2 | use_gen : True 3 | use_face : True 4 | #use_face : False 5 | 6 | model : P_vrnr 7 | use_nerf_time_latent : True 8 | is_expand_nerf_channel : True 9 | is_insert_nerf_depth : True 10 | fuse_mode : 0 11 | nerf_output_rgb_dim : 256 12 | is_skip_nerf : True 13 | 14 | uv_reso: 128 -------------------------------------------------------------------------------- /configs/projects/uvm.yml: -------------------------------------------------------------------------------- 1 | project_directory: "uvm_lib" 2 | model_module: "model_motion" -------------------------------------------------------------------------------- /docs/figs/summary.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TaoHuUMD/SurMo/ef68beea0a4615a85cceecaa35472d7525e592fb/docs/figs/summary.jpg -------------------------------------------------------------------------------- /docs/figs/test_example.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TaoHuUMD/SurMo/ef68beea0a4615a85cceecaa35472d7525e592fb/docs/figs/test_example.jpg -------------------------------------------------------------------------------- /download_models.py: -------------------------------------------------------------------------------- 1 | #to update -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | dominate==2.7.0 2 | easydict==1.10 3 | facenet==1.0.5 4 | imageio==2.10.5 5 | iopath==0.1.10 6 | lib==4.0.0 7 | matplotlib==3.6.2 8 | munch==2.5.0 9 | numpy==1.23.5 10 | omegaconf==2.3.0 11 | open3d==0.18.0 12 | opencv_python==4.6.0.66 13 | Pillow==9.3.0 14 | PyMCubes==0.1.2 15 | pyskeleton==1.0.0 16 | pytorch3d==0.6.2 17 | PyYAML==6.0 18 | Requests==2.32.3 19 | sets==0.3.2 20 | setuptools==68.0.0 21 | six==1.16.0 22 | smplx==0.1.28 23 | structures==0.9.5 24 | tensorboardX==2.5.1 25 | termcolor==2.5.0 26 | torch==1.10.1+cu111 27 | torchsummary==1.5.1 28 | torchvision==0.11.2+cu111 29 | tqdm==4.64.1 30 | trimesh==3.17.1 31 | utils==1.0.2 -------------------------------------------------------------------------------- /scripts/zju/313_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | id=$1 4 | gpu="${id}" 5 | 6 | method="--use_style --learn_3d --use_posmap --uvVol_2d --nr_pose_dep_uv" 7 | 8 | config="--config zju/motion_313_fv.yml --project_config uvm.yml --method_config motion.yml" 9 | 10 | render="--superreso LightSup --vrnr" 11 | setup="--motion_mode --uv_type SMPL --uv_reso 256 --batchSize 1" 12 | net="--posenet_outdim 96 --tex_latent_dim 16 --ab_uvh_plane_c 32 --nerf_dim 32 --style_dim 256" 13 | 14 | ab="--c_velo --c_acce --c_traj --velocity 1 --pred_pose_uv --pred_normal_uv --ab_pred_pose_by_velocity --new_dynamics --pred_pose_uv_rot --rot_all_same" 15 | 16 | aug="--small_rot --aug_nr" 17 | 18 | modelname="pretrained_313_Mp3dS" 19 | debug="" 20 | 21 | epoch="--niter 200 --niter_decay 0" 22 | reso="--gen_ratio 0.5 --nerf_ratio 0.25" 23 | 24 | improve="--ab_cond_uv_latent --ab_D_pose" 25 | 26 | basic="--N_samples 28 --uvVol_smpl_pts 8 --distributed --use_org_discrim --use_org_gan_loss --learn_uv --ab_uvh_plane --plus_uvh_enc --ab_nerf_rec --load_tmp_rendering" 27 | 28 | w="--w_D 1.0 --w_G_GAN 1.0 --w_Face 5 --w_G_L1 0.5 --w_G_feat 10 --w_nerf_rec 15 --w_pred_normal_uv 1.0 --w_rot_normal 1.0 --w_posmap_feat 0 --w_posmap 1" 29 | 30 | cmd="${basic} ${setup} ${net} ${ab} ${w} ${debug} ${epoch} ${reso} ${method} ${aug} ${render} ${improve} ${config} --name $modelname" 31 | 32 | if [ -z "$gpu" ] 33 | then 34 | gpu=$CUDA_VISIBLE_DEVICES 35 | fi 36 | export CUDA_DEVICE_ORDER=PCI_BUS_ID 37 | export NVIDIA_VISIBLE_DEVICES=${gpu} 38 | export CUDA_VISIBLE_DEVICES=${gpu} 39 | 40 | python -W ignore test.py ${cmd} --test_step_size 30 --test_eval -------------------------------------------------------------------------------- /scripts/zju/313_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | id=$1 4 | gpu="${id}" 5 | 6 | method="--use_style --learn_3d --use_posmap --uvVol_2d --nr_pose_dep_uv" 7 | 8 | config="--config zju/motion_313_fv.yml --project_config uvm.yml --method_config motion.yml" 9 | 10 | render="--superreso LightSup --vrnr" 11 | setup="--motion_mode --uv_type SMPL --uv_reso 256 --batchSize 1" 12 | net="--posenet_outdim 96 --tex_latent_dim 16 --ab_uvh_plane_c 32 --nerf_dim 32 --style_dim 256" 13 | 14 | ab="--c_velo --c_acce --c_traj --velocity 1 --pred_pose_uv --pred_normal_uv --ab_pred_pose_by_velocity --new_dynamics --pred_pose_uv_rot --rot_all_same" 15 | 16 | aug="--small_rot --aug_nr" 17 | 18 | modelname="train2_313_Mp3dS" 19 | debug="" 20 | 21 | epoch="--niter 200 --niter_decay 0" 22 | reso="--gen_ratio 0.5 --nerf_ratio 0.25" 23 | 24 | improve="--ab_cond_uv_latent --ab_D_pose" 25 | 26 | basic="--N_samples 28 --uvVol_smpl_pts 8 --distributed --use_org_discrim --use_org_gan_loss --learn_uv --ab_uvh_plane --plus_uvh_enc --ab_nerf_rec --load_tmp_rendering" 27 | 28 | w="--w_D 1.0 --w_G_GAN 1.0 --w_Face 5 --w_G_L1 0.5 --w_G_feat 10 --w_nerf_rec 15 --w_pred_normal_uv 1.0 --w_rot_normal 1.0 --w_posmap_feat 0 --w_posmap 1" 29 | 30 | cmd="${basic} ${setup} ${net} ${ab} ${w} ${debug} ${epoch} ${reso} ${method} ${aug} ${render} ${improve} ${config} --name $modelname" 31 | 32 | if [ -z "$gpu" ] 33 | then 34 | gpu=$CUDA_VISIBLE_DEVICES 35 | fi 36 | 37 | export NVIDIA_VISIBLE_DEVICES=${gpu} 38 | export CUDA_VISIBLE_DEVICES=${gpu} 39 | 40 | MASTER_PORT=$((12000 + $RANDOM % 20000)) 41 | gpures="${gpu//[^,]}" 42 | NUM_GPU="${#gpures}" 43 | NUM_GPU=$((NUM_GPU+1)) 44 | extra_opt="${extra_opt} --gpu_ids ${gpu}" 45 | 46 | python -W ignore -m torch.distributed.launch --nproc_per_node=${NUM_GPU} --master_port=${MASTER_PORT} train_dist.py ${cmd} 47 | -------------------------------------------------------------------------------- /scripts/zju/315_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | id=$1 4 | gpu="${id}" 5 | 6 | method="--use_style --learn_3d --use_posmap --uvVol_2d --nr_pose_dep_uv" 7 | 8 | config="--config zju/motion_315_fv.yml --project_config uvm.yml --method_config motion.yml" 9 | 10 | render="--superreso LightSup --vrnr" 11 | setup="--motion_mode --uv_type SMPL --uv_reso 256 --batchSize 1" 12 | net="--posenet_outdim 96 --tex_latent_dim 16 --ab_uvh_plane_c 32 --nerf_dim 32 --style_dim 256" 13 | 14 | ab="--c_velo --c_acce --c_traj --velocity 1 --pred_pose_uv --pred_normal_uv --ab_pred_pose_by_velocity --new_dynamics --pred_pose_uv_rot --rot_all_same" 15 | 16 | aug="--small_rot --aug_nr" 17 | 18 | modelname="pretrained_315_Mp3dS" 19 | debug="" 20 | 21 | epoch="--niter 200 --niter_decay 0" 22 | reso="--gen_ratio 0.5 --nerf_ratio 0.25" 23 | 24 | improve="--ab_cond_uv_latent --ab_D_pose" 25 | 26 | basic="--N_samples 28 --uvVol_smpl_pts 8 --distributed --use_org_discrim --use_org_gan_loss --learn_uv --ab_uvh_plane --plus_uvh_enc --ab_nerf_rec --load_tmp_rendering" 27 | 28 | w="--w_D 1.0 --w_G_GAN 1.0 --w_Face 5 --w_G_L1 0.5 --w_G_feat 10 --w_nerf_rec 15 --w_pred_normal_uv 1.0 --w_rot_normal 1.0 --w_posmap_feat 0 --w_posmap 1" 29 | 30 | cmd="${basic} ${setup} ${net} ${ab} ${w} ${debug} ${epoch} ${reso} ${method} ${aug} ${render} ${improve} ${config} --name $modelname" 31 | 32 | if [ -z "$gpu" ] 33 | then 34 | gpu=$CUDA_VISIBLE_DEVICES 35 | fi 36 | export CUDA_DEVICE_ORDER=PCI_BUS_ID 37 | export NVIDIA_VISIBLE_DEVICES=${gpu} 38 | export CUDA_VISIBLE_DEVICES=${gpu} 39 | 40 | python -W ignore test.py ${cmd} --test_step_size 30 --test_eval -------------------------------------------------------------------------------- /scripts/zju/315_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | id=$1 4 | gpu="${id}" 5 | 6 | method="--use_style --learn_3d --use_posmap --uvVol_2d --nr_pose_dep_uv" 7 | 8 | config="--config zju/motion_315_fv.yml --project_config uvm.yml --method_config motion.yml" 9 | 10 | render="--superreso LightSup --vrnr" 11 | setup="--motion_mode --uv_type SMPL --uv_reso 256 --batchSize 1" 12 | net="--posenet_outdim 96 --tex_latent_dim 16 --ab_uvh_plane_c 32 --nerf_dim 32 --style_dim 256" 13 | 14 | ab="--c_velo --c_acce --c_traj --velocity 1 --pred_pose_uv --pred_normal_uv --ab_pred_pose_by_velocity --new_dynamics --pred_pose_uv_rot --rot_all_same" 15 | 16 | aug="--small_rot --aug_nr" 17 | 18 | modelname="train_315_Mp3dS" 19 | debug="" 20 | 21 | epoch="--niter 200 --niter_decay 0" 22 | reso="--gen_ratio 0.5 --nerf_ratio 0.25" 23 | 24 | improve="--ab_cond_uv_latent --ab_D_pose" 25 | 26 | basic="--N_samples 28 --uvVol_smpl_pts 8 --distributed --use_org_discrim --use_org_gan_loss --learn_uv --ab_uvh_plane --plus_uvh_enc --ab_nerf_rec --load_tmp_rendering" 27 | 28 | w="--w_D 1.0 --w_G_GAN 1.0 --w_Face 5 --w_G_L1 0.5 --w_G_feat 10 --w_nerf_rec 15 --w_pred_normal_uv 1.0 --w_rot_normal 1.0 --w_posmap_feat 0 --w_posmap 1" 29 | 30 | cmd="${basic} ${setup} ${net} ${ab} ${w} ${debug} ${epoch} ${reso} ${method} ${aug} ${render} ${improve} ${config} --name $modelname" 31 | 32 | if [ -z "$gpu" ] 33 | then 34 | gpu=$CUDA_VISIBLE_DEVICES 35 | fi 36 | 37 | export NVIDIA_VISIBLE_DEVICES=${gpu} 38 | export CUDA_VISIBLE_DEVICES=${gpu} 39 | 40 | MASTER_PORT=$((12000 + $RANDOM % 20000)) 41 | gpures="${gpu//[^,]}" 42 | NUM_GPU="${#gpures}" 43 | NUM_GPU=$((NUM_GPU+1)) 44 | extra_opt="${extra_opt} --gpu_ids ${gpu}" 45 | 46 | python -W ignore -m torch.distributed.launch --nproc_per_node=${NUM_GPU} --master_port=${MASTER_PORT} train_dist.py ${cmd} 47 | -------------------------------------------------------------------------------- /scripts/zju/377_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | id=$1 4 | gpu="${id}" 5 | 6 | method="--use_style --learn_3d --use_posmap --uvVol_2d --nr_pose_dep_uv" 7 | 8 | config="--config zju/motion_377_fv.yml --project_config uvm.yml --method_config motion.yml" 9 | 10 | render="--superreso LightSup --vrnr" 11 | setup="--motion_mode --uv_type SMPL --uv_reso 256 --batchSize 1" 12 | net="--posenet_outdim 96 --tex_latent_dim 16 --ab_uvh_plane_c 32 --nerf_dim 32 --style_dim 256" 13 | 14 | ab="--c_velo --c_acce --c_traj --velocity 1 --pred_pose_uv --pred_normal_uv --ab_pred_pose_by_velocity --new_dynamics --pred_pose_uv_rot --rot_all_same" 15 | 16 | aug="--small_rot --aug_nr" 17 | 18 | modelname="pretrained_377_Mp3dS" 19 | debug="" 20 | 21 | epoch="--niter 200 --niter_decay 0" 22 | reso="--gen_ratio 0.5 --nerf_ratio 0.25" 23 | 24 | improve="--ab_cond_uv_latent --ab_D_pose" 25 | 26 | basic="--N_samples 28 --uvVol_smpl_pts 8 --distributed --use_org_discrim --use_org_gan_loss --learn_uv --ab_uvh_plane --plus_uvh_enc --ab_nerf_rec --load_tmp_rendering" 27 | 28 | w="--w_D 1.0 --w_G_GAN 1.0 --w_Face 5 --w_G_L1 0.5 --w_G_feat 10 --w_nerf_rec 15 --w_pred_normal_uv 1.0 --w_rot_normal 1.0 --w_posmap_feat 0 --w_posmap 1" 29 | 30 | cmd="${basic} ${setup} ${net} ${ab} ${w} ${debug} ${epoch} ${reso} ${method} ${aug} ${render} ${improve} ${config} --name $modelname" 31 | 32 | if [ -z "$gpu" ] 33 | then 34 | gpu=$CUDA_VISIBLE_DEVICES 35 | fi 36 | export CUDA_DEVICE_ORDER=PCI_BUS_ID 37 | export NVIDIA_VISIBLE_DEVICES=${gpu} 38 | export CUDA_VISIBLE_DEVICES=${gpu} 39 | 40 | python -W ignore test.py ${cmd} --test_step_size 30 --test_eval -------------------------------------------------------------------------------- /scripts/zju/377_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | id=$1 4 | gpu="${id}" 5 | 6 | method="--use_style --learn_3d --use_posmap --uvVol_2d --nr_pose_dep_uv" 7 | 8 | config="--config zju/motion_377_fv.yml --project_config uvm.yml --method_config motion.yml" 9 | 10 | render="--superreso LightSup --vrnr" 11 | setup="--motion_mode --uv_type SMPL --uv_reso 256 --batchSize 1" 12 | net="--posenet_outdim 96 --tex_latent_dim 16 --ab_uvh_plane_c 32 --nerf_dim 32 --style_dim 256" 13 | 14 | ab="--c_velo --c_acce --c_traj --velocity 1 --pred_pose_uv --pred_normal_uv --ab_pred_pose_by_velocity --new_dynamics --pred_pose_uv_rot --rot_all_same" 15 | 16 | aug="--small_rot --aug_nr" 17 | 18 | modelname="train_377_Mp3dS" 19 | debug="" 20 | 21 | epoch="--niter 200 --niter_decay 0" 22 | reso="--gen_ratio 0.5 --nerf_ratio 0.25" 23 | 24 | improve="--ab_cond_uv_latent --ab_D_pose" 25 | 26 | basic="--N_samples 28 --uvVol_smpl_pts 8 --distributed --use_org_discrim --use_org_gan_loss --learn_uv --ab_uvh_plane --plus_uvh_enc --ab_nerf_rec --load_tmp_rendering" 27 | 28 | w="--w_D 1.0 --w_G_GAN 1.0 --w_Face 5 --w_G_L1 0.5 --w_G_feat 10 --w_nerf_rec 15 --w_pred_normal_uv 1.0 --w_rot_normal 1.0 --w_posmap_feat 0 --w_posmap 1" 29 | 30 | cmd="${basic} ${setup} ${net} ${ab} ${w} ${debug} ${epoch} ${reso} ${method} ${aug} ${render} ${improve} ${config} --name $modelname" 31 | 32 | if [ -z "$gpu" ] 33 | then 34 | gpu=$CUDA_VISIBLE_DEVICES 35 | fi 36 | 37 | export NVIDIA_VISIBLE_DEVICES=${gpu} 38 | export CUDA_VISIBLE_DEVICES=${gpu} 39 | 40 | MASTER_PORT=$((12000 + $RANDOM % 20000)) 41 | gpures="${gpu//[^,]}" 42 | NUM_GPU="${#gpures}" 43 | NUM_GPU=$((NUM_GPU+1)) 44 | extra_opt="${extra_opt} --gpu_ids ${gpu}" 45 | 46 | python -W ignore -m torch.distributed.launch --nproc_per_node=${NUM_GPU} --master_port=${MASTER_PORT} train_dist.py ${cmd} 47 | -------------------------------------------------------------------------------- /scripts/zju/386_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | id=$1 4 | gpu="${id}" 5 | 6 | method="--use_style --learn_3d --use_posmap --uvVol_2d --nr_pose_dep_uv" 7 | 8 | config="--config zju/motion_386_fv.yml --project_config uvm.yml --method_config motion.yml" 9 | 10 | render="--superreso LightSup --vrnr" 11 | setup="--motion_mode --uv_type SMPL --uv_reso 256 --batchSize 1" 12 | net="--posenet_outdim 96 --tex_latent_dim 16 --ab_uvh_plane_c 32 --nerf_dim 32 --style_dim 256" 13 | 14 | ab="--c_velo --c_acce --c_traj --velocity 1 --pred_pose_uv --pred_normal_uv --ab_pred_pose_by_velocity --new_dynamics --pred_pose_uv_rot --rot_all_same" 15 | 16 | aug="--small_rot --aug_nr" 17 | 18 | modelname="pretrained_386_Mp3dS" 19 | debug="" 20 | 21 | epoch="--niter 200 --niter_decay 0" 22 | reso="--gen_ratio 0.5 --nerf_ratio 0.25" 23 | 24 | improve="--ab_cond_uv_latent --ab_D_pose" 25 | 26 | basic="--N_samples 28 --uvVol_smpl_pts 8 --distributed --use_org_discrim --use_org_gan_loss --learn_uv --ab_uvh_plane --plus_uvh_enc --ab_nerf_rec --load_tmp_rendering" 27 | 28 | w="--w_D 1.0 --w_G_GAN 1.0 --w_Face 5 --w_G_L1 0.5 --w_G_feat 10 --w_nerf_rec 15 --w_pred_normal_uv 1.0 --w_rot_normal 1.0 --w_posmap_feat 0 --w_posmap 1" 29 | 30 | cmd="${basic} ${setup} ${net} ${ab} ${w} ${debug} ${epoch} ${reso} ${method} ${aug} ${render} ${improve} ${config} --name $modelname" 31 | 32 | if [ -z "$gpu" ] 33 | then 34 | gpu=$CUDA_VISIBLE_DEVICES 35 | fi 36 | export CUDA_DEVICE_ORDER=PCI_BUS_ID 37 | export NVIDIA_VISIBLE_DEVICES=${gpu} 38 | export CUDA_VISIBLE_DEVICES=${gpu} 39 | 40 | python -W ignore test.py ${cmd} --test_step_size 30 --test_eval -------------------------------------------------------------------------------- /scripts/zju/386_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | id=$1 4 | gpu="${id}" 5 | 6 | method="--use_style --learn_3d --use_posmap --uvVol_2d --nr_pose_dep_uv" 7 | 8 | config="--config zju/motion_386_fv.yml --project_config uvm.yml --method_config motion.yml" 9 | 10 | render="--superreso LightSup --vrnr" 11 | setup="--motion_mode --uv_type SMPL --uv_reso 256 --batchSize 1" 12 | net="--posenet_outdim 96 --tex_latent_dim 16 --ab_uvh_plane_c 32 --nerf_dim 32 --style_dim 256" 13 | 14 | ab="--c_velo --c_acce --c_traj --velocity 1 --pred_pose_uv --pred_normal_uv --ab_pred_pose_by_velocity --new_dynamics --pred_pose_uv_rot --rot_all_same" 15 | 16 | aug="--small_rot --aug_nr" 17 | 18 | modelname="train_386_Mp3dS" 19 | debug="" 20 | 21 | epoch="--niter 200 --niter_decay 0" 22 | reso="--gen_ratio 0.5 --nerf_ratio 0.25" 23 | 24 | improve="--ab_cond_uv_latent --ab_D_pose" 25 | 26 | basic="--N_samples 28 --uvVol_smpl_pts 8 --distributed --use_org_discrim --use_org_gan_loss --learn_uv --ab_uvh_plane --plus_uvh_enc --ab_nerf_rec --load_tmp_rendering" 27 | 28 | w="--w_D 1.0 --w_G_GAN 1.0 --w_Face 5 --w_G_L1 0.5 --w_G_feat 10 --w_nerf_rec 15 --w_pred_normal_uv 1.0 --w_rot_normal 1.0 --w_posmap_feat 0 --w_posmap 1" 29 | 30 | cmd="${basic} ${setup} ${net} ${ab} ${w} ${debug} ${epoch} ${reso} ${method} ${aug} ${render} ${improve} ${config} --name $modelname" 31 | 32 | if [ -z "$gpu" ] 33 | then 34 | gpu=$CUDA_VISIBLE_DEVICES 35 | fi 36 | 37 | export NVIDIA_VISIBLE_DEVICES=${gpu} 38 | export CUDA_VISIBLE_DEVICES=${gpu} 39 | 40 | MASTER_PORT=$((12000 + $RANDOM % 20000)) 41 | gpures="${gpu//[^,]}" 42 | NUM_GPU="${#gpures}" 43 | NUM_GPU=$((NUM_GPU+1)) 44 | extra_opt="${extra_opt} --gpu_ids ${gpu}" 45 | 46 | python -W ignore -m torch.distributed.launch --nproc_per_node=${NUM_GPU} --master_port=${MASTER_PORT} train_dist.py ${cmd} 47 | -------------------------------------------------------------------------------- /scripts/zju/387_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | id=$1 4 | gpu="${id}" 5 | 6 | method="--use_style --learn_3d --use_posmap --uvVol_2d --nr_pose_dep_uv" 7 | 8 | config="--config zju/motion_387_fv.yml --project_config uvm.yml --method_config motion.yml" 9 | 10 | render="--superreso LightSup --vrnr" 11 | setup="--motion_mode --uv_type SMPL --uv_reso 256 --batchSize 1" 12 | net="--posenet_outdim 96 --tex_latent_dim 16 --ab_uvh_plane_c 32 --nerf_dim 32 --style_dim 256" 13 | 14 | ab="--c_velo --c_acce --c_traj --velocity 1 --pred_pose_uv --pred_normal_uv --ab_pred_pose_by_velocity --new_dynamics --pred_pose_uv_rot --rot_all_same" 15 | 16 | aug="--small_rot --aug_nr" 17 | 18 | modelname="pretrained_387_Mp3dS" 19 | debug="" 20 | 21 | epoch="--niter 200 --niter_decay 0" 22 | reso="--gen_ratio 0.5 --nerf_ratio 0.25" 23 | 24 | improve="--ab_cond_uv_latent --ab_D_pose" 25 | 26 | basic="--N_samples 28 --uvVol_smpl_pts 8 --distributed --use_org_discrim --use_org_gan_loss --learn_uv --ab_uvh_plane --plus_uvh_enc --ab_nerf_rec --load_tmp_rendering" 27 | 28 | w="--w_D 1.0 --w_G_GAN 1.0 --w_Face 5 --w_G_L1 0.5 --w_G_feat 10 --w_nerf_rec 15 --w_pred_normal_uv 1.0 --w_rot_normal 1.0 --w_posmap_feat 0 --w_posmap 1" 29 | 30 | cmd="${basic} ${setup} ${net} ${ab} ${w} ${debug} ${epoch} ${reso} ${method} ${aug} ${render} ${improve} ${config} --name $modelname" 31 | 32 | if [ -z "$gpu" ] 33 | then 34 | gpu=$CUDA_VISIBLE_DEVICES 35 | fi 36 | export CUDA_DEVICE_ORDER=PCI_BUS_ID 37 | export NVIDIA_VISIBLE_DEVICES=${gpu} 38 | export CUDA_VISIBLE_DEVICES=${gpu} 39 | 40 | python -W ignore test.py ${cmd} --test_step_size 30 --test_eval -------------------------------------------------------------------------------- /scripts/zju/387_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | id=$1 4 | gpu="${id}" 5 | 6 | method="--use_style --learn_3d --use_posmap --uvVol_2d --nr_pose_dep_uv" 7 | 8 | config="--config zju/motion_387_fv.yml --project_config uvm.yml --method_config motion.yml" 9 | 10 | render="--superreso LightSup --vrnr" 11 | setup="--motion_mode --uv_type SMPL --uv_reso 256 --batchSize 1" 12 | net="--posenet_outdim 96 --tex_latent_dim 16 --ab_uvh_plane_c 32 --nerf_dim 32 --style_dim 256" 13 | 14 | ab="--c_velo --c_acce --c_traj --velocity 1 --pred_pose_uv --pred_normal_uv --ab_pred_pose_by_velocity --new_dynamics --pred_pose_uv_rot --rot_all_same" 15 | 16 | aug="--small_rot --aug_nr" 17 | 18 | modelname="train_387_Mp3dS" 19 | debug="" 20 | 21 | epoch="--niter 200 --niter_decay 0" 22 | reso="--gen_ratio 0.5 --nerf_ratio 0.25" 23 | 24 | improve="--ab_cond_uv_latent --ab_D_pose" 25 | 26 | basic="--N_samples 28 --uvVol_smpl_pts 8 --distributed --use_org_discrim --use_org_gan_loss --learn_uv --ab_uvh_plane --plus_uvh_enc --ab_nerf_rec --load_tmp_rendering" 27 | 28 | w="--w_D 1.0 --w_G_GAN 1.0 --w_Face 5 --w_G_L1 0.5 --w_G_feat 10 --w_nerf_rec 15 --w_pred_normal_uv 1.0 --w_rot_normal 1.0 --w_posmap_feat 0 --w_posmap 1" 29 | 30 | cmd="${basic} ${setup} ${net} ${ab} ${w} ${debug} ${epoch} ${reso} ${method} ${aug} ${render} ${improve} ${config} --name $modelname" 31 | 32 | if [ -z "$gpu" ] 33 | then 34 | gpu=$CUDA_VISIBLE_DEVICES 35 | fi 36 | 37 | export NVIDIA_VISIBLE_DEVICES=${gpu} 38 | export CUDA_VISIBLE_DEVICES=${gpu} 39 | 40 | MASTER_PORT=$((12000 + $RANDOM % 20000)) 41 | gpures="${gpu//[^,]}" 42 | NUM_GPU="${#gpures}" 43 | NUM_GPU=$((NUM_GPU+1)) 44 | extra_opt="${extra_opt} --gpu_ids ${gpu}" 45 | 46 | python -W ignore -m torch.distributed.launch --nproc_per_node=${NUM_GPU} --master_port=${MASTER_PORT} train_dist.py ${cmd} 47 | -------------------------------------------------------------------------------- /scripts/zju/394_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | id=$1 4 | gpu="${id}" 5 | 6 | method="--use_style --learn_3d --use_posmap --uvVol_2d --nr_pose_dep_uv" 7 | 8 | config="--config zju/motion_394_fv.yml --project_config uvm.yml --method_config motion.yml" 9 | 10 | render="--superreso LightSup --vrnr" 11 | setup="--motion_mode --uv_type SMPL --uv_reso 256 --batchSize 1" 12 | net="--posenet_outdim 96 --tex_latent_dim 16 --ab_uvh_plane_c 32 --nerf_dim 32 --style_dim 256" 13 | 14 | ab="--c_velo --c_acce --c_traj --velocity 1 --pred_pose_uv --pred_normal_uv --ab_pred_pose_by_velocity --new_dynamics --pred_pose_uv_rot --rot_all_same" 15 | 16 | aug="--small_rot --aug_nr" 17 | 18 | modelname="pretrained_394_Mp3dS" 19 | debug="" 20 | 21 | epoch="--niter 200 --niter_decay 0" 22 | reso="--gen_ratio 0.5 --nerf_ratio 0.25" 23 | 24 | improve="--ab_cond_uv_latent --ab_D_pose" 25 | 26 | basic="--N_samples 28 --uvVol_smpl_pts 8 --distributed --use_org_discrim --use_org_gan_loss --learn_uv --ab_uvh_plane --plus_uvh_enc --ab_nerf_rec --load_tmp_rendering" 27 | 28 | w="--w_D 1.0 --w_G_GAN 1.0 --w_Face 5 --w_G_L1 0.5 --w_G_feat 10 --w_nerf_rec 15 --w_pred_normal_uv 1.0 --w_rot_normal 1.0 --w_posmap_feat 0 --w_posmap 1" 29 | 30 | cmd="${basic} ${setup} ${net} ${ab} ${w} ${debug} ${epoch} ${reso} ${method} ${aug} ${render} ${improve} ${config} --name $modelname" 31 | 32 | if [ -z "$gpu" ] 33 | then 34 | gpu=$CUDA_VISIBLE_DEVICES 35 | fi 36 | export CUDA_DEVICE_ORDER=PCI_BUS_ID 37 | export NVIDIA_VISIBLE_DEVICES=${gpu} 38 | export CUDA_VISIBLE_DEVICES=${gpu} 39 | 40 | python -W ignore test.py ${cmd} --test_step_size 30 --test_eval -------------------------------------------------------------------------------- /scripts/zju/394_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | id=$1 4 | gpu="${id}" 5 | 6 | method="--use_style --learn_3d --use_posmap --uvVol_2d --nr_pose_dep_uv" 7 | 8 | config="--config zju/motion_394_fv.yml --project_config uvm.yml --method_config motion.yml" 9 | 10 | render="--superreso LightSup --vrnr" 11 | setup="--motion_mode --uv_type SMPL --uv_reso 256 --batchSize 1" 12 | net="--posenet_outdim 96 --tex_latent_dim 16 --ab_uvh_plane_c 32 --nerf_dim 32 --style_dim 256" 13 | 14 | ab="--c_velo --c_acce --c_traj --velocity 1 --pred_pose_uv --pred_normal_uv --ab_pred_pose_by_velocity --new_dynamics --pred_pose_uv_rot --rot_all_same" 15 | 16 | aug="--small_rot --aug_nr" 17 | 18 | modelname="train_394_Mp3dS" 19 | debug="" 20 | 21 | epoch="--niter 200 --niter_decay 0" 22 | reso="--gen_ratio 0.5 --nerf_ratio 0.25" 23 | 24 | improve="--ab_cond_uv_latent --ab_D_pose" 25 | 26 | basic="--N_samples 28 --uvVol_smpl_pts 8 --distributed --use_org_discrim --use_org_gan_loss --learn_uv --ab_uvh_plane --plus_uvh_enc --ab_nerf_rec --load_tmp_rendering" 27 | 28 | w="--w_D 1.0 --w_G_GAN 1.0 --w_Face 5 --w_G_L1 0.5 --w_G_feat 10 --w_nerf_rec 15 --w_pred_normal_uv 1.0 --w_rot_normal 1.0 --w_posmap_feat 0 --w_posmap 1" 29 | 30 | cmd="${basic} ${setup} ${net} ${ab} ${w} ${debug} ${epoch} ${reso} ${method} ${aug} ${render} ${improve} ${config} --name $modelname" 31 | 32 | if [ -z "$gpu" ] 33 | then 34 | gpu=$CUDA_VISIBLE_DEVICES 35 | fi 36 | 37 | export NVIDIA_VISIBLE_DEVICES=${gpu} 38 | export CUDA_VISIBLE_DEVICES=${gpu} 39 | 40 | MASTER_PORT=$((12000 + $RANDOM % 20000)) 41 | gpures="${gpu//[^,]}" 42 | NUM_GPU="${#gpures}" 43 | NUM_GPU=$((NUM_GPU+1)) 44 | extra_opt="${extra_opt} --gpu_ids ${gpu}" 45 | 46 | python -W ignore -m torch.distributed.launch --nproc_per_node=${NUM_GPU} --master_port=${MASTER_PORT} train_dist.py ${cmd} 47 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | 4 | import sys 5 | 6 | import torch 7 | 8 | from uvm_lib.data.data_loader import CreateDataLoader 9 | from uvm_lib.models.models import create_model 10 | from uvm_lib.util.visualizer import Visualizer 11 | from uvm_lib.util import html 12 | 13 | from uvm_lib.options.test_option import ProjectOptions 14 | 15 | 16 | if __name__ == "__main__": 17 | 18 | opt = ProjectOptions().parse(save=False) 19 | opt.nThreads = 0 # test code only supports nThreads = 1 20 | opt.batchSize = 1 # test code only supports batchSize = 1 21 | opt.serial_batches = True # no shuffle 22 | opt.no_flip = True # no flip 23 | opt.is_inference = True 24 | 25 | print("test ", opt.gpu_ids) 26 | 27 | torch.cuda.set_device(opt.gpu_ids[0]) 28 | 29 | opt.phase = "test" 30 | 31 | if opt.no_label: 32 | opt.render_with_dp_label = False 33 | 34 | data_loader = CreateDataLoader(opt, opt.phase) 35 | dataset = data_loader.load_data() 36 | 37 | visualizer = Visualizer(opt) 38 | 39 | exp_name = opt.name 40 | 41 | which_epoch = opt.which_epoch 42 | 43 | model = create_model(opt).cuda().module 44 | 45 | if opt.test_eval: 46 | model_name = "ema_latest" if opt.which_epoch == "-1" else "ema_%s" % opt.which_epoch 47 | else: model_name = opt.which_epoch 48 | 49 | test_epoch, epoch_iter = model.load_all(model_name, True) 50 | opt.which_epoch = test_epoch #- 1 51 | 52 | if test_epoch == -1: 53 | test_epoch, epoch_iter = 1, 0 54 | print("test model not trained") 55 | 56 | if not opt.save_tmp_rendering: 57 | exit() 58 | 59 | print(test_epoch) 60 | which_epoch = test_epoch #- 1 61 | 62 | view_num = len(opt.multiview_ids) 63 | 64 | #to remove 65 | opt.multiview_ids = opt.multiview_ids 66 | 67 | 68 | for view_id in opt.multiview_ids: 69 | 70 | web_dir = os.path.join(opt.results_dir, exp_name) 71 | webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (exp_name, opt.phase, which_epoch)) 72 | 73 | result_dir = os.path.join(opt.results_dir, exp_name, "images") 74 | result_image_num = len(os.listdir(result_dir)) 75 | test_image_num = len(dataset) 76 | 77 | # test 78 | if not opt.engine and not opt.onnx: 79 | if opt.data_type == 16: 80 | model.half() 81 | elif opt.data_type == 8: 82 | model.type(torch.uint8) 83 | 84 | if opt.verbose: 85 | print(model) 86 | else: 87 | t = 0 88 | 89 | model.eval() 90 | 91 | for i, data in enumerate(dataset): 92 | 93 | minibatch = 1 94 | 95 | data_vid = data["cam_ind"].cpu().numpy()[0] 96 | if data_vid != int(view_id): continue 97 | 98 | with torch.no_grad(): 99 | generated = model.inference(data) 100 | 101 | model.compute_visuals(which_epoch) 102 | img_idx = data['frame_index'][0].cpu().numpy() 103 | dataset_id = model.dataset_id if isinstance(model.dataset_id, int) else model.dataset_id[0].cpu().numpy() 104 | 105 | img_idx = "d%s_%04d" % (dataset_id, img_idx) 106 | 107 | print('process image... %s' % img_idx) 108 | visualizer.save_images(webpage, model.get_current_visuals(), img_idx, "test") 109 | 110 | print("test finished") 111 | 112 | webpage.save() 113 | 114 | sys.exit() -------------------------------------------------------------------------------- /uvm_lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TaoHuUMD/SurMo/ef68beea0a4615a85cceecaa35472d7525e592fb/uvm_lib/__init__.py -------------------------------------------------------------------------------- /uvm_lib/base_options/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TaoHuUMD/SurMo/ef68beea0a4615a85cceecaa35472d7525e592fb/uvm_lib/base_options/__init__.py -------------------------------------------------------------------------------- /uvm_lib/base_options/evaluate_options.py: -------------------------------------------------------------------------------- 1 | from .test_options import TestOptions 2 | 3 | class EvaluateOptions(TestOptions): 4 | def initialize(self): 5 | TestOptions.initialize(self) 6 | 7 | #self.parser.add_argument('--model_name', type=str, default=None) 8 | self.parser.add_argument('--l1', action='store_true') 9 | self.parser.add_argument('--ssim', action='store_true') 10 | self.parser.add_argument('--fid', action='store_true') 11 | self.parser.add_argument('--lpip', action='store_true') 12 | self.parser.add_argument('--psnr', action='store_true') 13 | 14 | self.parser.add_argument('--make_video', action='store_true') 15 | 16 | self.parser.add_argument('--eval_dp', action='store_true') 17 | self.parser.add_argument('--dp_dir', type=str, default='') 18 | 19 | self.parser.add_argument('--use_gpu', type=int, default=1) 20 | self.parser.add_argument('-v', '--version', type=str, default='0.1') 21 | 22 | self.parser.add_argument('--crop_bbox', action='store_true') 23 | 24 | self.isTrain = False 25 | print("evaluations") 26 | #if opt.subdir != '': 27 | # opt.results_dir = os.path.join(opt.results_dir, opt.subdir) 28 | -------------------------------------------------------------------------------- /uvm_lib/base_options/motion_setup.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | 5 | def set_motion(parser): 6 | 7 | motion = parser.add_argument_group('motion') 8 | 9 | motion.add_argument("--infer_velocity", type=float, default=1, help='') 10 | 11 | 12 | motion.add_argument("--motion_chain", action='store_true', help='') 13 | motion.add_argument("--motion_steps", type=str, default="-25 -20 -15 -10 -5 -1 5 10 15 20 25", help='previous motion status') 14 | 15 | motion.add_argument("--motion_point", type=int, default=0, help='start or end point of a motion, <0, start, >0 end') 16 | 17 | 18 | motion.add_argument("--style_dim", type=int, default=256, help='') 19 | motion.add_argument("--ab_uvh_plane_c", type=int, default=16, help='dim of vh, uh plane') 20 | motion.add_argument("--nerf_dim", type=int, default=32, help='dim of nerf input') 21 | 22 | motion.add_argument('--aug_random_flip', action='store_true', help="") 23 | 24 | motion.add_argument("--use_global_posemap", action='store_true', help='whether use global verts in posemap') 25 | 26 | motion.add_argument("--is_pad_img", action='store_true', help='pad rectangle image to square') 27 | 28 | motion.add_argument("--ab_sup_only_dynamic_tex", action='store_true', help='no nerf, only supreso posemap out') 29 | 30 | motion.add_argument("--ab_sup_only_static_style", action='store_true', help='no nerf, only supreso on style') 31 | 32 | 33 | motion.add_argument("--ab_sup_2d_style", action='store_true', help='sup net only condition on 2d style latent') 34 | 35 | 36 | motion.add_argument("--dual_discrim_eg3d", action='store_true', help='#whether D on tex field') 37 | 38 | 39 | motion.add_argument("--use_org_gan_loss", action='store_true', help='#') 40 | motion.add_argument("--use_org_discrim", action='store_true', help='#') 41 | 42 | 43 | motion.add_argument("--ab_Dtex", action='store_true', help='#whether D on tex field') 44 | motion.add_argument("--ab_Dtex_pose", action='store_true', help='#whether Dtex is conditioned on pose map') 45 | 46 | motion.add_argument("--ab_uvh_plane", action='store_true', help='uvh tri-plane') 47 | motion.add_argument("--ab_nerf_rec", action='store_true', help='whether rec loss on nerf') 48 | 49 | motion.add_argument("--ab_Ddual", action='store_true', help='dual discriminator') 50 | motion.add_argument("--ab_D_pose", action='store_true', help='D cond on pose') 51 | motion.add_argument("--ab_tex_rec", action='store_true', help='uv tex recon') 52 | 53 | motion.add_argument("--D_label_noise", action='store_true', help='add label noise for D') 54 | motion.add_argument("--D_noise_factor", type=float, default=0.05, help='add label noise for D') 55 | 56 | motion.add_argument("--debug_data_size", type=int, default=10, help='debug small dataset') 57 | 58 | motion.add_argument("--abandon", action='store_true', help='not used') 59 | 60 | 61 | motion.add_argument("--ab_cond_uv_latent", action='store_true', help='super reso cond on 2d uv lat') 62 | motion.add_argument("--general_superreso", action='store_true', help='super reso cond on 2d uv lat') 63 | 64 | #in the future 65 | motion.add_argument("--deep_nerf", action='store_true', help='3 layer nerf network') 66 | motion.add_argument("--ab_cond_1d_lat", action='store_true', help='super reso cond on 1d style') 67 | -------------------------------------------------------------------------------- /uvm_lib/base_options/other_setup.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TaoHuUMD/SurMo/ef68beea0a4615a85cceecaa35472d7525e592fb/uvm_lib/base_options/other_setup.py -------------------------------------------------------------------------------- /uvm_lib/base_options/test_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | class TestOptions(BaseOptions): 4 | def initialize(self): 5 | BaseOptions.initialize(self) 6 | self.parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.') 7 | 8 | self.parser.add_argument('--results_dir', type=str, default='./data/result/test_output', help='saves results here.') 9 | self.parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images') 10 | self.parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') 11 | self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') 12 | self.parser.add_argument('--how_many', type=int, default=500000, help='how many test images to run') 13 | self.parser.add_argument('--cluster_path', type=str, default='features_clustered_010.npy', help='the path for clustered results of encoded features') 14 | self.parser.add_argument('--use_encoded_image', action='store_true', help='if specified, encode the real image to get the feature map') 15 | self.parser.add_argument("--export_onnx", type=str, help="export ONNX model to a given file") 16 | self.parser.add_argument("--engine", type=str, help="run serialized TRT engine") 17 | self.parser.add_argument("--onnx", type=str, help="run ONNX model via TRT") 18 | 19 | self.parser.add_argument("--pid1", type=str, help="pairid1") 20 | self.parser.add_argument("--pid2", type=str, help="pairid1") 21 | 22 | self.parser.add_argument("--test_epoch", type=str, help="which epoch to test", default='') 23 | self.parser.add_argument("--save_name", type=str, help="default project name", default='') 24 | 25 | self.isTrain = False 26 | 27 | -------------------------------------------------------------------------------- /uvm_lib/base_options/train_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | class TrainOptions(BaseOptions): 4 | def initialize(self): 5 | BaseOptions.initialize(self) 6 | # for displays 7 | self.parser.add_argument('--display_freq', type=int, default=50, help='frequency of showing training results on screen') 8 | self.parser.add_argument('--print_freq', type=int, default=50, help='frequency of showing training results on console') 9 | self.parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results') 10 | self.parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') 11 | 12 | self.parser.add_argument('--eva_epoch_freq', type=int, help='evaluation freq') 13 | self.parser.add_argument('--save_epoch_freq', type=int, default=2, help='frequency of saving checkpoints at the end of epochs') 14 | self.parser.add_argument('--save_latest_epoch_freq', type=int, default=5000, help='frequency of saving the latest results') 15 | self.parser.add_argument('--display_epoch_freq', type=int, default=0, help='frequency of showing training results on screen') 16 | 17 | # for training 18 | self.parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') 19 | self.parser.add_argument('--load_pretrain', type=str, default='', help='load the pretrained model from the specified location') 20 | self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') 21 | self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') 22 | 23 | self.isTrain = True 24 | -------------------------------------------------------------------------------- /uvm_lib/base_options/vrnr_setup.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | def vrnr_init(parser): 5 | 6 | parser.add_argument('--only_nerf', action='store_true') 7 | parser.add_argument('--only_gen', action='store_true') 8 | parser.add_argument('--use_nerf', action='store_true') 9 | parser.add_argument('--use_gen', action='store_true') 10 | 11 | parser.add_argument('--uv2pix', action='store_true', help="pix2pix") 12 | 13 | parser.add_argument('--id_num_total', type=int, default=6) # 3 channels 14 | 15 | parser.add_argument('--share_3D', action='store_true') 16 | 17 | parser.add_argument('--nerf_light', action='store_true', help="lighting in nerf") 18 | parser.add_argument('--not_latent_uv', action='store_true', help="3d vts latetnt in nerf") 19 | #parser.add_argument('--uv_normal', action='store_true', help="unknown") 20 | 21 | parser.add_argument('--local_only_body', type=bool, default=True, help="sample points only on body local nerf") 22 | parser.add_argument('--use_dilate_model', type=bool, default=True, help="sample points only on body local nerf") 23 | parser.add_argument('--is_opt_mask', action='store_true', help="whether opt nerf mask") 24 | 25 | parser.add_argument('--white_vgg', action='store_true', help="vgg image mask") 26 | 27 | parser.add_argument('--swap_tex', action='store_true', help="vgg image mask") 28 | 29 | parser.add_argument('--not_even', type=bool, default=True, help="sample points only on body local nerf") 30 | parser.add_argument('--N_samples', type=int, default=24, help="") 31 | 32 | #parser.add_argument('--dataset_name', type=str, help="dataset name") 33 | 34 | parser.add_argument('--nerf_reso', type=int, default=64, help="") 35 | parser.add_argument('--gen_reso', type=int, default=256, help="") 36 | parser.add_argument('--nerf_ratio', type=float, default=0, help="") 37 | parser.add_argument('--gen_ratio', type=float, default=0, help="") 38 | 39 | parser.add_argument('--is_smooth_nerf_latent', action='store_true', help="whether smooth nerf latent") 40 | 41 | parser.add_argument('--In_Canonical', action='store_true', help="Canonical nerf") 42 | parser.add_argument('--check_mesh', action='store_true', help="check can mesh") 43 | parser.add_argument('--check_can_mesh', action='store_true', help="check can mesh") 44 | 45 | parser.add_argument('--no_local_nerf', action='store_true', help="global nerf sampling") 46 | parser.add_argument('--use_img_feature', action='store_true', help="global nerf sampling") 47 | parser.add_argument('--sample_all_pixels', action='store_true', help="global nerf sampling") 48 | 49 | #nerf 50 | parser.add_argument('--not_use_conv3d', action='store_true', help="conv3d in nerf") 51 | parser.add_argument('--not_pose_cond', action='store_true', help="not pose condit") 52 | 53 | #generator neural rendering 54 | parser.add_argument('--no_encoder', action='store_true', help="encoder in nr") 55 | parser.add_argument('--pred_depth', action='store_true', help="nr predits depth") 56 | #parser.add_argument('--pred_normal_uv', action='store_true', help="nr predits normal") 57 | 58 | parser.add_argument('--is_cat_max_mean', action='store_true', help="fuse nr and nerf features") 59 | 60 | parser.add_argument('--begin_i', type=int, default = 10, help="vrnr") 61 | parser.add_argument('--ni', type=int, default = 200, help="vrnr") 62 | parser.add_argument('--N_rand', type=int, default = 64, help="vrnr") 63 | parser.add_argument('--voxel_size', type=float, default=0.005, help="weight") 64 | 65 | parser.add_argument('--nrays', type=int, default = 1024, help = "number of rays.. not used in vrnr") 66 | 67 | parser.add_argument('--w_nerf', type=float, default=3.0, help="weight") 68 | 69 | parser.add_argument('--use_small_dilation', type=bool, default=True, help="whether multiple identities") 70 | 71 | parser.add_argument('--max_ray_interval', type=float, default=0.25*0.2, help="local sampling interval") 72 | parser.add_argument('--perturb', type=int, default=1, help="perturb training") 73 | 74 | parser.add_argument('--white_bkgd', action='store_true', help="train snapshot") 75 | parser.add_argument('--raw_noise_std', type=float, default=0, help="weight") 76 | 77 | parser.add_argument('--uvVol_smpl_pts', type=float, default=5, help="train snapshot") 78 | 79 | #dataset 80 | parser.add_argument('--dataset_step', type=int, default=1, help="split dataset") 81 | parser.add_argument('--org_img_reso', type=int, default=1024, help="") 82 | 83 | parser.add_argument('--train_rgbd', action='store_true', help="train snapshot") 84 | 85 | parser.add_argument('--train_snap', action='store_true', help="train snapshot") 86 | parser.add_argument('--multi_id', type=bool, default=True, help="whether multiple identities") 87 | parser.add_argument('--id_num', type=int, default=1, help="identi number in training") 88 | parser.add_argument('--uvdim', type=int, default=16, help="uv latent dim") 89 | 90 | parser.add_argument('--debug_1pose_all_views', action='store_true', help="debug 1 frames") 91 | parser.add_argument('--vrnr_swap2', action='store_true', help="swap textures") 92 | 93 | parser.add_argument('--use_density_th', action='store_true', help="local nerf new dist") 94 | 95 | parser.add_argument('--use_new_dist', action='store_true', help="local nerf new dist") 96 | parser.add_argument('--par_density_norm', type=float, default=1000, help="uv latent dim") 97 | parser.add_argument('--density_th', type=float, default=0.5, help="uv latent dim") 98 | 99 | def dataset_setup(cfg): 100 | 101 | cfg.data_list = cfg.data_list.split(" ") 102 | 103 | dataset_id = {} 104 | for i in range(len(cfg.data_list)): 105 | dataset_id.update({cfg.data_list[i]: i}) 106 | 107 | cfg.dataset_id = [dataset_id] 108 | 109 | 110 | 111 | def vrnr_parse(cfg): 112 | 113 | cfg.use_nerf = True 114 | cfg.use_gen = True 115 | cfg.use_face = True 116 | 117 | if cfg.gen_ratio > 1: 118 | cfg.gen_ratio = 1 / cfg.gen_ratio 119 | 120 | if cfg.nerf_ratio > 1: 121 | cfg.nerf_ratio = 1 / cfg.nerf_ratio 122 | 123 | if cfg.gen_ratio == 0: cfg.gen_ratio = cfg.gen_reso / cfg.org_img_reso 124 | if cfg.nerf_ratio == 0: cfg.nerf_ratio = cfg.nerf_reso / cfg.org_img_reso 125 | 126 | def make_dataset_id(cfg): 127 | 128 | cfg.dataset_id = [{ 129 | "xyzc_394": 0, 130 | "xyzc_377": 1, 131 | "xyzc_313": 2, 132 | "xyzc_386": 3, 133 | "xyzc_392": 4, 134 | "xyzc_311": 5 135 | }] 136 | 137 | cfg.dataset_id_swap = cfg.dataset_id_swap1 138 | -------------------------------------------------------------------------------- /uvm_lib/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TaoHuUMD/SurMo/ef68beea0a4615a85cceecaa35472d7525e592fb/uvm_lib/data/__init__.py -------------------------------------------------------------------------------- /uvm_lib/data/base_data_loader.py: -------------------------------------------------------------------------------- 1 | 2 | class BaseDataLoader(): 3 | def __init__(self): 4 | pass 5 | 6 | def initialize(self, opt): 7 | self.opt = opt 8 | pass 9 | 10 | def load_data(): 11 | return None 12 | 13 | 14 | 15 | -------------------------------------------------------------------------------- /uvm_lib/data/base_dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | from PIL import Image 3 | import torchvision.transforms as transforms 4 | import numpy as np 5 | import random 6 | 7 | from Engine.th_utils.distributed.sampler import data_sampler 8 | 9 | class BaseDataset(data.Dataset): 10 | def __init__(self): 11 | super(BaseDataset, self).__init__() 12 | 13 | def name(self): 14 | return 'BaseDataset' 15 | 16 | def initialize(self, opt): 17 | pass 18 | 19 | def get_train_sampler(self): 20 | print("base class data sampler") 21 | self.train_sampler = data_sampler(self, shuffle=True, distributed = self.opt.training.distributed) 22 | return self.train_sampler 23 | 24 | 25 | def get_params(opt, size): 26 | w, h = size 27 | new_h = h 28 | new_w = w 29 | if opt.resize_or_crop == 'resize_and_crop': 30 | new_h = new_w = opt.loadSize 31 | elif opt.resize_or_crop == 'scale_width_and_crop': 32 | new_w = opt.loadSize 33 | new_h = opt.loadSize * h // w 34 | 35 | x = random.randint(0, np.maximum(0, new_w - opt.fineSize)) 36 | y = random.randint(0, np.maximum(0, new_h - opt.fineSize)) 37 | 38 | flip = random.random() > 0.5 39 | return {'crop_pos': (x, y), 'flip': flip} 40 | 41 | def get_transform(opt, params, method=Image.BICUBIC, normalize=True, dp = False): 42 | transform_list = [] 43 | if 'resize' in opt.resize_or_crop: 44 | osize = [opt.loadSize, opt.loadSize] 45 | transform_list.append(transforms.Scale(osize, Image.NEAREST)) 46 | elif 'scale_width' in opt.resize_or_crop: 47 | transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.loadSize, method))) 48 | 49 | if 'crop' in opt.resize_or_crop: 50 | transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.fineSize))) 51 | 52 | if opt.resize_or_crop == 'none': 53 | base = float(2 ** opt.n_downsample_global) 54 | if opt.netG == 'local': 55 | base *= (2 ** opt.n_local_enhancers) 56 | transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method))) 57 | 58 | if opt.isTrain and not opt.no_flip: 59 | transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip']))) 60 | 61 | if not dp: 62 | transform_list += [transforms.ToTensor()] 63 | 64 | if normalize: 65 | transform_list += [transforms.Normalize((0.5, 0.5, 0.5), 66 | (0.5, 0.5, 0.5))] 67 | return transforms.Compose(transform_list) 68 | 69 | def get_transform_tensor(): 70 | transform_list = [] 71 | transform_list += [transforms.ToTensor()] 72 | 73 | transform_list += [transforms.Normalize((0.5, 0.5, 0.5), 74 | (0.5, 0.5, 0.5))] 75 | return transforms.Compose(transform_list) 76 | 77 | def normalize(): 78 | return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 79 | 80 | def __make_power_2(img, base, method=Image.BICUBIC): 81 | ow, oh = img.size 82 | h = int(round(oh / base) * base) 83 | w = int(round(ow / base) * base) 84 | if (h == oh) and (w == ow): 85 | return img 86 | return img.resize((w, h), method) 87 | 88 | def __scale_width(img, target_width, method=Image.BICUBIC): 89 | ow, oh = img.size 90 | if (ow == target_width): 91 | return img 92 | w = target_width 93 | h = int(target_width * oh / ow) 94 | return img.resize((w, h), method) 95 | 96 | def __crop(img, pos, size): 97 | ow, oh = img.size 98 | x1, y1 = pos 99 | tw = th = size 100 | if (ow > tw or oh > th): 101 | return img.crop((x1, y1, x1 + tw, y1 + th)) 102 | return img 103 | 104 | def __flip(img, flip): 105 | if flip: 106 | return img.transpose(Image.FLIP_LEFT_RIGHT) 107 | return img 108 | -------------------------------------------------------------------------------- /uvm_lib/data/custom_dataset_data_loader.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data 2 | from .base_data_loader import BaseDataLoader 3 | 4 | import importlib 5 | from munch import * 6 | 7 | def CreateDataset(opt, phase): 8 | dataset = None 9 | 10 | dataset_name = None 11 | if isinstance(opt.dataset, dict) or isinstance(opt.dataset, Munch): 12 | dataset_name = opt.dataset.dataset_name 13 | else: 14 | dataset_name = opt.dataset 15 | 16 | print("load dataset ", dataset_name) 17 | 18 | proj_dir = opt.project_directory 19 | 20 | dataset_name_list = ["zju", "mpi", "aist", "mpi_free"] 21 | data_dict = {} 22 | for s in dataset_name_list: 23 | data_dict.update({s: f'{proj_dir}.data.dataset_{s}'}) 24 | 25 | 26 | if dataset_name in dataset_name_list: 27 | #from data_dict[opt.dataset] import 28 | dataset = importlib.import_module(data_dict[dataset_name]).Dataset() 29 | dataset.initialize(opt, phase, opt.multi_datasets[0]) 30 | print("dataset [%s] was created" % (dataset_name)) 31 | return dataset 32 | else: 33 | raise NotImplementedError() 34 | 35 | 36 | class CustomDatasetDataLoader(BaseDataLoader): 37 | def name(self): 38 | return 'CustomDatasetDataLoader' 39 | 40 | def initialize(self, opt, phase="train"): 41 | BaseDataLoader.initialize(self, opt) 42 | self.dataset = CreateDataset(opt, phase) 43 | 44 | batch_size = opt.batchSize 45 | nthreads = int(opt.nThreads) 46 | if phase == "evaluate": 47 | nthreads = 0 48 | batch_size = 1 49 | 50 | self.dataloader = torch.utils.data.DataLoader( 51 | self.dataset, 52 | batch_size = batch_size, 53 | shuffle=not opt.serial_batches, 54 | #sampler = self.dataset.get_train_sampler() if phase == "train" else None, 55 | num_workers = int(nthreads)) 56 | 57 | def load_data(self): 58 | return self.dataloader 59 | 60 | def __len__(self): 61 | return len(self.dataset) 62 | #return min(len(self.dataset), self.opt.max_dataset_size) 63 | 64 | 65 | 66 | class DistributedCustomDatasetDataLoader(BaseDataLoader): 67 | def name(self): 68 | return 'DistributedCustomDatasetDataLoader' 69 | 70 | def initialize(self, opt, dataset, sampler, phase="train"): 71 | 72 | assert phase=="train" 73 | BaseDataLoader.initialize(self, opt) 74 | self.dataset = dataset 75 | 76 | nthreads = int(opt.nThreads) 77 | 78 | self.dataloader = torch.utils.data.DataLoader( 79 | dataset, 80 | batch_size= opt.batchSize, 81 | shuffle = True if sampler is None else False, 82 | sampler = sampler, 83 | num_workers = int(nthreads)) 84 | 85 | def load_data(self): 86 | return self.dataloader 87 | 88 | def __len__(self): 89 | return len(self.dataset) 90 | #return min(len(self.dataset), self.opt.max_dataset_size) 91 | -------------------------------------------------------------------------------- /uvm_lib/data/data_loader.py: -------------------------------------------------------------------------------- 1 | 2 | def CreateDataLoader(opt, phase="train"): 3 | from .custom_dataset_data_loader import CustomDatasetDataLoader 4 | data_loader = CustomDatasetDataLoader() 5 | print(data_loader.name()) 6 | data_loader.initialize(opt, phase) 7 | return data_loader 8 | 9 | def CreateDataLoaderDistributed(opt, dataset, sampler, phase="train"): 10 | from .custom_dataset_data_loader import DistributedCustomDatasetDataLoader 11 | data_loader = DistributedCustomDatasetDataLoader() 12 | print(data_loader.name()) 13 | data_loader.initialize(opt, dataset, sampler, phase) 14 | return data_loader 15 | 16 | def CreateDataset(opt, phase): 17 | from .custom_dataset_data_loader import CreateDataset as CreateDataset_ 18 | return CreateDataset_(opt, phase) -------------------------------------------------------------------------------- /uvm_lib/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TaoHuUMD/SurMo/ef68beea0a4615a85cceecaa35472d7525e592fb/uvm_lib/models/__init__.py -------------------------------------------------------------------------------- /uvm_lib/models/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import sys 4 | from collections import OrderedDict 5 | from ..util import util 6 | from torchsummary import summary 7 | 8 | class BaseModel(torch.nn.Module): 9 | def name(self): 10 | return 'BaseModel' 11 | 12 | def initialize(self, opt): 13 | self.opt = opt 14 | self.gpu_ids = opt.gpu_ids 15 | self.isTrain = opt.isTrain 16 | self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor 17 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) 18 | self.visual_names = [] 19 | self.model_names = [] 20 | self.loss_names = [] 21 | 22 | def set_input(self, input): 23 | self.input = input 24 | 25 | def evaluate(self, data): 26 | self.forward_org(data) 27 | 28 | def forward(self, data): 29 | pass 30 | 31 | # used in test time, no backprop 32 | def test(self): 33 | pass 34 | 35 | def get_image_paths(self): 36 | pass 37 | 38 | def optimize_parameters(self): 39 | pass 40 | 41 | def get_current_visuals(self): 42 | visual_ret = OrderedDict() 43 | if "visOutput" in self.visual_names: 44 | visual_ret["visOutput"] = getattr(self, "visOutput") 45 | return visual_ret 46 | 47 | for name in self.visual_names: 48 | if isinstance(name, str): 49 | visual_ret[name] = getattr(self, name) 50 | return visual_ret 51 | 52 | def tensor_to_viz(self, visuals): 53 | viz_mod = OrderedDict() 54 | for name in visuals: 55 | viz_mod[name] = util.tensor2im(visuals[name][0]) 56 | return viz_mod 57 | 58 | 59 | def get_current_losses(self): 60 | """Return traning losses / errors. train.py will print out these errors on console, and save them to a file""" 61 | errors_ret = OrderedDict() 62 | for name in self.loss_names: 63 | if isinstance(name, str): 64 | errors_ret[name] = float(getattr(self, 'loss_' + name,0)) # float(...) works for both scalar tensor and float number 65 | return errors_ret 66 | 67 | def get_current_errors(self): 68 | return {} 69 | 70 | def save(self, label): 71 | pass 72 | 73 | def save_networks(self, epoch): 74 | """Save all the networks to the disk. 75 | 76 | Parameters: 77 | epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) 78 | """ 79 | for name in self.model_names: 80 | if isinstance(name, str): 81 | save_filename = '%s_net_%s.pth' % (epoch, name) 82 | save_path = os.path.join(self.save_dir, save_filename) 83 | net = getattr(self, 'net' + name) 84 | torch.save(net.cpu().state_dict(), save_path) 85 | if len(self.gpu_ids) and torch.cuda.is_available(): 86 | net.cuda() 87 | 88 | 89 | # helper saving function that can be used by subclasses 90 | def save_network(self, network, network_label, epoch_label, gpu_ids): 91 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 92 | save_path = os.path.join(self.save_dir, save_filename) 93 | torch.save(network.cpu().state_dict(), save_path) 94 | if len(gpu_ids) and torch.cuda.is_available(): 95 | network.cuda() 96 | 97 | # helper loading function that can be used by subclasses 98 | def load_network(self, network, network_label, epoch_label, save_dir=''): 99 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 100 | if not save_dir: 101 | save_dir = self.save_dir 102 | save_path = os.path.join(save_dir, save_filename) 103 | if not os.path.isfile(save_path): 104 | print('%s not exists yet!' % save_path) 105 | if network_label == 'G': 106 | raise('Generator must exist!') 107 | else: 108 | #network.load_state_dict(torch.load(save_path)) 109 | try: 110 | network.load_state_dict(torch.load(save_path)) 111 | except: 112 | pretrained_dict = torch.load(save_path) 113 | model_dict = network.state_dict() 114 | try: 115 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 116 | network.load_state_dict(pretrained_dict) 117 | if self.opt.verbose: 118 | print('Pretrained network %s has excessive layers; Only loading layers that are used' % network_label) 119 | except: 120 | print('Pretrained network %s has fewer layers; The following are not initialized:' % network_label) 121 | for k, v in pretrained_dict.items(): 122 | if v.size() == model_dict[k].size(): 123 | model_dict[k] = v 124 | 125 | if sys.version_info >= (3,0): 126 | not_initialized = set() 127 | else: 128 | from sets import Set 129 | not_initialized = Set() 130 | 131 | for k, v in model_dict.items(): 132 | if k not in pretrained_dict or v.size() != pretrained_dict[k].size(): 133 | not_initialized.add(k.split('.')[0]) 134 | 135 | print(sorted(not_initialized)) 136 | network.load_state_dict(model_dict) 137 | 138 | def load_networks(self, epoch): 139 | """Load all the networks from the disk. 140 | 141 | Parameters: 142 | epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) 143 | """ 144 | for name in self.model_names: 145 | if isinstance(name, str): 146 | load_filename = '%s_net_%s.pth' % (epoch, name) 147 | load_path = os.path.join(self.save_dir, load_filename) 148 | net = getattr(self, 'net' + name) 149 | #if isinstance(net, torch.nn.DataParallel): 150 | # net = net.module 151 | print('loading the model from %s' % load_path) 152 | # if you are using PyTorch newer than 0.4 (e.g., built from 153 | # GitHub source), you can remove str() on self.device 154 | state_dict = torch.load(load_path) 155 | if hasattr(state_dict, '_metadata'): 156 | del state_dict._metadata 157 | 158 | # patch InstanceNorm checkpoints prior to 0.4 159 | # for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop 160 | # self.__patch_instance_norm_state_dict(state_dict, net, key.split('.')) 161 | net.load_state_dict(state_dict) 162 | 163 | def update_learning_rate(): 164 | pass 165 | -------------------------------------------------------------------------------- /uvm_lib/models/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import importlib 3 | 4 | def create_model(opt): 5 | 6 | #model = importlib.import_module(opt.model_module)() 7 | 8 | model_path = f'{opt.project_directory}.models.{opt.model_module}' 9 | print(model_path) 10 | 11 | if opt.isTrain: 12 | model = importlib.import_module(model_path).Model() 13 | else: 14 | model = importlib.import_module(model_path).Model().cuda() 15 | 16 | model.initialize(opt) 17 | 18 | if opt.verbose: 19 | print("model [%s] was created" % (model.name())) 20 | 21 | print('model opt, data parallel, gpu num ', len(opt.gpu_ids)) 22 | 23 | if opt.phase=="test": 24 | model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids).to(opt.gpu_ids[0]) 25 | return model 26 | else: #opt.training.distributed: 27 | model.cuda(opt.gpu_ids[0]) 28 | if(opt.phase=="evaluate"): 29 | return torch.nn.DataParallel(model, device_ids=opt.gpu_ids).to(opt.gpu_ids[0]) 30 | 31 | return model 32 | -------------------------------------------------------------------------------- /uvm_lib/models/net_nerf_uvMotion.py: -------------------------------------------------------------------------------- 1 | from pickle import TRUE 2 | from torch._C import device 3 | from torch.autograd import grad 4 | import torch.nn as nn 5 | #from ..depend import spconv 6 | import torch.nn.functional as F 7 | import torch 8 | 9 | 10 | from Engine.th_utils.networks import embedder 11 | 12 | #generator, discriminator. 13 | from Engine.th_utils.my_pytorch3d.smpl_util import SMPL_Util 14 | 15 | from Engine.th_utils.animation.uv_generator import Index_UV_Generator 16 | 17 | is_debug = False 18 | 19 | class HumanUVNerfMotion(nn.Module): 20 | def __init__(self, opt): 21 | 22 | super(HumanUVNerfMotion, self).__init__() 23 | 24 | self.opt = opt 25 | 26 | nerf_inputdim = opt.posenet_setup.posenet_outdim if self.opt.uv_2dplane else opt.posenet_setup.posenet_outdim // 3 27 | 28 | if self.opt.combine_pose_style: 29 | self.uvh_feat_dim = nerf_inputdim 30 | else: 31 | self.uvh_feat_dim = opt.posenet_setup.tex_latent_dim + nerf_inputdim 32 | if self.opt.posenet_setup.pred_texture_uv: 33 | self.uvh_feat_dim += self.opt.posenet_setup.pred_texture_dim 34 | 35 | if self.opt.debug_only_enc: 36 | self.uvh_feat_dim = 0 37 | 38 | self.output_rgb_dim = opt.motion.nerf_dim 39 | 40 | self.add_layer_density_color = opt.add_layer_density_color 41 | self.add_layer_geometry = opt.add_layer_geometry 42 | 43 | self.use_pose_cond = not opt.not_pose_cond 44 | 45 | self.input_dim = 0 46 | self.actvn = nn.ReLU() 47 | self.img2mse = lambda x, y : torch.mean((x - y) ** 2) 48 | 49 | self.voxel_size = 0.005 50 | 51 | self.h_bound = [-1.2, 1.1] 52 | 53 | if 'local_rank' in self.opt: 54 | self.gpu_ids = [self.opt.local_rank] 55 | elif 'gpu_ids' in self.opt: 56 | self.gpu_ids = self.opt.gpu_ids 57 | 58 | 59 | self.render_posmap = Index_UV_Generator(self.opt.posenet_setup.uv_reso, self.opt.posenet_setup.uv_reso, uv_type=self.opt.posenet_setup.uv_type, data_dir="../asset/data/uv_sampler") 60 | self.vts_uv = self.render_posmap.get_vts_uv().cuda().permute(0,2,1) 61 | self.vts_uv.requires_grad = False 62 | 63 | self.smpl_util = SMPL_Util(gender = self.opt.gender, faces_uvs=self.render_posmap.faces_uvs, verts_uvs = self.render_posmap.verts_uvs, smpl_uv_vts = self.vts_uv) 64 | 65 | self.uvdim = self.opt.posenet_setup.tex_latent_dim 66 | self.uv_reso = self.opt.posenet_setup.uv_reso 67 | 68 | self.add_nerf() 69 | 70 | def add_nerf(self): 71 | 72 | if self.opt.plus_uvh_enc: #yes. 73 | self.geo_fc_0 = nn.Conv1d(self.uvh_feat_dim + embedder.uvh_dim, 256, 1) 74 | else: 75 | self.geo_fc_0 = nn.Conv1d(self.uvh_feat_dim, 256, 1) 76 | 77 | self.geo_fc_1 = nn.Conv1d(256, 256, 1) 78 | 79 | self.alpha_fc = nn.Conv1d(256, 1, 1) 80 | 81 | self.view_fc_0 = nn.Conv1d(256 + self.uvh_feat_dim + embedder.view_dim, 256, 1) 82 | self.view_fc_1 = nn.Conv1d(256, 128, 1) 83 | self.rgb_fc = nn.Conv1d(128, self.output_rgb_dim, 1) 84 | 85 | #torch.nn.Softplus() 86 | 87 | def trans_to_uv(self, smpl_vertices, sampled_pts_smpl_space): 88 | k = 1 89 | 90 | debug = False 91 | if debug: 92 | sampled_pts_smpl_space = (sampled_pts_smpl_space[0][0])[None,None,...] 93 | 94 | near_uv, h = self.smpl_util.get_nearest_pts_in_mesh_torch(smpl_vertices, sampled_pts_smpl_space, k=k, num_samples = self.opt.uvVol_smpl_pts) 95 | 96 | return near_uv, h 97 | 98 | def extract_uvstyle_features(self): 99 | t = 0 100 | 101 | def get_posmap_loss(self): 102 | return self.posmap_loss 103 | 104 | def gen_UVlatent(self, batch, input_latent): 105 | 106 | sampled_pts_smpl_space = batch["sampled_pts_smpl"] 107 | smpl_vertices = batch["smpl_vertices"] 108 | 109 | device = sampled_pts_smpl_space.device 110 | 111 | near_uv, h_pred = self.trans_to_uv(smpl_vertices, sampled_pts_smpl_space) 112 | 113 | batch_size = near_uv.shape[0] 114 | 115 | b,c,h,w = input_latent.shape 116 | 117 | if self.opt.combine_pose_style: 118 | uvh_plane = input_latent 119 | else:#style 120 | uvh_plane = input_latent[:,:self.opt.posenet_setup.posenet_outdim, ...] 121 | style_uv = input_latent[:, self.opt.posenet_setup.posenet_outdim: , ...] 122 | c -= self.opt.posenet_setup.tex_latent_dim 123 | 124 | if self.opt.uv_2dplane: 125 | pts_huv = torch.cat((h_pred, near_uv), 2) 126 | 127 | input = uvh_plane 128 | if not self.opt.combine_pose_style: 129 | input = torch.cat((input, style_uv), 1) 130 | 131 | fused_feat = self.render_posmap.index_posmap_by_vts(input, near_uv) 132 | 133 | huvlat = embedder.uvh_embedder(pts_huv).permute(0,2,1) 134 | return fused_feat, huvlat, near_uv 135 | 136 | else: 137 | 138 | self.opt.motion.ab_uvh_plane_c = self.opt.posenet_setup.posenet_outdim // 3 139 | c = self.opt.posenet_setup.posenet_outdim 140 | uh_dim = self.opt.motion.ab_uvh_plane_c 141 | vh_dim = uh_dim 142 | uv_dim = c - uh_dim - vh_dim 143 | 144 | uv_plane = uvh_plane[:, :uv_dim, ...] 145 | uh_plane = uvh_plane[:, uv_dim: uv_dim + uh_dim, ...] 146 | hv_plane = uvh_plane[:, uv_dim + uh_dim:, ...] 147 | 148 | self.depth_bound = [-0.1, 0.1] 149 | 150 | pts_huv = torch.cat((h_pred, near_uv), 2) 151 | huvlat = embedder.uvh_embedder(pts_huv).permute(0,2,1) 152 | 153 | 154 | h_pred *= 1/self.depth_bound[1]#-1,1 155 | h_pred = (h_pred + 1)/2 #[0, 1] 156 | if is_debug: print('*** ', near_uv.shape, near_uv[:,[0],...].shape, h_pred.shape) 157 | 158 | ##B, N, C 159 | uh = torch.cat((near_uv[..., [0]], h_pred), -1) 160 | hv = torch.cat((h_pred, near_uv[..., [1]]), -1) 161 | 162 | uv_feature = self.render_posmap.index_posmap_by_vts(uv_plane, near_uv) 163 | uh_feature = self.render_posmap.index_posmap_by_vts(uh_plane, uh) 164 | hv_feature = self.render_posmap.index_posmap_by_vts(hv_plane, hv) 165 | 166 | if self.opt.combine_pose_style: 167 | fused_feat = torch.cat([uv_feature.unsqueeze(1), uh_feature.unsqueeze(1), hv_feature.unsqueeze(1)], 1).mean(1) 168 | else: 169 | fused_feat = torch.cat([uv_feature.unsqueeze(1), uh_feature.unsqueeze(1), hv_feature.unsqueeze(1)], 1).mean(1) 170 | style_uv_feat = self.render_posmap.index_posmap_by_vts(style_uv, near_uv) 171 | fused_feat = torch.cat((style_uv_feat, fused_feat), 1) 172 | 173 | return fused_feat, huvlat, near_uv 174 | 175 | 176 | def forward(self, batch, uv_latent, only_density=False): 177 | 178 | viewdir = batch['view_dir'] 179 | 180 | if self.opt.uv_2dplane: 181 | uv_feat, uvh_encoding, uv_coord = self.gen_UVlatent(batch, uv_latent) 182 | nerf_input = uv_feat 183 | if self.opt.plus_uvh_enc: 184 | nerf_input = torch.cat((nerf_input, uvh_encoding), 1) 185 | uvh_feat = uv_feat 186 | else: 187 | uvh_feat, uvh_encoding, uv_coord= self.gen_UVlatent(batch, uv_latent) 188 | nerf_input = uvh_feat 189 | 190 | if self.opt.plus_uvh_enc: #yes. 191 | nerf_input = torch.cat((uvh_feat, uvh_encoding), 1) 192 | 193 | if self.opt.debug_only_enc: 194 | uvh_feat = None 195 | nerf_input = uvh_encoding 196 | 197 | net = self.geo_fc_0(nerf_input) 198 | net = self.actvn(net) 199 | 200 | net = self.geo_fc_1(net) 201 | net = self.actvn(net) 202 | 203 | alpha = self.alpha_fc(net) 204 | 205 | if self.opt.vrnr_mesh_demo: 206 | return alpha.transpose(1, 2) 207 | 208 | 209 | viewdir = viewdir.transpose(1, 2) 210 | 211 | feat_view = torch.cat((net, viewdir), 1) 212 | if uvh_feat is not None: 213 | feat_view = torch.cat((feat_view, uvh_feat), 1) 214 | 215 | net = self.view_fc_0(feat_view) 216 | net = self.actvn(net) 217 | 218 | net = self.view_fc_1(net) 219 | net = self.actvn(net) 220 | 221 | rgb = self.rgb_fc(net) 222 | 223 | raw = torch.cat((alpha, rgb), dim=1) 224 | raw = raw.transpose(1, 2) 225 | 226 | if self.opt.learn_uv: 227 | return torch.cat((raw, uv_coord), dim=-1) 228 | else: 229 | return raw -------------------------------------------------------------------------------- /uvm_lib/models/net_smooth.py: -------------------------------------------------------------------------------- 1 | import numbers 2 | import math 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | class GeomConvLayers(nn.Module): 9 | ''' 10 | A few convolutional layers to smooth the geometric feature tensor 11 | ''' 12 | def __init__(self, input_nc=16, hidden_nc=16, output_nc=16, use_relu=False): 13 | super().__init__() 14 | self.use_relu = use_relu 15 | 16 | self.conv1 = nn.Conv2d(input_nc, hidden_nc, kernel_size=5, stride=1, padding=2, bias=False) 17 | self.conv2 = nn.Conv2d(hidden_nc, hidden_nc, kernel_size=5, stride=1, padding=2, bias=False) 18 | self.conv3 = nn.Conv2d(hidden_nc, output_nc, kernel_size=5, stride=1, padding=2, bias=False) 19 | if use_relu: 20 | self.relu = nn.LeakyReLU(0.2, inplace=True) 21 | 22 | def forward(self, x): 23 | x = self.conv1(x) 24 | if self.use_relu: 25 | x = self.relu(x) 26 | x = self.conv2(x) 27 | if self.use_relu: 28 | x = self.relu(x) 29 | x = self.conv3(x) 30 | 31 | return x -------------------------------------------------------------------------------- /uvm_lib/options/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TaoHuUMD/SurMo/ef68beea0a4615a85cceecaa35472d7525e592fb/uvm_lib/options/__init__.py -------------------------------------------------------------------------------- /uvm_lib/options/evaluation_option.py: -------------------------------------------------------------------------------- 1 | from ..base_options.evaluate_options import EvaluateOptions 2 | from .project_option import import_project_opt 3 | 4 | 5 | class ProjectOptions(EvaluateOptions): 6 | def initialize(self): 7 | EvaluateOptions.initialize(self) 8 | import_project_opt(self.parser) -------------------------------------------------------------------------------- /uvm_lib/options/project_option.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | 5 | def import_project_opt(parser): 6 | 7 | parser.add_argument("--extract_gt_images", action='store_true', help='') 8 | 9 | motion = parser.add_argument_group('motion') 10 | 11 | motion.add_argument("--infer_velocity", type=float, default=1, help='') 12 | 13 | 14 | motion.add_argument("--motion_chain", action='store_true', help='') 15 | motion.add_argument("--motion_steps", type=str, default="-25 -20 -15 -10 -5 -1 5 10 15 20 25", help='previous motion status') 16 | 17 | motion.add_argument("--motion_point", type=int, default=0, help='start or end point of a motion, <0, start, >0 end') 18 | 19 | 20 | motion.add_argument("--style_dim", type=int, default=256, help='') 21 | motion.add_argument("--ab_uvh_plane_c", type=int, default=16, help='dim of vh, uh plane') 22 | motion.add_argument("--nerf_dim", type=int, default=32, help='dim of nerf input') 23 | 24 | motion.add_argument('--aug_random_flip', action='store_true', help="") 25 | motion.add_argument("--use_global_posemap", action='store_true', help='whether use global verts in posemap') 26 | 27 | #2 no nerf 28 | #--ab_only_sup_dynamic_tex 29 | motion.add_argument("--ab_sup_only_dynamic_tex", action='store_true', help='no nerf, only supreso posemap out') 30 | 31 | #3 no nerf, supp. cond on static style. 32 | motion.add_argument("--ab_sup_only_static_style", action='store_true', help='no nerf, only supreso on style') 33 | 34 | 35 | motion.add_argument("--ab_sup_2d_style", action='store_true', help='sup net only condition on 2d style latent') 36 | 37 | 38 | motion.add_argument("--dual_discrim_eg3d", action='store_true', help='#whether D on tex field') 39 | 40 | 41 | motion.add_argument("--use_org_gan_loss", action='store_true', help='#') 42 | motion.add_argument("--use_org_discrim", action='store_true', help='#') 43 | 44 | 45 | motion.add_argument("--ab_Dtex", action='store_true', help='#whether D on tex field') 46 | motion.add_argument("--ab_Dtex_pose", action='store_true', help='#whether Dtex is conditioned on pose map') 47 | 48 | motion.add_argument("--ab_uvh_plane", action='store_true', help='uvh tri-plane') 49 | motion.add_argument("--ab_nerf_rec", action='store_true', help='whether rec loss on nerf') 50 | 51 | motion.add_argument("--ab_Ddual", action='store_true', help='dual discriminator') 52 | motion.add_argument("--ab_D_pose", action='store_true', help='D cond on pose') 53 | motion.add_argument("--ab_tex_rec", action='store_true', help='uv tex recon') 54 | 55 | motion.add_argument("--D_label_noise", action='store_true', help='add label noise for D') 56 | motion.add_argument("--D_noise_factor", type=float, default=0.05, help='add label noise for D') 57 | 58 | motion.add_argument("--debug_data_size", type=int, default=10, help='debug small dataset') 59 | 60 | motion.add_argument("--abandon", action='store_true', help='not used') 61 | 62 | 63 | motion.add_argument("--ab_cond_uv_latent", action='store_true', help='super reso cond on 2d uv lat') 64 | motion.add_argument("--general_superreso", action='store_true', help='super reso cond on 2d uv lat') 65 | 66 | #in the future 67 | motion.add_argument("--deep_nerf", action='store_true', help='3 layer nerf network') 68 | motion.add_argument("--ab_cond_1d_lat", action='store_true', help='super reso cond on 1d style') 69 | -------------------------------------------------------------------------------- /uvm_lib/options/test_option.py: -------------------------------------------------------------------------------- 1 | from ..base_options.test_options import TestOptions 2 | from .project_option import import_project_opt 3 | import os 4 | 5 | class ProjectOptions(TestOptions): 6 | def initialize(self): 7 | TestOptions.initialize(self) 8 | import_project_opt(self.parser) 9 | 10 | def parse_(self): 11 | opt = self.parse() 12 | 13 | opt.results_dir = os.path.join(opt.results_dir, opt.name) 14 | -------------------------------------------------------------------------------- /uvm_lib/options/train_option.py: -------------------------------------------------------------------------------- 1 | from ..base_options.train_options import TrainOptions 2 | from .project_option import import_project_opt 3 | 4 | 5 | class ProjectOptions(TrainOptions): 6 | def initialize(self): 7 | TrainOptions.initialize(self) 8 | import_project_opt(self.parser) -------------------------------------------------------------------------------- /uvm_lib/util/html.py: -------------------------------------------------------------------------------- 1 | import dominate 2 | from dominate.tags import * 3 | 4 | import os 5 | import copy 6 | 7 | class HTML: 8 | def __init__(self, web_dir, title, img_dir = 'images', refresh=0, pre=""): 9 | self.title = title 10 | self.web_dir = web_dir 11 | 12 | self.loc_img_dir = img_dir if pre=="" else pre + "_" + img_dir 13 | 14 | self.img_dir = os.path.join(self.web_dir, self.loc_img_dir) 15 | self.pre = pre 16 | 17 | if not os.path.exists(self.web_dir): 18 | os.makedirs(self.web_dir) 19 | if not os.path.exists(self.img_dir): 20 | os.makedirs(self.img_dir) 21 | 22 | self.doc = dominate.document(title=title) 23 | if refresh > 0: 24 | with self.doc.head: 25 | meta(http_equiv="refresh", content=str(refresh)) 26 | 27 | def get_image_dir(self): 28 | return self.img_dir 29 | 30 | def add_header(self, str): 31 | with self.doc: 32 | h3(str) 33 | 34 | def add_table(self, border=1): 35 | self.t = table(border=border, style="table-layout: fixed;") 36 | self.doc.add(self.t) 37 | 38 | def add_comp_images(self, ims, txts, links, width=300): 39 | """add images to the HTML file 40 | 41 | Parameters: 42 | ims (str list) -- a list of image paths 43 | txts (str list) -- a list of image names shown on the website 44 | links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page 45 | """ 46 | self.t = table(border=1, style="table-layout: fixed;") # Insert a table 47 | self.doc.add(self.t) 48 | with self.t: 49 | with tr(): 50 | for im, txt, link in zip(ims, txts, links): 51 | with td(style="word-wrap: break-word;", halign="center", valign="top"): 52 | with p(): 53 | with a(href=os.path.join('', link)): 54 | img(style="width:%dpx" % width, src=os.path.join('', im)) 55 | br() 56 | p(txt) 57 | 58 | def add_images(self, ims, txts, links, width=512): 59 | self.add_table() 60 | with self.t: 61 | with tr(): 62 | for im, txt, link in zip(ims, txts, links): 63 | with td(style="word-wrap: break-word;", halign="center", valign="top"): 64 | with p(): 65 | with a(href=os.path.join(self.loc_img_dir, link)): 66 | img(style="width:%dpx" % (width), src=os.path.join(self.loc_img_dir, im)) 67 | br() 68 | p(txt) 69 | 70 | def save_tmp(self, idx=0): 71 | tmp_html_file = '%s/tmp_%s_index_%d.html' % (self.web_dir, self.pre, idx) 72 | f = open(tmp_html_file, 'wt') 73 | tmp_doc = copy.deepcopy(self.doc) 74 | f.write(tmp_doc.render()) 75 | f.close() 76 | 77 | def save(self): 78 | html_file = '%s/%s_index.html' % (self.web_dir, self.pre) 79 | f = open(html_file, 'wt') 80 | f.write(self.doc.render()) 81 | f.close() 82 | 83 | 84 | if __name__ == '__main__': 85 | html = HTML('web/', 'test_html') 86 | html.add_header('hello world') 87 | 88 | ims = [] 89 | txts = [] 90 | links = [] 91 | for n in range(4): 92 | ims.append('image_%d.jpg' % n) 93 | txts.append('text_%d' % n) 94 | links.append('image_%d.jpg' % n) 95 | html.add_images(ims, txts, links) 96 | html.save() 97 | -------------------------------------------------------------------------------- /uvm_lib/util/util.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import numpy as np 4 | from PIL import Image 5 | import numpy as np 6 | import os 7 | import cv2 8 | 9 | fixed_mat = np.load('../asset/fixedmat_viz.npy') 10 | # Converts a Tensor into a Numpy array 11 | # |imtype|: the desired type of the converted numpy array 12 | def tensor2im(image_tensor, imtype=np.uint8, normalize=True, is_list = False): 13 | if isinstance(image_tensor, list) or is_list: 14 | image_numpy = [] 15 | for i in range(len(image_tensor)): 16 | image_numpy.append(tensor2im(image_tensor[i], imtype, normalize)) 17 | return image_numpy 18 | image_numpy = image_tensor.cpu().float().numpy() 19 | if normalize: 20 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 21 | else: 22 | image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0 23 | image_numpy = np.clip(image_numpy, 0, 255) 24 | if image_numpy.shape[2] == 1 or image_numpy.shape[2] > 3: 25 | #fixed_mat = np.random.random_sample((16, 3)) 26 | 27 | image_numpy = image_numpy.sum(axis=-1) 28 | #image_numpy = np.matmul(image_numpy, fixed_mat) 29 | return image_numpy.astype(imtype) 30 | 31 | def tensor2imProj(image_tensor, imtype=np.uint8, normalize=True, is_list = False): 32 | global fixed_mat 33 | if isinstance(image_tensor, list) or is_list: 34 | image_numpy = [] 35 | for i in range(len(image_tensor)): 36 | image_numpy.append(tensor2imProj(image_tensor[i], imtype, normalize)) 37 | return image_numpy 38 | image_numpy = image_tensor.cpu().float().numpy() 39 | 40 | #print(image_numpy.shape, fixed_mat.shape) 41 | 42 | image_numpy = np.transpose(image_numpy, (1, 2, 0)) 43 | if image_numpy.shape[2] == 1 or image_numpy.shape[2] == 16: 44 | #image_numpy = image_numpy[:,:,0] 45 | image_numpy_proj = np.matmul(image_numpy, fixed_mat) 46 | 47 | max, min = image_numpy_proj.max(), image_numpy_proj.min() 48 | image_numpy_proj = (image_numpy_proj - min)/(max - min)*255 49 | image_numpy_proj = np.clip(image_numpy_proj, 0, 255) 50 | return image_numpy_proj.astype(imtype) 51 | elif image_numpy.shape[2] >3 : 52 | image_numpy = (image_numpy[:,:,3:6] + 1) / 2.0 * 255.0 53 | return image_numpy.astype(imtype) 54 | else: 55 | image_numpy = (image_numpy + 1) / 2.0 * 255.0 56 | return image_numpy.astype(imtype) 57 | 58 | # Converts a one-hot tensor into a colorful label map 59 | def tensor2label(label_tensor, n_label, imtype=np.uint8): 60 | if n_label == 0: 61 | return tensor2im(label_tensor, imtype) 62 | label_tensor = label_tensor.cpu().float() 63 | if label_tensor.size()[0] > 1: 64 | label_tensor = label_tensor.max(0, keepdim=True)[1] 65 | label_tensor = Colorize(n_label)(label_tensor) 66 | label_numpy = np.transpose(label_tensor.numpy(), (1, 2, 0)) 67 | return label_numpy.astype(imtype) 68 | 69 | 70 | def save_image(image_numpy, image_path): 71 | 72 | if not (image_numpy.size==256*256 or image_numpy.size==512*512 or image_numpy.size==1024*1024): 73 | cv2.imwrite(image_path, image_numpy[:,:,[2,1,0]]) 74 | else: 75 | cv2.imwrite(image_path, image_numpy*255) 76 | return 0 77 | image_pil = Image.fromarray(image_numpy) 78 | image_pil.save(image_path) 79 | 80 | def mkdirs(paths): 81 | if isinstance(paths, list) and not isinstance(paths, str): 82 | for path in paths: 83 | mkdir(path) 84 | else: 85 | mkdir(paths) 86 | 87 | def mkdir(path): 88 | os.makedirs(path, exist_ok = True) 89 | 90 | # if not os.path.exists(path): 91 | # os.makedirs(path) 92 | 93 | ############################################################################### 94 | # Code from 95 | # https://github.com/ycszen/pytorch-seg/blob/master/transform.py 96 | # Modified so it complies with the Citscape label map colors 97 | ############################################################################### 98 | def uint82bin(n, count=8): 99 | """returns the binary of integer n, count refers to amount of bits""" 100 | return ''.join([str((n >> y) & 1) for y in range(count-1, -1, -1)]) 101 | 102 | def labelcolormap(N): 103 | if N == 35: # cityscape 104 | cmap = np.array([( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), (111, 74, 0), ( 81, 0, 81), 105 | (128, 64,128), (244, 35,232), (250,170,160), (230,150,140), ( 70, 70, 70), (102,102,156), (190,153,153), 106 | (180,165,180), (150,100,100), (150,120, 90), (153,153,153), (153,153,153), (250,170, 30), (220,220, 0), 107 | (107,142, 35), (152,251,152), ( 70,130,180), (220, 20, 60), (255, 0, 0), ( 0, 0,142), ( 0, 0, 70), 108 | ( 0, 60,100), ( 0, 0, 90), ( 0, 0,110), ( 0, 80,100), ( 0, 0,230), (119, 11, 32), ( 0, 0,142)], 109 | dtype=np.uint8) 110 | else: 111 | cmap = np.zeros((N, 3), dtype=np.uint8) 112 | for i in range(N): 113 | r, g, b = 0, 0, 0 114 | id = i 115 | for j in range(7): 116 | str_id = uint82bin(id) 117 | r = r ^ (np.uint8(str_id[-1]) << (7-j)) 118 | g = g ^ (np.uint8(str_id[-2]) << (7-j)) 119 | b = b ^ (np.uint8(str_id[-3]) << (7-j)) 120 | id = id >> 3 121 | cmap[i, 0] = r 122 | cmap[i, 1] = g 123 | cmap[i, 2] = b 124 | return cmap 125 | 126 | class Colorize(object): 127 | def __init__(self, n=35): 128 | self.cmap = labelcolormap(n) 129 | self.cmap = torch.from_numpy(self.cmap[:n]) 130 | 131 | def __call__(self, gray_image): 132 | size = gray_image.size() 133 | color_image = torch.ByteTensor(3, size[1], size[2]).fill_(0) 134 | 135 | for label in range(0, len(self.cmap)): 136 | mask = (label == gray_image[0]).cpu() 137 | color_image[0][mask] = self.cmap[label][0] 138 | color_image[1][mask] = self.cmap[label][1] 139 | color_image[2][mask] = self.cmap[label][2] 140 | 141 | return color_image 142 | 143 | 144 | def plot2Fig(mainfigure): 145 | mainfigure.canvas.draw() 146 | # data = np.frombuffer(mainfigure.canvas.tostring_rgb(), dtype=np.uint8) 147 | data = np.fromstring(mainfigure.canvas.tostring_rgb(), dtype=np.uint8, sep='') 148 | w, h = mainfigure.canvas.get_width_height()[::-1] 149 | data = data.reshape(mainfigure.canvas.get_width_height()[::-1] + (3,)) 150 | # data2 = data[:, :h//nc*(nc-1),:] 151 | 152 | return data 153 | 154 | import matplotlib 155 | matplotlib.use('Agg') 156 | import matplotlib.pyplot as plt 157 | 158 | def showDPtexPlt(dptex): 159 | figure = plt.figure() 160 | 161 | count = 0 162 | for i in range(4): 163 | for j in range(6): 164 | ax = plt.subplot(4, 6, count + 1) 165 | fig = ax.imshow(dptex[count, :, :].sum(axis = -1)) 166 | fig.axes.get_xaxis().set_visible(False) 167 | fig.axes.get_yaxis().set_visible(False) 168 | plt.axis('off') 169 | 170 | count += 1 171 | figure.subplots_adjust(wspace=0, hspace=0) 172 | return plot2Fig(figure) 173 | 174 | def showDPtex(dptex): 175 | psize = dptex.shape[1] 176 | combinedtex = np.zeros((psize*4, psize*6)) 177 | count = 0 178 | for i in range(4): 179 | for j in range(6): 180 | combinedtex[i*psize:i*psize+psize, j*psize:j*psize+psize] = dptex[count].sum(axis = -1) 181 | count += 1 182 | 183 | combinedtex = np.repeat(combinedtex[..., np.newaxis], 3, axis = -1) 184 | combinedtex = (combinedtex - combinedtex.min()) / (combinedtex.max() - combinedtex.min())*255 185 | return combinedtex.astype(np.uint8) 186 | 187 | def padSqH(img, size, pad = 0): 188 | 189 | if size==img.shape[0]: 190 | return img 191 | 192 | ratio = size / img.shape[0] 193 | h = size 194 | w = int(img.shape[1] * ratio) 195 | 196 | frame = np.ones((size, size, 3)).astype(np.uint8) * pad 197 | p = int((h - w) / 2) 198 | img = cv2.resize(img, (w, h)) 199 | frame[:, p:p + w] = img 200 | return frame 201 | 202 | def save_img_list(img_list, path, img_name, rnum = 2): 203 | img = np.concatenate(img_list, 1) 204 | w = img.shape[1] // rnum 205 | h = img.shape[0] 206 | img = img.reshape(-1, w, 3) 207 | cv2.imwrite(os.path.join(path, f'{img_name}.png'), img) 208 | 209 | 210 | def split_dict_batch(d, n=1): 211 | keys = list(d.keys()) 212 | batch_size = len(d[keys[0]]) 213 | if batch_size==1: 214 | yield d 215 | return 216 | #for i in range(0, len(keys), n): 217 | for i in range(0, batch_size, n): 218 | yield {k: d[k][i:i+n] for k in keys} 219 | -------------------------------------------------------------------------------- /uvm_lib/util/visualizer.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import os 4 | from . import util 5 | try: 6 | from StringIO import StringIO # Python 2.7 7 | except ImportError: 8 | from io import BytesIO # Python 3.x 9 | 10 | import time 11 | from . import html 12 | 13 | from tensorboardX import SummaryWriter 14 | 15 | import imageio 16 | import torch 17 | 18 | class Visualizer(): 19 | def __init__(self, opt): 20 | 21 | self.use_html = True #not (opt.phase == "test") 22 | self.win_size = 3000 23 | self.name = opt.name 24 | self.tf_log = True 25 | 26 | if opt.phase == "test": return 27 | 28 | if "checkpoints_dir" in opt.keys(): 29 | checkpoints_dir = opt.checkpoints_dir 30 | else: 31 | checkpoints_dir = opt.training.checkpoints_dir 32 | 33 | if (opt.training.distributed and opt.local_rank == 0) or not opt.training.distributed: 34 | if self.tf_log: 35 | 36 | self.log_dir = os.path.join(checkpoints_dir, self.name) 37 | os.makedirs(self.log_dir, exist_ok=True) 38 | 39 | 40 | self.writer = SummaryWriter(log_dir=self.log_dir) 41 | 42 | if self.use_html: 43 | self.web_dir = os.path.join(self.log_dir, 'web') 44 | self.train_dir = os.path.join(self.web_dir, 'train_images') 45 | self.eval_dir = os.path.join(self.web_dir, 'eva_images') 46 | print('create web directory %s...' % self.web_dir) 47 | for d in [self.web_dir, self.train_dir, self.eval_dir]: 48 | os.makedirs(d, exist_ok=True) 49 | 50 | self.log_name = os.path.join(self.log_dir, 'loss_log.txt') 51 | 52 | with open(self.log_name, "a") as log_file: 53 | now = time.strftime("%c") 54 | log_file.write('================ Training Loss (%s) ================\n' % now) 55 | 56 | self.train_webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=30, pre="train") 57 | self.eval_webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=30, pre="eva") 58 | 59 | 60 | 61 | def save_eval_images(self, img_list, epoch, flag=""): 62 | if torch.is_tensor(flag): flag = str(flag.item()) 63 | self.save_img_list(img_list, self.eval_dir, 'epoch%s_%s' % (epoch, flag)) 64 | 65 | # |visuals|: dictionary of images to display or save 66 | def display_current_results(self, visuals, epoch, flag=""): 67 | if self.tf_log: # show images in tensorboard output 68 | img_summaries = [] 69 | for label, image_numpy in visuals.items(): 70 | # Write the image to a string 71 | if image_numpy is None: 72 | continue 73 | 74 | if torch.is_tensor(flag): flag = str(flag.item()) 75 | 76 | save_train = False 77 | save_eval = False 78 | img_dir = self.train_dir 79 | if flag.find("eva")!=-1: 80 | img_dir = self.eval_dir 81 | save_eval = True 82 | else: save_train = True 83 | 84 | if self.use_html: # save images to a html file 85 | for label, image_numpy in visuals.items(): 86 | if image_numpy is None: 87 | continue 88 | if isinstance(image_numpy, list): 89 | for i in range(len(image_numpy)): 90 | img_path = os.path.join(img_dir, 'epoch%.3d_%s_%s_%d.jpg' % (epoch, flag, label, i)) 91 | util.save_image(image_numpy[i], img_path) 92 | else: 93 | img_path = os.path.join(img_dir, 'epoch%.3d_%s_%s.jpg' % (epoch, flag, label)) 94 | util.save_image(image_numpy, img_path) 95 | 96 | 97 | # update website 98 | pre = "eva" if save_eval else "train" 99 | 100 | if save_eval: 101 | self.save_images(self.eval_webpage, visuals, 'epoch%.3d_%s' % (epoch, flag)) 102 | self.eval_webpage.save() 103 | 104 | else: 105 | self.save_images(self.train_webpage, visuals, 'epoch%.3d_%s' % (epoch, flag)) 106 | self.train_webpage.save() 107 | 108 | def save_training(self): 109 | self.train_webpage.save() 110 | self.eval_webpage.save() 111 | 112 | # errors: dictionary of error labels and values 113 | def plot_current_errors(self, errors, step): 114 | if self.tf_log: 115 | for tag, value in errors.items(): 116 | self.writer.add_scalar(tag,value,step) 117 | 118 | # errors: same format as |errors| of plotCurrentErrors 119 | def print_current_errors(self, epoch, i, errors, t): 120 | if isinstance(t, list): 121 | message = '(epoch: %d, iters: %d, timed: %.3f, timef: %.3f) ' % (epoch, i, t[0], t[1]) 122 | else: 123 | message = '(epoch: %d, iters: %d, time: %.3f) ' % (epoch, i, t) 124 | for k, v in errors.items(): 125 | if v != 0: 126 | message += '%s: %.3f ' % (k, v) 127 | 128 | print(message) 129 | with open(self.log_name, "a") as log_file: 130 | log_file.write('%s\n' % message) 131 | 132 | def save_img_list(self, img_list_, path, img_name, rnum = 2): 133 | img_list = [] 134 | for l in img_list_: 135 | if isinstance(l, list): 136 | img_list += l 137 | else: img_list.append(l) 138 | 139 | img = np.concatenate(img_list, 1) 140 | w = img.shape[1] // rnum 141 | img = np.concatenate([img[:, :w, :], img[:, w:, :]], 0) 142 | cv2.imwrite(os.path.join(path, f'{img_name}.png'), img[:,:,[2,1,0]]) 143 | 144 | def img_to_video(self, webpage, video_name): 145 | image_dir = webpage.get_image_dir() 146 | video_dir = os.path.join(image_dir, "../") 147 | 148 | vfs = sorted(os.listdir(image_dir)) 149 | 150 | video_pre = "%s_%s" % (video_name.split("_")[0], video_name.split("_")[1]) 151 | 152 | video_list = [] 153 | for f in vfs: 154 | if f.startswith(video_pre): 155 | video_list.append(os.path.join(image_dir, f)) 156 | 157 | def img_to_video(file_list, video_dir, fps, is_img=True): 158 | 159 | writer = imageio.get_writer(video_dir, fps=fps) 160 | 161 | for file_name in file_list: 162 | if is_img: 163 | img = imageio.imread(file_name) 164 | else: 165 | img = imageio.imread(file_name).astype(np.float32) 166 | 167 | writer.append_data(img) 168 | 169 | writer.close() 170 | 171 | img_to_video(video_list, os.path.join(video_dir, video_name), 25) 172 | print("%s video saved" % video_name) 173 | 174 | def save_images_list(self, webpage, visuals, frame_idx, phase="train"): 175 | 176 | image_dir = webpage.get_image_dir() 177 | 178 | name = "%s" % frame_idx 179 | 180 | webpage.add_header(name) 181 | ims = [] 182 | txts = [] 183 | links = [] 184 | 185 | image_numpy = visuals 186 | 187 | image_name = "%s.jpg" % (name) 188 | if phase=="test": 189 | image_name = "%s.png" % (name) 190 | 191 | save_path = os.path.join(image_dir, image_name) 192 | util.save_image(image_numpy, save_path) 193 | 194 | ims.append(image_name) 195 | txts.append(name) 196 | links.append(image_name) 197 | webpage.add_images(ims, txts, links, width=self.win_size) 198 | 199 | 200 | def save_images(self, webpage, visuals, frame_idx, phase="train"): 201 | 202 | image_dir = webpage.get_image_dir() 203 | 204 | name = "%s" % frame_idx 205 | 206 | webpage.add_header(name) 207 | ims = [] 208 | txts = [] 209 | links = [] 210 | 211 | for label, image_numpy in visuals.items(): 212 | if image_numpy is None: continue 213 | image_name = "%s_%s.jpg" % (name, label) 214 | if phase=="test": 215 | image_name = "%s_%s.png" % (name, label) 216 | #print(label, image_numpy.shape) 217 | save_path = os.path.join(image_dir, image_name) 218 | util.save_image(image_numpy, save_path) 219 | 220 | ims.append(image_name) 221 | txts.append(label) 222 | links.append(image_name) 223 | webpage.add_images(ims, txts, links, width=self.win_size) 224 | --------------------------------------------------------------------------------