├── .gitignore ├── ES.py ├── LICENSE ├── README.md ├── checkpoints ├── classifier │ └── model_00160000.ckpt └── generator │ └── model_00368000.ckpt ├── evaluate_poses.py ├── imgs ├── dataset.png └── demo.gif ├── inference.py ├── pose_check ├── data_loader │ └── pose_check_dataset.py ├── models │ ├── __init__.py │ ├── pointnet_utils.py │ ├── uninet.py │ └── uninet_mt.py ├── train_multi_task_var_impl.py └── utils │ └── utils.py ├── pose_generation ├── data_loader │ └── stable_pose_dataset.py ├── models │ ├── __init__.py │ ├── pointnet_utils.py │ ├── uninet.py │ └── vnet.py ├── train_impl.py └── utils │ └── utils.py ├── real_data ├── plys │ ├── bowl1.ply │ ├── pc_inworld_object_1.ply │ ├── pc_inworld_object_2.ply │ ├── pc_inworld_object_3.ply │ ├── pc_inworld_support_2.ply │ └── pc_inworld_support_3.ply └── real_data.txt ├── requirements.txt └── scripts ├── evaluate_testset.py ├── test_real_data.sh ├── train_multi_task.sh └── train_pose_generation.sh /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | regrasp.pem 4 | *.py[cod] 5 | *$py.class 6 | .DS_Store 7 | .idea* 8 | display 9 | # C extensions 10 | *.so 11 | *.zip 12 | debug 13 | */*/*/debug_gradient.py 14 | model_training/dataset 15 | *debug 16 | test_results 17 | # Distribution / packaging 18 | .Python 19 | #checkpoints 20 | model_training/checkpoints/*/events.* 21 | */*/imgs 22 | Test_Results 23 | build/ 24 | develop-eggs/ 25 | dist/ 26 | downloads/ 27 | eggs/ 28 | .eggs/ 29 | lib/ 30 | lib64/ 31 | parts/ 32 | sdist/ 33 | var/ 34 | wheels/ 35 | pip-wheel-metadata/ 36 | share/python-wheels/ 37 | *.egg-info/ 38 | .installed.cfg 39 | *.egg 40 | MANIFEST 41 | 42 | # PyInstaller 43 | # Usually these files are written by a python script from a template 44 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 45 | *.manifest 46 | *.spec 47 | 48 | # Installer logs 49 | pip-log.txt 50 | pip-delete-this-directory.txt 51 | 52 | # Unit test / coverage reports 53 | htmlcov/ 54 | .tox/ 55 | .nox/ 56 | .coverage 57 | .coverage.* 58 | .cache 59 | nosetests.xml 60 | coverage.xml 61 | *.cover 62 | *.py,cover 63 | .hypothesis/ 64 | .pytest_cache/ 65 | 66 | # Translations 67 | *.mo 68 | *.pot 69 | 70 | # Django stuff: 71 | *.log 72 | local_settings.py 73 | db.sqlite3 74 | db.sqlite3-journal 75 | 76 | # Flask stuff: 77 | instance/ 78 | .webassets-cache 79 | 80 | # Scrapy stuff: 81 | .scrapy 82 | 83 | # Sphinx documentation 84 | docs/_build/ 85 | 86 | # PyBuilder 87 | target/ 88 | 89 | # Jupyter Notebook 90 | .ipynb_checkpoints 91 | 92 | # IPython 93 | profile_default/ 94 | ipython_config.py 95 | 96 | # pyenv 97 | .python-version 98 | 99 | # pipenv 100 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 101 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 102 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 103 | # install all needed dependencies. 104 | #Pipfile.lock 105 | 106 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 107 | __pypackages__/ 108 | 109 | # Celery stuff 110 | celerybeat-schedule 111 | celerybeat.pid 112 | 113 | # SageMath parsed files 114 | *.sage.py 115 | 116 | # Environments 117 | .env 118 | .venv 119 | env/ 120 | venv/ 121 | ENV/ 122 | env.bak/ 123 | venv.bak/ 124 | 125 | # Spyder project_3ds settings 126 | .spyderproject 127 | .spyproject 128 | 129 | # Rope project_3ds settings 130 | .ropeproject 131 | 132 | # mkdocs documentation 133 | /site 134 | 135 | # mypy 136 | .mypy_cache/ 137 | .dmypy.json 138 | dmypy.json 139 | 140 | # Pyre type checker 141 | .pyre/ 142 | noise_* 143 | 144 | dataset 145 | test_results -------------------------------------------------------------------------------- /ES.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import pytorch3d.transforms as torch_transform 5 | from pose_check.utils.utils import dict2cuda 6 | from torch.distributions import MultivariateNormal 7 | from sklearn.cluster import AgglomerativeClustering 8 | 9 | class CEM(): 10 | 11 | """ 12 | Cross-entropy methods. Adapted to PyTorch 13 | """ 14 | 15 | def __init__(self, 16 | num_params, 17 | batch_size, 18 | pop_size, 19 | parents, 20 | mu_init=None, 21 | sigma_init=1e-3, 22 | clip=0.1, 23 | damp=0.1, 24 | damp_limit=1e-5, 25 | elitism=True, 26 | device=torch.device('cuda') 27 | ): 28 | 29 | # misc 30 | self.num_params = num_params 31 | self.batch_size = batch_size 32 | self.device = device 33 | # distribution parameters 34 | if mu_init is None: 35 | self.mu = torch.zeros([self.batch_size, self.num_params], device=device) 36 | else: 37 | self.mu = mu_init.clone() 38 | self.sigma = sigma_init 39 | self.damp = damp 40 | self.damp_limit = damp_limit 41 | self.tau = 0.95 42 | self.cov = self.sigma * torch.ones([self.batch_size, self.num_params], device=device) 43 | self.clip = clip 44 | 45 | # elite stuff 46 | self.elitism = elitism 47 | self.elite = torch.sqrt(torch.tensor(self.sigma, device=device)) * torch.rand(self.batch_size, self.num_params, device=device) 48 | self.elite_score = None 49 | 50 | # sampling stuff 51 | self.pop_size = pop_size 52 | if parents is None or parents <= 0: 53 | self.parents = pop_size // 2 54 | else: 55 | self.parents = parents 56 | self.weights = torch.FloatTensor([np.log((self.parents + 1) / i) 57 | for i in range(1, self.parents + 1)]).to(device) 58 | self.weights /= self.weights.sum() 59 | 60 | def ask(self, pop_size): 61 | """ 62 | Returns a list of candidates parameters 63 | """ 64 | epsilon = torch.randn(self.batch_size, pop_size, self.num_params, device=self.device) 65 | inds = self.mu.unsqueeze(1) + (epsilon * torch.sqrt(self.cov).unsqueeze(1)).clamp(-self.clip, self.clip) 66 | if self.elitism: 67 | inds[:, -1] = self.elite 68 | return inds 69 | 70 | def tell(self, solutions, scores): 71 | """ 72 | Updates the distribution 73 | returns the best solution 74 | :param solutions: (B, N, 6) 6d representation of transforms 75 | :param scores: (B, N) 76 | :return top_solution: (B, 6) 77 | """ 78 | assert len(scores.shape) == 2 79 | 80 | sorted_scores, idx_sorted = torch.sort(scores, dim=1, descending=True) 81 | 82 | old_mu = self.mu.clone() 83 | self.damp = self.damp * self.tau + (1 - self.tau) * self.damp_limit 84 | idx_sorted = idx_sorted[:, :self.parents] 85 | top_solutions = torch.gather(solutions, 1, idx_sorted.unsqueeze(2).expand(*idx_sorted.shape, solutions.shape[-1])) 86 | self.mu = self.weights @ top_solutions 87 | z = top_solutions - old_mu.unsqueeze(1) 88 | self.cov = 1 / self.parents * self.weights @ ( 89 | z * z) + self.damp * torch.ones([self.batch_size, self.num_params], device=self.device) 90 | 91 | self.elite = top_solutions[:, 0, :] 92 | # self.elite_score = scores[:, idx_sorted[0]] 93 | 94 | return top_solutions[:, 0, :], sorted_scores[:, 0] 95 | 96 | def get_distrib_params(self): 97 | """ 98 | Returns the parameters of the distrubtion: 99 | the mean and sigma 100 | """ 101 | return self.mu.clone(), self.cov.clone() 102 | 103 | class Searcher(): 104 | def __init__(self, 105 | action_dim, 106 | pop_size=25, 107 | parents=5, 108 | sigma_init=1e-3, 109 | clip=0.1, 110 | damp=0.1, 111 | damp_limit=0.05, 112 | device=torch.device('cuda')): 113 | 114 | self.sigma_init = sigma_init 115 | self.clip=clip 116 | self.pop_size = pop_size 117 | self.damp = damp 118 | self.damp_limit = damp_limit 119 | self.parents = parents 120 | self.action_dim = action_dim 121 | self.device = device 122 | 123 | def search(self, action_init, support_ply, object_ply, critic, n_iter=3): 124 | ''' 125 | 126 | :param action_init: (B, 4, 4) 127 | :param critic: 128 | :param n_iter: 129 | :param action_bound: 130 | :return: 131 | ''' 132 | batch_size = action_init.shape[0] 133 | action_init = matrix2vectors(action_init) # (B, 6) 134 | 135 | cem = CEM(num_params=self.action_dim, 136 | batch_size=batch_size, 137 | pop_size=self.pop_size, 138 | parents=self.parents, 139 | mu_init=action_init, 140 | sigma_init=self.sigma_init, 141 | clip=self.clip, 142 | damp=self.damp, 143 | damp_limit=self.damp_limit, 144 | elitism=True, 145 | device=self.device 146 | ) 147 | 148 | best_actions = None 149 | best_scores = None 150 | 151 | with torch.no_grad(): 152 | for iter in range(n_iter): 153 | actions = cem.ask(self.pop_size) 154 | Qs = critic(tr6d=actions.view(self.pop_size * batch_size, -1), 155 | support_ply=support_ply, 156 | object_ply=object_ply).view(batch_size, self.pop_size) 157 | good_actions, good_scores = cem.tell(actions, Qs) 158 | 159 | if best_scores is None: 160 | best_actions = good_actions 161 | best_scores = good_scores 162 | else: 163 | action_index = (best_scores < good_scores).squeeze() 164 | best_actions[action_index] = good_actions[action_index] 165 | print('before assign: ', best_scores) 166 | print('good scores: ', good_scores) 167 | best_scores = torch.max(best_scores, good_scores) 168 | print('after max: ', best_scores) 169 | 170 | if iter == n_iter - 1: 171 | best_actions = vectors2matrix(best_actions) # (B, 4, 4) 172 | return best_actions, best_scores 173 | 174 | def calc_min_dist(p_a, p_b): 175 | ''' 176 | 177 | :param p_a: (n, 3) 178 | :param p_b: (m, 3) 179 | :return: 180 | ''' 181 | aa = np.sum(p_a ** 2, axis=1, keepdims=False) 182 | bb = np.sum(p_b ** 2, axis=1, keepdims=False) 183 | n = p_a.shape[0] 184 | m = p_b.shape[0] 185 | 186 | a_ = np.reshape(p_a, (n, 1, 1, 3)) 187 | b_ = np.reshape(p_b, (1, m, 3, 1)) 188 | ab_ = np.matmul(a_, b_)[..., 0, 0] # (n, m) 189 | aa_ = np.repeat(np.reshape(aa, (n, 1)), axis=1, repeats=m) 190 | bb_ = np.repeat(np.reshape(bb, (1, m)), axis=0, repeats=n) 191 | dist = np.sqrt(aa_+bb_-2*ab_) 192 | return dist 193 | 194 | def heuristic_filter(points_a, points_b, thresh=0.018, d_th=0.65): 195 | dist = calc_min_dist(points_a, points_b) 196 | dist_a = np.min(dist, axis=1, keepdims=False) 197 | nearby_points = points_a[dist_a < thresh] 198 | if len(nearby_points) < 10: 199 | return False 200 | 201 | clustering = AgglomerativeClustering(n_clusters=None, 202 | distance_threshold=d_th).fit(nearby_points) 203 | labels = clustering.labels_ 204 | label_ids = np.unique(labels) 205 | 206 | if len(label_ids) < 2: 207 | return False 208 | for id_ in label_ids: 209 | if np.sum(labels == id_) < 5: 210 | return False 211 | return True 212 | 213 | def do_filtering(support_ply, object_ply, ): 214 | scores = [] 215 | sup_np = support_ply.detach().cpu().numpy() 216 | obj_np = object_ply.detach().cpu().numpy() 217 | B = sup_np.shape[0] 218 | 219 | for i in range(B): 220 | f = heuristic_filter(sup_np[i], obj_np[i]) 221 | scores.append(float(f)) 222 | scores = torch.from_numpy(np.array(scores)).to(support_ply.device) 223 | return scores 224 | 225 | 226 | class Critic(object): 227 | def __init__(self, model: nn.Module, device, mini_batch=4, use_filter=False): 228 | self.model = model.to(device) 229 | self.model.eval() 230 | self.mini_batch = mini_batch 231 | self.use_filter = use_filter 232 | self.device = device 233 | 234 | def __call__(self, tr6d, support_ply, object_ply): 235 | ''' 236 | 237 | :param pose_6d: (B*pop_size, 6) 238 | :param support: (N, 3) 239 | :param object: (N, 3) 240 | :return: scores: (B*pop_size) 241 | ''' 242 | 243 | transform = vectors2matrix(tr6d) # (M, 4, 4) 244 | object_ply = apply_multi_transforms(transform, object_ply) # (B, N, 3) 245 | B = object_ply.shape[0] 246 | N1 = object_ply.shape[1] 247 | support_ply = support_ply.unsqueeze(0).repeat(B, 1, 1) # (B, N, 3) 248 | N2 = support_ply.shape[1] 249 | 250 | data = torch.cat([support_ply, object_ply], 1).permute(0, 2, 1) # (B, 3, 2*N) 251 | 252 | mask = torch.zeros((B, 1, N1+N2), device=data.device, dtype=data.dtype) 253 | mask[:, :, N2:] = 1 254 | 255 | data = torch.cat([data, mask], 1) 256 | sample = {'data': data} 257 | if 'cuda' in self.device: 258 | sample = dict2cuda(sample) 259 | 260 | ret = infer_mini_batch(self.model, sample, self.mini_batch) 261 | probs = torch.softmax(ret['preds'][0], 1)[:, 1] # (M, ) 262 | 263 | if self.use_filter: 264 | scores = do_filtering(support_ply, object_ply) 265 | probs *= scores 266 | 267 | return probs 268 | 269 | def sample_from_gaussian(d, batch_size, num_samples): 270 | m = MultivariateNormal(torch.zeros(d), torch.eye(d)) 271 | z_noise = m.sample((batch_size, num_samples)) 272 | z_noise = z_noise.permute(0, 2, 1) # (B, 3, M) 273 | return z_noise 274 | 275 | def write_ply(points, colors, save_path): 276 | import os 277 | import os.path as osp 278 | import open3d as o3d 279 | 280 | if colors.max() > 1: 281 | div_ = 255. 282 | else: 283 | div_ = 1. 284 | os.makedirs(osp.dirname(save_path), exist_ok=True) 285 | pcd = o3d.geometry.PointCloud() 286 | pcd.points = o3d.utility.Vector3dVector(points) 287 | pcd.colors = o3d.utility.Vector3dVector(colors / div_) 288 | o3d.io.write_point_cloud(save_path, pcd, write_ascii=False) 289 | 290 | 291 | class Actor(object): 292 | def __init__(self, model: nn.Module, device, z_dim, batch_size=1): 293 | self.model = model.to(device) 294 | self.z_dim = z_dim 295 | self.batch_size = batch_size 296 | assert batch_size == 1 297 | self.model.eval() 298 | self.device = device 299 | 300 | def __call__(self, support_ply, object_ply, n_samp=128): 301 | ''' 302 | 303 | :param support: (N, 3) 304 | :param object: (N, 3) 305 | :return: scores: (M, 4, 4) 306 | ''' 307 | support_ply = support_ply.unsqueeze(0).permute(0, 2, 1) # (1, 3, N) 308 | object_ply = object_ply.unsqueeze(0).permute(0, 2, 1) # (1, 3, N) 309 | 310 | 311 | sample = {'support': support_ply, 'object': object_ply} 312 | z_noise = sample_from_gaussian(self.z_dim, 313 | self.batch_size, 314 | n_samp) 315 | z_noise[:, :, 0] = 0 316 | sample['z_noise'] = z_noise 317 | if 'cuda' in self.device: 318 | sample = dict2cuda(sample) 319 | 320 | pred = self.model(sample)['pred'] # (1, M, 4, 4) 321 | 322 | return pred[0] 323 | 324 | def matrix2vectors(matrix): 325 | ''' 326 | 327 | :param matrix: (B, 4, 4) 328 | :return: (B, 6) 329 | ''' 330 | assert matrix.shape[1:] == (4, 4) 331 | 332 | rot_vec = torch_transform.matrix_to_euler_angles(matrix[:, :3, :3], 'XYZ') # (B, 3) 333 | trs_vec = matrix[:, :3, 3] # (B, 3) 334 | return torch.cat([trs_vec, rot_vec], 1) # (B, 6) 335 | 336 | def vectors2matrix(vec6d): 337 | ''' 338 | 339 | :param vec6d: (B, 6) 340 | :return: (B, 4, 4) 341 | ''' 342 | assert vec6d.shape[1: ] == (6, ) 343 | B = vec6d.shape[0] 344 | rot = torch_transform.euler_angles_to_matrix(vec6d[:, 3:], 'XYZ') # (B, 3, 3) 345 | trs = vec6d[:, :3].unsqueeze(2) # (B, 3, 1) 346 | transform = torch.cat([rot, trs], dim=2) # (B, 3, 4) 347 | ones = torch.tensor([0, 0, 0, 1], 348 | device=transform.device).view(1, 1, 4) 349 | ones = ones.repeat((B, 1, 1)) 350 | transform = torch.cat([transform, ones], dim=1) # (B, 4, 4) 351 | return transform 352 | 353 | def apply_transform(transform, points): 354 | ''' 355 | 356 | :param transform: (4, 4) 357 | :param points: (N, 3) 358 | :return: 359 | ''' 360 | N = points.shape[0] 361 | ones = torch.ones((N, 1), device=points.device, dtype=points.dtype) 362 | points = torch.cat([points, ones], 1).unsqueeze(2) # (N, 4, 1) 363 | points_t = torch.matmul(transform.unsqueeze(0), points) # (N, 4, 1) 364 | points_t = points_t[..., :3, 0] # (N, 3) 365 | return points_t 366 | 367 | def apply_multi_transforms(transforms, points): 368 | ''' 369 | 370 | :param transforms: 371 | :param points: 372 | :return: 373 | ''' 374 | ret = [] 375 | for t in transforms: 376 | ret.append(apply_transform(t, points)) 377 | return torch.stack(ret, 0) 378 | 379 | def infer_mini_batch(model, data:dict, batch_size=16): 380 | B = data['data'].shape[0] 381 | assert B % batch_size == 0 382 | N = B // batch_size 383 | rets = {} 384 | for i in range(N): 385 | batch_data = {} 386 | for k, v in data.items(): 387 | batch_data[k] = v[i*batch_size:(i+1)*batch_size] 388 | ret_i:dict = model(batch_data) 389 | for k, v in ret_i.items(): 390 | rets[k] = rets.get(k, []) + [v] 391 | for k, v in rets.items(): 392 | if isinstance(v[0], torch.Tensor): 393 | rets[k] = torch.cat(v, dim=0) 394 | elif isinstance(v[0], list): 395 | cat_num = len(v[0]) 396 | ret_k = [] 397 | for i in range(cat_num): 398 | cat_i = torch.cat([x[i] for x in v], dim=0) 399 | ret_k.append(cat_i) 400 | rets[k] = ret_k 401 | else: 402 | raise 403 | return rets 404 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 成硕 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Learning2Regrasp 2 | 3 | ### Learning to Regrasp by Learning to Place, CoRL 2021. 4 | 5 | ## Introduction 6 | We propose a point-cloud-based system for robots to predict a sequence of pick-and-place operations for transforming an initial object grasp pose to the desired object grasp poses. We introduce a new and challenging synthetic dataset for learning and evaluating the proposed approach. If you find this project useful for your research, please cite: 7 | 8 | 9 | ``` 10 | @inproceedings{ 11 | cheng2021learning, 12 | title={Learning to Regrasp by Learning to Place}, 13 | author={Shuo Cheng and Kaichun Mo and Lin Shao}, 14 | booktitle={5th Annual Conference on Robot Learning }, 15 | year={2021}, 16 | url={https://openreview.net/forum?id=Qdb1ODTQTnL} 17 | } 18 | ``` 19 | Real-world regrasping demo: 20 | 21 | ![regrasp](imgs/demo.gif) 22 | 23 | ## How to Use 24 | 25 | ### Environment 26 | * python 3.8 (Anaconda) 27 | * ``pip install -r requirements.txt`` 28 | 29 | ### Dataset 30 | Visualization of sample stable poses: 31 | 32 | ![regrasp](imgs/dataset.png) 33 | 34 | Please download the [dataset](https://drive.google.com/file/d/1r-sAMhHJuIJgawzDSEuW5jKdrvIp-7pi/view?usp=sharing) and place it inside this folder. 35 | 36 | ### Reproducing Results 37 | 38 | * Evaluating synthetic data: ``python scripts/evaluate_testset.py`` 39 | * Evaluating real data: ``bash scripts/test_real_data.sh`` 40 | 41 | 42 | ### Test Your Own Data: 43 | * Please organize your data in the ``real_data`` folder as the example provided 44 | * Please make your data as clean and complete as possible since an offset ``(x_mean, y_mean, z_min)`` will be subtracted for centralizing the point cloud 45 | 46 | 47 | ### Training 48 | * Train generator: ``bash scripts/train_pose_generation.sh`` 49 | * Train classifier: ``bash scripts/train_multi_task.sh`` 50 | -------------------------------------------------------------------------------- /checkpoints/classifier/model_00160000.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/touristCheng/Learning2Regrasp/8152c539af6538fd8f4b9fe328ec4ca314abd74c/checkpoints/classifier/model_00160000.ckpt -------------------------------------------------------------------------------- /checkpoints/generator/model_00368000.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/touristCheng/Learning2Regrasp/8152c539af6538fd8f4b9fe328ec4ca314abd74c/checkpoints/generator/model_00368000.ckpt -------------------------------------------------------------------------------- /evaluate_poses.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import time 3 | from glob import glob 4 | import os 5 | import numpy as np 6 | import pybullet as p 7 | import pybullet_data 8 | import trimesh.transformations as T 9 | import argparse 10 | import json 11 | from tqdm import tqdm 12 | 13 | parser = argparse.ArgumentParser(description='Simulate pose.') 14 | parser.add_argument('--obj_path', type=str, default='./test_urdf/folk2r651.urdf') 15 | parser.add_argument('--sup_path', type=str, default='./test_urdf/bowl1r182.urdf') 16 | parser.add_argument('--init_obj_pose', type=str, default='./test_plys/folk2r651_init_pose.txt') 17 | parser.add_argument('--init_sup_pose', type=str, default='./test_plys/bowl1r182_init_pose.txt') 18 | parser.add_argument('--transforms', type=str, default='') 19 | parser.add_argument('--save_dir', type=str, default='./debug', 20 | help='path to save records.') 21 | parser.add_argument('--render', action='store_true') 22 | 23 | args = parser.parse_args() 24 | 25 | MAX_OBS_TIME=20 26 | LIN_V_TH=0.005 27 | ANG_V_TH=0.1 28 | 29 | # set up pybullet environment 30 | if args.render: 31 | physicsClient = p.connect(p.GUI) # turn off 32 | else: 33 | physicsClient = p.connect(p.DIRECT) # turn off 34 | 35 | p.setAdditionalSearchPath(pybullet_data.getDataPath()) 36 | 37 | def step_simulation(n, p): 38 | for i in range(n): 39 | p.stepSimulation() 40 | 41 | def p_enable_physics(p): 42 | p.setPhysicsEngineParameter(enableConeFriction=1) 43 | p.setPhysicsEngineParameter(contactBreakingThreshold=0.001) 44 | p.setPhysicsEngineParameter(allowedCcdPenetration=0.0) 45 | p.setPhysicsEngineParameter(numSolverIterations=40) 46 | p.setPhysicsEngineParameter(numSubSteps=40) 47 | p.setPhysicsEngineParameter(constraintSolverType=p.CONSTRAINT_SOLVER_LCP_DANTZIG, globalCFM=0.000001) 48 | p.setPhysicsEngineParameter(enableFileCaching=0) 49 | p.setTimeStep(1 / 100.0) 50 | p.setGravity(0, 0, -9.81) 51 | 52 | p_enable_physics(p) 53 | p.resetDebugVisualizerCamera(cameraDistance=0.4, cameraYaw=30, cameraPitch=-50, cameraTargetPosition=[0,0,0]) 54 | planeId = p.loadURDF("plane.urdf") 55 | 56 | def load_transforms(pose_dir,): 57 | all_files = glob('{}/*.npy'.format(pose_dir)) 58 | all_transforms = np.array([np.load(f)[0] for f in all_files]) # (n, 4, 4) 59 | return all_transforms 60 | 61 | def simulate(p, sup, obj): 62 | p.resetBaseVelocity(sup, [0, 0, 0], [0, 0, 0]) 63 | p.resetBaseVelocity(obj, [0, 0, 0], [0, 0, 0]) 64 | tic = time.time() 65 | obj_trs, obj_quat = p.getBasePositionAndOrientation(obj) 66 | sup_trs, sup_quat = p.getBasePositionAndOrientation(sup) 67 | 68 | while True: 69 | step_simulation(1, p) 70 | cur_obj_trs, cur_obj_quat = p.getBasePositionAndOrientation(obj) 71 | cur_sup_trs, cur_sup_quat = p.getBasePositionAndOrientation(sup) 72 | if not same_pose(init_pose_7q=obj_trs + obj_quat, 73 | cur_pose_7q=cur_obj_trs + cur_obj_quat, 74 | dist_th=0.03, ang_th=30): 75 | # print('object change! ') 76 | return False 77 | 78 | if not same_pose(init_pose_7q=sup_trs + sup_quat, 79 | cur_pose_7q=cur_sup_trs + cur_sup_quat, 80 | dist_th=0.02, ang_th=10): 81 | # print('support change! ') 82 | return False 83 | 84 | if time.time() - tic > MAX_OBS_TIME or simulation_stoped(p, obj): 85 | return True 86 | 87 | def same_pose(init_pose_7q, cur_pose_7q, dist_th, ang_th): 88 | assert len(init_pose_7q) == len(cur_pose_7q) 89 | assert len(init_pose_7q) == 7 90 | 91 | init_pose_7q = np.array(init_pose_7q) 92 | cur_pose_7q = np.array(cur_pose_7q) 93 | init_quat = init_pose_7q[3:] 94 | cur_quat = cur_pose_7q[3:] 95 | init_trs = init_pose_7q[:3] 96 | cur_trs = cur_pose_7q[:3] 97 | 98 | assert len(cur_trs) == 3 99 | assert len(cur_quat) == 4 100 | 101 | tmp = np.clip(np.abs(np.sum(init_quat*cur_quat)), 0., 1., ) 102 | deg_diff = 2 * 180 / np.pi * np.arccos(tmp) 103 | 104 | if deg_diff > ang_th: 105 | # print('degree change: ', deg_diff) 106 | return False 107 | 108 | trs_diff = np.sqrt(((init_trs - cur_trs)**2).sum()) 109 | if trs_diff > dist_th: 110 | # print('position change: ', trs_diff) 111 | return False 112 | 113 | return True 114 | 115 | def simulation_stoped(p, sID): 116 | lin_v, ang_v = p.getBaseVelocity(sID) 117 | if np.allclose(np.array(lin_v), np.zeros((3, )), rtol=1, atol=LIN_V_TH) \ 118 | and np.allclose(np.array(ang_v), np.zeros((3, )), rtol=1, atol=ANG_V_TH): 119 | return True 120 | else: 121 | return False 122 | 123 | def set_transform(subjectId, transform): 124 | trs = transform[:3, 3] 125 | rot = transform[:3, :3] 126 | euler = T.euler_from_matrix(rot, 'sxyz') 127 | quat = p.getQuaternionFromEuler(euler) 128 | p.resetBasePositionAndOrientation(subjectId, trs, quat) 129 | 130 | def transform_to_pose7(transform): 131 | transform = np.array(transform) 132 | euler = T.euler_from_matrix(transform[:3, :3], 'sxyz') 133 | trs = transform[:3, 3] 134 | quat = p.getQuaternionFromEuler(euler) 135 | pose7 = trs.tolist() + list(quat) 136 | return pose7 137 | 138 | def count_diff_pose(all_poses): 139 | pose_buff = [] 140 | for T in all_poses: 141 | if not pose_buff: 142 | pose_buff.append(T) 143 | else: 144 | uniq = True 145 | for Tb in pose_buff: 146 | Tb_7 = transform_to_pose7(Tb) 147 | T_7 = transform_to_pose7(T) 148 | if same_pose(Tb_7, T_7, 0.03, 30): 149 | uniq = False 150 | break 151 | if uniq: 152 | pose_buff.append(T) 153 | 154 | return len(pose_buff) 155 | 156 | def write_file(path, data_list): 157 | dir_name = osp.dirname(path) 158 | if dir_name: 159 | os.makedirs(dir_name, exist_ok=True) 160 | with open(path, 'w') as f: 161 | json.dump(data_list, f) 162 | 163 | def process_one_pair(supportId, objectId, all_transforms, init_sup_pose, init_obj_pose): 164 | 165 | results = [] 166 | for transform in tqdm(all_transforms): 167 | 168 | p.resetBaseVelocity(supportId, [0, 0, 0], [0, 0, 0]) 169 | p.resetBaseVelocity(objectId, [0, 0, 0], [0, 0, 0]) 170 | p.resetBasePositionAndOrientation(objectId, [0, 0, 10], 171 | p.getQuaternionFromEuler([0, 0, 0])) 172 | 173 | set_transform(supportId, init_sup_pose) 174 | 175 | tic = time.time() 176 | while True: 177 | step_simulation(1, p) 178 | if simulation_stoped(p, supportId) or time.time() - tic > MAX_OBS_TIME: 179 | step_simulation(5, p) 180 | break 181 | 182 | real_transform = transform @ init_obj_pose 183 | real_transform[2, 3] += 0.01 184 | set_transform(objectId, real_transform) 185 | 186 | p.stepSimulation() 187 | 188 | contacts1 = p.getContactPoints(supportId, objectId) 189 | contacts2 = p.getContactPoints(planeId, objectId) 190 | contacts = contacts1 + contacts2 191 | 192 | fail = False 193 | for c in contacts: 194 | if c[8] < -0.00001: 195 | fail = True 196 | break 197 | 198 | if not fail: 199 | result = simulate(p, supportId, objectId) 200 | else: 201 | result = False 202 | results.append(result) 203 | 204 | return np.array(results) 205 | 206 | def process(args): 207 | supportId = p.loadURDF(args.sup_path,) 208 | objectId = p.loadURDF(args.obj_path,) 209 | 210 | init_sup_pose = np.loadtxt(args.init_sup_pose) 211 | init_obj_pose = np.loadtxt(args.init_obj_pose) 212 | 213 | all_transforms = load_transforms(args.transforms) 214 | 215 | all_paths = glob('{}/*.npy'.format(args.transforms)) 216 | probs = np.array([float(osp.basename(f).split('_')[0]) for f in all_paths]) 217 | thresh = [0.6, 0.7, 0.8] 218 | 219 | all_transforms = all_transforms[probs >= thresh[0]] 220 | probs = probs[probs >= thresh[0]] 221 | 222 | 223 | os.makedirs(args.save_dir, exist_ok=True) 224 | sup_name = osp.basename(args.sup_path).split('.')[0] 225 | obj_name = osp.basename(args.obj_path).split('.')[0] 226 | 227 | results = process_one_pair(supportId=supportId, objectId=objectId, 228 | all_transforms=all_transforms, 229 | init_obj_pose=init_obj_pose, 230 | init_sup_pose=init_sup_pose) 231 | 232 | 233 | report = {'cnt': {}, 'acc': {}} 234 | for th in thresh: 235 | inds = probs > th 236 | acc_t = np.nanmean(results[inds]) 237 | 238 | select_poses = all_transforms[(probs > th) * results] 239 | cnt_t = count_diff_pose(select_poses) 240 | 241 | report['acc'][th] = str(acc_t) 242 | report['cnt'][th] = str(cnt_t) 243 | 244 | print(report) 245 | save_path = '{}/{}-{}.json'.format(args.save_dir, sup_name, obj_name) 246 | write_file(save_path, [report]) 247 | 248 | 249 | process(args) 250 | p.disconnect() 251 | 252 | -------------------------------------------------------------------------------- /imgs/dataset.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/touristCheng/Learning2Regrasp/8152c539af6538fd8f4b9fe328ec4ca314abd74c/imgs/dataset.png -------------------------------------------------------------------------------- /imgs/demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/touristCheng/Learning2Regrasp/8152c539af6538fd8f4b9fe328ec4ca314abd74c/imgs/demo.gif -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import os.path as osp 4 | import time 5 | 6 | import numpy as np 7 | import open3d as o3d 8 | import torch 9 | import torch.backends.cudnn as cudnn 10 | import torch.nn as nn 11 | import torch.nn.parallel 12 | from tqdm import tqdm 13 | 14 | from ES import Searcher, Critic, Actor, apply_transform, matrix2vectors 15 | from pose_check.models.uninet_mt import UniNet_MT_V2 16 | from pose_generation.models.vnet import VNet 17 | 18 | cudnn.benchmark = True 19 | 20 | parser = argparse.ArgumentParser(description='Test UCSNet.') 21 | 22 | parser.add_argument('--root_path', type=str, help='path to root directory.') 23 | parser.add_argument('--test_list', type=str, default='./pose_generation/dataset/test_list.txt') 24 | parser.add_argument('--save_path', type=str, help='path to save depth maps.') 25 | parser.add_argument('--real_data', action='store_true') 26 | parser.add_argument('--render_ply', action='store_true') 27 | parser.add_argument('--filter', action='store_true') 28 | 29 | #test parameters 30 | parser.add_argument('--generator_ckpt', type=str, help='the path for pre-trained model.', 31 | default='./checkpoints/') 32 | parser.add_argument('--stable_critic_ckpt', type=str, 33 | default='./checkpoints/') 34 | parser.add_argument('--pose_num', type=int, default=16) 35 | parser.add_argument('--rot_rp', type=str, default='6d') 36 | parser.add_argument('--z_dim', type=int, default=3) 37 | parser.add_argument('--num_iter', type=int, default=2) 38 | 39 | parser.add_argument('--device', type=str, default='cuda') 40 | 41 | args = parser.parse_args() 42 | 43 | def read_ply(path, pc_len): 44 | pcd = o3d.io.read_point_cloud(path) 45 | point_cloud = np.asarray(pcd.points) 46 | colors = np.asarray(pcd.colors) 47 | if len(point_cloud) < pc_len: 48 | ind = np.random.choice(len(point_cloud), pc_len-len(point_cloud)) 49 | point_cloud = np.concatenate([point_cloud, point_cloud[ind]], 0) 50 | elif len(point_cloud) > pc_len: 51 | ind = np.random.choice(len(point_cloud), pc_len) 52 | point_cloud = point_cloud[ind] 53 | return point_cloud 54 | 55 | def load_mv_ply(path, num_v=2, pc_len=1024): 56 | assert num_v <= 4 57 | pcs = [] 58 | for i in range(num_v): 59 | dir_name = osp.dirname(path) 60 | base_name = osp.basename(path).split('.')[0]+'.v{:04d}.ply'.format(i) 61 | path_i = osp.join(dir_name, base_name) 62 | pcs.append(read_ply(path_i, pc_len)) 63 | 64 | point_cloud = np.concatenate(pcs, 0) 65 | if len(point_cloud) < pc_len: 66 | ind = np.random.choice(len(point_cloud), pc_len-len(point_cloud)) 67 | point_cloud = np.concatenate([point_cloud, point_cloud[ind]], 0) 68 | elif len(point_cloud) > pc_len: 69 | ind = np.random.choice(len(point_cloud), pc_len) 70 | point_cloud = point_cloud[ind] 71 | return point_cloud 72 | 73 | def write_ply(points, colors, save_path): 74 | if colors.max() > 1: 75 | div_ = 255. 76 | else: 77 | div_ = 1. 78 | 79 | dir_name = osp.dirname(save_path) 80 | os.makedirs(dir_name, exist_ok=True) 81 | 82 | pcd = o3d.geometry.PointCloud() 83 | pcd.points = o3d.utility.Vector3dVector(points) 84 | pcd.colors = o3d.utility.Vector3dVector(colors / div_) 85 | o3d.io.write_point_cloud(save_path, pcd, write_ascii=False) 86 | 87 | 88 | def load_data(data_root, data_list, pc_len=1024, is_real=True): 89 | for subject_names in data_list: 90 | subjects = [] 91 | for name in subject_names: 92 | sub_path = '{}/{}.ply'.format(data_root, name) 93 | if is_real: 94 | subject_ply = read_ply(sub_path, pc_len=pc_len) 95 | else: 96 | subject_ply = load_mv_ply(sub_path, pc_len=pc_len, num_v=2) 97 | print('pc shape: ', subject_ply.shape) 98 | subject_tensor = torch.from_numpy(subject_ply).float().to(args.device) # (N, 3) 99 | subjects.append(subject_tensor) 100 | yield subjects 101 | 102 | def main(args): 103 | # build model 104 | # support = 0, object = 1, mask shape (B, 4, N) 105 | stable_critic = UniNet_MT_V2(mask_channel=True, only_test=True) 106 | 107 | generator = VNet(mask_channel=False, rot_rep=args.rot_rp, 108 | z_dim=args.z_dim, obj_feat=128, sup_feat=128, z_feat=64, 109 | only_test=True) 110 | 111 | # load checkpoint file specified by args.loadckpt 112 | print("Loading model {} ...".format(args.generator_ckpt)) 113 | g_state_dict = torch.load(args.generator_ckpt, map_location=torch.device("cpu")) 114 | generator.load_state_dict(g_state_dict['model'], strict=True) 115 | print('Success!') 116 | 117 | print("Loading model {} ...".format(args.stable_critic_ckpt)) 118 | s_state_dict = torch.load(args.stable_critic_ckpt, map_location=torch.device("cpu")) 119 | stable_critic.load_state_dict(s_state_dict['model'], strict=True) 120 | print('Success!') 121 | 122 | generator = nn.DataParallel(generator) 123 | generator.to(args.device) 124 | generator.eval() 125 | 126 | stable_critic = nn.DataParallel(stable_critic) 127 | stable_critic.to(args.device) 128 | stable_critic.eval() 129 | 130 | critic = Critic(stable_critic, device=args.device, mini_batch=64, use_filter=args.filter) 131 | actor = Actor(generator, device=args.device, z_dim=args.z_dim, batch_size=1) 132 | data_list = open(args.test_list, 'r').readlines() 133 | data_list = list(map(lambda x: str(x).strip().split('-'), data_list)) 134 | 135 | data_loader = load_data(args.root_path, data_list, 136 | pc_len=1024, is_real=args.real_data) 137 | 138 | for j, candidates in enumerate(tqdm(data_loader)): 139 | pair_id = '-'.join(data_list[j]) 140 | print('Processing {} ...'.format(pair_id)) 141 | 142 | solutions = search_solution(candidates=candidates, 143 | actor=actor, critic=critic, 144 | centralize=True, 145 | num_iter=args.num_iter, 146 | n_samp=args.pose_num, 147 | ) 148 | print('Total solutions: ', len(solutions)) 149 | 150 | save_predictions(candidates, solutions, pair_id, render_ply=args.render_ply) 151 | 152 | del candidates 153 | del solutions 154 | 155 | torch.cuda.empty_cache() 156 | 157 | def post_refine(support_ply, object_ply, init_transform, critic, num_iter=2): 158 | ''' 159 | 160 | :param support_ply: (N, 3) 161 | :param object_ply: (N, 3) 162 | :param init_transform: (B, 4, 4) 163 | :param critic: 164 | :return: 165 | ''' 166 | if num_iter == 0: 167 | init_transform_6d = matrix2vectors(init_transform) 168 | scores = critic(tr6d=init_transform_6d, 169 | support_ply=support_ply, 170 | object_ply=object_ply) 171 | return init_transform, scores 172 | 173 | cem_searcher = Searcher(action_dim=6, pop_size=4, parents=2, sigma_init=1e-4, 174 | clip=0.003, damp=0.001, damp_limit=0.00001, device=init_transform.device) 175 | refined_transforms, scores = cem_searcher.search(action_init=init_transform, 176 | support_ply=support_ply, 177 | object_ply=object_ply, 178 | critic=critic, 179 | n_iter=num_iter, 180 | ) 181 | return refined_transforms, scores 182 | 183 | def search_solution(candidates, actor, critic, centralize, num_iter=2, n_samp=64,): 184 | solutions = [] 185 | 186 | def dfs(support, layer_id, actions=[]): 187 | if layer_id >= len(candidates): 188 | return 189 | selected = candidates[layer_id] # (N, 3) 190 | 191 | tic = time.time() 192 | 193 | if centralize: 194 | assert support.shape[1] == 3 195 | assert len(support.shape) == 2 196 | assert selected.shape[1] == 3 197 | assert len(selected.shape) == 2 198 | 199 | sup_cent = torch.zeros((1, 3), device=support.device, 200 | dtype=support.dtype) 201 | sup_cent[0, :2] = torch.mean(support, 0, keepdim=True)[0, :2] 202 | sup_cent[0, 2] = torch.min(support, 0, keepdim=True)[0][0, 2] 203 | 204 | obj_cent = torch.zeros((1, 3), device=selected.device, 205 | dtype=selected.dtype) 206 | obj_cent[0, :2] = torch.mean(selected, 0, keepdim=True)[0, :2] 207 | obj_cent[0, 2] = torch.min(selected, 0, keepdim=True)[0][0, 2] 208 | 209 | support -= sup_cent 210 | selected -= obj_cent 211 | 212 | # write_ply(support, np.zeros_like(support), './debug_support.ply') 213 | # write_ply(selected, np.zeros_like(selected), './debug_object.ply') 214 | 215 | 216 | proposals = actor(support, selected, n_samp=n_samp) # (M, 4, 4) 217 | print('# Time [actor]: {:.2f}'.format(time.time() - tic)) 218 | 219 | tic = time.time() 220 | proposals, scores = post_refine(support, selected, proposals, critic, 221 | num_iter=num_iter) # (M, 4, 4), (M, ) 222 | print('# Time [post refine]: {:.2f}'.format(time.time() - tic)) 223 | 224 | 225 | if centralize: 226 | support += sup_cent 227 | selected += obj_cent 228 | base2cent = torch.eye(4, dtype=proposals.dtype, device=proposals.device).view((1, 4, 4)) 229 | base2cent[0, :3, 3] = -obj_cent[0, :3] 230 | cent2base = torch.eye(4, dtype=proposals.dtype, device=proposals.device).view((1, 4, 4)) 231 | cent2base[0, :3, 3] = sup_cent[0, :3] 232 | proposals = cent2base @ (proposals @ base2cent) 233 | 234 | print('layer {} scores: '.format(layer_id), scores) 235 | 236 | # proposals = proposals[scores >= 0.5] 237 | # scores = scores[scores >= 0.5] 238 | print('search layer {}, keep nodes: '.format(layer_id), proposals.shape, scores.shape) 239 | 240 | for action_i, score_i in zip(proposals, scores): 241 | actions.append((action_i.detach(), score_i.detach())) 242 | if layer_id == len(candidates)-1: 243 | # collect action seq 244 | solutions.append(actions.copy()) 245 | else: 246 | selected_t = apply_transform(action_i, selected) # (N, 3) 247 | next_support = torch.cat([support, selected_t], ) # (2*N, 3) 248 | dfs(next_support, layer_id+1, actions) 249 | actions.pop() 250 | 251 | with torch.no_grad(): 252 | # [s, o_i, ...] 253 | dfs(candidates[0], 1, []) 254 | return solutions 255 | 256 | def save_predictions(candidations, solutions, pair_id, render_ply): 257 | save_dir = osp.join(args.save_path, pair_id) 258 | os.makedirs(save_dir, exist_ok=True) 259 | for ind, solution in enumerate(solutions): 260 | save_one_pair(candidations, solution, save_dir, ind, render_ply=render_ply) 261 | 262 | def save_one_pair(point_clouds, solution, save_dir, index, render_ply=True): 263 | t2n = lambda x: x.detach().cpu().numpy() 264 | colors = [[20, 20, 160], [20, 160, 200]] 265 | 266 | scores = ['{:.2f}'.format((np.round(x[1].item(), 2))) for x in solution] 267 | transforms = [x[0] for x in solution] 268 | 269 | file_name = '_'.join(scores) + '_{:04d}.ply'.format(index) 270 | 271 | transforms_np = list(map(t2n, transforms)) 272 | mat_name = file_name.replace('.ply', '.npy') 273 | np.save(osp.join(save_dir, mat_name), transforms_np) 274 | 275 | if not render_ply: 276 | return 277 | 278 | assert len(transforms) + 1 == len(point_clouds) 279 | assert len(point_clouds) == len(colors) 280 | 281 | ret = [point_clouds[0], ] 282 | 283 | for i in range(len(point_clouds)-1): 284 | subject_i = apply_transform(transforms[i], point_clouds[i+1]) 285 | ret.append(subject_i) 286 | 287 | ply_buffers = [] 288 | for i in range(len(ret)): 289 | points = t2n(ret[i]) 290 | color = np.ones((len(points), 1)) @ np.array(colors[i]).reshape((1, 3)) 291 | subject_ply = np.concatenate([points, color], 1) # (N, 6) 292 | ply_buffers.append(subject_ply) 293 | full_ply = np.concatenate(ply_buffers, 0) 294 | 295 | write_ply(full_ply[:, :3], full_ply[:, 3:], osp.join(save_dir, file_name)) 296 | 297 | 298 | 299 | if __name__ == '__main__': 300 | with torch.no_grad(): 301 | main(args) -------------------------------------------------------------------------------- /pose_check/data_loader/pose_check_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | from torch.utils.data import DataLoader 4 | from scipy.spatial.transform import Rotation as R 5 | import os.path as osp 6 | import numpy as np 7 | import json 8 | import os 9 | from collections import OrderedDict 10 | import open3d as o3d 11 | 12 | def write_file(path, data_list): 13 | dir_name = osp.dirname(path) 14 | if dir_name: 15 | os.makedirs(dir_name, exist_ok=True) 16 | with open(path, 'w') as f: 17 | json.dump(data_list, f) 18 | 19 | def read_file(path): 20 | with open(path, 'r') as fd: 21 | data_list = json.load(fd, object_hook=OrderedDict) 22 | return data_list 23 | 24 | def load_ply(path, pc_len): 25 | pcd = o3d.io.read_point_cloud(path) 26 | point_cloud = np.asarray(pcd.points) 27 | colors = np.asarray(pcd.colors) 28 | if len(point_cloud) < pc_len: 29 | ind = np.random.choice(len(point_cloud), pc_len-len(point_cloud)) 30 | point_cloud = np.concatenate([point_cloud, point_cloud[ind]], 0) 31 | elif len(point_cloud) > pc_len: 32 | ind = np.random.choice(len(point_cloud), pc_len) 33 | point_cloud = point_cloud[ind] 34 | return point_cloud 35 | 36 | def load_mv_ply(root_dir, path, pc_len): 37 | flag = float(np.random.uniform(0, 1, ())) 38 | if flag < 0.25: 39 | num_v = 1 40 | else: 41 | num_v = 2 42 | v_inds = np.random.choice(range(4), num_v, replace=False) 43 | pcs = [] 44 | for i in range(num_v): 45 | 46 | path_i = path.rsplit('.ply', 1)[0]+'.v{:04d}.ply'.format(int(v_inds[i])) 47 | pcs.append(load_ply(osp.join(root_dir, path_i), pc_len)) 48 | point_cloud = np.concatenate(pcs, 0) 49 | if len(point_cloud) < pc_len: 50 | ind = np.random.choice(len(point_cloud), pc_len-len(point_cloud)) 51 | point_cloud = np.concatenate([point_cloud, point_cloud[ind]], 0) 52 | elif len(point_cloud) > pc_len: 53 | ind = np.random.choice(len(point_cloud), pc_len) 54 | point_cloud = point_cloud[ind] 55 | return point_cloud 56 | 57 | def load_collision_data(root_dir, list_path, samples=None): 58 | data_list = open(list_path, 'r').readlines() 59 | data_list = list(map(str.strip, data_list)) 60 | all_odgts = [] 61 | sim_odgts = [] 62 | for line in data_list: 63 | one_odgt = read_file(osp.join(root_dir, line)) 64 | if line.startswith('simulation'): 65 | one_odgt = list(filter(lambda x: x['stable'], one_odgt)) 66 | sim_odgts += one_odgt 67 | else: 68 | all_odgts += one_odgt 69 | 70 | np.random.shuffle(sim_odgts) 71 | sim_odgts = sim_odgts[:min(50000, len(sim_odgts))] 72 | all_odgts += sim_odgts 73 | np.random.shuffle(all_odgts) 74 | print('hard samples: ', len(sim_odgts)) 75 | print('total samples: ', len(all_odgts)) 76 | if samples: 77 | print('Sample ratio: ', samples) 78 | all_odgts = all_odgts[:samples] 79 | 80 | return all_odgts 81 | 82 | def load_multi_task_data(root_dir, list_path): 83 | data_list = open(list_path, 'r').readlines() 84 | data_list = list(map(str.strip, data_list)) 85 | 86 | random_data = [] 87 | simulation_data = [] 88 | for line in data_list: 89 | one_pair_data = read_file(osp.join(root_dir, line)) 90 | file_name = osp.basename(line) 91 | if file_name.startswith('random'): 92 | random_data += one_pair_data 93 | else: 94 | for data in one_pair_data: 95 | if data['stable'] and len(data['support_contact']) == 0: 96 | continue 97 | simulation_data.append(data) 98 | 99 | print('simulation data: ', len(simulation_data)) 100 | print('random data: ', len(random_data)) 101 | 102 | all_data = random_data + simulation_data 103 | np.random.shuffle(all_data) 104 | return all_data 105 | 106 | class PoseCheckDataset(Dataset): 107 | def __init__(self, root_dir, list_path, samples, label_name='collision', pc_len=1024, use_aug=True): 108 | super(PoseCheckDataset, self).__init__() 109 | self.root_dir = root_dir 110 | 111 | odgt_data = load_collision_data(root_dir, list_path, samples=samples) 112 | 113 | self.data_pairs = odgt_data 114 | print('Total pairs: ', len(self.data_pairs)) 115 | self.pc_len = pc_len 116 | self.label_name = label_name 117 | self.use_aug = use_aug 118 | 119 | def __getitem__(self, index): 120 | items:dict = self.data_pairs[index] 121 | 122 | transform = np.array(items['transform']) 123 | cls = int(items[self.label_name]) 124 | 125 | sup_ply = load_mv_ply(self.root_dir, items['sup_init_path'], 126 | pc_len=self.pc_len) 127 | obj_ply = load_mv_ply(self.root_dir, items['obj_init_path'], 128 | pc_len=self.pc_len) 129 | 130 | obj_ply = apply_transform(transform, obj_ply) # (N, 3) 131 | comp_ply = np.concatenate([sup_ply, obj_ply], 0) 132 | if self.use_aug: 133 | comp_ply = random_transform_points(comp_ply) # (N, 3) 134 | comp_ply = random_offset_points(comp_ply) 135 | 136 | mask_ply = np.zeros((len(comp_ply), 1), dtype='float32') 137 | mask_ply[len(sup_ply):, 0] = 1 138 | 139 | full_ply = np.concatenate([comp_ply, mask_ply], 1) 140 | 141 | full_ply = full_ply.T.astype('float32') # (4, N) 142 | 143 | ret = {'data': full_ply, 'label': cls, 144 | 'pair_id': items['pair_id'], 'index': items['index']} 145 | 146 | return ret 147 | 148 | def __len__(self): 149 | return len(self.data_pairs) 150 | 151 | class MultiTaskDatasetV2(Dataset): 152 | def __init__(self, root_dir, list_path, pc_len=1024, max_contact=200, use_aug=True): 153 | super(MultiTaskDatasetV2, self).__init__() 154 | self.root_dir = root_dir 155 | 156 | data_pairs = load_multi_task_data(root_dir, list_path) 157 | self.data_pairs = data_pairs 158 | 159 | print('Total pairs [{}]: '.format(list_path), len(self.data_pairs)) 160 | self.pc_len = pc_len 161 | self.max_contact = max_contact 162 | self.use_aug = use_aug 163 | 164 | def __getitem__(self, index): 165 | items: dict = self.data_pairs[index] 166 | 167 | transform = np.array(items['transform']) 168 | 169 | sup_ply = load_mv_ply(self.root_dir, items['sup_init_path'], 170 | pc_len=self.pc_len) 171 | obj_ply = load_mv_ply(self.root_dir, items['obj_init_path'], 172 | pc_len=self.pc_len) 173 | 174 | obj_ply = apply_transform(transform, obj_ply) # (N, 3) 175 | comp_ply = np.concatenate([sup_ply, obj_ply], 0) 176 | 177 | contact_label = int(items['contact']) 178 | stable_label = int(items['stable']) 179 | 180 | contact_points = np.zeros((self.max_contact, 3), dtype='float32') 181 | total_contacts = 0 182 | 183 | if contact_label == 1: 184 | if len(items['support_contact']) == 0: 185 | contact_label = 255 186 | 187 | if contact_label == 1: 188 | all_contacts = items['support_contact'] 189 | np.random.shuffle(all_contacts) 190 | total_contacts = min(len(all_contacts), self.max_contact) 191 | contact_points[:total_contacts] = np.array(all_contacts, dtype='float32')[:total_contacts] 192 | 193 | if self.use_aug: 194 | comp_ply, contact_points = random_transform_pair(comp_ply, contact_points) # (N, 3), (M, 3) 195 | comp_ply, contact_points = random_offset_pair(comp_ply, contact_points) 196 | 197 | 198 | mask_ply = np.zeros((len(comp_ply), 1), dtype='float32') 199 | mask_ply[len(sup_ply):, 0] = 1 200 | 201 | full_ply = np.concatenate([comp_ply, mask_ply], 1) 202 | full_ply = full_ply.T.astype('float32') # (4, N) 203 | 204 | ret = {'data': full_ply, 'contact_points': contact_points.astype('float32'), 205 | 'total_contacts': total_contacts, 206 | 'stable': stable_label, 'contact': contact_label, 207 | 'pair_id': items['pair_id'], 'index': items['index']} 208 | # (4, N) 209 | 210 | return ret 211 | 212 | def __len__(self): 213 | return len(self.data_pairs) 214 | 215 | def apply_transform(t, points): 216 | ''' 217 | 218 | :param t: (4, 4) 219 | :param points: (N, 3) 220 | :return: 221 | ''' 222 | N = points.shape[0] 223 | ones = np.ones((N, 1)) 224 | points = np.concatenate([points, ones], 1) # (N, 4) 225 | points = np.expand_dims(points, 2) # (N, 4, 1) 226 | t = np.expand_dims(t, 0) # (1, 4, 4) 227 | points = np.matmul(t, points)[:, :3, 0] # () 228 | return points 229 | 230 | def random_transform_pair(points, contacts): 231 | ''' 232 | 233 | :param points: (N, 3) 234 | :param contacts: (M, 3) 235 | :return: 236 | ''' 237 | deg = float(np.random.uniform(0, 360, size=())) 238 | r = R.from_euler('z', deg, degrees=True) 239 | t = np.eye(4) 240 | t[:3, :3] = r.as_matrix() 241 | 242 | points = apply_transform(t, points) 243 | contacts = apply_transform(t, contacts) 244 | return points, contacts 245 | 246 | def random_transform_points(points): 247 | ''' 248 | 249 | :param points: (N, 3) 250 | :param contacts: (M, 3) 251 | :return: 252 | ''' 253 | 254 | deg = float(np.random.uniform(0, 360, size=())) 255 | r = R.from_euler('z', deg, degrees=True) 256 | t = np.eye(4) 257 | t[:3, :3] = r.as_matrix() 258 | 259 | points = apply_transform(t, points) 260 | return points 261 | 262 | def random_offset_pair(points, contacts): 263 | ''' 264 | 265 | :param points: 266 | :param contacts: 267 | :return: 268 | ''' 269 | xyz_range = np.array([[-0.02, -0.02, -0.002], 270 | [0.02, 0.02, 0.002]]) 271 | offset = np.random.uniform(0, 1, size=(1, 3)) 272 | offset = offset * (xyz_range[1:2]-xyz_range[0:1]) + xyz_range[0:1] 273 | points += offset 274 | contacts += offset 275 | 276 | n = points.shape[0] 277 | sigma = 0.003 278 | noise = np.random.normal(0, sigma, size=(n,)) 279 | points[:, 2] += noise 280 | return points, contacts 281 | 282 | def random_offset_points(points): 283 | ''' 284 | 285 | :param points: 286 | :param contacts: 287 | :return: 288 | ''' 289 | xyz_range = np.array([[-0.02, -0.02, -0.002], 290 | [0.02, 0.02, 0.002]]) 291 | offset = np.random.uniform(0, 1, size=(1, 3)) 292 | offset = offset * (xyz_range[1:2]-xyz_range[0:1]) + xyz_range[0:1] 293 | points += offset 294 | 295 | n = points.shape[0] 296 | sigma = 0.003 297 | noise = np.random.normal(0, sigma, size=(n, 3)) 298 | points += noise 299 | return points 300 | 301 | def debug_visualization(sample): 302 | ''' 303 | 304 | :param transform: (4, 4) 305 | :param result: 306 | :param ind: 307 | :param pair_id: 308 | :return: 309 | ''' 310 | 311 | ply = sample['data'][0].numpy().T # (N, 4) 312 | 313 | n_ = 3000 314 | plane = np.zeros((n_, 3)) 315 | plane[:, :2] = np.random.uniform(-0.2, 0.2, size=(n_, 2)) 316 | plane = np.concatenate([plane, np.zeros((n_, 1))], 1) 317 | ply = np.concatenate([ply, plane], 0) 318 | 319 | points = ply[:, :3] 320 | mask = ply[:, 3:] 321 | 322 | if sample['stable'][0] == 1: 323 | obj_color = [0, 120, 0] 324 | else: 325 | obj_color = [120, 0, 0] 326 | 327 | color = mask @ np.array(obj_color).reshape((1, 3)) + \ 328 | (1-mask) @ np.array([0, 0, 200]).reshape((1, 3)) 329 | 330 | contacts = [] 331 | for pid in range(int(sample['total_contacts'][0])): 332 | cp = sample['contact_points'][0, pid].numpy() 333 | sphere = o3d.geometry.TriangleMesh.create_sphere(radius=0.005) 334 | sphere.translate([cp[0], cp[1], cp[2]]) 335 | sphere.paint_uniform_color([0.9, 0.1, 0.1]) 336 | contacts.append(sphere) 337 | 338 | sphere = o3d.geometry.TriangleMesh.create_sphere(radius=0.005) 339 | sphere.paint_uniform_color([0.5, 0.5, 0.1]) 340 | contacts.append(sphere) 341 | 342 | pcd = o3d.geometry.PointCloud() 343 | pcd.points = o3d.utility.Vector3dVector(points) 344 | pcd.colors = o3d.utility.Vector3dVector(color/255.) 345 | o3d.visualization.draw_geometries([pcd] + contacts) 346 | 347 | if __name__ == '__main__': 348 | # datasetV = PoseCheckDataset('../../../debug/dataset/plys', '../../../debug/dataset/debug.txt') 349 | # loaderV = DataLoader(datasetV, 2, sampler=None, num_workers=1, 350 | # drop_last=False, shuffle=False) 351 | 352 | 353 | # train_set = MultiTaskDatasetV2(root_dir='../../dataset', 354 | # list_path='../../dataset/data_list/train_classifier.txt', 355 | # use_aug=True) 356 | test_set = MultiTaskDatasetV2(root_dir='../../dataset', 357 | list_path='../../dataset/data_list/test_classifier.txt', 358 | use_aug=False) 359 | 360 | loaderV = DataLoader(test_set, 4, sampler=None, num_workers=2, 361 | drop_last=False, shuffle=False) 362 | print('all samples: ', len(loaderV)) 363 | 364 | for sample in loaderV: 365 | print(sample['data'].shape) 366 | print(sample['stable']) 367 | debug_visualization(sample) 368 | -------------------------------------------------------------------------------- /pose_check/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/touristCheng/Learning2Regrasp/8152c539af6538fd8f4b9fe328ec4ca314abd74c/pose_check/models/__init__.py -------------------------------------------------------------------------------- /pose_check/models/pointnet_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from time import time 5 | import numpy as np 6 | 7 | def timeit(tag, t): 8 | print("{}: {}s".format(tag, time() - t)) 9 | return time() 10 | 11 | def pc_normalize(pc): 12 | l = pc.shape[0] 13 | centroid = np.mean(pc, axis=0) 14 | pc = pc - centroid 15 | m = np.max(np.sqrt(np.sum(pc**2, axis=1))) 16 | pc = pc / m 17 | return pc 18 | 19 | def square_distance(src, dst): 20 | """ 21 | Calculate Euclid distance between each two points. 22 | src^T * dst = xn * xm + yn * ym + zn * zm; 23 | sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn; 24 | sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm; 25 | dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2 26 | = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst 27 | Input: 28 | src: source points, [B, N, C] 29 | dst: target points, [B, M, C] 30 | Output: 31 | dist: per-point square distance, [B, N, M] 32 | """ 33 | B, N, _ = src.shape 34 | _, M, _ = dst.shape 35 | dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) 36 | dist += torch.sum(src ** 2, -1).view(B, N, 1) 37 | dist += torch.sum(dst ** 2, -1).view(B, 1, M) 38 | return dist 39 | 40 | 41 | def index_points(points, idx): 42 | """ 43 | Input: 44 | points: input points data, [B, N, C] 45 | idx: sample index data, [B, S] 46 | Return: 47 | new_points:, indexed points data, [B, S, C] 48 | """ 49 | device = points.device 50 | B = points.shape[0] 51 | view_shape = list(idx.shape) 52 | view_shape[1:] = [1] * (len(view_shape) - 1) 53 | repeat_shape = list(idx.shape) 54 | repeat_shape[0] = 1 55 | batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape) 56 | new_points = points[batch_indices, idx, :] 57 | return new_points 58 | 59 | 60 | def farthest_point_sample(xyz, npoint): 61 | """ 62 | Input: 63 | xyz: pointcloud data, [B, N, 3] 64 | npoint: number of samples 65 | Return: 66 | centroids: sampled pointcloud index, [B, npoint] 67 | """ 68 | device = xyz.device 69 | B, N, C = xyz.shape 70 | centroids = torch.zeros(B, npoint, dtype=torch.long).to(device) 71 | distance = torch.ones(B, N).to(device) * 1e10 72 | farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) 73 | batch_indices = torch.arange(B, dtype=torch.long).to(device) 74 | for i in range(npoint): 75 | centroids[:, i] = farthest 76 | centroid = xyz[batch_indices, farthest, :].view(B, 1, 3) 77 | dist = torch.sum((xyz - centroid) ** 2, -1) 78 | mask = dist < distance 79 | distance[mask] = dist[mask] 80 | farthest = torch.max(distance, -1)[1] 81 | return centroids 82 | 83 | 84 | def query_ball_point(radius, nsample, xyz, new_xyz): 85 | """ 86 | Input: 87 | radius: local region radius 88 | nsample: max sample number in local region 89 | xyz: all points, [B, N, 3] 90 | new_xyz: query points, [B, S, 3] 91 | Return: 92 | group_idx: grouped points index, [B, S, nsample] 93 | """ 94 | device = xyz.device 95 | B, N, C = xyz.shape 96 | _, S, _ = new_xyz.shape 97 | group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1]) 98 | sqrdists = square_distance(new_xyz, xyz) 99 | group_idx[sqrdists > radius ** 2] = N 100 | group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample] 101 | group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample]) 102 | mask = group_idx == N 103 | group_idx[mask] = group_first[mask] 104 | return group_idx 105 | 106 | 107 | def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False): 108 | """ 109 | Input: 110 | npoint: 111 | radius: 112 | nsample: 113 | xyz: input points position data, [B, N, 3] 114 | points: input points data, [B, N, D] 115 | Return: 116 | new_xyz: sampled points position data, [B, npoint, nsample, 3] 117 | new_points: sampled points data, [B, npoint, nsample, 3+D] 118 | """ 119 | B, N, C = xyz.shape 120 | S = npoint 121 | fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C] 122 | torch.cuda.empty_cache() 123 | new_xyz = index_points(xyz, fps_idx) 124 | torch.cuda.empty_cache() 125 | idx = query_ball_point(radius, nsample, xyz, new_xyz) 126 | torch.cuda.empty_cache() 127 | grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C] 128 | torch.cuda.empty_cache() 129 | grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C) 130 | torch.cuda.empty_cache() 131 | 132 | if points is not None: 133 | grouped_points = index_points(points, idx) 134 | new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D] 135 | else: 136 | new_points = grouped_xyz_norm 137 | if returnfps: 138 | return new_xyz, new_points, grouped_xyz, fps_idx 139 | else: 140 | return new_xyz, new_points 141 | 142 | 143 | def sample_and_group_all(xyz, points): 144 | """ 145 | Input: 146 | xyz: input points position data, [B, N, 3] 147 | points: input points data, [B, N, D] 148 | Return: 149 | new_xyz: sampled points position data, [B, 1, 3] 150 | new_points: sampled points data, [B, 1, N, 3+D] 151 | """ 152 | device = xyz.device 153 | B, N, C = xyz.shape 154 | new_xyz = torch.zeros(B, 1, C).to(device) 155 | grouped_xyz = xyz.view(B, 1, N, C) 156 | if points is not None: 157 | new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1) 158 | else: 159 | new_points = grouped_xyz 160 | return new_xyz, new_points 161 | 162 | 163 | class PointNetSetAbstraction(nn.Module): 164 | def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all): 165 | super(PointNetSetAbstraction, self).__init__() 166 | self.npoint = npoint 167 | self.radius = radius 168 | self.nsample = nsample 169 | self.mlp_convs = nn.ModuleList() 170 | self.mlp_bns = nn.ModuleList() 171 | last_channel = in_channel 172 | for out_channel in mlp: 173 | self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1)) 174 | self.mlp_bns.append(nn.BatchNorm2d(out_channel)) 175 | last_channel = out_channel 176 | self.group_all = group_all 177 | 178 | def forward(self, xyz, points): 179 | """ 180 | Input: 181 | xyz: input points position data, [B, C, N] 182 | points: input points data, [B, D, N] 183 | Return: 184 | new_xyz: sampled points position data, [B, C, S] 185 | new_points_concat: sample points feature data, [B, D', S] 186 | """ 187 | xyz = xyz.permute(0, 2, 1) 188 | if points is not None: 189 | points = points.permute(0, 2, 1) 190 | 191 | if self.group_all: 192 | new_xyz, new_points = sample_and_group_all(xyz, points) 193 | else: 194 | new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points) 195 | # new_xyz: sampled points position data, [B, npoint, C] 196 | # new_points: sampled points data, [B, npoint, nsample, C+D] 197 | new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint] 198 | for i, conv in enumerate(self.mlp_convs): 199 | bn = self.mlp_bns[i] 200 | new_points = F.relu(bn(conv(new_points))) 201 | 202 | new_points = torch.max(new_points, 2)[0] 203 | new_xyz = new_xyz.permute(0, 2, 1) 204 | return new_xyz, new_points 205 | 206 | 207 | class PointNetSetAbstractionMsg(nn.Module): 208 | def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list): 209 | super(PointNetSetAbstractionMsg, self).__init__() 210 | self.npoint = npoint 211 | self.radius_list = radius_list 212 | self.nsample_list = nsample_list 213 | self.conv_blocks = nn.ModuleList() 214 | self.bn_blocks = nn.ModuleList() 215 | for i in range(len(mlp_list)): 216 | convs = nn.ModuleList() 217 | bns = nn.ModuleList() 218 | last_channel = in_channel + 3 219 | for out_channel in mlp_list[i]: 220 | convs.append(nn.Conv2d(last_channel, out_channel, 1)) 221 | bns.append(nn.BatchNorm2d(out_channel)) 222 | last_channel = out_channel 223 | self.conv_blocks.append(convs) 224 | self.bn_blocks.append(bns) 225 | 226 | def forward(self, xyz, points): 227 | """ 228 | Input: 229 | xyz: input points position data, [B, C, N] 230 | points: input points data, [B, D, N] 231 | Return: 232 | new_xyz: sampled points position data, [B, C, S] 233 | new_points_concat: sample points feature data, [B, D', S] 234 | """ 235 | xyz = xyz.permute(0, 2, 1) 236 | if points is not None: 237 | points = points.permute(0, 2, 1) 238 | 239 | B, N, C = xyz.shape 240 | S = self.npoint 241 | new_xyz = index_points(xyz, farthest_point_sample(xyz, S)) 242 | new_points_list = [] 243 | for i, radius in enumerate(self.radius_list): 244 | K = self.nsample_list[i] 245 | group_idx = query_ball_point(radius, K, xyz, new_xyz) 246 | grouped_xyz = index_points(xyz, group_idx) 247 | grouped_xyz -= new_xyz.view(B, S, 1, C) 248 | if points is not None: 249 | grouped_points = index_points(points, group_idx) 250 | grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1) 251 | else: 252 | grouped_points = grouped_xyz 253 | 254 | grouped_points = grouped_points.permute(0, 3, 2, 1) # [B, D, K, S] 255 | for j in range(len(self.conv_blocks[i])): 256 | conv = self.conv_blocks[i][j] 257 | bn = self.bn_blocks[i][j] 258 | grouped_points = F.relu(bn(conv(grouped_points))) 259 | new_points = torch.max(grouped_points, 2)[0] # [B, D', S] 260 | new_points_list.append(new_points) 261 | 262 | new_xyz = new_xyz.permute(0, 2, 1) 263 | new_points_concat = torch.cat(new_points_list, dim=1) 264 | return new_xyz, new_points_concat 265 | 266 | 267 | class PointNetFeaturePropagation(nn.Module): 268 | def __init__(self, in_channel, mlp): 269 | super(PointNetFeaturePropagation, self).__init__() 270 | self.mlp_convs = nn.ModuleList() 271 | self.mlp_bns = nn.ModuleList() 272 | last_channel = in_channel 273 | for out_channel in mlp: 274 | self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1)) 275 | self.mlp_bns.append(nn.BatchNorm1d(out_channel)) 276 | last_channel = out_channel 277 | 278 | def forward(self, xyz1, xyz2, points1, points2): 279 | """ 280 | Input: 281 | xyz1: input points position data, [B, C, N] 282 | xyz2: sampled input points position data, [B, C, S] 283 | points1: input points data, [B, D, N] 284 | points2: input points data, [B, D, S] 285 | Return: 286 | new_points: upsampled points data, [B, D', N] 287 | """ 288 | xyz1 = xyz1.permute(0, 2, 1) 289 | xyz2 = xyz2.permute(0, 2, 1) 290 | 291 | points2 = points2.permute(0, 2, 1) 292 | B, N, C = xyz1.shape 293 | _, S, _ = xyz2.shape 294 | 295 | if S == 1: 296 | interpolated_points = points2.repeat(1, N, 1) 297 | else: 298 | dists = square_distance(xyz1, xyz2) 299 | dists, idx = dists.sort(dim=-1) 300 | dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3] 301 | 302 | dist_recip = 1.0 / (dists + 1e-8) 303 | norm = torch.sum(dist_recip, dim=2, keepdim=True) 304 | weight = dist_recip / norm 305 | interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2) 306 | 307 | if points1 is not None: 308 | points1 = points1.permute(0, 2, 1) 309 | new_points = torch.cat([points1, interpolated_points], dim=-1) 310 | else: 311 | new_points = interpolated_points 312 | 313 | new_points = new_points.permute(0, 2, 1) 314 | for i, conv in enumerate(self.mlp_convs): 315 | bn = self.mlp_bns[i] 316 | new_points = F.relu(bn(conv(new_points))) 317 | return new_points 318 | -------------------------------------------------------------------------------- /pose_check/models/uninet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.parallel 4 | import torch.utils.data 5 | import torch.nn.functional as F 6 | 7 | from .pointnet_utils import PointNetSetAbstractionMsg, PointNetSetAbstraction 8 | 9 | class FeatureExtraction(nn.Module): 10 | def __init__(self, normal_channel=False, mask_channel=False, out_dim=256): 11 | super(FeatureExtraction, self).__init__() 12 | in_channel = 0 13 | if mask_channel: 14 | in_channel += 1 15 | if normal_channel: 16 | in_channel += 3 17 | 18 | self.ext_channel = in_channel 19 | 20 | self.sa1 = PointNetSetAbstractionMsg(512, [0.1, 0.2, 0.4], [16, 32, 128], in_channel, [[32, 32, 64], [64, 64, 128], [64, 96, 128]]) 21 | self.sa2 = PointNetSetAbstractionMsg(128, [0.2, 0.4, 0.8], [32, 64, 128], 320, [[64, 64, 128], [128, 128, 256], [128, 128, 256]]) 22 | self.sa3 = PointNetSetAbstraction(None, None, None, 640 + 3, [256, 512, 1024], True) 23 | 24 | self.fc1 = nn.Linear(1024, 512) 25 | self.bn1 = nn.BatchNorm1d(512) 26 | self.fc2 = nn.Linear(512, out_dim) 27 | self.bn2 = nn.BatchNorm1d(out_dim) 28 | 29 | def forward(self, xyz): 30 | B, C, N = xyz.shape 31 | if self.ext_channel > 0: 32 | norm = xyz[:, 3:, :] 33 | xyz = xyz[:, :3, :] 34 | else: 35 | norm = None 36 | 37 | l1_xyz, l1_points = self.sa1(xyz, norm) 38 | l2_xyz, l2_points = self.sa2(l1_xyz, l1_points) 39 | l3_xyz, l3_points = self.sa3(l2_xyz, l2_points) 40 | 41 | x = l3_points.view(B, -1) 42 | 43 | x = F.relu(self.bn1(self.fc1(x))) 44 | point_feat = F.relu(self.bn2(self.fc2(x))) 45 | return point_feat, l3_points 46 | 47 | class UniNet(nn.Module): 48 | def __init__(self, num_class=2, feat_dim=512, mask_channel=False, normal_channel=False, only_test=False): 49 | super(UniNet, self).__init__() 50 | 51 | self.feat_ext = FeatureExtraction(normal_channel=normal_channel, 52 | mask_channel=mask_channel, 53 | out_dim=feat_dim) 54 | 55 | self.fc1 = nn.Linear(feat_dim, 256) 56 | self.bn1 = nn.BatchNorm1d(256) 57 | self.fc2 = nn.Linear(256, 128) 58 | self.bn2 = nn.BatchNorm1d(128) 59 | self.drop = nn.Dropout(0.3) # drop 30% during training 60 | self.cls = nn.Linear(128, num_class) 61 | # loss function 62 | self.CELoss = nn.CrossEntropyLoss() 63 | self.only_test = only_test 64 | 65 | def forward(self, sample): 66 | points = sample['data'] 67 | 68 | feat, _ = self.feat_ext(points) 69 | feat1 = F.relu(self.bn1(self.fc1(feat))) 70 | feat2 = F.relu(self.bn2(self.fc2(feat1))) 71 | feat2 = self.drop(feat2) 72 | pred = self.cls(feat2) 73 | prob = torch.softmax(pred, dim=1) 74 | if self.only_test: 75 | return {'prob': prob} 76 | 77 | gt = sample['label'] 78 | loss = self.get_loss(pred, gt) 79 | 80 | return {'loss': loss, 'prob': prob} 81 | 82 | def get_loss(self, pred_logits, gt_labels): 83 | loss = self.CELoss(pred_logits, gt_labels) 84 | return loss 85 | 86 | 87 | if __name__ == '__main__': 88 | # setting 1 89 | point_w_mask = torch.ones((2, 4, 513)) 90 | model1 = UniNet(mask_channel=True) 91 | pred = model1(point_w_mask) 92 | print(pred.shape) 93 | -------------------------------------------------------------------------------- /pose_check/models/uninet_mt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.parallel 4 | import torch.utils.data 5 | import torch.nn.functional as F 6 | 7 | from .pointnet_utils import PointNetSetAbstractionMsg, PointNetSetAbstraction, PointNetFeaturePropagation 8 | 9 | class UniNet_MT_V2(nn.Module): 10 | def __init__(self, normal_channel=False, mask_channel=False, bootle_neck=256, only_test=False): 11 | super(UniNet_MT_V2, self).__init__() 12 | 13 | self.normal_channel = normal_channel 14 | self.mask_channel = mask_channel 15 | self.only_test = only_test 16 | 17 | additional_channel = 0 18 | if mask_channel: 19 | additional_channel += 1 20 | if normal_channel: 21 | additional_channel += 3 22 | 23 | self.sa1 = PointNetSetAbstractionMsg(512, [0.05, 0.1], [16, 32], additional_channel, [[8, 8, 16], [16, 16, 32]]) 24 | self.sa2 = PointNetSetAbstractionMsg(128, [0.1, 0.2], [16, 32], 16+32, [[32, 32, 64], [32, 48, 64]]) 25 | self.sa3 = PointNetSetAbstractionMsg(64, [0.2, 0.4], [16, 32], 64+64, [[64, 96, 128], [64, 96, 128]]) 26 | self.sa4 = PointNetSetAbstractionMsg(16, [0.4, 0.8], [16, 32], 128+128, [[128, 128, 256], [128, 172, bootle_neck]]) 27 | 28 | self.sa5 = PointNetSetAbstraction(None, None, None, bootle_neck+256 + 3, [256, 128], True) 29 | 30 | self.fp4 = PointNetFeaturePropagation(bootle_neck+256+128+128, [512, 256]) 31 | self.fp3 = PointNetFeaturePropagation(64+64+256, [256, 128]) 32 | self.fp2 = PointNetFeaturePropagation(16+32+128, [128, 96]) 33 | self.fp1 = PointNetFeaturePropagation(96, [64, 64]) 34 | 35 | self.contact_cls = nn.Sequential(*[ 36 | nn.Linear(128, 64), 37 | nn.BatchNorm1d(64), 38 | nn.ReLU(), 39 | nn.Dropout(0.2), # drop 30% during training 40 | nn.Linear(64, 32), 41 | nn.Linear(32, 2) 42 | ]) 43 | 44 | self.stable_cls = nn.Sequential(*[ 45 | nn.Linear(128, 64), 46 | nn.BatchNorm1d(64), 47 | nn.ReLU(), 48 | nn.Dropout(0.2), # drop 30% during training 49 | nn.Linear(64, 32), 50 | nn.Linear(32, 2) 51 | ]) 52 | 53 | self.offset_reg = nn.Sequential(*[nn.Conv1d(64, 32, kernel_size=(1, )), 54 | nn.BatchNorm1d(32), 55 | nn.ReLU(), 56 | nn.Conv1d(32, 16, kernel_size=(1, )), 57 | nn.Conv1d(16, 3, kernel_size=(1, )) 58 | ]) 59 | 60 | def forward(self, sample): 61 | xyz = sample['data'] 62 | 63 | B, C, N = xyz.shape 64 | l0_xyz = xyz[:, :3, :] 65 | if self.normal_channel or self.mask_channel: 66 | extr_channel = xyz[:, 3:, :] 67 | else: 68 | extr_channel = None 69 | 70 | l1_xyz, l1_points = self.sa1(l0_xyz, extr_channel) 71 | l2_xyz, l2_points = self.sa2(l1_xyz, l1_points) 72 | l3_xyz, l3_points = self.sa3(l2_xyz, l2_points) 73 | l4_xyz, l4_points = self.sa4(l3_xyz, l3_points) 74 | 75 | l5_xyz, l5_points = self.sa5(l4_xyz, l4_points) 76 | 77 | l5_points_flatten = l5_points.view(B, -1) 78 | 79 | l3_points = self.fp4(l3_xyz, l4_xyz, l3_points, l4_points) 80 | l2_points = self.fp3(l2_xyz, l3_xyz, l2_points, l3_points) 81 | l1_points = self.fp2(l1_xyz, l2_xyz, l1_points, l2_points) 82 | l0_points = self.fp1(l0_xyz, l1_xyz, None, l1_points) 83 | 84 | stable_pred = self.stable_cls(l5_points_flatten) 85 | contact_pred = self.contact_cls(l5_points_flatten) 86 | offset_pred = self.offset_reg(l0_points) 87 | 88 | if self.only_test: 89 | return {'preds': [stable_pred, contact_pred, offset_pred]} 90 | else: 91 | loss_items = {'stable': [stable_pred, sample['stable'], 1.], 92 | 'contact': [contact_pred, sample['contact'], 1.], 93 | 'offset': [offset_pred, sample, 10., 2.] 94 | } 95 | ret = self.calc_loss(loss_items) 96 | ret['preds'] = [stable_pred, contact_pred, offset_pred] 97 | return ret 98 | 99 | def calc_loss(self, loss_items: dict): 100 | stable_loss = F.cross_entropy(loss_items['stable'][0], 101 | loss_items['stable'][1], 102 | ignore_index=255) 103 | contact_loss = F.cross_entropy(loss_items['contact'][0], 104 | loss_items['contact'][1], 105 | ignore_index=255) 106 | 107 | 108 | gt_offset, ind = self.compute_vector_field(loss_items['offset'][1]) # (B, 3, N), (B, N) 109 | offset_loss = F.smooth_l1_loss(gt_offset, loss_items['offset'][0], reduce=False) # (B, N) 110 | 111 | end_points = loss_items['offset'][1]['data'][:, :3, :] + loss_items['offset'][0] 112 | 113 | var_loss = self.calc_variance(end_points=end_points, 114 | index=ind, 115 | num_contacts=loss_items['offset'][1]['total_contacts']) 116 | 117 | offset_loss = offset_loss[loss_items['contact'][1] == 1] 118 | offset_loss = torch.mean(offset_loss) 119 | offset_loss = torch.nan_to_num(offset_loss) 120 | 121 | total_loss = stable_loss * loss_items['stable'][2] + \ 122 | contact_loss * loss_items['contact'][2] + \ 123 | offset_loss * loss_items['offset'][2] + \ 124 | var_loss * loss_items['offset'][3] 125 | 126 | return {'loss': total_loss, 'loss_items': [stable_loss, contact_loss, offset_loss, var_loss]} 127 | 128 | def compute_vector_field(self, sample): 129 | ''' 130 | 131 | :param points: (B, N, 3) 132 | :param contacts: (B, M, 3) 133 | :return: 134 | ''' 135 | points = sample['data'].permute(0, 2, 1)[..., :3] # (B, N, 3) 136 | contacts = sample['contact_points'] # (B, 30, 3) 137 | num_contacts = sample['total_contacts'] # (B, ) 138 | B = points.shape[0] 139 | N = points.shape[1] 140 | inds = [] 141 | vects = [] 142 | with torch.no_grad(): 143 | for i in range(B): 144 | cnt = int(num_contacts[i]) 145 | if cnt > 0: 146 | vec_field = contacts[i:i+1, :cnt] - points[i].unsqueeze(1) # (N, M, 3) 147 | nm_dist = torch.sum(vec_field**2, dim=2, keepdim=False) # (N, M) 148 | ind1 = torch.argmin(nm_dist, dim=1, keepdim=False) # (N, ) 149 | ind0 = torch.arange(nm_dist.shape[0]) # (N, ) 150 | select_vect = vec_field[ind0, ind1].permute(1, 0) # (3, N) 151 | else: 152 | select_vect = torch.zeros((3, N), device=points.device) 153 | ind1 = torch.zeros((N, ), device=points.device, dtype=torch.long) 154 | 155 | vects.append(select_vect) 156 | inds.append(ind1) 157 | 158 | return torch.stack(vects, 0), torch.stack(inds, 0) 159 | 160 | def calc_variance(self, end_points, index, num_contacts): 161 | B = num_contacts.shape[0] 162 | var_loss = torch.zeros((), device=end_points.device, dtype=torch.float) 163 | end_points = end_points.permute(0, 2, 1) # (B, N, 3) 164 | cnt = 0 165 | for i in range(B): 166 | if int(num_contacts[i]) == 0: 167 | continue 168 | cnt += 1 169 | endp_i = end_points[i] # (N, 3) 170 | cont_id = torch.unique(index[i]) 171 | for j in cont_id: 172 | jth_ind = index[i] == int(j) 173 | mean_j = torch.mean(endp_i[jth_ind], dim=0, keepdim=True) # (1, 3) 174 | var_j = torch.sum((endp_i[jth_ind] - mean_j) ** 2, dim=1, keepdim=False) # (n, ) 175 | var_loss = var_loss + torch.mean(var_j) 176 | 177 | var_loss =var_loss / float(cnt+1e-16) 178 | return var_loss 179 | 180 | 181 | -------------------------------------------------------------------------------- /pose_check/train_multi_task_var_impl.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import gc 3 | import os 4 | import time 5 | 6 | import numpy as np 7 | import torch 8 | import torch.backends.cudnn as cudnn 9 | import torch.distributed as dist 10 | import torch.nn as nn 11 | import torch.optim as optim 12 | from tensorboardX import SummaryWriter 13 | from torch.utils.data import DataLoader 14 | 15 | from data_loader.pose_check_dataset import MultiTaskDatasetV2 16 | from models.uninet_mt import UniNet_MT_V2 17 | from utils.utils import get_step_schedule_with_warmup, dict2cuda, add_summary, \ 18 | DictAverageMeter, calc_stat 19 | 20 | cudnn.benchmark = True 21 | 22 | parser = argparse.ArgumentParser(description='Deep stereo using adaptive cost volume.') 23 | parser.add_argument('--root_path', type=str, help='path to root directory.') 24 | parser.add_argument('--train_list', type=str, help='train scene list.', default='') 25 | parser.add_argument('--val_list', type=str, help='val scene list.', default='') 26 | parser.add_argument('--save_path', type=str, help='path to save checkpoints.') 27 | parser.add_argument('--restore_path', type=str, default='') 28 | 29 | parser.add_argument('--epochs', type=int, default=20) 30 | parser.add_argument('--lr', type=float, default=0.001) 31 | parser.add_argument('--lr_idx', type=str, default="50,100,160:0.5") 32 | parser.add_argument('--wd', type=float, default=0.0, help='weight decay') 33 | parser.add_argument('--batch_size', type=int, default=32) 34 | 35 | parser.add_argument('--log_freq', type=int, default=100, help='print and summary frequency') 36 | parser.add_argument('--save_freq', type=int, default=10000, help='save checkpoint frequency.') 37 | parser.add_argument('--eval_freq', type=int, default=10000, help='evaluate frequency.') 38 | 39 | parser.add_argument('--sync_bn', action='store_true',help='Sync BN.') 40 | parser.add_argument('--opt_level', type=str, default="O0") 41 | parser.add_argument('--seed', type=int, default=0) 42 | parser.add_argument('--local_rank', type=int, default=0) 43 | parser.add_argument('--num_workers', type=int, default=4) 44 | parser.add_argument('--distributed', action='store_true') 45 | 46 | args = parser.parse_args() 47 | # num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 48 | is_distributed = args.distributed 49 | 50 | torch.manual_seed(args.seed) 51 | torch.cuda.manual_seed(args.seed) 52 | device = torch.device("cuda") 53 | 54 | if args.sync_bn: 55 | import apex 56 | import apex.amp as amp 57 | 58 | 59 | def print_func(data: dict, prefix: str= ''): 60 | for k, v in data.items(): 61 | if isinstance(v, dict): 62 | print_func(v, prefix + '.' + k) 63 | elif isinstance(v, list): 64 | print(prefix+'.'+k, v) 65 | else: 66 | print(prefix+'.'+k, v.shape) 67 | 68 | def main(args, model, optimizer, scheduler, train_loader, val_loader, train_sampler, start_step=0): 69 | 70 | train_step = start_step 71 | start_ep = start_step // len(train_loader) 72 | 73 | model.train() 74 | for ep in range(start_ep, args.epochs): 75 | np.random.seed() 76 | train_scores = DictAverageMeter() 77 | if train_sampler is not None: 78 | train_sampler.set_epoch(ep) 79 | 80 | for batch_idx, sample in enumerate(train_loader): 81 | tic = time.time() 82 | 83 | sample_cuda = dict2cuda(sample) 84 | 85 | # print_func(sample_cuda) 86 | optimizer.zero_grad() 87 | ret = model(sample_cuda) 88 | loss = ret['loss'].mean() 89 | preds = ret['preds'] 90 | loss_items = [l.mean() for l in ret['loss_items']] 91 | 92 | # print_func(outputs) 93 | if is_distributed and args.sync_bn: 94 | with amp.scale_loss(loss, optimizer) as scaled_loss: 95 | scaled_loss.backward() 96 | else: 97 | loss.backward() 98 | 99 | optimizer.step() 100 | scheduler.step() 101 | 102 | if train_step % args.log_freq == 0: 103 | train_scores.update({'loss': float(loss), 104 | 'contact_loss': float(loss_items[0]), 105 | 'stable_loss': float(loss_items[1]), 106 | 'offset_loss': float(loss_items[2]), 107 | 'variance_loss': float(loss_items[3]) 108 | }) 109 | 110 | calc_stat(sample_cuda, preds[0], train_scores, label_type='stable') 111 | calc_stat(sample_cuda, preds[1], train_scores, label_type='contact') 112 | 113 | avg_stat = train_scores.mean() 114 | print("[Rank: {}] time={:.2f} Epoch {}/{}, Iter {}/{}, lr {:.6f}, stats: {}".format( 115 | args.local_rank, time.time() - tic, 116 | ep, args.epochs, batch_idx, len(train_loader), 117 | optimizer.param_groups[0]["lr"], 118 | avg_stat)) 119 | if on_main: 120 | add_summary([{'type': 'scalars', 'tags': list(avg_stat.keys()), 121 | 'vals': list(avg_stat.values())}], 122 | logger=logger, step=train_step, flag='train') 123 | 124 | del sample_cuda 125 | del avg_stat 126 | gc.collect() 127 | 128 | if on_main and train_step % args.save_freq == 0: 129 | torch.save({"step": train_step, 130 | "model": model.module.state_dict(), 131 | "optimizer": optimizer.state_dict(), 132 | "scheduler": scheduler.state_dict(), 133 | }, 134 | "{}/model_{:08d}.ckpt".format(args.save_path, train_step)) 135 | 136 | if train_step % args.eval_freq == 0: 137 | print('evaluating model_{:08d}.ckpt ...'.format(train_step)) 138 | with torch.no_grad(): 139 | test(args, model, val_loader, train_step) 140 | model.train() 141 | 142 | train_step += 1 143 | 144 | del train_scores 145 | gc.collect() 146 | 147 | def test(args, model, test_loader, train_step): 148 | model.eval() 149 | val_scores = DictAverageMeter() 150 | for batch_idx, sample in enumerate(test_loader): 151 | sample_cuda = dict2cuda(sample) 152 | ret = model(sample_cuda) 153 | preds = ret['preds'] 154 | calc_stat(sample_cuda, preds[0], val_scores, label_type='stable') 155 | calc_stat(sample_cuda, preds[1], val_scores, label_type='contact') 156 | 157 | avg_stat = val_scores.mean() 158 | print("[Rank: {}] step {:06d}, stats: {}".format(args.local_rank, train_step, avg_stat)) 159 | if on_main: 160 | add_summary([{'type': 'scalars', 'tags': list(avg_stat.keys()), 161 | 'vals': list(avg_stat.values())}], 162 | logger=logger, step=train_step, flag='val') 163 | 164 | del sample_cuda 165 | del avg_stat 166 | del val_scores 167 | gc.collect() 168 | 169 | def distribute_model(args): 170 | def sync(): 171 | if not dist.is_available(): 172 | return 173 | if not dist.is_initialized(): 174 | return 175 | if dist.get_world_size() == 1: 176 | return 177 | dist.barrier() 178 | if is_distributed: 179 | torch.cuda.set_device(args.local_rank) 180 | torch.distributed.init_process_group( 181 | backend="nccl", init_method="env://" 182 | ) 183 | sync() 184 | 185 | start_step = 0 186 | 187 | model: torch.nn.Module = UniNet_MT_V2(mask_channel=True, bootle_neck=256) 188 | if args.restore_path: 189 | checkpoint = torch.load(args.restore_path, map_location=torch.device("cpu")) 190 | model.load_state_dict(checkpoint['model'], strict=True) 191 | 192 | model.to(device) 193 | 194 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, betas=(0.9, 0.999), 195 | weight_decay=args.wd) 196 | 197 | train_set = MultiTaskDatasetV2(root_dir=args.root_path, list_path=args.train_list, 198 | use_aug=True, ) 199 | print('train set ready.') 200 | val_set = MultiTaskDatasetV2(root_dir=args.root_path, list_path=args.val_list, 201 | use_aug=False,) 202 | print('val set ready.') 203 | if is_distributed: 204 | if args.sync_bn: 205 | model = apex.parallel.convert_syncbn_model(model) 206 | model, optimizer = amp.initialize(model, optimizer, opt_level=args.opt_level, ) 207 | print('Convert BN to Sync_BN successful.') 208 | 209 | model = nn.parallel.DistributedDataParallel( 210 | model, device_ids=[args.local_rank], output_device=args.local_rank,) 211 | 212 | train_sampler = torch.utils.data.DistributedSampler(train_set, num_replicas=dist.get_world_size(), 213 | rank=dist.get_rank()) 214 | val_sampler = torch.utils.data.DistributedSampler(val_set, num_replicas=dist.get_world_size(), 215 | rank=dist.get_rank()) 216 | else: 217 | model = nn.DataParallel(model) 218 | train_sampler, val_sampler = None, None 219 | 220 | def worker_init_fn(worker_id): 221 | np.random.seed(np.random.get_state()[1][0] + worker_id) 222 | 223 | train_loader = DataLoader(train_set, args.batch_size, sampler=train_sampler, 224 | num_workers=args.num_workers, pin_memory=True, 225 | drop_last=True, shuffle=not is_distributed, worker_init_fn=worker_init_fn) 226 | val_loader = DataLoader(val_set, 64, sampler=val_sampler, 227 | num_workers=1, pin_memory=True, 228 | drop_last=False, shuffle=False, worker_init_fn=worker_init_fn) 229 | 230 | milestones = list(map(float, args.lr_idx.split(':')[0].split(','))) 231 | assert np.max(milestones) <= 1.0, milestones 232 | milestones = list(map(lambda x: int(float(x) * float(len(train_loader) * args.epochs)), milestones)) 233 | gamma = float(args.lr_idx.split(':')[1]) 234 | warpup_iters = min(500, int(0.05*len(train_loader))) 235 | 236 | scheduler = get_step_schedule_with_warmup(optimizer=optimizer, milestones=milestones, 237 | gamma=gamma, warmup_iters=warpup_iters) 238 | 239 | if args.restore_path: 240 | optimizer.load_state_dict(checkpoint['optimizer']) 241 | scheduler.load_state_dict(checkpoint['scheduler']) 242 | start_step = checkpoint['step'] 243 | print("Restoring checkpoint {} ...".format(args.restore_path)) 244 | 245 | return model, optimizer, scheduler, train_loader, val_loader, train_sampler, start_step 246 | 247 | if __name__ == '__main__': 248 | model, optimizer, scheduler, train_loader, val_loader, train_sampler, start_step = distribute_model(args) 249 | on_main = (not is_distributed) or (dist.get_rank() == 0) 250 | if on_main: 251 | os.makedirs(args.save_path, exist_ok=True) 252 | logger = SummaryWriter(args.save_path) 253 | print(args) 254 | 255 | main(args=args, model=model, optimizer=optimizer, scheduler=scheduler, 256 | train_loader=train_loader, val_loader=val_loader, train_sampler=train_sampler, start_step=start_step) 257 | 258 | -------------------------------------------------------------------------------- /pose_check/utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim.lr_scheduler import LambdaLR 3 | import torchvision.utils as vutils 4 | import torch.distributed as dist 5 | 6 | import errno 7 | import os 8 | import re 9 | import sys 10 | import numpy as np 11 | from bisect import bisect_right 12 | 13 | 14 | def dict2cuda(data: dict): 15 | new_dic = {} 16 | for k, v in data.items(): 17 | if isinstance(v, dict): 18 | v = dict2cuda(v) 19 | elif isinstance(v, torch.Tensor): 20 | v = v.cuda() 21 | new_dic[k] = v 22 | return new_dic 23 | 24 | def dict2numpy(data: dict): 25 | new_dic = {} 26 | for k, v in data.items(): 27 | if isinstance(v, dict): 28 | v = dict2numpy(v) 29 | elif isinstance(v, torch.Tensor): 30 | v = v.detach().cpu().numpy().copy() 31 | new_dic[k] = v 32 | return new_dic 33 | 34 | def dict2float(data: dict): 35 | new_dic = {} 36 | for k, v in data.items(): 37 | if isinstance(v, dict): 38 | v = dict2float(v) 39 | elif isinstance(v, torch.Tensor): 40 | v = v.detach().cpu().item() 41 | new_dic[k] = v 42 | return new_dic 43 | 44 | def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1): 45 | """ Create a schedule with a learning rate that decreases linearly after 46 | linearly increasing during a warmup period. 47 | """ 48 | def lr_lambda(current_step): 49 | if current_step < num_warmup_steps: 50 | return float(current_step) / float(max(1, num_warmup_steps)) 51 | return max( 52 | 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)) 53 | ) 54 | 55 | return LambdaLR(optimizer, lr_lambda, last_epoch) 56 | 57 | def get_step_schedule_with_warmup(optimizer, milestones, gamma=0.1, warmup_factor=1.0/3, warmup_iters=500, last_epoch=-1,): 58 | def lr_lambda(current_step): 59 | if current_step < warmup_iters: 60 | alpha = float(current_step) / warmup_iters 61 | current_factor = warmup_factor * (1. - alpha) + alpha 62 | else: 63 | current_factor = 1. 64 | 65 | return max(0.0, current_factor * (gamma ** bisect_right(milestones, current_step))) 66 | 67 | return LambdaLR(optimizer, lr_lambda, last_epoch) 68 | 69 | def add_summary(data_items: list, logger, step: int, flag: str, max_disp=4): 70 | for data_item in data_items: 71 | tags = data_item['tags'] 72 | vals = data_item['vals'] 73 | dtype = data_item['type'] 74 | if dtype == 'points': 75 | for i in range(min(max_disp, len(tags))): 76 | logger.add_mesh('{}/{}'.format(flag, tags[i]), 77 | vertices=vals[0], colors=vals[1], global_step=step) 78 | elif dtype == 'scalars': 79 | for tag, val in zip(tags, vals): 80 | if val == 'None': 81 | val = 0 82 | logger.add_scalar('{}/{}'.format(flag, tag), 83 | val, global_step=step) 84 | else: 85 | raise NotImplementedError 86 | 87 | class DictAverageMeter(object): 88 | def __init__(self): 89 | self.data = {} 90 | 91 | def update(self, new_input: dict): 92 | for k, v in new_input.items(): 93 | if isinstance(v, list): 94 | self.data[k] = self.data.get(k, []) + v 95 | else: 96 | assert (isinstance(v, float) or isinstance(v, int)), type(v) 97 | self.data[k] = self.data.get(k, []) + [v] 98 | 99 | def mean(self): 100 | ret = {} 101 | for k, v in self.data.items(): 102 | if not v: 103 | ret[k] = 'None' 104 | else: 105 | ret[k] = np.round(np.mean(v), 4) 106 | return ret 107 | 108 | def reset(self): 109 | self.data = {} 110 | 111 | def calc_stat(sample, prob, scores, label_type='label', ignore_id=255): 112 | T2L = lambda x: x.float().detach().cpu().numpy().tolist() 113 | 114 | labels = sample[label_type] 115 | max_probs, preds = torch.max(prob, dim=1, keepdim=False) 116 | 117 | # remove ignore cases 118 | valid_inds = labels != ignore_id 119 | labels = labels[valid_inds] 120 | max_probs = max_probs[valid_inds] 121 | preds = preds[valid_inds] 122 | # 123 | 124 | all_acc = torch.mean((preds == labels).float()).item() 125 | scores.update({'{}_all_acc'.format(label_type): all_acc}) 126 | 127 | for i in range(2): 128 | pst_inds = (preds == i) 129 | if torch.sum(pst_inds) > 0: 130 | precision = T2L(preds[pst_inds] == labels[pst_inds]) 131 | else: 132 | precision = [] 133 | scores.update({'{}_precision_{}'.format(label_type, i): precision}) 134 | 135 | for thresh in [0.1, 0.4]: 136 | sel_inds = (torch.abs(max_probs-0.5) > thresh) 137 | ratio = torch.mean(sel_inds.float()).item() 138 | if ratio > 0: 139 | th_acc = T2L(preds[sel_inds] == labels[sel_inds]) 140 | else: 141 | th_acc = [] 142 | 143 | scores.update({'{}_P{}_ratio'.format(label_type, thresh): ratio, 144 | '{}_P{}_acc'.format(label_type, thresh): th_acc, 145 | }) -------------------------------------------------------------------------------- /pose_generation/data_loader/stable_pose_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | 4 | import numpy as np 5 | import os.path as osp 6 | import open3d as o3d 7 | import json 8 | from collections import OrderedDict 9 | from scipy.spatial.transform import Rotation as R 10 | import os 11 | 12 | def write_file(path, data_list): 13 | dir_name = osp.dirname(path) 14 | if dir_name: 15 | os.makedirs(dir_name, exist_ok=True) 16 | with open(path, 'w') as f: 17 | json.dump(data_list, f) 18 | 19 | def read_file(path): 20 | with open(path, 'r') as fd: 21 | data_list = json.load(fd, object_hook=OrderedDict) 22 | return data_list 23 | 24 | def parse_data(data_pairs): 25 | name2data = {} 26 | for pair in data_pairs: 27 | pair_id = pair['pair_id'] 28 | if not pair['stable']: 29 | print("{}_{:06d} not stable, skip.".format(pair_id, pair['index'])) 30 | continue 31 | if len(pair['support_contact']) == 0: 32 | print("{}_{:06d} no support contact, skip.".format(pair_id, pair['index'])) 33 | continue 34 | 35 | transform = np.array(pair['transform'], dtype='float32') 36 | if pair_id in name2data: 37 | name2data[pair_id]['stable_transforms'].append(transform) 38 | else: 39 | name2data[pair_id] = {'sup_init_path': pair['sup_init_path'], 40 | 'obj_init_path': pair['obj_init_path'], 41 | 'stable_transforms': [transform, ] 42 | } 43 | print('stable data pairs: ', name2data.keys()) 44 | return list(name2data.values()) 45 | 46 | def load_data(root_dir, list_path): 47 | data_list = open(list_path, 'r').readlines() 48 | data_list = list(map(str.strip, data_list)) 49 | all_pairs = [] 50 | for line in data_list: 51 | one_pair = read_file(osp.join(root_dir, line)) 52 | all_pairs += one_pair 53 | return parse_data(all_pairs) 54 | 55 | def load_ply(path, pc_len): 56 | pcd = o3d.io.read_point_cloud(path) 57 | point_cloud = np.asarray(pcd.points) 58 | colors = np.asarray(pcd.colors) 59 | if len(point_cloud) < pc_len: 60 | ind = np.random.choice(len(point_cloud), pc_len-len(point_cloud)) 61 | point_cloud = np.concatenate([point_cloud, point_cloud[ind]], 0) 62 | elif len(point_cloud) > pc_len: 63 | ind = np.random.choice(len(point_cloud), pc_len) 64 | point_cloud = point_cloud[ind] 65 | return point_cloud 66 | 67 | def load_mv_ply(root_dir, path, pc_len): 68 | flag = float(np.random.uniform(0, 1, ())) 69 | if flag < 0.25: 70 | num_v = 1 71 | else: 72 | num_v = 2 73 | v_inds = np.random.choice(range(4), num_v, replace=False) 74 | pcs = [] 75 | for i in range(num_v): 76 | path_i = path.rsplit('.ply', 1)[0]+'.v{:04d}.ply'.format(int(v_inds[i])) 77 | pcs.append(load_ply(osp.join(root_dir, path_i), pc_len)) 78 | point_cloud = np.concatenate(pcs, 0) 79 | if len(point_cloud) < pc_len: 80 | ind = np.random.choice(len(point_cloud), pc_len-len(point_cloud)) 81 | point_cloud = np.concatenate([point_cloud, point_cloud[ind]], 0) 82 | elif len(point_cloud) > pc_len: 83 | ind = np.random.choice(len(point_cloud), pc_len) 84 | point_cloud = point_cloud[ind] 85 | return point_cloud 86 | 87 | class StablePoseDataset(Dataset): 88 | def __init__(self, root_dir, list_path, pose_num=256, pc_len=1024, use_aug=True): 89 | super(StablePoseDataset, self).__init__() 90 | self.root_dir = root_dir 91 | self.pose_num = pose_num 92 | self.pc_len = pc_len 93 | self.use_aug = use_aug 94 | 95 | data_pairs = load_data(root_dir, list_path) 96 | np.random.shuffle(data_pairs) 97 | 98 | self.data_pairs = data_pairs 99 | print('Total pairs [{}]: '.format(list_path), len(self.data_pairs)) 100 | 101 | def __getitem__(self, index): 102 | sup_ply = load_mv_ply(self.root_dir, self.data_pairs[index]['sup_init_path'], 103 | pc_len=self.pc_len) 104 | obj_ply = load_mv_ply(self.root_dir, self.data_pairs[index]['obj_init_path'], 105 | pc_len=self.pc_len) 106 | # (N, 3), (N, 3) 107 | 108 | stable_transform = np.array(self.data_pairs[index]['stable_transforms'], dtype='float32') 109 | # (M, 4, 4) 110 | 111 | if len(stable_transform) >= self.pose_num: 112 | select_inds = np.random.choice(len(stable_transform), 113 | self.pose_num, replace=False) 114 | else: 115 | select_inds = np.random.choice(len(stable_transform), 116 | self.pose_num, replace=True) 117 | stable_transform = stable_transform[select_inds.tolist()] 118 | assert stable_transform.shape == (self.pose_num, 4, 4), stable_transform.shape 119 | 120 | if self.use_aug: 121 | sup_ply, obj_ply, stable_transform = random_transform_pair(support=sup_ply, 122 | object=obj_ply, 123 | transforms=stable_transform) 124 | 125 | ret = {'support': sup_ply.T, 'object': obj_ply.T, 126 | 'transforms': stable_transform, 127 | 'sup_path': self.data_pairs[index]['sup_init_path'], 128 | 'obj_path': self.data_pairs[index]['obj_init_path']} 129 | return ret 130 | 131 | def __len__(self): 132 | return len(self.data_pairs) 133 | 134 | def apply_transform(t, points): 135 | ''' 136 | 137 | :param t: (4, 4) 138 | :param points: (N, 3) 139 | :return: 140 | ''' 141 | N = points.shape[0] 142 | ones = np.ones((N, 1)) 143 | points = np.concatenate([points, ones], 1) # (N, 4) 144 | points = np.expand_dims(points, 2) # (N, 4, 1) 145 | t = np.expand_dims(t, 0) # (1, 4, 4) 146 | points = np.matmul(t, points)[:, :3, 0] # () 147 | return points 148 | 149 | 150 | def random_transform_pair(support, object, transforms): 151 | ''' 152 | 153 | :param support: (N, 3) 154 | :param object: (N, 3) 155 | :param transforms: (M, 4, 4) 156 | :return: 157 | ''' 158 | 159 | degs = np.random.uniform(0, 360, size=(2, )) 160 | r = R.from_euler('z', degs, degrees=True) 161 | 162 | t0 = np.eye(4) 163 | t1 = np.eye(4) 164 | t0[:3, :3] = r.as_matrix()[0] 165 | t1[:3, :3] = r.as_matrix()[1] 166 | 167 | xyz_range = np.array([[-0.005, -0.005, -0.005], 168 | [0.005, 0.005, 0.005]]) 169 | scales = np.random.uniform(0, 1, size=(2, 3)) 170 | offset = scales * (xyz_range[1:2]-xyz_range[0:1]) + xyz_range[0:1] 171 | 172 | t0[:3, 3] = offset[0] 173 | t1[:3, 3] = offset[1] 174 | 175 | object_t = apply_transform(t0, object) 176 | support_t = apply_transform(t1, support) 177 | 178 | t0 = np.expand_dims(t0, axis=0) 179 | t1 = np.expand_dims(t1, axis=0) 180 | transforms_t = np.matmul(t1, np.matmul(transforms, np.linalg.inv(t0))) 181 | 182 | sigma = 0.003 183 | noise_s = np.random.normal(0, sigma, size=(support_t.shape[0], 3)) 184 | support_t += noise_s 185 | noise_o = np.random.normal(0, sigma, size=(object_t.shape[0], 3)) 186 | object_t += noise_o 187 | 188 | return support_t, object_t, transforms_t 189 | 190 | 191 | 192 | if __name__ == '__main__': 193 | from torch.utils.data import DataLoader 194 | 195 | dataset = StablePoseDataset('../../dataset', 196 | list_path='../../dataset/data_list/generator_debug.txt',) 197 | loaderV = DataLoader(dataset, 1, sampler=None, num_workers=1, 198 | drop_last=False, shuffle=False) 199 | 200 | def visualize(sup, obj, ): 201 | pcd = o3d.geometry.PointCloud() 202 | points = np.concatenate([sup, obj], 0) 203 | colors = np.zeros((len(points), 3)) 204 | colors[:len(sup), 2] = 1. 205 | colors[len(sup):, 1] = 1. 206 | pcd.points = o3d.utility.Vector3dVector(points) 207 | pcd.colors = o3d.utility.Vector3dVector(colors) 208 | 209 | o3d.visualization.draw_geometries([pcd]) 210 | 211 | for t in range(10): 212 | for i, data in enumerate(iter(dataset)): 213 | print(data.keys()) 214 | sup_ply = data['support'].T 215 | obj_ply = data['object'].T 216 | transforms = data['transforms'] 217 | print(sup_ply.shape, transforms.shape) 218 | transforms = transforms[0] 219 | obj_ply = apply_transform(transforms, obj_ply) 220 | visualize(sup_ply, obj_ply) 221 | 222 | 223 | 224 | 225 | 226 | -------------------------------------------------------------------------------- /pose_generation/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/touristCheng/Learning2Regrasp/8152c539af6538fd8f4b9fe328ec4ca314abd74c/pose_generation/models/__init__.py -------------------------------------------------------------------------------- /pose_generation/models/pointnet_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from time import time 5 | import numpy as np 6 | 7 | def timeit(tag, t): 8 | print("{}: {}s".format(tag, time() - t)) 9 | return time() 10 | 11 | def square_distance(src, dst): 12 | """ 13 | Calculate Euclid distance between each two points. 14 | src^T * dst = xn * xm + yn * ym + zn * zm; 15 | sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn; 16 | sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm; 17 | dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2 18 | = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst 19 | Input: 20 | src: source points, [B, N, C] 21 | dst: target points, [B, M, C] 22 | Output: 23 | dist: per-point square distance, [B, N, M] 24 | """ 25 | B, N, _ = src.shape 26 | _, M, _ = dst.shape 27 | dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) 28 | dist += torch.sum(src ** 2, -1).view(B, N, 1) 29 | dist += torch.sum(dst ** 2, -1).view(B, 1, M) 30 | return dist 31 | 32 | def index_points(points, idx): 33 | """ 34 | Input: 35 | points: input points data, [B, N, C] 36 | idx: sample index data, [B, S] 37 | Return: 38 | new_points:, indexed points data, [B, S, C] 39 | """ 40 | device = points.device 41 | B = points.shape[0] 42 | view_shape = list(idx.shape) 43 | view_shape[1:] = [1] * (len(view_shape) - 1) 44 | repeat_shape = list(idx.shape) 45 | repeat_shape[0] = 1 46 | batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape) 47 | new_points = points[batch_indices, idx, :] 48 | return new_points 49 | 50 | def farthest_point_sample(xyz, npoint): 51 | """ 52 | Input: 53 | xyz: pointcloud data, [B, N, 3] 54 | npoint: number of samples 55 | Return: 56 | centroids: sampled pointcloud index, [B, npoint] 57 | """ 58 | device = xyz.device 59 | B, N, C = xyz.shape 60 | centroids = torch.zeros(B, npoint, dtype=torch.long).to(device) 61 | distance = torch.ones(B, N).to(device) * 1e10 62 | farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) 63 | batch_indices = torch.arange(B, dtype=torch.long).to(device) 64 | for i in range(npoint): 65 | centroids[:, i] = farthest 66 | centroid = xyz[batch_indices, farthest, :].view(B, 1, 3) 67 | dist = torch.sum((xyz - centroid) ** 2, -1) 68 | mask = dist < distance 69 | distance[mask] = dist[mask] 70 | farthest = torch.max(distance, -1)[1] 71 | return centroids 72 | 73 | def query_ball_point(radius, nsample, xyz, new_xyz): 74 | """ 75 | Input: 76 | radius: local region radius 77 | nsample: max sample number in local region 78 | xyz: all points, [B, N, 3] 79 | new_xyz: query points, [B, S, 3] 80 | Return: 81 | group_idx: grouped points index, [B, S, nsample] 82 | """ 83 | device = xyz.device 84 | B, N, C = xyz.shape 85 | _, S, _ = new_xyz.shape 86 | group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1]) 87 | sqrdists = square_distance(new_xyz, xyz) 88 | group_idx[sqrdists > radius ** 2] = N 89 | group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample] 90 | group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample]) 91 | mask = group_idx == N 92 | group_idx[mask] = group_first[mask] 93 | return group_idx 94 | 95 | def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False): 96 | """ 97 | Input: 98 | npoint: 99 | radius: 100 | nsample: 101 | xyz: input points position data, [B, N, 3] 102 | points: input points data, [B, N, D] 103 | Return: 104 | new_xyz: sampled points position data, [B, npoint, nsample, 3] 105 | new_points: sampled points data, [B, npoint, nsample, 3+D] 106 | """ 107 | B, N, C = xyz.shape 108 | S = npoint 109 | fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C] 110 | torch.cuda.empty_cache() 111 | new_xyz = index_points(xyz, fps_idx) 112 | torch.cuda.empty_cache() 113 | idx = query_ball_point(radius, nsample, xyz, new_xyz) 114 | torch.cuda.empty_cache() 115 | grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C] 116 | torch.cuda.empty_cache() 117 | grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C) 118 | torch.cuda.empty_cache() 119 | 120 | if points is not None: 121 | grouped_points = index_points(points, idx) 122 | new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D] 123 | else: 124 | new_points = grouped_xyz_norm 125 | if returnfps: 126 | return new_xyz, new_points, grouped_xyz, fps_idx 127 | else: 128 | return new_xyz, new_points 129 | 130 | def sample_and_group_all(xyz, points): 131 | """ 132 | Input: 133 | xyz: input points position data, [B, N, 3] 134 | points: input points data, [B, N, D] 135 | Return: 136 | new_xyz: sampled points position data, [B, 1, 3] 137 | new_points: sampled points data, [B, 1, N, 3+D] 138 | """ 139 | device = xyz.device 140 | B, N, C = xyz.shape 141 | new_xyz = torch.zeros(B, 1, C).to(device) 142 | grouped_xyz = xyz.unsqueeze(1) 143 | if points is not None: 144 | new_points = torch.cat([grouped_xyz, points.unsqueeze(1)], dim=-1) 145 | else: 146 | new_points = grouped_xyz 147 | return new_xyz, new_points 148 | 149 | class PointNetSetAbstraction(nn.Module): 150 | def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all): 151 | super(PointNetSetAbstraction, self).__init__() 152 | self.npoint = npoint 153 | self.radius = radius 154 | self.nsample = nsample 155 | self.mlp_convs = nn.ModuleList() 156 | self.mlp_bns = nn.ModuleList() 157 | last_channel = in_channel 158 | for out_channel in mlp: 159 | self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1)) 160 | self.mlp_bns.append(nn.BatchNorm2d(out_channel)) 161 | last_channel = out_channel 162 | self.group_all = group_all 163 | 164 | def forward(self, xyz, points): 165 | """ 166 | Input: 167 | xyz: input points position data, [B, C, N] 168 | points: input points data, [B, D, N] 169 | Return: 170 | new_xyz: sampled points position data, [B, C, S] 171 | new_points_concat: sample points feature data, [B, D', S] 172 | """ 173 | xyz = xyz.permute(0, 2, 1) 174 | if points is not None: 175 | points = points.permute(0, 2, 1) 176 | 177 | if self.group_all: 178 | new_xyz, new_points = sample_and_group_all(xyz, points) 179 | else: 180 | new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points) 181 | # new_xyz: sampled points position data, [B, npoint, C] 182 | # new_points: sampled points data, [B, npoint, nsample, C+D] 183 | new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint] 184 | for i, conv in enumerate(self.mlp_convs): 185 | bn = self.mlp_bns[i] 186 | new_points = F.relu(bn(conv(new_points))) 187 | 188 | new_points = torch.max(new_points, 2)[0] 189 | new_xyz = new_xyz.permute(0, 2, 1) 190 | return new_xyz, new_points 191 | 192 | class PointNetSetAbstractionMsg(nn.Module): 193 | def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list): 194 | super(PointNetSetAbstractionMsg, self).__init__() 195 | self.npoint = npoint 196 | self.radius_list = radius_list 197 | self.nsample_list = nsample_list 198 | self.conv_blocks = nn.ModuleList() 199 | self.bn_blocks = nn.ModuleList() 200 | for i in range(len(mlp_list)): 201 | convs = nn.ModuleList() 202 | bns = nn.ModuleList() 203 | last_channel = in_channel + 3 204 | for out_channel in mlp_list[i]: 205 | convs.append(nn.Conv2d(last_channel, out_channel, 1)) 206 | bns.append(nn.BatchNorm2d(out_channel)) 207 | last_channel = out_channel 208 | self.conv_blocks.append(convs) 209 | self.bn_blocks.append(bns) 210 | 211 | def forward(self, xyz, points): 212 | """ 213 | Input: 214 | xyz: input points position data, [B, C, N] 215 | points: input points data, [B, D, N] 216 | Return: 217 | new_xyz: sampled points position data, [B, C, S] 218 | new_points_concat: sample points feature data, [B, D', S] 219 | """ 220 | xyz = xyz.permute(0, 2, 1) 221 | if points is not None: 222 | points = points.permute(0, 2, 1) 223 | 224 | B, N, C = xyz.shape 225 | S = self.npoint 226 | new_xyz = index_points(xyz, farthest_point_sample(xyz, S)) 227 | new_points_list = [] 228 | for i, radius in enumerate(self.radius_list): 229 | K = self.nsample_list[i] 230 | group_idx = query_ball_point(radius, K, xyz, new_xyz) 231 | grouped_xyz = index_points(xyz, group_idx) 232 | grouped_xyz -= new_xyz.view(B, S, 1, C) 233 | if points is not None: 234 | grouped_points = index_points(points, group_idx) 235 | grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1) 236 | else: 237 | grouped_points = grouped_xyz 238 | 239 | grouped_points = grouped_points.permute(0, 3, 2, 1) # [B, D, K, S] 240 | for j in range(len(self.conv_blocks[i])): 241 | conv = self.conv_blocks[i][j] 242 | bn = self.bn_blocks[i][j] 243 | grouped_points = F.relu(bn(conv(grouped_points))) 244 | new_points = torch.max(grouped_points, 2)[0] # [B, D', S] 245 | new_points_list.append(new_points) 246 | 247 | new_xyz = new_xyz.permute(0, 2, 1) 248 | new_points_concat = torch.cat(new_points_list, dim=1) 249 | return new_xyz, new_points_concat 250 | 251 | -------------------------------------------------------------------------------- /pose_generation/models/uninet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.parallel 4 | import torch.utils.data 5 | import torch.nn.functional as F 6 | 7 | from pointnet_utils import PointNetSetAbstractionMsg, PointNetSetAbstraction 8 | 9 | class FeatureExtraction(nn.Module): 10 | def __init__(self, normal_channel=False, mask_channel=False, out_dim=256): 11 | super(FeatureExtraction, self).__init__() 12 | in_channel = 0 13 | if mask_channel: 14 | in_channel += 1 15 | if normal_channel: 16 | in_channel += 3 17 | 18 | self.ext_channel = in_channel 19 | 20 | self.sa1 = PointNetSetAbstractionMsg(512, [0.1, 0.2, 0.4], [16, 32, 128], in_channel, [[32, 32, 64], [64, 64, 128], [64, 96, 128]]) 21 | self.sa2 = PointNetSetAbstractionMsg(128, [0.2, 0.4, 0.8], [32, 64, 128], 320, [[64, 64, 128], [128, 128, 256], [128, 128, 256]]) 22 | self.sa3 = PointNetSetAbstraction(None, None, None, 640 + 3, [256, 512, 1024], True) 23 | 24 | self.fc1 = nn.Linear(1024, 512) 25 | self.bn1 = nn.BatchNorm1d(512) 26 | self.fc2 = nn.Linear(512, out_dim) 27 | self.bn2 = nn.BatchNorm1d(out_dim) 28 | 29 | def forward(self, xyz): 30 | B, C, N = xyz.shape 31 | if self.ext_channel > 0: 32 | norm = xyz[:, 3:, :] 33 | xyz = xyz[:, :3, :] 34 | else: 35 | norm = None 36 | 37 | l1_xyz, l1_points = self.sa1(xyz, norm) 38 | l2_xyz, l2_points = self.sa2(l1_xyz, l1_points) 39 | l3_xyz, l3_points = self.sa3(l2_xyz, l2_points) 40 | 41 | x = l3_points.view(B, -1) 42 | 43 | x = F.relu(self.bn1(self.fc1(x))) 44 | point_feat = F.relu(self.bn2(self.fc2(x))) 45 | return point_feat, l3_points 46 | 47 | class UniNet(nn.Module): 48 | def __init__(self, num_class=2, feat_dim=512, mask_channel=False, normal_channel=False): 49 | super(UniNet, self).__init__() 50 | 51 | self.feat_ext = FeatureExtraction(normal_channel=normal_channel, 52 | mask_channel=mask_channel, 53 | out_dim=feat_dim) 54 | 55 | self.fc1 = nn.Linear(feat_dim, 256) 56 | self.bn1 = nn.BatchNorm1d(256) 57 | self.fc2 = nn.Linear(256, 128) 58 | self.bn2 = nn.BatchNorm1d(128) 59 | self.drop = nn.Dropout(0.3) # drop 30% during training 60 | self.cls = nn.Linear(128, num_class) 61 | # loss function 62 | self.CELoss = nn.CrossEntropyLoss(reduction='none') 63 | 64 | def forward(self, points): 65 | feat, _ = self.feat_ext(points) 66 | feat1 = F.relu(self.bn1(self.fc1(feat))) 67 | feat2 = F.relu(self.bn2(self.fc2(feat1))) 68 | feat2 = self.drop(feat2) 69 | pred = self.cls(feat2) 70 | 71 | return pred 72 | 73 | def get_loss(self, pred_logits, gt_labels): 74 | loss = self.CELoss(pred_logits, gt_labels) 75 | return loss 76 | 77 | 78 | if __name__ == '__main__': 79 | # setting 1 80 | point_w_mask = torch.ones((2, 4, 513)) 81 | model1 = UniNet(mask_channel=True) 82 | pred = model1(point_w_mask) 83 | print(pred.shape) 84 | -------------------------------------------------------------------------------- /pose_generation/models/vnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.parallel 4 | import torch.utils.data 5 | import torch.nn.functional as F 6 | import pytorch3d.transforms as torch_transform 7 | 8 | from .pointnet_utils import PointNetSetAbstractionMsg, PointNetSetAbstraction 9 | 10 | def assert_not_nan(x, info): 11 | assert not torch.isnan(x).sum(), info 12 | 13 | class FeatureExtraction(nn.Module): 14 | def __init__(self, normal_channel=False, mask_channel=False, out_dim=128): 15 | super(FeatureExtraction, self).__init__() 16 | in_channel = 0 17 | if mask_channel: 18 | in_channel += 1 19 | if normal_channel: 20 | in_channel += 3 21 | 22 | self.ext_channel = in_channel 23 | 24 | self.sa1 = PointNetSetAbstractionMsg(512, [0.1, 0.2, 0.4], [16, 32, 128], in_channel, [[16, 16, 32], [32, 32, 64], [32, 32, 64]]) 25 | self.sa2 = PointNetSetAbstractionMsg(128, [0.2, 0.4, 0.8], [32, 64, 128], 160, [[32, 32, 64], [64, 64, 128], [64, 64, 128]]) 26 | self.sa3 = PointNetSetAbstraction(None, None, None, 320 + 3, [128, 256, 512], True) 27 | 28 | self.fc1 = nn.Linear(512, 256) 29 | self.bn1 = nn.BatchNorm1d(256) 30 | self.fc2 = nn.Linear(256, out_dim) 31 | self.bn2 = nn.BatchNorm1d(out_dim) 32 | 33 | def forward(self, xyz): 34 | B, C, N = xyz.shape 35 | if self.ext_channel > 0: 36 | norm = xyz[:, 3:, :] 37 | xyz = xyz[:, :3, :] 38 | else: 39 | norm = None 40 | 41 | l1_xyz, l1_points = self.sa1(xyz, norm) 42 | 43 | assert_not_nan(l1_xyz, '!l1_xyz, xyz: {} {}'.format( 44 | xyz.min(), xyz.max())) 45 | assert_not_nan(l1_points, '!l1_points, xyz: {} {}'.format( 46 | xyz.min(), xyz.max() 47 | )) 48 | 49 | l2_xyz, l2_points = self.sa2(l1_xyz, l1_points) 50 | 51 | assert_not_nan(l2_xyz, '!l2_xyz, l1_xyz: {} {}, l1_point: {} {}'.format( 52 | l1_xyz.min(), l1_xyz.max(), l1_points.min(), l1_points.max())) 53 | 54 | assert_not_nan(l2_points, '!l2_point, l1_xyz: {} {}, l1_point: {} {}'.format( 55 | l1_xyz.min(), l1_xyz.max(), l1_points.min(), l1_points.max())) 56 | 57 | l3_xyz, l3_points = self.sa3(l2_xyz, l2_points) 58 | 59 | assert_not_nan(l3_xyz, '!l3_xyz, l2_xyz: {} {}, l2_point: {} {}'.format( 60 | l2_xyz.min(), l2_xyz.max(), l2_points.min(), l2_points.max())) 61 | 62 | assert_not_nan(l3_points, '!l3_points, l2_xyz: {} {}, l2_point: {} {}'.format( 63 | l2_xyz.min(), l2_xyz.max(), l2_points.min(), l2_points.max())) 64 | 65 | x = l3_points.view(B, -1) 66 | 67 | x = F.relu(self.bn1(self.fc1(x))) 68 | 69 | assert_not_nan(x, '!pre_l, l3: {} {}'.format( 70 | l3_points.min(), l3_points.max() 71 | )) 72 | 73 | point_feat = F.relu(self.bn2(self.fc2(x))) 74 | 75 | assert_not_nan(point_feat, '!point feat, x: {} {}'.format( 76 | x.min(), x.max() 77 | )) 78 | 79 | return point_feat, l3_points 80 | 81 | class VNet(nn.Module): 82 | def __init__(self, sup_feat=128, obj_feat=128, z_feat=64, z_dim=3, rot_rep='6d', 83 | mask_channel=False, normal_channel=False, only_test=False): 84 | super(VNet, self).__init__() 85 | 86 | self.rotation_rep = rot_rep 87 | 88 | self.sup_feat_ext = FeatureExtraction(normal_channel=normal_channel, 89 | mask_channel=mask_channel, 90 | out_dim=sup_feat) 91 | self.obj_feat_ext = FeatureExtraction(normal_channel=normal_channel, 92 | mask_channel=mask_channel, 93 | out_dim=obj_feat) 94 | 95 | self.z_feat_ext = nn.Sequential(*[nn.Conv1d(z_dim, 32, kernel_size=(1, )), 96 | nn.BatchNorm1d(32), 97 | nn.ReLU(), 98 | nn.Conv1d(32, z_feat, kernel_size=(1, )), 99 | nn.BatchNorm1d(z_feat), 100 | nn.ReLU()]) 101 | 102 | 103 | # feature fusing 104 | self.fc1 = nn.Conv1d(sup_feat+obj_feat+64, 128, kernel_size=(1, )) 105 | self.bn1 = nn.BatchNorm1d(128) 106 | self.fc2 = nn.Conv1d(128, 64, kernel_size=(1, )) 107 | self.bn2 = nn.BatchNorm1d(64) 108 | 109 | self.drop = nn.Dropout(0.1) # drop 10% during training 110 | 111 | if rot_rep == 'axis_angle': 112 | d = 6 113 | elif rot_rep == '6d': 114 | d = 9 115 | else: 116 | raise NotImplementedError 117 | 118 | self.pose_reg = nn.Sequential(*[nn.Conv1d(64, 32, kernel_size=(1,)), 119 | nn.Conv1d(32, d, kernel_size=(1,))]) 120 | 121 | self.only_test = only_test 122 | 123 | def forward(self, samp_dict): 124 | ''' 125 | 126 | :param sup_points: (B, 3, N) 127 | :param obj_points: (B, 3, N) 128 | :param z: (B, C, M) 129 | :return: 130 | ''' 131 | 132 | sup_points = samp_dict['support'] 133 | obj_points = samp_dict['object'] 134 | z = samp_dict['z_noise'] 135 | 136 | assert not torch.isnan(sup_points).sum(), '# {}, {} contain NaN'.format(samp_dict['sup_path'], sup_points.shape) 137 | assert not torch.isnan(obj_points).sum(), '# {}, {} contain NaN'.format(samp_dict['obj_path'], obj_points.shape) 138 | 139 | sup_feat, _ = self.sup_feat_ext(sup_points) # (B, C1) 140 | 141 | assert not torch.isnan(sup_feat).sum(), '# Support {} [{} {}] feature contain NaN.'.format(sup_points.shape, 142 | sup_points.min(), 143 | sup_points.max() 144 | ) 145 | 146 | obj_feat, _ = self.obj_feat_ext(obj_points) # (B, C2) 147 | 148 | assert not torch.isnan(obj_feat).sum(), '# Object {} [{} {}] feature contain NaN.'.format(obj_points.shape, 149 | obj_points.min(), 150 | obj_points.max() 151 | ) 152 | 153 | z_feat = self.z_feat_ext(z) # (B, C3, M) 154 | 155 | assert not torch.isnan(z_feat).sum(), '# Z feature contain NaN.' 156 | 157 | M = z_feat.shape[2] 158 | 159 | sup_feat_rpt = sup_feat.unsqueeze(2).repeat((1, 1, M)) 160 | obj_feat_rpt = obj_feat.unsqueeze(2).repeat((1, 1, M)) 161 | 162 | fuse_feat = torch.cat([sup_feat_rpt, obj_feat_rpt, z_feat], dim=1) 163 | feat1 = F.relu(self.bn1(self.fc1(fuse_feat))) 164 | 165 | assert not torch.isnan(feat1).sum(), '# Deep Feature1 contain NaN.' 166 | 167 | feat2 = F.relu(self.bn2(self.fc2(feat1))) 168 | 169 | assert not torch.isnan(feat2).sum(), '# Deep Feature2 contain NaN.' 170 | 171 | pred = self.pose_reg(feat2) # (B, 6, M) 172 | 173 | assert not torch.isnan(pred).sum(), '# Raw predictions contain NaN.' 174 | 175 | pred_transforms = self.compute_transforms(pred, self.rotation_rep) 176 | 177 | assert not torch.isnan(pred_transforms).sum(), '# Transforms contain NaN.' 178 | 179 | ret = {'pred': pred_transforms} 180 | if self.only_test: 181 | return ret 182 | 183 | p2g_loss, g2p_loss, pred_pc, selt_pc = self.get_projection_loss(pred_transforms=pred_transforms, 184 | gt_transforms=samp_dict['transforms'], 185 | object_pc=obj_points) 186 | 187 | loss = p2g_loss + g2p_loss 188 | ret['loss'] = loss 189 | ret['pred_pc'] = pred_pc 190 | ret['selt_pc'] = selt_pc 191 | 192 | assert not torch.isnan(loss).sum(), '# Loss contain NaN.' 193 | return ret 194 | 195 | def compute_transforms(self, pred, rep=''): 196 | ''' 197 | :param pred: (B, n, M) 198 | :return: (B, M, 4, 4) 199 | ''' 200 | B, d, M = pred.shape 201 | pred = pred.permute(0, 2, 1) # (B, M, 6) 202 | pred_trs = pred[..., :3].unsqueeze(3) # (B, M, 3, 1) 203 | 204 | if rep == 'axis_angle': 205 | assert d == 6, pred.shape 206 | pred_rot = torch_transform.axis_angle_to_matrix(pred[..., 3:]) 207 | elif rep == '6d': 208 | assert d == 9, pred.shape 209 | pred_rot = torch_transform.rotation_6d_to_matrix(pred[..., 3:]) 210 | else: 211 | raise NotImplementedError 212 | 213 | transform = torch.cat([pred_rot, pred_trs], dim=3) # (B, M, 3, 4) 214 | ones = torch.tensor([0, 0, 0, 1], 215 | device=transform.device).view(1, 1, 1, 4) 216 | ones = ones.repeat((B, M, 1, 1)) 217 | transform = torch.cat([transform, ones], dim=2) # (B, M, 4, 4) 218 | return transform 219 | 220 | def get_projection_loss(self, pred_transforms: torch.Tensor, gt_transforms: torch.Tensor, object_pc: torch.Tensor): 221 | ''' 222 | 223 | :param pred_transforms: (B, M1, 4, 4) 224 | :param gt_transforms: (B, M2, 4, 4) 225 | :param object_pc: (B, 3, N) 226 | :return: 227 | ''' 228 | 229 | 230 | N = object_pc.shape[2] 231 | B = object_pc.shape[0] 232 | M1 = pred_transforms.shape[1] 233 | M2 = gt_transforms.shape[1] 234 | 235 | ones = torch.ones((B, 1, N), device=object_pc.device) 236 | object_pc = torch.cat([object_pc, ones], dim=1).permute(0, 2, 1).unsqueeze(1).unsqueeze(4) # (B, 1, N, 4, 1) 237 | 238 | gt_object_ = torch.matmul(gt_transforms.unsqueeze(2), object_pc)[..., :3, 0] # (B, M2, N, 3) 239 | pred_object_ = torch.matmul(pred_transforms.unsqueeze(2), object_pc)[..., :3, 0] # (B, M1, N, 3) 240 | 241 | pred_object = pred_object_.unsqueeze(2) # (B, M1, 1, N, 3) 242 | gt_object = gt_object_.unsqueeze(1) # (B, 1, M2, N, 3) 243 | 244 | m1m2_dist = torch.sum((pred_object - gt_object) ** 2, dim=4, keepdim=False) # (B, M1, M2, N) 245 | 246 | m1m2_dist = torch.mean(m1m2_dist, dim=3, keepdim=False) # (B, M1, M2) 247 | 248 | p2g_loss, g_ind = torch.min(m1m2_dist, dim=2, keepdim=False) # (B, M1) 249 | g2p_loss, p_ind = torch.min(m1m2_dist, dim=1, keepdim=False) # (B, M2) 250 | 251 | # arrange gt object using selected index 252 | g_ind = g_ind.view((B, M1, 1, 1)).repeat((1, 1, N, 3)) 253 | selt_object = torch.gather(gt_object_, dim=1, index=g_ind) # (B, M1, N, 3) 254 | # 255 | return torch.mean(p2g_loss), torch.mean(g2p_loss), pred_object_, selt_object 256 | 257 | def get_direct_loss(self, pred_transforms: torch.Tensor, gt_transforms: torch.Tensor, object_pc: torch.Tensor): 258 | ''' 259 | 260 | :param pred_transforms: (B, M1, 4, 4) 261 | :param gt_transforms: (B, M2, 4, 4) 262 | :param object_pc: (B, 3, N) 263 | :return: 264 | ''' 265 | N = object_pc.shape[2] 266 | B = object_pc.shape[0] 267 | pred_transforms = pred_transforms.unsqueeze(2) 268 | gt_transforms = gt_transforms.unsqueeze(1) 269 | m1m2_dist = torch.sum(torch.sum((pred_transforms - gt_transforms) ** 2, 270 | dim=4, keepdim=False), dim=3, keepdim=False) # (B, M1, M2) 271 | 272 | p2g_loss, _ = torch.min(m1m2_dist, dim=2, keepdim=False) 273 | g2p_loss, _ = torch.min(m1m2_dist, dim=1, keepdim=False) 274 | return torch.mean(p2g_loss), torch.mean(g2p_loss) 275 | 276 | 277 | if __name__ == '__main__': 278 | import torch.optim as optim 279 | 280 | point1 = torch.ones((2, 3, 320)) # 281 | point2 = torch.ones((2, 3, 200)) # 282 | z_noise = torch.ones((2, 4, 7)) 283 | # setting 2 284 | model = VNet(mask_channel=False, rot_rep='6d', 285 | z_dim=4, obj_feat=16, sup_feat=16) 286 | 287 | gt_trans = torch.eye(4).unsqueeze(0).unsqueeze(1).repeat((2, 3, 1, 1)) # (2, 3, 4, 4) 288 | gt_trans[0, 0, 2, 3] = 1 289 | print('gt shape: ', gt_trans.shape) 290 | print(gt_trans[0, 0]) 291 | optimizer = optim.Adam(model.parameters(), lr=0.01) 292 | 293 | sample = {'support': point1, 'object': point2, 294 | 'z_noise': z_noise, 'transforms': gt_trans} 295 | 296 | for ep in range(1000): 297 | optimizer.zero_grad() 298 | ret = model.forward(sample) 299 | tot_loss = ret['loss'] 300 | T = ret['pred'] 301 | print('step: ', ep) 302 | print('loss: ', tot_loss) 303 | print('T: ', T[0, 0]) 304 | print('pred_pc: ', ret['pred_pc'].shape) 305 | print('selt_pc: ', ret['selt_pc'].shape) 306 | tot_loss.backward() 307 | optimizer.step() 308 | 309 | -------------------------------------------------------------------------------- /pose_generation/train_impl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.backends.cudnn as cudnn 4 | import torch.optim as optim 5 | import torch.distributed as dist 6 | from torch.utils.data import DataLoader 7 | from torch.distributions import MultivariateNormal 8 | import torch.nn.functional as F 9 | import os 10 | import numpy as np 11 | import open3d as o3d 12 | from tensorboardX import SummaryWriter 13 | 14 | from data_loader.stable_pose_dataset import StablePoseDataset 15 | import argparse, os, sys, time, datetime 16 | import os.path as osp 17 | from utils.utils import get_linear_schedule_with_warmup, \ 18 | get_step_schedule_with_warmup, dict2cuda, add_summary, \ 19 | DictAverageMeter 20 | 21 | cudnn.benchmark = True 22 | 23 | parser = argparse.ArgumentParser(description='Deep stereo using adaptive cost volume.') 24 | parser.add_argument('--root_path', type=str, help='path to root directory.') 25 | parser.add_argument('--list_path', type=str, help='train scene list.', default='') 26 | parser.add_argument('--save_path', type=str, help='path to save checkpoints.') 27 | parser.add_argument('--restore_path', type=str, default='') 28 | parser.add_argument('--net_arch', type=str, default='vnet2') 29 | 30 | parser.add_argument('--epochs', type=int, default=20) 31 | parser.add_argument('--lr', type=float, default=0.001) 32 | parser.add_argument('--lr_idx', type=str, default="50,100,160:0.5") 33 | parser.add_argument('--wd', type=float, default=0.0, help='weight decay') 34 | parser.add_argument('--batch_size', type=int, default=32) 35 | parser.add_argument('--z_dim', type=int, default=3) 36 | parser.add_argument('--pose_num', type=int, default=128) 37 | parser.add_argument('--rot_rp', type=str, default='6d') 38 | 39 | parser.add_argument('--log_freq', type=int, default=1, help='print and summary frequency') 40 | parser.add_argument('--save_freq', type=int, default=2000, help='save checkpoint frequency.') 41 | 42 | parser.add_argument('--sync_bn', action='store_true',help='Sync BN.') 43 | parser.add_argument('--opt_level', type=str, default="O0") 44 | parser.add_argument('--seed', type=int, default=0) 45 | parser.add_argument('--local_rank', type=int, default=0) 46 | parser.add_argument('--num_workers', type=int, default=4) 47 | parser.add_argument('--distributed', action='store_true') 48 | 49 | args = parser.parse_args() 50 | 51 | # num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 52 | is_distributed = args.distributed 53 | 54 | torch.manual_seed(args.seed) 55 | torch.cuda.manual_seed(args.seed) 56 | device = torch.device("cuda") 57 | 58 | if args.sync_bn: 59 | import apex 60 | import apex.amp as amp 61 | 62 | def print_func(data: dict, prefix: str= ''): 63 | for k, v in data.items(): 64 | if isinstance(v, dict): 65 | print_func(v, prefix + '.' + k) 66 | elif isinstance(v, list): 67 | print(prefix+'.'+k, v) 68 | else: 69 | print(prefix+'.'+k, v.shape) 70 | 71 | def sample_from_gaussian(d, batch_size, num_samples): 72 | m = MultivariateNormal(torch.zeros(d), torch.eye(d)) 73 | z_noise = m.sample((batch_size, num_samples)) 74 | z_noise = z_noise.permute(0, 2, 1) # (B, 3, M) 75 | return z_noise 76 | 77 | def add_point_cloud(pred_pc, gt_pc, sup_pc, logger, step, flag): 78 | t2n = lambda x: x.detach().cpu().numpy() 79 | M = pred_pc.shape[1] 80 | N1 = pred_pc.shape[2] 81 | pred_pc = t2n(pred_pc[0]) 82 | gt_pc = t2n(gt_pc[0]) 83 | 84 | sup_pc:np.array = t2n(sup_pc[0]).T # (N, 3) 85 | sup_pc = np.repeat(np.expand_dims(sup_pc, axis=0), repeats=M, axis=0) 86 | N2 = sup_pc.shape[1] 87 | 88 | pred_tags = ['pred_sample_{}'.format(x) for x in range(M)] 89 | gt_tags = ['gt_sample_{}'.format(x) for x in range(M)] 90 | 91 | pred_pc = np.concatenate([pred_pc, sup_pc], axis=1) 92 | gt_pc = np.concatenate([gt_pc, sup_pc], axis=1) 93 | 94 | gt_colors = np.zeros((M, N1+N2, 3)) 95 | gt_colors[:, :N1, 1] = 255 96 | gt_colors[:, N1:, 0] = 100 97 | gt_colors[:, N1:, 1] = 100 98 | gt_colors[:, N1:, 2] = 200 99 | 100 | pred_colors = np.zeros((M, N1+N2, 3)) 101 | pred_colors[:, :N1, 1] = 80 102 | pred_colors[:, N1:, 0] = 100 103 | pred_colors[:, N1:, 1] = 100 104 | pred_colors[:, N1:, 2] = 200 105 | 106 | add_summary([{'type': 'points', 'tags': pred_tags, 107 | 'vals': [pred_pc, pred_colors]}, 108 | {'type': 'points', 'tags': gt_tags, 109 | 'vals': [gt_pc, gt_colors]}], 110 | logger=logger, step=step, flag=flag, max_disp=2) 111 | 112 | def main(args, model, optimizer, scheduler, train_loader, train_sampler, start_step=0): 113 | 114 | train_step = start_step 115 | start_ep = start_step // len(train_loader) 116 | 117 | model.train() 118 | for ep in range(start_ep, args.epochs): 119 | np.random.seed() 120 | train_scores = DictAverageMeter() 121 | if train_sampler is not None: 122 | train_sampler.set_epoch(ep) 123 | 124 | for batch_idx, sample in enumerate(train_loader): 125 | tic = time.time() 126 | sample['z_noise'] = sample_from_gaussian(args.z_dim, 127 | args.batch_size, 128 | args.pose_num) 129 | sample_cuda = dict2cuda(sample) 130 | 131 | # print_func(sample_cuda) 132 | optimizer.zero_grad() 133 | ret = model(sample_cuda) 134 | loss = ret['loss'] 135 | 136 | # print_func(outputs) 137 | if is_distributed and args.sync_bn: 138 | with amp.scale_loss(loss, optimizer) as scaled_loss: 139 | scaled_loss.backward() 140 | else: 141 | loss.backward() 142 | 143 | optimizer.step() 144 | scheduler.step() 145 | 146 | train_scores.update({'loss': loss.item()}) 147 | 148 | train_step += 1 149 | if train_step % args.log_freq == 0: 150 | avg_stat = train_scores.mean() 151 | print("[Rank: {}] time={:.2f} Epoch {}/{}, Iter {}/{}, lr {:.6f}, stats: {}".format( 152 | args.local_rank, time.time() - tic, 153 | ep+1, args.epochs, batch_idx+1, len(train_loader), 154 | optimizer.param_groups[0]["lr"], 155 | avg_stat)) 156 | if on_main: 157 | add_point_cloud(ret['pred_pc'], ret['selt_pc'], sample['support'], 158 | logger=logger, step=train_step, flag='train') 159 | add_summary([{'type': 'scalars', 'tags': list(avg_stat.keys()), 160 | 'vals': list(avg_stat.values())}], 161 | logger=logger, step=train_step, flag='train') 162 | 163 | if on_main and train_step % args.save_freq == 0: 164 | torch.save({"step": train_step, 165 | "model": model.module.state_dict(), 166 | "optimizer": optimizer.state_dict(), 167 | "scheduler": scheduler.state_dict(), 168 | }, 169 | "{}/model_{:08d}.ckpt".format(args.save_path, train_step)) 170 | 171 | def distribute_model(args): 172 | def sync(): 173 | if not dist.is_available(): 174 | return 175 | if not dist.is_initialized(): 176 | return 177 | if dist.get_world_size() == 1: 178 | return 179 | dist.barrier() 180 | if is_distributed: 181 | torch.cuda.set_device(args.local_rank) 182 | torch.distributed.init_process_group( 183 | backend="nccl", init_method="env://" 184 | ) 185 | sync() 186 | 187 | start_step = 0 188 | 189 | if args.net_arch == 'vnet2': 190 | from models.vnet import VNet 191 | print('use vnet2!') 192 | else: 193 | raise NotImplementedError 194 | 195 | model: torch.nn.Module = VNet(mask_channel=False, rot_rep=args.rot_rp, 196 | z_dim=args.z_dim, obj_feat=128, sup_feat=128, z_feat=64) 197 | if args.restore_path: 198 | checkpoint = torch.load(args.restore_path, map_location=torch.device("cpu")) 199 | model.load_state_dict(checkpoint['model'], strict=True) 200 | 201 | model.to(device) 202 | 203 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, betas=(0.9, 0.999), 204 | weight_decay=args.wd) 205 | 206 | train_set = StablePoseDataset(root_dir=args.root_path, list_path=args.list_path, 207 | pose_num=args.pose_num, pc_len=1024, use_aug=True) 208 | 209 | if is_distributed: 210 | if args.sync_bn: 211 | model = apex.parallel.convert_syncbn_model(model) 212 | model, optimizer = amp.initialize(model, optimizer, opt_level=args.opt_level, ) 213 | print('Convert BN to Sync_BN successful.') 214 | 215 | model = torch.nn.parallel.DistributedDataParallel( 216 | model, device_ids=[args.local_rank], output_device=args.local_rank,) 217 | 218 | train_sampler = torch.utils.data.DistributedSampler(train_set, num_replicas=dist.get_world_size(), 219 | rank=dist.get_rank()) 220 | 221 | else: 222 | model = nn.DataParallel(model) 223 | train_sampler = None 224 | 225 | def worker_init_fn(worker_id): 226 | np.random.seed(np.random.get_state()[1][0] + worker_id) 227 | 228 | train_loader = DataLoader(train_set, args.batch_size, sampler=train_sampler, 229 | num_workers=args.num_workers, pin_memory=True, 230 | drop_last=True, shuffle=not is_distributed, worker_init_fn=worker_init_fn) 231 | 232 | milestones = list(map(float, args.lr_idx.split(':')[0].split(','))) 233 | assert np.max(milestones) <= 1.0, milestones 234 | milestones = list(map(lambda x: int(float(x) * float(len(train_loader) * args.epochs)), milestones)) 235 | gamma = float(args.lr_idx.split(':')[1]) 236 | warpup_iters = min(500, int(0.05*len(train_loader))) 237 | 238 | scheduler = get_step_schedule_with_warmup(optimizer=optimizer, milestones=milestones, 239 | gamma=gamma, warmup_iters=warpup_iters) 240 | 241 | if args.restore_path: 242 | optimizer.load_state_dict(checkpoint['optimizer']) 243 | scheduler.load_state_dict(checkpoint['scheduler']) 244 | start_step = checkpoint['step'] 245 | print("Restoring checkpoint {} ...".format(args.restore_path)) 246 | 247 | return model, optimizer, scheduler, train_loader, train_sampler, start_step 248 | 249 | 250 | if __name__ == '__main__': 251 | model, optimizer, scheduler, train_loader, train_sampler, start_step = distribute_model(args) 252 | on_main = (not is_distributed) or (dist.get_rank() == 0) 253 | if on_main: 254 | os.makedirs(args.save_path, exist_ok=True) 255 | logger = SummaryWriter(args.save_path) 256 | print(args) 257 | 258 | main(args=args, model=model, optimizer=optimizer, scheduler=scheduler, 259 | train_loader=train_loader, train_sampler=train_sampler, start_step=start_step) 260 | -------------------------------------------------------------------------------- /pose_generation/utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim.lr_scheduler import LambdaLR 3 | import torchvision.utils as vutils 4 | import torch.distributed as dist 5 | 6 | import errno 7 | import os 8 | import re 9 | import sys 10 | import numpy as np 11 | from bisect import bisect_right 12 | 13 | 14 | def dict2cuda(data: dict): 15 | new_dic = {} 16 | for k, v in data.items(): 17 | if isinstance(v, dict): 18 | v = dict2cuda(v) 19 | elif isinstance(v, torch.Tensor): 20 | v = v.cuda() 21 | new_dic[k] = v 22 | return new_dic 23 | 24 | def dict2numpy(data: dict): 25 | new_dic = {} 26 | for k, v in data.items(): 27 | if isinstance(v, dict): 28 | v = dict2numpy(v) 29 | elif isinstance(v, torch.Tensor): 30 | v = v.detach().cpu().numpy().copy() 31 | new_dic[k] = v 32 | return new_dic 33 | 34 | def dict2float(data: dict): 35 | new_dic = {} 36 | for k, v in data.items(): 37 | if isinstance(v, dict): 38 | v = dict2float(v) 39 | elif isinstance(v, torch.Tensor): 40 | v = v.detach().cpu().item() 41 | new_dic[k] = v 42 | return new_dic 43 | 44 | def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1): 45 | """ Create a schedule with a learning rate that decreases linearly after 46 | linearly increasing during a warmup period. 47 | """ 48 | def lr_lambda(current_step): 49 | if current_step < num_warmup_steps: 50 | return float(current_step) / float(max(1, num_warmup_steps)) 51 | return max( 52 | 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)) 53 | ) 54 | 55 | return LambdaLR(optimizer, lr_lambda, last_epoch) 56 | 57 | def get_step_schedule_with_warmup(optimizer, milestones, gamma=0.1, warmup_factor=1.0/3, warmup_iters=500, last_epoch=-1,): 58 | def lr_lambda(current_step): 59 | if current_step < warmup_iters: 60 | alpha = float(current_step) / warmup_iters 61 | current_factor = warmup_factor * (1. - alpha) + alpha 62 | else: 63 | current_factor = 1. 64 | 65 | return max(0.0, current_factor * (gamma ** bisect_right(milestones, current_step))) 66 | 67 | return LambdaLR(optimizer, lr_lambda, last_epoch) 68 | 69 | def add_summary(data_items: list, logger, step: int, flag: str, max_disp=2): 70 | for data_item in data_items: 71 | tags = data_item['tags'] 72 | vals = data_item['vals'] 73 | dtype = data_item['type'] 74 | if dtype == 'points': 75 | b = min(max_disp, len(tags)) 76 | logger.add_mesh('{}/{}'.format(flag, tags[:b]), 77 | vertices=vals[0][:b], colors=vals[1][:b], global_step=step) 78 | 79 | elif dtype == 'scalars': 80 | for tag, val in zip(tags, vals): 81 | if val == 'None': 82 | val = 0 83 | logger.add_scalar('{}/{}'.format(flag, tag), 84 | val, global_step=step) 85 | else: 86 | raise NotImplementedError 87 | 88 | class DictAverageMeter(object): 89 | def __init__(self): 90 | self.data = {} 91 | 92 | def update(self, new_input: dict): 93 | for k, v in new_input.items(): 94 | if isinstance(v, list): 95 | self.data[k] = self.data.get(k, []) + v 96 | else: 97 | assert (isinstance(v, float) or isinstance(v, int)), type(v) 98 | self.data[k] = self.data.get(k, []) + [v] 99 | 100 | def mean(self): 101 | ret = {} 102 | for k, v in self.data.items(): 103 | if not v: 104 | ret[k] = 'None' 105 | else: 106 | ret[k] = np.round(np.mean(v), 4) 107 | return ret 108 | 109 | def reset(self): 110 | self.data = {} 111 | -------------------------------------------------------------------------------- /real_data/plys/bowl1.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/touristCheng/Learning2Regrasp/8152c539af6538fd8f4b9fe328ec4ca314abd74c/real_data/plys/bowl1.ply -------------------------------------------------------------------------------- /real_data/plys/pc_inworld_object_1.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/touristCheng/Learning2Regrasp/8152c539af6538fd8f4b9fe328ec4ca314abd74c/real_data/plys/pc_inworld_object_1.ply -------------------------------------------------------------------------------- /real_data/plys/pc_inworld_object_2.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/touristCheng/Learning2Regrasp/8152c539af6538fd8f4b9fe328ec4ca314abd74c/real_data/plys/pc_inworld_object_2.ply -------------------------------------------------------------------------------- /real_data/plys/pc_inworld_object_3.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/touristCheng/Learning2Regrasp/8152c539af6538fd8f4b9fe328ec4ca314abd74c/real_data/plys/pc_inworld_object_3.ply -------------------------------------------------------------------------------- /real_data/plys/pc_inworld_support_2.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/touristCheng/Learning2Regrasp/8152c539af6538fd8f4b9fe328ec4ca314abd74c/real_data/plys/pc_inworld_support_2.ply -------------------------------------------------------------------------------- /real_data/plys/pc_inworld_support_3.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/touristCheng/Learning2Regrasp/8152c539af6538fd8f4b9fe328ec4ca314abd74c/real_data/plys/pc_inworld_support_3.ply -------------------------------------------------------------------------------- /real_data/real_data.txt: -------------------------------------------------------------------------------- 1 | pc_inworld_support_2-pc_inworld_object_2 2 | pc_inworld_support_3-pc_inworld_object_3 3 | bowl1-pc_inworld_object_1 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorboardX 2 | torch==1.8 3 | pytorch3d 4 | torchvision==0.9 5 | open3d 6 | dominate 7 | plyfile==0.7.3 8 | scikit-learn==0.24.1 9 | scipy==1.6.2 10 | pybullet 11 | trimesh 12 | -------------------------------------------------------------------------------- /scripts/evaluate_testset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | import json 4 | from collections import OrderedDict 5 | import numpy as np 6 | import os.path as osp 7 | import matplotlib.pyplot as plt 8 | 9 | ####### 10 | 11 | 12 | g_ckpt="./checkpoints/generator/model_00368000.ckpt" 13 | c_ckpt="./checkpoints/classifier/model_00160000.ckpt" 14 | 15 | root_path="./dataset/test_plys" 16 | all_test_list = glob('./dataset/data_list/test_groups/*.txt') 17 | 18 | test_root_dir = './test_results/' 19 | 20 | ####### 21 | 22 | mesh_dir = './dataset/test_urdf' 23 | init_dir = './dataset/test_plys' 24 | total_rounds = range(1, 6) 25 | 26 | def eval_a_category(data_list, pred_dir, save_dir, mesh_dir, init_dir): 27 | data_pairs = open(data_list, 'r').readlines() 28 | data_pairs = [str(x).strip().split('-') for x in data_pairs] 29 | 30 | for sup_name, obj_name in data_pairs: 31 | sup_urdf = '{}/{}.urdf'.format(mesh_dir, sup_name) 32 | obj_urdf = '{}/{}.urdf'.format(mesh_dir, obj_name) 33 | init_sup_pose = '{}/{}_init_pose.txt'.format(init_dir, sup_name) 34 | init_obj_pose = '{}/{}_init_pose.txt'.format(init_dir, obj_name) 35 | pose_dir = '{}/{}-{}'.format(pred_dir, sup_name, obj_name) 36 | 37 | cmd = 'python3 evaluate_poses.py --obj_path {} --sup_path {} ' \ 38 | '--init_obj_pose {} --init_sup_pose {} --transforms {} --save_dir {}' \ 39 | .format(obj_urdf, sup_urdf, init_obj_pose, init_sup_pose, pose_dir, save_dir) 40 | os.system(cmd) 41 | 42 | def read_file(path): 43 | with open(path, 'r') as fd: 44 | data_list = json.load(fd, object_hook=OrderedDict) 45 | return data_list 46 | 47 | def write_file(path, data_list): 48 | dir_name = osp.dirname(path) 49 | if dir_name: 50 | os.makedirs(dir_name, exist_ok=True) 51 | with open(path, 'w') as f: 52 | json.dump(data_list, f) 53 | 54 | def plot(x_names, means, stds, ax, max_y=1, title=''): 55 | xs = np.arange(len(x_names)) 56 | ax.bar(xs, means, yerr=stds, align='center', alpha=0.5, ecolor='green', capsize=3) 57 | ax.set_xticks(xs, ) 58 | ax.set_xticklabels(x_names, rotation=75) 59 | ax.set_ylim(0, max_y) 60 | ax.yaxis.grid(True) 61 | ax.axhline(y=np.mean(means), color='r', linestyle='--') 62 | ax.set_ylabel(title) 63 | 64 | for x, val in zip(xs, means): 65 | ax.annotate("{:.3f}".format(val), 66 | xy = (x, 0.1), # top left corner of the histogram bar 67 | xytext = (0,0.2), # offsetting label position above its bar 68 | textcoords = "offset points", # Offset (in points) from the *xy* value 69 | ha = 'center', va = 'bottom' 70 | ) 71 | 72 | def load_eval_results(acc_stats, cnt_stats, root_dir, cat_name, thresh): 73 | all_paths = glob('{}/*.json'.format(root_dir, )) 74 | for path in all_paths: 75 | # pair_name = osp.basename(path).split('.')[0] 76 | # cat_name = pair_name # 77 | 78 | data = read_file(path)[0] 79 | acc_val = float(data['acc'][str(thresh)]) 80 | cnt_val = float(data['cnt'][str(thresh)]) 81 | acc_stats[cat_name] = acc_stats.get(cat_name, []) + [acc_val] 82 | cnt_stats[cat_name] = cnt_stats.get(cat_name, []) + [cnt_val] 83 | 84 | def eval_all(): 85 | 86 | for path in all_test_list: 87 | pair_name = osp.basename(path).split('.')[0] 88 | 89 | for i in total_rounds: 90 | testset_root = '{}/round_{}'.format(test_root_dir, i) 91 | save_dir = '{}/results/{}'.format(testset_root, pair_name) 92 | pred_dir = '{}/{}'.format(testset_root, pair_name) 93 | eval_a_category(data_list=path, 94 | pred_dir=pred_dir, 95 | save_dir=save_dir, 96 | mesh_dir=mesh_dir, 97 | init_dir=init_dir) 98 | 99 | def stats(thresh=0.8, ): 100 | acc_stats = {} 101 | cnt_stats = {} 102 | 103 | for path in all_test_list: 104 | pair_name = osp.basename(path).split('.')[0] 105 | for i in total_rounds: 106 | testset_root = '{}/round_{}'.format(test_root_dir, i) 107 | save_dir = '{}/results/{}'.format(testset_root, pair_name) 108 | load_eval_results(acc_stats=acc_stats, cnt_stats=cnt_stats, 109 | root_dir=save_dir, thresh=thresh, cat_name=pair_name) 110 | 111 | def process_stat(stat, ): 112 | x_names = sorted(list(stat.keys())) 113 | means = [] 114 | stds = [] 115 | for k in x_names: 116 | mean_k = np.mean(stat[k]) 117 | std_k = np.std(stat[k]) 118 | means.append(mean_k) 119 | stds.append(std_k) 120 | x_names = x_names + ['mean', ] 121 | mean_ = np.nanmean(means) 122 | std_ = np.nanstd(means) 123 | means.append(mean_) 124 | stds.append(std_) 125 | return x_names, means, stds 126 | 127 | x_names, acc_mean, acc_std = process_stat(acc_stats) 128 | _, cnt_mean, cnt_std = process_stat(cnt_stats) 129 | 130 | _, ax = plt.subplots(2, 1, figsize=(30, 10)) 131 | plot(x_names=x_names, means=acc_mean, stds=acc_std, ax=ax[0], title='accuracy') 132 | plot(x_names=x_names, means=cnt_mean, stds=cnt_std, ax=ax[1], max_y=128, title='diversity #') 133 | 134 | plt.tight_layout() 135 | plt.savefig('{}/hist_{}.pdf'.format(test_root_dir, thresh)) 136 | plt.show() 137 | 138 | def infer_all(): 139 | for path in all_test_list: 140 | pair_name = osp.basename(path).split('.')[0] 141 | for i in total_rounds: 142 | save_path="{}/round_{}/{}".format(test_root_dir, i, pair_name) 143 | os.makedirs(save_path, exist_ok=True) 144 | cmd = 'python3 inference.py --root_path {} --test_list {} --save_path {} ' \ 145 | '--generator_ckpt {} --stable_critic_ckpt {} --z_dim 3 --num_iter 1 ' \ 146 | '--pose_num 128 --rot_rp axis_angle --device cuda --render_ply' \ 147 | .format(root_path, path, save_path, g_ckpt, c_ckpt, ) 148 | os.system(cmd) 149 | 150 | def main(): 151 | infer_all() 152 | eval_all() 153 | stats(0.8) 154 | 155 | 156 | if __name__ == '__main__': 157 | main() -------------------------------------------------------------------------------- /scripts/test_real_data.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | #export CUDA_VISIBLE_DEVICES=1 3 | 4 | root_path="./real_data/plys" 5 | test_list="./real_data/real_data.txt" 6 | ########## 7 | 8 | pose_num=128 9 | rot_rp="axis_angle" 10 | 11 | g_ckpt="./checkpoints/generator/model_00368000.ckpt" 12 | c_ckpt="./checkpoints/classifier/model_00160000.ckpt" 13 | 14 | save_path="./test_results/real_data" 15 | 16 | mkdir -p $save_path 17 | 18 | python inference.py \ 19 | --root_path $root_path \ 20 | --test_list $test_list \ 21 | --save_path $save_path \ 22 | --generator_ckpt $g_ckpt \ 23 | --stable_critic_ckpt $c_ckpt \ 24 | --z_dim 3 \ 25 | --num_iter 1 \ 26 | --pose_num $pose_num \ 27 | --rot_rp $rot_rp \ 28 | --device 'cpu' \ 29 | --real_data \ 30 | --filter \ 31 | --render 32 | 33 | -------------------------------------------------------------------------------- /scripts/train_multi_task.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | root_path="./dataset" 4 | train_list_path="./dataset/data_list/train_classifier.txt" 5 | test_list_path="./dataset/data_list/test_classifier.txt" 6 | save_path="./classifier_checkpoints/training_$(date +"%F-%T")" 7 | 8 | batch=32 9 | num_proc=2 10 | 11 | mkdir -p $save_path 12 | 13 | python -m torch.distributed.launch \ 14 | --nproc_per_node=$num_proc \ 15 | pose_check/train_multi_task_var_impl.py \ 16 | --root_path $root_path \ 17 | --train_list $train_list_path \ 18 | --val_list $test_list_path \ 19 | --save_path $save_path \ 20 | --batch_size $batch \ 21 | --epochs 10 \ 22 | --lr 0.001 \ 23 | --lr_idx "0.8:0.9" \ 24 | --sync_bn \ 25 | --num_workers 2 \ 26 | --distributed \ 27 | | tee -a $save_path/log.txt 28 | 29 | #ps -ef | grep train_impl | grep -v grep | cut -c 9-15 | xargs kill -9 30 | -------------------------------------------------------------------------------- /scripts/train_pose_generation.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | root_path="./dataset" 4 | list_path="./dataset/data_list/train_generator.txt" 5 | save_path="./generator_checkpoints/training_$(date +"%F-%T")" 6 | 7 | batch=4 8 | num_proc=2 9 | pose_num=128 10 | rot_rp="axis_angle" 11 | 12 | mkdir -p $save_path 13 | 14 | python -m torch.distributed.launch \ 15 | --nproc_per_node=$num_proc \ 16 | pose_generation/train_impl.py \ 17 | --root_path $root_path \ 18 | --list_path $list_path \ 19 | --save_path $save_path \ 20 | --net_arch "vnet2" \ 21 | --batch_size $batch \ 22 | --epochs 20000 \ 23 | --lr 0.001 \ 24 | --lr_idx "0.95:0.9" \ 25 | --z_dim 3 \ 26 | --pose_num $pose_num \ 27 | --rot_rp $rot_rp \ 28 | --sync_bn \ 29 | --num_workers 4 \ 30 | --save_freq 4000 \ 31 | --log_freq 500 \ 32 | --distributed \ 33 | | tee -a $save_path/log.txt 34 | 35 | #ps -ef | grep train_pose_generation | grep -v grep | cut -c 9-15 | xargs kill -9 36 | --------------------------------------------------------------------------------