├── README.md
├── Teaser.png
├── data
├── SMAL
│ └── raw
│ │ └── move_smal.zip_here.txt
└── move_bone.zip_here.txt
├── datasets
├── __init__.py
├── bone.py
├── dfaust.py
├── meshdata.py
└── smal.py
├── main.py
├── models
├── __init__.py
├── arap.py
├── conv
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ ├── cheb_conv.cpython-36.pyc
│ │ └── message_passing.cpython-36.pyc
│ ├── cheb_conv.py
│ └── message_passing.py
├── inits.py
└── networks.py
├── template
├── dfaust.ply
├── smal_0.ply
└── smpl_male_template.ply
├── test_bone.sh
├── test_dfaust.sh
├── test_smal.sh
├── train_and_test_dfaust.sh
├── train_bone.sh
├── train_smal.sh
├── utils
├── __init__.py
├── dataloader.py
├── mesh_sampling.py
├── read.py
├── train_eval.py
├── utils.py
└── writer.py
└── work_dir
├── SMAL
└── out
│ └── move_smal_ckpt.zip_here.txt
└── move_bone_ckpt.zip_here.txt
/README.md:
--------------------------------------------------------------------------------
1 | # ARAPReg
2 | Code for ICCV 2021 paper: [ARAPReg: An As-Rigid-As Possible Regularization Loss for Learning Deformable Shape Generators.](https://arxiv.org/pdf/2108.09432.pdf).
3 |
4 |
5 |
6 |
7 | ## Installation
8 | The code is developed using `Python 3.6` and `cuda 10.2` on Ubuntu 18.04.
9 | * [Pytorch](https://pytorch.org/) (1.9.0)
10 | * [Pytorch Geometric](https://github.com/rusty1s/pytorch_geometric)
11 | * [OpenMesh](https://github.com/nmaxwell/OpenMesh-Python) (1.1.3)
12 | * [MPI-IS Mesh](https://github.com/MPI-IS/mesh): We suggest to install this library from the source.
13 | * [tqdm](https://github.com/tqdm/tqdm)
14 |
15 | Note that Pytorch and Pytorch Geometric versions might change with your cuda version.
16 |
17 |
18 | ## Data Preparation
19 | We provide data for 3 datasets: [DFAUST](https://dfaust.is.tue.mpg.de/), [SMAL](https://smal.is.tue.mpg.de/) and Bone dataset.
20 |
21 | ### DFAUST
22 | We use 4264 test shapes and 32933 training shapes from DFaust dataset.
23 | You can download the dataset [here](https://drive.google.com/file/d/1BaACAdJO0uoH5P084Gw11a_j3nKVSUjn/view?usp=sharing).
24 | Please place `dfaust.zip` in `data/DFaust/raw/`.
25 |
26 | ### SMAL
27 | We use 400 shapes from the family 0 in SMAL dataset. We generate shapes by the SMAL demo where the mean and the variance of the pose vectors are set to 0 and 0.2. We split them to 300 training and 100 testing samples.
28 |
29 | You can download the generated dataset [here](https://drive.google.com/file/d/1L3n6i097bgZtNYAmnGM9NwOWBNd4c1Fr/view?usp=sharing).
30 | After downloading, please move the downloaded `smal.zip` to `./data/SMAL/raw`.
31 |
32 | ### Bone
33 | We created a conventional bone dataset with 4 categories: tibia, pelvis, scapula and femur. Each category has about 50 shapes. We split them to 40 training and 10 testing samples.
34 | You can download the dataset [here](https://drive.google.com/file/d/1Naq1F6V-Oxw4AQZJeaCKfRrOCQneF0gT/view?usp=sharing).
35 | After downloading, please move `bone.zip` to `./data` then extract it.
36 |
37 |
38 | ## Testing
39 | ### Pretrained checkpoints
40 | You can find pre-trained models and training logs in the following paths:
41 |
42 | **DFAUST**: [checkpoints.zip](https://drive.google.com/file/d/1mCiF-XkMWPNDih4mmxRn6aaPnTAHdpK0/view?usp=sharing). Uncompress it under repository root will place two checkpoints in `DFaust/out/arap/checkpoints/` and `DFaust/out/arap/test_checkpoints/`.
43 |
44 | **SMAL**: [smal_ckpt.zip](https://drive.google.com/file/d/1IIAW5SmylMHsFpU-croeu-uNPdKP_fnL/view?usp=sharing). Move it to `./work_dir/SMAL/out`, then extract it.
45 |
46 | **Bone**: [bone_ckpt.zip](https://drive.google.com/file/d/15I-uABi6_-2qM3QG40Df9G9oNh-K55Nl/view?usp=sharing). Move it to `./work_dir`, then extract it. It contains checkpoints for 4 bone categories.
47 |
48 | ### Run testing
49 | After putting pre-trained checkpoints to their corresponding paths, you can run the following scripts to optimize latent vectors for shape reconstruction. Note that our model has the auto-decoder architecture, so there's still a latent vector training stage for testing shapes.
50 |
51 | Note that both SMAL and Bone checkpoints were trained on a single GPU. Please keep `args.distributed` `False` in `main.py`. In your own training, you can use multiple GPUs.
52 |
53 | **DFAUST**:
54 | ```
55 | bash test_dfaust.sh
56 | ```
57 | **SMAL**:
58 | ```
59 | bash test_smal.sh
60 | ```
61 | **Bone**:
62 | ```
63 | bash test_smal.sh
64 | ```
65 | Note that for bone dataset, we train and test 4 categories seperately. Currently there's `tibia` in the training and testing script. You can replace it with `femur`, `pelvis` or `scapula` to get results for other 3 categories.
66 |
67 |
68 | ## Model training
69 | To retrain our model, run the following scripts after downloading and extracting datasets.
70 |
71 | **DFAUST**:
72 | Note that on DFaust, it is preferred to have multiple GPUs for better efficiency. The script on DFaust tracks the reconstruction error to avoid over-fitting.
73 | ```
74 | bash train_and_test_dfaust.sh
75 | ```
76 | **SMAL**:
77 | ```
78 | bash train_smal.sh
79 | ```
80 | Note: batch_size is set to 16 in the script. It will have CUDA_OUT_OF_MEMORY error when the ARAPReg loss is added at the 800th epoch. You will have to set the batch size to 8 after 800 epochs.
81 | **Bone**:
82 | ```
83 | bash train_bone.sh
84 | ```
85 |
86 |
87 | ## Train on a new dataset
88 | Data preprocessing and loading scripts are in `./datasets`.
89 | To train on a new dataset, please write data loading file similar to `./datasets/dfaust.py`. Then add the dataset to `./datasets/meshdata.py` and `main.py`. Finally you can write a similar training script like `train_and_test_dfaust.sh`.
90 |
91 |
92 |
93 |
--------------------------------------------------------------------------------
/Teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GitBoSun/ARAPReg/88c8a75596c1151f3ed02af4b51f2fa7497e5561/Teaser.png
--------------------------------------------------------------------------------
/data/SMAL/raw/move_smal.zip_here.txt:
--------------------------------------------------------------------------------
1 | hh
2 |
--------------------------------------------------------------------------------
/data/move_bone.zip_here.txt:
--------------------------------------------------------------------------------
1 | hh
2 |
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .smal import SMAL
2 | from .bone import Bone
3 | from .dfaust import DFaust
4 | from .meshdata import MeshData
5 |
6 | __all__ = ['MeshData','Bone', 'DFaust', 'SMAL',]
7 |
--------------------------------------------------------------------------------
/datasets/bone.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | from glob import glob
4 |
5 | import torch
6 | from torch_geometric.data import InMemoryDataset, extract_zip
7 | from utils.read import read_mesh
8 |
9 | from tqdm import tqdm
10 |
11 |
12 | class Bone(InMemoryDataset):
13 | url = 'Not needed'
14 |
15 | categories = [
16 | 'femur',
17 | 'pelvis',
18 | 'scapula',
19 | 'tibia',
20 | ]
21 |
22 | def __init__(self,
23 | root,
24 | train=True,
25 | transform=None,
26 | pre_transform=None):
27 | if not osp.exists(osp.join(root, 'processed',)):
28 | os.makedirs(osp.join(root, 'processed',))
29 |
30 | super().__init__(root, transform, pre_transform)
31 | self.train_list = []
32 | self.test_list = []
33 | path = self.processed_paths[0] if train else self.processed_paths[1]
34 | self.data, self.slices = torch.load(path)
35 |
36 | @property
37 | def raw_file_names(self):
38 | return 'bone_data.zip'
39 |
40 | @property
41 | def processed_file_names(self):
42 | return ['training.pt', 'test.pt']
43 |
44 | def process(self):
45 | print('Processing...')
46 |
47 | fps = glob(osp.join(self.raw_dir, '*.stl'))
48 | if len(fps) == 0:
49 | extract_zip(self.raw_paths[0], self.raw_dir, log=False)
50 | fps = glob(osp.join(self.raw_dir, '*.stl'))
51 | train_data_list, test_data_list = [], []
52 | train_id = 0
53 | val_id = 0
54 | for idx, fp in enumerate(tqdm(fps)):
55 | if (idx % 100) < 10:
56 | data_id = val_id
57 | val_id = val_id + 1
58 | else:
59 | data_id = train_id
60 | train_id = train_id + 1
61 | data = read_mesh(fp, data_id)
62 | # data = read_mesh(fp)
63 | if self.pre_transform is not None:
64 | data = self.pre_transform(data)
65 |
66 | if (idx % 100) < 10:
67 | test_data_list.append(data)
68 | else:
69 | train_data_list.append(data)
70 |
71 | torch.save(self.collate(train_data_list), self.processed_paths[0])
72 | torch.save(self.collate(test_data_list), self.processed_paths[1])
73 |
--------------------------------------------------------------------------------
/datasets/dfaust.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | from glob import glob
4 |
5 | import torch
6 | from torch_geometric.data import InMemoryDataset, extract_zip, Data
7 | from torch_geometric.utils import to_undirected
8 |
9 | import numpy as np
10 |
11 | class DFaust(InMemoryDataset):
12 |
13 | def __init__(self,
14 | root='data/dfaust/',
15 | train=True,
16 | transform=None,
17 | pre_transform=None):
18 | if not osp.exists(osp.join(root, 'processed')):
19 | os.makedirs(osp.join(root, 'processed'))
20 |
21 | super().__init__(root, transform, pre_transform)
22 | path = self.processed_paths[0] if train else self.processed_paths[1]
23 | self.data, self.slices = torch.load(path)
24 |
25 | @property
26 | def raw_file_names(self):
27 | return 'dfaust.zip'
28 |
29 | @property
30 | def processed_file_names(self):
31 | return [ 'training.pt', 'test.pt']
32 |
33 | def download(self):
34 | raise RuntimeError(
35 | 'Dataset not found. Please '
36 | 'move {} to {}'.format(self.raw_file_names, self.raw_dir))
37 |
38 | def process(self):
39 |
40 | def convert_to_data(face, points, data_id, pose=None):
41 | face = torch.from_numpy(faces).T.type(torch.long)
42 | x = torch.tensor(points, dtype=torch.float)
43 | edge_index = torch.cat([face[:2], face[1:], face[::2]], dim=1)
44 | edge_index = to_undirected(edge_index)
45 | if pose is not None:
46 | return Data(x=x, edge_index=edge_index, data_id=data_id, pose=pose)
47 | return Data(x=x, edge_index=edge_index, data_id=data_id)
48 |
49 | if not osp.exists(osp.join(self.raw_dir, 'train.npy')):
50 | extract_zip(self.raw_paths[0], self.raw_dir, log=False)
51 |
52 | train_points = np.load(osp.join(self.raw_dir, 'train.npy')).astype(
53 | np.float32).reshape((-1, 6890, 3))
54 |
55 | test_points = np.load(osp.join(self.raw_dir, 'test.npy')).astype(
56 | np.float32).reshape((-1, 6890, 3))
57 |
58 | #eval_points = np.load(osp.join(self.raw_dir, 'eval.npy')).astype(
59 | # np.float32).reshape((-1, 6890, 3))
60 |
61 | faces = np.load(osp.join(self.raw_dir, 'faces.npy')).astype(
62 | np.int32).reshape((-1, 3))
63 |
64 | print('Processing...')
65 |
66 | train_data_list = []
67 | for i in range(train_points.shape[0]):
68 | data_id = i
69 | train = train_points[i]
70 | if i%100==0:
71 | print('processing training data %d/%d'%(i, train_points.shape[0]))
72 |
73 | data = convert_to_data(faces, train, data_id)
74 | if self.pre_transform is not None:
75 | data = self.pre_transform(data)
76 | train_data_list.append(data)
77 | torch.save(self.collate(train_data_list), self.processed_paths[0])
78 |
79 |
80 | """
81 | eval_data_list = []
82 | for data_id, eval in enumerate(tqdm(eval_points)):
83 | data = convert_to_data(faces, eval, data_id)
84 | if self.pre_transform is not None:
85 | data = self.pre_transform(data)
86 | eval_data_list.append(data)
87 | torch.save(self.collate(eval_data_list), self.processed_paths[1])
88 | """
89 |
90 | test_data_list = []
91 | data_id = 0
92 | for i in range(test_points.shape[0]):
93 | print('processing testing data %d/%d'%(i, test_points.shape[0]))
94 | test = test_points[i]
95 |
96 | data = convert_to_data(faces, test, data_id)
97 | if self.pre_transform is not None:
98 | data = self.pre_transform(data)
99 | test_data_list.append(data)
100 | data_id += 1
101 | torch.save(self.collate(test_data_list), self.processed_paths[1])
102 |
103 |
104 | if __name__ == '__main__':
105 | dfaust = DFaust()
106 |
107 |
--------------------------------------------------------------------------------
/datasets/meshdata.py:
--------------------------------------------------------------------------------
1 | import openmesh as om
2 | from datasets import SMAL, DFaust, Bone
3 | from sklearn.decomposition import PCA
4 | import numpy as np
5 | import sys
6 | sys.path.append("./")
7 | import torch
8 | import os
9 |
10 | class MeshData(object):
11 | def __init__(self,
12 | root,
13 | template_fp,
14 | dataset='DFaust',
15 | transform=None,
16 | pre_transform=None,
17 | pca_n_comp=8,
18 | vert_pca=False,
19 | heat_kernel=False):
20 | self.root = root
21 | self.template_fp = template_fp
22 | self.dataset = dataset
23 | self.transform = transform
24 | self.pre_transform = pre_transform
25 | self.train_dataset = None
26 | self.test_dataste = None
27 | self.template_points = None
28 | self.template_face = None
29 | self.mean = None
30 | self.std = None
31 | self.num_nodes = None
32 | self.use_vert_pca = vert_pca
33 | self.use_heat_kernel = heat_kernel
34 | self.pca = PCA(n_components=pca_n_comp)
35 |
36 | self.load()
37 |
38 | def load(self):
39 | if self.dataset=='SMAL':
40 | self.train_dataset = SMAL(self.root,
41 | train=True,
42 | transform=self.transform,
43 | pre_transform=self.pre_transform)
44 | self.test_dataset = SMAL(self.root,
45 | train=False,
46 | transform=self.transform,
47 | pre_transform=self.pre_transform)
48 |
49 | elif self.dataset=='DFaust':
50 | self.train_dataset = DFaust(self.root,
51 | train=True,
52 | transform=self.transform,
53 | pre_transform=self.pre_transform)
54 | self.test_dataset = DFaust(self.root,
55 | train=False,
56 | transform=self.transform,
57 | pre_transform=self.pre_transform)
58 | elif self.dataset=='Bone':
59 | self.train_dataset = Bone(self.root,
60 | train=True,
61 | transform=self.transform,
62 | pre_transform=self.pre_transform)
63 | self.test_dataset = Bone(self.root,
64 | train=False,
65 | transform=self.transform,
66 | pre_transform=self.pre_transform)
67 |
68 | tmp_mesh = om.read_trimesh(self.template_fp)
69 | self.template_points = tmp_mesh.points()
70 | self.template_face = tmp_mesh.face_vertex_indices()
71 | self.num_nodes = self.train_dataset[0].num_nodes
72 |
73 | self.num_train_graph = len(self.train_dataset)
74 | self.num_test_graph = len(self.test_dataset)
75 |
76 | self.mean = self.train_dataset.data.x.view(self.num_train_graph, -1,
77 | 3).mean(dim=0)
78 | self.std = self.train_dataset.data.x.view(self.num_train_graph, -1,
79 | 3).std(dim=0)
80 | if self.dataset=='SMAL':
81 | self.std = torch.ones(self.std.size())*0.2
82 | self.normalize()
83 |
84 | def normalize(self):
85 |
86 | vertices_train = self.train_dataset.data.x.view(self.num_train_graph, -1, 3).numpy()
87 | vertices_test = self.test_dataset.data.x.view(self.num_test_graph, -1, 3).numpy()
88 |
89 | print('Normalizing...')
90 | self.train_dataset.data.x = (
91 | (self.train_dataset.data.x.view(self.num_train_graph, -1, 3) -
92 | self.mean) / self.std).view(-1, 3)
93 | self.test_dataset.data.x = (
94 | (self.test_dataset.data.x.view(self.num_test_graph, -1, 3) -
95 | self.mean) / self.std).view(-1, 3)
96 |
97 | if self.use_vert_pca:
98 | print("Computing vertex PCA...")
99 | self.pca.fit(np.reshape(vertices_train, (self.num_train_graph, -1)))
100 | pca_axes = self.pca.components_
101 | train_pca_sv= np.matmul(np.reshape(vertices_train, (self.num_train_graph, -1)), pca_axes.transpose())
102 | test_pca_sv = np.matmul(np.reshape(vertices_test, (self.num_test_graph, -1)), pca_axes.transpose())
103 | pca_sv_mean = np.mean(train_pca_sv, axis=0)
104 | pca_sv_std = np.std(train_pca_sv, axis=0)
105 | self.train_pca_sv = (train_pca_sv - pca_sv_mean)/pca_sv_std
106 | self.test_pca_sv = (test_pca_sv - pca_sv_mean)/pca_sv_std
107 |
108 | print('Done!')
109 |
110 | def save_mesh(self, fp, x):
111 | x = x * self.std + self.mean
112 | om.write_mesh(fp, om.TriMesh(x.numpy(), self.template_face))
113 |
--------------------------------------------------------------------------------
/datasets/smal.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | from glob import glob
4 |
5 | import torch
6 | from torch_geometric.data import InMemoryDataset, extract_zip
7 | from utils.read import read_mesh
8 |
9 | from tqdm import tqdm
10 | import numpy as np
11 |
12 | class SMAL(InMemoryDataset):
13 |
14 | def __init__(self,
15 | root,
16 | train=True,
17 | transform=None,
18 | pre_transform=None):
19 | if not osp.exists(osp.join(root, 'processed')):
20 | os.makedirs(osp.join(root, 'processed'))
21 |
22 | super().__init__(root, transform, pre_transform)
23 | path = self.processed_paths[0] if train else self.processed_paths[1]
24 | self.data, self.slices = torch.load(path)
25 |
26 | @property
27 | def raw_file_names(self):
28 | return 'smal.zip'
29 |
30 | @property
31 | def processed_file_names(self):
32 | return [ 'training.pt', 'test.pt']
33 |
34 | def download(self):
35 | raise RuntimeError(
36 | 'Dataset not found. Please '
37 | 'move {} to {}'.format(self.raw_file_names, self.raw_dir))
38 |
39 | def process(self):
40 | print('Processing...')
41 | fps = glob(osp.join(self.raw_dir, '*/*.ply'))
42 | if len(fps) == 0:
43 | extract_zip(self.raw_paths[0], self.raw_dir, log=False)
44 | fps = glob(osp.join(self.raw_dir, '*/*.ply'))
45 | train_data_list, test_data_list = [], []
46 | train_id = 0
47 | val_id = 0
48 |
49 | poses = np.load(osp.join(self.raw_dir, 'results_pose_800/poses.npy'))
50 |
51 | for i in range(poses.shape[0]):
52 | print('processing %d/%d'%(i, poses.shape[0]))
53 |
54 | fp = osp.join(self.raw_dir, 'results_pose_800/%d.ply'%(i+1))
55 | pose_id = int(fp.split('/')[-1].split('.')[0]) - 1
56 |
57 | # 300 train, 100 test
58 | if pose_id> 299 and pose_id<400:
59 | data_id = val_id
60 | val_id = val_id + 1
61 | elif pose_id < 300:
62 | data_id = train_id
63 | train_id = train_id + 1
64 |
65 | data = read_mesh(fp, data_id, pose=poses[pose_id], return_face=False)
66 | if self.pre_transform is not None:
67 | data = self.pre_transform(data)
68 |
69 | if pose_id > 299 and pose_id < 400:
70 | test_data_list.append(data)
71 |
72 | elif pose_id < 300:
73 | train_data_list.append(data)
74 |
75 | torch.save(self.collate(train_data_list), self.processed_paths[0])
76 | torch.save(self.collate(test_data_list), self.processed_paths[1])
77 |
78 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import argparse
3 | import os
4 | import os.path as osp
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 | import torch.backends.cudnn as cudnn
9 | import torch_geometric.transforms as T
10 | from psbody.mesh import Mesh
11 |
12 | from models import AE, AE_single, Pool
13 | from datasets import MeshData
14 | from utils import utils, writer, train_eval, DataLoader, mesh_sampling
15 |
16 | parser = argparse.ArgumentParser(description='mesh autoencoder')
17 | parser.add_argument('--exp_name', type=str, default='model0')
18 | parser.add_argument('--dataset', type=str, default='DFAUST')
19 | parser.add_argument('--n_threads', type=int, default=4)
20 | parser.add_argument('--device_idx', type=int, default=0)
21 | parser.add_argument('--cpu', action='store_true', help='if True, use CPU only')
22 | parser.add_argument('--mode', type=str, default='train', help='[train, test, interpolate, extraplate]')
23 | parser.add_argument('--work_dir', type=str, default='./out')
24 | parser.add_argument('--data_dir', type=str, default='./data')
25 | parser.add_argument('--checkpoint', type=str, default=None)
26 | parser.add_argument('--test_checkpoint', type=str, default=None)
27 | parser.add_argument('--distributed', action='store_true')
28 | parser.add_argument('--alsotest', action='store_true')
29 |
30 | # network hyperparameters
31 | parser.add_argument('--out_channels',
32 | nargs='+',
33 | #default=[32, 32, 32, 64],
34 | default=[16, 16, 16, 32],
35 | type=int)
36 |
37 | parser.add_argument('--ds_factors',
38 | nargs='+',
39 | #default=[32, 32, 32, 64],
40 | default=[4, 4, 4, 4],
41 | type=int)
42 |
43 | parser.add_argument('--latent_channels', type=int, default=8)
44 | parser.add_argument('--in_channels', type=int, default=3)
45 | parser.add_argument('--K', type=int, default=6)
46 |
47 | # optimizer hyperparmeters
48 | parser.add_argument('--optimizer', type=str, default='Adam')
49 | parser.add_argument('--lr', type=float, default=8e-3,)
50 | parser.add_argument('--lr_decay', type=float, default=0.99)
51 | parser.add_argument('--decay_step', type=int, default=1)
52 | parser.add_argument('--weight_decay', type=float, default=0)
53 | parser.add_argument('--test_lr', type=float, default=0.01)
54 | parser.add_argument('--test_decay_step', type=int, default=1)
55 |
56 | parser.add_argument('--arap_weight', type=float, default=0.05)
57 | parser.add_argument('--use_arap_epoch', type=int, default=600, help='epoch that we start to use arap loss')
58 | parser.add_argument('--nz_max', type=int, default=60, help='random sample nz_max latent channels to compute ARAP energy')
59 |
60 | # training hyperparameters
61 | parser.add_argument('--batch_size', type=int, default=16)
62 | parser.add_argument('--epochs', type=int, default=1500)
63 | parser.add_argument('--test_epochs', type=int, default=2000)
64 | parser.add_argument('--continue_train', type=bool, default=False, help='If True, continue training from last checkpoint')
65 |
66 |
67 | # interpolate
68 | parser.add_argument('--inter_num', type=int, default=10, help='number of intermediate shapes between two shapes in interpolation')
69 | parser.add_argument('--extra_num', type=int, default=5, help='number of extrapolation perturbations per shape')
70 | parser.add_argument('--extra_thres', type=float, default=0.2)
71 |
72 |
73 | # others
74 | parser.add_argument('--seed', type=int, default=1)
75 | parser.add_argument('--use_vert_pca', type=bool, default=False, help='If True, use the vertex PCA as the latent vector initialization [DFAUST, Bone]')
76 | parser.add_argument('--use_pose_init', type=bool, default=False, help='If True, use the provided pose vector as the latent vector initialization in training [SMAL]')
77 |
78 | args = parser.parse_args()
79 |
80 | args.data_fp =osp.join(args.data_dir)
81 | args.out_dir = osp.join(args.work_dir, 'out', args.exp_name) # save checkpoints and logs
82 | args.results_dir = osp.join(args.work_dir, 'results', args.exp_name) # save training and testing results
83 | args.checkpoints_dir = osp.join(args.out_dir, 'checkpoints')
84 | args.checkpoints_dir_test = osp.join(args.out_dir, 'test_checkpoints')
85 | print(args)
86 |
87 | utils.makedirs(args.out_dir)
88 | utils.makedirs(args.checkpoints_dir)
89 | utils.makedirs(args.checkpoints_dir_test)
90 | utils.makedirs(args.results_dir)
91 | results_dir_train = os.path.join(args.results_dir, "train")
92 | results_dir_test = os.path.join(args.results_dir, "test")
93 | utils.makedirs(results_dir_train)
94 | utils.makedirs(results_dir_test)
95 | writer = writer.Writer(args)
96 |
97 | if args.cpu:
98 | device = torch.device('cpu')
99 | else:
100 | device = torch.device('cuda', args.device_idx)
101 |
102 |
103 | torch.set_num_threads(args.n_threads)
104 |
105 | # deterministic
106 | torch.manual_seed(args.seed)
107 | cudnn.benchmark = False
108 | cudnn.deterministic = True
109 |
110 | # load dataset
111 | if args.dataset=='SMAL':
112 | template_fp = osp.join('template', 'smal_0.ply')
113 | elif args.dataset=='DFaust':
114 | template_fp = osp.join('template', 'smpl_male_template.ply')
115 | elif args.dataset=='Bone':
116 | template_fp = osp.join(args.data_dir, 'template.obj')
117 | else:
118 | print('invalid dataset!')
119 | exit(-1)
120 |
121 | meshdata = MeshData(args.data_fp,
122 | template_fp,
123 | dataset=args.dataset,
124 | pca_n_comp=args.latent_channels,
125 | vert_pca=args.use_vert_pca)
126 |
127 | train_loader = DataLoader(meshdata.train_dataset,
128 | batch_size=args.batch_size,
129 | shuffle=True)
130 | test_loader = DataLoader(meshdata.test_dataset, batch_size=args.batch_size, shuffle=False)
131 |
132 | # generate/load transform matrices
133 | transform_fp = osp.join(args.data_fp, 'transform.pkl')
134 | if not osp.exists(transform_fp):
135 | print('Generating transform matrices...')
136 | mesh = Mesh(filename=template_fp)
137 | ds_factors = args.ds_factors
138 | if args.dataset=='SMAL':
139 | ds_factors = [1, 1, 1, 1]
140 | _, A, D, U, F = mesh_sampling.generate_transform_matrices(mesh, ds_factors)
141 | tmp = {'face': F, 'adj': A, 'down_transform': D, 'up_transform': U}
142 |
143 | with open(transform_fp, 'wb') as fp:
144 | pickle.dump(tmp, fp)
145 | print('Done!')
146 | print('Transform matrices are saved in \'{}\''.format(transform_fp))
147 | else:
148 | with open(transform_fp, 'rb') as f:
149 | tmp = pickle.load(f, encoding='latin1')
150 |
151 | edge_index_list = [utils.to_edge_index(adj).to(device) for adj in tmp['adj']]
152 | down_transform_list = [
153 | utils.to_sparse(down_transform).to(device)
154 | for down_transform in tmp['down_transform']
155 | ]
156 | up_transform_list = [
157 | utils.to_sparse(up_transform).to(device)
158 | for up_transform in tmp['up_transform']
159 | ]
160 |
161 |
162 |
163 |
164 | if args.distributed:
165 | model = AE(args.in_channels,
166 | args.out_channels,
167 | args.latent_channels,
168 | edge_index_list,
169 | down_transform_list,
170 | up_transform_list,
171 | K=args.K)
172 | #from mmcv.parallel import MMDistributedDataParallel
173 | #from mmcv.runner import get_dist_info, init_dist
174 | #init_dist('pytorch')
175 |
176 | #model = MMDistributedDataParallel(
177 | # model.cuda(),
178 | # device_ids=[torch.cuda.current_device()],
179 | # broadcast_buffers=False,
180 | # find_unused_parameters=False
181 | #)
182 | model = torch.nn.DataParallel(model)
183 | model = model.to(device)
184 | else:
185 | model = AE_single(args.in_channels,
186 | args.out_channels,
187 | args.latent_channels,
188 | edge_index_list,
189 | down_transform_list,
190 | up_transform_list,
191 | K=args.K)
192 | model = model.to(device)
193 | #print(model)
194 |
195 | rand_std = 1.0
196 | if args.dataset=='SMAL' and args.use_pose_init:
197 | rand_std = 0.2
198 | test_num_scenes = len(meshdata.test_dataset)
199 | test_lat_vecs = torch.nn.Embedding(test_num_scenes, args.latent_channels,)
200 | torch.nn.init.normal_(test_lat_vecs.weight.data, 0.0, rand_std)
201 | test_lat_vecs = test_lat_vecs.to(device)
202 |
203 | if args.use_vert_pca:
204 | pca_init = torch.from_numpy(meshdata.train_pca_sv)
205 | lat_vecs = torch.nn.Embedding.from_pretrained(pca_init, freeze=False)
206 | print(meshdata.train_pca_sv.mean(), np.std(meshdata.train_pca_sv))
207 |
208 | test_pca_init = torch.from_numpy(meshdata.test_pca_sv)
209 | test_lat_vecs = torch.nn.Embedding.from_pretrained(test_pca_init, freeze=False)
210 | test_lat_vecs = test_lat_vecs.to(device)
211 | elif args.use_pose_init:
212 | pose_init = torch.from_numpy(np.array(meshdata.train_dataset.data.pose, np.float32))
213 | pose_init = pose_init.reshape(meshdata.num_train_graph, -1)
214 | lat_vecs = torch.nn.Embedding.from_pretrained(pose_init, freeze=False)
215 | else:
216 | train_num_scenes = len(meshdata.train_dataset)
217 | lat_vecs = torch.nn.Embedding(train_num_scenes, args.latent_channels,)
218 | torch.nn.init.normal_(lat_vecs.weight.data,0.0,rand_std)
219 |
220 | lat_vecs = lat_vecs.to(device)
221 | if args.continue_train:
222 | start_epoch = writer.load_checkpoint(model, lat_vecs, None,
223 | None, checkpoint=args.checkpoint)
224 |
225 | if args.dataset=='SMAL':
226 | train_vec_lr = 8e-3
227 | else:
228 | train_vec_lr = args.lr
229 | optimizer_all = torch.optim.Adam(
230 | [
231 | {
232 | "params": model.parameters(),
233 | "lr": args.lr,
234 | "weight_decay": args.weight_decay
235 | },
236 | {
237 | "params": lat_vecs.parameters(),
238 | "lr": train_vec_lr,
239 | "weight_decay": args.weight_decay
240 | },
241 | ]
242 | )
243 |
244 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer_all,
245 | args.decay_step,
246 | gamma=args.lr_decay)
247 | if args.mode=='train':
248 | if args.alsotest:
249 | optimizer_test = torch.optim.Adam(
250 | test_lat_vecs.parameters(), lr=args.test_lr,
251 | weight_decay=args.weight_decay)
252 | scheduler_test = torch.optim.lr_scheduler.StepLR(
253 | optimizer_test, args.test_decay_step, gamma=args.lr_decay)
254 | else:
255 | optimizer_test = None
256 | scheduler_test = None
257 |
258 | train_eval.run(model,
259 | train_loader, lat_vecs, optimizer_all, scheduler,
260 | test_loader, test_lat_vecs, optimizer_test, scheduler_test,
261 | args.epochs, writer, device, results_dir_train,
262 | meshdata.mean.numpy(), meshdata.std.numpy(), meshdata.template_face,
263 | arap_weight=args.arap_weight, use_arap_epoch=args.use_arap_epoch,
264 | nz_max=args.nz_max, continue_train=args.continue_train,
265 | checkpoint=args.checkpoint, test_checkpoint=args.test_checkpoint, dataset=args.dataset)
266 |
267 | elif args.mode=='test':
268 | optimizer_test = torch.optim.Adam(
269 | test_lat_vecs.parameters(), lr=args.test_lr,
270 | weight_decay=args.weight_decay)
271 | scheduler_test = torch.optim.lr_scheduler.StepLR(
272 | optimizer_test, args.test_decay_step, gamma=args.lr_decay)
273 |
274 | train_eval.test_reconstruct(model, test_loader, test_lat_vecs,
275 | args.test_epochs, optimizer_test, scheduler_test, writer,
276 | device, results_dir_test, meshdata.mean.numpy(),
277 | meshdata.std.numpy(), meshdata.template_face,
278 | checkpoint=args.checkpoint, test_checkpoint=args.test_checkpoint, dataset=args.dataset)
279 |
280 | elif args.mode=='interpolate':
281 | train_eval.global_interpolate(model, lat_vecs, optimizer_all, scheduler,
282 | writer, device, args.results_dir, meshdata.mean.numpy(), meshdata.std.numpy(), meshdata.template_face, args.inter_num)
283 |
284 | elif args.mode=='extraplate':
285 | optimizer_test = torch.optim.Adam(test_lat_vecs.parameters(),
286 | lr=args.test_lr,
287 | weight_decay=args.weight_decay)
288 | scheduler_test = torch.optim.lr_scheduler.StepLR(optimizer_test,
289 | args.test_decay_step,
290 | gamma=args.lr_decay)
291 | train_eval.extrapolation(model, test_lat_vecs, optimizer_test, scheduler_test,
292 | writer, device, args.results_dir, meshdata.mean.numpy(), meshdata.std.numpy(), meshdata.template_face, args.extra_num, args.extra_thres)
293 |
294 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .networks import AE,AE_single, Pool
2 | from .arap import ARAP
3 | __all__ = [Pool, AE,AE_single, ARAP]
4 |
--------------------------------------------------------------------------------
/models/arap.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_geometric.utils import degree, get_laplacian
3 | import torch_sparse as ts
4 | import numpy as np
5 | import sys
6 |
7 | def get_laplacian_kron3x3(edge_index, edge_weights, N):
8 | edge_index, edge_weight = get_laplacian(edge_index, edge_weights, num_nodes=N)
9 | edge_weight *= 2
10 | e0, e1 = edge_index
11 | i0 = [e0*3, e0*3+1, e0*3+2]
12 | i1 = [e1*3, e1*3+1, e1*3+2]
13 | vals = [edge_weight, edge_weight, edge_weight]
14 | i0 = torch.cat(i0, 0)
15 | i1 = torch.cat(i1, 0)
16 | vals = torch.cat(vals, 0)
17 | indices, vals = ts.coalesce([i0, i1], vals, N*3, N*3)
18 | return indices, vals
19 |
20 | class ARAP(torch.nn.Module):
21 | def __init__(self, template_face, num_points):
22 | super(ARAP, self).__init__()
23 | N = num_points
24 | self.template_face = template_face
25 | adj = np.zeros((num_points, num_points))
26 | adj[template_face[:, 0], template_face[:, 1]] = 1
27 | adj[template_face[:, 1], template_face[:, 2]] = 1
28 | adj[template_face[:, 0], template_face[:, 2]] = 1
29 | adj = adj + adj.T
30 | edge_index = torch.as_tensor(np.stack(np.where(adj > 0), 0),
31 | dtype=torch.long)
32 | e0, e1 = edge_index
33 | deg = degree(e0, N)
34 | edge_weight = torch.ones_like(e0)
35 |
36 | L_indices, L_vals = get_laplacian_kron3x3(edge_index, edge_weight, N)
37 | self.register_buffer('L_indices', L_indices)
38 | self.register_buffer('L_vals', L_vals)
39 | self.register_buffer('edge_weight', edge_weight)
40 | self.register_buffer('edge_index', edge_index)
41 |
42 | def forward(self, x, J, k=0):
43 | """
44 | x: [B, N, 3] point locations.
45 | J: [B, N*3, D] Jacobian of generator.
46 | J_eigvals: [B, D]
47 | """
48 | num_batches, N = x.shape[:2]
49 | e0, e1 = self.edge_index
50 | edge_vecs = x[:, e0, :] - x[:, e1, :]
51 | trace_ = []
52 |
53 | for i in range(num_batches):
54 | LJ = ts.spmm(self.L_indices, self.L_vals, N*3, N*3, J[i])
55 | JTLJ = J[i].T.matmul(LJ)
56 |
57 | B0, B1, B_vals = [], [], []
58 | B0.append(e0*3 ); B1.append(e1*3+1); B_vals.append(-edge_vecs[i, :, 2]*self.edge_weight)
59 | B0.append(e0*3 ); B1.append(e1*3+2); B_vals.append( edge_vecs[i, :, 1]*self.edge_weight)
60 | B0.append(e0*3+1); B1.append(e1*3+0); B_vals.append( edge_vecs[i, :, 2]*self.edge_weight)
61 | B0.append(e0*3+1); B1.append(e1*3+2); B_vals.append(-edge_vecs[i, :, 0]*self.edge_weight)
62 | B0.append(e0*3+2); B1.append(e1*3+0); B_vals.append(-edge_vecs[i, :, 1]*self.edge_weight)
63 | B0.append(e0*3+2); B1.append(e1*3+1); B_vals.append( edge_vecs[i, :, 0]*self.edge_weight)
64 |
65 | B0.append(e0*3 ); B1.append(e0*3+1); B_vals.append(-edge_vecs[i, :, 2]*self.edge_weight)
66 | B0.append(e0*3 ); B1.append(e0*3+2); B_vals.append( edge_vecs[i, :, 1]*self.edge_weight)
67 | B0.append(e0*3+1); B1.append(e0*3+0); B_vals.append( edge_vecs[i, :, 2]*self.edge_weight)
68 | B0.append(e0*3+1); B1.append(e0*3+2); B_vals.append(-edge_vecs[i, :, 0]*self.edge_weight)
69 | B0.append(e0*3+2); B1.append(e0*3+0); B_vals.append(-edge_vecs[i, :, 1]*self.edge_weight)
70 | B0.append(e0*3+2); B1.append(e0*3+1); B_vals.append( edge_vecs[i, :, 0]*self.edge_weight)
71 | B0 = torch.cat(B0, 0)
72 | B1 = torch.cat(B1, 0)
73 | B_vals = torch.cat(B_vals, 0)
74 | B_indices, B_vals = ts.coalesce([B0, B1], B_vals, N*3, N*3)
75 | BT_indices, BT_vals = ts.transpose(B_indices, B_vals, N*3, N*3)
76 |
77 | C0, C1, C_vals = [], [], []
78 | edge_vecs_sq = (edge_vecs[i] * edge_vecs[i]).sum(-1)
79 | evi = edge_vecs[i]
80 | for di in range(3):
81 | for dj in range(3):
82 | C0.append(e0*3+di); C1.append(e0*3+dj); C_vals.append(-evi[:, di]*evi[:, dj]*self.edge_weight)
83 | C0.append(e0*3+di); C1.append(e0*3+di); C_vals.append(edge_vecs_sq*self.edge_weight)
84 | C0 = torch.cat(C0, 0)
85 | C1 = torch.cat(C1, 0)
86 | C_vals = torch.cat(C_vals, 0)
87 | C_indices, C_vals = ts.coalesce([C0, C1], C_vals, N*3, N*3)
88 | C_vals = C_vals.view(N, 3, 3).inverse().reshape(-1)
89 | BTJ = ts.spmm(BT_indices, BT_vals, N*3, N*3, J[i])
90 | CBTJ = ts.spmm(C_indices, C_vals, N*3, N*3, BTJ)
91 | JTBCBTJ = BTJ.T.mm(CBTJ)
92 |
93 | e = torch.linalg.eigvalsh(JTLJ-JTBCBTJ).clip(0)
94 |
95 | e = e ** 0.5
96 |
97 | trace = e.sum()
98 |
99 | trace_.append(trace)
100 |
101 | trace_ = torch.stack(trace_, )
102 | return trace_.mean()
103 |
--------------------------------------------------------------------------------
/models/conv/__init__.py:
--------------------------------------------------------------------------------
1 | from .message_passing import MessagePassing
2 | from .cheb_conv import ChebConv
3 |
4 |
5 | __all__ = [
6 | 'MessagePassing',
7 | 'ChebConv',
8 | ]
9 |
--------------------------------------------------------------------------------
/models/conv/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GitBoSun/ARAPReg/88c8a75596c1151f3ed02af4b51f2fa7497e5561/models/conv/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/models/conv/__pycache__/cheb_conv.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GitBoSun/ARAPReg/88c8a75596c1151f3ed02af4b51f2fa7497e5561/models/conv/__pycache__/cheb_conv.cpython-36.pyc
--------------------------------------------------------------------------------
/models/conv/__pycache__/message_passing.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GitBoSun/ARAPReg/88c8a75596c1151f3ed02af4b51f2fa7497e5561/models/conv/__pycache__/message_passing.cpython-36.pyc
--------------------------------------------------------------------------------
/models/conv/cheb_conv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import Parameter
3 | from torch_geometric.utils import remove_self_loops, add_self_loops
4 | from torch_geometric.utils import get_laplacian
5 | from .message_passing import MessagePassing
6 |
7 | from ..inits import glorot, zeros
8 |
9 |
10 | class ChebConv(MessagePassing):
11 | """
12 | Args:
13 | in_channels (int): Size of each input sample.
14 | out_channels (int): Size of each output sample.
15 | K (int): Chebyshev filter size, *i.e.* number of hops :math:`K`.
16 | normalization (str, optional): The normalization scheme for the graph
17 | Laplacian (default: :obj:`"sym"`):
18 |
19 | 1. :obj:`None`: No normalization
20 | :math:`\mathbf{L} = \mathbf{D} - \mathbf{A}`
21 |
22 | 2. :obj:`"sym"`: Symmetric normalization
23 | :math:`\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1/2} \mathbf{A}
24 | \mathbf{D}^{-1/2}`
25 |
26 | 3. :obj:`"rw"`: Random-walk normalization
27 | :math:`\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1} \mathbf{A}`
28 |
29 | You need to pass :obj:`lambda_max` to the :meth:`forward` method of
30 | this operator in case the normalization is non-symmetric.
31 | :obj:`\lambda_max` should be a :class:`torch.Tensor` of size
32 | :obj:`[num_graphs]` in a mini-batch scenario and a scalar when
33 | operating on single graphs.
34 | You can pre-compute :obj:`lambda_max` via the
35 | :class:`torch_geometric.transforms.LaplacianLambdaMax` transform.
36 | cached (bool, optional): If set to :obj:`True`, the layer will cache
37 | the computation of the scaled and normalized Laplacian
38 | :math:`\frac{2\mathbf{L}}{\lambda_{\max}} - \mathbf{I}` on first execution,
39 | and will use the cached version for further executions.
40 | This parameter should only be set to :obj:`True` in
41 | fixed graph scenarios. (default: :obj:`True`)
42 | bias (bool, optional): If set to :obj:`False`, the layer will not learn
43 | an additive bias. (default: :obj:`True`)
44 | **kwargs (optional): Additional arguments of
45 | :class:`torch_geometric.nn.conv.MessagePassing`.
46 | """
47 |
48 | def __init__(self,
49 | in_channels,
50 | out_channels,
51 | K,
52 | normalization='sym',
53 | cached=True,
54 | bias=True,
55 | **kwargs):
56 | super(ChebConv, self).__init__(aggr='add', **kwargs)
57 |
58 | assert K > 0
59 | assert normalization in [None, 'sym', 'rw'], 'Invalid normalization'
60 |
61 | self.in_channels = in_channels
62 | self.out_channels = out_channels
63 | self.normalization = normalization
64 | self.cached = cached
65 | self.weight = Parameter(torch.Tensor(K, in_channels, out_channels))
66 |
67 | if bias:
68 | self.bias = Parameter(torch.Tensor(out_channels))
69 | else:
70 | self.register_parameter('bias', None)
71 |
72 | self.reset_parameters()
73 |
74 | def reset_parameters(self):
75 | glorot(self.weight)
76 | zeros(self.bias)
77 | self.cached_result = None
78 | self.cached_num_edges = None
79 |
80 | @staticmethod
81 | def norm(edge_index,
82 | num_nodes,
83 | edge_weight,
84 | normalization,
85 | lambda_max,
86 | dtype=None,
87 | batch=None):
88 |
89 | edge_index, edge_weight = remove_self_loops(edge_index, edge_weight)
90 |
91 | edge_index, edge_weight = get_laplacian(edge_index, edge_weight,
92 | normalization, dtype,
93 | num_nodes)
94 |
95 | if batch is not None and torch.is_tensor(lambda_max):
96 | lambda_max = lambda_max[batch[edge_index[0]]]
97 |
98 | edge_weight = (2.0 * edge_weight) / lambda_max
99 | edge_weight[edge_weight == float('inf')] = 0
100 |
101 | edge_index, edge_weight = add_self_loops(edge_index,
102 | edge_weight,
103 | fill_value=-1,
104 | num_nodes=num_nodes)
105 |
106 | return edge_index, edge_weight
107 |
108 | def forward(self,
109 | x,
110 | edge_index,
111 | edge_weight=None,
112 | batch=None,
113 | lambda_max=None):
114 | """"""
115 | if self.normalization != 'sym' and lambda_max is None:
116 | raise ValueError('You need to pass `lambda_max` to `forward() in`'
117 | 'case the normalization is non-symmetric.')
118 | lambda_max = 2.0 if lambda_max is None else lambda_max
119 |
120 | if not self.cached or self.cached_result is None:
121 | edge_index, norm = self.norm(edge_index,
122 | x.size(1),
123 | edge_weight,
124 | self.normalization,
125 | lambda_max,
126 | dtype=x.dtype,
127 | batch=batch)
128 | self.cached_result = edge_index, norm
129 |
130 | edge_index, norm = self.cached_result
131 | Tx_0 = x
132 | out = torch.matmul(Tx_0, self.weight[0])
133 |
134 | if self.weight.size(0) > 1:
135 | Tx_1 = self.propagate(edge_index, x=x, norm=norm)
136 | out = out + torch.matmul(Tx_1, self.weight[1])
137 |
138 | for k in range(2, self.weight.size(0)):
139 | Tx_2 = 2 * self.propagate(edge_index, x=Tx_1, norm=norm) - Tx_0
140 | out = out + torch.matmul(Tx_2, self.weight[k])
141 | Tx_0, Tx_1 = Tx_1, Tx_2
142 |
143 | if self.bias is not None:
144 | out = out + self.bias
145 |
146 | return out
147 |
148 | def message(self, x_j, norm):
149 | return norm.view(-1, 1) * x_j
150 |
151 | def __repr__(self):
152 | return '{}({}, {}, K={}, normalization={})'.format(
153 | self.__class__.__name__, self.in_channels, self.out_channels,
154 | self.weight.size(0), self.normalization)
155 |
--------------------------------------------------------------------------------
/models/conv/message_passing.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import inspect
3 |
4 | import torch
5 | from torch_scatter import scatter
6 |
7 | special_args = [
8 | 'edge_index', 'edge_index_i', 'edge_index_j', 'size', 'size_i', 'size_j'
9 | ]
10 | __size_error_msg__ = ('All tensors which should get mapped to the same source '
11 | 'or target nodes must be of same size in dimension 0.')
12 |
13 | is_python2 = sys.version_info[0] < 3
14 | getargspec = inspect.getargspec if is_python2 else inspect.getfullargspec
15 |
16 |
17 | class MessagePassing(torch.nn.Module):
18 | def __init__(self, aggr='add', flow='source_to_target'):
19 | super(MessagePassing, self).__init__()
20 |
21 | self.aggr = aggr
22 | assert self.aggr in ['add', 'mean', 'max']
23 |
24 | self.flow = flow
25 | assert self.flow in ['source_to_target', 'target_to_source']
26 |
27 | self.__message_args__ = getargspec(self.message)[0][1:]
28 | self.__special_args__ = [(i, arg)
29 | for i, arg in enumerate(self.__message_args__)
30 | if arg in special_args]
31 | self.__message_args__ = [
32 | arg for arg in self.__message_args__ if arg not in special_args
33 | ]
34 | self.__update_args__ = getargspec(self.update)[0][2:]
35 |
36 | def propagate(self, edge_index, size=None, dim=0, **kwargs):
37 | dim = 1 # aggregate messages wrt nodes for batched_data: [batch_size, nodes, features]
38 | size = [None, None] if size is None else list(size)
39 | assert len(size) == 2
40 |
41 | i, j = (0, 1) if self.flow == 'target_to_source' else (1, 0)
42 | ij = {"_i": i, "_j": j}
43 |
44 | message_args = []
45 | for arg in self.__message_args__:
46 | if arg[-2:] in ij.keys():
47 | tmp = kwargs.get(arg[:-2], None)
48 | if tmp is None: # pragma: no cover
49 | message_args.append(tmp)
50 | else:
51 | idx = ij[arg[-2:]]
52 | if isinstance(tmp, tuple) or isinstance(tmp, list):
53 | assert len(tmp) == 2
54 | if tmp[1 - idx] is not None:
55 | if size[1 - idx] is None:
56 | size[1 - idx] = tmp[1 - idx].size(dim)
57 | if size[1 - idx] != tmp[1 - idx].size(dim):
58 | raise ValueError(__size_error_msg__)
59 | tmp = tmp[idx]
60 |
61 | if tmp is None:
62 | message_args.append(tmp)
63 | else:
64 | if size[idx] is None:
65 | size[idx] = tmp.size(dim)
66 | if size[idx] != tmp.size(dim):
67 | raise ValueError(__size_error_msg__)
68 |
69 | tmp = torch.index_select(tmp, dim, edge_index[idx])
70 | message_args.append(tmp)
71 | else:
72 | message_args.append(kwargs.get(arg, None))
73 |
74 | size[0] = size[1] if size[0] is None else size[0]
75 | size[1] = size[0] if size[1] is None else size[1]
76 |
77 | kwargs['edge_index'] = edge_index
78 | kwargs['size'] = size
79 |
80 | for (idx, arg) in self.__special_args__:
81 | if arg[-2:] in ij.keys():
82 | message_args.insert(idx, kwargs[arg[:-2]][ij[arg[-2:]]])
83 | else:
84 | message_args.insert(idx, kwargs[arg])
85 |
86 | update_args = [kwargs[arg] for arg in self.__update_args__]
87 |
88 | out = self.message(*message_args)
89 | out = scatter(out, edge_index[i], dim=dim, dim_size=size[i], reduce=self.aggr)
90 | out = self.update(out, *update_args)
91 |
92 | return out
93 |
94 | def message(self, x_j): # pragma: no cover
95 | return x_j
96 |
97 | def update(self, aggr_out): # pragma: no cover
98 | return aggr_out
99 |
--------------------------------------------------------------------------------
/models/inits.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 |
4 | def uniform(size, tensor):
5 | bound = 1.0 / math.sqrt(size)
6 | if tensor is not None:
7 | tensor.data.uniform_(-bound, bound)
8 |
9 |
10 | def kaiming_uniform(tensor, fan, a):
11 | if tensor is not None:
12 | bound = math.sqrt(6 / ((1 + a**2) * fan))
13 | tensor.data.uniform_(-bound, bound)
14 |
15 |
16 | def glorot(tensor):
17 | if tensor is not None:
18 | stdv = math.sqrt(6.0 / (tensor.size(-2) + tensor.size(-1)))
19 | tensor.data.uniform_(-stdv, stdv)
20 |
21 |
22 | def zeros(tensor):
23 | if tensor is not None:
24 | tensor.data.fill_(0)
25 |
26 |
27 | def ones(tensor):
28 | if tensor is not None:
29 | tensor.data.fill_(1)
30 |
31 |
32 | def reset(nn):
33 | def _reset(item):
34 | if hasattr(item, 'reset_parameters'):
35 | item.reset_parameters()
36 |
37 | if nn is not None:
38 | if hasattr(nn, 'children') and len(list(nn.children())) > 0:
39 | for item in nn.children():
40 | _reset(item)
41 | else:
42 | _reset(nn)
43 |
--------------------------------------------------------------------------------
/models/networks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from models.conv import ChebConv
5 |
6 | from .inits import reset
7 | from torch_scatter import scatter_add
8 |
9 |
10 | def Pool(x, trans, dim=1):
11 | row, col = trans._indices()
12 | value = trans._values().unsqueeze(-1)
13 | out = torch.index_select(x, dim, col) * value
14 | out = scatter_add(out, row, dim, dim_size=trans.size(0))
15 | return out
16 |
17 |
18 | class Enblock(nn.Module):
19 | def __init__(self, in_channels, out_channels, K, **kwargs):
20 | super(Enblock, self).__init__()
21 | self.conv = ChebConv(in_channels, out_channels, K, **kwargs)
22 | self.reset_parameters()
23 |
24 | def reset_parameters(self):
25 | for name, param in self.conv.named_parameters():
26 | if 'bias' in name:
27 | nn.init.constant_(param, 0)
28 | else:
29 | nn.init.xavier_uniform_(param)
30 |
31 | def forward(self, x, edge_index, down_transform):
32 | out = F.elu(self.conv(x, edge_index))
33 | out = Pool(out, down_transform)
34 | return out
35 |
36 |
37 | class Deblock(nn.Module):
38 | def __init__(self, in_channels, out_channels, K, **kwargs):
39 | super(Deblock, self).__init__()
40 | self.conv = ChebConv(in_channels, out_channels, K, **kwargs)
41 | self.reset_parameters()
42 |
43 | def reset_parameters(self):
44 | for name, param in self.conv.named_parameters():
45 | if 'bias' in name:
46 | nn.init.constant_(param, 0)
47 | else:
48 | nn.init.xavier_uniform_(param)
49 |
50 | def forward(self, x, edge_index, up_transform):
51 | out = Pool(x, up_transform)
52 | out = F.elu(self.conv(out, edge_index))
53 | return out
54 |
55 |
56 | class AE(nn.Module):
57 | def __init__(self, in_channels, out_channels, latent_channels,
58 | edge_index, down_transform, up_transform, K, **kwargs):
59 | super(AE, self).__init__()
60 | self.in_channels = in_channels
61 | self.out_channels = out_channels
62 | #self.edge_index = edge_index
63 | self.num_edge_index = len(edge_index)
64 | for i in range(self.num_edge_index):
65 | self.register_buffer(f'edge_index_{i}', edge_index[i])
66 | setattr(self, f'edge_index_{i}', edge_index[i])
67 |
68 | #self.down_transform = down_transform
69 | self.num_down_transform = len(down_transform)
70 | for i in range(self.num_down_transform):
71 | self.register_buffer(f'down_transform_{i}', down_transform[i])
72 | setattr(self, f'down_transform_{i}', down_transform[i])
73 |
74 | #self.up_transform = up_transform
75 | self.num_up_transform = len(up_transform)
76 | for i in range(self.num_up_transform):
77 | self.register_buffer(f'up_transform_{i}', up_transform[i])
78 | setattr(self, f'up_transform_{i}', up_transform[i])
79 | # self.num_vert used in the last and the first layer of encoder and decoder
80 | self.num_vert = down_transform[-1].size(0)
81 |
82 | # encoder
83 | #self.en_layers = nn.ModuleList()
84 | #for idx in range(len(out_channels)):
85 | # if idx == 0:
86 | # self.en_layers.append(
87 | # Enblock(in_channels, out_channels[idx], K, **kwargs))
88 | # else:
89 | # self.en_layers.append(
90 | # Enblock(out_channels[idx - 1], out_channels[idx], K,
91 | # **kwargs))
92 | #self.en_layers.append(
93 | # nn.Linear(self.num_vert * out_channels[-1], latent_channels))
94 |
95 | # decoder
96 | self.de_layers = nn.ModuleList()
97 | self.de_layers.append(
98 | nn.Linear(latent_channels, self.num_vert * out_channels[-1]))
99 | for idx in range(len(out_channels)):
100 | if idx == 0:
101 | self.de_layers.append(
102 | Deblock(out_channels[-idx - 1], out_channels[-idx - 1], K,
103 | **kwargs))
104 | else:
105 | self.de_layers.append(
106 | Deblock(out_channels[-idx], out_channels[-idx - 1], K,
107 | **kwargs))
108 | # reconstruction
109 | self.de_layers.append(
110 | ChebConv(out_channels[0], in_channels, K, **kwargs))
111 |
112 | self.reset_parameters()
113 |
114 | def reset_parameters(self):
115 | for name, param in self.named_parameters():
116 | if 'bias' in name:
117 | nn.init.constant_(param, 0)
118 | else:
119 | nn.init.xavier_uniform_(param)
120 |
121 | def encoder(self, x):
122 | for i, layer in enumerate(self.en_layers):
123 | if i != len(self.en_layers) - 1:
124 | x = layer(x, getattr(self, f'edge_index_{i}'),
125 | getattr(self, f'down_transform_{i}'))
126 | else:
127 | x = x.view(-1, layer.weight.size(1))
128 | x = layer(x)
129 | return x
130 |
131 | def decoder(self, x):
132 | num_layers = len(self.de_layers)
133 | num_deblocks = num_layers - 2
134 | for i, layer in enumerate(self.de_layers):
135 | if i == 0:
136 | x = layer(x)
137 | x = x.view(-1, self.num_vert, self.out_channels[-1])
138 | elif i != num_layers - 1:
139 | x = layer(x, getattr(self, f'edge_index_{num_deblocks - i}'),
140 | getattr(self, f'up_transform_{num_deblocks - i}'))
141 | else:
142 | # last layer
143 | x = layer(x, getattr(self, 'edge_index_0'))
144 | return x
145 |
146 | def forward(self, x):
147 | # x - batched feature matrix
148 | #z = self.encoder(x)
149 | out = self.decoder(x)
150 | return out
151 |
152 | class AE_single(nn.Module):
153 | def __init__(self, in_channels, out_channels, latent_channels,
154 | edge_index, down_transform, up_transform, K, **kwargs):
155 | super(AE_single, self).__init__()
156 | self.in_channels = in_channels
157 | self.out_channels = out_channels
158 | self.edge_index = edge_index
159 | self.down_transform = down_transform
160 | self.up_transform = up_transform
161 | # self.num_vert used in the last and the first layer of encoder and decoder
162 | self.num_vert = self.down_transform[-1].size(0)
163 |
164 | # encoder
165 | #self.en_layers = nn.ModuleList()
166 | #for idx in range(len(out_channels)):
167 | # if idx == 0:
168 | # self.en_layers.append(
169 | # Enblock(in_channels, out_channels[idx], K, **kwargs))
170 | # else:
171 | # self.en_layers.append(
172 | # Enblock(out_channels[idx - 1], out_channels[idx], K,
173 | # **kwargs))
174 | #self.en_layers.append(
175 | # nn.Linear(self.num_vert * out_channels[-1], latent_channels))
176 |
177 | # decoder
178 | self.de_layers = nn.ModuleList()
179 | self.de_layers.append(
180 | nn.Linear(latent_channels, self.num_vert * out_channels[-1]))
181 | for idx in range(len(out_channels)):
182 | if idx == 0:
183 | self.de_layers.append(
184 | Deblock(out_channels[-idx - 1], out_channels[-idx - 1], K,
185 | **kwargs))
186 | else:
187 | self.de_layers.append(
188 | Deblock(out_channels[-idx], out_channels[-idx - 1], K,
189 | **kwargs))
190 | # reconstruction
191 | self.de_layers.append(
192 | ChebConv(out_channels[0], in_channels, K, **kwargs))
193 |
194 | self.reset_parameters()
195 |
196 | def reset_parameters(self):
197 | for name, param in self.named_parameters():
198 | if 'bias' in name:
199 | nn.init.constant_(param, 0)
200 | else:
201 | nn.init.xavier_uniform_(param)
202 |
203 | def encoder(self, x):
204 | for i, layer in enumerate(self.en_layers):
205 | if i != len(self.en_layers) - 1:
206 | x = layer(x, self.edge_index[i], self.down_transform[i])
207 | else:
208 | x = x.view(-1, layer.weight.size(1))
209 | x = layer(x)
210 | return x
211 |
212 | def decoder(self, x):
213 | num_layers = len(self.de_layers)
214 | num_deblocks = num_layers - 2
215 | for i, layer in enumerate(self.de_layers):
216 | if i == 0:
217 | x = layer(x)
218 | x = x.view(-1, self.num_vert, self.out_channels[-1])
219 | elif i != num_layers - 1:
220 | x = layer(x, self.edge_index[num_deblocks - i],
221 | self.up_transform[num_deblocks - i])
222 | else:
223 | # last layer
224 | x = layer(x, self.edge_index[0])
225 | return x
226 |
227 | def forward(self, x):
228 | # x - batched feature matrix
229 | #z = self.encoder(x)
230 | out = self.decoder(x)
231 | return out
232 |
233 |
--------------------------------------------------------------------------------
/template/dfaust.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GitBoSun/ARAPReg/88c8a75596c1151f3ed02af4b51f2fa7497e5561/template/dfaust.ply
--------------------------------------------------------------------------------
/template/smal_0.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GitBoSun/ARAPReg/88c8a75596c1151f3ed02af4b51f2fa7497e5561/template/smal_0.ply
--------------------------------------------------------------------------------
/template/smpl_male_template.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GitBoSun/ARAPReg/88c8a75596c1151f3ed02af4b51f2fa7497e5561/template/smpl_male_template.ply
--------------------------------------------------------------------------------
/test_bone.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | CUDA_VISIBLE_DEVICES=0 python main.py \
3 | --exp_name 'model0' \
4 | --device_idx 0 \
5 | --epochs 4500 \
6 | --lr 8e-3 \
7 | --arap_weight 0.05 \
8 | --use_arap_epoch 1500 \
9 | --decay_step 10 \
10 | --latent_channels 8 \
11 | --use_vert_pca True \
12 | --work_dir ./work_dir/bone/tibia \
13 | --dataset Bone \
14 | --data_dir ./data/bone/tibia \
15 | --continue_train True \
16 | --test_lr 8e-2 \
17 | --test_epochs 2500 \
18 | --test_decay_step 30 \
19 | --mode test \
20 |
--------------------------------------------------------------------------------
/test_dfaust.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -W ignore main.py \
3 | --out_channels 32 32 32 64 \
4 | --ds_factors 2 2 2 2 \
5 | --exp_name 'arap' \
6 | --device_idx 0 \
7 | --batch_size 64 \
8 | --epochs 2000 \
9 | --n_threads 4 \
10 | --lr 1e-4 \
11 | --arap_weight 0.0 \
12 | --use_arap_epoch 800 \
13 | --decay_step 3 \
14 | --latent_channels 72 \
15 | --use_vert_pca True \
16 | --work_dir ./work_dir/DFaust \
17 | --dataset DFaust \
18 | --data_dir ./data/DFaust \
19 | --continue_train True \
20 | --test_lr 1e-3 \
21 | --test_epochs 2500 \
22 | --test_decay_step 5 \
23 | --mode test \
24 | --distributed \
25 | --checkpoint work_dir/DFaust/out/arap/checkpoints/checkpoint_0410.pt \
26 | --test_checkpoint work_dir/DFaust/out/arap/test_checkpoints/checkpoint_1230.pt \
27 |
--------------------------------------------------------------------------------
/test_smal.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | CUDA_VISIBLE_DEVICES=0 python -W ignore main.py \
3 | --exp_name 'model0' \
4 | --device_idx 0 \
5 | --batch_size 8 \
6 | --epochs 2000 \
7 | --lr 0.01 \
8 | --arap_weight 0.0 \
9 | --use_arap_epoch 800 \
10 | --nz_max 50 \
11 | --decay_step 3 \
12 | --latent_channels 96 \
13 | --use_pose_init True \
14 | --work_dir ./work_dir/SMAL \
15 | --dataset SMAL \
16 | --data_dir ./data/SMAL \
17 | --continue_train True \
18 | --test_lr 0.01 \
19 | --test_epochs 2500 \
20 | --test_decay_step 5 \
21 | --mode test \
22 |
23 |
--------------------------------------------------------------------------------
/train_and_test_dfaust.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -W ignore main.py \
3 | --out_channels 32 32 32 64 \
4 | --ds_factors 2 2 2 2 \
5 | --exp_name 'arap' \
6 | --device_idx 0 \
7 | --batch_size 64 \
8 | --epochs 450 \
9 | --lr 1e-4 \
10 | --test_lr 1e-3 \
11 | --test_decay_step 5 \
12 | --arap_weight 5e-4 \
13 | --use_arap_epoch 150 \
14 | --decay_step 3 \
15 | --latent_channels 72 \
16 | --use_vert_pca True \
17 | --work_dir ./work_dir/DFaust \
18 | --dataset DFaust \
19 | --data_dir ./data/DFaust \
20 | --distributed \
21 | --alsotest \
22 | #--continue_train True \
23 | #--checkpoint work_dir/DFaust/out/arap/checkpoints/checkpoint_0410.pt \
24 | #--test_checkpoint work_dir/DFaust/out/arap/test_checkpoints/checkpoint_01200.pt \
25 |
--------------------------------------------------------------------------------
/train_bone.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | CUDA_VISIBLE_DEVICES=0 python main.py \
3 | --exp_name 'model0' \
4 | --device_idx 0 \
5 | --epochs 4500 \
6 | --lr 8e-3 \
7 | --arap_weight 0.05 \
8 | --use_arap_epoch 1000 \
9 | --decay_step 10 \
10 | --latent_channels 8 \
11 | --use_vert_pca True \
12 | --work_dir ./work_dir/bone/tibia \
13 | --dataset Bone \
14 | --data_dir ./data/bone/tibia \
15 |
--------------------------------------------------------------------------------
/train_smal.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | CUDA_VISIBLE_DEVICES=0 python -W ignore main.py \
3 | --exp_name 'model0' \
4 | --device_idx 0 \
5 | --batch_size 16 \
6 | --epochs 2000 \
7 | --lr 0.01 \
8 | --arap_weight 0.05 \
9 | --use_arap_epoch 800 \
10 | --nz_max 96 \
11 | --decay_step 3 \
12 | --latent_channels 96 \
13 | --use_pose_init True \
14 | --work_dir ./work_dir/SMAL \
15 | --dataset SMAL \
16 | --data_dir ./data/SMAL \
17 | --continue_train True \
18 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .dataloader import DataLoader
2 |
3 | ___all__ = [
4 | 'DataLoader',
5 | ]
6 |
--------------------------------------------------------------------------------
/utils/dataloader.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data
2 | from torch.utils.data.dataloader import default_collate
3 |
4 | from torch_geometric.data import Data, Batch
5 |
6 |
7 | class DataLoader(torch.utils.data.DataLoader):
8 | def __init__(self, dataset, batch_size=1, shuffle=False, **kwargs):
9 | def dense_collate(data_list):
10 | batch = Batch()
11 | batch.batch = []
12 | for key in data_list[0].keys:
13 | batch[key] = default_collate([d[key] for d in data_list])
14 | for i, data in enumerate(data_list):
15 | num_nodes = data.num_nodes
16 | if num_nodes is not None:
17 | item = torch.full((num_nodes, ), i, dtype=torch.long)
18 | batch.batch.append(item)
19 | batch.batch = torch.cat(batch.batch, dim=0)
20 | return batch
21 |
22 | super(DataLoader, self).__init__(dataset,
23 | batch_size,
24 | shuffle,
25 | collate_fn=dense_collate,
26 | **kwargs)
27 |
--------------------------------------------------------------------------------
/utils/mesh_sampling.py:
--------------------------------------------------------------------------------
1 | import math
2 | import heapq
3 | import numpy as np
4 | import os
5 | import scipy.sparse as sp
6 | from psbody.mesh import Mesh
7 | from scipy.spatial import KDTree
8 |
9 | def row(A):
10 | return A.reshape((1, -1))
11 |
12 | def col(A):
13 | return A.reshape((-1, 1))
14 |
15 | def get_vert_connectivity(mesh_v, mesh_f):
16 | """Returns a sparse matrix (of size #verts x #verts) where each nonzero
17 | element indicates a neighborhood relation. For example, if there is a
18 | nonzero element in position (15,12), that means vertex 15 is connected
19 | by an edge to vertex 12."""
20 |
21 | vpv = sp.csc_matrix((len(mesh_v), len(mesh_v)))
22 |
23 | # for each column in the faces...
24 | for i in range(3):
25 | IS = mesh_f[:, i]
26 | JS = mesh_f[:, (i + 1) % 3]
27 | data = np.ones(len(IS))
28 | ij = np.vstack((row(IS.ravel()), row(JS.ravel())))
29 | mtx = sp.csc_matrix((data, ij), shape=vpv.shape)
30 | vpv = vpv + mtx + mtx.T
31 |
32 | return vpv
33 |
34 |
35 | def get_vertices_per_edge(mesh_v, mesh_f):
36 | """Returns an Ex2 array of adjacencies between vertices, where
37 | each element in the array is a vertex index. Each edge is included
38 | only once. If output of get_faces_per_edge is provided, this is used to
39 | avoid call to get_vert_connectivity()"""
40 |
41 | vc = sp.coo_matrix(get_vert_connectivity(mesh_v, mesh_f))
42 | result = np.hstack((col(vc.row), col(vc.col)))
43 | result = result[result[:, 0] < result[:, 1]] # for uniqueness
44 |
45 | return result
46 |
47 |
48 | def vertex_quadrics(mesh):
49 | """Computes a quadric for each vertex in the Mesh.
50 |
51 | Returns:
52 | v_quadrics: an (N x 4 x 4) array, where N is # vertices.
53 | """
54 |
55 | # Allocate quadrics
56 | v_quadrics = np.zeros((
57 | len(mesh.v),
58 | 4,
59 | 4,
60 | ))
61 |
62 | # For each face...
63 | for f_idx in range(len(mesh.f)):
64 |
65 | # Compute normalized plane equation for that face
66 | vert_idxs = mesh.f[f_idx]
67 | verts = np.hstack((mesh.v[vert_idxs], np.array([1, 1,
68 | 1]).reshape(-1, 1)))
69 | u, s, v = np.linalg.svd(verts)
70 | eq = v[-1, :].reshape(-1, 1)
71 | eq = eq / (np.linalg.norm(eq[0:3]))
72 |
73 | # Add the outer product of the plane equation to the
74 | # quadrics of the vertices for this face
75 | for k in range(3):
76 | v_quadrics[mesh.f[f_idx, k], :, :] += np.outer(eq, eq)
77 |
78 | return v_quadrics
79 |
80 |
81 | def setup_deformation_transfer(source, target, use_normals=False):
82 | rows = np.zeros(3 * target.v.shape[0])
83 | cols = np.zeros(3 * target.v.shape[0])
84 | coeffs_v = np.zeros(3 * target.v.shape[0])
85 | coeffs_n = np.zeros(3 * target.v.shape[0])
86 |
87 | nearest_faces, nearest_parts, nearest_vertices = source.compute_aabb_tree(
88 | ).nearest(target.v, True)
89 | nearest_faces = nearest_faces.ravel().astype(np.int64)
90 | nearest_parts = nearest_parts.ravel().astype(np.int64)
91 | nearest_vertices = nearest_vertices.ravel()
92 |
93 | for i in range(target.v.shape[0]):
94 | # Closest triangle index
95 | f_id = nearest_faces[i]
96 | # Closest triangle vertex ids
97 | nearest_f = source.f[f_id]
98 |
99 | # Closest surface point
100 | nearest_v = nearest_vertices[3 * i:3 * i + 3]
101 | # Distance vector to the closest surface point
102 | dist_vec = target.v[i] - nearest_v
103 |
104 | rows[3 * i:3 * i + 3] = i * np.ones(3)
105 | cols[3 * i:3 * i + 3] = nearest_f
106 |
107 | n_id = nearest_parts[i]
108 | if n_id == 0:
109 | # Closest surface point in triangle
110 | A = np.vstack((source.v[nearest_f])).T
111 | coeffs_v[3 * i:3 * i + 3] = np.linalg.lstsq(A, nearest_v,
112 | rcond=-1)[0]
113 | elif n_id > 0 and n_id <= 3:
114 | # Closest surface point on edge
115 | A = np.vstack((source.v[nearest_f[n_id - 1]],
116 | source.v[nearest_f[n_id % 3]])).T
117 | tmp_coeffs = np.linalg.lstsq(A, target.v[i], rcond=-1)[0]
118 | coeffs_v[3 * i + n_id - 1] = tmp_coeffs[0]
119 | coeffs_v[3 * i + n_id % 3] = tmp_coeffs[1]
120 | else:
121 | # Closest surface point a vertex
122 | coeffs_v[3 * i + n_id - 4] = 1.0
123 |
124 | matrix = sp.csc_matrix((coeffs_v, (rows, cols)),
125 | shape=(target.v.shape[0], source.v.shape[0]))
126 | return matrix
127 |
128 |
129 | def qslim_decimator_transformer(mesh, factor=None, n_verts_desired=None):
130 | """Return a simplified version of this mesh.
131 |
132 | A Qslim-style approach is used here.
133 |
134 | :param factor: fraction of the original vertices to retain
135 | :param n_verts_desired: number of the original vertices to retain
136 | :returns: new_faces: An Fx3 array of faces, mtx: Transformation matrix
137 | """
138 |
139 | if factor is None and n_verts_desired is None:
140 | raise Exception('Need either factor or n_verts_desired.')
141 |
142 | if n_verts_desired is None:
143 | n_verts_desired = math.ceil(len(mesh.v) * factor)
144 |
145 | Qv = vertex_quadrics(mesh)
146 |
147 | # fill out a sparse matrix indicating vertex-vertex adjacency
148 | # from psbody.mesh.topology.connectivity import get_vertices_per_edge
149 | vert_adj = get_vertices_per_edge(mesh.v, mesh.f)
150 | # vert_adj = sp.lil_matrix((len(mesh.v), len(mesh.v)))
151 | # for f_idx in range(len(mesh.f)):
152 | # vert_adj[mesh.f[f_idx], mesh.f[f_idx]] = 1
153 |
154 | vert_adj = sp.csc_matrix(
155 | (vert_adj[:, 0] * 0 + 1, (vert_adj[:, 0], vert_adj[:, 1])),
156 | shape=(len(mesh.v), len(mesh.v)))
157 | vert_adj = vert_adj + vert_adj.T
158 | vert_adj = vert_adj.tocoo()
159 |
160 | def collapse_cost(Qv, r, c, v):
161 | Qsum = Qv[r, :, :] + Qv[c, :, :]
162 | p1 = np.vstack((v[r].reshape(-1, 1), np.array([1]).reshape(-1, 1)))
163 | p2 = np.vstack((v[c].reshape(-1, 1), np.array([1]).reshape(-1, 1)))
164 |
165 | destroy_c_cost = p1.T.dot(Qsum).dot(p1)
166 | destroy_r_cost = p2.T.dot(Qsum).dot(p2)
167 | result = {
168 | 'destroy_c_cost': destroy_c_cost,
169 | 'destroy_r_cost': destroy_r_cost,
170 | 'collapse_cost': min([destroy_c_cost, destroy_r_cost]),
171 | 'Qsum': Qsum
172 | }
173 | return result
174 |
175 | # construct a queue of edges with costs
176 | queue = []
177 | for k in range(vert_adj.nnz):
178 | r = vert_adj.row[k]
179 | c = vert_adj.col[k]
180 |
181 | if r > c:
182 | continue
183 |
184 | cost = collapse_cost(Qv, r, c, mesh.v)['collapse_cost']
185 | heapq.heappush(queue, (cost, (r, c)))
186 |
187 | # decimate
188 | collapse_list = []
189 | nverts_total = len(mesh.v)
190 | faces = mesh.f.copy()
191 | while nverts_total > n_verts_desired:
192 | e = heapq.heappop(queue)
193 | r = e[1][0]
194 | c = e[1][1]
195 | if r == c:
196 | continue
197 |
198 | cost = collapse_cost(Qv, r, c, mesh.v)
199 | if cost['collapse_cost'] > e[0]:
200 | heapq.heappush(queue, (cost['collapse_cost'], e[1]))
201 | # print 'found outdated cost, %.2f < %.2f' % (e[0], cost['collapse_cost'])
202 | continue
203 | else:
204 |
205 | # update old vert idxs to new one,
206 | # in queue and in face list
207 | if cost['destroy_c_cost'] < cost['destroy_r_cost']:
208 | to_destroy = c
209 | to_keep = r
210 | else:
211 | to_destroy = r
212 | to_keep = c
213 |
214 | collapse_list.append([to_keep, to_destroy])
215 |
216 | # in our face array, replace "to_destroy" vertidx with "to_keep" vertidx
217 | np.place(faces, faces == to_destroy, to_keep)
218 |
219 | # same for queue
220 | which1 = [
221 | idx for idx in range(len(queue))
222 | if queue[idx][1][0] == to_destroy
223 | ]
224 | which2 = [
225 | idx for idx in range(len(queue))
226 | if queue[idx][1][1] == to_destroy
227 | ]
228 | for k in which1:
229 | queue[k] = (queue[k][0], (to_keep, queue[k][1][1]))
230 | for k in which2:
231 | queue[k] = (queue[k][0], (queue[k][1][0], to_keep))
232 |
233 | Qv[r, :, :] = cost['Qsum']
234 | Qv[c, :, :] = cost['Qsum']
235 |
236 | a = faces[:, 0] == faces[:, 1]
237 | b = faces[:, 1] == faces[:, 2]
238 | c = faces[:, 2] == faces[:, 0]
239 |
240 | # remove degenerate faces
241 | def logical_or3(x, y, z):
242 | return np.logical_or(x, np.logical_or(y, z))
243 |
244 | faces_to_keep = np.logical_not(logical_or3(a, b, c))
245 | faces = faces[faces_to_keep, :].copy()
246 |
247 | nverts_total = (len(np.unique(faces.flatten())))
248 |
249 | new_faces, mtx = _get_sparse_transform(faces, len(mesh.v))
250 | return new_faces, mtx
251 |
252 |
253 | def _get_sparse_transform(faces, num_original_verts):
254 | verts_left = np.unique(faces.flatten())
255 | IS = np.arange(len(verts_left))
256 | JS = verts_left
257 | data = np.ones(len(JS))
258 |
259 | mp = np.arange(0, np.max(faces.flatten()) + 1)
260 | mp[JS] = IS
261 | new_faces = mp[faces.copy().flatten()].reshape((-1, 3))
262 |
263 | ij = np.vstack((IS.flatten(), JS.flatten()))
264 | mtx = sp.csc_matrix((data, ij),
265 | shape=(len(verts_left), num_original_verts))
266 |
267 | return (new_faces, mtx)
268 |
269 |
270 | def generate_transform_matrices(mesh, factors):
271 | """Generates len(factors) meshes, each of them is scaled by factors[i] and
272 | computes the transformations between them.
273 |
274 | Returns:
275 | M: a set of meshes downsampled from mesh by a factor specified in factors.
276 | A: Adjacency matrix for each of the meshes
277 | D: csc_matrix Downsampling transforms between each of the meshes
278 | U: Upsampling transforms between each of the meshes
279 | F: a list of faces
280 | """
281 |
282 | factors = map(lambda x: 1.0 / x, factors)
283 | M, A, D, U, F = [], [], [], [], []
284 | F.append(mesh.f) # F[0]
285 | A.append(get_vert_connectivity(mesh.v, mesh.f).astype('float32')) # A[0]
286 | M.append(mesh) # M[0]
287 |
288 | for factor in factors:
289 | ds_f, ds_D = qslim_decimator_transformer(M[-1], factor=factor)
290 | D.append(ds_D.astype('float32'))
291 | new_mesh_v = ds_D.dot(M[-1].v)
292 | new_mesh = Mesh(v=new_mesh_v, f=ds_f)
293 | F.append(new_mesh.f)
294 | M.append(new_mesh)
295 | A.append(
296 | get_vert_connectivity(new_mesh.v, new_mesh.f).astype('float32'))
297 | U.append(setup_deformation_transfer(M[-1], M[-2]).astype('float32'))
298 |
299 | return M, A, D, U, F
300 |
--------------------------------------------------------------------------------
/utils/read.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_geometric.data import Data
3 | from torch_geometric.utils import to_undirected
4 | import openmesh as om
5 |
6 | def read_mesh(path, data_id, pose=None, return_face=False):
7 | mesh = om.read_trimesh(path)
8 | points = mesh.points()
9 | face = torch.from_numpy(mesh.face_vertex_indices()).T.type(torch.long)
10 |
11 | x = torch.tensor(points.astype('float32'))
12 | edge_index = torch.cat([face[:2], face[1:], face[::2]], dim=1)
13 | edge_index = to_undirected(edge_index)
14 | if return_face==True and pose is not None:
15 | return Data(x=x, edge_index=edge_index,face=face, data_id=data_id, pose=pose)
16 | if pose is not None:
17 | return Data(x=x, edge_index=edge_index, data_id=data_id, pose=pose)
18 | if return_face==True:
19 | return Data(x=x, edge_index=edge_index,face=face, data_id=data_id)
20 | return Data(x=x, edge_index=edge_index, data_id=data_id)
21 |
22 |
--------------------------------------------------------------------------------
/utils/train_eval.py:
--------------------------------------------------------------------------------
1 | import time
2 | import os, sys
3 | import torch
4 | import torch.nn.functional as F
5 | from psbody.mesh import Mesh
6 | import numpy as np
7 | from models import ARAP
8 | import scipy.io as sio
9 | from tqdm import tqdm
10 |
11 | def run(model,
12 | train_loader, lat_vecs, optimizer, scheduler,
13 | test_loader, test_lat_vecs, optimizer_test, scheduler_test,
14 | epochs, writer, device, results_dir, data_mean, data_std,
15 | template_face, arap_weight=0.0, use_arap_epoch=800, nz_max=10,
16 | continue_train=False, checkpoint=None, test_checkpoint=None, dataset='DFaust'):
17 |
18 | start_epoch = 0
19 | if continue_train:
20 | start_epoch = writer.load_checkpoint(
21 | model, lat_vecs, optimizer, scheduler, checkpoint=checkpoint)
22 |
23 | if (optimizer_test is not None) and continue_train:
24 | test_epoch = writer.load_checkpoint(
25 | model, test_lat_vecs, optimizer_test, scheduler_test,
26 | checkpoint=test_checkpoint, test=True)
27 |
28 | for epoch in range(start_epoch+1, epochs):
29 | t = time.time()
30 | use_arap = epoch > use_arap_epoch
31 |
32 | train_loss, l1_loss, arap_loss, l2_error = \
33 | train(
34 | model, epoch, optimizer, train_loader, lat_vecs,
35 | device, results_dir, data_mean, data_std, template_face,
36 | arap_weight=arap_weight, nz_max=nz_max,
37 | use_arap=use_arap, lr=scheduler.get_lr()[0],
38 | dataset=dataset,
39 | )
40 |
41 | if optimizer_test is not None:
42 | for k in range(3):
43 | test_loss, test_l1_loss, test_arap_loss, test_l2_error = \
44 | train(
45 | model, epoch*3+k, optimizer_test, test_loader, test_lat_vecs,
46 | device, results_dir, data_mean, data_std, template_face,
47 | arap_weight=0.0, nz_max=nz_max, use_arap=False,
48 | exp_name='reconstruct', lr=scheduler_test.get_lr()[0],
49 | dataset=dataset,
50 | )
51 | test_info = {
52 | 'test_current_epoch': epoch*3+k,
53 | 'test_epochs': epochs*3,
54 | 'test_loss': test_loss,
55 | 'l1_loss': test_l1_loss,
56 | 'arap_loss': test_arap_loss,
57 | 'mse_error': test_l2_error,
58 | 't_duration': 0.0,
59 | 'lr': scheduler_test.get_lr()[0]
60 | }
61 | scheduler_test.step()
62 | writer.print_info_test(test_info)
63 | if epoch % 10 == 0:
64 | writer.save_checkpoint(
65 | model, test_lat_vecs, optimizer_test, scheduler_test,
66 | epoch*3, test=True
67 | )
68 |
69 | t_duration = time.time() - t
70 | scheduler.step()
71 | info = {
72 | 'current_epoch': epoch,
73 | 'epochs': epochs,
74 | 'train_loss': train_loss,
75 | 'l1_loss': l1_loss,
76 | 'arap_loss': arap_loss,
77 | 'mse_error': l2_error,
78 | 't_duration': t_duration,
79 | 'lr': scheduler.get_last_lr()[0]
80 | }
81 |
82 | writer.print_info(info)
83 | if epoch % 10 == 0:
84 | writer.save_checkpoint(model, lat_vecs, optimizer, scheduler, epoch)
85 |
86 |
87 | def train(model, epoch, optimizer, loader, lat_vecs, device,
88 | results_dir,data_mean, data_std, template_face,
89 | arap_weight=0.0, nz_max=10, use_arap=False, dump=False,
90 | exp_name='train', lr=0.0, dataset='DFaust'):
91 |
92 | model.train()
93 | total_loss = 0
94 | total_l1_loss = 0
95 | total_arap_loss = 0
96 | l2_errors = []
97 |
98 | data_mean_gpu = torch.as_tensor(data_mean, dtype=torch.float).to(device)
99 | data_std_gpu = torch.as_tensor(data_std, dtype=torch.float).to(device)
100 |
101 | model.train()
102 | pbar = tqdm(total=len(loader),
103 | desc=f'{exp_name} {epoch}')
104 |
105 | # ARAP
106 | arap = ARAP(template_face, template_face.max()+1).to(device)
107 | mse = [0, 0]
108 |
109 | for b, data in enumerate(loader):
110 | optimizer.zero_grad()
111 | x = data.x.to(device)
112 | #x[:] = loader.dataset[4044].x
113 | ids = data.data_id.to(device)
114 | #ids[:] = 4044
115 | batch_vecs = lat_vecs(ids.view(-1))
116 | out = model(batch_vecs)
117 |
118 | pred_shape = out * data_std_gpu + data_mean_gpu
119 | gt_shape = x * data_std_gpu + data_mean_gpu
120 | if dataset=='SMAL':
121 | l1_loss = F.l1_loss(x, out, reduction='mean')
122 | arap_ep = 1e-3
123 | else:
124 | l1_loss = F.l1_loss(pred_shape, gt_shape, reduction='mean')
125 | arap_ep = 1e-1
126 | tmp_error = torch.sqrt(torch.sum((pred_shape - gt_shape)**2,dim=2)).detach().cpu()
127 | mse[1] += tmp_error.view(-1).shape[0]
128 | mse[0] += tmp_error.sum().item()
129 | l2_errors.append(tmp_error)
130 |
131 | loss = torch.zeros(1).to(device)
132 | loss += l1_loss
133 |
134 | if use_arap and arap_weight>0:
135 |
136 | jacob = get_jacobian_rand(
137 | pred_shape, batch_vecs, data_mean_gpu, data_std_gpu,
138 | model, device,
139 | epsilon=arap_ep,
140 | nz_max=nz_max)
141 | arap_energy = arap(pred_shape, jacob,) / jacob.shape[-1]
142 | total_arap_loss += arap_energy.item()
143 | loss += arap_weight*arap_energy
144 |
145 | loss.backward()
146 | total_loss += loss.item()
147 | total_l1_loss += l1_loss.item()
148 | pbar.set_postfix({
149 | 'loss': total_loss / (b+1.0),
150 | 'arap': total_arap_loss / (b+1.0),
151 | 'w_arap': arap_weight,
152 | 'MSE': mse[0] / (mse[1]+1e-6) * 1000,
153 | 'lr': lr,
154 | })
155 | pbar.update(1)
156 | optimizer.step()
157 | new_errors = torch.cat(l2_errors, dim=0)
158 | mean_error = new_errors.view((-1, )).mean()
159 |
160 | # visualize
161 | #if epoch%20==0 and dump==False:
162 | # gt_meshes = x.detach().cpu().numpy()
163 | # pred_meshes = out.detach().cpu().numpy()
164 | # for b_i in range(min(2, gt_meshes.shape[0])):
165 | # pred_v = pred_meshes[b_i].reshape((-1, 3))*data_std + data_mean
166 | # gt_v = gt_meshes[b_i].reshape((-1, 3))*data_std + data_mean
167 | # mesh = Mesh(v=pred_v, f=template_face)
168 | # mesh.write_ply(os.path.join(results_dir, '%d_%d'%(epoch, b_i)+'_pred.ply'))
169 | # mesh = Mesh(v=gt_v, f=template_face)
170 | # mesh.write_ply(os.path.join(results_dir, '%d_%d'%(epoch, b_i)+'_gt.ply'))
171 | #if dump==True and epoch%200==1:
172 | # model.eval()
173 | # with torch.no_grad():
174 | # for data in loader:
175 | # x = data.x.to(device)
176 | # ids = data.data_id.to(device)
177 | # batch_vecs = lat_vecs(ids.view(-1))
178 | # out = model(batch_vecs)
179 | # pred_shape = out * data_std_gpu + data_mean_gpu
180 | # gt_shape = x * data_std_gpu + data_mean_gpu
181 | # gt_meshes = gt_shape.detach().cpu().numpy()
182 | # pred_meshes = pred_shape.detach().cpu().numpy()
183 | # ids_np = data.data_id.cpu().numpy()
184 |
185 | # for b_i in range(gt_meshes.shape[0]):
186 | # pred_v = pred_meshes[b_i].reshape((-1, 3))
187 | # gt_v = gt_meshes[b_i].reshape((-1 , 3))
188 | # mesh = Mesh(v=pred_v, f=template_face)
189 | # mesh.write_ply(os.path.join(results_dir, '%d'%(ids_np[b_i])+'_pred.ply'))
190 | # mesh = Mesh(v=gt_v, f=template_face)
191 | # mesh.write_ply(os.path.join(results_dir, '%d'%(ids_np[b_i])+'_gt.ply'))
192 |
193 | return total_loss / len(loader), total_l1_loss / len(loader), total_arap_loss / len(loader), mean_error
194 |
195 | def get_jacobian(out, z, data_mean_gpu, data_std_gpu, model, device, epsilon=[1e-3]):
196 | nb, nz = z.size()
197 | _, n_vert, nc = out.size()
198 | jacobian = torch.zeros((nb, n_vert*nc, nz)).to(device)
199 |
200 | for i in range(nz):
201 | dz = torch.zeros(z.size()).to(device)
202 | dz[:, i] = epsilon
203 | z_new = z + dz
204 | out_new = model(z_new)
205 | dout = (out_new - out).view(nb, -1)
206 | jacobian[:, :, i] = dout/epsilon
207 |
208 | data_std_gpu = data_std_gpu.reshape((1, n_vert*nc, 1))
209 | jacobian = jacobian*data_std_gpu
210 | return jacobian
211 |
212 | def get_jacobian_rand(cur_shape, z, data_mean_gpu, data_std_gpu, model, device, epsilon=[1e-3],nz_max=10):
213 | nb, nz = z.size()
214 | _, n_vert, nc = cur_shape.size()
215 | if nz >= nz_max:
216 | rand_idx = np.random.permutation(nz)[:nz_max]
217 | nz = nz_max
218 | else:
219 | rand_idx = np.arange(nz)
220 |
221 | jacobian = torch.zeros((nb, n_vert*nc, nz)).to(device)
222 | for i, idx in enumerate(rand_idx):
223 | dz = torch.zeros(z.size()).to(device)
224 | dz[:, idx] = epsilon
225 | z_new = z + dz
226 | out_new = model(z_new)
227 | shape_new = out_new * data_std_gpu + data_mean_gpu
228 | dout = (shape_new - cur_shape).view(nb, -1)
229 | jacobian[:, :, i] = dout/epsilon
230 | return jacobian
231 |
232 |
233 | def test_reconstruct(
234 | model, test_loader, test_lat_vecs, epochs, test_optimizer, scheduler,
235 | writer, device, results_dir, data_mean, data_std, template_face,
236 | checkpoint=None, test_checkpoint=None, dataset='DFaust'):
237 | # load model
238 | start_epoch = writer.load_checkpoint(model, checkpoint=checkpoint)
239 | if test_checkpoint is not None:
240 | start_epoch = writer.load_checkpoint(
241 | model, test_lat_vecs, test_optimizer, scheduler,
242 | checkpoint=test_checkpoint, test=True)
243 |
244 | for epoch in range(1, epochs + 1):
245 | t = time.time()
246 |
247 | test_loss, l1_loss, arap_loss, l2_error = \
248 | train(model, epoch, test_optimizer, test_loader,
249 | test_lat_vecs, device, results_dir, data_mean,
250 | data_std, template_face, arap_weight=0.0, use_arap=False,
251 | dump=True, exp_name='reconstruct', lr=scheduler.get_lr()[0],
252 | dataset=dataset,
253 | )
254 |
255 | t_duration = time.time() - t
256 | scheduler.step()
257 | info = {
258 | 'test_current_epoch': epoch,
259 | 'test_epochs': epochs,
260 | 'test_loss': test_loss,
261 | 'l1_loss': l1_loss,
262 | 'arap_loss':arap_loss,
263 | 'mse_error':l2_error,
264 | 't_duration': t_duration,
265 | 'lr': scheduler.get_lr()[0]
266 | }
267 | if epoch % 200 == 1:
268 | writer.save_checkpoint(model, test_lat_vecs, test_optimizer,
269 | scheduler, epoch, test=True)
270 | writer.print_info_test(info)
271 | writer.save_checkpoint(
272 | model, test_lat_vecs, test_optimizer,
273 | scheduler, epoch, test=True)
274 |
275 | def global_interpolate(
276 | model, lat_vecs, optimizer, scheduler, writer,
277 | device, results_dir, data_mean, data_std,
278 | template_face, interpolate_num):
279 | from scipy.sparse import csr_matrix
280 | from scipy.sparse.csgraph import minimum_spanning_tree
281 | import numpy as np
282 |
283 | load_epoch = writer.load_checkpoint(model, lat_vecs, optimizer, scheduler,)
284 |
285 | lat_vecs_np = lat_vecs.weight.data.detach().cpu().numpy()
286 | results_dir = os.path.join(results_dir, "interpolate/epoch%d"%(load_epoch))
287 | if not os.path.exists(results_dir):
288 | os.makedirs(results_dir)
289 |
290 | min_path = np.random.randint(0, lat_vecs_np.shape[0], size=50)
291 | print('min_path', min_path)
292 |
293 | for p in range(len(min_path)-1):
294 | start_vec = lat_vecs.weight.data[min_path[p]]
295 | end_vec = lat_vecs.weight.data[min_path[p+1]]
296 |
297 | for i in range(interpolate_num+1):
298 | vec = start_vec + i*(end_vec - start_vec)/interpolate_num
299 | out = model(vec)
300 | out_numpy = out.detach().cpu().numpy()
301 | out_v = out_numpy.reshape((-1, 3))*data_std + data_mean
302 |
303 | mesh = Mesh(v=out_v, f=template_face)
304 | mesh.write_obj(os.path.join(results_dir, '%d_%d_%d-%d'%(p, i, min_path[p], min_path[p+1])+'.obj'))
305 |
306 |
307 | def extrapolation(model, lat_vecs, optimizer, scheduler, writer, device, results_dir, data_mean,
308 | data_std, template_face, extra_num=5, extra_thres=0.2):
309 |
310 | load_epoch = writer.load_checkpoint(model, lat_vecs, optimizer, scheduler,test=True)
311 |
312 | lat_vecs_np = lat_vecs.weight.data.detach().cpu().numpy()
313 | print('lat_vecs_np', lat_vecs_np)
314 | results_dir = os.path.join(results_dir, "extrapolation/epoch%d"%(load_epoch))
315 | if not os.path.exists(results_dir):
316 | os.makedirs(results_dir)
317 |
318 | z_dict = {}
319 | num_train_scenes = lat_vecs_np.shape[0]
320 |
321 | z_path = np.random.randint(0, num_train_scenes, size=30)
322 | print(z_path)
323 | for p in range(len(z_path)-1):
324 | or_vec = lat_vecs.weight.data[z_path[p]]
325 | extra_z = torch.zeros((extra_num+1, lat_vecs_np.shape[1])).to(device)
326 |
327 | for i in range(extra_num):
328 | vec = or_vec + extra_thres*(torch.rand(lat_vecs_np.shape[1]).to(device)-0.5)
329 | extra_z[i+1] = vec
330 | extra_z[0] = or_vec
331 | out = model(extra_z)
332 | out_numpy = out.detach().cpu().numpy()
333 | for i in range(extra_num):
334 | out_v = out_numpy[i].reshape((-1, 3))*data_std + data_mean
335 | mesh = Mesh(v=out_v, f=template_face)
336 | mesh.write_ply(os.path.join(results_dir, '%d_%d'%(z_path[p],i)+'.ply'))
337 | z_dict['%d_%d'%( z_path[p], i)] = extra_z[i].detach().cpu().numpy()
338 | np.save(os.path.join(results_dir, "name_latentz.npy"), z_dict)
339 |
340 |
341 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import numpy as np
4 | from glob import glob
5 | from scipy.spatial import cKDTree as KDTree
6 | from matplotlib import cm
7 |
8 |
9 |
10 | def makedirs(folder):
11 | if not os.path.exists(folder):
12 | os.makedirs(folder)
13 |
14 |
15 | def to_sparse(spmat):
16 | return torch.sparse.FloatTensor(
17 | torch.LongTensor([spmat.tocoo().row,
18 | spmat.tocoo().col]),
19 | torch.FloatTensor(spmat.tocoo().data), torch.Size(spmat.tocoo().shape))
20 |
21 |
22 | def to_edge_index(mat):
23 | return torch.LongTensor(np.vstack(mat.nonzero()))
24 |
25 | def get_colors_from_diff_pc(diff_pc, min_error, max_error):
26 | colors = np.zeros((diff_pc.shape[0],3))
27 | mix = (diff_pc-min_error)/(max_error-min_error)
28 | mix = np.clip(mix, 0,1) #point_num
29 | cmap=cm.get_cmap('coolwarm')
30 | colors = cmap(mix)[:,0:3]
31 | return colors
32 |
33 |
34 | def save_pc_with_color_into_ply(template_ply, pc, color, fn):
35 | plydata=template_ply
36 | #pc = pc.copy()*pc_std + pc_mean
37 | plydata['vertex']['x']=pc[:,0]
38 | plydata['vertex']['y']=pc[:,1]
39 | plydata['vertex']['z']=pc[:,2]
40 |
41 | plydata['vertex']['red']=color[:,0]
42 | plydata['vertex']['green']=color[:,1]
43 | plydata['vertex']['blue']=color[:,2]
44 |
45 | plydata.write(fn)
46 | plydata['vertex']['red']=plydata['vertex']['red']*0+0.7*255
47 | plydata['vertex']['green']=plydata['vertex']['red']*0+0.7*255
48 | plydata['vertex']['blue']=plydata['vertex']['red']*0+0.7*255
49 |
50 | def compute_trimesh_chamfer(gt_points, gen_points):
51 | """
52 | This function computes a symmetric chamfer distance, i.e. the sum of both chamfers.
53 | gt_points: trimesh.points.PointCloud of just poins, sampled from the surface (see
54 | compute_metrics.ply for more documentation)
55 | gen_mesh: trimesh.base.Trimesh of output mesh from whichever autoencoding reconstruction
56 | method (see compute_metrics.py for more)
57 | """
58 | # only need numpy array of points
59 | # gt_points_np = gt_points.vertices
60 | gt_points_np = gt_points.detach().cpu().numpy()
61 | gen_points_sampled = gen_points.detach().cpu().numpy()
62 |
63 | # one direction
64 | gen_points_kd_tree = KDTree(gen_points_sampled)
65 | one_distances, one_vertex_ids = gen_points_kd_tree.query(gt_points_np)
66 | gt_to_gen_chamfer = np.mean(np.square(one_distances))
67 |
68 | # other direction
69 | gt_points_kd_tree = KDTree(gt_points_np)
70 | two_distances, two_vertex_ids = gt_points_kd_tree.query(gen_points_sampled)
71 | gen_to_gt_chamfer = np.mean(np.square(two_distances))
72 |
73 | return gt_to_gen_chamfer + gen_to_gt_chamfer
74 |
--------------------------------------------------------------------------------
/utils/writer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import torch
4 | import json
5 | from glob import glob
6 |
7 |
8 | class Writer:
9 | def __init__(self, args=None):
10 | self.args = args
11 |
12 | if self.args is not None:
13 | tmp_log_list = glob(os.path.join(args.out_dir, 'log*'))
14 |
15 | self.test_log_file = os.path.join(
16 | args.out_dir, 'test_log_{:s}.txt'.format(
17 | time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime())))
18 | if len(tmp_log_list) == 0:
19 | self.log_file = os.path.join(
20 | args.out_dir, 'log_{:s}.txt'.format(
21 | time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime())))
22 | else:
23 | self.log_file = tmp_log_list[0]
24 | message = '{}'.format(self.args)
25 | if args.mode=='test':
26 | with open(self.test_log_file, 'a') as log_file:
27 | log_file.write('{:s}\n'.format(message))
28 | else:
29 | with open(self.log_file, 'a') as log_file:
30 | log_file.write('{:s}\n'.format(message))
31 |
32 |
33 | def print_info(self, info):
34 | message = 'Epoch: {}/{}, Time: {:.3f}s, Train Loss: {:.5f}, L1: {:.5f}, arap: {:.5f} MSE: {:.5f}, lr: {:.6f}' \
35 | .format(info['current_epoch'], info['epochs'], info['t_duration'], \
36 | info['train_loss'], info['l1_loss'], info['arap_loss'], info['mse_error'], info['lr'])
37 |
38 | with open(self.log_file, 'a') as log_file:
39 | log_file.write('{:s}\n'.format(message))
40 | #print(message)
41 |
42 | def print_info_test(self, info):
43 | message = '[test] Epoch: {}/{}, Time: {:.3f}s, Train Loss: {:.5f}, L1: {:.5f}, arap: {:.5f} MSE: {:.6f} lr: {:.6f}' \
44 | .format(info['test_current_epoch'], info['test_epochs'], info['t_duration'], \
45 | info['test_loss'], info['l1_loss'], info['arap_loss'],info['mse_error'],info['lr'])
46 |
47 | with open(self.test_log_file, 'a') as log_file:
48 | log_file.write('{:s}\n'.format(message))
49 | #print(message)
50 |
51 | def save_checkpoint(self, model, latent_vecs, optimizer, scheduler,
52 | epoch, test=False):
53 | model_path = self.args.checkpoints_dir
54 | if test==True:
55 | model_path = self.args.checkpoints_dir_test
56 | print('save test checkpoint')
57 |
58 | torch.save(
59 | {
60 | 'epoch': epoch,
61 | 'model_state_dict': model.state_dict(),
62 | 'optimizer_state_dict': optimizer.state_dict(),
63 | 'scheduler_state_dict': scheduler.state_dict(),
64 | 'train_latent_vecs': latent_vecs.state_dict(),
65 | },
66 | os.path.join(
67 | model_path, 'checkpoint_{:04d}.pt'.format(epoch))
68 | )
69 |
70 | def load_checkpoint(self, model, latent_vecs=None, optimizer=None,
71 | scheduler=None, test=False, checkpoint=None):
72 | model_path = self.args.checkpoints_dir
73 | if test==True:
74 | model_path = self.args.checkpoints_dir_test
75 |
76 | if checkpoint is None:
77 | checkpoint_list = glob(os.path.join(model_path, 'checkpoint_*'))
78 | if len(checkpoint_list)==0:
79 | print('train from scrach')
80 | return 0
81 | latest_checkpoint = sorted(checkpoint_list)[-1]
82 | else:
83 | latest_checkpoint = checkpoint
84 | print("loading model from ", latest_checkpoint)
85 | data = torch.load(latest_checkpoint)
86 | model.load_state_dict(data["model_state_dict"])
87 | if latent_vecs:
88 | latent_vecs.load_state_dict(data["train_latent_vecs"])
89 | if scheduler:
90 | scheduler.load_state_dict(data["scheduler_state_dict"])
91 | if optimizer:
92 | optimizer.load_state_dict(data["optimizer_state_dict"])
93 | print("loaded!")
94 | return data["epoch"]
95 |
--------------------------------------------------------------------------------
/work_dir/SMAL/out/move_smal_ckpt.zip_here.txt:
--------------------------------------------------------------------------------
1 | .
2 |
--------------------------------------------------------------------------------
/work_dir/move_bone_ckpt.zip_here.txt:
--------------------------------------------------------------------------------
1 | hh
2 |
--------------------------------------------------------------------------------