├── assets └── network.png ├── common ├── metircs_util.py ├── torch_util.py ├── cache.py ├── misc_util.py ├── parallel_util.py ├── visualization_util.py ├── rendering_util.py └── geometry_util.py ├── networks ├── common.py ├── pointnet.py ├── residual_block.py ├── multihead_attention.py ├── resunet.py ├── transformer.py └── tracking_network.py ├── config ├── eval_tracking_default.yaml ├── predict_tracking_gt.yaml ├── predict_tracking_noise.yaml └── train_tracking_default.yaml ├── .gitignore ├── train_tracking.py ├── README.md ├── components ├── mlp.py ├── gridding.py └── unet3d.py ├── predict_tracking_gt.py ├── predict_tracking_noise.py └── eval_tracking.py /assets/network.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaoxiaoxh/GarmentTracking/HEAD/assets/network.png -------------------------------------------------------------------------------- /common/metircs_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | def pdist(A, B, dist_type='L2'): 3 | if dist_type == 'L2': 4 | D2 = torch.sum((A.unsqueeze(1) - B.unsqueeze(0)).pow(2), 2) 5 | return torch.sqrt(D2 + 1e-7) 6 | elif dist_type == 'SquareL2': 7 | return torch.sum((A.unsqueeze(1) - B.unsqueeze(0)).pow(2), 2) 8 | else: 9 | raise NotImplementedError('Not implemented') 10 | -------------------------------------------------------------------------------- /networks/common.py: -------------------------------------------------------------------------------- 1 | import MinkowskiEngine as ME 2 | 3 | 4 | def get_norm(norm_type, num_feats, bn_momentum=0.05, D=-1): 5 | if norm_type == 'BN': 6 | return ME.MinkowskiBatchNorm(num_feats, momentum=bn_momentum) 7 | elif norm_type == 'IN': 8 | return ME.MinkowskiInstanceNorm(num_feats, dimension=D) 9 | else: 10 | raise ValueError(f'Type {norm_type}, not defined') 11 | -------------------------------------------------------------------------------- /common/torch_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch_geometric.data as tgd 4 | 5 | def to_numpy(x): 6 | return x.detach().to('cpu').numpy() 7 | 8 | def get_batch_size(obj): 9 | if isinstance(obj, torch.Tensor): 10 | return obj.shape[0] 11 | elif isinstance(obj, tgd.Batch): 12 | return obj.num_graphs 13 | else: 14 | raise TypeError("Unsupported Type") 15 | -------------------------------------------------------------------------------- /config/eval_tracking_default.yaml: -------------------------------------------------------------------------------- 1 | main: 2 | prediction_output_dir: ~/dev/garmentnets/outputs/2021-07-31/01-43-33 3 | # negative for using all cores avaliable 4 | num_workers: 1 5 | eval: 6 | compute_pc_metrics: 7 | enabled: True 8 | compute_chamfer: 9 | enabled: True 10 | compute_euclidian: 11 | enabled: True 12 | vis: 13 | random_sample_regular: False 14 | samples_per_instance: 50 15 | samples_per_video: 20 16 | first_samples_num: 200 17 | vis_sample_idxs_range: [0, 200] 18 | rank_metric: 'euclidian_sim' 19 | num_normal: 4 20 | num_best: 4 21 | num_worst: 10 22 | task_mesh_vis: 23 | offset: [0.8,0,0] 24 | nocs_mesh_vis: 25 | offset: [0.3,0,0] 26 | value_delta: 0.1 27 | nocs_pc_vis: 28 | offset: [1.0,0,0] 29 | save_point_cloud: False 30 | logger: 31 | mode: offline 32 | name: null 33 | tags: [] 34 | -------------------------------------------------------------------------------- /common/cache.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | import hashlib 3 | import pickle 4 | 5 | def file_attr_cache(target_file, cache_dir='~/local/.cache/file_attr_cache'): 6 | cache_dir_path = pathlib.Path(cache_dir).expanduser() 7 | target_file_path = pathlib.Path(target_file).expanduser() 8 | assert(target_file_path.exists()) 9 | target_key = hashlib.md5( 10 | str(target_file_path.absolute()).encode()).hexdigest() 11 | def decorator(func): 12 | def wrapped(*args, **kwargs): 13 | if not cache_dir_path.exists(): 14 | cache_dir_path.mkdir(parents=True, exist_ok=True) 15 | else: 16 | assert(cache_dir_path.is_dir()) 17 | cache_file_path = cache_dir_path.joinpath(target_key) 18 | if cache_file_path.exists(): 19 | target_time = target_file_path.stat().st_mtime 20 | cache_time = cache_file_path.stat().st_mtime 21 | if target_time < cache_time: 22 | # exist and older than target 23 | obj = pickle.load(cache_file_path.open('rb')) 24 | return obj 25 | 26 | # run function 27 | obj = func(*args, **kwargs) 28 | pickle.dump(obj, cache_file_path.open('wb')) 29 | return obj 30 | return wrapped 31 | return decorator 32 | -------------------------------------------------------------------------------- /networks/pointnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class MiniPointNetfeat(nn.Module): 7 | def __init__(self, nn_channels=(3, 64, 128, 256)): 8 | super(MiniPointNetfeat, self).__init__() 9 | self.nn_channels = nn_channels 10 | assert len(nn_channels) == 4 11 | self.conv1 = torch.nn.Conv1d(nn_channels[0], nn_channels[1], 1) 12 | self.conv2 = torch.nn.Conv1d(nn_channels[1], nn_channels[2], 1) 13 | self.conv3 = torch.nn.Conv1d(nn_channels[2], nn_channels[3], 1) 14 | self.bn1 = nn.BatchNorm1d(nn_channels[1]) 15 | self.bn2 = nn.BatchNorm1d(nn_channels[2]) 16 | self.bn3 = nn.BatchNorm1d(nn_channels[3]) 17 | 18 | def forward(self, x): 19 | """ 20 | 21 | :param x: (B, C, N) input points 22 | :return: global feature (B, C') or dense feature (B, C', N) 23 | """ 24 | n_pts = x.size()[2] 25 | x = F.relu(self.bn1(self.conv1(x))) 26 | 27 | pointfeat = x 28 | x = F.relu(self.bn2(self.conv2(x))) 29 | x = self.bn3(self.conv3(x)) 30 | x = torch.max(x, 2, keepdim=True)[0] # (B, C', 1) 31 | x = x.view(-1, self.nn_channels[-1]) # (B, C') 32 | global_feat = x 33 | x = x.view(-1, self.nn_channels[-1], 1).repeat(1, 1, n_pts) 34 | return torch.cat([x, pointfeat], dim=1), global_feat # (B, C', N), (B, C') 35 | -------------------------------------------------------------------------------- /common/misc_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from common.metircs_util import pdist 5 | 6 | 7 | def _hash(arr, M): 8 | if isinstance(arr, np.ndarray): 9 | N, D = arr.shape 10 | else: 11 | N, D = len(arr[0]), len(arr) 12 | 13 | hash_vec = np.zeros(N, dtype=np.int64) 14 | for d in range(D): 15 | if isinstance(arr, np.ndarray): 16 | hash_vec += arr[:, d] * M**d 17 | else: 18 | hash_vec += arr[d] * M**d 19 | return hash_vec 20 | 21 | 22 | def find_nn_gpu(F0, F1, nn_max_n=-1, return_distance=False, dist_type='SquareL2'): 23 | # Too much memory if F0 or F1 large. Divide the F0 24 | if nn_max_n > 1: 25 | N = len(F0) 26 | C = int(np.ceil(N / nn_max_n)) 27 | stride = nn_max_n 28 | dists, inds = [], [] 29 | for i in range(C): 30 | dist = pdist(F0[i * stride:(i + 1) * stride], F1, dist_type=dist_type) 31 | min_dist, ind = dist.min(dim=1) 32 | dists.append(min_dist.detach().unsqueeze(1).cpu()) 33 | inds.append(ind.cpu()) 34 | 35 | if C * stride < N: 36 | dist = pdist(F0[C * stride:], F1, dist_type=dist_type) 37 | min_dist, ind = dist.min(dim=1) 38 | dists.append(min_dist.detach().unsqueeze(1).cpu()) 39 | inds.append(ind.cpu()) 40 | 41 | dists = torch.cat(dists) 42 | inds = torch.cat(inds) 43 | assert len(inds) == N 44 | else: 45 | dist = pdist(F0, F1, dist_type=dist_type) 46 | min_dist, inds = dist.min(dim=1) 47 | dists = min_dist.detach().unsqueeze(1).cpu() 48 | inds = inds.cpu() 49 | if return_distance: 50 | return inds, dists 51 | else: 52 | return inds 53 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # vscode 2 | .vscode 3 | 4 | # Pycharm 5 | .idea/ 6 | 7 | # MacOS 8 | .DS_Store 9 | 10 | # data 11 | data 12 | vis_data* 13 | wandb 14 | 15 | # python 16 | **/__pycache__ 17 | **/*.pyc 18 | **/.ipynb_checkpoints 19 | 20 | # Byte-compiled / optimized / DLL files 21 | __pycache__/ 22 | *.py[cod] 23 | *$py.class 24 | 25 | # C extensions 26 | *.so 27 | 28 | # Distribution / packaging 29 | .Python 30 | build/ 31 | develop-eggs/ 32 | dist/ 33 | downloads/ 34 | eggs/ 35 | .eggs/ 36 | lib/ 37 | lib64/ 38 | parts/ 39 | sdist/ 40 | var/ 41 | wheels/ 42 | *.egg-info/ 43 | .installed.cfg 44 | *.egg 45 | MANIFEST 46 | 47 | # PyInstaller 48 | # Usually these files are written by a python script from a template 49 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 50 | *.manifest 51 | *.spec 52 | 53 | # Installer logs 54 | pip-log.txt 55 | pip-delete-this-directory.txt 56 | 57 | # Unit test / coverage reports 58 | htmlcov/ 59 | .tox/ 60 | .coverage 61 | .coverage.* 62 | .cache 63 | nosetests.xml 64 | coverage.xml 65 | *.cover 66 | .hypothesis/ 67 | .pytest_cache/ 68 | 69 | # Translations 70 | *.mo 71 | *.pot 72 | 73 | # Django stuff: 74 | *.log 75 | local_settings.py 76 | db.sqlite3 77 | 78 | # Flask stuff: 79 | instance/ 80 | .webassets-cache 81 | 82 | # Scrapy stuff: 83 | .scrapy 84 | 85 | # Sphinx documentation 86 | docs/_build/ 87 | 88 | # PyBuilder 89 | target/ 90 | 91 | # Jupyter Notebook 92 | .ipynb_checkpoints 93 | 94 | # pyenv 95 | .python-version 96 | 97 | # celery beat schedule file 98 | celerybeat-schedule 99 | 100 | # SageMath parsed files 101 | *.sage.py 102 | 103 | # Environments 104 | .env 105 | .venv 106 | env/ 107 | venv/ 108 | ENV/ 109 | env.bak/ 110 | venv.bak/ 111 | 112 | # Spyder project settings 113 | .spyderproject 114 | .spyproject 115 | 116 | # Rope project settings 117 | .ropeproject 118 | 119 | # mkdocs documentation 120 | /site 121 | 122 | # mypy 123 | .mypy_cache/ 124 | 125 | # macOS 126 | .DS_Store 127 | 128 | # Output 129 | output/ 130 | outputs 131 | -------------------------------------------------------------------------------- /config/predict_tracking_gt.yaml: -------------------------------------------------------------------------------- 1 | datamodule: 2 | # must use the absolute path of dataset to work with hydra 3 | zarr_path: ~/dev/garmentnets/data/garmentnets_dataset.zarr/Dress 4 | metadata_cache_dir: ~/local/.cache/metadata_cache_dir 5 | batch_size: 1 6 | num_workers: 0 7 | # sample size 8 | num_pc_sample: 8000 9 | num_pc_sample_final: 4000 10 | num_volume_sample: 0 11 | num_surface_sample: 6000 12 | num_surface_sample_init: 10000 13 | # mixed sampling config 14 | surface_sample_ratio: 0 15 | surface_sample_std: 0.05 16 | # surface sample gt-sim points noise 17 | surface_gt_normal_noise_ratio: 0 18 | surface_gt_normal_std: 0.01 19 | # feature config 20 | use_rgb: True 21 | use_nocs_as_feature: False 22 | # voxelization config 23 | voxel_size: 0.0025 24 | # data augumentation 25 | enable_augumentation: False 26 | enable_zero_center: False 27 | num_views: 4 28 | pc_noise_std: 0 29 | use_pc_nocs_frame1_aug: False 30 | use_fist_frame_pc_nocs_aug_in_test: False 31 | pc_nocs_global_scale_aug_range: [0.8, 1.2] 32 | pc_nocs_global_max_offset_aug: 0.1 33 | pc_nocs_gaussian_std: 0 34 | use_mesh_nocs_aug: False 35 | use_fist_frame_mesh_nocs_aug_in_test: False 36 | mesh_nocs_global_scale_aug_range: [0.8, 1.2] 37 | mesh_nocs_global_max_offset_aug: 0 38 | # random seed 39 | static_epoch_seed: False 40 | # datamodule config 41 | dataset_split: [8,1,1] 42 | split_seed: 0 43 | remove_invalid_interval_in_train: False 44 | # first-frame fitting config 45 | alpha: 1000.0 46 | finetune_offset: [0., -0.03, 0.] 47 | 48 | main: 49 | garmentnets_prediction_output_dir: ~/dev/garmentnets/outputs/2021-07-31/01-43-33 50 | checkpoint_path: ~/dev/garmentnets/data/garmentnets_checkpoints/pipeline_checkpoints/Dress_pipeline.ckpt 51 | gpu_id: 0 52 | prediction: 53 | # val or test 54 | subset: test 55 | volume_size: 128 56 | max_refine_mesh_step: 0 57 | use_valid_grip_interval: True 58 | use_cross_interval_tracking: True 59 | use_garmentnets_prediction: False 60 | disable_mesh_nocs_refine_in_test: False 61 | disable_pc_nocs_refine_in_test: False 62 | alpha: 1000.0 63 | value_threshold: 0.128 64 | debug: False 65 | logger: 66 | mode: offline 67 | name: null 68 | tags: [] 69 | -------------------------------------------------------------------------------- /config/predict_tracking_noise.yaml: -------------------------------------------------------------------------------- 1 | datamodule: 2 | # must use the absolute path of dataset to work with hydra 3 | zarr_path: ~/dev/garmentnets/data/garmentnets_dataset.zarr/Dress 4 | metadata_cache_dir: ~/local/.cache/metadata_cache_dir 5 | batch_size: 1 6 | num_workers: 0 7 | # sample size 8 | num_pc_sample: 8000 9 | num_pc_sample_final: 4000 10 | num_volume_sample: 0 11 | num_surface_sample: 6000 12 | num_surface_sample_init: 10000 13 | # mixed sampling config 14 | surface_sample_ratio: 0 15 | surface_sample_std: 0.05 16 | # surface sample gt-sim points noise 17 | surface_gt_normal_noise_ratio: 0 18 | surface_gt_normal_std: 0.01 19 | # feature config 20 | use_rgb: True 21 | use_nocs_as_feature: False 22 | # voxelization config 23 | voxel_size: 0.0025 24 | # data augumentation 25 | enable_augumentation: False 26 | enable_zero_center: False 27 | num_views: 4 28 | pc_noise_std: 0 29 | use_pc_nocs_frame1_aug: True 30 | use_fist_frame_pc_nocs_aug_in_test: True 31 | pc_nocs_global_scale_aug_range: [0.8, 1.2] 32 | pc_nocs_global_max_offset_aug: 0.1 33 | pc_nocs_gaussian_std: 0.05 34 | use_mesh_nocs_aug: True 35 | use_fist_frame_mesh_nocs_aug_in_test: True 36 | mesh_nocs_global_scale_aug_range: [0.8, 1.2] 37 | mesh_nocs_global_max_offset_aug: 0 38 | # random seed 39 | static_epoch_seed: False 40 | # datamodule config 41 | dataset_split: [8,1,1] 42 | split_seed: 0 43 | remove_invalid_interval_in_train: False 44 | # first-frame fitting config 45 | alpha: 1000.0 46 | finetune_offset: [0., -0.03, 0.] 47 | 48 | main: 49 | garmentnets_prediction_output_dir: ~/dev/garmentnets/outputs/2021-07-31/01-43-33 50 | checkpoint_path: ~/dev/garmentnets/data/garmentnets_checkpoints/pipeline_checkpoints/Dress_pipeline.ckpt 51 | gpu_id: 0 52 | prediction: 53 | # val or test 54 | subset: test 55 | volume_size: 128 56 | max_refine_mesh_step: 1 57 | use_valid_grip_interval: True 58 | use_cross_interval_tracking: True 59 | use_garmentnets_prediction: False 60 | disable_mesh_nocs_refine_in_test: False 61 | disable_pc_nocs_refine_in_test: False 62 | alpha: 1000.0 63 | value_threshold: 0.128 64 | debug: False 65 | logger: 66 | mode: offline 67 | name: null 68 | tags: [] 69 | -------------------------------------------------------------------------------- /train_tracking.py: -------------------------------------------------------------------------------- 1 | # %% 2 | # import 3 | import os 4 | import pathlib 5 | import yaml 6 | import hydra 7 | from omegaconf import DictConfig, OmegaConf 8 | import pytorch_lightning as pl 9 | 10 | from datasets.tracking_dataset import SparseUnet3DTrackingDataModule2 11 | from networks.tracking_network import GarmentTrackingPipeline 12 | 13 | # %% 14 | # main script 15 | @hydra.main(config_path="config", config_name="train_tracking_default") 16 | def main(cfg: DictConfig) -> None: 17 | # hydra creates working directory automatically 18 | print(os.getcwd()) 19 | os.mkdir("checkpoints") 20 | 21 | datamodule = SparseUnet3DTrackingDataModule2(**cfg.datamodule) 22 | batch_size = datamodule.kwargs['batch_size'] 23 | 24 | pipeline_model = GarmentTrackingPipeline( 25 | batch_size=batch_size, **cfg.garment_tracking_model) 26 | 27 | category = pathlib.Path(cfg.datamodule.zarr_path).stem 28 | cfg.logger.tags.append(category) 29 | logger = pl.loggers.WandbLogger( 30 | project=os.path.basename(__file__), 31 | **cfg.logger) 32 | # logger.watch(pipeline_model, **cfg.logger_watch) 33 | wandb_run = logger.experiment 34 | wandb_meta = { 35 | 'run_name': wandb_run.name, 36 | 'run_id': wandb_run.id 37 | } 38 | 39 | all_config = { 40 | 'config': OmegaConf.to_container(cfg, resolve=True), 41 | 'output_dir': os.getcwd(), 42 | 'wandb': wandb_meta 43 | } 44 | yaml.dump(all_config, open('config.yaml', 'w'), default_flow_style=False) 45 | logger.log_hyperparams(all_config) 46 | 47 | checkpoint_callback = pl.callbacks.ModelCheckpoint( 48 | dirpath="checkpoints", 49 | filename="{epoch}-{val_loss:.4f}", 50 | monitor='val_loss', 51 | save_last=True, 52 | save_top_k=5, 53 | mode='min', 54 | save_weights_only=False, 55 | every_n_epochs=1, 56 | save_on_train_epoch_end=True) 57 | trainer = pl.Trainer( 58 | callbacks=[checkpoint_callback], 59 | checkpoint_callback=True, 60 | logger=logger, 61 | check_val_every_n_epoch=1, 62 | **cfg.trainer) 63 | trainer.fit(model=pipeline_model, datamodule=datamodule) 64 | 65 | # %% 66 | # driver 67 | if __name__ == "__main__": 68 | main() 69 | -------------------------------------------------------------------------------- /common/parallel_util.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | import traceback 3 | from tqdm import tqdm 4 | 5 | import pandas as pd 6 | import dask 7 | import dask.bag as db 8 | from dask.diagnostics import ProgressBar 9 | 10 | # helper functions 11 | # =============== 12 | def interpret_num_workers(num_workers): 13 | if num_workers < 1: 14 | num_workers = multiprocessing.cpu_count() 15 | return num_workers 16 | 17 | def get_catch_all_warpper(func): 18 | def wrapper(*args, **kwargs): 19 | result = None 20 | err = None 21 | stack_trace = None 22 | try: 23 | result = func(*args, **kwargs) 24 | except Exception as e: 25 | err = e 26 | stack_trace = traceback.format_exc() 27 | return { 28 | 'result': result, 29 | 'error': err, 30 | 'stack_trace': stack_trace 31 | } 32 | return wrapper 33 | 34 | # high level API 35 | # ============== 36 | def parallel_map( 37 | func, sequence, 38 | num_workers=-1, 39 | scheduler="processes", 40 | include_input=False, 41 | preserve_index=True 42 | ): 43 | # process input 44 | num_workers = interpret_num_workers(num_workers) 45 | input_sequence = list(sequence) 46 | safe_func = get_catch_all_warpper(func) 47 | 48 | # map 49 | output_sequence = None 50 | if num_workers == 1: 51 | output_sequence = list() 52 | for x in tqdm(input_sequence): 53 | output_sequence.append(safe_func(x)) 54 | else: 55 | input_sequence_b = db.from_sequence(input_sequence) 56 | output_sequence_b = input_sequence_b.map(safe_func) 57 | with dask.config.set({ 58 | 'scheduler': scheduler, 59 | 'multiprocessing.context': 'fork', 60 | 'num_workers': num_workers 61 | }): 62 | with ProgressBar(): 63 | output_sequence = output_sequence_b.compute() 64 | 65 | # consolidate result 66 | index = None 67 | if isinstance(sequence, pd.Series) and preserve_index: 68 | index = sequence.index 69 | result_df = pd.DataFrame(output_sequence, 70 | columns=['result', 'error', 'stack_trace'], 71 | index=index) 72 | if include_input: 73 | result_df['input'] = input_sequence 74 | return result_df 75 | -------------------------------------------------------------------------------- /networks/residual_block.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from networks.common import get_norm 4 | 5 | import MinkowskiEngine as ME 6 | import MinkowskiEngine.MinkowskiFunctional as MEF 7 | import pytorch_lightning as pl 8 | 9 | 10 | class BasicBlockBase(pl.LightningModule): 11 | expansion = 1 12 | NORM_TYPE = 'BN' 13 | 14 | def __init__(self, 15 | inplanes, 16 | planes, 17 | stride=1, 18 | dilation=1, 19 | downsample=None, 20 | bn_momentum=0.1, 21 | D=3): 22 | super(BasicBlockBase, self).__init__() 23 | self.save_hyperparameters() 24 | 25 | self.conv1 = ME.MinkowskiConvolution( 26 | inplanes, planes, kernel_size=3, stride=stride, dimension=D) 27 | self.norm1 = get_norm(self.NORM_TYPE, planes, bn_momentum=bn_momentum, D=D) 28 | self.conv2 = ME.MinkowskiConvolution( 29 | planes, 30 | planes, 31 | kernel_size=3, 32 | stride=1, 33 | dilation=dilation, 34 | bias=False, 35 | dimension=D) 36 | self.norm2 = get_norm(self.NORM_TYPE, planes, bn_momentum=bn_momentum, D=D) 37 | self.downsample = downsample 38 | 39 | def forward(self, x): 40 | residual = x 41 | 42 | out = self.conv1(x) 43 | out = self.norm1(out) 44 | out = MEF.relu(out) 45 | 46 | out = self.conv2(out) 47 | out = self.norm2(out) 48 | 49 | if self.downsample is not None: 50 | residual = self.downsample(x) 51 | 52 | out += residual 53 | out = MEF.relu(out) 54 | 55 | return out 56 | 57 | 58 | class BasicBlockBN(BasicBlockBase): 59 | NORM_TYPE = 'BN' 60 | 61 | 62 | class BasicBlockIN(BasicBlockBase): 63 | NORM_TYPE = 'IN' 64 | 65 | 66 | def get_block(norm_type, 67 | inplanes, 68 | planes, 69 | stride=1, 70 | dilation=1, 71 | downsample=None, 72 | bn_momentum=0.1, 73 | D=3): 74 | if norm_type == 'BN': 75 | return BasicBlockBN(inplanes, planes, stride, dilation, downsample, bn_momentum, D) 76 | elif norm_type == 'IN': 77 | return BasicBlockIN(inplanes, planes, stride, dilation, downsample, bn_momentum, D) 78 | else: 79 | raise ValueError(f'Type {norm_type}, not defined') 80 | -------------------------------------------------------------------------------- /common/visualization_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from common.rendering_util import render_nocs, render_wnf, render_wnf_points, render_points_confidence 4 | 5 | # high-level API 6 | # ============== 7 | def overlay_grip(img, grip_nocs, color=(1,0,0,1), side='front', kernel_size=4): 8 | assert(img.shape[0] == img.shape[1]) 9 | img_size = img.shape[0] 10 | grip_img = render_nocs( 11 | points=np.expand_dims(grip_nocs, axis=0), colors=np.array([color]), 12 | side=side, img_size=img_size, kernel_size=kernel_size) 13 | is_grip = grip_img[:,:,3] > 0 14 | new_img = img.copy() 15 | new_img[is_grip] = grip_img[is_grip] 16 | return new_img 17 | 18 | 19 | def render_nocs_pair(gt_nocs, pred_nocs, gt_grip_nocs=None, 20 | pred_grip_nocs=None, pred_grip_nocs_nn=None, 21 | side='front', img_size=256, kernel_size=4): 22 | """ 23 | Both colored using gt_nocs's nocs as rgb 24 | """ 25 | 26 | dtype = gt_nocs.dtype 27 | colors = np.concatenate( 28 | [gt_nocs, np.ones( 29 | (len(gt_nocs), 1), dtype=dtype)], axis=1) 30 | gt_img = render_nocs( 31 | gt_nocs, colors=colors, 32 | side=side, img_size=img_size, kernel_size=kernel_size) 33 | pred_img = render_nocs( 34 | pred_nocs, colors=colors, 35 | side=side, img_size=img_size, kernel_size=kernel_size) 36 | if gt_grip_nocs is not None: 37 | gt_img = overlay_grip(gt_img, gt_grip_nocs, 38 | side=side, kernel_size=kernel_size*2) 39 | if pred_grip_nocs is not None: 40 | pred_img = overlay_grip(pred_img, pred_grip_nocs, 41 | side=side, kernel_size=kernel_size*2) 42 | if pred_grip_nocs_nn is not None: 43 | pred_img = overlay_grip(pred_img, pred_grip_nocs_nn, color=(0,1,0,1), 44 | side=side, kernel_size=kernel_size*2) 45 | pair_img = np.concatenate([gt_img, pred_img], axis=1) 46 | return pair_img 47 | 48 | 49 | def render_confidence_pair(gt_nocs, pred_nocs, confidence, 50 | side='front', img_size=256, kernel_size=4): 51 | gt_img = render_points_confidence( 52 | gt_nocs, confidence) 53 | pred_img = render_points_confidence( 54 | pred_nocs, confidence) 55 | pair_img = np.concatenate([gt_img, pred_img], axis=1) 56 | return pair_img 57 | 58 | 59 | def render_wnf_pair(gt_wnf_img, pred_wnf_img, img_size=256): 60 | gt_img = render_wnf(gt_wnf_img, img_size=img_size) 61 | pred_img = render_wnf(pred_wnf_img, img_size=img_size) 62 | pair_img = np.concatenate([gt_img, pred_img], axis=1) 63 | return pair_img 64 | 65 | def render_wnf_points_pair(query_points, gt_wnf, pred_wnf, img_size=256): 66 | gt_img = render_wnf_points( 67 | query_points=query_points, wnf_values=gt_wnf, img_size=img_size) 68 | pred_img = render_wnf_points( 69 | query_points=query_points, wnf_values=pred_wnf, img_size=img_size) 70 | pair_img = np.concatenate([gt_img, pred_img], axis=1) 71 | return pair_img 72 | 73 | def get_vis_idxs(batch_idx, 74 | batch_size=None, this_batch_size=None, 75 | vis_per_items=1, max_vis_per_epoch=None): 76 | assert((batch_size is not None) or (this_batch_size is not None)) 77 | if this_batch_size is None: 78 | this_batch_size = batch_size 79 | if batch_size is None: 80 | batch_size = this_batch_size 81 | 82 | global_idxs = list() 83 | selected_idxs = list() 84 | vis_idxs = list() 85 | for i in range(this_batch_size): 86 | global_idx = batch_size * batch_idx + i 87 | global_idxs.append(global_idx) 88 | vis_idx = global_idx // vis_per_items 89 | vis_modulo = global_idx % vis_per_items 90 | if (vis_modulo == 0) and (vis_idx < max_vis_per_epoch): 91 | selected_idxs.append(i) 92 | vis_idxs.append(vis_idx) 93 | return global_idxs, selected_idxs, vis_idxs 94 | 95 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GarmentTracking 2 | 3 | This repository contains the source code for the paper [GarmentTracking: Category-Level Garment Pose Tracking](https://garment-tracking.robotflow.ai/). This paper has been accepted to CVPR 2023. 4 | 5 | ![network](assets/network.png) 6 | 7 | ## Datasets 8 | The dataset is collected by [VR-Garment](https://github.com/xiaoxiaoxh/VR-Garment) recording system. 9 | 10 | Please download [VR-Folding Dataset](https://huggingface.co/datasets/robotflow/vr-folding) from Hugging Face. All the data are stored in [zarr](https://zarr.readthedocs.io/en/stable/) format. You can put data under `%PROJECT_DIR/data` or any other location. 11 | 12 | ## Pre-trained Model 13 | 14 | Please download [checkpoints](https://drive.google.com/file/d/1ATy_rWQ13P_AAaP8wvQd41uVvL6UOco2/view?usp=share_link) from Google Drive. 15 | 16 | ## Environment 17 | 18 | ### Requirements 19 | 20 | - Python >= 3.8 21 | - Pytorch >= 1.9.1 22 | - CUDA >= 11.1 23 | 24 | Please use the following commands to setup environments (we highly recommend installing Pytorch with pip for compatibility). The 3D feature extractor used in our paper is based on [MinkowskiEngine](https://github.com/NVIDIA/MinkowskiEngine). 25 | 26 | ``` 27 | conda create -n garment_tracking python=3.9 28 | conda activate garment_tracking 29 | ``` 30 | 31 | ```bash 32 | pip install torch==1.12.0+cu113 torchvision==0.13.0+cu113 --extra-index-url https://download.pytorch.org/whl/cu113 33 | ``` 34 | ```bash 35 | conda install -y openblas-devel igl -c anaconda -c conda-forge 36 | ``` 37 | 38 | ```bash 39 | pip install -U git+https://github.com/NVIDIA/MinkowskiEngine -v --no-deps --install-option="--blas_include_dirs=${CONDA_PREFIX}/include" --install-option="--blas=openblas" 40 | ``` 41 | ```bash 42 | pip install torch-geometric torch-scatter torch_sparse torch_cluster torchmetrics==0.5.1 open3d pandas wandb pytorch-lightning==1.4.9 hydra-core scipy==1.7.0 scikit-image matplotlib zarr numcodecs tqdm dask numba 43 | ``` 44 | 45 | 46 | 47 | ## Training 48 | 49 | Here is the example for training ( `Tshirt`, `Folding` task): 50 | 51 | ```bash 52 | python train_tracking.py datamodule.zarr_path=data/vr_folding_dataset.zarr/Tshirt logger.offline=False logger.name=Tshirt-folding-tracking 53 | ``` 54 | 55 | Here `logger.offline=False` will enable online syncing (eg. losses, logs, visualization) for [wandb](wandb.ai). You can use offline syncing mode by setting`logger.offline=True`. You can set `datamodule.batch_size=8` if the GPU memory is not large enough. 56 | 57 | Each running will create a new working directory (eg. `2022-11-03/12-33-00`) under `%PROJECT_DIR/outputs` which contains all the checkpoints and logs. 58 | 59 | ## Inference 60 | 61 | Here are some examples for inference ( `Tshirt`, `Folding` task): 62 | 63 | - First-frame initialization with GT: 64 | 65 | ```bash 66 | python predict_tracking_gt.py datamodule.zarr_path=data/vr_folding_dataset.zarr/Tshirt prediction.max_refine_mesh_step=0 main.checkpoint_path=outputs/2022-11-03/12-33-00/checkpoints/last.ckpt logger.name=Tshirt-folding-tracking_test-gt 67 | ``` 68 | 69 | - First-frame initialization with noise: 70 | 71 | ```bash 72 | python predict_tracking_noise.py datamodule.zarr_path=data/vr_folding_dataset.zarr/Tshirt prediction.max_refine_mesh_step=1 main.checkpoint_path=outputs/2022-11-03/12-33-00/checkpoints/last.ckpt logger.name=Tshirt-folding-tracking_test-noise 73 | ``` 74 | 75 | For *Folding* task, we recommend using `prediction.max_refine_mesh_step=1`. For *Flattening* task, we recommend using `prediction.max_refine_mesh_step=15`. 76 | 77 | ## Evaluation 78 | 79 | Here is the example for evaluation ( `Tshirt`, `Folding` task): 80 | 81 | ```bash 82 | python eval_tracking.py main.prediction_output_dir=outputs/2022-11-07/14-48-52 logger.name=Tshirt-folding-tracking-base10_test-gt 83 | ``` 84 | 85 | The evaluation will also generate some visualization examples in the form of logs in [wandb](wandb.ai). You can set `logger.offline=False` if you want to enable automatic online syncing for [wandb](wandb.ai). You can also manually sync the logs later in offline mode by default. 86 | -------------------------------------------------------------------------------- /config/train_tracking_default.yaml: -------------------------------------------------------------------------------- 1 | datamodule: 2 | # must use the absolute path of dataset to work with hydra 3 | zarr_path: ~/dev/garmentnets/data/garmentnets_dataset.zarr/Dress 4 | video_info_path: None 5 | metadata_cache_dir: ~/local/.cache/metadata_cache_dir 6 | use_file_attr_cache: True 7 | batch_size: 16 8 | num_workers: 4 9 | # sample size 10 | num_pc_sample: 8000 11 | num_pc_sample_final: 4000 12 | num_volume_sample: 0 13 | num_surface_sample: 6000 14 | num_surface_sample_init: 10000 15 | # mixed sampling config 16 | surface_sample_ratio: 0 17 | surface_sample_std: 0.05 18 | # surface sample gt-sim points noise 19 | surface_gt_normal_noise_ratio: 0 20 | surface_gt_normal_std: 0.01 21 | # feature config 22 | use_rgb: True 23 | use_nocs_as_feature: False 24 | # voxelization config 25 | voxel_size: 0.0025 26 | # data augumentation 27 | enable_augumentation: True 28 | random_rot_range: [-180,180] 29 | enable_zero_center: False 30 | num_views: 4 31 | pc_noise_std: 0 32 | use_pc_nocs_frame1_aug: True 33 | pc_nocs_global_scale_aug_range: [0.8, 1.2] 34 | pc_nocs_global_max_offset_aug: 0.1 35 | pc_nocs_gaussian_std: 0 36 | use_pc_nocs_ball_offset_aug: False 37 | pc_nocs_ball_query_radius_range: [0.0, 0.2] 38 | pc_nocs_ball_query_max_nn: 400 39 | use_mesh_nocs_aug: True 40 | mesh_nocs_global_scale_aug_range: [0.8, 1.2] 41 | mesh_nocs_global_max_offset_aug: 0 42 | # volume 43 | volume_size: 128 44 | # or nocs_signed_distance_field or nocs_occupancy_grid or sim_nocs_winding_number_field or nocs_distance_field 45 | volume_group: nocs_winding_number_field 46 | # use 0.05 47 | tsdf_clip_value: null 48 | volume_absolute_value: False 49 | # random seed 50 | static_epoch_seed: False 51 | is_val: True 52 | # datamodule config 53 | dataset_split: [8,1,1] 54 | split_seed: 0 55 | remove_invalid_interval_in_train: True 56 | # first-frame fitting config 57 | alpha: 1000.0 58 | finetune_offset: [0., -0.03, 0.] 59 | 60 | garment_tracking_model: 61 | sparse_unet3d_encoder_params: 62 | in_channels: 3 63 | out_channels: 64 64 | conv1_kernel_size: 5 65 | normalize_feature: False 66 | predict_segm: False 67 | CHANNELS: [None, 64, 64, 128, 256] 68 | TR_CHANNELS: [None, 64, 64, 64, 128] 69 | transformer_params: 70 | input_channels: 3 71 | use_xyz: True 72 | input_size: 4000 73 | d_model: 64 74 | num_layers: 1 75 | key_feature_dim: 64 76 | with_pos_embed: True 77 | encoder_pos_embed_input_dim: 6 78 | decoder_pos_embed_input_dim: 3 79 | inverse_source_template: False 80 | fea_channels: [64, 128, 128] 81 | feat_slim_last_layer: True 82 | nocs_slim_last_layer: True 83 | nocs_bins: 64 84 | nocs_channels: [128, 128, 128, 192] 85 | nocs_refiner_params: 86 | detach_input_pc_feature: True 87 | detach_global_pc_feature: False 88 | detach_global_mesh_feature: True 89 | volume_agg_params: 90 | nn_channels: [134, 256, 128] 91 | batch_norm: True 92 | lower_corner: [0,0,0] 93 | upper_corner: [1,1,1] 94 | grid_shape: [32,32,32] 95 | reduce_method: max 96 | include_point_feature: True 97 | use_gt_nocs_for_train: True 98 | use_mlp_v2: True 99 | unet3d_params: 100 | in_channels: 128 101 | out_channels: 128 102 | f_maps: 32 103 | layer_order: gcr 104 | num_groups: 8 105 | num_levels: 4 106 | surface_decoder_params: 107 | nn_channels: [128,256,256,3] 108 | batch_norm: True 109 | use_mlp_v2: True 110 | warp_loss_weight: 10.0 111 | nocs_loss_weight: 1.0 112 | mesh_loss_weight: 10.0 113 | use_nocs_refiner: True 114 | learning_rate: 0.0001 115 | optimizer_type: Adam 116 | loss_type: l2 117 | vis_per_items: 200 118 | max_vis_per_epoch_train: 10 119 | max_vis_per_epoch_val: 10 120 | debug: False 121 | trainer: 122 | gpus: [0] 123 | logger: 124 | offline: True 125 | name: null 126 | tags: [] 127 | logger_watch: 128 | log: gradients 129 | log_freq: 100 130 | -------------------------------------------------------------------------------- /networks/multihead_attention.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import math 4 | import torch.nn.functional as F 5 | import numpy as np 6 | 7 | 8 | class TransNonlinear(nn.Module): 9 | def __init__(self, d_model, dim_feedforward, dropout=0.1): 10 | super().__init__() 11 | self.linear1 = nn.Linear(d_model, dim_feedforward) 12 | self.dropout = nn.Dropout(dropout) 13 | self.linear2 = nn.Linear(dim_feedforward, d_model) 14 | 15 | self.norm2 = nn.LayerNorm(d_model) 16 | self.dropout1 = nn.Dropout(dropout) 17 | self.dropout2 = nn.Dropout(dropout) 18 | self.activation = nn.ReLU() 19 | 20 | def forward(self, src): 21 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) 22 | src = src + self.dropout2(src2) 23 | src = self.norm2(src) 24 | return src 25 | 26 | 27 | class MultiheadAttention(nn.Module): 28 | def __init__(self, feature_dim=512, n_head=8, key_feature_dim=64, 29 | extra_nonlinear=True): 30 | super(MultiheadAttention, self).__init__() 31 | self.Nh = n_head 32 | self.head = nn.ModuleList() 33 | self.extra_nonlinear = nn.ModuleList() 34 | for N in range(self.Nh): 35 | self.head.append(RelationUnit(feature_dim, key_feature_dim)) 36 | if extra_nonlinear: 37 | self.extra_nonlinear.append(TransNonlinear(feature_dim, key_feature_dim)) 38 | else: 39 | self.extra_nonlinear = None 40 | 41 | def forward(self, query=None, key=None, value=None, 42 | ): 43 | """ 44 | query : #pixel x batch x dim 45 | 46 | """ 47 | isFirst = True 48 | for N in range(self.Nh): 49 | if(isFirst): 50 | concat = self.head[N](query, key, value) 51 | if self.extra_nonlinear: 52 | concat = self.extra_nonlinear[N](concat) 53 | isFirst = False 54 | else: 55 | tmp = self.head[N](query, key, value) 56 | if self.extra_nonlinear: 57 | tmp = self.extra_nonlinear[N](tmp) 58 | concat = torch.cat((concat, tmp), -1) 59 | 60 | output = concat 61 | return output 62 | 63 | 64 | class RelationUnit(nn.Module): 65 | def __init__(self, feature_dim=512, key_feature_dim=64): 66 | super(RelationUnit, self).__init__() 67 | self.temp = 1 68 | self.WK = nn.Linear(feature_dim, key_feature_dim, bias=False) 69 | self.WQ = nn.Linear(feature_dim, key_feature_dim, bias=False) 70 | self.WV = nn.Linear(feature_dim, feature_dim, bias=False) 71 | self.after_norm = nn.BatchNorm1d(feature_dim) 72 | self.trans_conv = nn.Linear(feature_dim, feature_dim, bias=False) 73 | 74 | # Init weights 75 | for m in self.WK.modules(): 76 | m.weight.data.normal_(0, math.sqrt(2. / m.out_features)) 77 | if m.bias is not None: 78 | m.bias.data.zero_() 79 | 80 | for m in self.WQ.modules(): 81 | m.weight.data.normal_(0, math.sqrt(2. / m.out_features)) 82 | if m.bias is not None: 83 | m.bias.data.zero_() 84 | 85 | for m in self.WV.modules(): 86 | m.weight.data.normal_(0, math.sqrt(2. / m.out_features)) 87 | if m.bias is not None: 88 | m.bias.data.zero_() 89 | 90 | def forward(self, query=None, key=None, value=None, mask=None): 91 | w_k = self.WK(key) 92 | w_k = F.normalize(w_k, p=2, dim=-1) 93 | w_k = w_k.permute(1, 2, 0) # Batch, Dim, Len_1 94 | 95 | w_q = self.WQ(query) 96 | w_q = F.normalize(w_q, p=2, dim=-1) 97 | w_q = w_q.permute(1, 0, 2) # Batch, Len_2, Dim 98 | 99 | dot_prod = torch.bmm(w_q, w_k) # Batch, Len_2, Len_1 100 | if mask is not None: 101 | dot_prod = dot_prod.masked_fill(mask == 0, -1e9) 102 | affinity = F.softmax(dot_prod * self.temp, dim=-1) 103 | affinity = affinity / (1e-9 + affinity.sum(dim=1, keepdim=True)) 104 | 105 | w_v = self.WV(value) 106 | w_v = w_v.permute(1,0,2) # Batch, Len_1, Dim 107 | output = torch.bmm(affinity, w_v) # Batch, Len_2, Dim 108 | output = output.permute(1,0,2) 109 | 110 | output = self.trans_conv(query - output) 111 | 112 | return F.relu(output) 113 | 114 | -------------------------------------------------------------------------------- /components/mlp.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import pytorch_lightning as pl 3 | 4 | class PointBatchNorm1D(nn.BatchNorm1d): 5 | def __init__(self, *args, **kwargs): 6 | super().__init__(*args, **kwargs) 7 | def forward(self, x): 8 | return super().forward(x.view(-1, x.shape[-1])).view(x.shape) 9 | 10 | 11 | def MLP(channels, batch_norm=True, last_layer=False, drop_out=False): 12 | layers = list() 13 | for i in range(1, len(channels)): 14 | if i == len(channels) - 1 and last_layer: 15 | module_layers = [ 16 | nn.Linear(channels[i - 1], channels[i])] 17 | else: 18 | module_layers = [ 19 | nn.Linear(channels[i - 1], channels[i])] 20 | module_layers.append(nn.ReLU()) 21 | if batch_norm: 22 | module_layers.append( 23 | PointBatchNorm1D(channels[i])) 24 | if drop_out: 25 | module_layers.append(nn.Dropout(0.5)) 26 | module = nn.Sequential(*module_layers) 27 | layers.append(module) 28 | return nn.Sequential(*layers) 29 | 30 | 31 | class MLP_V2(pl.LightningModule): 32 | def __init__(self, channels, batch_norm=True, transpose_input=False): 33 | super(MLP_V2, self).__init__() 34 | layers = [] 35 | norm_type = 'batch' if batch_norm else None 36 | for i in range(1, len(channels)): 37 | if i == len(channels) - 1: 38 | layers.append(EquivariantLayer(channels[i - 1], channels[i], 39 | activation=None, normalization=None)) 40 | else: 41 | layers.append(EquivariantLayer(channels[i - 1], channels[i], normalization=norm_type)) 42 | self.layers = nn.ModuleList(layers) 43 | self.transpose_input = transpose_input 44 | 45 | def forward(self, x): 46 | expand_dim = False 47 | if self.transpose_input: 48 | if len(x.shape) == 2: 49 | x = x.unsqueeze(1) # (B, 1, C) 50 | expand_dim = True 51 | x = x.transpose(-1, -2) # (B, C, 1) or (B, C, M) 52 | for layer in self.layers: 53 | x = layer(x) 54 | if self.transpose_input: 55 | if expand_dim: 56 | x = x.squeeze(-1) # (B, C') 57 | else: 58 | x = x.transpose(-1, -2) # (B, M, C') 59 | return x 60 | 61 | 62 | class EquivariantLayer(nn.Module): 63 | def __init__(self, num_in_channels, num_out_channels, activation='relu', normalization=None, momentum=0.1, 64 | num_groups=16): 65 | super(EquivariantLayer, self).__init__() 66 | 67 | self.num_in_channels = num_in_channels 68 | self.num_out_channels = num_out_channels 69 | self.activation = activation 70 | self.normalization = normalization 71 | 72 | self.conv = nn.Conv1d(self.num_in_channels, self.num_out_channels, kernel_size=1, stride=1, padding=0) 73 | 74 | if 'batch' == self.normalization: 75 | self.norm = nn.BatchNorm1d(self.num_out_channels, momentum=momentum, affine=True) 76 | elif 'instance' == self.normalization: 77 | self.norm = nn.InstanceNorm1d(self.num_out_channels, momentum=momentum, affine=True) 78 | elif 'group' == self.normalization: 79 | self.norm = nn.GroupNorm(num_groups=num_groups, num_channels=self.num_out_channels) 80 | 81 | if 'relu' == self.activation: 82 | self.act = nn.ReLU() 83 | elif 'leakyrelu' == self.activation: 84 | self.act = nn.LeakyReLU(0.01) 85 | 86 | self.weight_init() 87 | 88 | def weight_init(self): 89 | for m in self.modules(): 90 | if isinstance(m, nn.Conv1d): 91 | if self.activation == 'relu' or self.activation == 'leakyrelu': 92 | nn.init.kaiming_normal_(m.weight, nonlinearity=self.activation) 93 | else: 94 | m.weight.data.normal_(0, std=0.01) 95 | if m.bias is not None: 96 | m.bias.data.zero_() 97 | elif isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.InstanceNorm1d): 98 | m.weight.data.fill_(1) 99 | m.bias.data.zero_() 100 | 101 | def forward(self, x): 102 | y = self.conv(x) 103 | 104 | if self.normalization == 'batch': 105 | y = self.norm(y) 106 | elif self.normalization is not None: 107 | y = self.norm(y) 108 | 109 | if self.activation is not None: 110 | y = self.act(y) 111 | 112 | return y -------------------------------------------------------------------------------- /common/rendering_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numba import jit 3 | from matplotlib.cm import get_cmap 4 | from skimage.transform import resize 5 | 6 | # helper functions 7 | # ================ 8 | @jit(nopython=True, nogil=True) 9 | def _render_points_idx(xy_idx, z, idx_img, min_img, kernel_size, kernel_offset): 10 | for i in range(len(z)): 11 | x, y = xy_idx[i] 12 | this_z = z[i] 13 | min_z = min_img[y, x] 14 | for dy in range(kernel_offset, kernel_offset + kernel_size): 15 | ny = min(max(y + dy, 0), idx_img.shape[0] - 1) 16 | for dx in range(kernel_offset, kernel_offset + kernel_size): 17 | nx = min(max(x + dx, 0), idx_img.shape[1] - 1) 18 | min_z = min_img[ny, nx] 19 | if this_z < min_z: 20 | min_img[ny, nx] = this_z 21 | idx_img[ny, nx] = i 22 | 23 | 24 | # low-level API 25 | # ============= 26 | def render_points_idx(points, img_size=256, kernel_size=4): 27 | # assumes points are normized bewteen 0 and 1 28 | # assume s colros are rgb 0 to 1 29 | # images are in cv coordiante: (y, x) 30 | idx_dtype = np.uint32 31 | default_idx = np.iinfo(idx_dtype).max 32 | idx_img = np.full(shape=(img_size, img_size), 33 | fill_value=default_idx, dtype=idx_dtype) 34 | min_img = np.full(shape=(img_size, img_size), 35 | fill_value=float('inf'), dtype=points.dtype) 36 | xy_idx = np.clip( 37 | (points[:,:2] * (img_size-1)).astype(idx_dtype), 38 | 0, img_size-1) 39 | z = points[:, 2] 40 | kernel_offset = -(kernel_size // 2) 41 | _render_points_idx(xy_idx, z, idx_img, min_img, kernel_size, kernel_offset) 42 | return idx_img 43 | 44 | 45 | def color_idx_img(idx_img, colors, default_color=np.array([1,1,1])): 46 | h, w = idx_img.shape 47 | default_idx = np.iinfo(idx_img.dtype).max 48 | img_not_null = idx_img < default_idx 49 | idxs = idx_img[img_not_null] 50 | color_img = np.zeros((h, w, len(default_color)), dtype=np.float32) 51 | color_img[:, :] = default_color 52 | color_img[img_not_null] = colors[idxs] 53 | return color_img 54 | 55 | 56 | def get_extrinsic(side='front'): 57 | # world to camera 58 | if side == 'front': 59 | return np.array([ 60 | [1, 0, 0, 0], 61 | [0, 0,-1, 1], 62 | [0, 1, 0, 0], 63 | [0, 0, 0, 1] 64 | ]) 65 | elif side == 'top': 66 | return np.array([ 67 | [1, 0, 0, 0], 68 | [0,-1, 0, 1], 69 | [0, 0,-1, 1], 70 | [0, 0, 0, 1] 71 | ]) 72 | elif side == 'left': 73 | return np.array([ 74 | [0,-1, 0, 1], 75 | [0, 0,-1, 1], 76 | [1, 0, 0, 0], 77 | [0, 0, 0, 1] 78 | ]) 79 | else: 80 | assert(False) 81 | 82 | 83 | def to_camera(points, extrinsic): 84 | rx = extrinsic[:3, :3] 85 | tx = extrinsic[:3, 3] 86 | result = points @ rx.T + tx 87 | return result 88 | 89 | 90 | # high-level API 91 | # ============== 92 | def render_nocs( 93 | points, colors=None, 94 | side='front', img_size=256, kernel_size=4, 95 | default_color=np.array([1,1,1,0])): 96 | extrinsic = get_extrinsic(side) 97 | camera_points = to_camera(points, extrinsic) 98 | if colors is None: 99 | colors = np.concatenate( 100 | [points, np.ones((len(points), 1), dtype=points.dtype)], axis=1) 101 | 102 | idx_img = render_points_idx(camera_points, 103 | img_size=img_size, kernel_size=kernel_size) 104 | color_img = color_idx_img( 105 | idx_img, colors, 106 | default_color=default_color) 107 | return color_img 108 | 109 | 110 | def render_wnf(wnf_img, img_size=256, cmap='viridis', min_value=-0.5, max_value=1.5): 111 | cmap = get_cmap(cmap) 112 | value_img = (wnf_img - min_value) / (max_value - min_value) 113 | color_img = cmap(value_img) 114 | final_img = resize(color_img, (img_size, img_size), anti_aliasing=False) 115 | return final_img 116 | 117 | def get_wnf_cmap(cmap='viridis', min_value=-0.5, max_value=1.5): 118 | cmap = get_cmap(cmap) 119 | def cmap_func(x): 120 | values = (x - min_value) / (max_value - min_value) 121 | colors = cmap(values) 122 | return colors 123 | return cmap_func 124 | 125 | def render_wnf_points(query_points, wnf_values, slice_range=(0.5, 0.6), side='front', **kwargs): 126 | cmap = get_wnf_cmap() 127 | colors = cmap(wnf_values) 128 | assert side == 'front' 129 | dim_idx = 1 130 | is_selected = (slice_range[0] < query_points[...,dim_idx]) \ 131 | & (query_points[...,dim_idx] < slice_range[1]) 132 | 133 | color_img = render_nocs( 134 | points=query_points[is_selected], 135 | colors=colors[is_selected], side=side, **kwargs) 136 | return color_img 137 | 138 | def render_points_confidence(points, confidence, side='front', **kwargs): 139 | cmap = get_wnf_cmap(min_value=0.0, max_value=1.0) 140 | colors = cmap(confidence) 141 | color_img = render_nocs( 142 | points=points, 143 | colors=colors, side=side, **kwargs) 144 | return color_img 145 | 146 | -------------------------------------------------------------------------------- /networks/resunet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import MinkowskiEngine as ME 4 | import MinkowskiEngine.MinkowskiFunctional as MEF 5 | 6 | from networks.residual_block import get_block 7 | from networks.common import get_norm 8 | from components.mlp import MLP 9 | 10 | 11 | class SparseResUNet2(ME.MinkowskiNetwork): 12 | # To use the model, must call initialize_coords before forward pass. 13 | # Once data is processed, call clear to reset the model before calling initialize_coords 14 | def __init__(self, 15 | in_channels=3, 16 | out_channels=32, 17 | bn_momentum=0.1, 18 | normalize_feature=True, 19 | conv1_kernel_size=5, 20 | D=3, 21 | NORM_TYPE='BN', 22 | BLOCK_NORM_TYPE='BN', 23 | CHANNELS=(None, 32, 64, 128, 256), 24 | TR_CHANNELS=(None, 64, 64, 64, 128), 25 | predict_segm=False, 26 | SEGM_CHANNELS=(32, 64, 64, 2) 27 | ): 28 | ME.MinkowskiNetwork.__init__(self, D) 29 | self.normalize_feature = normalize_feature 30 | self.conv1 = ME.MinkowskiConvolution( 31 | in_channels=in_channels, 32 | out_channels=CHANNELS[1], 33 | kernel_size=conv1_kernel_size, 34 | stride=1, 35 | dilation=1, 36 | bias=False, 37 | dimension=D) 38 | self.norm1 = get_norm(NORM_TYPE, CHANNELS[1], bn_momentum=bn_momentum, D=D) 39 | 40 | self.block1 = get_block( 41 | BLOCK_NORM_TYPE, CHANNELS[1], CHANNELS[1], bn_momentum=bn_momentum, D=D) 42 | 43 | self.conv2 = ME.MinkowskiConvolution( 44 | in_channels=CHANNELS[1], 45 | out_channels=CHANNELS[2], 46 | kernel_size=3, 47 | stride=2, 48 | dilation=1, 49 | bias=False, 50 | dimension=D) 51 | self.norm2 = get_norm(NORM_TYPE, CHANNELS[2], bn_momentum=bn_momentum, D=D) 52 | 53 | self.block2 = get_block( 54 | BLOCK_NORM_TYPE, CHANNELS[2], CHANNELS[2], bn_momentum=bn_momentum, D=D) 55 | 56 | self.conv3 = ME.MinkowskiConvolution( 57 | in_channels=CHANNELS[2], 58 | out_channels=CHANNELS[3], 59 | kernel_size=3, 60 | stride=2, 61 | dilation=1, 62 | bias=False, 63 | dimension=D) 64 | self.norm3 = get_norm(NORM_TYPE, CHANNELS[3], bn_momentum=bn_momentum, D=D) 65 | 66 | self.block3 = get_block( 67 | BLOCK_NORM_TYPE, CHANNELS[3], CHANNELS[3], bn_momentum=bn_momentum, D=D) 68 | 69 | self.conv4 = ME.MinkowskiConvolution( 70 | in_channels=CHANNELS[3], 71 | out_channels=CHANNELS[4], 72 | kernel_size=3, 73 | stride=2, 74 | dilation=1, 75 | bias=False, 76 | dimension=D) 77 | self.norm4 = get_norm(NORM_TYPE, CHANNELS[4], bn_momentum=bn_momentum, D=D) 78 | 79 | self.block4 = get_block( 80 | BLOCK_NORM_TYPE, CHANNELS[4], CHANNELS[4], bn_momentum=bn_momentum, D=D) 81 | 82 | self.conv4_tr = ME.MinkowskiConvolutionTranspose( 83 | in_channels=CHANNELS[4], 84 | out_channels=TR_CHANNELS[4], 85 | kernel_size=3, 86 | stride=2, 87 | dilation=1, 88 | bias=False, 89 | dimension=D) 90 | self.norm4_tr = get_norm(NORM_TYPE, TR_CHANNELS[4], bn_momentum=bn_momentum, D=D) 91 | 92 | self.block4_tr = get_block( 93 | BLOCK_NORM_TYPE, TR_CHANNELS[4], TR_CHANNELS[4], bn_momentum=bn_momentum, D=D) 94 | 95 | self.conv3_tr = ME.MinkowskiConvolutionTranspose( 96 | in_channels=CHANNELS[3] + TR_CHANNELS[4], 97 | out_channels=TR_CHANNELS[3], 98 | kernel_size=3, 99 | stride=2, 100 | dilation=1, 101 | bias=False, 102 | dimension=D) 103 | self.norm3_tr = get_norm(NORM_TYPE, TR_CHANNELS[3], bn_momentum=bn_momentum, D=D) 104 | 105 | self.block3_tr = get_block( 106 | BLOCK_NORM_TYPE, TR_CHANNELS[3], TR_CHANNELS[3], bn_momentum=bn_momentum, D=D) 107 | 108 | self.conv2_tr = ME.MinkowskiConvolutionTranspose( 109 | in_channels=CHANNELS[2] + TR_CHANNELS[3], 110 | out_channels=TR_CHANNELS[2], 111 | kernel_size=3, 112 | stride=2, 113 | dilation=1, 114 | bias=False, 115 | dimension=D) 116 | self.norm2_tr = get_norm(NORM_TYPE, TR_CHANNELS[2], bn_momentum=bn_momentum, D=D) 117 | 118 | self.block2_tr = get_block( 119 | BLOCK_NORM_TYPE, TR_CHANNELS[2], TR_CHANNELS[2], bn_momentum=bn_momentum, D=D) 120 | 121 | self.conv1_tr = ME.MinkowskiConvolution( 122 | in_channels=CHANNELS[1] + TR_CHANNELS[2], 123 | out_channels=TR_CHANNELS[1], 124 | kernel_size=1, 125 | stride=1, 126 | dilation=1, 127 | bias=False, 128 | dimension=D) 129 | 130 | # self.block1_tr = BasicBlockBN(TR_CHANNELS[1], TR_CHANNELS[1], bn_momentum=bn_momentum, D=D) 131 | 132 | self.final = ME.MinkowskiConvolution( 133 | in_channels=TR_CHANNELS[1], 134 | out_channels=out_channels, 135 | kernel_size=1, 136 | stride=1, 137 | dilation=1, 138 | bias=True, 139 | dimension=D) 140 | 141 | self.predict_segm = predict_segm 142 | if predict_segm: 143 | self.segm_mlp = MLP(SEGM_CHANNELS, batch_norm=True, last_layer=True) 144 | self.weight_init(self.segm_mlp) 145 | 146 | def weight_init(self, module): 147 | for m in module.modules(): 148 | if isinstance(m, nn.Linear): 149 | nn.init.kaiming_normal_(m.weight, nonlinearity='relu') 150 | if m.bias is not None: 151 | m.bias.data.zero_() 152 | 153 | def forward(self, x): 154 | out_s1 = self.conv1(x) 155 | out_s1 = self.norm1(out_s1) 156 | out_s1 = self.block1(out_s1) 157 | out = MEF.relu(out_s1) 158 | 159 | out_s2 = self.conv2(out) 160 | out_s2 = self.norm2(out_s2) 161 | out_s2 = self.block2(out_s2) 162 | out = MEF.relu(out_s2) 163 | 164 | out_s4 = self.conv3(out) 165 | out_s4 = self.norm3(out_s4) 166 | out_s4 = self.block3(out_s4) 167 | out = MEF.relu(out_s4) 168 | 169 | out_s8 = self.conv4(out) 170 | out_s8 = self.norm4(out_s8) 171 | out_s8 = self.block4(out_s8) 172 | out = MEF.relu(out_s8) 173 | 174 | out = self.conv4_tr(out) 175 | out = self.norm4_tr(out) 176 | out = self.block4_tr(out) 177 | out_s4_tr = MEF.relu(out) 178 | 179 | out = ME.cat(out_s4_tr, out_s4) 180 | 181 | out = self.conv3_tr(out) 182 | out = self.norm3_tr(out) 183 | out = self.block3_tr(out) 184 | out_s2_tr = MEF.relu(out) 185 | 186 | out = ME.cat(out_s2_tr, out_s2) 187 | 188 | out = self.conv2_tr(out) 189 | out = self.norm2_tr(out) 190 | out = self.block2_tr(out) 191 | out_s1_tr = MEF.relu(out) 192 | 193 | out = ME.cat(out_s1_tr, out_s1) 194 | out = self.conv1_tr(out) 195 | out = MEF.relu(out) 196 | out = self.final(out) 197 | 198 | if self.normalize_feature: 199 | features = ME.SparseTensor( 200 | out.F / torch.norm(out.F, p=2, dim=1, keepdim=True), 201 | coordinate_map_key=out.coordinate_map_key, 202 | coordinate_manager=out.coordinate_manager) 203 | else: 204 | features = out 205 | 206 | if self.predict_segm: 207 | dense_features = features.F 208 | segm_score = self.segm_mlp(dense_features) 209 | return features, segm_score 210 | else: 211 | return features, None 212 | -------------------------------------------------------------------------------- /common/geometry_util.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | import numpy as np 4 | import igl 5 | import open3d as o3d 6 | import copy 7 | import torch 8 | 9 | 10 | def build_line(length=1.0, num_nodes=50): 11 | verts = np.zeros((num_nodes, 3), dtype=np.float32) 12 | verts[:, 0] = np.linspace(0, length, num_nodes) 13 | edges = np.empty((num_nodes - 1, 2), dtype=np.uint32) 14 | edges[:, 0] = range(0, num_nodes - 1) 15 | edges[:, 1] = range(1, num_nodes) 16 | return verts, edges 17 | 18 | 19 | def build_rectangle(width=0.45, height=0.32, width_num_node=23, height_num_node=17): 20 | """ 21 | Row major, row corresponds to width 22 | """ 23 | #width_num_node = int(np.round(width / grid_size)) + 1 24 | #height_num_node = int(np.round(height / grid_size)) + 1 25 | 26 | print("Creating a rectangular grid with the following parameters:") 27 | print("Width:", width) 28 | print("Height:", height) 29 | print("W nodes::", width_num_node) 30 | print("H nodes:", height_num_node) 31 | 32 | def xy_to_index(x_idx, y_idx): 33 | # Assumes the following layout in imagespace - 0 is to the top left of the image 34 | # 35 | # 0 cloth_x_size+0 ... cloth_y_size*cloth_x_size - cloth_x_size + 0 36 | # 1 cloth_x_size+1 ... cloth_y_size*cloth_x_size - cloth_x_size + 1 37 | # 2 cloth_x_size+2 ... cloth_y_size*cloth_x_size - cloth_x_size + 2 38 | # ... 39 | # cloth_x_size-1 cloth_x_size*2-1 ... cloth_y_size*cloth_x_size - 1 40 | # return x_idx * width_num_node + y_idx 41 | return y_idx * height_num_node + x_idx 42 | 43 | verts = np.zeros((width_num_node * height_num_node, 3), dtype=np.float32) 44 | uv = np.zeros((width_num_node * height_num_node, 2), dtype=np.float32) 45 | edges_temp = [] 46 | faces_temp = [] 47 | for x in range(height_num_node): 48 | for y in range(width_num_node): 49 | curr_idx = xy_to_index(x, y) 50 | verts[curr_idx, 0] = x * height / (height_num_node - 1) 51 | verts[curr_idx, 1] = y * width / (width_num_node - 1) 52 | uv[curr_idx, 0] = x / (height_num_node - 1) 53 | uv[curr_idx, 1] = y / (width_num_node - 1) 54 | 55 | if x + 1 < height_num_node: 56 | edges_temp.append([curr_idx, xy_to_index(x + 1, y)]) 57 | if y + 1 < width_num_node: 58 | edges_temp.append([curr_idx, xy_to_index(x, y + 1)]) 59 | if x + 1 < height_num_node and y + 1 < width_num_node: 60 | faces_temp.append([curr_idx, xy_to_index(x + 1, y), xy_to_index(x + 1, y + 1), xy_to_index(x, y + 1)]) 61 | 62 | edges = np.array(edges_temp, dtype=np.uint32) 63 | faces = np.array(faces_temp, dtype=np.uint32) 64 | return verts, edges, faces, uv 65 | 66 | def faces_to_edges(faces): 67 | edges_set = set() 68 | for face in faces: 69 | for i in range(1, len(face)): 70 | edge_pair = (face[i-1], face[i]) 71 | edge_pair = tuple(sorted(edge_pair)) 72 | edges_set.add(edge_pair) 73 | edges = np.array(list(edges_set), dtype=np.int) 74 | return edges 75 | 76 | class AABBNormalizer: 77 | def __init__(self, aabb): 78 | center = np.mean(aabb, axis=0) 79 | edge_lengths = aabb[1] - aabb[0] 80 | scale = 1 / np.max(edge_lengths) 81 | result_center = np.ones((3,), dtype=aabb.dtype) / 2 82 | 83 | self.center = center 84 | self.scale = scale 85 | self.result_center = result_center 86 | 87 | def __call__(self, data): 88 | center = self.center 89 | scale = self.scale 90 | result_center = self.result_center 91 | 92 | result = (data - center) * scale + result_center 93 | return result 94 | 95 | def inverse(self, data): 96 | center = self.center 97 | scale = self.scale 98 | result_center = self.result_center 99 | 100 | result = (data - result_center) / scale + center 101 | return result 102 | 103 | class AABBGripNormalizer: 104 | """ 105 | Assumes that the origin is gripping point. 106 | Only translate the aabb in z direction and scale to fit. 107 | """ 108 | def __init__(self, aabb, padding=0.05): 109 | nocs_radius = 0.5 - padding 110 | radius = np.max(np.abs(aabb), axis=0)[:2] 111 | radius_scale = np.min(nocs_radius / radius) 112 | nocs_z = nocs_radius * 2 113 | z_length = aabb[1,2] - aabb[0,2] 114 | z_scale = nocs_z / z_length 115 | scale = min(radius_scale, z_scale) 116 | 117 | z_max = aabb[1,2] * scale 118 | offset = np.array([0.5, 0.5, 1-padding-z_max], dtype=aabb.dtype) 119 | self.scale = scale 120 | self.offset = offset 121 | 122 | def __call__(self, data): 123 | scale = self.scale 124 | offset = self.offset 125 | result = (data * scale) + offset 126 | return result 127 | 128 | def inverse(self, data): 129 | scale = self.scale 130 | offset = self.offset 131 | result = (data - offset) / scale 132 | return result 133 | 134 | 135 | def get_aabb(coords): 136 | """ 137 | Axis Aligned Bounding Box 138 | Input: 139 | coords: (N, C) array 140 | Output: 141 | aabb: (2, C) array 142 | """ 143 | min_coords = np.min(coords, axis=0) 144 | max_coords = np.max(coords, axis=0) 145 | aabb = np.stack([min_coords, max_coords]) 146 | return aabb 147 | 148 | 149 | def buffer_aabb(aabb, buffer): 150 | result_aabb = aabb.copy() 151 | result_aabb[0] -= buffer 152 | result_aabb[1] += buffer 153 | return result_aabb 154 | 155 | 156 | def quads2tris(quads): 157 | assert(isinstance(quads, np.ndarray)) 158 | assert(len(quads.shape) == 2) 159 | assert(quads.shape[1] == 4) 160 | 161 | # allocate new array 162 | tris = np.zeros((quads.shape[0] * 2, 3), dtype=quads.dtype) 163 | tris[0::2] = quads[:, [0,1,2]] 164 | tris[1::2] = quads[:, [0,2,3]] 165 | return tris 166 | 167 | 168 | def barycentric_interpolation(query_coords: np.array, verts: np.array, faces: np.array) -> np.array: 169 | """ 170 | Input: 171 | query_coords: np.array[M, 3] float barycentric coorindates 172 | verts: np.array[N, 3] float vertecies 173 | faces: np.array[M, 3] int face index into verts, 1:1 coorespondace to query_coords 174 | 175 | Output 176 | result: np.array[M, 3] float interpolated points 177 | """ 178 | assert(len(verts.shape) == 2) 179 | result = np.zeros((len(query_coords), verts.shape[1]), dtype=verts.dtype) 180 | for c in range(verts.shape[1]): 181 | for i in range(query_coords.shape[1]): 182 | result[:, c] += \ 183 | query_coords[:, i] * verts[:,c][faces[:,i]] 184 | return result 185 | 186 | 187 | def mesh_sample_barycentric( 188 | verts: np.ndarray, faces: np.ndarray, 189 | num_samples: int, seed: Optional[int] = None, 190 | face_areas: np.ndarray = None) -> Tuple[np.ndarray, np.ndarray]: 191 | """ 192 | Uniformly sample points (as their barycentric coordinate) on suface 193 | 194 | Input: 195 | verts: np.array[N, 3] float mesh vertecies 196 | faces: np.array[M, 3] int mesh face index into verts 197 | num_sampels: int 198 | seed: int random seed 199 | face_areas: np.array[M, 3] per-face areas 200 | 201 | Output: 202 | barycentric_all: np.array[num_samples, 3] float sampled barycentric coordinates 203 | selected_face_idx: np.array[num_samples,3] int sampled faces, 1:1 coorespondance to barycentric_all 204 | """ 205 | # generate face area 206 | if face_areas is None: 207 | face_areas = igl.doublearea(verts, faces) 208 | face_areas = face_areas / np.sum(face_areas) 209 | assert(len(face_areas) == len(faces)) 210 | 211 | rs = np.random.RandomState(seed=seed) 212 | # select faces 213 | selected_face_idx = rs.choice( 214 | len(faces), size=num_samples, 215 | replace=True, p=face_areas).astype(faces.dtype) 216 | 217 | # generate random barycentric coordinate 218 | barycentric_uv = rs.uniform(0, 1, size=(num_samples, 2)) 219 | not_triangle = (np.sum(barycentric_uv, axis=1) >= 1) 220 | barycentric_uv[not_triangle] = 1 - barycentric_uv[not_triangle] 221 | 222 | barycentric_all = np.zeros((num_samples, 3), dtype=barycentric_uv.dtype) 223 | barycentric_all[:, :2] = barycentric_uv 224 | barycentric_all[:, 2] = 1 - np.sum(barycentric_uv, axis=1) 225 | 226 | return barycentric_all, selected_face_idx 227 | 228 | 229 | def get_matching_index_numpy(source_pts, target_pts, is_train=False): 230 | source_pts_torch = torch.from_numpy(source_pts).to('cpu' if is_train else 'cuda:0') 231 | target_pts_torch = torch.from_numpy(target_pts).to('cpu' if is_train else 'cuda:0') 232 | num_source_pts = source_pts.shape[0] # N1 233 | num_target_pts = target_pts.shape[0] # N2 234 | source_pts_torch_expand = source_pts_torch.unsqueeze(1).expand(-1, num_target_pts, -1) # (N1, N2, 3) 235 | tartet_pts_torch_expand = target_pts_torch.unsqueeze(0).expand(num_source_pts, -1, -1) # (N1, N2, 3) 236 | dist_matrix = torch.sum((source_pts_torch_expand - tartet_pts_torch_expand)**2, dim=-1) # (N1, N2) 237 | matching_inds = np.stack([np.arange(num_source_pts), dist_matrix.min(dim=1)[1].cpu().numpy()], axis=1) 238 | return matching_inds -------------------------------------------------------------------------------- /components/gridding.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | import torch_scatter 6 | 7 | 8 | def batch_to_volume(batch, volume_size, reduce='mean'): 9 | batch_idx = batch.batch 10 | points = batch.pos 11 | features = batch.x 12 | # debug 13 | # features = torch.arange(-12000, len(points)-12000, device=points.device).repeat(3, 1).T.type(torch.float32) 14 | # points = features / volume_size 15 | 16 | batch_size = batch.num_graphs 17 | feature_size = features.shape[1] 18 | # compute volume index 19 | volume_idx_flat = None 20 | with torch.no_grad(): 21 | grid_coord_f = points * volume_size 22 | grid_coord_i = torch.clamp(grid_coord_f.type(torch.int64), 0, volume_size-1) 23 | volume_idx_flat = \ 24 | batch_idx * (volume_size ** 3) \ 25 | + grid_coord_i[:,0] * (volume_size ** 2) \ 26 | + grid_coord_i[:,1] * volume_size \ 27 | + grid_coord_i[:,2] 28 | 29 | # scatter and aggregate to volume 30 | features = features.type(torch.float32) 31 | volume_feature_flat = torch_scatter.scatter( 32 | src=features.T, index=volume_idx_flat, 33 | dim=-1, dim_size=batch_size*volume_size**3, 34 | reduce=reduce) 35 | 36 | # reshape to volume 37 | volume_feature = volume_feature_flat.reshape( 38 | (feature_size, batch_size, 39 | volume_size, volume_size, volume_size)).permute((1,0,2,3,4)) 40 | 41 | return volume_feature 42 | 43 | 44 | def nocs_grid_sample(feature_volume: torch.Tensor, query_points: torch.Tensor, 45 | mode:str='bilinear', padding_mode:str='border', 46 | align_corners:bool=True) -> torch.Tensor: 47 | """ 48 | feature_volume: (N,C,D,H,W) or (N,D,H,W) or (D,H,W) 49 | query_points: (N,M,3) or (M,3) 50 | return: (N,M,C) or (M,C) 51 | """ 52 | # 1. processs query_points 53 | # normalize query points to (-1, 1), which is 54 | # requried by grid_sample 55 | query_points_normalized = 2.0 * query_points - 1.0 56 | 57 | query_points_shape = None 58 | if len(query_points.shape) == 2: 59 | shape = tuple(query_points.shape) 60 | query_points_shape = (1, shape[0], 1, 1, shape[1]) 61 | elif len(query_points.shape) == 3: 62 | shape = tuple(query_points.shape) 63 | query_points_shape = (shape[0], shape[1], 1, 1, shape[2]) 64 | else: 65 | raise RuntimeError("Invalid query_points shape {}".format( 66 | str(query_points.shape))) 67 | query_points_reshaped = query_points_normalized.view(*query_points_shape) 68 | # sample_features uses zyx convension in coordinate, not xyz 69 | query_points_reshaped = query_points_reshaped.flip(-1) 70 | 71 | 72 | # 2. process feature_volume 73 | feature_volume_shape = None 74 | if len(feature_volume.shape) == 5: 75 | feature_volume_shape = tuple(feature_volume.shape) 76 | elif len(feature_volume.shape) == 4: 77 | shape = tuple(feature_volume.shape) 78 | feature_volume_shape = (shape[0], 1, shape[1], shape[2], shape[3]) 79 | elif len(feature_volume.shape) == 3: 80 | feature_volume_shape = (1,1) + feature_volume.shape 81 | feature_volume_reshaped = feature_volume.view(*feature_volume_shape) 82 | 83 | # 3. sample 84 | # shape (N,C,M,1,1) 85 | sampled_features = F.grid_sample( 86 | input=feature_volume_reshaped, grid=query_points_reshaped, 87 | mode=mode, padding_mode=padding_mode, align_corners=align_corners) 88 | 89 | # 4. reshape output 90 | out_features_shape = None 91 | if len(query_points.shape) == 2: 92 | out_features_shape = (query_points.shape[0], sampled_features.shape[1]) 93 | elif len(query_points.shape) == 3: 94 | out_features_shape = query_points.shape[:2] + (sampled_features.shape[1],) 95 | out_features = sampled_features.permute(0,2,1,3,4).view(*out_features_shape) 96 | 97 | return out_features 98 | 99 | 100 | class VirtualGrid: 101 | def __init__(self, 102 | lower_corner=(0,0,0), 103 | upper_corner=(1,1,1), 104 | grid_shape=(32, 32, 32), 105 | batch_size=8, 106 | device=torch.device('cpu'), 107 | int_dtype=torch.int64, 108 | float_dtype=torch.float32, 109 | ): 110 | self.lower_corner = tuple(lower_corner) 111 | self.upper_corner = tuple(upper_corner) 112 | self.grid_shape = tuple(grid_shape) 113 | self.batch_size = int(batch_size) 114 | self.device = device 115 | self.int_dtype = int_dtype 116 | self.float_dtype = float_dtype 117 | 118 | @property 119 | def num_grids(self): 120 | grid_shape = self.grid_shape 121 | batch_size = self.batch_size 122 | return int(np.prod((batch_size,) + grid_shape)) 123 | 124 | def get_grid_idxs(self, include_batch=True): 125 | batch_size = self.batch_size 126 | grid_shape = self.grid_shape 127 | device = self.device 128 | int_dtype = self.int_dtype 129 | dims = grid_shape 130 | if include_batch: 131 | dims = (batch_size,) + grid_shape 132 | axis_coords = [torch.arange(0, x, device=device, dtype=int_dtype) 133 | for x in dims] 134 | coords_per_axis = torch.meshgrid(*axis_coords) 135 | grid_idxs = torch.stack(coords_per_axis, axis=-1) 136 | return grid_idxs 137 | 138 | def get_grid_points(self, include_batch=True): 139 | lower_corner = self.lower_corner 140 | upper_corner = self.upper_corner 141 | grid_shape = self.grid_shape 142 | float_dtype = self.float_dtype 143 | device = self.device 144 | grid_idxs = self.get_grid_idxs(include_batch=include_batch) 145 | 146 | lc = torch.tensor(lower_corner, dtype=float_dtype, device=device) 147 | uc = torch.tensor(upper_corner, dtype=float_dtype, device=device) 148 | idx_scale = torch.tensor(grid_shape, 149 | dtype=float_dtype, device=device) - 1 150 | scales = (uc - lc) / idx_scale 151 | offsets = -lc 152 | 153 | grid_idxs_no_batch = grid_idxs 154 | if include_batch: 155 | grid_idxs_no_batch = grid_idxs[:,:,:,:,1:] 156 | grid_idxs_f = grid_idxs_no_batch.to(float_dtype) 157 | grid_points = grid_idxs_f * scales + offsets 158 | return grid_points 159 | 160 | def get_points_grid_idxs(self, points, batch_idx=None): 161 | lower_corner = self.lower_corner 162 | upper_corner = self.upper_corner 163 | grid_shape = self.grid_shape 164 | int_dtype = self.int_dtype 165 | float_dtype = self.float_dtype 166 | device = self.device 167 | lc = torch.tensor(lower_corner, dtype=float_dtype, device=device) 168 | uc = torch.tensor(upper_corner, dtype=float_dtype, device=device) 169 | idx_scale = torch.tensor(grid_shape, 170 | dtype=float_dtype, device=device) - 1 171 | offsets = -lc 172 | scales = idx_scale / (uc - lc) 173 | points_idxs_f = (points + offsets) * scales 174 | points_idxs_i = points_idxs_f.to(dtype=int_dtype) 175 | points_idxs = torch.empty_like(points_idxs_i) 176 | for i in range(3): 177 | points_idxs[...,i] = torch.clamp( 178 | points_idxs_i[...,i], min=0, max=grid_shape[i]-1) 179 | final_points_idxs = points_idxs 180 | if batch_idx is not None: 181 | final_points_idxs = torch.cat( 182 | [batch_idx.view(*points.shape[:-1], 1).to( 183 | dtype=points_idxs.dtype), points_idxs], 184 | axis=-1) 185 | return final_points_idxs 186 | 187 | def flatten_idxs(self, idxs, keepdim=False): 188 | grid_shape = self.grid_shape 189 | batch_size = self.batch_size 190 | 191 | coord_size = idxs.shape[-1] 192 | target_shape = None 193 | if coord_size == 4: 194 | # with batch 195 | target_shape = (batch_size,) + grid_shape 196 | elif coord_size == 3: 197 | # without batch 198 | target_shape = grid_shape 199 | else: 200 | raise RuntimeError("Invalid shape {}".format(str(idxs.shape))) 201 | target_stride = tuple(np.cumprod(np.array(target_shape)[::-1])[::-1])[1:] + (1,) 202 | flat_idxs = (idxs * torch.tensor(target_stride, 203 | dtype=idxs.dtype, device=idxs.device)).sum( 204 | axis=-1, keepdim=keepdim, dtype=idxs.dtype) 205 | return flat_idxs 206 | 207 | def unflatten_idxs(self, flat_idxs, include_batch=True): 208 | grid_shape = self.grid_shape 209 | batch_size = self.batch_size 210 | target_shape = grid_shape 211 | if include_batch: 212 | target_shape = (batch_size,) + grid_shape 213 | target_stride = tuple(np.cumprod(np.array(target_shape)[::-1])[::-1])[1:] + (1,) 214 | 215 | source_shape = tuple(flat_idxs.shape) 216 | if source_shape[-1] == 1: 217 | source_shape = source_shape[:-1] 218 | flat_idxs = flat_idxs[...,0] 219 | source_shape += (4,) if include_batch else (3,) 220 | 221 | idxs = torch.empty(size=source_shape, 222 | dtype=flat_idxs.dtype, device=flat_idxs.device) 223 | mod = flat_idxs 224 | for i in range(source_shape[-1]): 225 | idxs[...,i] = mod / target_stride[i] 226 | mod = mod % target_stride[i] 227 | return idxs 228 | 229 | def idxs_to_points(self, idxs): 230 | lower_corner = self.lower_corner 231 | upper_corner = self.upper_corner 232 | grid_shape = self.grid_shape 233 | float_dtype = self.float_dtype 234 | int_dtype = idxs.dtype 235 | device = idxs.device 236 | 237 | source_shape = idxs.shape 238 | point_idxs = None 239 | if source_shape[-1] == 4: 240 | # has batch idx 241 | point_idxs = idxs[...,1:] 242 | elif source_shape[-1] == 3: 243 | point_idxs = idxs 244 | else: 245 | raise RuntimeError("Invalid shape {}".format(tuple(source_shape))) 246 | 247 | lc = torch.tensor(lower_corner, dtype=float_dtype, device=device) 248 | uc = torch.tensor(upper_corner, dtype=float_dtype, device=device) 249 | idx_scale = torch.tensor(grid_shape, 250 | dtype=float_dtype, device=device) - 1 251 | offsets = lc 252 | scales = (uc - lc) / idx_scale 253 | 254 | idxs_points = point_idxs * scales + offsets 255 | return idxs_points 256 | 257 | 258 | def ceil_div(a, b): 259 | return -(-a // b) 260 | 261 | class ArraySlicer: 262 | def __init__(self, shape: tuple, chunks: tuple): 263 | assert(len(chunks) <= len(shape)) 264 | relevent_shape = shape[:len(chunks)] 265 | chunk_size = tuple(ceil_div(*x) \ 266 | for x in zip(relevent_shape, chunks)) 267 | 268 | self.relevent_shape = relevent_shape 269 | self.chunks = chunks 270 | self.chunk_size = chunk_size 271 | 272 | def __len__(self): 273 | chunk_size = self.chunk_size 274 | return int(np.prod(chunk_size)) 275 | 276 | def __getitem__(self, idx): 277 | relevent_shape = self.relevent_shape 278 | chunks = self.chunks 279 | chunk_size = self.chunk_size 280 | chunk_stride = np.cumprod((chunk_size[1:] + (1,))[::-1])[::-1] 281 | chunk_idx = list() 282 | mod = idx 283 | for x in chunk_stride: 284 | chunk_idx.append(mod // x) 285 | mod = mod % x 286 | 287 | slices = list() 288 | for i in range(len(chunk_idx)): 289 | start = chunks[i] * chunk_idx[i] 290 | end = min(relevent_shape[i], 291 | chunks[i] * (chunk_idx[i] + 1)) 292 | slices.append(slice(start, end)) 293 | return slices 294 | 295 | def __iter__(self): 296 | for i in range(len(self)): 297 | yield self[i] 298 | -------------------------------------------------------------------------------- /networks/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from typing import Optional, List 4 | from torch import nn, Tensor 5 | from components.gridding import VirtualGrid 6 | from components.mlp import MLP 7 | from networks.multihead_attention import MultiheadAttention 8 | 9 | 10 | class PositionEmbeddingLearned(nn.Module): 11 | """ 12 | Absolute pos embedding, learned. 13 | """ 14 | 15 | def __init__(self, input_channel=3, num_pos_feats=256): 16 | super().__init__() 17 | self.position_embedding_head = nn.Sequential( 18 | nn.Conv1d(input_channel, num_pos_feats, kernel_size=1), 19 | nn.BatchNorm1d(num_pos_feats), 20 | nn.ReLU(inplace=True), 21 | nn.Conv1d(num_pos_feats, num_pos_feats, kernel_size=1)) 22 | 23 | def forward(self, xyz): 24 | # xyz : BxNx3 25 | xyz = xyz.transpose(1, 2).contiguous() 26 | # Bx3xN 27 | position_embedding = self.position_embedding_head(xyz) 28 | return position_embedding 29 | 30 | 31 | class TransformerSiamese(nn.Module): 32 | def __init__(self, 33 | input_channels=3, 34 | use_xyz=True, 35 | input_size=4000, 36 | d_model=64, 37 | num_layers=1, 38 | key_feature_dim=128, 39 | with_pos_embed=True, 40 | encoder_pos_embed_input_dim=3, 41 | decoder_pos_embed_input_dim=3, 42 | fea_channels=(128, 256, 256), 43 | nocs_bins=64, 44 | nocs_channels=(256, 256, 3), 45 | feat_slim_last_layer=True, 46 | nocs_slim_last_layer=True, 47 | inverse_source_template=False, 48 | ): 49 | super(TransformerSiamese, self).__init__() 50 | self.input_channels = input_channels 51 | self.use_xyz = use_xyz 52 | self.input_size = input_size 53 | self.d_model = d_model 54 | self.num_layers = num_layers 55 | self.nocs_bins = nocs_bins 56 | self.encoder_pos_embed_input_dim = encoder_pos_embed_input_dim 57 | self.decoder_pos_embed_input_dim = decoder_pos_embed_input_dim 58 | assert encoder_pos_embed_input_dim in (3, 6) 59 | self.with_pos_embed = with_pos_embed 60 | self.inverse_source_template = inverse_source_template 61 | 62 | multihead_attn = MultiheadAttention( 63 | feature_dim=d_model, n_head=1, key_feature_dim=key_feature_dim) 64 | 65 | if self.with_pos_embed: 66 | encoder_pos_embed = PositionEmbeddingLearned(encoder_pos_embed_input_dim, d_model) 67 | decoder_pos_embed = PositionEmbeddingLearned(decoder_pos_embed_input_dim, d_model) 68 | else: 69 | encoder_pos_embed = None 70 | decoder_pos_embed = None 71 | 72 | self.fea_layer = MLP(fea_channels, batch_norm=True, last_layer=feat_slim_last_layer) 73 | 74 | output_dim = 3 75 | if nocs_bins is not None: 76 | output_dim = nocs_bins * 3 77 | assert nocs_channels[-1] == output_dim 78 | self.nocs_layer = MLP(nocs_channels, batch_norm=True, last_layer=nocs_slim_last_layer) 79 | 80 | self.encoder = TransformerEncoder( 81 | multihead_attn=multihead_attn, FFN=None, 82 | d_model=d_model, num_encoder_layers=num_layers, 83 | self_posembed=encoder_pos_embed) 84 | self.decoder = TransformerDecoder( 85 | multihead_attn=multihead_attn, FFN=None, 86 | d_model=d_model, num_decoder_layers=num_layers, 87 | key_feature_dim=key_feature_dim, 88 | self_posembed=decoder_pos_embed) 89 | 90 | def transform_fuse(self, search_feature, search_coord, 91 | template_feature, template_coord): 92 | """Use transformer to fuse feature. 93 | 94 | template_feature : (B, C, N) 95 | template_coord : (B, N, 3) or (B, N, 6) 96 | """ 97 | # BxCxN -> NxBxC 98 | search_feature = search_feature.permute(2, 0, 1) 99 | template_feature = template_feature.permute(2, 0, 1) 100 | 101 | ## encoder 102 | encoded_memory = self.encoder(template_feature, 103 | query_pos=template_coord if self.with_pos_embed else None) 104 | 105 | encoded_feat = self.decoder(search_feature, 106 | memory=encoded_memory, 107 | query_pos=search_coord) # NxBxC 108 | 109 | # NxBxC -> BxNxC 110 | encoded_feat = encoded_feat.permute(1, 0, 2) 111 | encoded_feat = self.fea_layer(encoded_feat) # BxNxC 112 | 113 | return encoded_feat 114 | 115 | def forward(self, template_feature, template_coord, 116 | search_feature, search_coord): 117 | """ 118 | template_feature: (B*N, C) 119 | template_coord: (B*N, 3) or (B*N, 6) 120 | search_feature: (B*N, C) 121 | search_coord: (B*N, 3) 122 | """ 123 | feature_size = template_feature.shape[-1] 124 | template_feature = template_feature.reshape(-1, self.input_size, feature_size).permute(0, 2, 1) # (B, C, N) 125 | search_feature = search_feature.reshape(-1, self.input_size, feature_size).permute(0, 2, 1) # (B, C, N) 126 | template_coord = template_coord.reshape(-1, self.input_size, template_coord.shape[-1]) # (B, N, 3 or 6) 127 | search_coord = search_coord.reshape(-1, self.input_size, 3) # (B, N, 3) 128 | batch_size = template_feature.shape[0] 129 | 130 | if self.inverse_source_template: 131 | fusion_feature = self.transform_fuse( 132 | template_feature, template_coord, search_feature, search_coord) # (B, N, C) 133 | else: 134 | fusion_feature = self.transform_fuse( 135 | search_feature, search_coord, template_feature, template_coord) # (B, N, C) 136 | pred_nocs = self.nocs_layer(fusion_feature) # (B, N, C'*3) 137 | if self.nocs_bins is None: 138 | # direct regression 139 | pred_nocs = torch.sigmoid(pred_nocs) # (B, N, C'*3) 140 | 141 | fusion_feature = fusion_feature.reshape(batch_size * self.input_size, -1) # (B*N, C) 142 | pred_nocs = pred_nocs.reshape(batch_size * self.input_size, -1) # (B*N, C'*3) 143 | return fusion_feature, pred_nocs 144 | 145 | def logits_to_nocs(self, logits): 146 | nocs_bins = self.nocs_bins 147 | if nocs_bins is None: 148 | # directly regress from nn 149 | return logits 150 | 151 | # reshape 152 | logits_bins = None 153 | if len(logits.shape) == 2: 154 | logits_bins = logits.reshape((logits.shape[0], nocs_bins, 3)) 155 | elif len(logits.shape) == 1: 156 | logits_bins = logits.reshape((nocs_bins, 3)) 157 | 158 | bin_idx_pred = torch.argmax(logits_bins, dim=1, keepdim=False) 159 | 160 | # turn into per-channel classification problem 161 | vg = self.get_virtual_grid(logits.get_device()) 162 | points_pred = vg.idxs_to_points(bin_idx_pred) 163 | return points_pred 164 | 165 | def get_virtual_grid(self, device): 166 | nocs_bins = self.nocs_bins 167 | vg = VirtualGrid(lower_corner=(0, 0, 0), upper_corner=(1, 1, 1), 168 | grid_shape=(nocs_bins,) * 3, batch_size=1, 169 | device=device, int_dtype=torch.int64, 170 | float_dtype=torch.float32) 171 | return vg 172 | 173 | 174 | class InstanceL2Norm(nn.Module): 175 | """Instance L2 normalization. 176 | """ 177 | def __init__(self, size_average=True, eps=1e-5, scale=1.0): 178 | super().__init__() 179 | self.size_average = size_average 180 | self.eps = eps 181 | self.scale = scale 182 | 183 | def forward(self, input): 184 | if self.size_average: 185 | return input * (self.scale * ((input.shape[1] * input.shape[2] * input.shape[3]) / ( 186 | torch.sum((input * input).reshape(input.shape[0], 1, 1, -1), dim=3, keepdim=True) + self.eps)).sqrt()) # view 187 | else: 188 | return input * (self.scale / (torch.sum((input * input).reshape(input.shape[0], 1, 1, -1), dim=3, keepdim=True) + self.eps).sqrt()) 189 | 190 | 191 | class TransformerEncoderLayer(nn.Module): 192 | def __init__(self, multihead_attn, FFN, d_model, self_posembed=None): 193 | super().__init__() 194 | self.self_attn = multihead_attn 195 | # Implementation of Feedforward model 196 | self.FFN = FFN 197 | self.norm = nn.InstanceNorm1d(d_model) 198 | self.self_posembed = self_posembed 199 | 200 | self.dropout = nn.Dropout(0.1) 201 | 202 | def with_pos_embed(self, tensor, pos_embed: Optional[Tensor]): 203 | return tensor if pos_embed is None else tensor + pos_embed 204 | 205 | def forward(self, src, query_pos=None): 206 | # BxNxC -> BxCxN -> NxBxC 207 | if self.self_posembed is not None and query_pos is not None: 208 | query_pos_embed = self.self_posembed(query_pos).permute(2, 0, 1) 209 | else: 210 | query_pos_embed = None 211 | query = key = value = self.with_pos_embed(src, query_pos_embed) 212 | 213 | # self-attention 214 | # NxBxC 215 | src2 = self.self_attn(query=query, key=key, value=value) 216 | src = src + src2 217 | 218 | # NxBxC -> BxCxN -> NxBxC 219 | src = self.norm(src.permute(1, 2, 0)).permute(2, 0, 1) 220 | return F.relu(src) 221 | # return src 222 | 223 | 224 | class TransformerEncoder(nn.Module): 225 | def __init__(self, multihead_attn, FFN, 226 | d_model=512, 227 | num_encoder_layers=6, 228 | activation="relu", 229 | self_posembed=None): 230 | super().__init__() 231 | encoder_layer = TransformerEncoderLayer( 232 | multihead_attn, FFN, d_model, self_posembed=self_posembed) 233 | self.layers = _get_clones(encoder_layer, num_encoder_layers) 234 | 235 | def forward(self, src, query_pos=None): 236 | num_imgs, batch, dim = src.shape 237 | output = src 238 | 239 | for layer in self.layers: 240 | output = layer(output, query_pos=query_pos) 241 | 242 | # import pdb; pdb.set_trace() 243 | # [L,B,D] -> [B,D,L] 244 | # output_feat = output.reshape(num_imgs, batch, dim) 245 | return output 246 | 247 | 248 | class TransformerDecoderLayer(nn.Module): 249 | def __init__(self, multihead_attn, FFN, d_model, key_feature_dim, self_posembed=None): 250 | super().__init__() 251 | self.self_attn = multihead_attn 252 | self.cross_attn = MultiheadAttention( 253 | feature_dim=d_model, 254 | n_head=1, key_feature_dim=key_feature_dim) 255 | 256 | self.FFN = FFN 257 | self.norm1 = nn.InstanceNorm1d(d_model) 258 | self.norm2 = nn.InstanceNorm1d(d_model) 259 | self.self_posembed = self_posembed 260 | 261 | self.dropout1 = nn.Dropout(0.1) 262 | self.dropout2 = nn.Dropout(0.1) 263 | 264 | def with_pos_embed(self, tensor, pos_embed: Optional[Tensor]): 265 | return tensor if pos_embed is None else tensor + pos_embed 266 | 267 | def forward(self, tgt, memory, query_pos=None): 268 | if self.self_posembed is not None and query_pos is not None: 269 | query_pos_embed = self.self_posembed(query_pos).permute(2, 0, 1) 270 | else: 271 | query_pos_embed = None 272 | # NxBxC 273 | 274 | # self-attention 275 | query = key = value = self.with_pos_embed(tgt, query_pos_embed) 276 | 277 | tgt2 = self.self_attn(query=query, key=key, value=value) 278 | # tgt2 = self.dropout1(tgt2) 279 | tgt = tgt + tgt2 280 | # tgt = F.relu(tgt) 281 | # tgt = self.instance_norm(tgt, input_shape) 282 | # NxBxC 283 | # tgt = self.norm(tgt) 284 | tgt = self.norm1(tgt.permute(1, 2, 0)).permute(2, 0, 1) 285 | tgt = F.relu(tgt) 286 | 287 | mask = self.cross_attn( 288 | query=tgt, key=memory, value=memory) 289 | # mask = self.dropout2(mask) 290 | tgt2 = tgt + mask 291 | tgt2 = self.norm2(tgt2.permute(1, 2, 0)).permute(2, 0, 1) 292 | 293 | tgt2 = F.relu(tgt2) 294 | return tgt2 295 | 296 | 297 | class TransformerDecoder(nn.Module): 298 | def __init__(self, multihead_attn, FFN, 299 | d_model=512, 300 | num_decoder_layers=6, 301 | key_feature_dim=64, 302 | self_posembed=None): 303 | super().__init__() 304 | decoder_layer = TransformerDecoderLayer( 305 | multihead_attn, FFN, d_model, key_feature_dim, self_posembed=self_posembed) 306 | self.layers = _get_clones(decoder_layer, num_decoder_layers) 307 | 308 | def forward(self, tgt, memory, query_pos=None): 309 | assert tgt.dim() == 3, 'Expect 3 dimensional inputs' 310 | tgt_shape = tgt.shape 311 | num_imgs, batch, dim = tgt.shape 312 | 313 | output = tgt 314 | for layer in self.layers: 315 | output = layer(output, memory, query_pos=query_pos) 316 | return output 317 | 318 | 319 | def _get_clones(module, N): 320 | return nn.ModuleList([module for i in range(N)]) 321 | 322 | 323 | def _get_activation_fn(activation): 324 | """Return an activation function given a string""" 325 | if activation == "relu": 326 | return F.relu 327 | if activation == "gelu": 328 | return F.gelu 329 | if activation == "glu": 330 | return F.glu 331 | raise RuntimeError(F"activation should be relu/gelu, not {activation}.") 332 | -------------------------------------------------------------------------------- /predict_tracking_gt.py: -------------------------------------------------------------------------------- 1 | # %% 2 | # import 3 | import os 4 | import pathlib 5 | import time 6 | 7 | import hydra 8 | import numpy as np 9 | import pandas as pd 10 | import torch 11 | import wandb 12 | import yaml 13 | import zarr 14 | from numcodecs import Blosc 15 | from omegaconf import DictConfig, OmegaConf 16 | from tqdm import tqdm 17 | 18 | from common.torch_util import to_numpy 19 | from common.geometry_util import barycentric_interpolation, mesh_sample_barycentric 20 | from datasets.tracking_dataset import SparseUnet3DTrackingDataModule2 21 | from networks.tracking_network import GarmentTrackingPipeline 22 | 23 | # %% 24 | # helper functions 25 | def get_checkpoint_df(checkpoint_dir): 26 | all_checkpoint_paths = sorted(pathlib.Path(checkpoint_dir).glob('*.ckpt')) 27 | rows = list() 28 | for path in all_checkpoint_paths: 29 | fname = path.stem 30 | row = dict() 31 | for item in fname.split('-'): 32 | key, value = item.split('=') 33 | row[key] = float(value) 34 | row['path'] = str(path.absolute()) 35 | rows.append(row) 36 | checkpoint_df = pd.DataFrame(rows) 37 | return checkpoint_df 38 | 39 | 40 | def get_mc_surface(pred_samples_group, group_key, 41 | init_num_points=12000, 42 | final_num_points=6000, 43 | value_threshold=0.13, seed=0, 44 | value_key='marching_cubes_mesh/volume_gradient_magnitude'): 45 | sample_group = pred_samples_group[group_key] 46 | # io 47 | pred_mc_group = sample_group['marching_cubes_mesh'] 48 | pred_mc_verts = pred_mc_group['verts'][:] 49 | pred_mc_faces = pred_mc_group['faces'][:] 50 | pred_mc_sim_verts = pred_mc_group['warp_field'][:] 51 | # point sample 52 | num_samples = int(init_num_points) 53 | pred_sample_bc, pred_sample_face_idx = mesh_sample_barycentric( 54 | pred_mc_verts, pred_mc_faces, 55 | num_samples=num_samples, seed=seed) 56 | pred_sample_nocs_points = barycentric_interpolation( 57 | pred_sample_bc, 58 | pred_mc_verts, 59 | pred_mc_faces[pred_sample_face_idx]) 60 | pred_sample_sim_points = barycentric_interpolation( 61 | pred_sample_bc, 62 | pred_mc_sim_verts, 63 | pred_mc_faces[pred_sample_face_idx]) 64 | # remove holes 65 | pred_value = sample_group[value_key][:] 66 | pred_sample_value = np.squeeze(barycentric_interpolation( 67 | pred_sample_bc, 68 | np.expand_dims(pred_value, axis=1), 69 | pred_mc_faces[pred_sample_face_idx])) 70 | is_valid_sample = pred_sample_value > value_threshold 71 | valid_pred_sample_nocs_points = pred_sample_nocs_points[is_valid_sample] 72 | valid_pred_sample_sim_points = pred_sample_sim_points[is_valid_sample] 73 | 74 | valid_num_samples = valid_pred_sample_nocs_points.shape[0] 75 | if valid_num_samples >= final_num_points: 76 | np.random.seed(seed) 77 | valid_idxs = np.random.choice(np.arange(valid_num_samples), size=final_num_points) 78 | valid_pred_sample_nocs_points = valid_pred_sample_nocs_points[valid_idxs, :] 79 | valid_pred_sample_sim_points = valid_pred_sample_sim_points[valid_idxs, :] 80 | else: 81 | np.random.seed(seed) 82 | shuffle_idxs = np.arange(valid_num_samples) 83 | np.random.shuffle(shuffle_idxs) 84 | valid_pred_sample_nocs_points = valid_pred_sample_nocs_points[shuffle_idxs, :] 85 | valid_pred_sample_sim_points = valid_pred_sample_sim_points[shuffle_idxs, :] 86 | res_num = final_num_points - valid_num_samples 87 | valid_pred_sample_nocs_points = np.concatenate([valid_pred_sample_nocs_points, 88 | valid_pred_sample_nocs_points[:res_num, :]], axis=0) 89 | valid_pred_sample_sim_points = np.concatenate([valid_pred_sample_sim_points, 90 | valid_pred_sample_sim_points[:res_num, :]], axis=0) 91 | assert valid_pred_sample_nocs_points.shape[0] == final_num_points 92 | 93 | return valid_pred_sample_nocs_points, valid_pred_sample_sim_points 94 | 95 | 96 | # %% 97 | # main script 98 | @hydra.main(config_path="config", 99 | config_name="predict_tracking_gt") 100 | def main(cfg: DictConfig) -> None: 101 | # hydra creates working directory automatically 102 | pred_output_dir = os.getcwd() 103 | print(pred_output_dir) 104 | 105 | # determine checkpoint 106 | checkpoint_path = os.path.expanduser(cfg.main.checkpoint_path) 107 | assert (pathlib.Path(checkpoint_path).exists()) 108 | 109 | # load datamodule 110 | datamodule = SparseUnet3DTrackingDataModule2(**cfg.datamodule) 111 | datamodule.prepare_data() 112 | batch_size = datamodule.kwargs['batch_size'] 113 | assert (batch_size == 1) 114 | # val and test dataloader both uses val_dataset 115 | val_dataset = datamodule.val_dataset 116 | # subset = getattr(datamodule, '{}_subset'.format(cfg.prediction.subset)) 117 | dataloader = getattr(datamodule, '{}_dataloader'.format(cfg.prediction.subset))() 118 | num_samples = len(dataloader) 119 | 120 | # load input zarr 121 | input_zarr_path = os.path.expanduser(cfg.datamodule.zarr_path) 122 | input_root = zarr.open(input_zarr_path, 'r') 123 | input_samples_group = input_root['samples'] 124 | 125 | if cfg.prediction.use_garmentnets_prediction: 126 | # garmentnets prediction zarr 127 | pred_zarr_path = os.path.join(cfg.main.garmentnets_prediction_output_dir, 'prediction.zarr') 128 | assert (pathlib.Path(pred_zarr_path).exists()) 129 | assert cfg.prediction.use_valid_grip_interval, \ 130 | 'only support grip interval evluation for garmentnets prediction' 131 | pred_root = zarr.open(pred_zarr_path, 'r') 132 | pred_samples_group = pred_root['samples'] 133 | 134 | # create output zarr 135 | output_zarr_path = os.path.join(pred_output_dir, 'prediction.zarr') 136 | store = zarr.DirectoryStore(output_zarr_path) 137 | compressor = Blosc(cname='zstd', clevel=6, shuffle=Blosc.BITSHUFFLE) 138 | output_root = zarr.group(store=store, overwrite=False) 139 | output_samples_group = output_root.require_group('samples', overwrite=False) 140 | 141 | root_attrs = { 142 | 'subset': cfg.prediction.subset 143 | } 144 | output_root.attrs.put(root_attrs) 145 | 146 | # init wandb 147 | wandb_path = os.path.join(pred_output_dir, 'wandb') 148 | os.mkdir(wandb_path) 149 | wandb_run = wandb.init( 150 | project=os.path.basename(__file__), 151 | **cfg.logger) 152 | wandb_meta = { 153 | 'run_name': wandb_run.name, 154 | 'run_id': wandb_run.id 155 | } 156 | meta = { 157 | 'script_path': __file__ 158 | } 159 | 160 | # load module to gpu 161 | model_cpu = GarmentTrackingPipeline.load_from_checkpoint(checkpoint_path) 162 | device = torch.device('cuda:{}'.format(cfg.main.gpu_id)) 163 | model = model_cpu.to(device) 164 | model.eval() 165 | model.requires_grad_(False) 166 | model.batch_size = batch_size 167 | assert model.batch_size == 1 168 | assert cfg.datamodule.num_workers == 0 169 | model.disable_mesh_nocs_refine_in_test = cfg.prediction.disable_mesh_nocs_refine_in_test 170 | model.disable_pc_nocs_refine_in_test = cfg.prediction.disable_pc_nocs_refine_in_test 171 | 172 | # dump final cfg 173 | all_config = { 174 | 'config': OmegaConf.to_container(cfg, resolve=True), 175 | 'output_dir': pred_output_dir, 176 | 'wandb': wandb_meta, 177 | 'meta': meta 178 | } 179 | yaml.dump(all_config, open('config.yaml', 'w'), default_flow_style=False) 180 | wandb.config.update(all_config) 181 | 182 | if cfg.prediction.use_garmentnets_prediction: 183 | # get pc_sim of the first frame 184 | start_idx = datamodule.test_idxs[0] 185 | test_group_row = val_dataset.groups_df.iloc[start_idx] 186 | group_key = test_group_row.group_key 187 | while group_key not in pred_samples_group: 188 | start_idx += 1 189 | test_group_row = val_dataset.groups_df.iloc[start_idx] 190 | group_key = test_group_row.group_key 191 | # calculate PC nocs of the first frame with garmentnets prediction 192 | next_sample_surface_nocs_points, next_sample_surface_sim_points = \ 193 | get_mc_surface(pred_samples_group, group_key, 194 | final_num_points=cfg.datamodule.num_surface_sample, 195 | value_threshold=cfg.prediction.value_threshold, seed=start_idx) 196 | 197 | val_dataset.set_prev_pose(None, None, start_idx, 198 | next_sample_surface_nocs_points, next_sample_surface_sim_points) 199 | 200 | current_video_id = None 201 | current_interval_id = None 202 | current_video_frame_idx = 0 203 | current_valid_video_frame_idx = 0 204 | current_interval_frame_idx = 0 205 | use_grip_interval = cfg.prediction.use_valid_grip_interval 206 | current_mesh_nocs_points = None 207 | # loop 208 | for batch_idx, batch_cpu in enumerate(tqdm(dataloader)): 209 | if len(batch_cpu) == 0: 210 | continue 211 | # locate raw info 212 | dataset_idx = int(batch_cpu['dataset_idx1'][0]) 213 | video_id = int(batch_cpu['video_id1'][0]) 214 | assert batch_cpu['video_id1'][0] == batch_cpu['video_id2'][0] 215 | is_new_video = current_video_id is None or current_video_id != video_id 216 | if is_new_video: 217 | # move to a new video 218 | current_video_id = video_id 219 | current_video_frame_idx = 0 220 | current_valid_video_frame_idx = 0 221 | current_mesh_nocs_points = None 222 | 223 | if use_grip_interval: 224 | grip_interval_id = val_dataset.idx_to_interval_list[dataset_idx] 225 | if grip_interval_id == -1: 226 | # not in valid grip interval 227 | current_interval_id = None 228 | current_video_frame_idx += 1 229 | continue 230 | is_new_interval = current_interval_id is None or current_interval_id != grip_interval_id 231 | if is_new_interval: 232 | current_interval_id = grip_interval_id 233 | current_interval_frame_idx = 0 234 | else: 235 | grip_interval_id = video_id 236 | current_interval_id = current_video_id 237 | current_interval_frame_idx = current_video_frame_idx 238 | 239 | val_group_row = val_dataset.groups_df.iloc[dataset_idx] 240 | group_key = val_group_row.group_key 241 | attr_keys = ['scale', 'sample_id', 'garment_name'] 242 | attrs = dict((x, val_group_row[x]) for x in attr_keys) 243 | attrs['batch_idx'] = batch_idx 244 | attrs['video_id'] = video_id 245 | attrs['video_frame_idx'] = current_video_frame_idx 246 | if use_grip_interval: 247 | attrs['interval_id'] = grip_interval_id 248 | attrs['interval_frame_idx'] = current_interval_frame_idx 249 | 250 | # load input zarr 251 | input_group = input_samples_group[group_key] 252 | 253 | # create zarr group 254 | output_group = output_samples_group.require_group( 255 | group_key, overwrite=False) 256 | output_group.attrs.put(attrs) 257 | 258 | batch = {key: value.to(device=device) for key, value in batch_cpu.items()} 259 | 260 | use_refine_mesh_for_query = current_valid_video_frame_idx < cfg.prediction.max_refine_mesh_step 261 | start_time = time.time() 262 | # stage 1 263 | result = model(batch, use_refine_mesh_for_query=use_refine_mesh_for_query) 264 | time_stage1 = time.time() 265 | print('Stage 1 used {} s....'.format(time_stage1 - start_time)) 266 | 267 | # save nocs data 268 | nocs_data = result['encoder_result'] 269 | if 'refined_pos_frame2' in nocs_data: 270 | pred_pc_nocs = nocs_data['refined_pos_frame2'] 271 | else: 272 | pred_pc_nocs = nocs_data['pos_frame2'] 273 | pc_data_torch = { 274 | 'input_prev_nocs': batch['y1'], 275 | 'pred_nocs': pred_pc_nocs, 276 | 'input_points': batch['pos2'], 277 | 'input_rgb': (batch['x2'] * 255).to(torch.uint8), 278 | 'gt_nocs': batch['y2'] 279 | } 280 | pc_data = dict((x[0], to_numpy(x[1])) for x in pc_data_torch.items()) 281 | output_pc_group = output_group.require_group( 282 | 'point_cloud', overwrite=False) 283 | for key, data in pc_data.items(): 284 | output_pc_group.array( 285 | name=key, data=data, chunks=data.shape, 286 | compressor=compressor, overwrite=True) 287 | 288 | # save predicted mesh data 289 | rot_mat_torch = batch['input_aug_rot_mat1'] 290 | gt_mesh_nocs_points = batch['gt_surf_query_points2'] 291 | if 'refined_surf_query_points2' in nocs_data: 292 | if use_refine_mesh_for_query: 293 | pred_mesh_nocs_points = nocs_data['refined_surf_query_points2'] 294 | else: 295 | pred_mesh_nocs_points = batch['surf_query_points2'] 296 | else: 297 | if current_mesh_nocs_points is not None: 298 | pred_mesh_nocs_points = current_mesh_nocs_points.clone() 299 | else: 300 | pred_mesh_nocs_points = gt_mesh_nocs_points.clone() 301 | 302 | surface_decoder_result = result['surface_decoder_result'] 303 | pred_warpfield = surface_decoder_result['out_features'] 304 | pred_sim_points = pred_warpfield 305 | 306 | pred_sim_points = pred_sim_points @ rot_mat_torch.T 307 | pred_sim_points = pred_sim_points.squeeze(-1).transpose(0, 1) 308 | gt_sim_points = batch['gt_sim_points2'] 309 | gt_sim_points = gt_sim_points @ rot_mat_torch.T 310 | gt_sim_points = gt_sim_points.squeeze(-1).transpose(0, 1) 311 | 312 | mesh_data_torch = {'pred_nocs_points': pred_mesh_nocs_points, 313 | 'pred_sim_points': pred_sim_points, 314 | 'gt_sim_points': gt_sim_points, 315 | 'gt_nocs_points': gt_mesh_nocs_points} 316 | l2_norm_error = torch.mean(torch.norm(pred_sim_points - gt_sim_points, dim=1)) 317 | print('interval {}, interval frame {}, video {}, video frame {}, error: {}'.format( 318 | grip_interval_id, current_interval_frame_idx, video_id, current_video_frame_idx, l2_norm_error.item())) 319 | mesh_data = dict((x[0], to_numpy(x[1])) for x in mesh_data_torch.items()) 320 | output_mesh_points_group = output_group.require_group( 321 | 'mesh_points', overwrite=False) 322 | for key, data in mesh_data.items(): 323 | output_mesh_points_group.array( 324 | name=key, data=data, chunks=data.shape, 325 | compressor=compressor, overwrite=True) 326 | 327 | # copy gt mesh data 328 | rot_mat = np.squeeze(to_numpy(batch_cpu['input_aug_rot_mat1'])) 329 | aug_keys = ['cloth_verts'] 330 | input_mesh_group = input_group['mesh'] 331 | output_mesh_group = output_group.require_group('gt_mesh', overwrite=False) 332 | for key, value in input_mesh_group.arrays(): 333 | data = value[:] 334 | if key in aug_keys: 335 | data = data @ rot_mat.T 336 | output_mesh_group.array( 337 | name=key, data=data, chunks=data.shape, 338 | compressor=compressor, overwrite=True) 339 | 340 | # logging 341 | log_data = { 342 | 'prediction_batch_idx': batch_idx, 343 | 'prediction_video_id': video_id, 344 | 'prediction_video_frame_idx': current_video_frame_idx, 345 | } 346 | wandb.log( 347 | data=log_data, 348 | step=batch_idx) 349 | 350 | is_last_video_frame = len(val_dataset.video_to_idxs_dict[str(video_id).zfill(6)]) \ 351 | == current_video_frame_idx + 1 352 | if use_grip_interval: 353 | is_last_interval_frame = len(val_dataset.interval_to_idxs_dict[grip_interval_id]) \ 354 | == current_interval_frame_idx + 1 355 | last_grip_interval_ids = [val_dataset.idx_to_interval_list[idx] 356 | for idx in range(dataset_idx + 1, 357 | val_dataset.video_to_idxs_dict[str(video_id).zfill(6)][-1])] 358 | is_last_interval = True 359 | for grip_interval_id in last_grip_interval_ids: 360 | if grip_interval_id != -1: 361 | is_last_interval = False 362 | if cfg.prediction.use_cross_interval_tracking: 363 | is_last_frame = is_last_video_frame or is_last_interval 364 | else: 365 | is_last_frame = is_last_interval_frame or is_last_video_frame 366 | else: 367 | is_last_frame = is_last_video_frame 368 | 369 | if is_last_frame: 370 | if cfg.prediction.use_garmentnets_prediction and dataset_idx + 1 < len(val_dataset): 371 | # get next frame pc_sim 372 | start_idx = dataset_idx + 1 373 | test_group_row = val_dataset.groups_df.iloc[start_idx] 374 | group_key = test_group_row.group_key 375 | while group_key not in pred_samples_group and start_idx + 1 < len(val_dataset): 376 | start_idx += 1 377 | test_group_row = val_dataset.groups_df.iloc[start_idx] 378 | group_key = test_group_row.group_key 379 | # calculate PC nocs of the first frame with garmentnets prediction 380 | next_sample_surface_nocs_points, next_sample_surface_sim_points = \ 381 | get_mc_surface(pred_samples_group, group_key, 382 | final_num_points=cfg.datamodule.num_surface_sample, 383 | value_threshold=cfg.prediction.value_threshold, seed=start_idx) 384 | val_dataset.set_prev_pose(None, None, start_idx, 385 | next_sample_surface_nocs_points, next_sample_surface_sim_points) 386 | else: 387 | val_dataset.set_prev_pose(None, None, None, None, None) 388 | else: 389 | assert model.batch_size == 1 390 | if current_valid_video_frame_idx < cfg.prediction.max_refine_mesh_step: 391 | current_mesh_nocs_points = pred_mesh_nocs_points.clone().contiguous() 392 | else: 393 | if val_dataset.prev_surf_query_points1 is not None: 394 | current_mesh_nocs_points = torch.from_numpy(val_dataset.prev_surf_query_points1.copy()) 395 | else: 396 | current_mesh_nocs_points = gt_mesh_nocs_points.clone().contiguous() 397 | current_pc_nocs = pred_pc_nocs.clone().contiguous() 398 | val_dataset.set_prev_pose(to_numpy(current_mesh_nocs_points), to_numpy(current_pc_nocs), dataset_idx + 1, 399 | None, None) 400 | 401 | current_video_frame_idx += 1 402 | current_valid_video_frame_idx += 1 403 | current_interval_frame_idx += 1 404 | 405 | 406 | # %% 407 | # driver 408 | if __name__ == "__main__": 409 | main() 410 | -------------------------------------------------------------------------------- /predict_tracking_noise.py: -------------------------------------------------------------------------------- 1 | # %% 2 | # import 3 | import os 4 | import pathlib 5 | import time 6 | 7 | import hydra 8 | import numpy as np 9 | import pandas as pd 10 | import torch 11 | import wandb 12 | import yaml 13 | import zarr 14 | from numcodecs import Blosc 15 | from omegaconf import DictConfig, OmegaConf 16 | from tqdm import tqdm 17 | 18 | from common.torch_util import to_numpy 19 | from common.geometry_util import barycentric_interpolation, mesh_sample_barycentric 20 | from datasets.tracking_dataset import SparseUnet3DTrackingDataModule2 21 | from networks.tracking_network import GarmentTrackingPipeline 22 | 23 | # %% 24 | # helper functions 25 | def get_checkpoint_df(checkpoint_dir): 26 | all_checkpoint_paths = sorted(pathlib.Path(checkpoint_dir).glob('*.ckpt')) 27 | rows = list() 28 | for path in all_checkpoint_paths: 29 | fname = path.stem 30 | row = dict() 31 | for item in fname.split('-'): 32 | key, value = item.split('=') 33 | row[key] = float(value) 34 | row['path'] = str(path.absolute()) 35 | rows.append(row) 36 | checkpoint_df = pd.DataFrame(rows) 37 | return checkpoint_df 38 | 39 | 40 | def get_mc_surface(pred_samples_group, group_key, 41 | init_num_points=12000, 42 | final_num_points=6000, 43 | value_threshold=0.13, seed=0, 44 | value_key='marching_cubes_mesh/volume_gradient_magnitude'): 45 | sample_group = pred_samples_group[group_key] 46 | # io 47 | pred_mc_group = sample_group['marching_cubes_mesh'] 48 | pred_mc_verts = pred_mc_group['verts'][:] 49 | pred_mc_faces = pred_mc_group['faces'][:] 50 | pred_mc_sim_verts = pred_mc_group['warp_field'][:] 51 | # point sample 52 | num_samples = int(init_num_points) 53 | pred_sample_bc, pred_sample_face_idx = mesh_sample_barycentric( 54 | pred_mc_verts, pred_mc_faces, 55 | num_samples=num_samples, seed=seed) 56 | pred_sample_nocs_points = barycentric_interpolation( 57 | pred_sample_bc, 58 | pred_mc_verts, 59 | pred_mc_faces[pred_sample_face_idx]) 60 | pred_sample_sim_points = barycentric_interpolation( 61 | pred_sample_bc, 62 | pred_mc_sim_verts, 63 | pred_mc_faces[pred_sample_face_idx]) 64 | # remove holes 65 | pred_value = sample_group[value_key][:] 66 | pred_sample_value = np.squeeze(barycentric_interpolation( 67 | pred_sample_bc, 68 | np.expand_dims(pred_value, axis=1), 69 | pred_mc_faces[pred_sample_face_idx])) 70 | is_valid_sample = pred_sample_value > value_threshold 71 | valid_pred_sample_nocs_points = pred_sample_nocs_points[is_valid_sample] 72 | valid_pred_sample_sim_points = pred_sample_sim_points[is_valid_sample] 73 | 74 | valid_num_samples = valid_pred_sample_nocs_points.shape[0] 75 | if valid_num_samples >= final_num_points: 76 | np.random.seed(seed) 77 | valid_idxs = np.random.choice(np.arange(valid_num_samples), size=final_num_points) 78 | valid_pred_sample_nocs_points = valid_pred_sample_nocs_points[valid_idxs, :] 79 | valid_pred_sample_sim_points = valid_pred_sample_sim_points[valid_idxs, :] 80 | else: 81 | np.random.seed(seed) 82 | shuffle_idxs = np.arange(valid_num_samples) 83 | np.random.shuffle(shuffle_idxs) 84 | valid_pred_sample_nocs_points = valid_pred_sample_nocs_points[shuffle_idxs, :] 85 | valid_pred_sample_sim_points = valid_pred_sample_sim_points[shuffle_idxs, :] 86 | res_num = final_num_points - valid_num_samples 87 | valid_pred_sample_nocs_points = np.concatenate([valid_pred_sample_nocs_points, 88 | valid_pred_sample_nocs_points[:res_num, :]], axis=0) 89 | valid_pred_sample_sim_points = np.concatenate([valid_pred_sample_sim_points, 90 | valid_pred_sample_sim_points[:res_num, :]], axis=0) 91 | assert valid_pred_sample_nocs_points.shape[0] == final_num_points 92 | 93 | return valid_pred_sample_nocs_points, valid_pred_sample_sim_points 94 | 95 | 96 | # %% 97 | # main script 98 | @hydra.main(config_path="config", 99 | config_name="predict_tracking_noise") 100 | def main(cfg: DictConfig) -> None: 101 | # hydra creates working directory automatically 102 | pred_output_dir = os.getcwd() 103 | print(pred_output_dir) 104 | 105 | # determine checkpoint 106 | checkpoint_path = os.path.expanduser(cfg.main.checkpoint_path) 107 | assert (pathlib.Path(checkpoint_path).exists()) 108 | 109 | # load datamodule 110 | datamodule = SparseUnet3DTrackingDataModule2(**cfg.datamodule) 111 | datamodule.prepare_data() 112 | batch_size = datamodule.kwargs['batch_size'] 113 | assert (batch_size == 1) 114 | # val and test dataloader both uses val_dataset 115 | val_dataset = datamodule.val_dataset 116 | # subset = getattr(datamodule, '{}_subset'.format(cfg.prediction.subset)) 117 | dataloader = getattr(datamodule, '{}_dataloader'.format(cfg.prediction.subset))() 118 | num_samples = len(dataloader) 119 | 120 | # load input zarr 121 | input_zarr_path = os.path.expanduser(cfg.datamodule.zarr_path) 122 | input_root = zarr.open(input_zarr_path, 'r') 123 | input_samples_group = input_root['samples'] 124 | 125 | if cfg.prediction.use_garmentnets_prediction: 126 | # garmentnets prediction zarr 127 | pred_zarr_path = os.path.join(cfg.main.garmentnets_prediction_output_dir, 'prediction.zarr') 128 | assert (pathlib.Path(pred_zarr_path).exists()) 129 | assert cfg.prediction.use_valid_grip_interval, \ 130 | 'only support grip interval evluation for garmentnets prediction' 131 | pred_root = zarr.open(pred_zarr_path, 'r') 132 | pred_samples_group = pred_root['samples'] 133 | 134 | # create output zarr 135 | output_zarr_path = os.path.join(pred_output_dir, 'prediction.zarr') 136 | store = zarr.DirectoryStore(output_zarr_path) 137 | compressor = Blosc(cname='zstd', clevel=6, shuffle=Blosc.BITSHUFFLE) 138 | output_root = zarr.group(store=store, overwrite=False) 139 | output_samples_group = output_root.require_group('samples', overwrite=False) 140 | 141 | root_attrs = { 142 | 'subset': cfg.prediction.subset 143 | } 144 | output_root.attrs.put(root_attrs) 145 | 146 | # init wandb 147 | wandb_path = os.path.join(pred_output_dir, 'wandb') 148 | os.mkdir(wandb_path) 149 | wandb_run = wandb.init( 150 | project=os.path.basename(__file__), 151 | **cfg.logger) 152 | wandb_meta = { 153 | 'run_name': wandb_run.name, 154 | 'run_id': wandb_run.id 155 | } 156 | meta = { 157 | 'script_path': __file__ 158 | } 159 | 160 | # load module to gpu 161 | model_cpu = GarmentTrackingPipeline.load_from_checkpoint(checkpoint_path) 162 | device = torch.device('cuda:{}'.format(cfg.main.gpu_id)) 163 | model = model_cpu.to(device) 164 | model.eval() 165 | model.requires_grad_(False) 166 | model.batch_size = batch_size 167 | assert model.batch_size == 1 168 | assert cfg.datamodule.num_workers == 0 169 | model.disable_mesh_nocs_refine_in_test = cfg.prediction.disable_mesh_nocs_refine_in_test 170 | model.disable_pc_nocs_refine_in_test = cfg.prediction.disable_pc_nocs_refine_in_test 171 | 172 | # dump final cfg 173 | all_config = { 174 | 'config': OmegaConf.to_container(cfg, resolve=True), 175 | 'output_dir': pred_output_dir, 176 | 'wandb': wandb_meta, 177 | 'meta': meta 178 | } 179 | yaml.dump(all_config, open('config.yaml', 'w'), default_flow_style=False) 180 | wandb.config.update(all_config) 181 | 182 | if cfg.prediction.use_garmentnets_prediction: 183 | # get pc_sim of the first frame 184 | start_idx = datamodule.test_idxs[0] 185 | test_group_row = val_dataset.groups_df.iloc[start_idx] 186 | group_key = test_group_row.group_key 187 | while group_key not in pred_samples_group: 188 | start_idx += 1 189 | test_group_row = val_dataset.groups_df.iloc[start_idx] 190 | group_key = test_group_row.group_key 191 | # calculate PC nocs of the first frame with garmentnets prediction 192 | next_sample_surface_nocs_points, next_sample_surface_sim_points = \ 193 | get_mc_surface(pred_samples_group, group_key, 194 | final_num_points=cfg.datamodule.num_surface_sample, 195 | value_threshold=cfg.prediction.value_threshold, seed=start_idx) 196 | 197 | val_dataset.set_prev_pose(None, None, start_idx, 198 | next_sample_surface_nocs_points, next_sample_surface_sim_points) 199 | 200 | current_video_id = None 201 | current_interval_id = None 202 | current_video_frame_idx = 0 203 | current_valid_video_frame_idx = 0 204 | current_interval_frame_idx = 0 205 | use_grip_interval = cfg.prediction.use_valid_grip_interval 206 | current_mesh_nocs_points = None 207 | # loop 208 | for batch_idx, batch_cpu in enumerate(tqdm(dataloader)): 209 | if len(batch_cpu) == 0: 210 | continue 211 | # locate raw info 212 | dataset_idx = int(batch_cpu['dataset_idx1'][0]) 213 | video_id = int(batch_cpu['video_id1'][0]) 214 | assert batch_cpu['video_id1'][0] == batch_cpu['video_id2'][0] 215 | is_new_video = current_video_id is None or current_video_id != video_id 216 | if is_new_video: 217 | # move to a new video 218 | current_video_id = video_id 219 | current_video_frame_idx = 0 220 | current_valid_video_frame_idx = 0 221 | current_mesh_nocs_points = None 222 | 223 | if use_grip_interval: 224 | grip_interval_id = val_dataset.idx_to_interval_list[dataset_idx] 225 | if grip_interval_id == -1: 226 | # not in valid grip interval 227 | current_interval_id = None 228 | current_video_frame_idx += 1 229 | continue 230 | is_new_interval = current_interval_id is None or current_interval_id != grip_interval_id 231 | if is_new_interval: 232 | current_interval_id = grip_interval_id 233 | current_interval_frame_idx = 0 234 | else: 235 | grip_interval_id = video_id 236 | current_interval_id = current_video_id 237 | current_interval_frame_idx = current_video_frame_idx 238 | 239 | val_group_row = val_dataset.groups_df.iloc[dataset_idx] 240 | group_key = val_group_row.group_key 241 | attr_keys = ['scale', 'sample_id', 'garment_name'] 242 | attrs = dict((x, val_group_row[x]) for x in attr_keys) 243 | attrs['batch_idx'] = batch_idx 244 | attrs['video_id'] = video_id 245 | attrs['video_frame_idx'] = current_video_frame_idx 246 | if use_grip_interval: 247 | attrs['interval_id'] = grip_interval_id 248 | attrs['interval_frame_idx'] = current_interval_frame_idx 249 | 250 | # load input zarr 251 | input_group = input_samples_group[group_key] 252 | 253 | # create zarr group 254 | output_group = output_samples_group.require_group( 255 | group_key, overwrite=False) 256 | output_group.attrs.put(attrs) 257 | 258 | batch = {key: value.to(device=device) for key, value in batch_cpu.items()} 259 | 260 | use_refine_mesh_for_query = current_valid_video_frame_idx < cfg.prediction.max_refine_mesh_step 261 | start_time = time.time() 262 | # stage 1 263 | result = model(batch, use_refine_mesh_for_query=use_refine_mesh_for_query) 264 | time_stage1 = time.time() 265 | print('Stage 1 used {} s....'.format(time_stage1 - start_time)) 266 | 267 | # save nocs data 268 | nocs_data = result['encoder_result'] 269 | if 'refined_pos_frame2' in nocs_data: 270 | pred_pc_nocs = nocs_data['refined_pos_frame2'] 271 | else: 272 | pred_pc_nocs = nocs_data['pos_frame2'] 273 | pc_data_torch = { 274 | 'input_prev_nocs': batch['y1'], 275 | 'pred_nocs': pred_pc_nocs, 276 | 'input_points': batch['pos2'], 277 | 'input_rgb': (batch['x2'] * 255).to(torch.uint8), 278 | 'gt_nocs': batch['y2'] 279 | } 280 | pc_data = dict((x[0], to_numpy(x[1])) for x in pc_data_torch.items()) 281 | output_pc_group = output_group.require_group( 282 | 'point_cloud', overwrite=False) 283 | for key, data in pc_data.items(): 284 | output_pc_group.array( 285 | name=key, data=data, chunks=data.shape, 286 | compressor=compressor, overwrite=True) 287 | 288 | # save predicted mesh data 289 | rot_mat_torch = batch['input_aug_rot_mat1'] 290 | gt_mesh_nocs_points = batch['gt_surf_query_points2'] 291 | if 'refined_surf_query_points2' in nocs_data: 292 | if use_refine_mesh_for_query: 293 | pred_mesh_nocs_points = nocs_data['refined_surf_query_points2'] 294 | else: 295 | pred_mesh_nocs_points = batch['surf_query_points2'] 296 | else: 297 | if current_mesh_nocs_points is not None: 298 | pred_mesh_nocs_points = current_mesh_nocs_points.clone() 299 | else: 300 | pred_mesh_nocs_points = gt_mesh_nocs_points.clone() 301 | 302 | surface_decoder_result = result['surface_decoder_result'] 303 | pred_warpfield = surface_decoder_result['out_features'] 304 | pred_sim_points = pred_warpfield 305 | 306 | pred_sim_points = pred_sim_points @ rot_mat_torch.T 307 | pred_sim_points = pred_sim_points.squeeze(-1).transpose(0, 1) 308 | gt_sim_points = batch['gt_sim_points2'] 309 | gt_sim_points = gt_sim_points @ rot_mat_torch.T 310 | gt_sim_points = gt_sim_points.squeeze(-1).transpose(0, 1) 311 | 312 | mesh_data_torch = {'pred_nocs_points': pred_mesh_nocs_points, 313 | 'pred_sim_points': pred_sim_points, 314 | 'gt_sim_points': gt_sim_points, 315 | 'gt_nocs_points': gt_mesh_nocs_points} 316 | l2_norm_error = torch.mean(torch.norm(pred_sim_points - gt_sim_points, dim=1)) 317 | print('interval {}, interval frame {}, video {}, video frame {}, error: {}'.format( 318 | grip_interval_id, current_interval_frame_idx, video_id, current_video_frame_idx, l2_norm_error.item())) 319 | mesh_data = dict((x[0], to_numpy(x[1])) for x in mesh_data_torch.items()) 320 | output_mesh_points_group = output_group.require_group( 321 | 'mesh_points', overwrite=False) 322 | for key, data in mesh_data.items(): 323 | output_mesh_points_group.array( 324 | name=key, data=data, chunks=data.shape, 325 | compressor=compressor, overwrite=True) 326 | 327 | # copy gt mesh data 328 | rot_mat = np.squeeze(to_numpy(batch_cpu['input_aug_rot_mat1'])) 329 | aug_keys = ['cloth_verts'] 330 | input_mesh_group = input_group['mesh'] 331 | output_mesh_group = output_group.require_group('gt_mesh', overwrite=False) 332 | for key, value in input_mesh_group.arrays(): 333 | data = value[:] 334 | if key in aug_keys: 335 | data = data @ rot_mat.T 336 | output_mesh_group.array( 337 | name=key, data=data, chunks=data.shape, 338 | compressor=compressor, overwrite=True) 339 | 340 | # logging 341 | log_data = { 342 | 'prediction_batch_idx': batch_idx, 343 | 'prediction_video_id': video_id, 344 | 'prediction_video_frame_idx': current_video_frame_idx, 345 | } 346 | wandb.log( 347 | data=log_data, 348 | step=batch_idx) 349 | 350 | is_last_video_frame = len(val_dataset.video_to_idxs_dict[str(video_id).zfill(6)]) \ 351 | == current_video_frame_idx + 1 352 | if use_grip_interval: 353 | is_last_interval_frame = len(val_dataset.interval_to_idxs_dict[grip_interval_id]) \ 354 | == current_interval_frame_idx + 1 355 | last_grip_interval_ids = [val_dataset.idx_to_interval_list[idx] 356 | for idx in range(dataset_idx + 1, 357 | val_dataset.video_to_idxs_dict[str(video_id).zfill(6)][-1])] 358 | is_last_interval = True 359 | for grip_interval_id in last_grip_interval_ids: 360 | if grip_interval_id != -1: 361 | is_last_interval = False 362 | if cfg.prediction.use_cross_interval_tracking: 363 | is_last_frame = is_last_video_frame or is_last_interval 364 | else: 365 | is_last_frame = is_last_interval_frame or is_last_video_frame 366 | else: 367 | is_last_frame = is_last_video_frame 368 | 369 | if is_last_frame: 370 | if cfg.prediction.use_garmentnets_prediction and dataset_idx + 1 < len(val_dataset): 371 | # get next frame pc_sim 372 | start_idx = dataset_idx + 1 373 | test_group_row = val_dataset.groups_df.iloc[start_idx] 374 | group_key = test_group_row.group_key 375 | while group_key not in pred_samples_group and start_idx + 1 < len(val_dataset): 376 | start_idx += 1 377 | test_group_row = val_dataset.groups_df.iloc[start_idx] 378 | group_key = test_group_row.group_key 379 | # calculate PC nocs of the first frame with garmentnets prediction 380 | next_sample_surface_nocs_points, next_sample_surface_sim_points = \ 381 | get_mc_surface(pred_samples_group, group_key, 382 | final_num_points=cfg.datamodule.num_surface_sample, 383 | value_threshold=cfg.prediction.value_threshold, seed=start_idx) 384 | val_dataset.set_prev_pose(None, None, start_idx, 385 | next_sample_surface_nocs_points, next_sample_surface_sim_points) 386 | else: 387 | val_dataset.set_prev_pose(None, None, None, None, None) 388 | else: 389 | assert model.batch_size == 1 390 | if current_valid_video_frame_idx < cfg.prediction.max_refine_mesh_step: 391 | current_mesh_nocs_points = pred_mesh_nocs_points.clone().contiguous() 392 | else: 393 | if val_dataset.prev_surf_query_points1 is not None: 394 | current_mesh_nocs_points = torch.from_numpy(val_dataset.prev_surf_query_points1.copy()) 395 | else: 396 | current_mesh_nocs_points = gt_mesh_nocs_points.clone().contiguous() 397 | current_pc_nocs = pred_pc_nocs.clone().contiguous() 398 | val_dataset.set_prev_pose(to_numpy(current_mesh_nocs_points), to_numpy(current_pc_nocs), dataset_idx + 1, 399 | None, None) 400 | 401 | current_video_frame_idx += 1 402 | current_valid_video_frame_idx += 1 403 | current_interval_frame_idx += 1 404 | 405 | 406 | # %% 407 | # driver 408 | if __name__ == "__main__": 409 | main() 410 | -------------------------------------------------------------------------------- /eval_tracking.py: -------------------------------------------------------------------------------- 1 | # %% 2 | # set numpy threads 3 | import os 4 | os.environ["OMP_NUM_THREADS"] = "20" 5 | os.environ["OPENBLAS_NUM_THREADS"] = "20" 6 | os.environ["MKL_NUM_THREADS"] = "20" 7 | os.environ["VECLIB_MAXIMUM_THREADS"] = "20" 8 | os.environ["NUMEXPR_NUM_THREADS"] = "20" 9 | # currently requires custom-built igl-python binding 10 | os.environ["IGL_PARALLEL_FOR_NUM_THREADS"] = "1" 11 | import numpy as np 12 | import igl 13 | 14 | # %% 15 | # import 16 | import pathlib 17 | from pprint import pprint 18 | import json 19 | 20 | import yaml 21 | import hydra 22 | from omegaconf import DictConfig, OmegaConf 23 | import wandb 24 | import zarr 25 | from numcodecs import Blosc 26 | from tqdm import tqdm 27 | 28 | import pandas as pd 29 | import numpy as np 30 | from scipy.spatial import ckdtree 31 | import igl 32 | 33 | from common.parallel_util import parallel_map 34 | from common.geometry_util import ( 35 | AABBNormalizer, AABBGripNormalizer) 36 | 37 | 38 | # %% 39 | # helper functions 40 | def write_dict_to_group(data, group, compressor): 41 | for key, data in data.items(): 42 | if isinstance(data, np.ndarray): 43 | group.array( 44 | name=key, data=data, chunks=data.shape, 45 | compressor=compressor, overwrite=True) 46 | else: 47 | group[key] = data 48 | 49 | def compute_pc_metrics(sample_key, samples_group, nocs_aabb, **kwargs): 50 | sample_group = samples_group[sample_key] 51 | # io 52 | pc_group = sample_group['point_cloud'] 53 | gt_nocs = pc_group['gt_nocs'][:] 54 | pred_nocs = pc_group['pred_nocs'][:] 55 | 56 | # transform 57 | normalizer = AABBNormalizer(nocs_aabb) 58 | gt_nocs = normalizer.inverse(gt_nocs) 59 | pred_nocs = normalizer.inverse(pred_nocs) 60 | 61 | # compute 62 | nocs_diff = pred_nocs - gt_nocs 63 | nocs_error_mean_per_dim = np.mean(np.abs(nocs_diff), axis=0) 64 | nocs_diff_std_per_dim = np.std(nocs_diff, axis=0) 65 | 66 | mirror_gt_nocs = gt_nocs.copy() 67 | mirror_gt_nocs[:, 0] = -mirror_gt_nocs[:, 0] 68 | mirror_nocs_error = pred_nocs - mirror_gt_nocs 69 | nocs_error_dist = np.linalg.norm(nocs_diff, axis=1) 70 | mirror_nocs_error_dist = np.linalg.norm(mirror_nocs_error, axis=1) 71 | mirror_min_nocs_error_dist = np.minimum(nocs_error_dist, mirror_nocs_error_dist) 72 | 73 | metrics = { 74 | 'nocs_pc_error_distance': np.mean(nocs_error_dist), 75 | 'nocs_pc_mirror_error_distance': np.mean(mirror_nocs_error_dist), 76 | 'nocs_pc_min_agg_error_distance': np.mean(mirror_min_nocs_error_dist), 77 | 'nocs_pc_agg_min_error_distance': np.minimum(np.mean(nocs_error_dist), np.mean(mirror_nocs_error_dist)) 78 | } 79 | axis_order = ['x', 'y', 'z'] 80 | per_dim_features = { 81 | 'nocs_pc_diff_std': nocs_diff_std_per_dim, 82 | 'nocs_pc_error': nocs_error_mean_per_dim, 83 | } 84 | for key, value in per_dim_features.items(): 85 | for i in range(3): 86 | metrics['_'.join([key, axis_order[i]])] = value[i] 87 | return metrics 88 | 89 | 90 | def compute_chamfer(sample_key, samples_group, nocs_aabb, 91 | **kwargs): 92 | sample_group = samples_group[sample_key] 93 | 94 | mesh_points_group = sample_group['mesh_points'] 95 | pred_sim_points = mesh_points_group['pred_sim_points'][:] 96 | gt_sim_points = mesh_points_group['gt_sim_points'][:] 97 | 98 | # compute chamfer distance 99 | def get_chamfer(pred_points, gt_points): 100 | pred_tree = ckdtree.cKDTree(pred_points) 101 | gt_tree = ckdtree.cKDTree(gt_points) 102 | forward_distance, forward_nn_idx = gt_tree.query(pred_points, k=1) 103 | backward_distance, backward_nn_idx = pred_tree.query(gt_points, k=1) 104 | forward_chamfer = np.mean(forward_distance) 105 | backward_chamfer = np.mean(backward_distance) 106 | symmetrical_chamfer = np.mean([forward_chamfer, backward_chamfer]) 107 | result = { 108 | # 'chamfer_forward': forward_chamfer, 109 | # 'chamfer_backward': backward_chamfer, 110 | 'chamfer_symmetrical': symmetrical_chamfer 111 | } 112 | return result 113 | 114 | in_data = { 115 | 'sim': { 116 | 'pred_points': pred_sim_points, 117 | 'gt_points': gt_sim_points 118 | }, 119 | } 120 | 121 | key_order = ['sim'] 122 | old_in_data = in_data 123 | in_data = dict([(x, old_in_data[x]) for x in key_order if x in old_in_data]) 124 | 125 | result = dict() 126 | for category, kwargs in in_data.items(): 127 | out_data = get_chamfer(**kwargs) 128 | for key, value in out_data.items(): 129 | result['_'.join([key, category])] = value 130 | return result 131 | 132 | 133 | def compute_euclidian( 134 | sample_key, 135 | samples_group, 136 | **kwargs): 137 | sample_group = samples_group[sample_key] 138 | 139 | mesh_points_group = sample_group['mesh_points'] 140 | pred_sim_points = mesh_points_group['pred_sim_points'][:] 141 | gt_sim_points = mesh_points_group['gt_sim_points'][:] 142 | 143 | # compute chamfer distance 144 | def get_euclidian(pred_points, gt_points): 145 | euclidian = np.mean(np.linalg.norm(pred_sim_points - gt_sim_points, axis=1)) 146 | result = { 147 | 'euclidian': euclidian 148 | } 149 | return result 150 | 151 | in_data = { 152 | 'sim': { 153 | 'pred_points': pred_sim_points, 154 | 'gt_points': gt_sim_points 155 | }, 156 | } 157 | 158 | key_order = ['sim'] 159 | old_in_data = in_data 160 | in_data = dict([(x, old_in_data[x]) for x in key_order if x in old_in_data]) 161 | 162 | result = dict() 163 | for category, kwargs in in_data.items(): 164 | out_data = get_euclidian(**kwargs) 165 | for key, value in out_data.items(): 166 | result['_'.join([key, category])] = value 167 | return result 168 | 169 | 170 | # %% 171 | # visualization functions 172 | def get_task_mesh_vis( 173 | sample_key, 174 | samples_group, 175 | offset=(0.8,0,0), 176 | save_path=None, 177 | **kwargs): 178 | """ 179 | Visualizes task space result as a point cloud 180 | Order: GT sim mesh Pred sim mesh Sim point cloud 181 | """ 182 | sample_group = samples_group[sample_key] 183 | # io 184 | mesh_points_group = sample_group['mesh_points'] 185 | pred_sim_points = mesh_points_group['pred_sim_points'][:] 186 | pred_nocs_points = mesh_points_group['pred_nocs_points'][:] 187 | gt_nocs_points = mesh_points_group['gt_nocs_points'][:] 188 | gt_sim_points = mesh_points_group['gt_sim_points'][:] 189 | if 'attention_score' in mesh_points_group: 190 | attention_score = mesh_points_group['attention_score'][:].repeat(3, axis=1) 191 | else: 192 | attention_score = None 193 | 194 | pc_group = sample_group['point_cloud'] 195 | gt_input_pc = pc_group['input_points'][:] 196 | gt_input_rgb = pc_group['input_rgb'][:].astype(np.float32) 197 | pred_input_nocs = pc_group['pred_nocs'][:] 198 | gt_nocs_pc = pc_group['gt_nocs'][:] 199 | 200 | mesh_group = sample_group['gt_mesh'] 201 | cloth_faces_tri = mesh_group['cloth_faces_tri'][:] 202 | 203 | # vis 204 | offset_vec = np.array(offset) 205 | gt_sim_pc = np.concatenate([gt_sim_points - offset_vec, gt_nocs_points * 255], axis=1) 206 | pred_sim_pc = np.concatenate([pred_sim_points, pred_nocs_points * 255], axis=1) 207 | pred_nocs_pc = np.concatenate([gt_input_pc + 2 * offset_vec, pred_input_nocs * 255], axis=1) 208 | gt_rgb_pc = np.concatenate([gt_input_pc + offset_vec, gt_input_rgb], axis=1) 209 | gt_nocs_pc = np.concatenate([gt_input_pc + 3 * offset_vec, gt_nocs_pc * 255], axis=1) 210 | if attention_score is not None: 211 | pred_att_pc = np.concatenate([pred_sim_points - 2 * offset_vec, attention_score * 255], axis=1) 212 | all_pc = np.concatenate([pred_att_pc, gt_sim_pc, pred_sim_pc, gt_rgb_pc, pred_nocs_pc, gt_nocs_pc], axis=0).astype(np.float32) 213 | else: 214 | all_pc = np.concatenate([gt_sim_pc, pred_sim_pc, gt_rgb_pc, pred_nocs_pc, gt_nocs_pc], axis=0).astype(np.float32) 215 | if save_path is not None: 216 | num_mesh_points = pred_sim_pc.shape[0] 217 | num_pc_points = gt_rgb_pc.shape[0] 218 | padding = np.array([[num_mesh_points, num_mesh_points, num_mesh_points, 219 | num_pc_points, num_pc_points, num_pc_points]]).astype(np.float32) 220 | all_pc = np.concatenate([all_pc, padding], axis=0).astype(np.float32) 221 | np.save(save_path, all_pc) 222 | print('Saving to {}!'.format(save_path)) 223 | np.save(save_path.replace('vis', 'vis_faces'), cloth_faces_tri) 224 | print('Saving to {}!'.format(save_path.replace('vis', 'vis_faces'))) 225 | vis_obj = wandb.Object3D(all_pc) 226 | return vis_obj 227 | 228 | 229 | def get_nocs_pc_vis( 230 | sample_key, 231 | samples_group, 232 | offset=[1.0,0,0], **kwargs): 233 | """ 234 | GT nocs pc Pred nocs pc (colored with gt nocs) 235 | """ 236 | sample_group = samples_group[sample_key] 237 | # io 238 | pc_group = sample_group['point_cloud'] 239 | gt_nocs_pc = pc_group['gt_nocs'][:] 240 | pred_nocs_pc = pc_group['pred_nocs'][:] 241 | input_prev_nocs_pc = pc_group['input_prev_nocs'][:] 242 | if 'pred_nocs_confidence' in pc_group: 243 | pred_nocs_confidence = pc_group['pred_nocs_confidence'][:] 244 | else: 245 | pred_nocs_confidence = None 246 | 247 | # vis 248 | offset_vec = np.array(offset) 249 | gt_nocs_vis = np.concatenate([gt_nocs_pc - offset_vec, gt_nocs_pc * 255], axis=1) 250 | pred_nocs_vis = np.concatenate([pred_nocs_pc, gt_nocs_pc * 255], axis=1) 251 | input_prev_nocs_vis = np.concatenate([input_prev_nocs_pc + offset_vec, input_prev_nocs_pc * 255], axis=1) 252 | if pred_nocs_confidence is not None: 253 | pred_confidence_vis = np.concatenate([pred_nocs_pc + 2 * offset_vec, pred_nocs_confidence * 255], axis=1) 254 | all_pc = np.concatenate([gt_nocs_vis, pred_nocs_vis, input_prev_nocs_vis, pred_confidence_vis]) 255 | else: 256 | all_pc = np.concatenate([gt_nocs_vis, pred_nocs_vis, input_prev_nocs_vis]) 257 | vis_obj = wandb.Object3D(all_pc) 258 | return vis_obj 259 | 260 | 261 | def get_nocs_mesh_vis(sample_key, 262 | samples_group, 263 | offset=[1.0,0,0], **kwargs): 264 | """ 265 | GT nocs pc Pred nocs pc (colored with gt nocs) 266 | """ 267 | sample_group = samples_group[sample_key] 268 | # io 269 | mesh_group = sample_group['mesh_points'] 270 | gt_nocs_mesh = mesh_group['gt_nocs_points'][:] 271 | pred_nocs_mesh = mesh_group['pred_nocs_points'][:] 272 | 273 | # vis 274 | offset_vec = np.array(offset) 275 | gt_nocs_vis = np.concatenate([gt_nocs_mesh - offset_vec, gt_nocs_mesh * 255], axis=1) 276 | pred_nocs_vis = np.concatenate([pred_nocs_mesh, gt_nocs_mesh * 255], axis=1) 277 | all_pc = np.concatenate([gt_nocs_vis, pred_nocs_vis]) 278 | vis_obj = wandb.Object3D(all_pc) 279 | return vis_obj 280 | 281 | 282 | # %% 283 | # main script 284 | @hydra.main(config_path="config", 285 | config_name="eval_tracking_default.yaml") 286 | def main(cfg: DictConfig) -> None: 287 | # load datase 288 | pred_output_dir = os.path.expanduser(cfg.main.prediction_output_dir) 289 | pred_config_path = os.path.join(pred_output_dir, 'config.yaml') 290 | pred_config_all = OmegaConf.load(pred_config_path) 291 | 292 | # setup wandb 293 | output_dir = os.getcwd() 294 | print(output_dir) 295 | 296 | wandb_path = os.path.join(output_dir, 'wandb') 297 | os.mkdir(wandb_path) 298 | wandb_run = wandb.init( 299 | project=os.path.basename(__file__), 300 | **cfg.logger) 301 | wandb_meta = { 302 | 'run_name': wandb_run.name, 303 | 'run_id': wandb_run.id 304 | } 305 | meta = { 306 | 'script_path': __file__ 307 | } 308 | all_config = { 309 | 'config': OmegaConf.to_container(cfg, resolve=True), 310 | 'prediction_config': OmegaConf.to_container(pred_config_all, resolve=True), 311 | 'output_dir': output_dir, 312 | 'wandb': wandb_meta, 313 | 'meta': meta 314 | } 315 | yaml.dump(all_config, open('config.yaml', 'w'), default_flow_style=False) 316 | wandb.config.update(all_config) 317 | 318 | # setup zarr 319 | pred_zarr_path = os.path.join(pred_output_dir, 'prediction.zarr') 320 | pred_root = zarr.open(pred_zarr_path, 'r+') 321 | samples_group = pred_root['samples'] 322 | summary_group = pred_root.require_group('summary', overwrite=False) 323 | compressor = Blosc(cname='zstd', clevel=6, shuffle=Blosc.BITSHUFFLE) 324 | 325 | sample_key, sample_group = next(iter(samples_group.groups())) 326 | print(sample_group.tree()) 327 | all_sample_keys = list() 328 | all_sample_groups = list() 329 | for sample_key, sample_group in samples_group.groups(): 330 | all_sample_keys.append(sample_key) 331 | all_sample_groups.append(sample_group) 332 | 333 | global_metrics_group = summary_group.require_group('metrics', overwrite=False) 334 | global_per_sample_group = global_metrics_group.require_group('per_sample', overwrite=False) 335 | global_agg_group = global_metrics_group.require_group('aggregate', overwrite=False) 336 | 337 | # write instance order 338 | sample_keys_arr = np.array(all_sample_keys) 339 | global_per_sample_group.array('sample_keys', sample_keys_arr, 340 | chunks=sample_keys_arr.shape, compressor=compressor, overwrite=True) 341 | 342 | # load aabb 343 | input_zarr_path = os.path.expanduser( 344 | pred_config_all.config.datamodule.zarr_path) 345 | input_root = zarr.open(input_zarr_path, 'r') 346 | input_samples_group = input_root['samples'] 347 | input_summary_group = input_root['summary'] 348 | nocs_aabb = input_summary_group['cloth_canonical_aabb_union'][:] 349 | sim_aabb = input_summary_group['cloth_aabb_union'][:] 350 | 351 | num_workers = cfg.main.num_workers 352 | sample_keys_series = pd.Series(all_sample_keys) 353 | result_df = parallel_map( 354 | lambda x: False, 355 | sample_keys_series, 356 | num_workers=num_workers, 357 | preserve_index=True) 358 | is_sample_null = result_df.result 359 | not_null_sample_keys_series = sample_keys_series.loc[~is_sample_null] 360 | 361 | # compute metrics 362 | metric_func_dict = { 363 | 'compute_pc_metrics': compute_pc_metrics, 364 | 'compute_chamfer': compute_chamfer, 365 | 'compute_euclidian': compute_euclidian, 366 | } 367 | 368 | num_workers = cfg.main.num_workers 369 | all_metrics = dict() 370 | for func_key, func in metric_func_dict.items(): 371 | print("Running {}".format(func_key)) 372 | metric_args = OmegaConf.to_container(cfg.eval[func_key], resolve=True) 373 | if not metric_args['enabled']: 374 | print("Disabled, skipping") 375 | continue 376 | 377 | print("Config:") 378 | pprint(metric_args) 379 | result_df = parallel_map( 380 | lambda x: func( 381 | sample_key=x, 382 | samples_group=samples_group, 383 | input_samples_group=input_samples_group, 384 | nocs_aabb=nocs_aabb, 385 | sim_aabb=sim_aabb, 386 | **metric_args), 387 | not_null_sample_keys_series, 388 | num_workers=num_workers, 389 | preserve_index=True) 390 | # print error 391 | errors_series = result_df.loc[result_df.error.notnull()].error 392 | if len(errors_series) > 0: 393 | print("Errors:") 394 | print(errors_series) 395 | 396 | result_dict = dict() 397 | for key in sample_keys_series.index: 398 | data = dict() 399 | if key in result_df.index: 400 | value = result_df.result.loc[key] 401 | if value is not None: 402 | data = value 403 | result_dict[key] = data 404 | this_metric_df = pd.DataFrame( 405 | list(result_dict.values()), 406 | index=sample_keys_series.index) 407 | 408 | for column in this_metric_df: 409 | all_metrics[column] = this_metric_df[column] 410 | value = np.array(this_metric_df[column]) 411 | global_per_sample_group.array( 412 | name=column, data=value, chunks=value.shape, 413 | compressor=compressor, overwrite=True) 414 | value_agg = np.nanmean(value) 415 | global_agg_group[column] = value_agg 416 | 417 | all_metrics_df = pd.DataFrame( 418 | all_metrics, 419 | index=sample_keys_series.index) 420 | all_metrics_df['null_percentage'] = is_sample_null.astype(np.float32) 421 | 422 | all_metrics_agg = all_metrics_df.mean() 423 | for column in all_metrics_df: 424 | if 'euclidian' in column: 425 | all_metrics_agg[column + '@0.03'] = (all_metrics_df[column] <= 0.03).sum() \ 426 | / len(all_metrics_df[column]) 427 | all_metrics_agg[column + '@0.05'] = (all_metrics_df[column] <= 0.05).sum() \ 428 | / len(all_metrics_df[column]) 429 | all_metrics_agg[column + '@0.08'] = (all_metrics_df[column] <= 0.08).sum() \ 430 | / len(all_metrics_df[column]) 431 | all_metrics_agg[column + '@0.1'] = (all_metrics_df[column] <= 0.1).sum() \ 432 | / len(all_metrics_df[column]) 433 | all_metrics_agg[column + '@0.15'] = (all_metrics_df[column] <= 0.15).sum() \ 434 | / len(all_metrics_df[column]) 435 | 436 | print(all_metrics_agg) 437 | # save metric to disk 438 | all_metrics_path = os.path.join(output_dir, 'all_metrics.csv') 439 | agg_path = os.path.join(output_dir, 'all_metrics_agg.csv') 440 | summary_path = os.path.join(output_dir, 'summary.json') 441 | all_metrics_df.to_csv(all_metrics_path) 442 | all_metrics_df.describe().to_csv(agg_path) 443 | json.dump(dict(all_metrics_agg), open(summary_path, 'w'), indent=2) 444 | 445 | if cfg.vis.samples_per_instance <= 0: 446 | print("Done!") 447 | return 448 | 449 | # visualization 450 | # pick best and worst 451 | rank_column = all_metrics_df[cfg.vis.rank_metric] 452 | sorted_rank_column = rank_column.sort_values() 453 | best_idxs = sorted_rank_column.index[:cfg.vis.num_best] 454 | worst_idxs = sorted_rank_column.index[-cfg.vis.num_best:][::-1] 455 | if cfg.vis.random_sample_regular: 456 | num_samples = len(sorted_rank_column) 457 | vis_idxs = np.random.choice(num_samples, size=cfg.vis.num_normal) 458 | else: 459 | start_idx, end_idx = cfg.vis.vis_sample_idxs_range 460 | vis_idxs = np.arange(start_idx, end_idx+1) 461 | 462 | print('vis_idxs: {}'.format(vis_idxs.tolist())) 463 | vis_idx_dict = dict() 464 | for i, idx in enumerate(vis_idxs): 465 | vis_idx_dict[idx] = "regular_{0:02d}".format(i) 466 | for i, idx in enumerate(best_idxs): 467 | vis_idx_dict[idx] = "best_{0:02d}".format(i) 468 | for i, idx in enumerate(worst_idxs): 469 | vis_idx_dict[idx] = "worst_{0:02d}".format(i) 470 | 471 | vis_func_dict = { 472 | 'task_mesh_vis': get_task_mesh_vis, 473 | 'nocs_pc_vis': get_nocs_pc_vis, 474 | 'nocs_mesh_vis': get_nocs_mesh_vis, 475 | } 476 | no_override_keys = list() 477 | # all_log_data = list() 478 | print("Logging visualization to wandb") 479 | for i in tqdm(range(len(all_metrics_df))): 480 | log_data = dict(all_metrics_df.loc[i]) 481 | if i in vis_idx_dict: 482 | vis_key = vis_idx_dict[i] 483 | if cfg.vis.save_point_cloud: 484 | save_dir = os.path.join(output_dir, 'vis') 485 | if not os.path.exists(save_dir): 486 | os.mkdir(save_dir) 487 | os.makedirs(save_dir.replace('vis', 'vis_faces'), exist_ok=True) 488 | save_path = os.path.join(save_dir, '{:0>4d}.npy'.format(i)) 489 | else: 490 | save_path = None 491 | for func_key, func in vis_func_dict.items(): 492 | metric_args = OmegaConf.to_container(cfg.vis[func_key], resolve=True) 493 | sample_key = sample_keys_series.loc[i] 494 | vis_obj = func(sample_key, samples_group, 495 | nocs_aabb=nocs_aabb, 496 | sim_aabb=sim_aabb, 497 | save_path=save_path, 498 | **metric_args) 499 | vis_name = '_'.join([func_key, vis_key]) 500 | log_data[vis_name] = vis_obj 501 | # all_log_data.append(log_data) 502 | wandb_run.log(log_data, step=i) 503 | 504 | print("Logging summary to wandb") 505 | for key, value in tqdm(all_metrics_agg.items()): 506 | wandb_run.summary[key] = value 507 | print("Done!") 508 | 509 | # %% 510 | # driver 511 | if __name__ == "__main__": 512 | main() 513 | -------------------------------------------------------------------------------- /components/unet3d.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Code from the 3D UNet implementation: 3 | https://github.com/wolny/pytorch-3dunet/ 4 | ''' 5 | import importlib 6 | import torch 7 | import torch.nn as nn 8 | from torch.nn import functional as F 9 | from functools import partial 10 | 11 | def number_of_features_per_level(init_channel_number, num_levels): 12 | return [init_channel_number * 2 ** k for k in range(num_levels)] 13 | 14 | 15 | def conv3d(in_channels, out_channels, kernel_size, bias, padding=1): 16 | return nn.Conv3d(in_channels, out_channels, kernel_size, padding=padding, bias=bias) 17 | 18 | 19 | def create_conv(in_channels, out_channels, kernel_size, order, num_groups, padding=1): 20 | """ 21 | Create a list of modules with together constitute a single conv layer with non-linearity 22 | and optional batchnorm/groupnorm. 23 | 24 | Args: 25 | in_channels (int): number of input channels 26 | out_channels (int): number of output channels 27 | order (string): order of things, e.g. 28 | 'cr' -> conv + ReLU 29 | 'gcr' -> groupnorm + conv + ReLU 30 | 'cl' -> conv + LeakyReLU 31 | 'ce' -> conv + ELU 32 | 'bcr' -> batchnorm + conv + ReLU 33 | num_groups (int): number of groups for the GroupNorm 34 | padding (int): add zero-padding to the input 35 | 36 | Return: 37 | list of tuple (name, module) 38 | """ 39 | assert 'c' in order, "Conv layer MUST be present" 40 | assert order[0] not in 'rle', 'Non-linearity cannot be the first operation in the layer' 41 | 42 | modules = [] 43 | for i, char in enumerate(order): 44 | if char == 'r': 45 | modules.append(('ReLU', nn.ReLU(inplace=True))) 46 | elif char == 'l': 47 | modules.append(('LeakyReLU', nn.LeakyReLU(negative_slope=0.1, inplace=True))) 48 | elif char == 'e': 49 | modules.append(('ELU', nn.ELU(inplace=True))) 50 | elif char == 'c': 51 | # add learnable bias only in the absence of batchnorm/groupnorm 52 | bias = not ('g' in order or 'b' in order) 53 | modules.append(('conv', conv3d(in_channels, out_channels, kernel_size, bias, padding=padding))) 54 | elif char == 'g': 55 | is_before_conv = i < order.index('c') 56 | if is_before_conv: 57 | num_channels = in_channels 58 | else: 59 | num_channels = out_channels 60 | 61 | # use only one group if the given number of groups is greater than the number of channels 62 | if num_channels < num_groups: 63 | num_groups = 1 64 | 65 | assert num_channels % num_groups == 0, f'Expected number of channels in input to be divisible by num_groups. num_channels={num_channels}, num_groups={num_groups}' 66 | modules.append(('groupnorm', nn.GroupNorm(num_groups=num_groups, num_channels=num_channels))) 67 | elif char == 'b': 68 | is_before_conv = i < order.index('c') 69 | if is_before_conv: 70 | modules.append(('batchnorm', nn.BatchNorm3d(in_channels))) 71 | else: 72 | modules.append(('batchnorm', nn.BatchNorm3d(out_channels))) 73 | else: 74 | raise ValueError(f"Unsupported layer type '{char}'. MUST be one of ['b', 'g', 'r', 'l', 'e', 'c']") 75 | 76 | return modules 77 | 78 | 79 | class SingleConv(nn.Sequential): 80 | """ 81 | Basic convolutional module consisting of a Conv3d, non-linearity and optional batchnorm/groupnorm. The order 82 | of operations can be specified via the `order` parameter 83 | 84 | Args: 85 | in_channels (int): number of input channels 86 | out_channels (int): number of output channels 87 | kernel_size (int): size of the convolving kernel 88 | order (string): determines the order of layers, e.g. 89 | 'cr' -> conv + ReLU 90 | 'crg' -> conv + ReLU + groupnorm 91 | 'cl' -> conv + LeakyReLU 92 | 'ce' -> conv + ELU 93 | num_groups (int): number of groups for the GroupNorm 94 | """ 95 | 96 | def __init__(self, in_channels, out_channels, kernel_size=3, order='crg', num_groups=8, padding=1): 97 | super(SingleConv, self).__init__() 98 | 99 | for name, module in create_conv(in_channels, out_channels, kernel_size, order, num_groups, padding=padding): 100 | self.add_module(name, module) 101 | 102 | 103 | class DoubleConv(nn.Sequential): 104 | """ 105 | A module consisting of two consecutive convolution layers (e.g. BatchNorm3d+ReLU+Conv3d). 106 | We use (Conv3d+ReLU+GroupNorm3d) by default. 107 | This can be changed however by providing the 'order' argument, e.g. in order 108 | to change to Conv3d+BatchNorm3d+ELU use order='cbe'. 109 | Use padded convolutions to make sure that the output (H_out, W_out) is the same 110 | as (H_in, W_in), so that you don't have to crop in the decoder path. 111 | 112 | Args: 113 | in_channels (int): number of input channels 114 | out_channels (int): number of output channels 115 | encoder (bool): if True we're in the encoder path, otherwise we're in the decoder 116 | kernel_size (int): size of the convolving kernel 117 | order (string): determines the order of layers, e.g. 118 | 'cr' -> conv + ReLU 119 | 'crg' -> conv + ReLU + groupnorm 120 | 'cl' -> conv + LeakyReLU 121 | 'ce' -> conv + ELU 122 | num_groups (int): number of groups for the GroupNorm 123 | """ 124 | 125 | def __init__(self, in_channels, out_channels, encoder, kernel_size=3, order='crg', num_groups=8): 126 | super(DoubleConv, self).__init__() 127 | if encoder: 128 | # we're in the encoder path 129 | conv1_in_channels = in_channels 130 | conv1_out_channels = out_channels // 2 131 | if conv1_out_channels < in_channels: 132 | conv1_out_channels = in_channels 133 | conv2_in_channels, conv2_out_channels = conv1_out_channels, out_channels 134 | else: 135 | # we're in the decoder path, decrease the number of channels in the 1st convolution 136 | conv1_in_channels, conv1_out_channels = in_channels, out_channels 137 | conv2_in_channels, conv2_out_channels = out_channels, out_channels 138 | 139 | # conv1 140 | self.add_module('SingleConv1', 141 | SingleConv(conv1_in_channels, conv1_out_channels, kernel_size, order, num_groups)) 142 | # conv2 143 | self.add_module('SingleConv2', 144 | SingleConv(conv2_in_channels, conv2_out_channels, kernel_size, order, num_groups)) 145 | 146 | 147 | class ExtResNetBlock(nn.Module): 148 | """ 149 | Basic UNet block consisting of a SingleConv followed by the residual block. 150 | The SingleConv takes care of increasing/decreasing the number of channels and also ensures that the number 151 | of output channels is compatible with the residual block that follows. 152 | This block can be used instead of standard DoubleConv in the Encoder module. 153 | Motivated by: https://arxiv.org/pdf/1706.00120.pdf 154 | 155 | Notice we use ELU instead of ReLU (order='cge') and put non-linearity after the groupnorm. 156 | """ 157 | 158 | def __init__(self, in_channels, out_channels, kernel_size=3, order='cge', num_groups=8, **kwargs): 159 | super(ExtResNetBlock, self).__init__() 160 | 161 | # first convolution 162 | self.conv1 = SingleConv(in_channels, out_channels, kernel_size=kernel_size, order=order, num_groups=num_groups) 163 | # residual block 164 | self.conv2 = SingleConv(out_channels, out_channels, kernel_size=kernel_size, order=order, num_groups=num_groups) 165 | # remove non-linearity from the 3rd convolution since it's going to be applied after adding the residual 166 | n_order = order 167 | for c in 'rel': 168 | n_order = n_order.replace(c, '') 169 | self.conv3 = SingleConv(out_channels, out_channels, kernel_size=kernel_size, order=n_order, 170 | num_groups=num_groups) 171 | 172 | # create non-linearity separately 173 | if 'l' in order: 174 | self.non_linearity = nn.LeakyReLU(negative_slope=0.1, inplace=True) 175 | elif 'e' in order: 176 | self.non_linearity = nn.ELU(inplace=True) 177 | else: 178 | self.non_linearity = nn.ReLU(inplace=True) 179 | 180 | def forward(self, x): 181 | # apply first convolution and save the output as a residual 182 | out = self.conv1(x) 183 | residual = out 184 | 185 | # residual block 186 | out = self.conv2(out) 187 | out = self.conv3(out) 188 | 189 | out += residual 190 | out = self.non_linearity(out) 191 | 192 | return out 193 | 194 | 195 | class Encoder(nn.Module): 196 | """ 197 | A single module from the encoder path consisting of the optional max 198 | pooling layer (one may specify the MaxPool kernel_size to be different 199 | than the standard (2,2,2), e.g. if the volumetric data is anisotropic 200 | (make sure to use complementary scale_factor in the decoder path) followed by 201 | a DoubleConv module. 202 | Args: 203 | in_channels (int): number of input channels 204 | out_channels (int): number of output channels 205 | conv_kernel_size (int): size of the convolving kernel 206 | apply_pooling (bool): if True use MaxPool3d before DoubleConv 207 | pool_kernel_size (tuple): the size of the window to take a max over 208 | pool_type (str): pooling layer: 'max' or 'avg' 209 | basic_module(nn.Module): either ResNetBlock or DoubleConv 210 | conv_layer_order (string): determines the order of layers 211 | in `DoubleConv` module. See `DoubleConv` for more info. 212 | num_groups (int): number of groups for the GroupNorm 213 | """ 214 | 215 | def __init__(self, in_channels, out_channels, conv_kernel_size=3, apply_pooling=True, 216 | pool_kernel_size=(2, 2, 2), pool_type='max', basic_module=DoubleConv, conv_layer_order='crg', 217 | num_groups=8): 218 | super(Encoder, self).__init__() 219 | assert pool_type in ['max', 'avg'] 220 | if apply_pooling: 221 | if pool_type == 'max': 222 | self.pooling = nn.MaxPool3d(kernel_size=pool_kernel_size) 223 | else: 224 | self.pooling = nn.AvgPool3d(kernel_size=pool_kernel_size) 225 | else: 226 | self.pooling = None 227 | 228 | self.basic_module = basic_module(in_channels, out_channels, 229 | encoder=True, 230 | kernel_size=conv_kernel_size, 231 | order=conv_layer_order, 232 | num_groups=num_groups) 233 | 234 | def forward(self, x): 235 | if self.pooling is not None: 236 | x = self.pooling(x) 237 | x = self.basic_module(x) 238 | return x 239 | 240 | 241 | class Decoder(nn.Module): 242 | """ 243 | A single module for decoder path consisting of the upsampling layer 244 | (either learned ConvTranspose3d or nearest neighbor interpolation) followed by a basic module (DoubleConv or ExtResNetBlock). 245 | Args: 246 | in_channels (int): number of input channels 247 | out_channels (int): number of output channels 248 | kernel_size (int): size of the convolving kernel 249 | scale_factor (tuple): used as the multiplier for the image H/W/D in 250 | case of nn.Upsample or as stride in case of ConvTranspose3d, must reverse the MaxPool3d operation 251 | from the corresponding encoder 252 | basic_module(nn.Module): either ResNetBlock or DoubleConv 253 | conv_layer_order (string): determines the order of layers 254 | in `DoubleConv` module. See `DoubleConv` for more info. 255 | num_groups (int): number of groups for the GroupNorm 256 | """ 257 | 258 | def __init__(self, in_channels, out_channels, kernel_size=3, scale_factor=(2, 2, 2), basic_module=DoubleConv, 259 | conv_layer_order='crg', num_groups=8, mode='nearest'): 260 | super(Decoder, self).__init__() 261 | if basic_module == DoubleConv: 262 | # if DoubleConv is the basic_module use interpolation for upsampling and concatenation joining 263 | self.upsampling = Upsampling(transposed_conv=False, in_channels=in_channels, out_channels=out_channels, 264 | kernel_size=kernel_size, scale_factor=scale_factor, mode=mode) 265 | # concat joining 266 | self.joining = partial(self._joining, concat=True) 267 | else: 268 | # if basic_module=ExtResNetBlock use transposed convolution upsampling and summation joining 269 | self.upsampling = Upsampling(transposed_conv=True, in_channels=in_channels, out_channels=out_channels, 270 | kernel_size=kernel_size, scale_factor=scale_factor, mode=mode) 271 | # sum joining 272 | self.joining = partial(self._joining, concat=False) 273 | # adapt the number of in_channels for the ExtResNetBlock 274 | in_channels = out_channels 275 | 276 | self.basic_module = basic_module(in_channels, out_channels, 277 | encoder=False, 278 | kernel_size=kernel_size, 279 | order=conv_layer_order, 280 | num_groups=num_groups) 281 | 282 | def forward(self, encoder_features, x): 283 | x = self.upsampling(encoder_features=encoder_features, x=x) 284 | x = self.joining(encoder_features, x) 285 | x = self.basic_module(x) 286 | return x 287 | 288 | @staticmethod 289 | def _joining(encoder_features, x, concat): 290 | if concat: 291 | return torch.cat((encoder_features, x), dim=1) 292 | else: 293 | return encoder_features + x 294 | 295 | 296 | class Upsampling(nn.Module): 297 | """ 298 | Upsamples a given multi-channel 3D data using either interpolation or learned transposed convolution. 299 | 300 | Args: 301 | transposed_conv (bool): if True uses ConvTranspose3d for upsampling, otherwise uses interpolation 302 | concat_joining (bool): if True uses concatenation joining between encoder and decoder features, otherwise 303 | uses summation joining (see Residual U-Net) 304 | in_channels (int): number of input channels for transposed conv 305 | out_channels (int): number of output channels for transpose conv 306 | kernel_size (int or tuple): size of the convolving kernel 307 | scale_factor (int or tuple): stride of the convolution 308 | mode (str): algorithm used for upsampling: 309 | 'nearest' | 'linear' | 'bilinear' | 'trilinear' | 'area'. Default: 'nearest' 310 | """ 311 | 312 | def __init__(self, transposed_conv, in_channels=None, out_channels=None, kernel_size=3, 313 | scale_factor=(2, 2, 2), mode='nearest'): 314 | super(Upsampling, self).__init__() 315 | 316 | if transposed_conv: 317 | # make sure that the output size reverses the MaxPool3d from the corresponding encoder 318 | # (D_out = (D_in − 1) ×  stride[0] − 2 ×  padding[0] +  kernel_size[0] +  output_padding[0]) 319 | self.upsample = nn.ConvTranspose3d(in_channels, out_channels, kernel_size=kernel_size, stride=scale_factor, 320 | padding=1) 321 | else: 322 | self.upsample = partial(self._interpolate, mode=mode) 323 | 324 | def forward(self, encoder_features, x): 325 | output_size = encoder_features.size()[2:] 326 | return self.upsample(x, output_size) 327 | 328 | @staticmethod 329 | def _interpolate(x, size, mode): 330 | return F.interpolate(x, size=size, mode=mode) 331 | 332 | 333 | class FinalConv(nn.Sequential): 334 | """ 335 | A module consisting of a convolution layer (e.g. Conv3d+ReLU+GroupNorm3d) and the final 1x1 convolution 336 | which reduces the number of channels to 'out_channels'. 337 | with the number of output channels 'out_channels // 2' and 'out_channels' respectively. 338 | We use (Conv3d+ReLU+GroupNorm3d) by default. 339 | This can be change however by providing the 'order' argument, e.g. in order 340 | to change to Conv3d+BatchNorm3d+ReLU use order='cbr'. 341 | Args: 342 | in_channels (int): number of input channels 343 | out_channels (int): number of output channels 344 | kernel_size (int): size of the convolving kernel 345 | order (string): determines the order of layers, e.g. 346 | 'cr' -> conv + ReLU 347 | 'crg' -> conv + ReLU + groupnorm 348 | num_groups (int): number of groups for the GroupNorm 349 | """ 350 | 351 | def __init__(self, in_channels, out_channels, kernel_size=3, order='crg', num_groups=8): 352 | super(FinalConv, self).__init__() 353 | 354 | # conv1 355 | self.add_module('SingleConv', SingleConv(in_channels, in_channels, kernel_size, order, num_groups)) 356 | 357 | # in the last layer a 1×1 convolution reduces the number of output channels to out_channels 358 | final_conv = nn.Conv3d(in_channels, out_channels, 1) 359 | self.add_module('final_conv', final_conv) 360 | 361 | class Abstract3DUNet(nn.Module): 362 | """ 363 | Base class for standard and residual UNet. 364 | 365 | Args: 366 | in_channels (int): number of input channels 367 | out_channels (int): number of output segmentation masks; 368 | Note that that the of out_channels might correspond to either 369 | different semantic classes or to different binary segmentation mask. 370 | It's up to the user of the class to interpret the out_channels and 371 | use the proper loss criterion during training (i.e. CrossEntropyLoss (multi-class) 372 | or BCEWithLogitsLoss (two-class) respectively) 373 | f_maps (int, tuple): number of feature maps at each level of the encoder; if it's an integer the number 374 | of feature maps is given by the geometric progression: f_maps ^ k, k=1,2,3,4 375 | final_sigmoid (bool): if True apply element-wise nn.Sigmoid after the 376 | final 1x1 convolution, otherwise apply nn.Softmax. MUST be True if nn.BCELoss (two-class) is used 377 | to train the model. MUST be False if nn.CrossEntropyLoss (multi-class) is used to train the model. 378 | basic_module: basic model for the encoder/decoder (DoubleConv, ExtResNetBlock, ....) 379 | layer_order (string): determines the order of layers 380 | in `SingleConv` module. e.g. 'crg' stands for Conv3d+ReLU+GroupNorm3d. 381 | See `SingleConv` for more info 382 | f_maps (int, tuple): if int: number of feature maps in the first conv layer of the encoder (default: 64); 383 | if tuple: number of feature maps at each level 384 | num_groups (int): number of groups for the GroupNorm 385 | num_levels (int): number of levels in the encoder/decoder path (applied only if f_maps is an int) 386 | is_segmentation (bool): if True (semantic segmentation problem) Sigmoid/Softmax normalization is applied 387 | after the final convolution; if False (regression problem) the normalization layer is skipped at the end 388 | testing (bool): if True (testing mode) the `final_activation` (if present, i.e. `is_segmentation=true`) 389 | will be applied as the last operation during the forward pass; if False the model is in training mode 390 | and the `final_activation` (even if present) won't be applied; default: False 391 | """ 392 | 393 | def __init__(self, in_channels, out_channels, final_sigmoid, basic_module, f_maps=64, layer_order='gcr', 394 | num_groups=8, num_levels=4, is_segmentation=False, testing=False, **kwargs): 395 | super(Abstract3DUNet, self).__init__() 396 | 397 | self.testing = testing 398 | 399 | if isinstance(f_maps, int): 400 | f_maps = number_of_features_per_level(f_maps, num_levels=num_levels) 401 | 402 | # create encoder path consisting of Encoder modules. Depth of the encoder is equal to `len(f_maps)` 403 | encoders = [] 404 | for i, out_feature_num in enumerate(f_maps): 405 | if i == 0: 406 | encoder = Encoder(in_channels, out_feature_num, apply_pooling=False, basic_module=basic_module, 407 | conv_layer_order=layer_order, num_groups=num_groups) 408 | else: 409 | # TODO: adapt for anisotropy in the data, i.e. use proper pooling kernel to make the data isotropic after 1-2 pooling operations 410 | # currently pools with a constant kernel: (2, 2, 2) 411 | encoder = Encoder(f_maps[i - 1], out_feature_num, basic_module=basic_module, 412 | conv_layer_order=layer_order, num_groups=num_groups) 413 | encoders.append(encoder) 414 | 415 | self.encoders = nn.ModuleList(encoders) 416 | 417 | # create decoder path consisting of the Decoder modules. The length of the decoder is equal to `len(f_maps) - 1` 418 | decoders = [] 419 | reversed_f_maps = list(reversed(f_maps)) 420 | for i in range(len(reversed_f_maps) - 1): 421 | if basic_module == DoubleConv: 422 | in_feature_num = reversed_f_maps[i] + reversed_f_maps[i + 1] 423 | else: 424 | in_feature_num = reversed_f_maps[i] 425 | 426 | out_feature_num = reversed_f_maps[i + 1] 427 | # TODO: if non-standard pooling was used, make sure to use correct striding for transpose conv 428 | # currently strides with a constant stride: (2, 2, 2) 429 | decoder = Decoder(in_feature_num, out_feature_num, basic_module=basic_module, 430 | conv_layer_order=layer_order, num_groups=num_groups) 431 | decoders.append(decoder) 432 | 433 | self.decoders = nn.ModuleList(decoders) 434 | 435 | # in the last layer a 1×1 convolution reduces the number of output 436 | # channels to the number of labels 437 | self.final_conv = nn.Conv3d(f_maps[0], out_channels, 1) 438 | 439 | if is_segmentation: 440 | # semantic segmentation problem 441 | if final_sigmoid: 442 | self.final_activation = nn.Sigmoid() 443 | else: 444 | self.final_activation = nn.Softmax(dim=1) 445 | else: 446 | # regression problem 447 | self.final_activation = None 448 | 449 | def forward(self, x): 450 | # encoder part 451 | encoders_features = [] 452 | for encoder in self.encoders: 453 | x = encoder(x) 454 | # reverse the encoder outputs to be aligned with the decoder 455 | encoders_features.insert(0, x) 456 | 457 | # remove the last encoder's output from the list 458 | # !!remember: it's the 1st in the list 459 | encoders_features = encoders_features[1:] 460 | 461 | # decoder part 462 | for decoder, encoder_features in zip(self.decoders, encoders_features): 463 | # pass the output from the corresponding encoder and the output 464 | # of the previous decoder 465 | x = decoder(encoder_features, x) 466 | 467 | x = self.final_conv(x) 468 | 469 | # apply final_activation (i.e. Sigmoid or Softmax) only during prediction. During training the network outputs 470 | # logits and it's up to the user to normalize it before visualising with tensorboard or computing validation metric 471 | if self.testing and self.final_activation is not None: 472 | x = self.final_activation(x) 473 | 474 | return x 475 | 476 | 477 | class UNet3D(Abstract3DUNet): 478 | """ 479 | 3DUnet model from 480 | `"3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation" 481 | `. 482 | 483 | Uses `DoubleConv` as a basic_module and nearest neighbor upsampling in the decoder 484 | """ 485 | 486 | def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, layer_order='gcr', 487 | num_groups=8, num_levels=4, is_segmentation=True, **kwargs): 488 | super(UNet3D, self).__init__(in_channels=in_channels, out_channels=out_channels, final_sigmoid=final_sigmoid, 489 | basic_module=DoubleConv, f_maps=f_maps, layer_order=layer_order, 490 | num_groups=num_groups, num_levels=num_levels, is_segmentation=is_segmentation, 491 | **kwargs) 492 | 493 | 494 | class ResidualUNet3D(Abstract3DUNet): 495 | """ 496 | Residual 3DUnet model implementation based on https://arxiv.org/pdf/1706.00120.pdf. 497 | Uses ExtResNetBlock as a basic building block, summation joining instead 498 | of concatenation joining and transposed convolutions for upsampling (watch out for block artifacts). 499 | Since the model effectively becomes a residual net, in theory it allows for deeper UNet. 500 | """ 501 | 502 | def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, layer_order='gcr', 503 | num_groups=8, num_levels=5, is_segmentation=True, **kwargs): 504 | super(ResidualUNet3D, self).__init__(in_channels=in_channels, out_channels=out_channels, 505 | final_sigmoid=final_sigmoid, 506 | basic_module=ExtResNetBlock, f_maps=f_maps, layer_order=layer_order, 507 | num_groups=num_groups, num_levels=num_levels, 508 | is_segmentation=is_segmentation, 509 | **kwargs) 510 | 511 | -------------------------------------------------------------------------------- /networks/tracking_network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | import pytorch_lightning as pl 6 | import wandb 7 | import torch_scatter 8 | 9 | from components.unet3d import Abstract3DUNet, DoubleConv 10 | from components.mlp import MLP, MLP_V2 11 | from networks.resunet import SparseResUNet2 12 | from networks.transformer import TransformerSiamese 13 | from networks.pointnet import MiniPointNetfeat 14 | from common.torch_util import to_numpy 15 | from common.visualization_util import ( 16 | get_vis_idxs, render_nocs_pair) 17 | from components.gridding import VirtualGrid 18 | 19 | import MinkowskiEngine as ME 20 | 21 | 22 | class VolumeTrackingFeatureAggregator(pl.LightningModule): 23 | def __init__(self, 24 | nn_channels=(38, 64, 64), 25 | batch_norm=True, 26 | lower_corner=(0,0,0), 27 | upper_corner=(1,1,1), 28 | grid_shape=(32, 32, 32), 29 | reduce_method='max', 30 | include_point_feature=True, 31 | use_gt_nocs_for_train=True, 32 | use_mlp_v2=True, 33 | ): 34 | super().__init__() 35 | self.save_hyperparameters() 36 | if use_mlp_v2: 37 | self.local_nn2 = MLP_V2(nn_channels, batch_norm=batch_norm, transpose_input=True) 38 | else: 39 | self.local_nn2 = MLP(nn_channels, batch_norm=batch_norm) 40 | self.lower_corner = tuple(lower_corner) 41 | self.upper_corner = tuple(upper_corner) 42 | self.grid_shape = tuple(grid_shape) 43 | self.reduce_method = reduce_method 44 | self.include_point_feature = include_point_feature 45 | self.use_gt_nocs_for_train = use_gt_nocs_for_train 46 | 47 | def forward(self, nocs_data, batch_size, is_train=False): 48 | lower_corner = self.lower_corner 49 | upper_corner = self.upper_corner 50 | grid_shape = self.grid_shape 51 | include_point_feature = self.include_point_feature 52 | reduce_method = self.reduce_method 53 | 54 | sim_points_frame2 = nocs_data['sim_points_frame2'] 55 | if is_train and self.use_gt_nocs_for_train: 56 | points_frame2 = nocs_data['pos_gt_frame2'] 57 | else: 58 | if 'refined_pos_frame2' in nocs_data: 59 | points_frame2 = nocs_data['refined_pos_frame2'] 60 | else: 61 | points_frame2 = nocs_data['pos_frame2'] 62 | nocs_features_frame2 = nocs_data['x_frame2'] 63 | batch_idx_frame2 = nocs_data['batch_frame2'] 64 | device = points_frame2.device 65 | float_dtype = points_frame2.dtype 66 | int_dtype = torch.int64 67 | 68 | vg = VirtualGrid( 69 | lower_corner=lower_corner, 70 | upper_corner=upper_corner, 71 | grid_shape=grid_shape, 72 | batch_size=batch_size, 73 | device=device, 74 | int_dtype=int_dtype, 75 | float_dtype=float_dtype) 76 | 77 | # get aggregation target index 78 | points_grid_idxs_frame2 = vg.get_points_grid_idxs(points_frame2, batch_idx=batch_idx_frame2) 79 | flat_idxs_frame2 = vg.flatten_idxs(points_grid_idxs_frame2, keepdim=False) 80 | 81 | # get features 82 | features_list_frame2 = [nocs_features_frame2] 83 | if include_point_feature: 84 | points_grid_points_frame2 = vg.idxs_to_points(points_grid_idxs_frame2) 85 | local_offset_frame2 = points_frame2 - points_grid_points_frame2 86 | features_list_frame2.append(local_offset_frame2) 87 | features_list_frame2.append(sim_points_frame2) 88 | 89 | features_frame2 = torch.cat(features_list_frame2, axis=-1) 90 | 91 | # per-point transform 92 | if self.local_nn2 is not None: 93 | features_frame2 = self.local_nn2(features_frame2) 94 | 95 | # scatter 96 | volume_feature_flat_frame2 = torch_scatter.scatter( 97 | src=features_frame2.T, index=flat_idxs_frame2, dim=-1, 98 | dim_size=vg.num_grids, reduce=reduce_method) 99 | 100 | # reshape to volume 101 | feature_size = volume_feature_flat_frame2.shape[0] 102 | volume_feature_all = volume_feature_flat_frame2.reshape( 103 | (feature_size, batch_size) + grid_shape).permute((1,0,2,3,4)) 104 | return volume_feature_all 105 | 106 | 107 | class UNet3DTracking(pl.LightningModule): 108 | def __init__(self, in_channels, out_channels, f_maps=64, 109 | layer_order='gcr', num_groups=8, num_levels=4): 110 | super().__init__() 111 | self.save_hyperparameters() 112 | self.abstract_3d_unet = Abstract3DUNet( 113 | in_channels=in_channels, out_channels=out_channels, 114 | final_sigmoid=False, basic_module=DoubleConv, f_maps=f_maps, 115 | layer_order=layer_order, num_groups=num_groups, 116 | num_levels=num_levels, is_segmentation=False) 117 | 118 | def forward(self, data): 119 | result = self.abstract_3d_unet(data) 120 | return result 121 | 122 | 123 | class ImplicitWNFDecoder(pl.LightningModule): 124 | def __init__(self, 125 | nn_channels=(128,256,256,3), 126 | batch_norm=True, 127 | last_layer_mlp=False, 128 | use_mlp_v2=False 129 | ): 130 | super().__init__() 131 | self.save_hyperparameters() 132 | if use_mlp_v2: 133 | self.mlp = MLP_V2(nn_channels, batch_norm=batch_norm, transpose_input=True) 134 | else: 135 | self.mlp = MLP(nn_channels, batch_norm=batch_norm, last_layer=last_layer_mlp) 136 | 137 | def forward(self, features_grid, query_points): 138 | """ 139 | features_grid: (N,C,D,H,W) 140 | query_points: (N,M,3) 141 | """ 142 | batch_size = features_grid.shape[0] 143 | if len(query_points.shape) == 2: 144 | query_points = query_points.view(batch_size, -1, 3) 145 | # normalize query points to (-1, 1), which is 146 | # requried by grid_sample 147 | query_points_normalized = 2.0 * query_points - 1.0 148 | # shape (N,C,M,1,1) 149 | sampled_features = F.grid_sample( 150 | input=features_grid, 151 | grid=query_points_normalized.view( 152 | *(query_points_normalized.shape[:2] + (1,1,3))), 153 | mode='bilinear', padding_mode='border', 154 | align_corners=True) 155 | # shape (N,M,C) 156 | sampled_features = sampled_features.view( 157 | sampled_features.shape[:3]).permute(0,2,1) 158 | 159 | # shape (N,M,C) 160 | out_features = self.mlp(sampled_features) 161 | 162 | return out_features 163 | 164 | 165 | class PointMeshNocsRefiner(pl.LightningModule): 166 | def __init__(self, 167 | pc_pointnet_channels=(326, 256, 256, 1024), 168 | mesh_pointnet_channels=(3, 64, 128, 1024), 169 | pc_refine_mlp_channels=(2304, 1024, 512, 192), 170 | mesh_refine_pointnet_channels=(2112, 512, 512, 1024), 171 | mesh_refine_mlp_channels=(1024, 512, 256, 6), 172 | detach_input_pc_feature=True, 173 | detach_global_pc_feature=True, 174 | detach_global_mesh_feature=True, 175 | **kwargs): 176 | super(PointMeshNocsRefiner, self).__init__() 177 | self.pc_pointnet = MiniPointNetfeat(nn_channels=pc_pointnet_channels) 178 | self.mesh_pointnet = MiniPointNetfeat(nn_channels=mesh_pointnet_channels) 179 | self.mesh_refine_pointnet = MiniPointNetfeat(nn_channels=mesh_refine_pointnet_channels) 180 | self.pc_refine_mlp = MLP_V2(channels=pc_refine_mlp_channels, batch_norm=True) 181 | self.mesh_refine_mlp = MLP_V2(channels=mesh_refine_mlp_channels, batch_norm=True) 182 | self.detach_input_pc_feature = detach_input_pc_feature 183 | self.detach_global_pc_feature = detach_global_pc_feature 184 | self.detach_global_mesh_feature = detach_global_mesh_feature 185 | 186 | def forward(self, pc_nocs, pc_sim, pc_cls_logits, pc_feat, mesh_nocs, batch_size=16, **kwargs): 187 | """ 188 | NOCS refiner to refine predicted PC NOCS 189 | :param pc_nocs: (B*N, 3) 190 | :param pc_sim: (B*N, 3) 191 | :param pc_cls_logits: (B*N, 64*3) 192 | :param pc_feat: (B*N, C) 193 | :param mesh_nocs: (B*M, 3) 194 | :param batch_size: int 195 | :return: refined PC NOCS logits (B*N, 64*3), refined mesh NOCS (B*M, 3) 196 | """ 197 | pc_nocs = pc_nocs.view(batch_size, -1, 3) # (B, N, 3) 198 | pc_sim = pc_sim.view(batch_size, -1, 3) # (B, N, 3) 199 | mesh_nocs = mesh_nocs.view(batch_size, -1, 3) # (B, M, 3) 200 | num_pc_points = pc_nocs.shape[1] # N 201 | num_mesh_points = mesh_nocs.shape[1] # M 202 | # detach logits 203 | pc_cls_logits = pc_cls_logits.view(batch_size, num_pc_points, -1).detach() # (B, N, 64*3) 204 | pc_feat = pc_feat.view(batch_size, num_pc_points, -1) # (B, N, C) 205 | if self.detach_input_pc_feature: 206 | pc_feat = pc_feat.detach() 207 | pc_input_all = torch.cat([pc_nocs, pc_sim, pc_cls_logits, pc_feat], dim=-1) # (B, N, 3+3+64*3+C) 208 | pc_input_all = pc_input_all.transpose(1, 2) # (B, 3+3+64*3+C, N) 209 | 210 | # pc pointnet 211 | pc_feat_dense, pc_feat_global = self.pc_pointnet(pc_input_all) # (B, C', N), (B, C') 212 | 213 | # mesh pointnet 214 | mesh_input_all = mesh_nocs.transpose(1, 2) # (B, 3, M) 215 | mesh_feat_dense, _ = self.mesh_pointnet(mesh_input_all) # (B, C', M) 216 | pc_feat_global_expand = pc_feat_global.unsqueeze(2).expand(-1, -1, num_mesh_points) # (B, C', M) 217 | if self.detach_global_pc_feature: 218 | pc_feat_global_expand = pc_feat_global_expand.detach() 219 | mesh_feat_dense_cat = torch.cat([mesh_feat_dense, pc_feat_global_expand], dim=1) # (B, C'+C', M) 220 | _, mesh_refine_feat_global = self.mesh_refine_pointnet(mesh_feat_dense_cat) # (B, C') 221 | 222 | # refine pc-nocs 223 | mesh_refine_feat_global_expand = mesh_refine_feat_global.unsqueeze(2).expand(-1, -1, num_pc_points) # (B, C', N) 224 | if self.detach_global_mesh_feature: 225 | mesh_refine_feat_global_expand = mesh_refine_feat_global_expand.detach() 226 | pc_feat_cat = torch.cat([pc_feat_dense, mesh_refine_feat_global_expand], dim=1) # (B, C'+C', N) 227 | delta_pc_cls_logits = self.pc_refine_mlp(pc_feat_cat) # (B, 64*3, N) 228 | delta_pc_cls_logits = delta_pc_cls_logits.transpose(1, 2) # (B, N, 64*3) 229 | assert delta_pc_cls_logits.shape[-1] == pc_cls_logits.shape[-1] 230 | refine_pc_cls_logits = pc_cls_logits + delta_pc_cls_logits # (B, N, 64*3) 231 | refine_pc_cls_logits = refine_pc_cls_logits.reshape(batch_size * num_pc_points, -1) # (B*N, 64*3) 232 | 233 | # refine mesh-nocs 234 | mesh_refine_feat_global = mesh_refine_feat_global.unsqueeze(2) # (B, C', 1) 235 | mesh_nocs_delta_logits = self.mesh_refine_mlp(mesh_refine_feat_global) # (B, 6, 1) 236 | refine_offset, refine_scale = mesh_nocs_delta_logits[:, :3, 0].unsqueeze(1), \ 237 | mesh_nocs_delta_logits[:, 3:, 0].unsqueeze(1) # (B, 1, 3) 238 | nocs_center = torch.tensor([[[0.5, 0.5, 0.5]]]).to(mesh_nocs.get_device()) # (1, 1, 3) 239 | refined_mesh_nocs = (mesh_nocs - nocs_center) * refine_scale + nocs_center + refine_offset # (B, M, 3) 240 | refined_mesh_nocs = refined_mesh_nocs.reshape(batch_size * num_mesh_points, -1) # (B*M, 3) 241 | return refine_pc_cls_logits, refined_mesh_nocs 242 | 243 | 244 | class GarmentTrackingPipeline(pl.LightningModule): 245 | """ 246 | Use sparse ResUNet as backbone 247 | Use point-cloud pair as input(not mesh) 248 | Add transformer for self-attention and cross-attention 249 | """ 250 | def __init__(self, 251 | # sparse uned3d encoder params 252 | sparse_unet3d_encoder_params, 253 | # self-attention and cross-attention transformer params 254 | transformer_params, 255 | # pc nocs and mesh nocs refiner params 256 | nocs_refiner_params, 257 | # VolumeFeaturesAggregator params 258 | volume_agg_params, 259 | # unet3d params 260 | unet3d_params, 261 | # ImplicitWNFDecoder params 262 | surface_decoder_params, 263 | # training params 264 | learning_rate=1e-4, 265 | optimizer_type='Adam', 266 | loss_type='l2', 267 | volume_loss_weight=1.0, 268 | warp_loss_weight=10.0, 269 | nocs_loss_weight=1.0, 270 | mesh_loss_weight=10.0, 271 | use_nocs_refiner=True, 272 | disable_pc_nocs_refine_in_test=False, 273 | disable_mesh_nocs_refine_in_test=False, 274 | # vis params 275 | vis_per_items=0, 276 | max_vis_per_epoch_train=0, 277 | max_vis_per_epoch_val=0, 278 | batch_size=None, 279 | # debug params 280 | debug=False 281 | ): 282 | super().__init__() 283 | self.save_hyperparameters() 284 | 285 | criterion = None 286 | if loss_type == 'l2': 287 | criterion = nn.MSELoss(reduction='mean') 288 | elif loss_type == 'smooth_l1': 289 | criterion = nn.SmoothL1Loss(reduction='mean') 290 | else: 291 | raise RuntimeError("Invalid loss_type: {}".format(loss_type)) 292 | 293 | self.sparse_unet3d_encoder = SparseResUNet2(**sparse_unet3d_encoder_params) 294 | self.transformer_siamese = TransformerSiamese(**transformer_params) 295 | self.nocs_refiner = PointMeshNocsRefiner(**nocs_refiner_params) 296 | self.volume_agg = VolumeTrackingFeatureAggregator(**volume_agg_params) 297 | self.unet_3d = UNet3DTracking(**unet3d_params) 298 | self.surface_decoder = ImplicitWNFDecoder(**surface_decoder_params) 299 | self.use_nocs_refiner = use_nocs_refiner 300 | self.disable_pc_nocs_refine_in_test = disable_pc_nocs_refine_in_test 301 | self.disable_mesh_nocs_refine_in_test = disable_mesh_nocs_refine_in_test 302 | self.mesh_loss_weight = mesh_loss_weight 303 | 304 | self.surface_criterion = criterion 305 | if self.transformer_siamese.nocs_bins is not None: 306 | self.nocs_criterion = nn.CrossEntropyLoss() 307 | else: 308 | self.nocs_criterion = criterion 309 | self.mesh_criterion = criterion 310 | 311 | self.volume_loss_weight = volume_loss_weight 312 | self.nocs_loss_weight = nocs_loss_weight 313 | self.warp_loss_weight = warp_loss_weight 314 | 315 | self.learning_rate = learning_rate 316 | assert optimizer_type in ('Adam', 'SGD') 317 | self.optimizer_type = optimizer_type 318 | self.vis_per_items = vis_per_items 319 | self.max_vis_per_epoch_train = max_vis_per_epoch_train 320 | self.max_vis_per_epoch_val = max_vis_per_epoch_val 321 | self.batch_size = batch_size 322 | self.debug = debug 323 | 324 | # forward function for each stage 325 | # =============================== 326 | def encoder_forward(self, data, is_train=False): 327 | input1 = ME.SparseTensor(data['feat1'], coordinates=data['coords1']) 328 | input2 = ME.SparseTensor(data['feat2'], coordinates=data['coords2']) 329 | 330 | features_frame1_sparse, _ = self.sparse_unet3d_encoder(input1) 331 | features_frame2_sparse, _ = self.sparse_unet3d_encoder(input2) 332 | features_frame1 = features_frame1_sparse.F 333 | features_frame2 = features_frame2_sparse.F 334 | 335 | pc_nocs_frame1 = data['y1'] 336 | pos_gt_frame2 = data['y2'] 337 | per_point_batch_idx_frame1 = data['pc_batch_idx1'] 338 | per_point_batch_idx_frame2 = data['pc_batch_idx2'] 339 | sim_points_frame1 = data['pos1'] 340 | sim_points_frame2 = data['pos2'] 341 | 342 | if (self.transformer_siamese.encoder_pos_embed_input_dim == 6 and 343 | not self.transformer_siamese.inverse_source_template)\ 344 | or (self.transformer_siamese.decoder_pos_embed_input_dim == 6 and 345 | self.transformer_siamese.inverse_source_template): 346 | frame1_coord = torch.cat([sim_points_frame1, pc_nocs_frame1], dim=-1) 347 | elif self.transformer_siamese.encoder_pos_embed_input_dim == 3: 348 | frame1_coord = sim_points_frame1 349 | else: 350 | raise NotImplementedError 351 | 352 | fusion_feature, pc_logits_frame2 = self.transformer_siamese(features_frame1, frame1_coord, 353 | features_frame2, sim_points_frame2) 354 | 355 | if self.transformer_siamese.nocs_bins is not None: 356 | # NOCS classification 357 | vg = self.transformer_siamese.get_virtual_grid(pc_logits_frame2.get_device()) 358 | nocs_bins = self.transformer_siamese.nocs_bins 359 | pred_logits_bins = pc_logits_frame2.reshape( 360 | (pc_logits_frame2.shape[0], nocs_bins, 3)) 361 | nocs_bin_idx_pred = torch.argmax(pred_logits_bins, dim=1) 362 | pc_nocs_frame2 = vg.idxs_to_points(nocs_bin_idx_pred) 363 | else: 364 | # NOCS regression 365 | pc_nocs_frame2 = pc_logits_frame2 366 | 367 | nocs_data = dict( 368 | x_frame2=fusion_feature, 369 | pos_frame1=pc_nocs_frame1, 370 | pos_frame2=pc_nocs_frame2, 371 | logits_frame2=pc_logits_frame2, 372 | pos_gt_frame2=pos_gt_frame2, 373 | batch_frame1=per_point_batch_idx_frame1, 374 | batch_frame2=per_point_batch_idx_frame2, 375 | sim_points_frame1=sim_points_frame1, 376 | sim_points_frame2=sim_points_frame2) 377 | 378 | if self.use_nocs_refiner: 379 | assert self.transformer_siamese.nocs_bins is not None 380 | mesh_nocs = data['surf_query_points2'] 381 | batch_size = torch.max(per_point_batch_idx_frame2).item() + 1 382 | gt_mesh_nocs = data['gt_surf_query_points2'] if is_train else None 383 | refined_pc_logits_frame2, refined_surf_query_points2 = \ 384 | self.nocs_refiner(pc_nocs_frame2, sim_points_frame2, pc_logits_frame2, fusion_feature, mesh_nocs, 385 | batch_size, is_train=is_train, gt_mesh_nocs=gt_mesh_nocs) 386 | if self.disable_mesh_nocs_refine_in_test: 387 | nocs_data['refined_surf_query_points2'] = mesh_nocs 388 | else: 389 | nocs_data['refined_surf_query_points2'] = refined_surf_query_points2 390 | if self.disable_pc_nocs_refine_in_test: 391 | refined_pc_logits_frame2 = pc_logits_frame2 392 | 393 | nocs_data['refined_logits_frame2'] = refined_pc_logits_frame2 394 | 395 | # get NOCS coordinates from logits 396 | vg = self.transformer_siamese.get_virtual_grid(pc_logits_frame2.get_device()) 397 | nocs_bins = self.transformer_siamese.nocs_bins 398 | refined_pred_logits_bins = refined_pc_logits_frame2.reshape( 399 | (refined_pc_logits_frame2.shape[0], nocs_bins, 3)) 400 | refined_nocs_bin_idx_pred = torch.argmax(refined_pred_logits_bins, dim=1) 401 | refined_pc_nocs_frame2 = vg.idxs_to_points(refined_nocs_bin_idx_pred) 402 | nocs_data['refined_pos_frame2'] = refined_pc_nocs_frame2 403 | 404 | return nocs_data 405 | 406 | def unet3d_forward(self, encoder_result, is_train=False): 407 | # volume agg 408 | in_feature_volume = self.volume_agg(encoder_result, self.batch_size, is_train) 409 | 410 | # unet3d 411 | out_feature_volume = self.unet_3d(in_feature_volume) 412 | unet3d_result = { 413 | 'out_feature_volume': out_feature_volume 414 | } 415 | return unet3d_result 416 | 417 | def surface_decoder_forward(self, unet3d_result, query_points): 418 | out_feature_volume = unet3d_result['out_feature_volume'] 419 | out_features = self.surface_decoder(out_feature_volume, query_points) 420 | decoder_result = { 421 | 'out_features': out_features 422 | } 423 | return decoder_result 424 | 425 | # forward 426 | # ======= 427 | def forward(self, data, is_train=False, use_refine_mesh_for_query=True): 428 | encoder_result = self.encoder_forward(data, is_train) 429 | if is_train: 430 | surface_query_points = data['gt_surf_query_points2'] 431 | else: 432 | if use_refine_mesh_for_query: 433 | surface_query_points = encoder_result['refined_surf_query_points2'] 434 | else: 435 | surface_query_points = data['surf_query_points2'] 436 | unet3d_result = self.unet3d_forward(encoder_result, is_train) 437 | 438 | surface_decoder_result = self.surface_decoder_forward( 439 | unet3d_result, surface_query_points) 440 | 441 | result = { 442 | 'encoder_result': encoder_result, 443 | 'unet3d_result': unet3d_result, 444 | 'surface_decoder_result': surface_decoder_result 445 | } 446 | return result 447 | 448 | # training 449 | # ======== 450 | def configure_optimizers(self): 451 | if self.optimizer_type == 'Adam': 452 | return optim.Adam(self.parameters(), lr=self.learning_rate) 453 | else: 454 | return NotImplementedError 455 | 456 | def vis_batch(self, batch, batch_idx, result, is_train=False, img_size=256): 457 | nocs_data = result['encoder_result'] 458 | pred_nocs = nocs_data['pos_frame2'].detach() 459 | if self.use_nocs_refiner: 460 | refined_pred_pc_nocs = nocs_data['refined_pos_frame2'].detach() 461 | refined_pred_mesh_nocs = nocs_data['refined_surf_query_points2'].detach() 462 | gt_nocs = nocs_data['pos_gt_frame2'] 463 | batch_idxs = nocs_data['batch_frame2'] 464 | this_batch_size = len(batch['dataset_idx1']) 465 | gt_mesh_nocs = batch['gt_surf_query_points2'] 466 | 467 | vis_per_items = self.vis_per_items 468 | batch_size = self.batch_size 469 | if is_train: 470 | max_vis_per_epoch = self.max_vis_per_epoch_train 471 | prefix = 'train_' 472 | else: 473 | max_vis_per_epoch = self.max_vis_per_epoch_val 474 | prefix = 'val_' 475 | 476 | _, selected_idxs, vis_idxs = get_vis_idxs(batch_idx, 477 | batch_size=batch_size, this_batch_size=this_batch_size, 478 | vis_per_items=vis_per_items, max_vis_per_epoch=max_vis_per_epoch) 479 | 480 | log_data = dict() 481 | for i, vis_idx in zip(selected_idxs, vis_idxs): 482 | label = prefix + str(vis_idx) 483 | is_this_item = (batch_idxs == i) 484 | this_gt_nocs = to_numpy(gt_nocs[is_this_item]) 485 | this_pred_nocs = to_numpy(pred_nocs[is_this_item]) 486 | if self.use_nocs_refiner: 487 | refined_label = prefix + 'refine_pc_' + str(vis_idx) 488 | this_refined_pred_pc_nocs = to_numpy(refined_pred_pc_nocs[is_this_item]) 489 | refined_pc_nocs_img = render_nocs_pair(this_gt_nocs, this_refined_pred_pc_nocs, 490 | None, None, img_size=img_size) 491 | log_data[refined_label] = [wandb.Image(refined_pc_nocs_img, caption=refined_label)] 492 | 493 | refined_label = prefix + 'refine_mesh_' + str(vis_idx) 494 | this_refined_pred_mesh_nocs = to_numpy(refined_pred_mesh_nocs.reshape(this_batch_size, -1, 3)[i]) 495 | this_gt_mesh_nocs = to_numpy(gt_mesh_nocs.reshape(this_batch_size, -1, 3)[i]) 496 | refined_mesh_nocs_img = render_nocs_pair(this_gt_mesh_nocs, this_refined_pred_mesh_nocs, 497 | None, None, img_size=img_size) 498 | log_data[refined_label] = [wandb.Image(refined_mesh_nocs_img, caption=refined_label)] 499 | 500 | nocs_img = render_nocs_pair(this_gt_nocs, this_pred_nocs, 501 | None, None, img_size=img_size) 502 | img = nocs_img 503 | log_data[label] = [wandb.Image(img, caption=label)] 504 | 505 | return log_data 506 | 507 | def infer(self, batch, batch_idx, is_train=True): 508 | if len(batch) == 0: 509 | return dict(loss=torch.tensor(0., device='cuda:0', requires_grad=True)) 510 | try: 511 | result = self(batch, is_train=is_train) 512 | encoder_result = result['encoder_result'] 513 | 514 | # NOCS error distance 515 | pred_nocs = encoder_result['pos_frame2'] 516 | gt_nocs = encoder_result['pos_gt_frame2'] 517 | nocs_err_dist = torch.norm(pred_nocs - gt_nocs, dim=-1).mean() 518 | 519 | # NOCS loss 520 | if self.transformer_siamese.nocs_bins is not None: 521 | # classification 522 | nocs_bins = self.transformer_siamese.nocs_bins 523 | vg = self.transformer_siamese.get_virtual_grid(pred_nocs.get_device()) 524 | pred_logits = encoder_result['logits_frame2'] 525 | pred_logits_bins = pred_logits.reshape( 526 | (pred_logits.shape[0], nocs_bins, 3)) 527 | gt_nocs_idx = vg.get_points_grid_idxs(gt_nocs) 528 | nocs_loss = self.nocs_criterion(pred_logits_bins, gt_nocs_idx) * self.nocs_loss_weight 529 | else: 530 | # regression 531 | nocs_loss = self.nocs_criterion(pred_nocs, gt_nocs) * self.nocs_loss_weight 532 | 533 | if self.use_nocs_refiner: 534 | # only support classification 535 | assert self.transformer_siamese.nocs_bins is not None 536 | nocs_bins = self.transformer_siamese.nocs_bins 537 | refined_pred_nocs = encoder_result['refined_pos_frame2'] 538 | refined_nocs_err_dist = torch.norm(refined_pred_nocs - gt_nocs, dim=-1).mean() 539 | 540 | vg = self.transformer_siamese.get_virtual_grid(refined_pred_nocs.get_device()) 541 | refined_pred_logits = encoder_result['refined_logits_frame2'] 542 | refined_pred_logits_bins = refined_pred_logits.reshape( 543 | (refined_pred_logits.shape[0], nocs_bins, 3)) 544 | gt_nocs_idx = vg.get_points_grid_idxs(gt_nocs) 545 | refined_nocs_loss = self.nocs_criterion(refined_pred_logits_bins, gt_nocs_idx) * self.nocs_loss_weight 546 | 547 | refined_mesh_nocs = encoder_result['refined_surf_query_points2'] 548 | gt_mesh_nocs = batch['gt_surf_query_points2'] 549 | mesh_loss = self.mesh_loss_weight * self.mesh_criterion(refined_mesh_nocs, gt_mesh_nocs) 550 | refined_mesh_err_dist = torch.norm(refined_mesh_nocs - gt_mesh_nocs, dim=-1).mean() 551 | 552 | # warp field loss (surface loss) 553 | surface_decoder_result = result['surface_decoder_result'] 554 | surface_criterion = self.surface_criterion 555 | pred_warp_field = surface_decoder_result['out_features'] 556 | gt_sim_points_frame2 = batch['gt_sim_points2'] 557 | gt_warpfield = gt_sim_points_frame2.reshape(pred_warp_field.shape) 558 | warp_loss = surface_criterion(pred_warp_field, gt_warpfield) 559 | warp_loss = self.warp_loss_weight * warp_loss 560 | 561 | loss_dict = { 562 | 'nocs_err_dist': nocs_err_dist, 563 | 'nocs_loss': nocs_loss, 564 | 'warp_loss': warp_loss, 565 | } 566 | if self.use_nocs_refiner: 567 | loss_dict['refined_nocs_loss'] = refined_nocs_loss 568 | loss_dict['refined_nocs_err_dist'] = refined_nocs_err_dist 569 | loss_dict['mesh_loss'] = mesh_loss 570 | loss_dict['refined_mesh_err_dist'] = refined_mesh_err_dist 571 | 572 | metrics = dict(loss_dict) 573 | metrics['loss'] = nocs_loss + warp_loss 574 | if self.use_nocs_refiner: 575 | metrics['loss'] += refined_nocs_loss 576 | metrics['loss'] += mesh_loss 577 | 578 | for key, value in metrics.items(): 579 | log_key = ('train_' if is_train else 'val_') + key 580 | self.log(log_key, value) 581 | log_data = self.vis_batch(batch, batch_idx, result, is_train=is_train) 582 | self.logger.log_metrics(log_data, step=self.global_step) 583 | except Exception as e: 584 | raise e 585 | 586 | return metrics 587 | 588 | def training_step(self, batch, batch_idx): 589 | # torch.cuda.empty_cache() 590 | metrics = self.infer(batch, batch_idx, is_train=True) 591 | return metrics['loss'] 592 | 593 | def validation_step(self, batch, batch_idx): 594 | # torch.cuda.empty_cache() 595 | metrics = self.infer(batch, batch_idx, is_train=False) 596 | return metrics['loss'] 597 | --------------------------------------------------------------------------------