├── README.md ├── Teaser.png ├── data ├── SMAL │ └── raw │ │ └── move_smal.zip_here.txt └── move_bone.zip_here.txt ├── datasets ├── __init__.py ├── bone.py ├── dfaust.py ├── meshdata.py └── smal.py ├── main.py ├── models ├── __init__.py ├── arap.py ├── conv │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── cheb_conv.cpython-36.pyc │ │ └── message_passing.cpython-36.pyc │ ├── cheb_conv.py │ └── message_passing.py ├── inits.py └── networks.py ├── template ├── dfaust.ply ├── smal_0.ply └── smpl_male_template.ply ├── test_bone.sh ├── test_dfaust.sh ├── test_smal.sh ├── train_and_test_dfaust.sh ├── train_bone.sh ├── train_smal.sh ├── utils ├── __init__.py ├── dataloader.py ├── mesh_sampling.py ├── read.py ├── train_eval.py ├── utils.py └── writer.py └── work_dir ├── SMAL └── out │ └── move_smal_ckpt.zip_here.txt └── move_bone_ckpt.zip_here.txt /README.md: -------------------------------------------------------------------------------- 1 | # ARAPReg 2 | Code for ICCV 2021 paper: [ARAPReg: An As-Rigid-As Possible Regularization Loss for Learning Deformable Shape Generators.](https://arxiv.org/pdf/2108.09432.pdf). 3 |

4 | 5 |

6 | 7 | ## Installation 8 | The code is developed using `Python 3.6` and `cuda 10.2` on Ubuntu 18.04. 9 | * [Pytorch](https://pytorch.org/) (1.9.0) 10 | * [Pytorch Geometric](https://github.com/rusty1s/pytorch_geometric) 11 | * [OpenMesh](https://github.com/nmaxwell/OpenMesh-Python) (1.1.3) 12 | * [MPI-IS Mesh](https://github.com/MPI-IS/mesh): We suggest to install this library from the source. 13 | * [tqdm](https://github.com/tqdm/tqdm) 14 | 15 | Note that Pytorch and Pytorch Geometric versions might change with your cuda version. 16 | 17 | 18 | ## Data Preparation 19 | We provide data for 3 datasets: [DFAUST](https://dfaust.is.tue.mpg.de/), [SMAL](https://smal.is.tue.mpg.de/) and Bone dataset. 20 | 21 | ### DFAUST 22 | We use 4264 test shapes and 32933 training shapes from DFaust dataset. 23 | You can download the dataset [here](https://drive.google.com/file/d/1BaACAdJO0uoH5P084Gw11a_j3nKVSUjn/view?usp=sharing). 24 | Please place `dfaust.zip` in `data/DFaust/raw/`. 25 | 26 | ### SMAL 27 | We use 400 shapes from the family 0 in SMAL dataset. We generate shapes by the SMAL demo where the mean and the variance of the pose vectors are set to 0 and 0.2. We split them to 300 training and 100 testing samples. 28 | 29 | You can download the generated dataset [here](https://drive.google.com/file/d/1L3n6i097bgZtNYAmnGM9NwOWBNd4c1Fr/view?usp=sharing). 30 | After downloading, please move the downloaded `smal.zip` to `./data/SMAL/raw`. 31 | 32 | ### Bone 33 | We created a conventional bone dataset with 4 categories: tibia, pelvis, scapula and femur. Each category has about 50 shapes. We split them to 40 training and 10 testing samples. 34 | You can download the dataset [here](https://drive.google.com/file/d/1Naq1F6V-Oxw4AQZJeaCKfRrOCQneF0gT/view?usp=sharing). 35 | After downloading, please move `bone.zip` to `./data` then extract it. 36 | 37 | 38 | ## Testing 39 | ### Pretrained checkpoints 40 | You can find pre-trained models and training logs in the following paths: 41 | 42 | **DFAUST**: [checkpoints.zip](https://drive.google.com/file/d/1mCiF-XkMWPNDih4mmxRn6aaPnTAHdpK0/view?usp=sharing). Uncompress it under repository root will place two checkpoints in `DFaust/out/arap/checkpoints/` and `DFaust/out/arap/test_checkpoints/`. 43 | 44 | **SMAL**: [smal_ckpt.zip](https://drive.google.com/file/d/1IIAW5SmylMHsFpU-croeu-uNPdKP_fnL/view?usp=sharing). Move it to `./work_dir/SMAL/out`, then extract it. 45 | 46 | **Bone**: [bone_ckpt.zip](https://drive.google.com/file/d/15I-uABi6_-2qM3QG40Df9G9oNh-K55Nl/view?usp=sharing). Move it to `./work_dir`, then extract it. It contains checkpoints for 4 bone categories. 47 | 48 | ### Run testing 49 | After putting pre-trained checkpoints to their corresponding paths, you can run the following scripts to optimize latent vectors for shape reconstruction. Note that our model has the auto-decoder architecture, so there's still a latent vector training stage for testing shapes. 50 | 51 | Note that both SMAL and Bone checkpoints were trained on a single GPU. Please keep `args.distributed` `False` in `main.py`. In your own training, you can use multiple GPUs. 52 | 53 | **DFAUST**: 54 | ``` 55 | bash test_dfaust.sh 56 | ``` 57 | **SMAL**: 58 | ``` 59 | bash test_smal.sh 60 | ``` 61 | **Bone**: 62 | ``` 63 | bash test_smal.sh 64 | ``` 65 | Note that for bone dataset, we train and test 4 categories seperately. Currently there's `tibia` in the training and testing script. You can replace it with `femur`, `pelvis` or `scapula` to get results for other 3 categories. 66 | 67 | 68 | ## Model training 69 | To retrain our model, run the following scripts after downloading and extracting datasets. 70 | 71 | **DFAUST**: 72 | Note that on DFaust, it is preferred to have multiple GPUs for better efficiency. The script on DFaust tracks the reconstruction error to avoid over-fitting. 73 | ``` 74 | bash train_and_test_dfaust.sh 75 | ``` 76 | **SMAL**: 77 | ``` 78 | bash train_smal.sh 79 | ``` 80 | Note: batch_size is set to 16 in the script. It will have CUDA_OUT_OF_MEMORY error when the ARAPReg loss is added at the 800th epoch. You will have to set the batch size to 8 after 800 epochs. 81 | **Bone**: 82 | ``` 83 | bash train_bone.sh 84 | ``` 85 | 86 | 87 | ## Train on a new dataset 88 | Data preprocessing and loading scripts are in `./datasets`. 89 | To train on a new dataset, please write data loading file similar to `./datasets/dfaust.py`. Then add the dataset to `./datasets/meshdata.py` and `main.py`. Finally you can write a similar training script like `train_and_test_dfaust.sh`. 90 | 91 | 92 | 93 | -------------------------------------------------------------------------------- /Teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitBoSun/ARAPReg/88c8a75596c1151f3ed02af4b51f2fa7497e5561/Teaser.png -------------------------------------------------------------------------------- /data/SMAL/raw/move_smal.zip_here.txt: -------------------------------------------------------------------------------- 1 | hh 2 | -------------------------------------------------------------------------------- /data/move_bone.zip_here.txt: -------------------------------------------------------------------------------- 1 | hh 2 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .smal import SMAL 2 | from .bone import Bone 3 | from .dfaust import DFaust 4 | from .meshdata import MeshData 5 | 6 | __all__ = ['MeshData','Bone', 'DFaust', 'SMAL',] 7 | -------------------------------------------------------------------------------- /datasets/bone.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | from glob import glob 4 | 5 | import torch 6 | from torch_geometric.data import InMemoryDataset, extract_zip 7 | from utils.read import read_mesh 8 | 9 | from tqdm import tqdm 10 | 11 | 12 | class Bone(InMemoryDataset): 13 | url = 'Not needed' 14 | 15 | categories = [ 16 | 'femur', 17 | 'pelvis', 18 | 'scapula', 19 | 'tibia', 20 | ] 21 | 22 | def __init__(self, 23 | root, 24 | train=True, 25 | transform=None, 26 | pre_transform=None): 27 | if not osp.exists(osp.join(root, 'processed',)): 28 | os.makedirs(osp.join(root, 'processed',)) 29 | 30 | super().__init__(root, transform, pre_transform) 31 | self.train_list = [] 32 | self.test_list = [] 33 | path = self.processed_paths[0] if train else self.processed_paths[1] 34 | self.data, self.slices = torch.load(path) 35 | 36 | @property 37 | def raw_file_names(self): 38 | return 'bone_data.zip' 39 | 40 | @property 41 | def processed_file_names(self): 42 | return ['training.pt', 'test.pt'] 43 | 44 | def process(self): 45 | print('Processing...') 46 | 47 | fps = glob(osp.join(self.raw_dir, '*.stl')) 48 | if len(fps) == 0: 49 | extract_zip(self.raw_paths[0], self.raw_dir, log=False) 50 | fps = glob(osp.join(self.raw_dir, '*.stl')) 51 | train_data_list, test_data_list = [], [] 52 | train_id = 0 53 | val_id = 0 54 | for idx, fp in enumerate(tqdm(fps)): 55 | if (idx % 100) < 10: 56 | data_id = val_id 57 | val_id = val_id + 1 58 | else: 59 | data_id = train_id 60 | train_id = train_id + 1 61 | data = read_mesh(fp, data_id) 62 | # data = read_mesh(fp) 63 | if self.pre_transform is not None: 64 | data = self.pre_transform(data) 65 | 66 | if (idx % 100) < 10: 67 | test_data_list.append(data) 68 | else: 69 | train_data_list.append(data) 70 | 71 | torch.save(self.collate(train_data_list), self.processed_paths[0]) 72 | torch.save(self.collate(test_data_list), self.processed_paths[1]) 73 | -------------------------------------------------------------------------------- /datasets/dfaust.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | from glob import glob 4 | 5 | import torch 6 | from torch_geometric.data import InMemoryDataset, extract_zip, Data 7 | from torch_geometric.utils import to_undirected 8 | 9 | import numpy as np 10 | 11 | class DFaust(InMemoryDataset): 12 | 13 | def __init__(self, 14 | root='data/dfaust/', 15 | train=True, 16 | transform=None, 17 | pre_transform=None): 18 | if not osp.exists(osp.join(root, 'processed')): 19 | os.makedirs(osp.join(root, 'processed')) 20 | 21 | super().__init__(root, transform, pre_transform) 22 | path = self.processed_paths[0] if train else self.processed_paths[1] 23 | self.data, self.slices = torch.load(path) 24 | 25 | @property 26 | def raw_file_names(self): 27 | return 'dfaust.zip' 28 | 29 | @property 30 | def processed_file_names(self): 31 | return [ 'training.pt', 'test.pt'] 32 | 33 | def download(self): 34 | raise RuntimeError( 35 | 'Dataset not found. Please ' 36 | 'move {} to {}'.format(self.raw_file_names, self.raw_dir)) 37 | 38 | def process(self): 39 | 40 | def convert_to_data(face, points, data_id, pose=None): 41 | face = torch.from_numpy(faces).T.type(torch.long) 42 | x = torch.tensor(points, dtype=torch.float) 43 | edge_index = torch.cat([face[:2], face[1:], face[::2]], dim=1) 44 | edge_index = to_undirected(edge_index) 45 | if pose is not None: 46 | return Data(x=x, edge_index=edge_index, data_id=data_id, pose=pose) 47 | return Data(x=x, edge_index=edge_index, data_id=data_id) 48 | 49 | if not osp.exists(osp.join(self.raw_dir, 'train.npy')): 50 | extract_zip(self.raw_paths[0], self.raw_dir, log=False) 51 | 52 | train_points = np.load(osp.join(self.raw_dir, 'train.npy')).astype( 53 | np.float32).reshape((-1, 6890, 3)) 54 | 55 | test_points = np.load(osp.join(self.raw_dir, 'test.npy')).astype( 56 | np.float32).reshape((-1, 6890, 3)) 57 | 58 | #eval_points = np.load(osp.join(self.raw_dir, 'eval.npy')).astype( 59 | # np.float32).reshape((-1, 6890, 3)) 60 | 61 | faces = np.load(osp.join(self.raw_dir, 'faces.npy')).astype( 62 | np.int32).reshape((-1, 3)) 63 | 64 | print('Processing...') 65 | 66 | train_data_list = [] 67 | for i in range(train_points.shape[0]): 68 | data_id = i 69 | train = train_points[i] 70 | if i%100==0: 71 | print('processing training data %d/%d'%(i, train_points.shape[0])) 72 | 73 | data = convert_to_data(faces, train, data_id) 74 | if self.pre_transform is not None: 75 | data = self.pre_transform(data) 76 | train_data_list.append(data) 77 | torch.save(self.collate(train_data_list), self.processed_paths[0]) 78 | 79 | 80 | """ 81 | eval_data_list = [] 82 | for data_id, eval in enumerate(tqdm(eval_points)): 83 | data = convert_to_data(faces, eval, data_id) 84 | if self.pre_transform is not None: 85 | data = self.pre_transform(data) 86 | eval_data_list.append(data) 87 | torch.save(self.collate(eval_data_list), self.processed_paths[1]) 88 | """ 89 | 90 | test_data_list = [] 91 | data_id = 0 92 | for i in range(test_points.shape[0]): 93 | print('processing testing data %d/%d'%(i, test_points.shape[0])) 94 | test = test_points[i] 95 | 96 | data = convert_to_data(faces, test, data_id) 97 | if self.pre_transform is not None: 98 | data = self.pre_transform(data) 99 | test_data_list.append(data) 100 | data_id += 1 101 | torch.save(self.collate(test_data_list), self.processed_paths[1]) 102 | 103 | 104 | if __name__ == '__main__': 105 | dfaust = DFaust() 106 | 107 | -------------------------------------------------------------------------------- /datasets/meshdata.py: -------------------------------------------------------------------------------- 1 | import openmesh as om 2 | from datasets import SMAL, DFaust, Bone 3 | from sklearn.decomposition import PCA 4 | import numpy as np 5 | import sys 6 | sys.path.append("./") 7 | import torch 8 | import os 9 | 10 | class MeshData(object): 11 | def __init__(self, 12 | root, 13 | template_fp, 14 | dataset='DFaust', 15 | transform=None, 16 | pre_transform=None, 17 | pca_n_comp=8, 18 | vert_pca=False, 19 | heat_kernel=False): 20 | self.root = root 21 | self.template_fp = template_fp 22 | self.dataset = dataset 23 | self.transform = transform 24 | self.pre_transform = pre_transform 25 | self.train_dataset = None 26 | self.test_dataste = None 27 | self.template_points = None 28 | self.template_face = None 29 | self.mean = None 30 | self.std = None 31 | self.num_nodes = None 32 | self.use_vert_pca = vert_pca 33 | self.use_heat_kernel = heat_kernel 34 | self.pca = PCA(n_components=pca_n_comp) 35 | 36 | self.load() 37 | 38 | def load(self): 39 | if self.dataset=='SMAL': 40 | self.train_dataset = SMAL(self.root, 41 | train=True, 42 | transform=self.transform, 43 | pre_transform=self.pre_transform) 44 | self.test_dataset = SMAL(self.root, 45 | train=False, 46 | transform=self.transform, 47 | pre_transform=self.pre_transform) 48 | 49 | elif self.dataset=='DFaust': 50 | self.train_dataset = DFaust(self.root, 51 | train=True, 52 | transform=self.transform, 53 | pre_transform=self.pre_transform) 54 | self.test_dataset = DFaust(self.root, 55 | train=False, 56 | transform=self.transform, 57 | pre_transform=self.pre_transform) 58 | elif self.dataset=='Bone': 59 | self.train_dataset = Bone(self.root, 60 | train=True, 61 | transform=self.transform, 62 | pre_transform=self.pre_transform) 63 | self.test_dataset = Bone(self.root, 64 | train=False, 65 | transform=self.transform, 66 | pre_transform=self.pre_transform) 67 | 68 | tmp_mesh = om.read_trimesh(self.template_fp) 69 | self.template_points = tmp_mesh.points() 70 | self.template_face = tmp_mesh.face_vertex_indices() 71 | self.num_nodes = self.train_dataset[0].num_nodes 72 | 73 | self.num_train_graph = len(self.train_dataset) 74 | self.num_test_graph = len(self.test_dataset) 75 | 76 | self.mean = self.train_dataset.data.x.view(self.num_train_graph, -1, 77 | 3).mean(dim=0) 78 | self.std = self.train_dataset.data.x.view(self.num_train_graph, -1, 79 | 3).std(dim=0) 80 | if self.dataset=='SMAL': 81 | self.std = torch.ones(self.std.size())*0.2 82 | self.normalize() 83 | 84 | def normalize(self): 85 | 86 | vertices_train = self.train_dataset.data.x.view(self.num_train_graph, -1, 3).numpy() 87 | vertices_test = self.test_dataset.data.x.view(self.num_test_graph, -1, 3).numpy() 88 | 89 | print('Normalizing...') 90 | self.train_dataset.data.x = ( 91 | (self.train_dataset.data.x.view(self.num_train_graph, -1, 3) - 92 | self.mean) / self.std).view(-1, 3) 93 | self.test_dataset.data.x = ( 94 | (self.test_dataset.data.x.view(self.num_test_graph, -1, 3) - 95 | self.mean) / self.std).view(-1, 3) 96 | 97 | if self.use_vert_pca: 98 | print("Computing vertex PCA...") 99 | self.pca.fit(np.reshape(vertices_train, (self.num_train_graph, -1))) 100 | pca_axes = self.pca.components_ 101 | train_pca_sv= np.matmul(np.reshape(vertices_train, (self.num_train_graph, -1)), pca_axes.transpose()) 102 | test_pca_sv = np.matmul(np.reshape(vertices_test, (self.num_test_graph, -1)), pca_axes.transpose()) 103 | pca_sv_mean = np.mean(train_pca_sv, axis=0) 104 | pca_sv_std = np.std(train_pca_sv, axis=0) 105 | self.train_pca_sv = (train_pca_sv - pca_sv_mean)/pca_sv_std 106 | self.test_pca_sv = (test_pca_sv - pca_sv_mean)/pca_sv_std 107 | 108 | print('Done!') 109 | 110 | def save_mesh(self, fp, x): 111 | x = x * self.std + self.mean 112 | om.write_mesh(fp, om.TriMesh(x.numpy(), self.template_face)) 113 | -------------------------------------------------------------------------------- /datasets/smal.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | from glob import glob 4 | 5 | import torch 6 | from torch_geometric.data import InMemoryDataset, extract_zip 7 | from utils.read import read_mesh 8 | 9 | from tqdm import tqdm 10 | import numpy as np 11 | 12 | class SMAL(InMemoryDataset): 13 | 14 | def __init__(self, 15 | root, 16 | train=True, 17 | transform=None, 18 | pre_transform=None): 19 | if not osp.exists(osp.join(root, 'processed')): 20 | os.makedirs(osp.join(root, 'processed')) 21 | 22 | super().__init__(root, transform, pre_transform) 23 | path = self.processed_paths[0] if train else self.processed_paths[1] 24 | self.data, self.slices = torch.load(path) 25 | 26 | @property 27 | def raw_file_names(self): 28 | return 'smal.zip' 29 | 30 | @property 31 | def processed_file_names(self): 32 | return [ 'training.pt', 'test.pt'] 33 | 34 | def download(self): 35 | raise RuntimeError( 36 | 'Dataset not found. Please ' 37 | 'move {} to {}'.format(self.raw_file_names, self.raw_dir)) 38 | 39 | def process(self): 40 | print('Processing...') 41 | fps = glob(osp.join(self.raw_dir, '*/*.ply')) 42 | if len(fps) == 0: 43 | extract_zip(self.raw_paths[0], self.raw_dir, log=False) 44 | fps = glob(osp.join(self.raw_dir, '*/*.ply')) 45 | train_data_list, test_data_list = [], [] 46 | train_id = 0 47 | val_id = 0 48 | 49 | poses = np.load(osp.join(self.raw_dir, 'results_pose_800/poses.npy')) 50 | 51 | for i in range(poses.shape[0]): 52 | print('processing %d/%d'%(i, poses.shape[0])) 53 | 54 | fp = osp.join(self.raw_dir, 'results_pose_800/%d.ply'%(i+1)) 55 | pose_id = int(fp.split('/')[-1].split('.')[0]) - 1 56 | 57 | # 300 train, 100 test 58 | if pose_id> 299 and pose_id<400: 59 | data_id = val_id 60 | val_id = val_id + 1 61 | elif pose_id < 300: 62 | data_id = train_id 63 | train_id = train_id + 1 64 | 65 | data = read_mesh(fp, data_id, pose=poses[pose_id], return_face=False) 66 | if self.pre_transform is not None: 67 | data = self.pre_transform(data) 68 | 69 | if pose_id > 299 and pose_id < 400: 70 | test_data_list.append(data) 71 | 72 | elif pose_id < 300: 73 | train_data_list.append(data) 74 | 75 | torch.save(self.collate(train_data_list), self.processed_paths[0]) 76 | torch.save(self.collate(test_data_list), self.processed_paths[1]) 77 | 78 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import argparse 3 | import os 4 | import os.path as osp 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.backends.cudnn as cudnn 9 | import torch_geometric.transforms as T 10 | from psbody.mesh import Mesh 11 | 12 | from models import AE, AE_single, Pool 13 | from datasets import MeshData 14 | from utils import utils, writer, train_eval, DataLoader, mesh_sampling 15 | 16 | parser = argparse.ArgumentParser(description='mesh autoencoder') 17 | parser.add_argument('--exp_name', type=str, default='model0') 18 | parser.add_argument('--dataset', type=str, default='DFAUST') 19 | parser.add_argument('--n_threads', type=int, default=4) 20 | parser.add_argument('--device_idx', type=int, default=0) 21 | parser.add_argument('--cpu', action='store_true', help='if True, use CPU only') 22 | parser.add_argument('--mode', type=str, default='train', help='[train, test, interpolate, extraplate]') 23 | parser.add_argument('--work_dir', type=str, default='./out') 24 | parser.add_argument('--data_dir', type=str, default='./data') 25 | parser.add_argument('--checkpoint', type=str, default=None) 26 | parser.add_argument('--test_checkpoint', type=str, default=None) 27 | parser.add_argument('--distributed', action='store_true') 28 | parser.add_argument('--alsotest', action='store_true') 29 | 30 | # network hyperparameters 31 | parser.add_argument('--out_channels', 32 | nargs='+', 33 | #default=[32, 32, 32, 64], 34 | default=[16, 16, 16, 32], 35 | type=int) 36 | 37 | parser.add_argument('--ds_factors', 38 | nargs='+', 39 | #default=[32, 32, 32, 64], 40 | default=[4, 4, 4, 4], 41 | type=int) 42 | 43 | parser.add_argument('--latent_channels', type=int, default=8) 44 | parser.add_argument('--in_channels', type=int, default=3) 45 | parser.add_argument('--K', type=int, default=6) 46 | 47 | # optimizer hyperparmeters 48 | parser.add_argument('--optimizer', type=str, default='Adam') 49 | parser.add_argument('--lr', type=float, default=8e-3,) 50 | parser.add_argument('--lr_decay', type=float, default=0.99) 51 | parser.add_argument('--decay_step', type=int, default=1) 52 | parser.add_argument('--weight_decay', type=float, default=0) 53 | parser.add_argument('--test_lr', type=float, default=0.01) 54 | parser.add_argument('--test_decay_step', type=int, default=1) 55 | 56 | parser.add_argument('--arap_weight', type=float, default=0.05) 57 | parser.add_argument('--use_arap_epoch', type=int, default=600, help='epoch that we start to use arap loss') 58 | parser.add_argument('--nz_max', type=int, default=60, help='random sample nz_max latent channels to compute ARAP energy') 59 | 60 | # training hyperparameters 61 | parser.add_argument('--batch_size', type=int, default=16) 62 | parser.add_argument('--epochs', type=int, default=1500) 63 | parser.add_argument('--test_epochs', type=int, default=2000) 64 | parser.add_argument('--continue_train', type=bool, default=False, help='If True, continue training from last checkpoint') 65 | 66 | 67 | # interpolate 68 | parser.add_argument('--inter_num', type=int, default=10, help='number of intermediate shapes between two shapes in interpolation') 69 | parser.add_argument('--extra_num', type=int, default=5, help='number of extrapolation perturbations per shape') 70 | parser.add_argument('--extra_thres', type=float, default=0.2) 71 | 72 | 73 | # others 74 | parser.add_argument('--seed', type=int, default=1) 75 | parser.add_argument('--use_vert_pca', type=bool, default=False, help='If True, use the vertex PCA as the latent vector initialization [DFAUST, Bone]') 76 | parser.add_argument('--use_pose_init', type=bool, default=False, help='If True, use the provided pose vector as the latent vector initialization in training [SMAL]') 77 | 78 | args = parser.parse_args() 79 | 80 | args.data_fp =osp.join(args.data_dir) 81 | args.out_dir = osp.join(args.work_dir, 'out', args.exp_name) # save checkpoints and logs 82 | args.results_dir = osp.join(args.work_dir, 'results', args.exp_name) # save training and testing results 83 | args.checkpoints_dir = osp.join(args.out_dir, 'checkpoints') 84 | args.checkpoints_dir_test = osp.join(args.out_dir, 'test_checkpoints') 85 | print(args) 86 | 87 | utils.makedirs(args.out_dir) 88 | utils.makedirs(args.checkpoints_dir) 89 | utils.makedirs(args.checkpoints_dir_test) 90 | utils.makedirs(args.results_dir) 91 | results_dir_train = os.path.join(args.results_dir, "train") 92 | results_dir_test = os.path.join(args.results_dir, "test") 93 | utils.makedirs(results_dir_train) 94 | utils.makedirs(results_dir_test) 95 | writer = writer.Writer(args) 96 | 97 | if args.cpu: 98 | device = torch.device('cpu') 99 | else: 100 | device = torch.device('cuda', args.device_idx) 101 | 102 | 103 | torch.set_num_threads(args.n_threads) 104 | 105 | # deterministic 106 | torch.manual_seed(args.seed) 107 | cudnn.benchmark = False 108 | cudnn.deterministic = True 109 | 110 | # load dataset 111 | if args.dataset=='SMAL': 112 | template_fp = osp.join('template', 'smal_0.ply') 113 | elif args.dataset=='DFaust': 114 | template_fp = osp.join('template', 'smpl_male_template.ply') 115 | elif args.dataset=='Bone': 116 | template_fp = osp.join(args.data_dir, 'template.obj') 117 | else: 118 | print('invalid dataset!') 119 | exit(-1) 120 | 121 | meshdata = MeshData(args.data_fp, 122 | template_fp, 123 | dataset=args.dataset, 124 | pca_n_comp=args.latent_channels, 125 | vert_pca=args.use_vert_pca) 126 | 127 | train_loader = DataLoader(meshdata.train_dataset, 128 | batch_size=args.batch_size, 129 | shuffle=True) 130 | test_loader = DataLoader(meshdata.test_dataset, batch_size=args.batch_size, shuffle=False) 131 | 132 | # generate/load transform matrices 133 | transform_fp = osp.join(args.data_fp, 'transform.pkl') 134 | if not osp.exists(transform_fp): 135 | print('Generating transform matrices...') 136 | mesh = Mesh(filename=template_fp) 137 | ds_factors = args.ds_factors 138 | if args.dataset=='SMAL': 139 | ds_factors = [1, 1, 1, 1] 140 | _, A, D, U, F = mesh_sampling.generate_transform_matrices(mesh, ds_factors) 141 | tmp = {'face': F, 'adj': A, 'down_transform': D, 'up_transform': U} 142 | 143 | with open(transform_fp, 'wb') as fp: 144 | pickle.dump(tmp, fp) 145 | print('Done!') 146 | print('Transform matrices are saved in \'{}\''.format(transform_fp)) 147 | else: 148 | with open(transform_fp, 'rb') as f: 149 | tmp = pickle.load(f, encoding='latin1') 150 | 151 | edge_index_list = [utils.to_edge_index(adj).to(device) for adj in tmp['adj']] 152 | down_transform_list = [ 153 | utils.to_sparse(down_transform).to(device) 154 | for down_transform in tmp['down_transform'] 155 | ] 156 | up_transform_list = [ 157 | utils.to_sparse(up_transform).to(device) 158 | for up_transform in tmp['up_transform'] 159 | ] 160 | 161 | 162 | 163 | 164 | if args.distributed: 165 | model = AE(args.in_channels, 166 | args.out_channels, 167 | args.latent_channels, 168 | edge_index_list, 169 | down_transform_list, 170 | up_transform_list, 171 | K=args.K) 172 | #from mmcv.parallel import MMDistributedDataParallel 173 | #from mmcv.runner import get_dist_info, init_dist 174 | #init_dist('pytorch') 175 | 176 | #model = MMDistributedDataParallel( 177 | # model.cuda(), 178 | # device_ids=[torch.cuda.current_device()], 179 | # broadcast_buffers=False, 180 | # find_unused_parameters=False 181 | #) 182 | model = torch.nn.DataParallel(model) 183 | model = model.to(device) 184 | else: 185 | model = AE_single(args.in_channels, 186 | args.out_channels, 187 | args.latent_channels, 188 | edge_index_list, 189 | down_transform_list, 190 | up_transform_list, 191 | K=args.K) 192 | model = model.to(device) 193 | #print(model) 194 | 195 | rand_std = 1.0 196 | if args.dataset=='SMAL' and args.use_pose_init: 197 | rand_std = 0.2 198 | test_num_scenes = len(meshdata.test_dataset) 199 | test_lat_vecs = torch.nn.Embedding(test_num_scenes, args.latent_channels,) 200 | torch.nn.init.normal_(test_lat_vecs.weight.data, 0.0, rand_std) 201 | test_lat_vecs = test_lat_vecs.to(device) 202 | 203 | if args.use_vert_pca: 204 | pca_init = torch.from_numpy(meshdata.train_pca_sv) 205 | lat_vecs = torch.nn.Embedding.from_pretrained(pca_init, freeze=False) 206 | print(meshdata.train_pca_sv.mean(), np.std(meshdata.train_pca_sv)) 207 | 208 | test_pca_init = torch.from_numpy(meshdata.test_pca_sv) 209 | test_lat_vecs = torch.nn.Embedding.from_pretrained(test_pca_init, freeze=False) 210 | test_lat_vecs = test_lat_vecs.to(device) 211 | elif args.use_pose_init: 212 | pose_init = torch.from_numpy(np.array(meshdata.train_dataset.data.pose, np.float32)) 213 | pose_init = pose_init.reshape(meshdata.num_train_graph, -1) 214 | lat_vecs = torch.nn.Embedding.from_pretrained(pose_init, freeze=False) 215 | else: 216 | train_num_scenes = len(meshdata.train_dataset) 217 | lat_vecs = torch.nn.Embedding(train_num_scenes, args.latent_channels,) 218 | torch.nn.init.normal_(lat_vecs.weight.data,0.0,rand_std) 219 | 220 | lat_vecs = lat_vecs.to(device) 221 | if args.continue_train: 222 | start_epoch = writer.load_checkpoint(model, lat_vecs, None, 223 | None, checkpoint=args.checkpoint) 224 | 225 | if args.dataset=='SMAL': 226 | train_vec_lr = 8e-3 227 | else: 228 | train_vec_lr = args.lr 229 | optimizer_all = torch.optim.Adam( 230 | [ 231 | { 232 | "params": model.parameters(), 233 | "lr": args.lr, 234 | "weight_decay": args.weight_decay 235 | }, 236 | { 237 | "params": lat_vecs.parameters(), 238 | "lr": train_vec_lr, 239 | "weight_decay": args.weight_decay 240 | }, 241 | ] 242 | ) 243 | 244 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer_all, 245 | args.decay_step, 246 | gamma=args.lr_decay) 247 | if args.mode=='train': 248 | if args.alsotest: 249 | optimizer_test = torch.optim.Adam( 250 | test_lat_vecs.parameters(), lr=args.test_lr, 251 | weight_decay=args.weight_decay) 252 | scheduler_test = torch.optim.lr_scheduler.StepLR( 253 | optimizer_test, args.test_decay_step, gamma=args.lr_decay) 254 | else: 255 | optimizer_test = None 256 | scheduler_test = None 257 | 258 | train_eval.run(model, 259 | train_loader, lat_vecs, optimizer_all, scheduler, 260 | test_loader, test_lat_vecs, optimizer_test, scheduler_test, 261 | args.epochs, writer, device, results_dir_train, 262 | meshdata.mean.numpy(), meshdata.std.numpy(), meshdata.template_face, 263 | arap_weight=args.arap_weight, use_arap_epoch=args.use_arap_epoch, 264 | nz_max=args.nz_max, continue_train=args.continue_train, 265 | checkpoint=args.checkpoint, test_checkpoint=args.test_checkpoint, dataset=args.dataset) 266 | 267 | elif args.mode=='test': 268 | optimizer_test = torch.optim.Adam( 269 | test_lat_vecs.parameters(), lr=args.test_lr, 270 | weight_decay=args.weight_decay) 271 | scheduler_test = torch.optim.lr_scheduler.StepLR( 272 | optimizer_test, args.test_decay_step, gamma=args.lr_decay) 273 | 274 | train_eval.test_reconstruct(model, test_loader, test_lat_vecs, 275 | args.test_epochs, optimizer_test, scheduler_test, writer, 276 | device, results_dir_test, meshdata.mean.numpy(), 277 | meshdata.std.numpy(), meshdata.template_face, 278 | checkpoint=args.checkpoint, test_checkpoint=args.test_checkpoint, dataset=args.dataset) 279 | 280 | elif args.mode=='interpolate': 281 | train_eval.global_interpolate(model, lat_vecs, optimizer_all, scheduler, 282 | writer, device, args.results_dir, meshdata.mean.numpy(), meshdata.std.numpy(), meshdata.template_face, args.inter_num) 283 | 284 | elif args.mode=='extraplate': 285 | optimizer_test = torch.optim.Adam(test_lat_vecs.parameters(), 286 | lr=args.test_lr, 287 | weight_decay=args.weight_decay) 288 | scheduler_test = torch.optim.lr_scheduler.StepLR(optimizer_test, 289 | args.test_decay_step, 290 | gamma=args.lr_decay) 291 | train_eval.extrapolation(model, test_lat_vecs, optimizer_test, scheduler_test, 292 | writer, device, args.results_dir, meshdata.mean.numpy(), meshdata.std.numpy(), meshdata.template_face, args.extra_num, args.extra_thres) 293 | 294 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .networks import AE,AE_single, Pool 2 | from .arap import ARAP 3 | __all__ = [Pool, AE,AE_single, ARAP] 4 | -------------------------------------------------------------------------------- /models/arap.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.utils import degree, get_laplacian 3 | import torch_sparse as ts 4 | import numpy as np 5 | import sys 6 | 7 | def get_laplacian_kron3x3(edge_index, edge_weights, N): 8 | edge_index, edge_weight = get_laplacian(edge_index, edge_weights, num_nodes=N) 9 | edge_weight *= 2 10 | e0, e1 = edge_index 11 | i0 = [e0*3, e0*3+1, e0*3+2] 12 | i1 = [e1*3, e1*3+1, e1*3+2] 13 | vals = [edge_weight, edge_weight, edge_weight] 14 | i0 = torch.cat(i0, 0) 15 | i1 = torch.cat(i1, 0) 16 | vals = torch.cat(vals, 0) 17 | indices, vals = ts.coalesce([i0, i1], vals, N*3, N*3) 18 | return indices, vals 19 | 20 | class ARAP(torch.nn.Module): 21 | def __init__(self, template_face, num_points): 22 | super(ARAP, self).__init__() 23 | N = num_points 24 | self.template_face = template_face 25 | adj = np.zeros((num_points, num_points)) 26 | adj[template_face[:, 0], template_face[:, 1]] = 1 27 | adj[template_face[:, 1], template_face[:, 2]] = 1 28 | adj[template_face[:, 0], template_face[:, 2]] = 1 29 | adj = adj + adj.T 30 | edge_index = torch.as_tensor(np.stack(np.where(adj > 0), 0), 31 | dtype=torch.long) 32 | e0, e1 = edge_index 33 | deg = degree(e0, N) 34 | edge_weight = torch.ones_like(e0) 35 | 36 | L_indices, L_vals = get_laplacian_kron3x3(edge_index, edge_weight, N) 37 | self.register_buffer('L_indices', L_indices) 38 | self.register_buffer('L_vals', L_vals) 39 | self.register_buffer('edge_weight', edge_weight) 40 | self.register_buffer('edge_index', edge_index) 41 | 42 | def forward(self, x, J, k=0): 43 | """ 44 | x: [B, N, 3] point locations. 45 | J: [B, N*3, D] Jacobian of generator. 46 | J_eigvals: [B, D] 47 | """ 48 | num_batches, N = x.shape[:2] 49 | e0, e1 = self.edge_index 50 | edge_vecs = x[:, e0, :] - x[:, e1, :] 51 | trace_ = [] 52 | 53 | for i in range(num_batches): 54 | LJ = ts.spmm(self.L_indices, self.L_vals, N*3, N*3, J[i]) 55 | JTLJ = J[i].T.matmul(LJ) 56 | 57 | B0, B1, B_vals = [], [], [] 58 | B0.append(e0*3 ); B1.append(e1*3+1); B_vals.append(-edge_vecs[i, :, 2]*self.edge_weight) 59 | B0.append(e0*3 ); B1.append(e1*3+2); B_vals.append( edge_vecs[i, :, 1]*self.edge_weight) 60 | B0.append(e0*3+1); B1.append(e1*3+0); B_vals.append( edge_vecs[i, :, 2]*self.edge_weight) 61 | B0.append(e0*3+1); B1.append(e1*3+2); B_vals.append(-edge_vecs[i, :, 0]*self.edge_weight) 62 | B0.append(e0*3+2); B1.append(e1*3+0); B_vals.append(-edge_vecs[i, :, 1]*self.edge_weight) 63 | B0.append(e0*3+2); B1.append(e1*3+1); B_vals.append( edge_vecs[i, :, 0]*self.edge_weight) 64 | 65 | B0.append(e0*3 ); B1.append(e0*3+1); B_vals.append(-edge_vecs[i, :, 2]*self.edge_weight) 66 | B0.append(e0*3 ); B1.append(e0*3+2); B_vals.append( edge_vecs[i, :, 1]*self.edge_weight) 67 | B0.append(e0*3+1); B1.append(e0*3+0); B_vals.append( edge_vecs[i, :, 2]*self.edge_weight) 68 | B0.append(e0*3+1); B1.append(e0*3+2); B_vals.append(-edge_vecs[i, :, 0]*self.edge_weight) 69 | B0.append(e0*3+2); B1.append(e0*3+0); B_vals.append(-edge_vecs[i, :, 1]*self.edge_weight) 70 | B0.append(e0*3+2); B1.append(e0*3+1); B_vals.append( edge_vecs[i, :, 0]*self.edge_weight) 71 | B0 = torch.cat(B0, 0) 72 | B1 = torch.cat(B1, 0) 73 | B_vals = torch.cat(B_vals, 0) 74 | B_indices, B_vals = ts.coalesce([B0, B1], B_vals, N*3, N*3) 75 | BT_indices, BT_vals = ts.transpose(B_indices, B_vals, N*3, N*3) 76 | 77 | C0, C1, C_vals = [], [], [] 78 | edge_vecs_sq = (edge_vecs[i] * edge_vecs[i]).sum(-1) 79 | evi = edge_vecs[i] 80 | for di in range(3): 81 | for dj in range(3): 82 | C0.append(e0*3+di); C1.append(e0*3+dj); C_vals.append(-evi[:, di]*evi[:, dj]*self.edge_weight) 83 | C0.append(e0*3+di); C1.append(e0*3+di); C_vals.append(edge_vecs_sq*self.edge_weight) 84 | C0 = torch.cat(C0, 0) 85 | C1 = torch.cat(C1, 0) 86 | C_vals = torch.cat(C_vals, 0) 87 | C_indices, C_vals = ts.coalesce([C0, C1], C_vals, N*3, N*3) 88 | C_vals = C_vals.view(N, 3, 3).inverse().reshape(-1) 89 | BTJ = ts.spmm(BT_indices, BT_vals, N*3, N*3, J[i]) 90 | CBTJ = ts.spmm(C_indices, C_vals, N*3, N*3, BTJ) 91 | JTBCBTJ = BTJ.T.mm(CBTJ) 92 | 93 | e = torch.linalg.eigvalsh(JTLJ-JTBCBTJ).clip(0) 94 | 95 | e = e ** 0.5 96 | 97 | trace = e.sum() 98 | 99 | trace_.append(trace) 100 | 101 | trace_ = torch.stack(trace_, ) 102 | return trace_.mean() 103 | -------------------------------------------------------------------------------- /models/conv/__init__.py: -------------------------------------------------------------------------------- 1 | from .message_passing import MessagePassing 2 | from .cheb_conv import ChebConv 3 | 4 | 5 | __all__ = [ 6 | 'MessagePassing', 7 | 'ChebConv', 8 | ] 9 | -------------------------------------------------------------------------------- /models/conv/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitBoSun/ARAPReg/88c8a75596c1151f3ed02af4b51f2fa7497e5561/models/conv/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /models/conv/__pycache__/cheb_conv.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitBoSun/ARAPReg/88c8a75596c1151f3ed02af4b51f2fa7497e5561/models/conv/__pycache__/cheb_conv.cpython-36.pyc -------------------------------------------------------------------------------- /models/conv/__pycache__/message_passing.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitBoSun/ARAPReg/88c8a75596c1151f3ed02af4b51f2fa7497e5561/models/conv/__pycache__/message_passing.cpython-36.pyc -------------------------------------------------------------------------------- /models/conv/cheb_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Parameter 3 | from torch_geometric.utils import remove_self_loops, add_self_loops 4 | from torch_geometric.utils import get_laplacian 5 | from .message_passing import MessagePassing 6 | 7 | from ..inits import glorot, zeros 8 | 9 | 10 | class ChebConv(MessagePassing): 11 | """ 12 | Args: 13 | in_channels (int): Size of each input sample. 14 | out_channels (int): Size of each output sample. 15 | K (int): Chebyshev filter size, *i.e.* number of hops :math:`K`. 16 | normalization (str, optional): The normalization scheme for the graph 17 | Laplacian (default: :obj:`"sym"`): 18 | 19 | 1. :obj:`None`: No normalization 20 | :math:`\mathbf{L} = \mathbf{D} - \mathbf{A}` 21 | 22 | 2. :obj:`"sym"`: Symmetric normalization 23 | :math:`\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1/2} \mathbf{A} 24 | \mathbf{D}^{-1/2}` 25 | 26 | 3. :obj:`"rw"`: Random-walk normalization 27 | :math:`\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1} \mathbf{A}` 28 | 29 | You need to pass :obj:`lambda_max` to the :meth:`forward` method of 30 | this operator in case the normalization is non-symmetric. 31 | :obj:`\lambda_max` should be a :class:`torch.Tensor` of size 32 | :obj:`[num_graphs]` in a mini-batch scenario and a scalar when 33 | operating on single graphs. 34 | You can pre-compute :obj:`lambda_max` via the 35 | :class:`torch_geometric.transforms.LaplacianLambdaMax` transform. 36 | cached (bool, optional): If set to :obj:`True`, the layer will cache 37 | the computation of the scaled and normalized Laplacian 38 | :math:`\frac{2\mathbf{L}}{\lambda_{\max}} - \mathbf{I}` on first execution, 39 | and will use the cached version for further executions. 40 | This parameter should only be set to :obj:`True` in 41 | fixed graph scenarios. (default: :obj:`True`) 42 | bias (bool, optional): If set to :obj:`False`, the layer will not learn 43 | an additive bias. (default: :obj:`True`) 44 | **kwargs (optional): Additional arguments of 45 | :class:`torch_geometric.nn.conv.MessagePassing`. 46 | """ 47 | 48 | def __init__(self, 49 | in_channels, 50 | out_channels, 51 | K, 52 | normalization='sym', 53 | cached=True, 54 | bias=True, 55 | **kwargs): 56 | super(ChebConv, self).__init__(aggr='add', **kwargs) 57 | 58 | assert K > 0 59 | assert normalization in [None, 'sym', 'rw'], 'Invalid normalization' 60 | 61 | self.in_channels = in_channels 62 | self.out_channels = out_channels 63 | self.normalization = normalization 64 | self.cached = cached 65 | self.weight = Parameter(torch.Tensor(K, in_channels, out_channels)) 66 | 67 | if bias: 68 | self.bias = Parameter(torch.Tensor(out_channels)) 69 | else: 70 | self.register_parameter('bias', None) 71 | 72 | self.reset_parameters() 73 | 74 | def reset_parameters(self): 75 | glorot(self.weight) 76 | zeros(self.bias) 77 | self.cached_result = None 78 | self.cached_num_edges = None 79 | 80 | @staticmethod 81 | def norm(edge_index, 82 | num_nodes, 83 | edge_weight, 84 | normalization, 85 | lambda_max, 86 | dtype=None, 87 | batch=None): 88 | 89 | edge_index, edge_weight = remove_self_loops(edge_index, edge_weight) 90 | 91 | edge_index, edge_weight = get_laplacian(edge_index, edge_weight, 92 | normalization, dtype, 93 | num_nodes) 94 | 95 | if batch is not None and torch.is_tensor(lambda_max): 96 | lambda_max = lambda_max[batch[edge_index[0]]] 97 | 98 | edge_weight = (2.0 * edge_weight) / lambda_max 99 | edge_weight[edge_weight == float('inf')] = 0 100 | 101 | edge_index, edge_weight = add_self_loops(edge_index, 102 | edge_weight, 103 | fill_value=-1, 104 | num_nodes=num_nodes) 105 | 106 | return edge_index, edge_weight 107 | 108 | def forward(self, 109 | x, 110 | edge_index, 111 | edge_weight=None, 112 | batch=None, 113 | lambda_max=None): 114 | """""" 115 | if self.normalization != 'sym' and lambda_max is None: 116 | raise ValueError('You need to pass `lambda_max` to `forward() in`' 117 | 'case the normalization is non-symmetric.') 118 | lambda_max = 2.0 if lambda_max is None else lambda_max 119 | 120 | if not self.cached or self.cached_result is None: 121 | edge_index, norm = self.norm(edge_index, 122 | x.size(1), 123 | edge_weight, 124 | self.normalization, 125 | lambda_max, 126 | dtype=x.dtype, 127 | batch=batch) 128 | self.cached_result = edge_index, norm 129 | 130 | edge_index, norm = self.cached_result 131 | Tx_0 = x 132 | out = torch.matmul(Tx_0, self.weight[0]) 133 | 134 | if self.weight.size(0) > 1: 135 | Tx_1 = self.propagate(edge_index, x=x, norm=norm) 136 | out = out + torch.matmul(Tx_1, self.weight[1]) 137 | 138 | for k in range(2, self.weight.size(0)): 139 | Tx_2 = 2 * self.propagate(edge_index, x=Tx_1, norm=norm) - Tx_0 140 | out = out + torch.matmul(Tx_2, self.weight[k]) 141 | Tx_0, Tx_1 = Tx_1, Tx_2 142 | 143 | if self.bias is not None: 144 | out = out + self.bias 145 | 146 | return out 147 | 148 | def message(self, x_j, norm): 149 | return norm.view(-1, 1) * x_j 150 | 151 | def __repr__(self): 152 | return '{}({}, {}, K={}, normalization={})'.format( 153 | self.__class__.__name__, self.in_channels, self.out_channels, 154 | self.weight.size(0), self.normalization) 155 | -------------------------------------------------------------------------------- /models/conv/message_passing.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import inspect 3 | 4 | import torch 5 | from torch_scatter import scatter 6 | 7 | special_args = [ 8 | 'edge_index', 'edge_index_i', 'edge_index_j', 'size', 'size_i', 'size_j' 9 | ] 10 | __size_error_msg__ = ('All tensors which should get mapped to the same source ' 11 | 'or target nodes must be of same size in dimension 0.') 12 | 13 | is_python2 = sys.version_info[0] < 3 14 | getargspec = inspect.getargspec if is_python2 else inspect.getfullargspec 15 | 16 | 17 | class MessagePassing(torch.nn.Module): 18 | def __init__(self, aggr='add', flow='source_to_target'): 19 | super(MessagePassing, self).__init__() 20 | 21 | self.aggr = aggr 22 | assert self.aggr in ['add', 'mean', 'max'] 23 | 24 | self.flow = flow 25 | assert self.flow in ['source_to_target', 'target_to_source'] 26 | 27 | self.__message_args__ = getargspec(self.message)[0][1:] 28 | self.__special_args__ = [(i, arg) 29 | for i, arg in enumerate(self.__message_args__) 30 | if arg in special_args] 31 | self.__message_args__ = [ 32 | arg for arg in self.__message_args__ if arg not in special_args 33 | ] 34 | self.__update_args__ = getargspec(self.update)[0][2:] 35 | 36 | def propagate(self, edge_index, size=None, dim=0, **kwargs): 37 | dim = 1 # aggregate messages wrt nodes for batched_data: [batch_size, nodes, features] 38 | size = [None, None] if size is None else list(size) 39 | assert len(size) == 2 40 | 41 | i, j = (0, 1) if self.flow == 'target_to_source' else (1, 0) 42 | ij = {"_i": i, "_j": j} 43 | 44 | message_args = [] 45 | for arg in self.__message_args__: 46 | if arg[-2:] in ij.keys(): 47 | tmp = kwargs.get(arg[:-2], None) 48 | if tmp is None: # pragma: no cover 49 | message_args.append(tmp) 50 | else: 51 | idx = ij[arg[-2:]] 52 | if isinstance(tmp, tuple) or isinstance(tmp, list): 53 | assert len(tmp) == 2 54 | if tmp[1 - idx] is not None: 55 | if size[1 - idx] is None: 56 | size[1 - idx] = tmp[1 - idx].size(dim) 57 | if size[1 - idx] != tmp[1 - idx].size(dim): 58 | raise ValueError(__size_error_msg__) 59 | tmp = tmp[idx] 60 | 61 | if tmp is None: 62 | message_args.append(tmp) 63 | else: 64 | if size[idx] is None: 65 | size[idx] = tmp.size(dim) 66 | if size[idx] != tmp.size(dim): 67 | raise ValueError(__size_error_msg__) 68 | 69 | tmp = torch.index_select(tmp, dim, edge_index[idx]) 70 | message_args.append(tmp) 71 | else: 72 | message_args.append(kwargs.get(arg, None)) 73 | 74 | size[0] = size[1] if size[0] is None else size[0] 75 | size[1] = size[0] if size[1] is None else size[1] 76 | 77 | kwargs['edge_index'] = edge_index 78 | kwargs['size'] = size 79 | 80 | for (idx, arg) in self.__special_args__: 81 | if arg[-2:] in ij.keys(): 82 | message_args.insert(idx, kwargs[arg[:-2]][ij[arg[-2:]]]) 83 | else: 84 | message_args.insert(idx, kwargs[arg]) 85 | 86 | update_args = [kwargs[arg] for arg in self.__update_args__] 87 | 88 | out = self.message(*message_args) 89 | out = scatter(out, edge_index[i], dim=dim, dim_size=size[i], reduce=self.aggr) 90 | out = self.update(out, *update_args) 91 | 92 | return out 93 | 94 | def message(self, x_j): # pragma: no cover 95 | return x_j 96 | 97 | def update(self, aggr_out): # pragma: no cover 98 | return aggr_out 99 | -------------------------------------------------------------------------------- /models/inits.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | 4 | def uniform(size, tensor): 5 | bound = 1.0 / math.sqrt(size) 6 | if tensor is not None: 7 | tensor.data.uniform_(-bound, bound) 8 | 9 | 10 | def kaiming_uniform(tensor, fan, a): 11 | if tensor is not None: 12 | bound = math.sqrt(6 / ((1 + a**2) * fan)) 13 | tensor.data.uniform_(-bound, bound) 14 | 15 | 16 | def glorot(tensor): 17 | if tensor is not None: 18 | stdv = math.sqrt(6.0 / (tensor.size(-2) + tensor.size(-1))) 19 | tensor.data.uniform_(-stdv, stdv) 20 | 21 | 22 | def zeros(tensor): 23 | if tensor is not None: 24 | tensor.data.fill_(0) 25 | 26 | 27 | def ones(tensor): 28 | if tensor is not None: 29 | tensor.data.fill_(1) 30 | 31 | 32 | def reset(nn): 33 | def _reset(item): 34 | if hasattr(item, 'reset_parameters'): 35 | item.reset_parameters() 36 | 37 | if nn is not None: 38 | if hasattr(nn, 'children') and len(list(nn.children())) > 0: 39 | for item in nn.children(): 40 | _reset(item) 41 | else: 42 | _reset(nn) 43 | -------------------------------------------------------------------------------- /models/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from models.conv import ChebConv 5 | 6 | from .inits import reset 7 | from torch_scatter import scatter_add 8 | 9 | 10 | def Pool(x, trans, dim=1): 11 | row, col = trans._indices() 12 | value = trans._values().unsqueeze(-1) 13 | out = torch.index_select(x, dim, col) * value 14 | out = scatter_add(out, row, dim, dim_size=trans.size(0)) 15 | return out 16 | 17 | 18 | class Enblock(nn.Module): 19 | def __init__(self, in_channels, out_channels, K, **kwargs): 20 | super(Enblock, self).__init__() 21 | self.conv = ChebConv(in_channels, out_channels, K, **kwargs) 22 | self.reset_parameters() 23 | 24 | def reset_parameters(self): 25 | for name, param in self.conv.named_parameters(): 26 | if 'bias' in name: 27 | nn.init.constant_(param, 0) 28 | else: 29 | nn.init.xavier_uniform_(param) 30 | 31 | def forward(self, x, edge_index, down_transform): 32 | out = F.elu(self.conv(x, edge_index)) 33 | out = Pool(out, down_transform) 34 | return out 35 | 36 | 37 | class Deblock(nn.Module): 38 | def __init__(self, in_channels, out_channels, K, **kwargs): 39 | super(Deblock, self).__init__() 40 | self.conv = ChebConv(in_channels, out_channels, K, **kwargs) 41 | self.reset_parameters() 42 | 43 | def reset_parameters(self): 44 | for name, param in self.conv.named_parameters(): 45 | if 'bias' in name: 46 | nn.init.constant_(param, 0) 47 | else: 48 | nn.init.xavier_uniform_(param) 49 | 50 | def forward(self, x, edge_index, up_transform): 51 | out = Pool(x, up_transform) 52 | out = F.elu(self.conv(out, edge_index)) 53 | return out 54 | 55 | 56 | class AE(nn.Module): 57 | def __init__(self, in_channels, out_channels, latent_channels, 58 | edge_index, down_transform, up_transform, K, **kwargs): 59 | super(AE, self).__init__() 60 | self.in_channels = in_channels 61 | self.out_channels = out_channels 62 | #self.edge_index = edge_index 63 | self.num_edge_index = len(edge_index) 64 | for i in range(self.num_edge_index): 65 | self.register_buffer(f'edge_index_{i}', edge_index[i]) 66 | setattr(self, f'edge_index_{i}', edge_index[i]) 67 | 68 | #self.down_transform = down_transform 69 | self.num_down_transform = len(down_transform) 70 | for i in range(self.num_down_transform): 71 | self.register_buffer(f'down_transform_{i}', down_transform[i]) 72 | setattr(self, f'down_transform_{i}', down_transform[i]) 73 | 74 | #self.up_transform = up_transform 75 | self.num_up_transform = len(up_transform) 76 | for i in range(self.num_up_transform): 77 | self.register_buffer(f'up_transform_{i}', up_transform[i]) 78 | setattr(self, f'up_transform_{i}', up_transform[i]) 79 | # self.num_vert used in the last and the first layer of encoder and decoder 80 | self.num_vert = down_transform[-1].size(0) 81 | 82 | # encoder 83 | #self.en_layers = nn.ModuleList() 84 | #for idx in range(len(out_channels)): 85 | # if idx == 0: 86 | # self.en_layers.append( 87 | # Enblock(in_channels, out_channels[idx], K, **kwargs)) 88 | # else: 89 | # self.en_layers.append( 90 | # Enblock(out_channels[idx - 1], out_channels[idx], K, 91 | # **kwargs)) 92 | #self.en_layers.append( 93 | # nn.Linear(self.num_vert * out_channels[-1], latent_channels)) 94 | 95 | # decoder 96 | self.de_layers = nn.ModuleList() 97 | self.de_layers.append( 98 | nn.Linear(latent_channels, self.num_vert * out_channels[-1])) 99 | for idx in range(len(out_channels)): 100 | if idx == 0: 101 | self.de_layers.append( 102 | Deblock(out_channels[-idx - 1], out_channels[-idx - 1], K, 103 | **kwargs)) 104 | else: 105 | self.de_layers.append( 106 | Deblock(out_channels[-idx], out_channels[-idx - 1], K, 107 | **kwargs)) 108 | # reconstruction 109 | self.de_layers.append( 110 | ChebConv(out_channels[0], in_channels, K, **kwargs)) 111 | 112 | self.reset_parameters() 113 | 114 | def reset_parameters(self): 115 | for name, param in self.named_parameters(): 116 | if 'bias' in name: 117 | nn.init.constant_(param, 0) 118 | else: 119 | nn.init.xavier_uniform_(param) 120 | 121 | def encoder(self, x): 122 | for i, layer in enumerate(self.en_layers): 123 | if i != len(self.en_layers) - 1: 124 | x = layer(x, getattr(self, f'edge_index_{i}'), 125 | getattr(self, f'down_transform_{i}')) 126 | else: 127 | x = x.view(-1, layer.weight.size(1)) 128 | x = layer(x) 129 | return x 130 | 131 | def decoder(self, x): 132 | num_layers = len(self.de_layers) 133 | num_deblocks = num_layers - 2 134 | for i, layer in enumerate(self.de_layers): 135 | if i == 0: 136 | x = layer(x) 137 | x = x.view(-1, self.num_vert, self.out_channels[-1]) 138 | elif i != num_layers - 1: 139 | x = layer(x, getattr(self, f'edge_index_{num_deblocks - i}'), 140 | getattr(self, f'up_transform_{num_deblocks - i}')) 141 | else: 142 | # last layer 143 | x = layer(x, getattr(self, 'edge_index_0')) 144 | return x 145 | 146 | def forward(self, x): 147 | # x - batched feature matrix 148 | #z = self.encoder(x) 149 | out = self.decoder(x) 150 | return out 151 | 152 | class AE_single(nn.Module): 153 | def __init__(self, in_channels, out_channels, latent_channels, 154 | edge_index, down_transform, up_transform, K, **kwargs): 155 | super(AE_single, self).__init__() 156 | self.in_channels = in_channels 157 | self.out_channels = out_channels 158 | self.edge_index = edge_index 159 | self.down_transform = down_transform 160 | self.up_transform = up_transform 161 | # self.num_vert used in the last and the first layer of encoder and decoder 162 | self.num_vert = self.down_transform[-1].size(0) 163 | 164 | # encoder 165 | #self.en_layers = nn.ModuleList() 166 | #for idx in range(len(out_channels)): 167 | # if idx == 0: 168 | # self.en_layers.append( 169 | # Enblock(in_channels, out_channels[idx], K, **kwargs)) 170 | # else: 171 | # self.en_layers.append( 172 | # Enblock(out_channels[idx - 1], out_channels[idx], K, 173 | # **kwargs)) 174 | #self.en_layers.append( 175 | # nn.Linear(self.num_vert * out_channels[-1], latent_channels)) 176 | 177 | # decoder 178 | self.de_layers = nn.ModuleList() 179 | self.de_layers.append( 180 | nn.Linear(latent_channels, self.num_vert * out_channels[-1])) 181 | for idx in range(len(out_channels)): 182 | if idx == 0: 183 | self.de_layers.append( 184 | Deblock(out_channels[-idx - 1], out_channels[-idx - 1], K, 185 | **kwargs)) 186 | else: 187 | self.de_layers.append( 188 | Deblock(out_channels[-idx], out_channels[-idx - 1], K, 189 | **kwargs)) 190 | # reconstruction 191 | self.de_layers.append( 192 | ChebConv(out_channels[0], in_channels, K, **kwargs)) 193 | 194 | self.reset_parameters() 195 | 196 | def reset_parameters(self): 197 | for name, param in self.named_parameters(): 198 | if 'bias' in name: 199 | nn.init.constant_(param, 0) 200 | else: 201 | nn.init.xavier_uniform_(param) 202 | 203 | def encoder(self, x): 204 | for i, layer in enumerate(self.en_layers): 205 | if i != len(self.en_layers) - 1: 206 | x = layer(x, self.edge_index[i], self.down_transform[i]) 207 | else: 208 | x = x.view(-1, layer.weight.size(1)) 209 | x = layer(x) 210 | return x 211 | 212 | def decoder(self, x): 213 | num_layers = len(self.de_layers) 214 | num_deblocks = num_layers - 2 215 | for i, layer in enumerate(self.de_layers): 216 | if i == 0: 217 | x = layer(x) 218 | x = x.view(-1, self.num_vert, self.out_channels[-1]) 219 | elif i != num_layers - 1: 220 | x = layer(x, self.edge_index[num_deblocks - i], 221 | self.up_transform[num_deblocks - i]) 222 | else: 223 | # last layer 224 | x = layer(x, self.edge_index[0]) 225 | return x 226 | 227 | def forward(self, x): 228 | # x - batched feature matrix 229 | #z = self.encoder(x) 230 | out = self.decoder(x) 231 | return out 232 | 233 | -------------------------------------------------------------------------------- /template/dfaust.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitBoSun/ARAPReg/88c8a75596c1151f3ed02af4b51f2fa7497e5561/template/dfaust.ply -------------------------------------------------------------------------------- /template/smal_0.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitBoSun/ARAPReg/88c8a75596c1151f3ed02af4b51f2fa7497e5561/template/smal_0.ply -------------------------------------------------------------------------------- /template/smpl_male_template.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GitBoSun/ARAPReg/88c8a75596c1151f3ed02af4b51f2fa7497e5561/template/smpl_male_template.ply -------------------------------------------------------------------------------- /test_bone.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | CUDA_VISIBLE_DEVICES=0 python main.py \ 3 | --exp_name 'model0' \ 4 | --device_idx 0 \ 5 | --epochs 4500 \ 6 | --lr 8e-3 \ 7 | --arap_weight 0.05 \ 8 | --use_arap_epoch 1500 \ 9 | --decay_step 10 \ 10 | --latent_channels 8 \ 11 | --use_vert_pca True \ 12 | --work_dir ./work_dir/bone/tibia \ 13 | --dataset Bone \ 14 | --data_dir ./data/bone/tibia \ 15 | --continue_train True \ 16 | --test_lr 8e-2 \ 17 | --test_epochs 2500 \ 18 | --test_decay_step 30 \ 19 | --mode test \ 20 | -------------------------------------------------------------------------------- /test_dfaust.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -W ignore main.py \ 3 | --out_channels 32 32 32 64 \ 4 | --ds_factors 2 2 2 2 \ 5 | --exp_name 'arap' \ 6 | --device_idx 0 \ 7 | --batch_size 64 \ 8 | --epochs 2000 \ 9 | --n_threads 4 \ 10 | --lr 1e-4 \ 11 | --arap_weight 0.0 \ 12 | --use_arap_epoch 800 \ 13 | --decay_step 3 \ 14 | --latent_channels 72 \ 15 | --use_vert_pca True \ 16 | --work_dir ./work_dir/DFaust \ 17 | --dataset DFaust \ 18 | --data_dir ./data/DFaust \ 19 | --continue_train True \ 20 | --test_lr 1e-3 \ 21 | --test_epochs 2500 \ 22 | --test_decay_step 5 \ 23 | --mode test \ 24 | --distributed \ 25 | --checkpoint work_dir/DFaust/out/arap/checkpoints/checkpoint_0410.pt \ 26 | --test_checkpoint work_dir/DFaust/out/arap/test_checkpoints/checkpoint_1230.pt \ 27 | -------------------------------------------------------------------------------- /test_smal.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | CUDA_VISIBLE_DEVICES=0 python -W ignore main.py \ 3 | --exp_name 'model0' \ 4 | --device_idx 0 \ 5 | --batch_size 8 \ 6 | --epochs 2000 \ 7 | --lr 0.01 \ 8 | --arap_weight 0.0 \ 9 | --use_arap_epoch 800 \ 10 | --nz_max 50 \ 11 | --decay_step 3 \ 12 | --latent_channels 96 \ 13 | --use_pose_init True \ 14 | --work_dir ./work_dir/SMAL \ 15 | --dataset SMAL \ 16 | --data_dir ./data/SMAL \ 17 | --continue_train True \ 18 | --test_lr 0.01 \ 19 | --test_epochs 2500 \ 20 | --test_decay_step 5 \ 21 | --mode test \ 22 | 23 | -------------------------------------------------------------------------------- /train_and_test_dfaust.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -W ignore main.py \ 3 | --out_channels 32 32 32 64 \ 4 | --ds_factors 2 2 2 2 \ 5 | --exp_name 'arap' \ 6 | --device_idx 0 \ 7 | --batch_size 64 \ 8 | --epochs 450 \ 9 | --lr 1e-4 \ 10 | --test_lr 1e-3 \ 11 | --test_decay_step 5 \ 12 | --arap_weight 5e-4 \ 13 | --use_arap_epoch 150 \ 14 | --decay_step 3 \ 15 | --latent_channels 72 \ 16 | --use_vert_pca True \ 17 | --work_dir ./work_dir/DFaust \ 18 | --dataset DFaust \ 19 | --data_dir ./data/DFaust \ 20 | --distributed \ 21 | --alsotest \ 22 | #--continue_train True \ 23 | #--checkpoint work_dir/DFaust/out/arap/checkpoints/checkpoint_0410.pt \ 24 | #--test_checkpoint work_dir/DFaust/out/arap/test_checkpoints/checkpoint_01200.pt \ 25 | -------------------------------------------------------------------------------- /train_bone.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | CUDA_VISIBLE_DEVICES=0 python main.py \ 3 | --exp_name 'model0' \ 4 | --device_idx 0 \ 5 | --epochs 4500 \ 6 | --lr 8e-3 \ 7 | --arap_weight 0.05 \ 8 | --use_arap_epoch 1000 \ 9 | --decay_step 10 \ 10 | --latent_channels 8 \ 11 | --use_vert_pca True \ 12 | --work_dir ./work_dir/bone/tibia \ 13 | --dataset Bone \ 14 | --data_dir ./data/bone/tibia \ 15 | -------------------------------------------------------------------------------- /train_smal.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | CUDA_VISIBLE_DEVICES=0 python -W ignore main.py \ 3 | --exp_name 'model0' \ 4 | --device_idx 0 \ 5 | --batch_size 16 \ 6 | --epochs 2000 \ 7 | --lr 0.01 \ 8 | --arap_weight 0.05 \ 9 | --use_arap_epoch 800 \ 10 | --nz_max 96 \ 11 | --decay_step 3 \ 12 | --latent_channels 96 \ 13 | --use_pose_init True \ 14 | --work_dir ./work_dir/SMAL \ 15 | --dataset SMAL \ 16 | --data_dir ./data/SMAL \ 17 | --continue_train True \ 18 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataloader import DataLoader 2 | 3 | ___all__ = [ 4 | 'DataLoader', 5 | ] 6 | -------------------------------------------------------------------------------- /utils/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data 2 | from torch.utils.data.dataloader import default_collate 3 | 4 | from torch_geometric.data import Data, Batch 5 | 6 | 7 | class DataLoader(torch.utils.data.DataLoader): 8 | def __init__(self, dataset, batch_size=1, shuffle=False, **kwargs): 9 | def dense_collate(data_list): 10 | batch = Batch() 11 | batch.batch = [] 12 | for key in data_list[0].keys: 13 | batch[key] = default_collate([d[key] for d in data_list]) 14 | for i, data in enumerate(data_list): 15 | num_nodes = data.num_nodes 16 | if num_nodes is not None: 17 | item = torch.full((num_nodes, ), i, dtype=torch.long) 18 | batch.batch.append(item) 19 | batch.batch = torch.cat(batch.batch, dim=0) 20 | return batch 21 | 22 | super(DataLoader, self).__init__(dataset, 23 | batch_size, 24 | shuffle, 25 | collate_fn=dense_collate, 26 | **kwargs) 27 | -------------------------------------------------------------------------------- /utils/mesh_sampling.py: -------------------------------------------------------------------------------- 1 | import math 2 | import heapq 3 | import numpy as np 4 | import os 5 | import scipy.sparse as sp 6 | from psbody.mesh import Mesh 7 | from scipy.spatial import KDTree 8 | 9 | def row(A): 10 | return A.reshape((1, -1)) 11 | 12 | def col(A): 13 | return A.reshape((-1, 1)) 14 | 15 | def get_vert_connectivity(mesh_v, mesh_f): 16 | """Returns a sparse matrix (of size #verts x #verts) where each nonzero 17 | element indicates a neighborhood relation. For example, if there is a 18 | nonzero element in position (15,12), that means vertex 15 is connected 19 | by an edge to vertex 12.""" 20 | 21 | vpv = sp.csc_matrix((len(mesh_v), len(mesh_v))) 22 | 23 | # for each column in the faces... 24 | for i in range(3): 25 | IS = mesh_f[:, i] 26 | JS = mesh_f[:, (i + 1) % 3] 27 | data = np.ones(len(IS)) 28 | ij = np.vstack((row(IS.ravel()), row(JS.ravel()))) 29 | mtx = sp.csc_matrix((data, ij), shape=vpv.shape) 30 | vpv = vpv + mtx + mtx.T 31 | 32 | return vpv 33 | 34 | 35 | def get_vertices_per_edge(mesh_v, mesh_f): 36 | """Returns an Ex2 array of adjacencies between vertices, where 37 | each element in the array is a vertex index. Each edge is included 38 | only once. If output of get_faces_per_edge is provided, this is used to 39 | avoid call to get_vert_connectivity()""" 40 | 41 | vc = sp.coo_matrix(get_vert_connectivity(mesh_v, mesh_f)) 42 | result = np.hstack((col(vc.row), col(vc.col))) 43 | result = result[result[:, 0] < result[:, 1]] # for uniqueness 44 | 45 | return result 46 | 47 | 48 | def vertex_quadrics(mesh): 49 | """Computes a quadric for each vertex in the Mesh. 50 | 51 | Returns: 52 | v_quadrics: an (N x 4 x 4) array, where N is # vertices. 53 | """ 54 | 55 | # Allocate quadrics 56 | v_quadrics = np.zeros(( 57 | len(mesh.v), 58 | 4, 59 | 4, 60 | )) 61 | 62 | # For each face... 63 | for f_idx in range(len(mesh.f)): 64 | 65 | # Compute normalized plane equation for that face 66 | vert_idxs = mesh.f[f_idx] 67 | verts = np.hstack((mesh.v[vert_idxs], np.array([1, 1, 68 | 1]).reshape(-1, 1))) 69 | u, s, v = np.linalg.svd(verts) 70 | eq = v[-1, :].reshape(-1, 1) 71 | eq = eq / (np.linalg.norm(eq[0:3])) 72 | 73 | # Add the outer product of the plane equation to the 74 | # quadrics of the vertices for this face 75 | for k in range(3): 76 | v_quadrics[mesh.f[f_idx, k], :, :] += np.outer(eq, eq) 77 | 78 | return v_quadrics 79 | 80 | 81 | def setup_deformation_transfer(source, target, use_normals=False): 82 | rows = np.zeros(3 * target.v.shape[0]) 83 | cols = np.zeros(3 * target.v.shape[0]) 84 | coeffs_v = np.zeros(3 * target.v.shape[0]) 85 | coeffs_n = np.zeros(3 * target.v.shape[0]) 86 | 87 | nearest_faces, nearest_parts, nearest_vertices = source.compute_aabb_tree( 88 | ).nearest(target.v, True) 89 | nearest_faces = nearest_faces.ravel().astype(np.int64) 90 | nearest_parts = nearest_parts.ravel().astype(np.int64) 91 | nearest_vertices = nearest_vertices.ravel() 92 | 93 | for i in range(target.v.shape[0]): 94 | # Closest triangle index 95 | f_id = nearest_faces[i] 96 | # Closest triangle vertex ids 97 | nearest_f = source.f[f_id] 98 | 99 | # Closest surface point 100 | nearest_v = nearest_vertices[3 * i:3 * i + 3] 101 | # Distance vector to the closest surface point 102 | dist_vec = target.v[i] - nearest_v 103 | 104 | rows[3 * i:3 * i + 3] = i * np.ones(3) 105 | cols[3 * i:3 * i + 3] = nearest_f 106 | 107 | n_id = nearest_parts[i] 108 | if n_id == 0: 109 | # Closest surface point in triangle 110 | A = np.vstack((source.v[nearest_f])).T 111 | coeffs_v[3 * i:3 * i + 3] = np.linalg.lstsq(A, nearest_v, 112 | rcond=-1)[0] 113 | elif n_id > 0 and n_id <= 3: 114 | # Closest surface point on edge 115 | A = np.vstack((source.v[nearest_f[n_id - 1]], 116 | source.v[nearest_f[n_id % 3]])).T 117 | tmp_coeffs = np.linalg.lstsq(A, target.v[i], rcond=-1)[0] 118 | coeffs_v[3 * i + n_id - 1] = tmp_coeffs[0] 119 | coeffs_v[3 * i + n_id % 3] = tmp_coeffs[1] 120 | else: 121 | # Closest surface point a vertex 122 | coeffs_v[3 * i + n_id - 4] = 1.0 123 | 124 | matrix = sp.csc_matrix((coeffs_v, (rows, cols)), 125 | shape=(target.v.shape[0], source.v.shape[0])) 126 | return matrix 127 | 128 | 129 | def qslim_decimator_transformer(mesh, factor=None, n_verts_desired=None): 130 | """Return a simplified version of this mesh. 131 | 132 | A Qslim-style approach is used here. 133 | 134 | :param factor: fraction of the original vertices to retain 135 | :param n_verts_desired: number of the original vertices to retain 136 | :returns: new_faces: An Fx3 array of faces, mtx: Transformation matrix 137 | """ 138 | 139 | if factor is None and n_verts_desired is None: 140 | raise Exception('Need either factor or n_verts_desired.') 141 | 142 | if n_verts_desired is None: 143 | n_verts_desired = math.ceil(len(mesh.v) * factor) 144 | 145 | Qv = vertex_quadrics(mesh) 146 | 147 | # fill out a sparse matrix indicating vertex-vertex adjacency 148 | # from psbody.mesh.topology.connectivity import get_vertices_per_edge 149 | vert_adj = get_vertices_per_edge(mesh.v, mesh.f) 150 | # vert_adj = sp.lil_matrix((len(mesh.v), len(mesh.v))) 151 | # for f_idx in range(len(mesh.f)): 152 | # vert_adj[mesh.f[f_idx], mesh.f[f_idx]] = 1 153 | 154 | vert_adj = sp.csc_matrix( 155 | (vert_adj[:, 0] * 0 + 1, (vert_adj[:, 0], vert_adj[:, 1])), 156 | shape=(len(mesh.v), len(mesh.v))) 157 | vert_adj = vert_adj + vert_adj.T 158 | vert_adj = vert_adj.tocoo() 159 | 160 | def collapse_cost(Qv, r, c, v): 161 | Qsum = Qv[r, :, :] + Qv[c, :, :] 162 | p1 = np.vstack((v[r].reshape(-1, 1), np.array([1]).reshape(-1, 1))) 163 | p2 = np.vstack((v[c].reshape(-1, 1), np.array([1]).reshape(-1, 1))) 164 | 165 | destroy_c_cost = p1.T.dot(Qsum).dot(p1) 166 | destroy_r_cost = p2.T.dot(Qsum).dot(p2) 167 | result = { 168 | 'destroy_c_cost': destroy_c_cost, 169 | 'destroy_r_cost': destroy_r_cost, 170 | 'collapse_cost': min([destroy_c_cost, destroy_r_cost]), 171 | 'Qsum': Qsum 172 | } 173 | return result 174 | 175 | # construct a queue of edges with costs 176 | queue = [] 177 | for k in range(vert_adj.nnz): 178 | r = vert_adj.row[k] 179 | c = vert_adj.col[k] 180 | 181 | if r > c: 182 | continue 183 | 184 | cost = collapse_cost(Qv, r, c, mesh.v)['collapse_cost'] 185 | heapq.heappush(queue, (cost, (r, c))) 186 | 187 | # decimate 188 | collapse_list = [] 189 | nverts_total = len(mesh.v) 190 | faces = mesh.f.copy() 191 | while nverts_total > n_verts_desired: 192 | e = heapq.heappop(queue) 193 | r = e[1][0] 194 | c = e[1][1] 195 | if r == c: 196 | continue 197 | 198 | cost = collapse_cost(Qv, r, c, mesh.v) 199 | if cost['collapse_cost'] > e[0]: 200 | heapq.heappush(queue, (cost['collapse_cost'], e[1])) 201 | # print 'found outdated cost, %.2f < %.2f' % (e[0], cost['collapse_cost']) 202 | continue 203 | else: 204 | 205 | # update old vert idxs to new one, 206 | # in queue and in face list 207 | if cost['destroy_c_cost'] < cost['destroy_r_cost']: 208 | to_destroy = c 209 | to_keep = r 210 | else: 211 | to_destroy = r 212 | to_keep = c 213 | 214 | collapse_list.append([to_keep, to_destroy]) 215 | 216 | # in our face array, replace "to_destroy" vertidx with "to_keep" vertidx 217 | np.place(faces, faces == to_destroy, to_keep) 218 | 219 | # same for queue 220 | which1 = [ 221 | idx for idx in range(len(queue)) 222 | if queue[idx][1][0] == to_destroy 223 | ] 224 | which2 = [ 225 | idx for idx in range(len(queue)) 226 | if queue[idx][1][1] == to_destroy 227 | ] 228 | for k in which1: 229 | queue[k] = (queue[k][0], (to_keep, queue[k][1][1])) 230 | for k in which2: 231 | queue[k] = (queue[k][0], (queue[k][1][0], to_keep)) 232 | 233 | Qv[r, :, :] = cost['Qsum'] 234 | Qv[c, :, :] = cost['Qsum'] 235 | 236 | a = faces[:, 0] == faces[:, 1] 237 | b = faces[:, 1] == faces[:, 2] 238 | c = faces[:, 2] == faces[:, 0] 239 | 240 | # remove degenerate faces 241 | def logical_or3(x, y, z): 242 | return np.logical_or(x, np.logical_or(y, z)) 243 | 244 | faces_to_keep = np.logical_not(logical_or3(a, b, c)) 245 | faces = faces[faces_to_keep, :].copy() 246 | 247 | nverts_total = (len(np.unique(faces.flatten()))) 248 | 249 | new_faces, mtx = _get_sparse_transform(faces, len(mesh.v)) 250 | return new_faces, mtx 251 | 252 | 253 | def _get_sparse_transform(faces, num_original_verts): 254 | verts_left = np.unique(faces.flatten()) 255 | IS = np.arange(len(verts_left)) 256 | JS = verts_left 257 | data = np.ones(len(JS)) 258 | 259 | mp = np.arange(0, np.max(faces.flatten()) + 1) 260 | mp[JS] = IS 261 | new_faces = mp[faces.copy().flatten()].reshape((-1, 3)) 262 | 263 | ij = np.vstack((IS.flatten(), JS.flatten())) 264 | mtx = sp.csc_matrix((data, ij), 265 | shape=(len(verts_left), num_original_verts)) 266 | 267 | return (new_faces, mtx) 268 | 269 | 270 | def generate_transform_matrices(mesh, factors): 271 | """Generates len(factors) meshes, each of them is scaled by factors[i] and 272 | computes the transformations between them. 273 | 274 | Returns: 275 | M: a set of meshes downsampled from mesh by a factor specified in factors. 276 | A: Adjacency matrix for each of the meshes 277 | D: csc_matrix Downsampling transforms between each of the meshes 278 | U: Upsampling transforms between each of the meshes 279 | F: a list of faces 280 | """ 281 | 282 | factors = map(lambda x: 1.0 / x, factors) 283 | M, A, D, U, F = [], [], [], [], [] 284 | F.append(mesh.f) # F[0] 285 | A.append(get_vert_connectivity(mesh.v, mesh.f).astype('float32')) # A[0] 286 | M.append(mesh) # M[0] 287 | 288 | for factor in factors: 289 | ds_f, ds_D = qslim_decimator_transformer(M[-1], factor=factor) 290 | D.append(ds_D.astype('float32')) 291 | new_mesh_v = ds_D.dot(M[-1].v) 292 | new_mesh = Mesh(v=new_mesh_v, f=ds_f) 293 | F.append(new_mesh.f) 294 | M.append(new_mesh) 295 | A.append( 296 | get_vert_connectivity(new_mesh.v, new_mesh.f).astype('float32')) 297 | U.append(setup_deformation_transfer(M[-1], M[-2]).astype('float32')) 298 | 299 | return M, A, D, U, F 300 | -------------------------------------------------------------------------------- /utils/read.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.data import Data 3 | from torch_geometric.utils import to_undirected 4 | import openmesh as om 5 | 6 | def read_mesh(path, data_id, pose=None, return_face=False): 7 | mesh = om.read_trimesh(path) 8 | points = mesh.points() 9 | face = torch.from_numpy(mesh.face_vertex_indices()).T.type(torch.long) 10 | 11 | x = torch.tensor(points.astype('float32')) 12 | edge_index = torch.cat([face[:2], face[1:], face[::2]], dim=1) 13 | edge_index = to_undirected(edge_index) 14 | if return_face==True and pose is not None: 15 | return Data(x=x, edge_index=edge_index,face=face, data_id=data_id, pose=pose) 16 | if pose is not None: 17 | return Data(x=x, edge_index=edge_index, data_id=data_id, pose=pose) 18 | if return_face==True: 19 | return Data(x=x, edge_index=edge_index,face=face, data_id=data_id) 20 | return Data(x=x, edge_index=edge_index, data_id=data_id) 21 | 22 | -------------------------------------------------------------------------------- /utils/train_eval.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os, sys 3 | import torch 4 | import torch.nn.functional as F 5 | from psbody.mesh import Mesh 6 | import numpy as np 7 | from models import ARAP 8 | import scipy.io as sio 9 | from tqdm import tqdm 10 | 11 | def run(model, 12 | train_loader, lat_vecs, optimizer, scheduler, 13 | test_loader, test_lat_vecs, optimizer_test, scheduler_test, 14 | epochs, writer, device, results_dir, data_mean, data_std, 15 | template_face, arap_weight=0.0, use_arap_epoch=800, nz_max=10, 16 | continue_train=False, checkpoint=None, test_checkpoint=None, dataset='DFaust'): 17 | 18 | start_epoch = 0 19 | if continue_train: 20 | start_epoch = writer.load_checkpoint( 21 | model, lat_vecs, optimizer, scheduler, checkpoint=checkpoint) 22 | 23 | if (optimizer_test is not None) and continue_train: 24 | test_epoch = writer.load_checkpoint( 25 | model, test_lat_vecs, optimizer_test, scheduler_test, 26 | checkpoint=test_checkpoint, test=True) 27 | 28 | for epoch in range(start_epoch+1, epochs): 29 | t = time.time() 30 | use_arap = epoch > use_arap_epoch 31 | 32 | train_loss, l1_loss, arap_loss, l2_error = \ 33 | train( 34 | model, epoch, optimizer, train_loader, lat_vecs, 35 | device, results_dir, data_mean, data_std, template_face, 36 | arap_weight=arap_weight, nz_max=nz_max, 37 | use_arap=use_arap, lr=scheduler.get_lr()[0], 38 | dataset=dataset, 39 | ) 40 | 41 | if optimizer_test is not None: 42 | for k in range(3): 43 | test_loss, test_l1_loss, test_arap_loss, test_l2_error = \ 44 | train( 45 | model, epoch*3+k, optimizer_test, test_loader, test_lat_vecs, 46 | device, results_dir, data_mean, data_std, template_face, 47 | arap_weight=0.0, nz_max=nz_max, use_arap=False, 48 | exp_name='reconstruct', lr=scheduler_test.get_lr()[0], 49 | dataset=dataset, 50 | ) 51 | test_info = { 52 | 'test_current_epoch': epoch*3+k, 53 | 'test_epochs': epochs*3, 54 | 'test_loss': test_loss, 55 | 'l1_loss': test_l1_loss, 56 | 'arap_loss': test_arap_loss, 57 | 'mse_error': test_l2_error, 58 | 't_duration': 0.0, 59 | 'lr': scheduler_test.get_lr()[0] 60 | } 61 | scheduler_test.step() 62 | writer.print_info_test(test_info) 63 | if epoch % 10 == 0: 64 | writer.save_checkpoint( 65 | model, test_lat_vecs, optimizer_test, scheduler_test, 66 | epoch*3, test=True 67 | ) 68 | 69 | t_duration = time.time() - t 70 | scheduler.step() 71 | info = { 72 | 'current_epoch': epoch, 73 | 'epochs': epochs, 74 | 'train_loss': train_loss, 75 | 'l1_loss': l1_loss, 76 | 'arap_loss': arap_loss, 77 | 'mse_error': l2_error, 78 | 't_duration': t_duration, 79 | 'lr': scheduler.get_last_lr()[0] 80 | } 81 | 82 | writer.print_info(info) 83 | if epoch % 10 == 0: 84 | writer.save_checkpoint(model, lat_vecs, optimizer, scheduler, epoch) 85 | 86 | 87 | def train(model, epoch, optimizer, loader, lat_vecs, device, 88 | results_dir,data_mean, data_std, template_face, 89 | arap_weight=0.0, nz_max=10, use_arap=False, dump=False, 90 | exp_name='train', lr=0.0, dataset='DFaust'): 91 | 92 | model.train() 93 | total_loss = 0 94 | total_l1_loss = 0 95 | total_arap_loss = 0 96 | l2_errors = [] 97 | 98 | data_mean_gpu = torch.as_tensor(data_mean, dtype=torch.float).to(device) 99 | data_std_gpu = torch.as_tensor(data_std, dtype=torch.float).to(device) 100 | 101 | model.train() 102 | pbar = tqdm(total=len(loader), 103 | desc=f'{exp_name} {epoch}') 104 | 105 | # ARAP 106 | arap = ARAP(template_face, template_face.max()+1).to(device) 107 | mse = [0, 0] 108 | 109 | for b, data in enumerate(loader): 110 | optimizer.zero_grad() 111 | x = data.x.to(device) 112 | #x[:] = loader.dataset[4044].x 113 | ids = data.data_id.to(device) 114 | #ids[:] = 4044 115 | batch_vecs = lat_vecs(ids.view(-1)) 116 | out = model(batch_vecs) 117 | 118 | pred_shape = out * data_std_gpu + data_mean_gpu 119 | gt_shape = x * data_std_gpu + data_mean_gpu 120 | if dataset=='SMAL': 121 | l1_loss = F.l1_loss(x, out, reduction='mean') 122 | arap_ep = 1e-3 123 | else: 124 | l1_loss = F.l1_loss(pred_shape, gt_shape, reduction='mean') 125 | arap_ep = 1e-1 126 | tmp_error = torch.sqrt(torch.sum((pred_shape - gt_shape)**2,dim=2)).detach().cpu() 127 | mse[1] += tmp_error.view(-1).shape[0] 128 | mse[0] += tmp_error.sum().item() 129 | l2_errors.append(tmp_error) 130 | 131 | loss = torch.zeros(1).to(device) 132 | loss += l1_loss 133 | 134 | if use_arap and arap_weight>0: 135 | 136 | jacob = get_jacobian_rand( 137 | pred_shape, batch_vecs, data_mean_gpu, data_std_gpu, 138 | model, device, 139 | epsilon=arap_ep, 140 | nz_max=nz_max) 141 | arap_energy = arap(pred_shape, jacob,) / jacob.shape[-1] 142 | total_arap_loss += arap_energy.item() 143 | loss += arap_weight*arap_energy 144 | 145 | loss.backward() 146 | total_loss += loss.item() 147 | total_l1_loss += l1_loss.item() 148 | pbar.set_postfix({ 149 | 'loss': total_loss / (b+1.0), 150 | 'arap': total_arap_loss / (b+1.0), 151 | 'w_arap': arap_weight, 152 | 'MSE': mse[0] / (mse[1]+1e-6) * 1000, 153 | 'lr': lr, 154 | }) 155 | pbar.update(1) 156 | optimizer.step() 157 | new_errors = torch.cat(l2_errors, dim=0) 158 | mean_error = new_errors.view((-1, )).mean() 159 | 160 | # visualize 161 | #if epoch%20==0 and dump==False: 162 | # gt_meshes = x.detach().cpu().numpy() 163 | # pred_meshes = out.detach().cpu().numpy() 164 | # for b_i in range(min(2, gt_meshes.shape[0])): 165 | # pred_v = pred_meshes[b_i].reshape((-1, 3))*data_std + data_mean 166 | # gt_v = gt_meshes[b_i].reshape((-1, 3))*data_std + data_mean 167 | # mesh = Mesh(v=pred_v, f=template_face) 168 | # mesh.write_ply(os.path.join(results_dir, '%d_%d'%(epoch, b_i)+'_pred.ply')) 169 | # mesh = Mesh(v=gt_v, f=template_face) 170 | # mesh.write_ply(os.path.join(results_dir, '%d_%d'%(epoch, b_i)+'_gt.ply')) 171 | #if dump==True and epoch%200==1: 172 | # model.eval() 173 | # with torch.no_grad(): 174 | # for data in loader: 175 | # x = data.x.to(device) 176 | # ids = data.data_id.to(device) 177 | # batch_vecs = lat_vecs(ids.view(-1)) 178 | # out = model(batch_vecs) 179 | # pred_shape = out * data_std_gpu + data_mean_gpu 180 | # gt_shape = x * data_std_gpu + data_mean_gpu 181 | # gt_meshes = gt_shape.detach().cpu().numpy() 182 | # pred_meshes = pred_shape.detach().cpu().numpy() 183 | # ids_np = data.data_id.cpu().numpy() 184 | 185 | # for b_i in range(gt_meshes.shape[0]): 186 | # pred_v = pred_meshes[b_i].reshape((-1, 3)) 187 | # gt_v = gt_meshes[b_i].reshape((-1 , 3)) 188 | # mesh = Mesh(v=pred_v, f=template_face) 189 | # mesh.write_ply(os.path.join(results_dir, '%d'%(ids_np[b_i])+'_pred.ply')) 190 | # mesh = Mesh(v=gt_v, f=template_face) 191 | # mesh.write_ply(os.path.join(results_dir, '%d'%(ids_np[b_i])+'_gt.ply')) 192 | 193 | return total_loss / len(loader), total_l1_loss / len(loader), total_arap_loss / len(loader), mean_error 194 | 195 | def get_jacobian(out, z, data_mean_gpu, data_std_gpu, model, device, epsilon=[1e-3]): 196 | nb, nz = z.size() 197 | _, n_vert, nc = out.size() 198 | jacobian = torch.zeros((nb, n_vert*nc, nz)).to(device) 199 | 200 | for i in range(nz): 201 | dz = torch.zeros(z.size()).to(device) 202 | dz[:, i] = epsilon 203 | z_new = z + dz 204 | out_new = model(z_new) 205 | dout = (out_new - out).view(nb, -1) 206 | jacobian[:, :, i] = dout/epsilon 207 | 208 | data_std_gpu = data_std_gpu.reshape((1, n_vert*nc, 1)) 209 | jacobian = jacobian*data_std_gpu 210 | return jacobian 211 | 212 | def get_jacobian_rand(cur_shape, z, data_mean_gpu, data_std_gpu, model, device, epsilon=[1e-3],nz_max=10): 213 | nb, nz = z.size() 214 | _, n_vert, nc = cur_shape.size() 215 | if nz >= nz_max: 216 | rand_idx = np.random.permutation(nz)[:nz_max] 217 | nz = nz_max 218 | else: 219 | rand_idx = np.arange(nz) 220 | 221 | jacobian = torch.zeros((nb, n_vert*nc, nz)).to(device) 222 | for i, idx in enumerate(rand_idx): 223 | dz = torch.zeros(z.size()).to(device) 224 | dz[:, idx] = epsilon 225 | z_new = z + dz 226 | out_new = model(z_new) 227 | shape_new = out_new * data_std_gpu + data_mean_gpu 228 | dout = (shape_new - cur_shape).view(nb, -1) 229 | jacobian[:, :, i] = dout/epsilon 230 | return jacobian 231 | 232 | 233 | def test_reconstruct( 234 | model, test_loader, test_lat_vecs, epochs, test_optimizer, scheduler, 235 | writer, device, results_dir, data_mean, data_std, template_face, 236 | checkpoint=None, test_checkpoint=None, dataset='DFaust'): 237 | # load model 238 | start_epoch = writer.load_checkpoint(model, checkpoint=checkpoint) 239 | if test_checkpoint is not None: 240 | start_epoch = writer.load_checkpoint( 241 | model, test_lat_vecs, test_optimizer, scheduler, 242 | checkpoint=test_checkpoint, test=True) 243 | 244 | for epoch in range(1, epochs + 1): 245 | t = time.time() 246 | 247 | test_loss, l1_loss, arap_loss, l2_error = \ 248 | train(model, epoch, test_optimizer, test_loader, 249 | test_lat_vecs, device, results_dir, data_mean, 250 | data_std, template_face, arap_weight=0.0, use_arap=False, 251 | dump=True, exp_name='reconstruct', lr=scheduler.get_lr()[0], 252 | dataset=dataset, 253 | ) 254 | 255 | t_duration = time.time() - t 256 | scheduler.step() 257 | info = { 258 | 'test_current_epoch': epoch, 259 | 'test_epochs': epochs, 260 | 'test_loss': test_loss, 261 | 'l1_loss': l1_loss, 262 | 'arap_loss':arap_loss, 263 | 'mse_error':l2_error, 264 | 't_duration': t_duration, 265 | 'lr': scheduler.get_lr()[0] 266 | } 267 | if epoch % 200 == 1: 268 | writer.save_checkpoint(model, test_lat_vecs, test_optimizer, 269 | scheduler, epoch, test=True) 270 | writer.print_info_test(info) 271 | writer.save_checkpoint( 272 | model, test_lat_vecs, test_optimizer, 273 | scheduler, epoch, test=True) 274 | 275 | def global_interpolate( 276 | model, lat_vecs, optimizer, scheduler, writer, 277 | device, results_dir, data_mean, data_std, 278 | template_face, interpolate_num): 279 | from scipy.sparse import csr_matrix 280 | from scipy.sparse.csgraph import minimum_spanning_tree 281 | import numpy as np 282 | 283 | load_epoch = writer.load_checkpoint(model, lat_vecs, optimizer, scheduler,) 284 | 285 | lat_vecs_np = lat_vecs.weight.data.detach().cpu().numpy() 286 | results_dir = os.path.join(results_dir, "interpolate/epoch%d"%(load_epoch)) 287 | if not os.path.exists(results_dir): 288 | os.makedirs(results_dir) 289 | 290 | min_path = np.random.randint(0, lat_vecs_np.shape[0], size=50) 291 | print('min_path', min_path) 292 | 293 | for p in range(len(min_path)-1): 294 | start_vec = lat_vecs.weight.data[min_path[p]] 295 | end_vec = lat_vecs.weight.data[min_path[p+1]] 296 | 297 | for i in range(interpolate_num+1): 298 | vec = start_vec + i*(end_vec - start_vec)/interpolate_num 299 | out = model(vec) 300 | out_numpy = out.detach().cpu().numpy() 301 | out_v = out_numpy.reshape((-1, 3))*data_std + data_mean 302 | 303 | mesh = Mesh(v=out_v, f=template_face) 304 | mesh.write_obj(os.path.join(results_dir, '%d_%d_%d-%d'%(p, i, min_path[p], min_path[p+1])+'.obj')) 305 | 306 | 307 | def extrapolation(model, lat_vecs, optimizer, scheduler, writer, device, results_dir, data_mean, 308 | data_std, template_face, extra_num=5, extra_thres=0.2): 309 | 310 | load_epoch = writer.load_checkpoint(model, lat_vecs, optimizer, scheduler,test=True) 311 | 312 | lat_vecs_np = lat_vecs.weight.data.detach().cpu().numpy() 313 | print('lat_vecs_np', lat_vecs_np) 314 | results_dir = os.path.join(results_dir, "extrapolation/epoch%d"%(load_epoch)) 315 | if not os.path.exists(results_dir): 316 | os.makedirs(results_dir) 317 | 318 | z_dict = {} 319 | num_train_scenes = lat_vecs_np.shape[0] 320 | 321 | z_path = np.random.randint(0, num_train_scenes, size=30) 322 | print(z_path) 323 | for p in range(len(z_path)-1): 324 | or_vec = lat_vecs.weight.data[z_path[p]] 325 | extra_z = torch.zeros((extra_num+1, lat_vecs_np.shape[1])).to(device) 326 | 327 | for i in range(extra_num): 328 | vec = or_vec + extra_thres*(torch.rand(lat_vecs_np.shape[1]).to(device)-0.5) 329 | extra_z[i+1] = vec 330 | extra_z[0] = or_vec 331 | out = model(extra_z) 332 | out_numpy = out.detach().cpu().numpy() 333 | for i in range(extra_num): 334 | out_v = out_numpy[i].reshape((-1, 3))*data_std + data_mean 335 | mesh = Mesh(v=out_v, f=template_face) 336 | mesh.write_ply(os.path.join(results_dir, '%d_%d'%(z_path[p],i)+'.ply')) 337 | z_dict['%d_%d'%( z_path[p], i)] = extra_z[i].detach().cpu().numpy() 338 | np.save(os.path.join(results_dir, "name_latentz.npy"), z_dict) 339 | 340 | 341 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import numpy as np 4 | from glob import glob 5 | from scipy.spatial import cKDTree as KDTree 6 | from matplotlib import cm 7 | 8 | 9 | 10 | def makedirs(folder): 11 | if not os.path.exists(folder): 12 | os.makedirs(folder) 13 | 14 | 15 | def to_sparse(spmat): 16 | return torch.sparse.FloatTensor( 17 | torch.LongTensor([spmat.tocoo().row, 18 | spmat.tocoo().col]), 19 | torch.FloatTensor(spmat.tocoo().data), torch.Size(spmat.tocoo().shape)) 20 | 21 | 22 | def to_edge_index(mat): 23 | return torch.LongTensor(np.vstack(mat.nonzero())) 24 | 25 | def get_colors_from_diff_pc(diff_pc, min_error, max_error): 26 | colors = np.zeros((diff_pc.shape[0],3)) 27 | mix = (diff_pc-min_error)/(max_error-min_error) 28 | mix = np.clip(mix, 0,1) #point_num 29 | cmap=cm.get_cmap('coolwarm') 30 | colors = cmap(mix)[:,0:3] 31 | return colors 32 | 33 | 34 | def save_pc_with_color_into_ply(template_ply, pc, color, fn): 35 | plydata=template_ply 36 | #pc = pc.copy()*pc_std + pc_mean 37 | plydata['vertex']['x']=pc[:,0] 38 | plydata['vertex']['y']=pc[:,1] 39 | plydata['vertex']['z']=pc[:,2] 40 | 41 | plydata['vertex']['red']=color[:,0] 42 | plydata['vertex']['green']=color[:,1] 43 | plydata['vertex']['blue']=color[:,2] 44 | 45 | plydata.write(fn) 46 | plydata['vertex']['red']=plydata['vertex']['red']*0+0.7*255 47 | plydata['vertex']['green']=plydata['vertex']['red']*0+0.7*255 48 | plydata['vertex']['blue']=plydata['vertex']['red']*0+0.7*255 49 | 50 | def compute_trimesh_chamfer(gt_points, gen_points): 51 | """ 52 | This function computes a symmetric chamfer distance, i.e. the sum of both chamfers. 53 | gt_points: trimesh.points.PointCloud of just poins, sampled from the surface (see 54 | compute_metrics.ply for more documentation) 55 | gen_mesh: trimesh.base.Trimesh of output mesh from whichever autoencoding reconstruction 56 | method (see compute_metrics.py for more) 57 | """ 58 | # only need numpy array of points 59 | # gt_points_np = gt_points.vertices 60 | gt_points_np = gt_points.detach().cpu().numpy() 61 | gen_points_sampled = gen_points.detach().cpu().numpy() 62 | 63 | # one direction 64 | gen_points_kd_tree = KDTree(gen_points_sampled) 65 | one_distances, one_vertex_ids = gen_points_kd_tree.query(gt_points_np) 66 | gt_to_gen_chamfer = np.mean(np.square(one_distances)) 67 | 68 | # other direction 69 | gt_points_kd_tree = KDTree(gt_points_np) 70 | two_distances, two_vertex_ids = gt_points_kd_tree.query(gen_points_sampled) 71 | gen_to_gt_chamfer = np.mean(np.square(two_distances)) 72 | 73 | return gt_to_gen_chamfer + gen_to_gt_chamfer 74 | -------------------------------------------------------------------------------- /utils/writer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | import json 5 | from glob import glob 6 | 7 | 8 | class Writer: 9 | def __init__(self, args=None): 10 | self.args = args 11 | 12 | if self.args is not None: 13 | tmp_log_list = glob(os.path.join(args.out_dir, 'log*')) 14 | 15 | self.test_log_file = os.path.join( 16 | args.out_dir, 'test_log_{:s}.txt'.format( 17 | time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime()))) 18 | if len(tmp_log_list) == 0: 19 | self.log_file = os.path.join( 20 | args.out_dir, 'log_{:s}.txt'.format( 21 | time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime()))) 22 | else: 23 | self.log_file = tmp_log_list[0] 24 | message = '{}'.format(self.args) 25 | if args.mode=='test': 26 | with open(self.test_log_file, 'a') as log_file: 27 | log_file.write('{:s}\n'.format(message)) 28 | else: 29 | with open(self.log_file, 'a') as log_file: 30 | log_file.write('{:s}\n'.format(message)) 31 | 32 | 33 | def print_info(self, info): 34 | message = 'Epoch: {}/{}, Time: {:.3f}s, Train Loss: {:.5f}, L1: {:.5f}, arap: {:.5f} MSE: {:.5f}, lr: {:.6f}' \ 35 | .format(info['current_epoch'], info['epochs'], info['t_duration'], \ 36 | info['train_loss'], info['l1_loss'], info['arap_loss'], info['mse_error'], info['lr']) 37 | 38 | with open(self.log_file, 'a') as log_file: 39 | log_file.write('{:s}\n'.format(message)) 40 | #print(message) 41 | 42 | def print_info_test(self, info): 43 | message = '[test] Epoch: {}/{}, Time: {:.3f}s, Train Loss: {:.5f}, L1: {:.5f}, arap: {:.5f} MSE: {:.6f} lr: {:.6f}' \ 44 | .format(info['test_current_epoch'], info['test_epochs'], info['t_duration'], \ 45 | info['test_loss'], info['l1_loss'], info['arap_loss'],info['mse_error'],info['lr']) 46 | 47 | with open(self.test_log_file, 'a') as log_file: 48 | log_file.write('{:s}\n'.format(message)) 49 | #print(message) 50 | 51 | def save_checkpoint(self, model, latent_vecs, optimizer, scheduler, 52 | epoch, test=False): 53 | model_path = self.args.checkpoints_dir 54 | if test==True: 55 | model_path = self.args.checkpoints_dir_test 56 | print('save test checkpoint') 57 | 58 | torch.save( 59 | { 60 | 'epoch': epoch, 61 | 'model_state_dict': model.state_dict(), 62 | 'optimizer_state_dict': optimizer.state_dict(), 63 | 'scheduler_state_dict': scheduler.state_dict(), 64 | 'train_latent_vecs': latent_vecs.state_dict(), 65 | }, 66 | os.path.join( 67 | model_path, 'checkpoint_{:04d}.pt'.format(epoch)) 68 | ) 69 | 70 | def load_checkpoint(self, model, latent_vecs=None, optimizer=None, 71 | scheduler=None, test=False, checkpoint=None): 72 | model_path = self.args.checkpoints_dir 73 | if test==True: 74 | model_path = self.args.checkpoints_dir_test 75 | 76 | if checkpoint is None: 77 | checkpoint_list = glob(os.path.join(model_path, 'checkpoint_*')) 78 | if len(checkpoint_list)==0: 79 | print('train from scrach') 80 | return 0 81 | latest_checkpoint = sorted(checkpoint_list)[-1] 82 | else: 83 | latest_checkpoint = checkpoint 84 | print("loading model from ", latest_checkpoint) 85 | data = torch.load(latest_checkpoint) 86 | model.load_state_dict(data["model_state_dict"]) 87 | if latent_vecs: 88 | latent_vecs.load_state_dict(data["train_latent_vecs"]) 89 | if scheduler: 90 | scheduler.load_state_dict(data["scheduler_state_dict"]) 91 | if optimizer: 92 | optimizer.load_state_dict(data["optimizer_state_dict"]) 93 | print("loaded!") 94 | return data["epoch"] 95 | -------------------------------------------------------------------------------- /work_dir/SMAL/out/move_smal_ckpt.zip_here.txt: -------------------------------------------------------------------------------- 1 | . 2 | -------------------------------------------------------------------------------- /work_dir/move_bone_ckpt.zip_here.txt: -------------------------------------------------------------------------------- 1 | hh 2 | --------------------------------------------------------------------------------