├── utils ├── __pycache__ │ ├── log.cpython-36.pyc │ ├── log_SPST.cpython-36.pyc │ ├── pc_utils.cpython-36.pyc │ ├── trans_norm.cpython-36.pyc │ ├── pc_utils_Norm.cpython-36.pyc │ └── log_SPST_selection.cpython-36.pyc ├── log.py ├── log_SPST.py ├── loss.py ├── metasets_data_utils.py ├── trans_norm.py ├── pc_utils.py └── pc_utils_Norm.py ├── data ├── download.py ├── dataloader_MetaSets.py ├── dataloader_GraspNetPC.py ├── grasp_datautils.py └── dataloader_PointDA_initial.py ├── requirements.txt ├── run_test.sh ├── README.md ├── models ├── pointnet_util.py └── model.py ├── critic.py ├── augmentation.py └── SPST_finetune_PCFEA_cls.py /utils/__pycache__/log.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaoyao3302/PCFEA/HEAD/utils/__pycache__/log.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/log_SPST.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaoyao3302/PCFEA/HEAD/utils/__pycache__/log_SPST.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/pc_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaoyao3302/PCFEA/HEAD/utils/__pycache__/pc_utils.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/trans_norm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaoyao3302/PCFEA/HEAD/utils/__pycache__/trans_norm.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/pc_utils_Norm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaoyao3302/PCFEA/HEAD/utils/__pycache__/pc_utils_Norm.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/log_SPST_selection.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaoyao3302/PCFEA/HEAD/utils/__pycache__/log_SPST_selection.cpython-36.pyc -------------------------------------------------------------------------------- /data/download.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gdown 3 | 4 | url = 'https://drive.google.com/uc?id=1-LfJWL5geF9h0Z2QpdTL0n4lShy8wy2J' 5 | output = 'PointDA_data.zip' 6 | gdown.download(url, output, quiet=False) 7 | 8 | os.system('unzip PointDA_data.zip') 9 | os.system('rm PointDA_data.zip') -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | gdown==3.10.3 2 | numpy==1.18.1 3 | torchsummary==1.5.1 4 | pandas==0.25.1 5 | h5py==2.10.0 6 | torch==1.3.1 7 | scikit_learn==0.22.2.post1 8 | open3d 9 | opencv-python 10 | sklearn 11 | tensorboard 12 | tensorboardX 13 | 14 | 15 | 16 | conda create -n GTSA python=3.8 17 | pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu118 18 | pip install pandas 19 | pip install h5py 20 | pip install -U scikit-learn 21 | pip install open3d 22 | pip install tensorboard 23 | pip install tensorboardX 24 | pip install numpy==1.23 -------------------------------------------------------------------------------- /run_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python train_GTSA_cls.py \ 3 | --num_workers 6 \ 4 | --seed 1024 \ 5 | --use_avg_pool False \ 6 | --src_dataset 'modelnet' \ 7 | --trgt_dataset 'shapenet' \ 8 | --epochs 100 \ 9 | --gpus '0' \ 10 | --batch_size 8 \ 11 | --test_batch_size 8 \ 12 | --use_aug False \ 13 | --lambda_0 0.25 \ 14 | --epoch_warmup 10 \ 15 | --selection_strategy 'ratio' \ 16 | --use_gradual_src_threshold False \ 17 | --use_gradual_trgt_threshold True \ 18 | --mode_src_threshold 'nonlinear' \ 19 | --mode_trgt_threshold 'nonlinear' \ 20 | --exp_k 0.15 \ 21 | --src_threshold 0.0 \ 22 | --trgt_threshold 1.0 \ 23 | --use_gradual_src_ratio True \ 24 | --use_gradual_trgt_ratio True \ 25 | --src_ratio 1.0 \ 26 | --trgt_ratio 1.0 \ 27 | --period_update_pool 10 \ 28 | --use_model_eval True \ 29 | --loss_function 'use_mean' \ 30 | --use_EMA False \ 31 | --EMA_update_warmup False \ 32 | --EMA_decay 0.99 \ 33 | --use_src_IDFA True \ 34 | --use_trgt_IDFA True \ 35 | --tao 2.0 \ 36 | --w_PCFEA 1.0 \ 37 | --w_src_IDFA 1.0 \ 38 | --w_trgt_IDFA 1.0 \ 39 | --exp_name 'test' -------------------------------------------------------------------------------- /utils/log.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import pandas as pd 3 | import copy 4 | import torch 5 | import os 6 | import sklearn.metrics as metrics 7 | import pdb 8 | 9 | 10 | class IOStream(): 11 | """ 12 | Logging to screen and file 13 | """ 14 | def __init__(self, args): 15 | self.path = args.out_path + '/' + args.src_dataset + '_' + args.trgt_dataset + '/' + args.model + '/' 16 | if not os.path.exists(self.path): 17 | os.makedirs(self.path) 18 | if args.exp_name is None: 19 | timestamp = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M')) 20 | self.path = self.path + '/' + timestamp 21 | else: 22 | self.path = self.path + '/' + args.exp_name 23 | if not os.path.exists(self.path): 24 | os.makedirs(self.path) 25 | self.f = open(self.path + '/run.log', 'a') 26 | self.args = args 27 | 28 | def cprint(self, text): 29 | datetime_string = datetime.datetime.now().strftime("%d-%m-%y %H:%M:%S") 30 | to_print = "%s: %s" % (datetime_string, text) 31 | print(to_print) 32 | self.f.write(to_print + "\n") 33 | self.f.flush() 34 | 35 | def close(self): 36 | self.f.close() 37 | 38 | def save_model(self, model, epoch, additional_info): 39 | path = self.path + '/' + additional_info 40 | if not os.path.exists(path): 41 | os.makedirs(path) 42 | path = path + '/model.pt' 43 | 44 | best_model = copy.deepcopy(model) 45 | if len(self.args.gpus) > 1: 46 | torch.save(model.module.state_dict(), path) 47 | else: 48 | state = { 49 | 'epoch': epoch, 50 | 'model': best_model.state_dict() 51 | } 52 | torch.save(state, path) 53 | return best_model 54 | 55 | def print_progress(self, domain_set, partition, epoch, print_losses, true=None, pred=None): 56 | outstr = "%s - %s %d" % (partition, domain_set, epoch) 57 | acc = 0 58 | if true is not None and pred is not None: 59 | acc = metrics.accuracy_score(true, pred) 60 | avg_per_class_acc = metrics.balanced_accuracy_score(true, pred) 61 | outstr += ", acc: %.4f, avg acc: %.4f" % (acc, avg_per_class_acc) 62 | 63 | for loss, loss_val in print_losses.items(): 64 | outstr += ", %s loss: %.4f" % (loss, loss_val) 65 | self.cprint(outstr) 66 | return acc 67 | -------------------------------------------------------------------------------- /utils/log_SPST.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import pandas as pd 3 | import copy 4 | import torch 5 | import os 6 | import sklearn.metrics as metrics 7 | import pdb 8 | 9 | 10 | class IOStream(): 11 | """ 12 | Logging to screen and file 13 | """ 14 | def __init__(self, args): 15 | self.path = args.out_path + '/' + args.src_dataset + '_' + args.trgt_dataset + '/' + args.model + '/' 16 | if not os.path.exists(self.path): 17 | os.makedirs(self.path) 18 | if args.exp_name is None: 19 | timestamp = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M')) 20 | self.path = self.path + '/' + timestamp 21 | else: 22 | self.path = self.path + '/' + args.exp_name 23 | if not os.path.exists(self.path): 24 | os.makedirs(self.path) 25 | self.f = open(self.path + '/run_SPST.log', 'a') 26 | self.args = args 27 | 28 | def cprint(self, text): 29 | datetime_string = datetime.datetime.now().strftime("%d-%m-%y %H:%M:%S") 30 | to_print = "%s: %s" % (datetime_string, text) 31 | print(to_print) 32 | self.f.write(to_print + "\n") 33 | self.f.flush() 34 | 35 | def close(self): 36 | self.f.close() 37 | 38 | def save_model(self, model, epoch, additional_info): 39 | path = self.path + '/' + additional_info 40 | if not os.path.exists(path): 41 | os.makedirs(path) 42 | path = path + '/model.pt' 43 | 44 | best_model = copy.deepcopy(model) 45 | if len(self.args.gpus) > 1: 46 | torch.save(model.module.state_dict(), path) 47 | else: 48 | state = { 49 | 'epoch': epoch, 50 | 'model': best_model.state_dict() 51 | } 52 | torch.save(state, path) 53 | return best_model 54 | 55 | def print_progress(self, domain_set, partition, epoch, print_losses, true=None, pred=None): 56 | outstr = "%s - %s %d" % (partition, domain_set, epoch) 57 | acc = 0 58 | if true is not None and pred is not None: 59 | acc = metrics.accuracy_score(true, pred) 60 | avg_per_class_acc = metrics.balanced_accuracy_score(true, pred) 61 | outstr += ", acc: %.4f, avg acc: %.4f" % (acc, avg_per_class_acc) 62 | 63 | for loss, loss_val in print_losses.items(): 64 | outstr += ", %s loss: %.4f" % (loss, loss_val) 65 | self.cprint(outstr) 66 | return acc 67 | -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class LabelSmoothingCrossEntropy(nn.Module): 7 | def __init__(self, eps=0.1, reduction='mean'): 8 | super(LabelSmoothingCrossEntropy, self).__init__() 9 | self.eps = eps 10 | self.reduction = reduction 11 | 12 | def forward(self, output, target): 13 | c = output.size()[-1] 14 | log_preds = F.log_softmax(output, dim=-1) 15 | if self.reduction == 'sum': 16 | loss = -log_preds.sum() 17 | else: 18 | loss = -log_preds.sum(dim=-1) 19 | if self.reduction == 'mean': 20 | loss = loss.mean() 21 | return loss * self.eps / c + (1 - self.eps) * F.nll_loss(log_preds, target, reduction=self.reduction) 22 | 23 | 24 | class OrthogonalMatrixLoss(nn.Module): 25 | def __init__(self): 26 | super(OrthogonalMatrixLoss, self).__init__() 27 | 28 | def forward(self, x): 29 | batch_size = x.size()[0] 30 | m = torch.bmm(x, x.transpose(1, 2)) 31 | d = m.size()[1] 32 | diag_sum = 0 33 | for i in range(batch_size): 34 | for j in range(d): 35 | diag_sum += m[i][j][j] 36 | return (m.sum() - diag_sum) / batch_size 37 | 38 | 39 | # barlow twins 40 | class OrthogonalMatrixLoss_BT(nn.Module): 41 | def __init__(self, lamb=0.1): 42 | super(OrthogonalMatrixLoss_BT, self).__init__() 43 | self.lamb = lamb 44 | 45 | def forward(self, x): 46 | batch_size = x.size()[0] 47 | m = torch.bmm(x, x.transpose(1, 2)) 48 | m_square = m.pow(2) 49 | d = m.size()[1] 50 | diag_sum = 0 51 | off_diag_sum = m_square.sum() 52 | for i in range(batch_size): 53 | for j in range(d): 54 | diag_sum += (1 - 2 * m[i][j][j] + m_square[i][j][j]) 55 | off_diag_sum -= m_square[i][j][j] 56 | return (diag_sum + self.lamb * off_diag_sum) / batch_size 57 | 58 | 59 | class BarlowTwins(nn.Module): 60 | def __init__(self, lamb=0.02): 61 | super().__init__() 62 | self.lamb = lamb 63 | 64 | def forward(self, y1, y2): 65 | # empirical cross-correlation matrix 66 | c = torch.mm(y1.T, y2) 67 | c.div_(y1.shape[0]) 68 | 69 | # use --scale-loss to multiply the loss by a constant factor 70 | # see the Issues section of the readme 71 | on_diag = torch.diagonal(c).add_(-1).pow_(2).sum() 72 | off_diag = self.off_diagonal(c).pow_(2).sum() 73 | loss = on_diag + self.lamb * off_diag 74 | return loss 75 | 76 | def off_diagonal(self, x): 77 | # return a flattened view of the off-diagonal elements of a square matrix 78 | n, m = x.shape 79 | assert n == m 80 | return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten() 81 | 82 | 83 | 84 | 85 | -------------------------------------------------------------------------------- /data/dataloader_MetaSets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as data 3 | import os 4 | import sys 5 | import h5py 6 | import numpy as np 7 | from multiprocessing.dummy import Pool 8 | from torchvision import transforms 9 | import glob 10 | import random 11 | import threading 12 | import time 13 | from utils.pc_utils import * 14 | from augmentation import density, drop_hole, p_scan 15 | import pdb 16 | 17 | 18 | class PaddingData(data.Dataset): 19 | def __init__(self, pc_root, status='train', swapax=False, pc_input_num=2048): 20 | super(PaddingData, self).__init__() 21 | 22 | self.status = status 23 | 24 | self.pc_list = [] 25 | self.lbl_list = [] 26 | self.transforms = transforms.Compose( 27 | [ 28 | PointcloudToTensor(), 29 | PointcloudScale(), 30 | PointcloudRotate(), 31 | PointcloudRotatePerturbation(), 32 | PointcloudTranslate(), 33 | PointcloudJitter(), 34 | ] 35 | ) 36 | self.pc_input_num = pc_input_num 37 | 38 | categorys = glob.glob(os.path.join(pc_root, '*')) 39 | categorys = [c.split(os.path.sep)[-1] for c in categorys] 40 | categorys = sorted(categorys) 41 | print(categorys) 42 | 43 | if status == 'train': 44 | npy_list = glob.glob(os.path.join(pc_root, '*', 'train', '*.npy')) 45 | else: 46 | npy_list = glob.glob(os.path.join(pc_root, '*', 'test', '*.npy')) 47 | 48 | for idx, _dir in enumerate(npy_list): 49 | print("\r%d/%d" % (idx, len(npy_list)), end="") 50 | pc = np.load(_dir).astype(np.float32) 51 | if swapax: 52 | pc[:, 1] = pc[:, 2] + pc[:, 1] 53 | pc[:, 2] = pc[:, 1] - pc[:, 2] 54 | pc[:, 1] = pc[:, 1] - pc[:, 2] 55 | self.pc_list.append(pc) 56 | self.lbl_list.append(categorys.index(_dir.split('/')[-3])) 57 | print() 58 | 59 | print(f'{status} data num: {len(self.pc_list)}') 60 | 61 | def __getitem__(self, idx): 62 | lbl = self.lbl_list[idx] 63 | pc = self.pc_list[idx] # 2048, 3 64 | pc = normal_pc(pc) 65 | 66 | pn = min(pc.shape[0], self.pc_input_num) 67 | if self.status == 'train': 68 | pc_aug = pc 69 | if np.random.random() > 0.5: 70 | pc_aug = density(pc_aug, num_point=2048) 71 | if np.random.random() > 0.5: 72 | pc_aug = drop_hole(pc_aug, num_point=2048) 73 | if np.random.random() > 0.5: 74 | pc_aug = p_scan(pc_aug, num_point=2048) 75 | 76 | pc_aug = self.transforms(pc_aug) 77 | pc_aug = pc_aug.numpy() 78 | 79 | pc = self.transforms(pc) 80 | pc = pc.numpy() 81 | else: 82 | pc_aug = pc 83 | 84 | if pn < self.pc_input_num: 85 | pc = np.append(pc, np.zeros((self.pc_input_num - pc.shape[0], 3)), axis=0) 86 | pc_aug = np.append(pc_aug, np.zeros((self.pc_input_num - pc_aug.shape[0], 3)), axis=0) 87 | pc = pc[:self.pc_input_num] 88 | pc_aug = pc_aug[:self.pc_input_num] 89 | 90 | return (idx, pc, lbl, pc_aug) 91 | 92 | def __len__(self): 93 | return len(self.pc_list) 94 | 95 | 96 | if __name__ == '__main__': 97 | root = '../../data/MetaSets/' 98 | dataset = 'scanobjectnn_9' 99 | data_root = root + dataset 100 | dataset = PaddingData(data_root, status='train') 101 | print(dataset[1]) 102 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PCFEA 2 | 3 | This is the official PyTorch implementation of our paper: 4 | 5 | > **[Progressive Classifier and Feature Extractor Adaptation for Unsupervised Domain Adaptation on Point Clouds](https://arxiv.org/abs/2311.16474)** 6 | > *In European Conference on Computer Vision (ECCV), 2024* 7 | 8 | 9 | > **Abstract.** 10 | > Unsupervised domain adaptation (UDA) is a critical challenge in the field of point cloud analysis. Previous works tackle the problem either by feature extractor adaptation to enable a shared classifier to distinguish domain-invariant features, or by classifier adaptation to evolve the classifier to recognize target-styled source features to increase its adaptation ability. However, by learning domain-invariant features, feature extractor adaptation methods fail to encode semantically meaningful target-specific information, while classifier adaptation methods rely heavily on the accurate estimation of the target distribution. In this work, we propose a novel framework that deeply couples the classifier and feature extractor adaption for 3D UDA, dubbed Progressive Classifier and Feature Extractor Adaptation (PCFEA). Our PCFEA conducts 3D UDA from two distinct perspectives: macro and micro levels. On the macro level, we propose a progressive target-styled feature augmentation (PTFA) that establishes a series of intermediate domains to enable the model to progressively adapt to the target domain. Throughout this process, the source classifier is evolved to recognize target-styled source features (\ie, classifier adaptation). On the micro level, we develop an intermediate domain feature extractor adaptation (IDFA) that performs a compact feature alignment to encourage the target-styled feature extraction gradually. In this way, PTFA and IDFA can mutually benefit each other: IDFA contributes to the distribution estimation of PTFA while PTFA constructs smoother intermediate domains to encourage an accurate feature alignment of IDFA. We validate our method on popular benchmark datasets, where our method achieves new state-of-the-art performance. 11 | 12 | 13 | 14 | ## Getting Started 15 | 16 | ### Installation 17 | 18 | Please follow the steps in the requirements.txt file to prepare the environment. 19 | 20 | 21 | ### Dataset: 22 | 23 | Our code supports PointDA-10 dataset and GraspNetPC-10 dataset. 24 | 25 | - Please download PointDA-10 dataset at https://drive.google.com/file/d/1-LfJWL5geF9h0Z2QpdTL0n4lShy8wy2J/view?usp=sharing. 26 | 27 | - Please download GraspNetPC-10 dataset at https://drive.google.com/file/d/1VVHmsSToFMVccge-LsYJW67IS94rNxWR/view?usp=sharing. 28 | 29 | 30 | Please unzip the datasets and modify the dataset path in configuration files. 31 | 32 | For example, if you put your data under the data folder like this, you can directly bash our run_test.sh file to run the code. 33 | ``` 34 | 35 | ├── data 36 | ├── GraspNetPointClouds 37 | ├── test 38 | └── train 39 | └── PointDA_data 40 | ├── modelnet 41 | ├── scannet 42 | └── shapenet 43 | 44 | ├── PCFEA 45 | ├── data 46 | ├── dataloader_XXXX.py 47 | ├── .... 48 | └── grasp_datautils 49 | ├── log 50 | ├── XXX.txt 51 | └── XXX.txt 52 | ├── models 53 | ├── model.py 54 | └── pointnet_util.py 55 | ├── utils 56 | ├── log_SPST.py 57 | ├── .... 58 | └── trans_norm.py 59 | ├── augmentation.py 60 | ├── .... 61 | └── train_PCFEA_cls.py 62 | 63 | ``` 64 | 65 | 66 | 67 | ## Usage 68 | 69 | We tried many different methods before finally coming up with our PCFEA. So, we also provide the codes of some of the methods we tried, so the config may seem a bit cumbersome. 70 | If you want to reproduce our results, please directly bash run_test.sh. 71 | 72 | To run with different settings, please modify the settings in the sh file. 73 | 74 | We have uploaded the log files in the log folder. 75 | 76 | Note that all of our experiments are tested on 2080Ti, A5000 or 3090. 77 | 78 | 79 | 80 | ## Citation 81 | 82 | If you find these projects useful, please consider citing our paper. 83 | 84 | 85 | 86 | 87 | ## Acknowledgement 88 | 89 | We thank [GAST](https://github.com/zou-longkun/GAST), [MLSP](https://github.com/VITA-Group/MLSP), [DefRec_and_PCM](https://github.com/IdanAchituve/DefRec_and_PCM), [PointDAN](https://github.com/canqin001/PointDAN), [ImplicitPCDA](https://github.com/Jhonve/ImplicitPCDA), [DGCNN](https://github.com/WangYueFt/dgcnn), [PointNet](https://github.com/charlesq34/pointnet), [PointNet++](https://github.com/charlesq34/pointnet2) and other relevant works for their amazing open-sourced projects! 90 | -------------------------------------------------------------------------------- /data/dataloader_GraspNetPC.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | 4 | import h5py 5 | import numpy as np 6 | import open3d as o3d 7 | from augmentation import density, drop_hole, p_scan 8 | import torch.utils.data as data 9 | 10 | # from grasp_datautils import jitter_pointcloud, scale_to_unit_cube, rotate_pc, random_rotate_one_axis, jitter_pointcloud_adaptive 11 | # from utils.pc_utils_Norm import farthest_point_sample_no_curv_np 12 | 13 | import pdb 14 | 15 | 16 | class GraspNetPointClouds(data.Dataset): 17 | def __init__(self, dataroot, partition='train'): 18 | super(GraspNetPointClouds).__init__() 19 | self.partition = partition 20 | 21 | def __getitem__(self, item): 22 | o3d_pointcloud = o3d.io.read_point_cloud(self.pc_list[item]) 23 | pointcloud = np.asarray(o3d_pointcloud.points) # 1024, 3 24 | 25 | pointcloud = pointcloud.astype(np.float32) 26 | path = self.pc_list[item].split('.x')[0] 27 | label = np.copy(self.label[item]) 28 | 29 | if self.partition == 'train': 30 | pointcloud_aug = pointcloud 31 | if np.random.random() > 0.5: 32 | pointcloud_aug = density(pointcloud_aug) 33 | if np.random.random() > 0.5: 34 | pointcloud_aug = drop_hole(pointcloud_aug) 35 | if np.random.random() > 0.5: 36 | pointcloud_aug = p_scan(pointcloud_aug) 37 | else: 38 | pointcloud_aug = pointcloud 39 | 40 | data_item = {} 41 | data_item['PC'] = pointcloud 42 | data_item['Label'] = label 43 | data_item['PC_Aug'] = pointcloud_aug 44 | 45 | return (item, pointcloud, label, pointcloud_aug) 46 | 47 | def __len__(self): 48 | return len(self.pc_list) 49 | 50 | # def get_data_loader(self, batch_size, num_workers, drop_last, shuffle=True): 51 | # return data.DataLoader(dataset=self, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True, drop_last=drop_last) 52 | 53 | class GraspNetRealPointClouds(GraspNetPointClouds): 54 | def __init__(self, dataroot, mode, partition='train'): 55 | super(GraspNetRealPointClouds).__init__() 56 | self.partition = partition 57 | 58 | dataroot = os.path.join(dataroot, 'GraspNetPC/GraspNetPointClouds') 59 | 60 | # pdb.set_trace() 61 | 62 | DATA_DIR = os.path.join(dataroot, partition, "Real", mode) # mode can be 'kinect' or 'realsense' 63 | # read data 64 | xyzs_list = sorted(glob.glob(os.path.join(DATA_DIR, '*', '*.xyz'))) 65 | 66 | self.pc_list = [] 67 | self.lbl_list = [] 68 | 69 | for xyz_path in xyzs_list: 70 | self.pc_list.append(xyz_path) 71 | self.lbl_list.append(int(xyz_path.split('/')[-2])) 72 | 73 | self.label = np.asarray(self.lbl_list) 74 | self.num_examples = len(self.pc_list) 75 | 76 | if partition == "train": 77 | self.train_ind = np.asarray([i for i in range(self.num_examples) if i % 10 < 8]).astype(np.int) 78 | np.random.shuffle(self.train_ind) 79 | self.val_ind = np.asarray([i for i in range(self.num_examples) if i % 10 >= 8]).astype(np.int) 80 | np.random.shuffle(self.val_ind) 81 | 82 | 83 | class GraspNetSynthetictPointClouds(GraspNetPointClouds): 84 | def __init__(self, dataroot, partition='train', device=None, use_density=True, use_drop=True, use_scan=True): 85 | super(GraspNetSynthetictPointClouds).__init__() 86 | self.partition = partition 87 | 88 | dataroot = os.path.join(dataroot, 'GraspNetPC/GraspNetPointClouds') 89 | 90 | # pdb.set_trace() 91 | 92 | if device == None: 93 | DATA_DIR_kinect = os.path.join(dataroot, partition, "Synthetic", "kinect") 94 | DATA_DIR_realsense = os.path.join(dataroot, partition, "Synthetic", "realsense") 95 | xyzs_list = sorted(glob.glob(os.path.join(DATA_DIR_kinect, '*', '*.xyz'))) 96 | xyzs_list_realsense = sorted(glob.glob(os.path.join(DATA_DIR_realsense, '*', '*.xyz'))) 97 | 98 | xyzs_list.extend(xyzs_list_realsense) 99 | elif device == 'kinect': 100 | DATA_DIR = os.path.join(dataroot, partition, "Synthetic", "kinect") 101 | xyzs_list = sorted(glob.glob(os.path.join(DATA_DIR, '*', '*.xyz'))) 102 | elif device == 'realsense': 103 | DATA_DIR = os.path.join(dataroot, partition, "Synthetic", "realsense") 104 | xyzs_list = sorted(glob.glob(os.path.join(DATA_DIR, '*', '*.xyz'))) 105 | 106 | self.pc_list = [] 107 | self.lbl_list = [] 108 | 109 | for xyz_path in xyzs_list: 110 | self.pc_list.append(xyz_path) 111 | self.lbl_list.append(int(xyz_path.split('/')[-2])) 112 | 113 | self.label = np.asarray(self.lbl_list) 114 | self.num_examples = len(self.pc_list) 115 | 116 | if partition == "train": 117 | self.train_ind = np.asarray([i for i in range(self.num_examples) if i % 10 < 8]).astype(np.int) 118 | np.random.shuffle(self.train_ind) 119 | self.val_ind = np.asarray([i for i in range(self.num_examples) if i % 10 >= 8]).astype(np.int) 120 | np.random.shuffle(self.val_ind) 121 | 122 | 123 | 124 | if __name__ == '__main__': 125 | root = '../../data/GraspNetPC-10/GraspNetPointClouds/' 126 | device = 'kinect' 127 | dataset = GraspNetSynthetictPointClouds(root, partition='train', device=device) -------------------------------------------------------------------------------- /utils/metasets_data_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def angle_axis(angle, axis): 6 | u = axis / np.linalg.norm(axis) 7 | cosval, sinval = np.cos(angle), np.sin(angle) 8 | 9 | cross_prod_mat = np.array([[0.0, -u[2], u[1]], 10 | [u[2], 0.0, -u[0]], 11 | [-u[1], u[0], 0.0]]) 12 | 13 | R = torch.from_numpy( 14 | cosval * np.eye(3) 15 | + sinval * cross_prod_mat 16 | + (1.0 - cosval) * np.outer(u, u) 17 | ) 18 | return R.float() 19 | 20 | 21 | class PointcloudScale(object): 22 | def __init__(self, lo=0.8, hi=1.25): 23 | self.lo, self.hi = lo, hi 24 | 25 | def __call__(self, points): 26 | scaler = np.random.uniform(self.lo, self.hi) 27 | points[:, 0:3] *= scaler 28 | return points 29 | 30 | 31 | class PointcloudRotate(object): 32 | def __init__(self, axis=np.array([0.0, 1.0, 0.0])): 33 | self.axis = axis 34 | 35 | def __call__(self, points): 36 | rotation_angle = np.random.uniform() * 2 * np.pi 37 | rotation_matrix = angle_axis(rotation_angle, self.axis) 38 | 39 | normals = points.size(1) > 3 40 | if not normals: 41 | return torch.matmul(points, rotation_matrix.t()) 42 | else: 43 | pc_xyz = points[:, 0:3] 44 | pc_normals = points[:, 3:] 45 | points[:, 0:3] = torch.matmul(pc_xyz, rotation_matrix.t()) 46 | points[:, 3:] = torch.matmul(pc_normals, rotation_matrix.t()) 47 | 48 | return points 49 | 50 | 51 | class PointcloudRotatePerturbation(object): 52 | def __init__(self, angle_sigma=0.06, angle_clip=0.18): 53 | self.angle_sigma, self.angle_clip = angle_sigma, angle_clip 54 | 55 | def _get_angles(self): 56 | angles = np.clip( 57 | self.angle_sigma * np.random.randn(3), -self.angle_clip, self.angle_clip 58 | ) 59 | 60 | return angles 61 | 62 | def __call__(self, points): 63 | angles = self._get_angles() 64 | Rx = angle_axis(angles[0], np.array([1.0, 0.0, 0.0])) 65 | Ry = angle_axis(angles[1], np.array([0.0, 1.0, 0.0])) 66 | Rz = angle_axis(angles[2], np.array([0.0, 0.0, 1.0])) 67 | 68 | rotation_matrix = torch.matmul(torch.matmul(Rz, Ry), Rx) 69 | 70 | normals = points.size(1) > 3 71 | if not normals: 72 | return torch.matmul(points, rotation_matrix.t()) 73 | else: 74 | pc_xyz = points[:, 0:3] 75 | pc_normals = points[:, 3:] 76 | points[:, 0:3] = torch.matmul(pc_xyz, rotation_matrix.t()) 77 | points[:, 3:] = torch.matmul(pc_normals, rotation_matrix.t()) 78 | 79 | return points 80 | 81 | 82 | class PointcloudJitter(object): 83 | def __init__(self, std=0.01, clip=0.05): 84 | self.std, self.clip = std, clip 85 | 86 | def __call__(self, points): 87 | jittered_data = ( 88 | points.new(points.size(0), 3) 89 | .normal_(mean=0.0, std=self.std) 90 | .clamp_(-self.clip, self.clip) 91 | ) 92 | points[:, 0:3] += jittered_data 93 | return points 94 | 95 | 96 | class PointcloudTranslate(object): 97 | def __init__(self, translate_range=0.1): 98 | self.translate_range = translate_range 99 | 100 | def __call__(self, points): 101 | translation = np.random.uniform(-self.translate_range, self.translate_range) 102 | points[:, 0:3] += translation 103 | return points 104 | 105 | 106 | class PointcloudToTensor(object): 107 | def __call__(self, points): 108 | return torch.from_numpy(points).float() 109 | 110 | 111 | def normal_pc(pc): 112 | pc_mean = pc.mean(axis=0) 113 | pc = pc - pc_mean 114 | pc_L_max = np.max(np.sqrt(np.sum(abs(pc ** 2), axis=-1))) 115 | pc = pc/pc_L_max 116 | return pc 117 | 118 | def density(pc, v_point=np.array([1, 0, 0]), gate=1): 119 | dist = np.sqrt((v_point ** 2).sum()) 120 | max_dist = dist + 1 121 | min_dist = dist - 1 122 | dist = np.linalg.norm(pc - v_point.reshape(1,3), axis=1) 123 | dist = (dist - min_dist) / (max_dist - min_dist) 124 | r_list = np.random.uniform(0, 1, pc.shape[0]) 125 | tmp_pc = pc[dist * gate < (r_list)] 126 | return tmp_pc 127 | 128 | def p_scan(pc, pixel_size=0.017): 129 | pixel = int(2 / pixel_size) 130 | rotated_pc = rotate_point_cloud_3d(pc) 131 | pc_compress = (rotated_pc[:,2] + 1) / 2 * pixel * pixel + (rotated_pc[:,1] + 1) / 2 * pixel 132 | points_list = [None for i in range((pixel + 5) * (pixel + 5))] 133 | pc_compress = pc_compress.astype(np.int) 134 | for index, point in enumerate(rotated_pc): 135 | compress_index = pc_compress[index] 136 | if compress_index > len(points_list): 137 | print('out of index:', compress_index, len(points_list), point, pc[index], (pc[index] ** 2).sum(), (point ** 2).sum()) 138 | if points_list[compress_index] is None: 139 | points_list[compress_index] = index 140 | elif point[0] > rotated_pc[points_list[compress_index]][0]: 141 | points_list[compress_index] = index 142 | points_list = list(filter(lambda x: x is not None, points_list)) 143 | points_list = pc[points_list] 144 | return points_list 145 | 146 | def drop_hole(pc, p): 147 | random_point = np.random.randint(0, pc.shape[0]) 148 | index = np.linalg.norm(pc - pc[random_point].reshape(1,3), axis=1).argsort() 149 | return pc[index[int(pc.shape[0] * p):]] 150 | 151 | def rotate_point_cloud_3d(pc): 152 | rotation_angle = np.random.rand(3) * 2 * np.pi 153 | cosval = np.cos(rotation_angle) 154 | sinval = np.sin(rotation_angle) 155 | rotation_matrix_1 = np.array([[cosval[0], 0, sinval[0]], 156 | [0, 1, 0], 157 | [-sinval[0], 0, cosval[0]]]) 158 | rotation_matrix_2 = np.array([[1, 0, 0], 159 | [0, cosval[1], -sinval[1]], 160 | [0, sinval[1], cosval[1]]]) 161 | rotation_matrix_3 = np.array([[cosval[2], -sinval[2], 0], 162 | [sinval[2], cosval[2], 0], 163 | [0, 0, 1]]) 164 | rotation_matrix = np.matmul(np.matmul(rotation_matrix_1, rotation_matrix_2), rotation_matrix_3) 165 | rotated_data = np.dot(pc.reshape((-1, 3)), rotation_matrix) 166 | 167 | return rotated_data 168 | -------------------------------------------------------------------------------- /data/grasp_datautils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import open3d as o3d 4 | 5 | import torch 6 | 7 | def label_processing(label_path, label_id): 8 | label = cv2.imread(label_path) 9 | label_max = np.max(label) 10 | 11 | if label_max == 255: # is rendered synthetic label 12 | label = label[:, :, 0] 13 | foreground_mask = (label == label_id + 1) 14 | background_mask = 1 - foreground_mask 15 | else: 16 | label = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE) 17 | foreground_mask = (label == label_id + 1) 18 | background_mask = 1 - foreground_mask 19 | 20 | return background_mask.astype(np.bool8) 21 | 22 | def label_background_processing(label_path): 23 | label = cv2.imread(label_path) 24 | label_max = np.max(label) 25 | 26 | if label_max == 255: # is rendered synthetic label 27 | label = label[:, :, 0] 28 | background_mask = (label >= 90) # there are 88 models in total, reduce rasterization error 29 | else: 30 | label = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE) 31 | background_mask = (label <= 0) 32 | 33 | return background_mask 34 | 35 | def lable_cropping(label_path, label_id, bias=8): 36 | label_img = cv2.imread(label_path) 37 | h = label_img.shape[0] 38 | w = label_img.shape[1] 39 | 40 | idx_h, idx_w, _ = np.where(label_img == label_id + 1) 41 | 42 | if idx_h.shape[0] == 0: # no such label 43 | up = int((h / 2) - 128) 44 | down = int((h / 2) + 128) 45 | left = int((w / 2) - 128) 46 | right = int((w / 2) + 128) 47 | return up, down, left, right 48 | 49 | up = np.min(idx_h) - bias 50 | down = np.max(idx_h) + bias 51 | left = np.min(idx_w) - bias 52 | right = np.max(idx_w) + bias 53 | 54 | if right - left >= down - up: 55 | mid_h = (up + down) / 2 56 | len = right - left 57 | up = int(mid_h - (len / 2)) 58 | down = int(mid_h + (len / 2)) 59 | 60 | # check bbox 61 | if up < 0: 62 | up = 0 63 | down = up + len 64 | elif down >= h: 65 | down = h - 1 66 | up = down - len 67 | 68 | if left < 0: 69 | left = 0 70 | right = left + len 71 | elif right >= w: 72 | right = w - 1 73 | left = right - len 74 | else: 75 | mid_w = (left + right) / 2 76 | len = down - up 77 | left = int(mid_w - (len / 2)) 78 | right = int(mid_w + (len / 2)) 79 | 80 | # check bbox 81 | if up < 0: 82 | up = 0 83 | down = up + len 84 | elif down >= h: 85 | down = h - 1 86 | up = down - len 87 | 88 | if left < 0: 89 | left = 0 90 | right = left + len 91 | elif right >= w: 92 | right = w - 1 93 | left = right - len 94 | 95 | return up, down, left, right 96 | 97 | def label_processing_cropping(label_path, label_id, bias=8): 98 | label_img = cv2.imread(label_path) 99 | h = label_img.shape[0] 100 | w = label_img.shape[1] 101 | 102 | bg_h, bh_w, _ = np.where(label_img == 0) 103 | if bg_h.shape[0] >= 16: 104 | # is real mask 105 | up, down, left, right = lable_cropping(label_path, label_id=254) 106 | else: 107 | up, down, left, right = lable_cropping(label_path, label_id * 3 - 1) 108 | 109 | return up, down, left, right 110 | 111 | def get_camera_parameters(extrinsic_mat=None, camera=''): 112 | param = o3d.camera.PinholeCameraParameters() 113 | 114 | if extrinsic_mat == None: 115 | param.extrinsic = np.eye(4, dtype=np.float64) 116 | else: 117 | param.extrinsic = extrinsic_mat 118 | 119 | # param.intrinsic = o3d.camera.PinholeCameraIntrinsic() 120 | 121 | if 'kinect' in camera: 122 | param.intrinsic.set_intrinsics(1280, 720, 631.5, 631.2, 639.5, 359.5) 123 | elif 'realsense' in camera: 124 | param.intrinsic.set_intrinsics(1280, 720, 927.17, 927.37, 639.5, 359.5) 125 | else: 126 | print("Unknow camera type") 127 | exit(0) 128 | 129 | return param 130 | 131 | def scale_to_unit_cube(x): 132 | if len(x) == 0: 133 | return x 134 | centroid = np.mean(x, axis=0) 135 | x -= centroid 136 | furthest_distance = np.max(np.sqrt(np.sum(abs(x) ** 2, axis=-1))) 137 | x /= furthest_distance 138 | return x 139 | 140 | def rotate_shape(x, axis, angle): 141 | """ 142 | Input: 143 | x: pointcloud data, [B, C, N] 144 | axis: axis to do rotation about 145 | angle: rotation angle 146 | Return: 147 | A rotated shape 148 | """ 149 | R_x = np.asarray([[1, 0, 0], [0, np.cos(angle), -np.sin(angle)], [0, np.sin(angle), np.cos(angle)]]) 150 | R_y = np.asarray([[np.cos(angle), 0, np.sin(angle)], [0, 1, 0], [-np.sin(angle), 0, np.cos(angle)]]) 151 | R_z = np.asarray([[np.cos(angle), -np.sin(angle), 0], [np.sin(angle), np.cos(angle), 0], [0, 0, 1]]) 152 | 153 | if axis == "x": 154 | return x.dot(R_x).astype('float32') 155 | elif axis == "y": 156 | return x.dot(R_y).astype('float32') 157 | else: 158 | return x.dot(R_z).astype('float32') 159 | 160 | def rotate_pc(pc): 161 | pc = rotate_shape(pc, 'x', -np.pi / 2) 162 | return pc 163 | 164 | def random_rotate_one_axis(X, axis): 165 | """ 166 | Apply random rotation about one axis 167 | Input: 168 | x: pointcloud data, [B, C, N] 169 | axis: axis to do rotation about 170 | Return: 171 | A rotated shape 172 | """ 173 | rotation_angle = np.random.uniform() * 2 * np.pi 174 | cosval = np.cos(rotation_angle) 175 | sinval = np.sin(rotation_angle) 176 | if axis == 'x': 177 | R_x = [[1, 0, 0], [0, cosval, -sinval], [0, sinval, cosval]] 178 | X = np.matmul(X, R_x) 179 | elif axis == 'y': 180 | R_y = [[cosval, 0, sinval], [0, 1, 0], [-sinval, 0, cosval]] 181 | X = np.matmul(X, R_y) 182 | else: 183 | R_z = [[cosval, -sinval, 0], [sinval, cosval, 0], [0, 0, 1]] 184 | X = np.matmul(X, R_z) 185 | return X.astype('float32') 186 | 187 | def translate_pointcloud(pointcloud): 188 | """ 189 | Input: 190 | pointcloud: pointcloud data, [B, C, N] 191 | Return: 192 | A translated shape 193 | """ 194 | xyz1 = np.random.uniform(low=2./3., high=3./2., size=[3]) 195 | xyz2 = np.random.uniform(low=-0.2, high=0.2, size=[3]) 196 | 197 | translated_pointcloud = np.add(np.multiply(pointcloud, xyz1), xyz2).astype('float32') 198 | return translated_pointcloud 199 | 200 | def jitter_pointcloud(pointcloud, sigma=0.01, clip=0.02): 201 | """ 202 | Input: 203 | pointcloud: pointcloud data, [B, C, N] 204 | sigma: 205 | clip: 206 | Return: 207 | A jittered shape 208 | """ 209 | N, C = pointcloud.shape 210 | pointcloud += np.clip(sigma * np.random.randn(N, C), -1*clip, clip) 211 | return pointcloud.astype('float32') 212 | 213 | def jitter_pointcloud_adaptive(pointcloud): 214 | """ 215 | Input: 216 | pointcloud: pointcloud data, [B, C, N] 217 | sigma: 218 | clip: 219 | Return: 220 | A jittered shape 221 | """ 222 | N, C = pointcloud.shape 223 | 224 | inner = np.matmul(pointcloud, np.transpose(pointcloud, (1, 0))) 225 | pc_2 = np.sum(pointcloud ** 2, axis = 1, keepdims=True) 226 | pairwise_distances = pc_2 - 2 * inner + np.transpose(pc_2, (1, 0)) 227 | zero_mask = np.where(pairwise_distances <= 1e-4) 228 | pairwise_distances[zero_mask] = 9999. 229 | min_distances = np.min(pairwise_distances, axis=1) 230 | min_distances = np.sqrt(min_distances) 231 | 232 | min_distances_expdim = np.expand_dims(min_distances, axis=1) 233 | min_distances_expdim = np.repeat(min_distances_expdim, C, axis=1) 234 | 235 | # pointcloud += np.clip(min_distances_expdim * np.random.randn(N, C), -1 * min_distances_expdim, min_distances_expdim) # normal sampling 236 | pointcloud += np.clip(min_distances_expdim * (np.random.rand(N, C) * 2. - 1.), -1 * min_distances_expdim, min_distances_expdim) # uniform sampling 237 | return pointcloud.astype('float32') 238 | 239 | def pc_preprocessing(pc): 240 | mean = np.mean(pc, axis=0) 241 | pc = pc - mean 242 | pc_max = np.max(np.abs(pc)) 243 | 244 | pc = pc / (pc_max * 1.1) 245 | pc = (pc + 1.) / 2. 246 | return pc 247 | 248 | def nearest_distances(x, y): 249 | # x query, y target 250 | inner = -2 * torch.matmul(x.transpose(2, 1), y) # x B 3 N; y B 3 M 251 | xx = torch.sum(x**2, dim=1, keepdim=True) 252 | yy = torch.sum(y**2, dim=1, keepdim=True) 253 | 254 | pairwise_distance = xx.transpose(2, 1) + inner + yy 255 | nearest_distance = torch.sqrt(torch.min(pairwise_distance, dim=2, keepdim=True).values) 256 | 257 | return nearest_distance 258 | 259 | def self_nearest_distances(x): 260 | inner = -2 * torch.matmul(x.transpose(2, 1), x) # x B 3 N 261 | xx = torch.sum(x**2, dim=1, keepdim=True) 262 | 263 | pairwise_distance = xx.transpose(2, 1) + inner + xx 264 | pairwise_distance += torch.eye(x.shape[2]).to(pairwise_distance.device) * 2 265 | nearest_distance = torch.sqrt(torch.min(pairwise_distance, dim=2, keepdim=True).values) 266 | 267 | return nearest_distance 268 | 269 | def self_nearest_distances_K(x, k=3): 270 | inner = -2 * torch.matmul(x.transpose(2, 1), x) # x B 3 N 271 | xx = torch.sum(x**2, dim=1, keepdim=True) 272 | 273 | pairwise_distance = xx.transpose(2, 1) + inner + xx 274 | pairwise_distance += torch.eye(x.shape[2]).to(pairwise_distance.device) * 2 275 | pairwise_distance *= -1 276 | k_nearest_distance = pairwise_distance.topk(k=k, dim=2)[0] 277 | k_nearest_distance *= -1 278 | 279 | nearest_distance = torch.sqrt(torch.mean(k_nearest_distance, dim=2, keepdim=True)) 280 | 281 | return nearest_distance 282 | 283 | def write_pc(point_cloud, output_path): 284 | point_cloud_o3d = o3d.geometry.PointCloud() 285 | point_cloud_o3d.points = o3d.utilit.Vector3dVector(point_cloud) 286 | o3d.io.write_point_cloud(output_path, point_cloud_o3d) -------------------------------------------------------------------------------- /data/dataloader_PointDA_initial.py: -------------------------------------------------------------------------------- 1 | # without metasets processing 2 | import os 3 | import glob 4 | import h5py 5 | import numpy as np 6 | from torch.utils.data import Dataset 7 | from utils.pc_utils_Norm import (farthest_point_sample_no_curv_np, scale_to_unit_cube, jitter_pointcloud, 8 | rotate_shape, random_rotate_one_axis) 9 | from augmentation import density, drop_hole, p_scan 10 | 11 | 12 | eps = 10e-4 13 | NUM_POINTS = 1024 14 | idx_to_label = {0: "bathtub", 1: "bed", 2: "bookshelf", 3: "cabinet", 15 | 4: "chair", 5: "lamp", 6: "monitor", 16 | 7: "plant", 8: "sofa", 9: "table"} 17 | label_to_idx = {"bathtub": 0, "bed": 1, "bookshelf": 2, "cabinet": 3, 18 | "chair": 4, "lamp": 5, "monitor": 6, 19 | "plant": 7, "sofa": 8, "table": 9} 20 | 21 | 22 | def load_data_h5py_scannet10(partition, dataroot): 23 | """ 24 | Input: 25 | partition - train/test 26 | Return: 27 | data,label arrays 28 | """ 29 | DATA_DIR = dataroot + '/PointDA_data/scannet/' 30 | # DATA_DIR = dataroot + '/PointDA_data/scannet' 31 | all_data = [] 32 | all_label = [] 33 | for h5_name in sorted(glob.glob(os.path.join(DATA_DIR, '%s_*.h5' % partition))): 34 | f = h5py.File(h5_name, 'r') 35 | data = f['data'][:] 36 | label = f['label'][:] 37 | f.close() 38 | all_data.append(data) 39 | all_label.append(label) 40 | all_data = np.concatenate(all_data, axis=0) 41 | all_label = np.concatenate(all_label, axis=0) 42 | return np.array(all_data).astype('float32'), np.array(all_label).astype('int64') 43 | 44 | 45 | class ScanNet(Dataset): 46 | """ 47 | scannet dataset for pytorch dataloader 48 | """ 49 | def __init__(self, io, dataroot, partition='train', random_rotation=True): 50 | self.partition = partition 51 | self.random_rotation = random_rotation 52 | 53 | # read data 54 | self.data, self.label = load_data_h5py_scannet10(self.partition, dataroot) 55 | self.num_examples = self.data.shape[0] 56 | 57 | # split train to train part and validation part 58 | if partition == "train": 59 | self.train_ind = np.asarray([i for i in range(self.num_examples) if i % 10 < 8]).astype(np.int) 60 | np.random.shuffle(self.train_ind) 61 | self.val_ind = np.asarray([i for i in range(self.num_examples) if i % 10 >= 8]).astype(np.int) 62 | np.random.shuffle(self.val_ind) 63 | 64 | io.cprint("number of " + partition + " examples in scannet" + ": " + str(self.data.shape[0])) 65 | unique, counts = np.unique(self.label, return_counts=True) 66 | io.cprint("Occurrences count of classes in scannet " + partition + " set: " + str(dict(zip(unique, counts)))) 67 | 68 | def __getitem__(self, item): 69 | pointcloud = np.copy(self.data[item])[:, :3] 70 | label = np.copy(self.label[item]) 71 | pointcloud = scale_to_unit_cube(pointcloud) 72 | # Rotate ScanNet by -90 degrees 73 | pointcloud = self.rotate_pc(pointcloud) 74 | # sample according to farthest point sampling 75 | if pointcloud.shape[0] > NUM_POINTS: 76 | pointcloud = np.swapaxes(np.expand_dims(pointcloud, 0), 1, 2) 77 | _, pointcloud = farthest_point_sample_no_curv_np(pointcloud, NUM_POINTS) 78 | pointcloud = np.swapaxes(pointcloud.squeeze(), 1, 0).astype('float32') 79 | 80 | # apply data rotation and augmentation on train samples 81 | if self.partition == 'train': 82 | pointcloud = jitter_pointcloud(pointcloud) 83 | if self.random_rotation==True: 84 | pointcloud = random_rotate_one_axis(pointcloud, "z") 85 | 86 | if self.partition == 'train': 87 | pointcloud_aug = pointcloud 88 | if np.random.random() > 0.5: 89 | pointcloud_aug = density(pointcloud_aug) 90 | if np.random.random() > 0.5: 91 | pointcloud_aug = drop_hole(pointcloud_aug) 92 | if np.random.random() > 0.5: 93 | pointcloud_aug = p_scan(pointcloud_aug) 94 | else: 95 | pointcloud_aug = pointcloud 96 | 97 | return (item, pointcloud, label, pointcloud_aug) 98 | 99 | def __len__(self): 100 | return self.data.shape[0] 101 | 102 | # scannet is rotated such that the up direction is the y axis 103 | def rotate_pc(self, pointcloud): 104 | pointcloud = rotate_shape(pointcloud, 'x', -np.pi / 2) 105 | return pointcloud 106 | 107 | 108 | class ModelNet(Dataset): 109 | """ 110 | modelnet dataset for pytorch dataloader 111 | """ 112 | def __init__(self, io, dataroot, partition='train', random_rotation=True): 113 | self.partition = partition 114 | self.random_rotation = random_rotation 115 | 116 | self.pc_list = [] 117 | self.lbl_list = [] 118 | DATA_DIR = os.path.join(dataroot, "PointDA_data", "modelnet") 119 | 120 | npy_list = sorted(glob.glob(os.path.join(DATA_DIR, '*', partition, '*.npy'))) 121 | 122 | for _dir in npy_list: 123 | self.pc_list.append(_dir) 124 | self.lbl_list.append(label_to_idx[_dir.split('/')[-3]]) 125 | 126 | self.label = np.asarray(self.lbl_list) 127 | self.num_examples = len(self.pc_list) 128 | 129 | # split train to train part and validation part 130 | if partition == "train": 131 | self.train_ind = np.asarray([i for i in range(self.num_examples) if i % 10 < 8]).astype(np.int) 132 | np.random.shuffle(self.train_ind) 133 | self.val_ind = np.asarray([i for i in range(self.num_examples) if i % 10 >= 8]).astype(np.int) 134 | np.random.shuffle(self.val_ind) 135 | 136 | io.cprint("number of " + partition + " examples in modelnet : " + str(len(self.pc_list))) 137 | unique, counts = np.unique(self.label, return_counts=True) 138 | io.cprint("Occurrences count of classes in modelnet " + partition + " set: " + str(dict(zip(unique, counts)))) 139 | 140 | def __getitem__(self, item): 141 | pointcloud = np.load(self.pc_list[item])[:, :3].astype(np.float32) 142 | label = np.copy(self.label[item]) 143 | pointcloud = scale_to_unit_cube(pointcloud) 144 | # sample according to farthest point sampling 145 | if pointcloud.shape[0] > NUM_POINTS: 146 | pointcloud = np.swapaxes(np.expand_dims(pointcloud, 0), 1, 2) 147 | _, pointcloud = farthest_point_sample_no_curv_np(pointcloud, NUM_POINTS) 148 | pointcloud = np.swapaxes(pointcloud.squeeze(), 1, 0).astype('float32') 149 | 150 | # apply data rotation and augmentation on train samples 151 | if self.partition == 'train': 152 | pointcloud = jitter_pointcloud(pointcloud) 153 | if self.random_rotation==True: 154 | pointcloud = random_rotate_one_axis(pointcloud, "z") 155 | 156 | if self.partition == 'train': 157 | pointcloud_aug = pointcloud 158 | if np.random.random() > 0.5: 159 | pointcloud_aug = density(pointcloud_aug) 160 | if np.random.random() > 0.5: 161 | pointcloud_aug = drop_hole(pointcloud_aug) 162 | if np.random.random() > 0.5: 163 | pointcloud_aug = p_scan(pointcloud_aug) 164 | else: 165 | pointcloud_aug = pointcloud 166 | 167 | return (item, pointcloud, label, pointcloud_aug) 168 | 169 | def __len__(self): 170 | return len(self.pc_list) 171 | 172 | 173 | class ShapeNet(Dataset): 174 | """ 175 | Sahpenet dataset for pytorch dataloader 176 | """ 177 | def __init__(self, io, dataroot, partition='train', random_rotation=True): 178 | self.partition = partition 179 | self.random_rotation = random_rotation 180 | 181 | self.pc_list = [] 182 | self.lbl_list = [] 183 | DATA_DIR = os.path.join(dataroot, "PointDA_data", "shapenet") 184 | npy_list = sorted(glob.glob(os.path.join(DATA_DIR, '*', partition, '*.npy'))) 185 | 186 | for _dir in npy_list: 187 | self.pc_list.append(_dir) 188 | self.lbl_list.append(label_to_idx[_dir.split('/')[-3]]) 189 | 190 | self.label = np.asarray(self.lbl_list) 191 | self.num_examples = len(self.pc_list) 192 | 193 | # split train to train part and validation part 194 | if partition == "train": 195 | self.train_ind = np.asarray([i for i in range(self.num_examples) if i % 10 < 8]).astype(np.int) 196 | np.random.shuffle(self.train_ind) 197 | self.val_ind = np.asarray([i for i in range(self.num_examples) if i % 10 >= 8]).astype(np.int) 198 | np.random.shuffle(self.val_ind) 199 | 200 | io.cprint("number of " + partition + " examples in shapenet: " + str(len(self.pc_list))) 201 | unique, counts = np.unique(self.label, return_counts=True) 202 | io.cprint("Occurrences count of classes in shapenet " + partition + " set: " + str(dict(zip(unique, counts)))) 203 | 204 | def __getitem__(self, item): 205 | pointcloud = np.load(self.pc_list[item])[:, :3].astype(np.float32) 206 | label = np.copy(self.label[item]) 207 | pointcloud = scale_to_unit_cube(pointcloud) 208 | # Rotate ShapeNet by -90 degrees 209 | pointcloud = self.rotate_pc(pointcloud, label) 210 | # sample according to farthest point sampling 211 | if pointcloud.shape[0] > NUM_POINTS: 212 | pointcloud = np.swapaxes(np.expand_dims(pointcloud, 0), 1, 2) 213 | _, pointcloud = farthest_point_sample_no_curv_np(pointcloud, NUM_POINTS) 214 | pointcloud = np.swapaxes(pointcloud.squeeze(), 1, 0).astype('float32') 215 | 216 | # apply data rotation and augmentation on train samples 217 | if self.partition == 'train': 218 | pointcloud = jitter_pointcloud(pointcloud) 219 | if self.random_rotation == True: 220 | pointcloud = random_rotate_one_axis(pointcloud, "z") 221 | 222 | if self.partition == 'train': 223 | pointcloud_aug = pointcloud 224 | if np.random.random() > 0.5: 225 | pointcloud_aug = density(pointcloud_aug) 226 | if np.random.random() > 0.5: 227 | pointcloud_aug = drop_hole(pointcloud_aug) 228 | if np.random.random() > 0.5: 229 | pointcloud_aug = p_scan(pointcloud_aug) 230 | else: 231 | pointcloud_aug = pointcloud 232 | 233 | return (item, pointcloud, label, pointcloud_aug) 234 | 235 | def __len__(self): 236 | return len(self.pc_list) 237 | 238 | # shpenet is rotated such that the up direction is the y axis in all shapes except plant 239 | def rotate_pc(self, pointcloud, label): 240 | if label.item(0) != label_to_idx["plant"]: 241 | pointcloud = rotate_shape(pointcloud, 'x', -np.pi / 2) 242 | return pointcloud 243 | -------------------------------------------------------------------------------- /models/pointnet_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from time import time 5 | import numpy as np 6 | 7 | def timeit(tag, t): 8 | print("{}: {}s".format(tag, time() - t)) 9 | return time() 10 | 11 | def pc_normalize(pc): 12 | l = pc.shape[0] 13 | centroid = np.mean(pc, axis=0) 14 | pc = pc - centroid 15 | m = np.max(np.sqrt(np.sum(pc**2, axis=1))) 16 | pc = pc / m 17 | return pc 18 | 19 | def square_distance(src, dst): 20 | """ 21 | Calculate Euclid distance between each two points. 22 | 23 | src^T * dst = xn * xm + yn * ym + zn * zm; 24 | sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn; 25 | sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm; 26 | dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2 27 | = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst 28 | 29 | Input: 30 | src: source points, [B, N, C] 31 | dst: target points, [B, M, C] 32 | Output: 33 | dist: per-point square distance, [B, N, M] 34 | """ 35 | B, N, _ = src.shape 36 | _, M, _ = dst.shape 37 | dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) 38 | dist += torch.sum(src ** 2, -1).view(B, N, 1) 39 | dist += torch.sum(dst ** 2, -1).view(B, 1, M) 40 | return dist 41 | 42 | 43 | def index_points(points, idx): 44 | """ 45 | 46 | Input: 47 | points: input points data, [B, N, C] 48 | idx: sample index data, [B, S] 49 | Return: 50 | new_points:, indexed points data, [B, S, C] 51 | """ 52 | device = points.device 53 | B = points.shape[0] 54 | view_shape = list(idx.shape) 55 | view_shape[1:] = [1] * (len(view_shape) - 1) 56 | repeat_shape = list(idx.shape) 57 | repeat_shape[0] = 1 58 | batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape) 59 | new_points = points[batch_indices, idx, :] 60 | return new_points 61 | 62 | 63 | def farthest_point_sample(xyz, npoint): 64 | """ 65 | Input: 66 | xyz: pointcloud data, [B, N, 3] 67 | npoint: number of samples 68 | Return: 69 | centroids: sampled pointcloud index, [B, npoint] 70 | """ 71 | device = xyz.device 72 | B, N, C = xyz.shape 73 | centroids = torch.zeros(B, npoint, dtype=torch.long).to(device) 74 | distance = torch.ones(B, N).to(device) * 1e10 75 | farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) 76 | batch_indices = torch.arange(B, dtype=torch.long).to(device) 77 | for i in range(npoint): 78 | centroids[:, i] = farthest 79 | centroid = xyz[batch_indices, farthest, :].view(B, 1, 3) 80 | dist = torch.sum((xyz - centroid) ** 2, -1) 81 | mask = dist < distance 82 | distance[mask] = dist[mask] 83 | farthest = torch.max(distance, -1)[1] 84 | return centroids 85 | 86 | 87 | def query_ball_point(radius, nsample, xyz, new_xyz): 88 | """ 89 | Input: 90 | radius: local region radius 91 | nsample: max sample number in local region 92 | xyz: all points, [B, N, 3] 93 | new_xyz: query points, [B, S, 3] 94 | Return: 95 | group_idx: grouped points index, [B, S, nsample] 96 | """ 97 | device = xyz.device 98 | B, N, C = xyz.shape 99 | _, S, _ = new_xyz.shape 100 | group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1]) 101 | sqrdists = square_distance(new_xyz, xyz) 102 | group_idx[sqrdists > radius ** 2] = N 103 | group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample] 104 | group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample]) 105 | mask = group_idx == N 106 | group_idx[mask] = group_first[mask] 107 | return group_idx 108 | 109 | 110 | def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False): 111 | """ 112 | Input: 113 | npoint: 114 | radius: 115 | nsample: 116 | xyz: input points position data, [B, N, 3] 117 | points: input points data, [B, N, D] 118 | Return: 119 | new_xyz: sampled points position data, [B, npoint, nsample, 3] 120 | new_points: sampled points data, [B, npoint, nsample, 3+D] 121 | """ 122 | B, N, C = xyz.shape 123 | S = npoint 124 | if S == N: 125 | new_xyz = xyz 126 | else: 127 | fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C] 128 | new_xyz = index_points(xyz, fps_idx) 129 | idx = query_ball_point(radius, nsample, xyz, new_xyz) 130 | grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C] 131 | grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C) 132 | 133 | if points is not None: 134 | grouped_points = index_points(points, idx) 135 | new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D] 136 | else: 137 | new_points = grouped_xyz_norm 138 | if returnfps: 139 | return new_xyz, new_points, grouped_xyz, fps_idx 140 | else: 141 | return new_xyz, new_points 142 | 143 | 144 | def sample_and_group_all(xyz, points): 145 | """ 146 | Input: 147 | xyz: input points position data, [B, N, 3] 148 | points: input points data, [B, N, D] 149 | Return: 150 | new_xyz: sampled points position data, [B, 1, 3] 151 | new_points: sampled points data, [B, 1, N, 3+D] 152 | """ 153 | device = xyz.device 154 | B, N, C = xyz.shape 155 | new_xyz = torch.zeros(B, 1, C).to(device) 156 | grouped_xyz = xyz.view(B, 1, N, C) 157 | if points is not None: 158 | new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1) 159 | else: 160 | new_points = grouped_xyz 161 | return new_xyz, new_points 162 | 163 | 164 | class PointNetSetAbstraction(nn.Module): 165 | def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all): 166 | super(PointNetSetAbstraction, self).__init__() 167 | self.npoint = npoint 168 | self.radius = radius 169 | self.nsample = nsample 170 | self.mlp_convs = nn.ModuleList() 171 | self.mlp_bns = nn.ModuleList() 172 | last_channel = in_channel 173 | for out_channel in mlp: 174 | self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1)) 175 | self.mlp_bns.append(nn.BatchNorm2d(out_channel)) 176 | last_channel = out_channel 177 | self.group_all = group_all 178 | 179 | def forward(self, xyz, points): 180 | """ 181 | Input: 182 | xyz: input points position data, [B, C, N] 183 | points: input points data, [B, D, N] 184 | Return: 185 | new_xyz: sampled points position data, [B, C, S] 186 | new_points_concat: sample points feature data, [B, D', S] 187 | """ 188 | xyz = xyz.permute(0, 2, 1) 189 | if points is not None: 190 | points = points.permute(0, 2, 1) 191 | 192 | if self.group_all: 193 | new_xyz, new_points = sample_and_group_all(xyz, points) 194 | else: 195 | new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points) 196 | # new_xyz: sampled points position data, [B, npoint, C] 197 | # new_points: sampled points data, [B, npoint, nsample, C+D] 198 | new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint] 199 | for i, conv in enumerate(self.mlp_convs): 200 | bn = self.mlp_bns[i] 201 | new_points = F.relu(bn(conv(new_points))) 202 | 203 | if self.group_all: 204 | new_points = new_points 205 | else: 206 | new_points = torch.max(new_points, 2)[0] 207 | new_xyz = new_xyz.permute(0, 2, 1) 208 | return new_xyz, new_points 209 | 210 | 211 | class PointNetSetAbstractionMsg(nn.Module): 212 | def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list): 213 | super(PointNetSetAbstractionMsg, self).__init__() 214 | self.npoint = npoint 215 | self.radius_list = radius_list 216 | self.nsample_list = nsample_list 217 | self.conv_blocks = nn.ModuleList() 218 | self.bn_blocks = nn.ModuleList() 219 | for i in range(len(mlp_list)): 220 | convs = nn.ModuleList() 221 | bns = nn.ModuleList() 222 | last_channel = in_channel + 3 223 | for out_channel in mlp_list[i]: 224 | convs.append(nn.Conv2d(last_channel, out_channel, 1)) 225 | bns.append(nn.BatchNorm2d(out_channel)) 226 | last_channel = out_channel 227 | self.conv_blocks.append(convs) 228 | self.bn_blocks.append(bns) 229 | 230 | def forward(self, xyz, points): 231 | """ 232 | Input: 233 | xyz: input points position data, [B, C, N] 234 | points: input points data, [B, D, N] 235 | Return: 236 | new_xyz: sampled points position data, [B, C, S] 237 | new_points_concat: sample points feature data, [B, D', S] 238 | """ 239 | xyz = xyz.permute(0, 2, 1) 240 | if points is not None: 241 | points = points.permute(0, 2, 1) 242 | 243 | B, N, C = xyz.shape 244 | S = self.npoint 245 | new_xyz = index_points(xyz, farthest_point_sample(xyz, S)) 246 | new_points_list = [] 247 | for i, radius in enumerate(self.radius_list): 248 | K = self.nsample_list[i] 249 | group_idx = query_ball_point(radius, K, xyz, new_xyz) 250 | grouped_xyz = index_points(xyz, group_idx) 251 | grouped_xyz -= new_xyz.view(B, S, 1, C) 252 | if points is not None: 253 | grouped_points = index_points(points, group_idx) 254 | grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1) 255 | else: 256 | grouped_points = grouped_xyz 257 | 258 | grouped_points = grouped_points.permute(0, 3, 2, 1) # [B, D, K, S] 259 | for j in range(len(self.conv_blocks[i])): 260 | conv = self.conv_blocks[i][j] 261 | bn = self.bn_blocks[i][j] 262 | grouped_points = F.relu(bn(conv(grouped_points))) 263 | new_points = torch.max(grouped_points, 2)[0] # [B, D', S] 264 | new_points_list.append(new_points) 265 | 266 | new_xyz = new_xyz.permute(0, 2, 1) 267 | new_points_concat = torch.cat(new_points_list, dim=1) 268 | return new_xyz, new_points_concat 269 | 270 | 271 | class PointNetFeaturePropagation(nn.Module): 272 | def __init__(self, in_channel, mlp): 273 | super(PointNetFeaturePropagation, self).__init__() 274 | self.mlp_convs = nn.ModuleList() 275 | self.mlp_bns = nn.ModuleList() 276 | last_channel = in_channel 277 | for out_channel in mlp: 278 | self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1)) 279 | self.mlp_bns.append(nn.BatchNorm1d(out_channel)) 280 | last_channel = out_channel 281 | 282 | def forward(self, xyz1, xyz2, points1, points2): 283 | """ 284 | Input: 285 | xyz1: input points position data, [B, C, N] 286 | xyz2: sampled input points position data, [B, C, S] 287 | points1: input points data, [B, D, N] 288 | points2: input points data, [B, D, S] 289 | Return: 290 | new_points: upsampled points data, [B, D', N] 291 | """ 292 | xyz1 = xyz1.permute(0, 2, 1) 293 | xyz2 = xyz2.permute(0, 2, 1) 294 | 295 | points2 = points2.permute(0, 2, 1) 296 | B, N, C = xyz1.shape 297 | _, S, _ = xyz2.shape 298 | 299 | if S == 1: 300 | interpolated_points = points2.repeat(1, N, 1) 301 | else: 302 | dists = square_distance(xyz1, xyz2) 303 | dists, idx = dists.sort(dim=-1) 304 | dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3] 305 | 306 | dist_recip = 1.0 / (dists + 1e-8) 307 | norm = torch.sum(dist_recip, dim=2, keepdim=True) 308 | weight = dist_recip / norm 309 | interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2) 310 | 311 | if points1 is not None: 312 | points1 = points1.permute(0, 2, 1) 313 | new_points = torch.cat([points1, interpolated_points], dim=-1) 314 | else: 315 | new_points = interpolated_points 316 | 317 | new_points = new_points.permute(0, 2, 1) 318 | for i, conv in enumerate(self.mlp_convs): 319 | bn = self.mlp_bns[i] 320 | new_points = F.relu(bn(conv(new_points))) 321 | return new_points 322 | 323 | -------------------------------------------------------------------------------- /critic.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import pdb 5 | import gc 6 | import math 7 | from sklearn import manifold 8 | 9 | 10 | class EstimatorCV(): 11 | def __init__(self, feature_num, class_num, device): 12 | super(EstimatorCV, self).__init__() 13 | self.class_num = class_num 14 | self.device = device 15 | 16 | self.CoVariance = torch.zeros(class_num, feature_num, feature_num).to(device) 17 | self.Ave = torch.zeros(class_num, feature_num).to(device) 18 | self.Amount = torch.zeros(class_num).to(device) 19 | 20 | def update_CV(self, features, labels): 21 | N = features.size(0) 22 | C = self.class_num 23 | A = features.size(1) 24 | 25 | NxCxFeatures = features.view( 26 | N, 1, A 27 | ).expand( 28 | N, C, A 29 | ) 30 | onehot = torch.zeros(N, C).to(self.device) 31 | onehot.scatter_(1, labels.view(-1, 1), 1) 32 | 33 | NxCxA_onehot = onehot.view(N, C, 1).expand(N, C, A) 34 | 35 | features_by_sort = NxCxFeatures.mul(NxCxA_onehot) # feature of a certain class 36 | 37 | Amount_CxA = NxCxA_onehot.sum(0) 38 | Amount_CxA[Amount_CxA == 0] = 1 39 | 40 | ave_CxA = features_by_sort.sum(0) / Amount_CxA 41 | 42 | var_temp = features_by_sort - \ 43 | ave_CxA.expand(N, C, A).mul(NxCxA_onehot) 44 | 45 | var_temp = torch.bmm( 46 | var_temp.permute(1, 2, 0), 47 | var_temp.permute(1, 0, 2) 48 | ).div(Amount_CxA.view(C, A, 1).expand(C, A, A)) 49 | 50 | sum_weight_CV = onehot.sum(0).view(C, 1, 1).expand(C, A, A) 51 | 52 | sum_weight_AV = onehot.sum(0).view(C, 1).expand(C, A) 53 | 54 | weight_CV = sum_weight_CV.div( 55 | sum_weight_CV + self.Amount.view(C, 1, 1).expand(C, A, A) 56 | ) 57 | weight_CV[weight_CV != weight_CV] = 0 58 | 59 | weight_AV = sum_weight_AV.div( 60 | sum_weight_AV + self.Amount.view(C, 1).expand(C, A) 61 | ) 62 | weight_AV[weight_AV != weight_AV] = 0 63 | 64 | additional_CV = weight_CV.mul(1 - weight_CV).mul( 65 | torch.bmm( 66 | (self.Ave - ave_CxA).view(C, A, 1), 67 | (self.Ave - ave_CxA).view(C, 1, A) 68 | ) 69 | ) 70 | 71 | self.CoVariance = (self.CoVariance.mul(1 - weight_CV) + var_temp 72 | .mul(weight_CV)).detach() + additional_CV.detach() 73 | 74 | self.Ave = (self.Ave.mul(1 - weight_AV) + ave_CxA.mul(weight_AV)).detach() 75 | 76 | self.Amount += onehot.sum(0) 77 | 78 | 79 | class ISDALoss(nn.Module): 80 | def __init__(self, feature_num, class_num, device): 81 | super(ISDALoss, self).__init__() 82 | 83 | self.estimator = EstimatorCV(feature_num, class_num, device) 84 | 85 | self.class_num = class_num 86 | self.device = device 87 | 88 | self.cross_entropy = nn.CrossEntropyLoss() 89 | 90 | def isda_aug(self, fc, features, y, labels, cv_matrix, ratio): 91 | 92 | N = features.size(0) 93 | C = self.class_num 94 | A = features.size(1) 95 | 96 | weight_m = list(fc.parameters())[0] 97 | 98 | NxW_ij = weight_m.expand(N, C, A) 99 | 100 | NxW_kj = torch.gather(NxW_ij, 101 | 1, 102 | labels.view(N, 1, 1) 103 | .expand(N, C, A)) 104 | 105 | CV_temp = cv_matrix[labels] 106 | 107 | # sigma2 = ratio * \ 108 | # torch.bmm(torch.bmm(NxW_ij - NxW_kj, 109 | # CV_temp).view(N * C, 1, A), 110 | # (NxW_ij - NxW_kj).view(N * C, A, 1)).view(N, C) 111 | 112 | sigma2 = ratio * \ 113 | torch.bmm(torch.bmm(NxW_ij - NxW_kj, 114 | CV_temp), 115 | (NxW_ij - NxW_kj).permute(0, 2, 1)) 116 | 117 | sigma2 = sigma2.mul(torch.eye(C).to(self.device) 118 | .expand(N, C, C)).sum(2).view(N, C) 119 | 120 | aug_result = y + 0.5 * sigma2 121 | 122 | return aug_result 123 | 124 | def forward(self, model, fc, x, target_x, ratio): 125 | 126 | logits = model(x) 127 | features = logits['feature'] 128 | y = logits['pred'] 129 | 130 | self.estimator.update_CV(features.detach(), target_x) 131 | 132 | isda_aug_y = self.isda_aug(fc, features, y, target_x, self.estimator.CoVariance.detach(), ratio) 133 | 134 | loss = self.cross_entropy(isda_aug_y, target_x) 135 | 136 | return loss, y 137 | 138 | 139 | def MI(outputs_target): 140 | batch_size = outputs_target.size(0) 141 | softmax_outs_t = nn.Softmax(dim=1)(outputs_target) 142 | avg_softmax_outs_t = torch.sum(softmax_outs_t, dim=0) / float(batch_size) 143 | log_avg_softmax_outs_t = torch.log(avg_softmax_outs_t) 144 | item1 = -torch.sum(avg_softmax_outs_t * log_avg_softmax_outs_t) 145 | item2 = -torch.sum(softmax_outs_t * torch.log(softmax_outs_t)) / float(batch_size) 146 | return item1 - item2 147 | 148 | def CalculateMean(features, labels, class_num): 149 | device = features.device 150 | N = features.size(0) # size of the pool 151 | C = class_num 152 | A = features.size(1) # dimension of the feature 153 | 154 | avg_CxA = torch.zeros(C, A).to(device) 155 | NxCxFeatures = features.view(N, 1, A).expand(N, C, A) 156 | 157 | onehot = torch.zeros(N, C).to(device) 158 | onehot.scatter_(1, labels.view(-1, 1), 1) 159 | NxCxA_onehot = onehot.view(N, C, 1).expand(N, C, A) 160 | 161 | Amount_CxA = NxCxA_onehot.sum(0) 162 | Amount_CxA[Amount_CxA == 0] = 1.0 163 | 164 | del onehot 165 | gc.collect() 166 | for c in range(class_num): 167 | c_temp = NxCxFeatures[:, c, :].mul(NxCxA_onehot[:, c, :]) 168 | c_temp = torch.sum(c_temp, dim=0) 169 | avg_CxA[c] = c_temp / Amount_CxA[c] 170 | return avg_CxA.detach() 171 | 172 | def Calculate_CV(features, labels, ave_CxA, class_num): 173 | device = features.device 174 | N = features.size(0) 175 | C = class_num 176 | A = features.size(1) 177 | 178 | var_temp = torch.zeros(C, A, A).to(device) 179 | NxCxFeatures = features.view(N, 1, A).expand(N, C, A) 180 | 181 | onehot = torch.zeros(N, C).to(device) 182 | onehot.scatter_(1, labels.view(-1, 1), 1) 183 | NxCxA_onehot = onehot.view(N, C, 1).expand(N, C, A) 184 | 185 | Amount_CxA = NxCxA_onehot.sum(0) 186 | Amount_CxA[Amount_CxA == 0] = 1 187 | Amount_CxAxA = Amount_CxA.view(C, A, 1).expand(C, A, A) 188 | del Amount_CxA, onehot 189 | gc.collect() 190 | 191 | avg_NxCxA = ave_CxA.expand(N, C, A) 192 | for c in range(C): 193 | features_by_sort_c = NxCxFeatures[:, c, :].mul(NxCxA_onehot[:, c, :]) 194 | avg_by_sort_c = avg_NxCxA[:, c, :].mul(NxCxA_onehot[:, c, :]) 195 | var_temp_c = features_by_sort_c - avg_by_sort_c 196 | var_temp[c] = torch.mm(var_temp_c.permute(1,0), var_temp_c).div(Amount_CxAxA[c]) 197 | return var_temp.detach() 198 | 199 | 200 | class TSALoss(nn.Module): 201 | def __init__(self, class_num): 202 | super(TSALoss, self).__init__() 203 | self.class_num = class_num 204 | self.cross_entropy = nn.CrossEntropyLoss() 205 | 206 | def aug(self, s_mean_matrix, t_mean_matrix, fc, features, y_s, labels_s, t_cv_matrix, Lambda): 207 | device = features.device 208 | N = features.size(0) 209 | C = self.class_num 210 | A = features.size(1) 211 | 212 | weight_m = list(fc.parameters())[0] 213 | NxW_ij = weight_m.expand(N, C, A) 214 | NxW_kj = torch.gather(NxW_ij, 1, labels_s.view(N, 1, 1).expand(N, C, A)) 215 | 216 | t_CV_temp = t_cv_matrix[labels_s] 217 | 218 | sigma2 = Lambda * torch.bmm(torch.bmm(NxW_ij - NxW_kj, t_CV_temp), (NxW_ij - NxW_kj).permute(0, 2, 1)) 219 | sigma2 = sigma2.mul(torch.eye(C).to(device).expand(N, C, C)).sum(2).view(N, C) 220 | 221 | sourceMean_NxA = s_mean_matrix[labels_s] 222 | targetMean_NxA = t_mean_matrix[labels_s] 223 | dataMean_NxA = (targetMean_NxA - sourceMean_NxA) 224 | dataMean_NxAx1 = dataMean_NxA.expand(1, N, A).permute(1, 2, 0) 225 | 226 | del t_CV_temp, sourceMean_NxA, targetMean_NxA, dataMean_NxA 227 | gc.collect() 228 | 229 | dataW_NxCxA = NxW_ij - NxW_kj 230 | dataW_x_detaMean_NxCx1 = torch.bmm(dataW_NxCxA, dataMean_NxAx1) 231 | datW_x_detaMean_NxC = dataW_x_detaMean_NxCx1.view(N, C) 232 | 233 | aug_result = y_s + 0.5 * sigma2 + Lambda * datW_x_detaMean_NxC 234 | return aug_result 235 | 236 | def forward(self, fc, features_source: torch.Tensor, y_s, labels_source, Lambda, mean_source, mean_target, covariance_target): 237 | aug_y = self.aug(mean_source, mean_target, fc, features_source, y_s, labels_source, covariance_target, Lambda) 238 | loss = self.cross_entropy(aug_y, labels_source) 239 | return loss 240 | 241 | 242 | def CalculateSelectedMean(feature, label, indicator, class_num): 243 | # feature: feature pool [N, A] 244 | # label: label pool [N] 245 | # indicator: indicate whether the sample is selected or not [N] 246 | device = feature.device 247 | 248 | N = feature.size(0) # size of the pool, s_len + t_len 249 | C = class_num 250 | A = feature.size(1) # dimension of the feature 251 | 252 | avg_CxA = torch.zeros(C, A).to(device) 253 | 254 | NxCxFeatures = feature.view(N, 1, A).expand(N, C, A) 255 | 256 | onehot = torch.zeros(N, C).to(device) 257 | onehot.scatter_(1, label.view(-1, 1), 1) 258 | NxCxA_onehot = onehot.view(N, C, 1).expand(N, C, A) # if class_i = j, [i, :, j] = 1, else 0 259 | NxCxA_onehot = NxCxA_onehot * indicator.view(N, 1, 1).expand(N, C, A) # incase some samples are not selected and would be treated as label = 0 260 | 261 | Amount_CxA = NxCxA_onehot.sum(0) # [C, A], we need to calculate the mean for each channel for each category 262 | Amount_CxA[Amount_CxA == 0] = 1.0 263 | 264 | del onehot 265 | gc.collect() 266 | 267 | for c in range(class_num): 268 | c_temp = NxCxFeatures[:, c, :].mul(NxCxA_onehot[:, c, :]) 269 | c_temp = torch.sum(c_temp, dim=0) 270 | avg_CxA[c] = c_temp / Amount_CxA[c] 271 | return avg_CxA.detach() 272 | 273 | 274 | def CalculateSelectedCV(feature, label, indicator, mean_pool, class_num): 275 | # feature: feature pool [N, A] 276 | # label: label pool [N] 277 | # indicator: indicate whether the sample is selected or not [N] 278 | # mean_pool: mean value of the pool [C, A] 279 | device = feature.device 280 | 281 | N = feature.size(0) 282 | C = class_num 283 | A = feature.size(1) 284 | 285 | var_temp = torch.zeros(C, A, A).to(device) 286 | 287 | NxCxFeatures = feature.view(N, 1, A).expand(N, C, A) 288 | 289 | onehot = torch.zeros(N, C).to(device) 290 | onehot.scatter_(1, label.view(-1, 1), 1) 291 | NxCxA_onehot = onehot.view(N, C, 1).expand(N, C, A) 292 | NxCxA_onehot = NxCxA_onehot * indicator.view(N, 1, 1).expand(N, C, A) 293 | 294 | Amount_CxA = NxCxA_onehot.sum(0) 295 | Amount_CxA[Amount_CxA == 0] = 1 296 | Amount_CxAxA = Amount_CxA.view(C, A, 1).expand(C, A, A) 297 | 298 | del Amount_CxA, onehot 299 | gc.collect() 300 | 301 | avg_NxCxA = mean_pool.expand(N, C, A) 302 | for c in range(C): 303 | features_by_sort_c = NxCxFeatures[:, c, :].mul(NxCxA_onehot[:, c, :]) 304 | avg_by_sort_c = avg_NxCxA[:, c, :].mul(NxCxA_onehot[:, c, :]) 305 | var_temp_c = features_by_sort_c - avg_by_sort_c 306 | var_temp[c] = torch.mm(var_temp_c.permute(1,0), var_temp_c).div(Amount_CxAxA[c]) 307 | return var_temp.detach() 308 | 309 | 310 | class PCFEALoss_no_mean(nn.Module): 311 | def __init__(self, class_num): 312 | super(PCFEALoss_no_mean, self).__init__() 313 | self.class_num = class_num 314 | self.cross_entropy = nn.CrossEntropyLoss() 315 | 316 | def aug(self, fc, features, pred, labels, cv_matrix, Lambda): 317 | device = features.device 318 | 319 | N = features.size(0) 320 | C = self.class_num 321 | A = features.size(1) 322 | 323 | weight_m = list(fc.parameters())[0] 324 | NxW_ij = weight_m.expand(N, C, A) 325 | NxW_kj = torch.gather(NxW_ij, 1, labels.view(N, 1, 1).expand(N, C, A)) 326 | 327 | CV_temp = cv_matrix[labels] 328 | 329 | sigma2 = Lambda * torch.bmm(torch.bmm(NxW_ij - NxW_kj, CV_temp), (NxW_ij - NxW_kj).permute(0, 2, 1)) 330 | sigma2 = sigma2.mul(torch.eye(C).to(device).expand(N, C, C)).sum(2).view(N, C) 331 | 332 | aug_result = pred + 0.5 * sigma2 333 | 334 | return aug_result 335 | 336 | def forward(self, fc, features, pred, labels, Lambda, covariance_sample): 337 | aug_y = self.aug(fc, features, pred, labels, covariance_sample, Lambda) 338 | loss = self.cross_entropy(aug_y, labels) 339 | return loss 340 | 341 | 342 | class PCFEALoss(nn.Module): 343 | def __init__(self, class_num): 344 | super(PCFEALoss, self).__init__() 345 | self.class_num = class_num 346 | self.cross_entropy = nn.CrossEntropyLoss() 347 | 348 | def aug(self, mean_matrix, mean_source, fc, features, pred, labels, cv_matrix, Lambda): 349 | device = features.device 350 | 351 | N = features.size(0) 352 | C = self.class_num 353 | A = features.size(1) 354 | 355 | weight_m = list(fc.parameters())[0] 356 | NxW_ij = weight_m.expand(N, C, A) 357 | NxW_kj = torch.gather(NxW_ij, 1, labels.view(N, 1, 1).expand(N, C, A)) 358 | 359 | CV_temp = cv_matrix[labels] 360 | 361 | sigma2 = Lambda * torch.bmm(torch.bmm(NxW_ij - NxW_kj, CV_temp), (NxW_ij - NxW_kj).permute(0, 2, 1)) 362 | sigma2 = sigma2.mul(torch.eye(C).to(device).expand(N, C, C)).sum(2).view(N, C) 363 | 364 | sourceMean_NxA = mean_source[labels] 365 | poolMean_NxA = mean_matrix[labels] 366 | dataMean_NxA = (poolMean_NxA - sourceMean_NxA) 367 | dataMean_NxAx1 = dataMean_NxA.expand(1, N, A).permute(1, 2, 0) 368 | 369 | dataW_NxCxA = NxW_ij - NxW_kj 370 | dataW_x_detaMean_NxCx1 = torch.bmm(dataW_NxCxA, dataMean_NxAx1) 371 | datW_x_detaMean_NxC = dataW_x_detaMean_NxCx1.view(N, C) 372 | 373 | aug_result = pred + 0.5 * sigma2 + datW_x_detaMean_NxC 374 | 375 | return aug_result 376 | 377 | def forward(self, fc, features, pred, labels, Lambda, mean_sample, mean_source, covariance_sample): 378 | aug_y = self.aug(mean_sample, mean_source, fc, features, pred, labels, covariance_sample, Lambda) 379 | loss = self.cross_entropy(aug_y, labels) 380 | return loss 381 | 382 | 383 | class Focal_loss(nn.Module): 384 | def __init__(self, alpha=0.25, gamma=2, reduction='mean'): 385 | super(Focal_loss, self).__init__() 386 | self.alpha = alpha 387 | self.gamma = gamma 388 | self.reduction = reduction 389 | 390 | def forward(self, pred, label): 391 | device = pred.device 392 | B, D = pred.shape # D is num_class 393 | # pred = torch.sigmoid(pred) # which function can be used here 394 | pred = F.softmax(pred, dim=-1) 395 | ones = torch.sparse.torch.eye(D).to(device) 396 | label_one_hot = ones.index_select(0, label) 397 | pred = (pred * label_one_hot).sum(dim=-1) 398 | loss = -self.alpha * (torch.pow((1 - pred), self.gamma)) * pred.log() 399 | 400 | if self.reduction == 'mean': 401 | loss = loss.mean() 402 | elif self.reduction == 'sum': 403 | loss = loss.sum() 404 | else: 405 | loss = loss 406 | 407 | return loss 408 | 409 | 410 | def dot(x, y): 411 | return torch.sum(x * y, dim=-1) 412 | 413 | 414 | # SimCLR 415 | class InfoNCE(nn.Module): 416 | def __init__(self, temperature): 417 | super(InfoNCE, self).__init__() 418 | self.T = temperature 419 | self.cossim = nn.CosineSimilarity(dim=-1) 420 | self.CELoss = nn.CrossEntropyLoss() 421 | 422 | def forward(self, data1, data2): 423 | data = torch.cat([data1, data2], dim=0) # 2*B, D 424 | device = data1.device 425 | B, D = data1.shape 426 | sim = self.cossim(data.unsqueeze(0), data.unsqueeze(1)) / self.T 427 | sim_pos = torch.cat([torch.diag(sim, B), torch.diag(sim, -B)], dim=0).reshape(2*B, 1) 428 | mask = torch.ones_like(sim).long().to(device) 429 | idx1 = torch.ones([B]).long().to(device) 430 | idx2 = torch.ones([2*B]).long().to(device) 431 | mask = mask - (torch.diag_embed(idx1, B) + torch.diag_embed(idx1, -B) + torch.diag_embed(idx2, 0)) 432 | mask = mask.bool() 433 | sim_neg = sim[mask].reshape(2*B, -1) 434 | 435 | sim = torch.cat([sim_pos, sim_neg], dim=-1) 436 | label = torch.zeros(2*B).long().to(device) 437 | loss = self.CELoss(sim, label) 438 | return loss 439 | 440 | 441 | def IDFALoss(prototype, feature, label, tao): 442 | # adopted from PCS 443 | # prototype: num_class, D 444 | # feature: batch_size, D 445 | proto_feature = prototype[label] 446 | sim_pos = torch.exp(torch.cosine_similarity(feature, proto_feature, dim=-1) / tao) 447 | sim_neg = torch.exp(torch.cosine_similarity(feature.unsqueeze(1), prototype.unsqueeze(0), dim=-1) / tao).sum(dim=-1) 448 | # sim_pos = torch.exp(torch.sum(feature * proto_feature, dim=-1) / tao) 449 | # sim_neg = torch.sum(torch.exp(torch.mm(feature, prototype.T) / tao), dim=-1) 450 | ratio = sim_pos / (sim_neg + 1e-6) 451 | proto_loss = -1 * torch.sum(torch.log(ratio)) / (ratio.size(0) + 1e-6) 452 | 453 | return proto_loss 454 | 455 | 456 | def sigmoid_function(input, k): 457 | return 1.0/(1 + math.exp(-input * k)) 458 | 459 | 460 | # calculate similarity between samples and center of clusters in PCS 461 | # this function is in torch utils in utils folder in pcs folder 462 | def contrastive_sim(instances, proto=None, tao=0.05): 463 | # prob_matrix [bs, dim] 464 | # proto_dim [nums, dim] 465 | if proto is None: 466 | proto = instances 467 | ins_ext = instances.unsqueeze(1).repeat(1, proto.size(0), 1) 468 | sim_matrix = torch.exp(torch.sum(ins_ext * proto, dim=-1) / tao) 469 | return sim_matrix 470 | 471 | 472 | def cosine_similarity(data, center): 473 | data = data.view(data.shape[0], -1) 474 | center = center.view(center.shape[0], -1) 475 | data = F.normalize(data) 476 | center = F.normalize(center) 477 | distance = data.mm(center.t()) 478 | 479 | return distance 480 | 481 | 482 | if __name__ == '__main__': 483 | data1 = torch.rand([24, 3, 1024]).cuda() 484 | data2 = torch.rand([24, 3, 512]).cuda() 485 | -------------------------------------------------------------------------------- /utils/trans_norm.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from torch.nn.modules.module import Module 4 | from torch.nn.parameter import Parameter 5 | import torch 6 | import itertools 7 | 8 | class _TransNorm(Module): 9 | 10 | """http: // ise.thss.tsinghua.edu.cn / ~mlong / doc / transferable - normalization - nips19.pdf""" 11 | 12 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True): 13 | super(_TransNorm, self).__init__() 14 | self.num_features = num_features 15 | self.eps = eps 16 | self.momentum = momentum 17 | self.affine = affine 18 | self.track_running_stats = track_running_stats 19 | if self.affine: 20 | self.weight = Parameter(torch.Tensor(num_features)) 21 | self.bias = Parameter(torch.Tensor(num_features)) 22 | else: 23 | self.register_parameter('weight', None) 24 | self.register_parameter('bias', None) 25 | 26 | if self.track_running_stats: 27 | self.register_buffer('running_mean_source', torch.zeros(num_features)) 28 | self.register_buffer('running_mean_target', torch.zeros(num_features)) 29 | self.register_buffer('running_var_source', torch.ones(num_features)) 30 | self.register_buffer('running_var_target', torch.ones(num_features)) 31 | self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long)) 32 | else: 33 | self.register_parameter('running_mean_source', None) 34 | self.register_parameter('running_mean_target', None) 35 | self.register_parameter('running_var_source', None) 36 | self.register_parameter('running_var_target', None) 37 | self.reset_parameters() 38 | 39 | def reset_parameters(self): 40 | if self.track_running_stats: 41 | self.running_mean_source.zero_() 42 | self.running_mean_target.zero_() 43 | self.running_var_source.fill_(1) 44 | self.running_var_target.fill_(1) 45 | if self.affine: 46 | self.weight.data.uniform_() 47 | self.bias.data.zero_() 48 | 49 | def _check_input_dim(self, input): 50 | return NotImplemented 51 | 52 | def _load_from_state_dict_from_pretrained_model(self, state_dict, prefix, metadata, strict, missing_keys, unexpected_keys, error_msgs): 53 | r"""Copies parameters and buffers from :attr:`state_dict` into only 54 | this module, but not its descendants. This is called on every submodule 55 | in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this 56 | module in input :attr:`state_dict` is provided as :attr`metadata`. 57 | For state dicts without meta data, :attr`metadata` is empty. 58 | Subclasses can achieve class-specific backward compatible loading using 59 | the version number at `metadata.get("version", None)`. 60 | .. note:: 61 | :attr:`state_dict` is not the same object as the input 62 | :attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So 63 | it can be modified. 64 | Arguments: 65 | state_dict (dict): a dict containing parameters and 66 | persistent buffers. 67 | prefix (str): the prefix for parameters and buffers used in this 68 | module 69 | metadata (dict): a dict containing the metadata for this moodule. 70 | See 71 | strict (bool): whether to strictly enforce that the keys in 72 | :attr:`state_dict` with :attr:`prefix` match the names of 73 | parameters and buffers in this module 74 | missing_keys (list of str): if ``strict=False``, add missing keys to 75 | this list 76 | unexpected_keys (list of str): if ``strict=False``, add unexpected 77 | keys to this list 78 | error_msgs (list of str): error messages should be added to this 79 | list, and will be reported together in 80 | :meth:`~torch.nn.Module.load_state_dict` 81 | """ 82 | local_name_params = itertools.chain(self._parameters.items(), self._buffers.items()) 83 | local_state = {k: v.data for k, v in local_name_params if v is not None} 84 | 85 | for name, param in local_state.items(): 86 | key = prefix + name 87 | # if 'source' in key or 'target' in key: 88 | # key = key[:-7] 89 | # print(key) 90 | if key in state_dict: 91 | input_param = state_dict[key] 92 | if input_param.shape != param.shape: 93 | # local shape should match the one in checkpoint 94 | error_msgs.append('size mismatch for {}: copying a param of {} from checkpoint, ' 95 | 'where the shape is {} in current model.' 96 | .format(key, param.shape, input_param.shape)) 97 | continue 98 | if isinstance(input_param, Parameter): 99 | # backwards compatibility for serialized parameters 100 | input_param = input_param.data 101 | try: 102 | param.copy_(input_param) 103 | except Exception: 104 | error_msgs.append('While copying the parameter named "{}", ' 105 | 'whose dimensions in the model are {} and ' 106 | 'whose dimensions in the checkpoint are {}.' 107 | .format(key, param.size(), input_param.size())) 108 | elif strict: 109 | missing_keys.append(key) 110 | 111 | 112 | 113 | def forward(self, input): 114 | self._check_input_dim(input) 115 | if self.training : ## train mode 116 | 117 | ## 1. Domain Specific Mean and Variance. 118 | batch_size = input.size()[0] // 2 119 | input_source = input[:batch_size] 120 | input_target = input[batch_size:] 121 | 122 | ## 2. Domain Sharing Gamma and Beta. 123 | z_source = F.batch_norm( 124 | input_source, self.running_mean_source, self.running_var_source, self.weight, self.bias, 125 | self.training or not self.track_running_stats, self.momentum, self.eps) 126 | 127 | z_target = F.batch_norm( 128 | input_target, self.running_mean_target, self.running_var_target, self.weight, self.bias, 129 | self.training or not self.track_running_stats, self.momentum, self.eps) 130 | z = torch.cat((z_source, z_target), dim=0) 131 | 132 | if input.dim() == 4: ## TransNorm2d 133 | input_source = input_source.permute(0,2,3,1).contiguous().view(-1,self.num_features) 134 | input_target = input_target.permute(0,2,3,1).contiguous().view(-1,self.num_features) 135 | 136 | cur_mean_source = torch.mean(input_source, dim=0) 137 | cur_var_source = torch.var(input_source,dim=0) 138 | cur_mean_target = torch.mean(input_target, dim=0) 139 | cur_var_target = torch.var(input_target, dim=0) 140 | 141 | ## 3. Domain Adaptive Alpha. 142 | 143 | ### 3.1 Calculating Distance 144 | dis = torch.abs(cur_mean_source / torch.sqrt(cur_var_source + self.eps) - 145 | cur_mean_target / torch.sqrt(cur_var_target + self.eps)) 146 | 147 | ### 3.2 Generating Probability 148 | prob = 1.0 / (1.0 + dis) 149 | alpha = self.num_features * prob / sum(prob) 150 | 151 | if input.dim() == 2: 152 | alpha = alpha.view(1, self.num_features) 153 | elif input.dim() == 4: 154 | alpha = alpha.view(1, self.num_features, 1, 1) 155 | 156 | ## 3.3 Residual Connection 157 | return z * (1 + alpha.detach()) 158 | 159 | 160 | else: ##test mode 161 | z = F.batch_norm( 162 | input, self.running_mean_target, self.running_var_target, self.weight, self.bias, 163 | self.training or not self.track_running_stats, self.momentum, self.eps) 164 | 165 | dis = torch.abs(self.running_mean_source / torch.sqrt(self.running_var_source + self.eps) 166 | - self.running_mean_target / torch.sqrt(self.running_var_target + self.eps)) 167 | prob = 1.0 / (1.0 + dis) 168 | alpha = self.num_features * prob / sum(prob) 169 | 170 | if input.dim() == 2: 171 | alpha = alpha.view(1, self.num_features) 172 | elif input.dim() == 4: 173 | alpha = alpha.view(1, self.num_features, 1, 1) 174 | return z * (1 + alpha.detach()) 175 | 176 | def extra_repr(self): 177 | return '{num_features}, eps={eps}, momentum={momentum}, affine={affine}, ' \ 178 | 'track_running_stats={track_running_stats}'.format(**self.__dict__) 179 | 180 | def _load_from_state_dict(self, state_dict, prefix, metadata, strict, 181 | missing_keys, unexpected_keys, error_msgs): 182 | version = metadata.get('version', None) 183 | if (version is None or version < 2) and self.track_running_stats: 184 | # at version 2: added num_batches_tracked buffer 185 | # this should have a default value of 0 186 | num_batches_tracked_key = prefix + 'num_batches_tracked' 187 | if num_batches_tracked_key not in state_dict: 188 | state_dict[num_batches_tracked_key] = torch.tensor(0, dtype=torch.long) 189 | 190 | self._load_from_state_dict_from_pretrained_model( 191 | state_dict, prefix, metadata, strict, 192 | missing_keys, unexpected_keys, error_msgs) 193 | 194 | 195 | class TransNorm1d(_TransNorm): 196 | r"""Applies Batch Normalization over a 2D or 3D input (a mini-batch of 1D 197 | inputs with optional additional channel dimension) as described in the paper 198 | `Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`_ . 199 | .. math:: 200 | y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta 201 | The mean and standard-deviation are calculated per-dimension over 202 | the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors 203 | of size `C` (where `C` is the input size). 204 | By default, during training this layer keeps running estimates of its 205 | computed mean and variance, which are then used for normalization during 206 | evaluation. The running estimates are kept with a default :attr:`momentum` 207 | of 0.1. 208 | If :attr:`track_running_stats` is set to ``False``, this layer then does not 209 | keep running estimates, and batch statistics are instead used during 210 | evaluation time as well. 211 | .. note:: 212 | This :attr:`momentum` argument is different from one used in optimizer 213 | classes and the conventional notion of momentum. Mathematically, the 214 | update rule for running statistics here is 215 | :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momemtum} \times x_t`, 216 | where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the 217 | new observed value. 218 | Because the Batch Normalization is done over the `C` dimension, computing statistics 219 | on `(N, L)` slices, it's common terminology to call this Temporal Batch Normalization. 220 | Args: 221 | num_features: :math:`C` from an expected input of size 222 | :math:`(N, C, L)` or :math:`L` from input of size :math:`(N, L)` 223 | eps: a value added to the denominator for numerical stability. 224 | Default: 1e-5 225 | momentum: the value used for the running_mean and running_var 226 | computation. Can be set to ``None`` for cumulative moving average 227 | (i.e. simple average). Default: 0.1 228 | affine: a boolean value that when set to ``True``, this module has 229 | learnable affine parameters. Default: ``True`` 230 | track_running_stats: a boolean value that when set to ``True``, this 231 | module tracks the running mean and variance, and when set to ``False``, 232 | this module does not track such statistics and always uses batch 233 | statistics in both training and eval modes. Default: ``True`` 234 | Shape: 235 | - Input: :math:`(N, C)` or :math:`(N, C, L)` 236 | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) 237 | Examples:: 238 | >>> # With Learnable Parameters 239 | >>> m = nn.BatchNorm1d(100) 240 | >>> # Without Learnable Parameters 241 | >>> m = nn.BatchNorm1d(100, affine=False) 242 | >>> input = torch.randn(20, 100) 243 | >>> output = m(input) 244 | .. _`Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`: 245 | https://arxiv.org/abs/1502.03167 246 | """ 247 | 248 | def _check_input_dim(self, input): 249 | if input.dim() != 2 and input.dim() != 3: 250 | raise ValueError('expected 2D or 3D input (got {}D input)' 251 | .format(input.dim())) 252 | 253 | 254 | class TransNorm2d(_TransNorm): 255 | r"""Applies Batch Normalization over a 4D input (a mini-batch of 2D inputs 256 | with additional channel dimension) as described in the paper 257 | `Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`_ . 258 | .. math:: 259 | y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta 260 | The mean and standard-deviation are calculated per-dimension over 261 | the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors 262 | of size `C` (where `C` is the input size). 263 | By default, during training this layer keeps running estimates of its 264 | computed mean and variance, which are then used for normalization during 265 | evaluation. The running estimates are kept with a default :attr:`momentum` 266 | of 0.1. 267 | If :attr:`track_running_stats` is set to ``False``, this layer then does not 268 | keep running estimates, and batch statistics are instead used during 269 | evaluation time as well. 270 | .. note:: 271 | This :attr:`momentum` argument is different from one used in optimizer 272 | classes and the conventional notion of momentum. Mathematically, the 273 | update rule for running statistics here is 274 | :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momemtum} \times x_t`, 275 | where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the 276 | new observed value. 277 | Because the Batch Normalization is done over the `C` dimension, computing statistics 278 | on `(N, H, W)` slices, it's common terminology to call this Spatial Batch Normalization. 279 | Args: 280 | num_features: :math:`C` from an expected input of size 281 | :math:`(N, C, H, W)` 282 | eps: a value added to the denominator for numerical stability. 283 | Default: 1e-5 284 | momentum: the value used for the running_mean and running_var 285 | computation. Can be set to ``None`` for cumulative moving average 286 | (i.e. simple average). Default: 0.1 287 | affine: a boolean value that when set to ``True``, this module has 288 | learnable affine parameters. Default: ``True`` 289 | track_running_stats: a boolean value that when set to ``True``, this 290 | module tracks the running mean and variance, and when set to ``False``, 291 | this module does not track such statistics and always uses batch 292 | statistics in both training and eval modes. Default: ``True`` 293 | Shape: 294 | - Input: :math:`(N, C, H, W)` 295 | - Output: :math:`(N, C, H, W)` (same shape as input) 296 | Examples:: 297 | >>> # With Learnable Parameters 298 | >>> m = nn.BatchNorm2d(100) 299 | >>> # Without Learnable Parameters 300 | >>> m = nn.BatchNorm2d(100, affine=False) 301 | >>> input = torch.randn(20, 100, 35, 45) 302 | >>> output = m(input) 303 | .. _`Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`: 304 | https://arxiv.org/abs/1502.03167 305 | """ 306 | 307 | def _check_input_dim(self, input): 308 | if input.dim() != 4: 309 | raise ValueError('expected 4D input (got {}D input)' 310 | .format(input.dim())) 311 | 312 | 313 | class TransNorm3d(_TransNorm): 314 | r"""Applies Batch Normalization over a 5D input (a mini-batch of 3D inputs 315 | with additional channel dimension) as described in the paper 316 | `Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`_ . 317 | .. math:: 318 | y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta 319 | The mean and standard-deviation are calculated per-dimension over 320 | the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors 321 | of size `C` (where `C` is the input size). 322 | By default, during training this layer keeps running estimates of its 323 | computed mean and variance, which are then used for normalization during 324 | evaluation. The running estimates are kept with a default :attr:`momentum` 325 | of 0.1. 326 | If :attr:`track_running_stats` is set to ``False``, this layer then does not 327 | keep running estimates, and batch statistics are instead used during 328 | evaluation time as well. 329 | .. note:: 330 | This :attr:`momentum` argument is different from one used in optimizer 331 | classes and the conventional notion of momentum. Mathematically, the 332 | update rule for running statistics here is 333 | :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momemtum} \times x_t`, 334 | where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the 335 | new observed value. 336 | Because the Batch Normalization is done over the `C` dimension, computing statistics 337 | on `(N, D, H, W)` slices, it's common terminology to call this Volumetric Batch Normalization 338 | or Spatio-temporal Batch Normalization. 339 | Args: 340 | num_features: :math:`C` from an expected input of size 341 | :math:`(N, C, D, H, W)` 342 | eps: a value added to the denominator for numerical stability. 343 | Default: 1e-5 344 | momentum: the value used for the running_mean and running_var 345 | computation. Can be set to ``None`` for cumulative moving average 346 | (i.e. simple average). Default: 0.1 347 | affine: a boolean value that when set to ``True``, this module has 348 | learnable affine parameters. Default: ``True`` 349 | track_running_stats: a boolean value that when set to ``True``, this 350 | module tracks the running mean and variance, and when set to ``False``, 351 | this module does not track such statistics and always uses batch 352 | statistics in both training and eval modes. Default: ``True`` 353 | Shape: 354 | - Input: :math:`(N, C, D, H, W)` 355 | - Output: :math:`(N, C, D, H, W)` (same shape as input) 356 | Examples:: 357 | >>> # With Learnable Parameters 358 | >>> m = nn.BatchNorm3d(100) 359 | >>> # Without Learnable Parameters 360 | >>> m = nn.BatchNorm3d(100, affine=False) 361 | >>> input = torch.randn(20, 100, 35, 45, 10) 362 | >>> output = m(input) 363 | .. _`Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`: 364 | https://arxiv.org/abs/1502.03167 365 | """ 366 | 367 | def _check_input_dim(self, input): 368 | if input.dim() != 5: 369 | raise ValueError('expected 5D input (got {}D input)' 370 | .format(input.dim())) -------------------------------------------------------------------------------- /augmentation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | import torch 4 | import matplotlib.pyplot as plt 5 | from mpl_toolkits.mplot3d import Axes3D 6 | import os 7 | import glob 8 | import open3d as o3d 9 | import pdb 10 | # from models.pointnet2_utils import farthest_point_sample, query_ball_point, index_points 11 | 12 | 13 | def reshape_num(data, num_point): 14 | """ 15 | 16 | :param data: N * 3 17 | :param num_point: 18 | :return: 19 | """ 20 | len_data = data.shape[0] 21 | if len_data < num_point: 22 | num_cat = math.ceil(num_point / len_data) 23 | data = data.repeat(num_cat, 0)[:num_point, :] 24 | else: 25 | data = data[:num_point, :] 26 | 27 | return data 28 | 29 | 30 | def normalize_data(batch_data): 31 | """ Normalize the batch data, use coordinates of the block centered at origin, 32 | Input: 33 | BxNxC array 34 | Output: 35 | BxNxC array 36 | """ 37 | B, N, C = batch_data.shape 38 | normal_data = np.zeros((B, N, C)) 39 | for b in range(B): 40 | pc = batch_data[b] 41 | centroid = np.mean(pc, axis=0) 42 | pc = pc - centroid 43 | m = np.max(np.sqrt(np.sum(pc ** 2, axis=1))) 44 | pc = pc / m 45 | normal_data[b] = pc 46 | return normal_data 47 | 48 | 49 | def shuffle_data(data, labels): 50 | """ Shuffle data and labels. 51 | Input: 52 | data: B,N,... numpy array 53 | label: B,... numpy array 54 | Return: 55 | shuffled data, label and shuffle indices 56 | """ 57 | idx = np.arange(len(labels)) 58 | np.random.shuffle(idx) 59 | return data[idx, ...], labels[idx], idx 60 | 61 | 62 | def shuffle_points(batch_data): 63 | """ Shuffle orders of points in each point cloud -- changes FPS behavior. 64 | Use the same shuffling idx for the entire batch. 65 | Input: 66 | BxNxC array 67 | Output: 68 | BxNxC array 69 | """ 70 | idx = np.arange(batch_data.shape[1]) 71 | np.random.shuffle(idx) 72 | return batch_data[:, idx, :] 73 | 74 | 75 | def rotate_point_cloud(batch_data): 76 | """ Randomly rotate the point clouds to augument the dataset 77 | rotation is per shape based along up direction 78 | Input: 79 | BxNx3 array, original batch of point clouds 80 | Return: 81 | BxNx3 array, rotated batch of point clouds 82 | """ 83 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 84 | for k in range(batch_data.shape[0]): 85 | rotation_angle = np.random.uniform() * 2 * np.pi 86 | cosval = np.cos(rotation_angle) 87 | sinval = np.sin(rotation_angle) 88 | rotation_matrix = np.array([[cosval, 0, sinval], 89 | [0, 1, 0], 90 | [-sinval, 0, cosval]]) 91 | shape_pc = batch_data[k, ...] 92 | rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) 93 | return rotated_data 94 | 95 | 96 | def rotate_point_cloud_z(batch_data): 97 | """ Randomly rotate the point clouds to augument the dataset 98 | rotation is per shape based along up direction 99 | Input: 100 | BxNx3 array, original batch of point clouds 101 | Return: 102 | BxNx3 array, rotated batch of point clouds 103 | """ 104 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 105 | for k in range(batch_data.shape[0]): 106 | rotation_angle = np.random.uniform() * 2 * np.pi 107 | cosval = np.cos(rotation_angle) 108 | sinval = np.sin(rotation_angle) 109 | rotation_matrix = np.array([[cosval, sinval, 0], 110 | [-sinval, cosval, 0], 111 | [0, 0, 1]]) 112 | shape_pc = batch_data[k, ...] 113 | rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) 114 | return rotated_data 115 | 116 | 117 | def rotate_point_cloud_with_normal(batch_xyz_normal): 118 | """ Randomly rotate XYZ, normal point cloud. 119 | Input: 120 | batch_xyz_normal: B,N,6, first three channels are XYZ, last 3 all normal 121 | Output: 122 | B,N,6, rotated XYZ, normal point cloud 123 | """ 124 | for k in range(batch_xyz_normal.shape[0]): 125 | rotation_angle = np.random.uniform() * 2 * np.pi 126 | cosval = np.cos(rotation_angle) 127 | sinval = np.sin(rotation_angle) 128 | rotation_matrix = np.array([[cosval, 0, sinval], 129 | [0, 1, 0], 130 | [-sinval, 0, cosval]]) 131 | shape_pc = batch_xyz_normal[k, :, 0:3] 132 | shape_normal = batch_xyz_normal[k, :, 3:6] 133 | batch_xyz_normal[k, :, 0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) 134 | batch_xyz_normal[k, :, 3:6] = np.dot(shape_normal.reshape((-1, 3)), rotation_matrix) 135 | return batch_xyz_normal 136 | 137 | 138 | def rotate_perturbation_point_cloud_with_normal(batch_data, angle_sigma=0.06, angle_clip=0.18): 139 | """ Randomly perturb the point clouds by small rotations 140 | Input: 141 | BxNx6 array, original batch of point clouds and point normals 142 | Return: 143 | BxNx3 array, rotated batch of point clouds 144 | """ 145 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 146 | for k in range(batch_data.shape[0]): 147 | angles = np.clip(angle_sigma * np.random.randn(3), -angle_clip, angle_clip) 148 | Rx = np.array([[1, 0, 0], 149 | [0, np.cos(angles[0]), -np.sin(angles[0])], 150 | [0, np.sin(angles[0]), np.cos(angles[0])]]) 151 | Ry = np.array([[np.cos(angles[1]), 0, np.sin(angles[1])], 152 | [0, 1, 0], 153 | [-np.sin(angles[1]), 0, np.cos(angles[1])]]) 154 | Rz = np.array([[np.cos(angles[2]), -np.sin(angles[2]), 0], 155 | [np.sin(angles[2]), np.cos(angles[2]), 0], 156 | [0, 0, 1]]) 157 | R = np.dot(Rz, np.dot(Ry, Rx)) 158 | shape_pc = batch_data[k, :, 0:3] 159 | shape_normal = batch_data[k, :, 3:6] 160 | rotated_data[k, :, 0:3] = np.dot(shape_pc.reshape((-1, 3)), R) 161 | rotated_data[k, :, 3:6] = np.dot(shape_normal.reshape((-1, 3)), R) 162 | return rotated_data 163 | 164 | 165 | def rotate_point_cloud_by_angle(batch_data, rotation_angle): 166 | """ Rotate the point cloud along up direction with certain angle. 167 | Input: 168 | BxNx3 array, original batch of point clouds 169 | Return: 170 | BxNx3 array, rotated batch of point clouds 171 | """ 172 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 173 | for k in range(batch_data.shape[0]): 174 | # rotation_angle = np.random.uniform() * 2 * np.pi 175 | cosval = np.cos(rotation_angle) 176 | sinval = np.sin(rotation_angle) 177 | rotation_matrix = np.array([[cosval, 0, sinval], 178 | [0, 1, 0], 179 | [-sinval, 0, cosval]]) 180 | shape_pc = batch_data[k, :, 0:3] 181 | rotated_data[k, :, 0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) 182 | return rotated_data 183 | 184 | 185 | def rotate_point_cloud_by_angle_with_normal(batch_data, rotation_angle): 186 | """ Rotate the point cloud along up direction with certain angle. 187 | Input: 188 | BxNx6 array, original batch of point clouds with normal 189 | scalar, angle of rotation 190 | Return: 191 | BxNx6 array, rotated batch of point clouds iwth normal 192 | """ 193 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 194 | for k in range(batch_data.shape[0]): 195 | # rotation_angle = np.random.uniform() * 2 * np.pi 196 | cosval = np.cos(rotation_angle) 197 | sinval = np.sin(rotation_angle) 198 | rotation_matrix = np.array([[cosval, 0, sinval], 199 | [0, 1, 0], 200 | [-sinval, 0, cosval]]) 201 | shape_pc = batch_data[k, :, 0:3] 202 | shape_normal = batch_data[k, :, 3:6] 203 | rotated_data[k, :, 0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) 204 | rotated_data[k, :, 3:6] = np.dot(shape_normal.reshape((-1, 3)), rotation_matrix) 205 | return rotated_data 206 | 207 | 208 | def rotate_perturbation_point_cloud(batch_data, angle_sigma=0.06, angle_clip=0.18): 209 | """ Randomly perturb the point clouds by small rotations 210 | Input: 211 | BxNx3 array, original batch of point clouds 212 | Return: 213 | BxNx3 array, rotated batch of point clouds 214 | """ 215 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 216 | for k in range(batch_data.shape[0]): 217 | angles = np.clip(angle_sigma * np.random.randn(3), -angle_clip, angle_clip) 218 | Rx = np.array([[1, 0, 0], 219 | [0, np.cos(angles[0]), -np.sin(angles[0])], 220 | [0, np.sin(angles[0]), np.cos(angles[0])]]) 221 | Ry = np.array([[np.cos(angles[1]), 0, np.sin(angles[1])], 222 | [0, 1, 0], 223 | [-np.sin(angles[1]), 0, np.cos(angles[1])]]) 224 | Rz = np.array([[np.cos(angles[2]), -np.sin(angles[2]), 0], 225 | [np.sin(angles[2]), np.cos(angles[2]), 0], 226 | [0, 0, 1]]) 227 | R = np.dot(Rz, np.dot(Ry, Rx)) 228 | shape_pc = batch_data[k, ...] 229 | rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), R) 230 | return rotated_data 231 | 232 | 233 | def jitter_point_cloud(batch_data, sigma=0.01, clip=0.05): 234 | """ Randomly jitter points. jittering is per point. 235 | Input: 236 | BxNx3 array, original batch of point clouds 237 | Return: 238 | BxNx3 array, jittered batch of point clouds 239 | """ 240 | B, N, C = batch_data.shape 241 | assert (clip > 0) 242 | jittered_data = np.clip(sigma * np.random.randn(B, N, C), -1 * clip, clip) 243 | jittered_data += batch_data 244 | return jittered_data 245 | 246 | 247 | def shift_point_cloud(batch_data, shift_range=0.1): 248 | """ Randomly shift point cloud. Shift is per point cloud. 249 | Input: 250 | BxNx3 array, original batch of point clouds 251 | Return: 252 | BxNx3 array, shifted batch of point clouds 253 | """ 254 | B, N, C = batch_data.shape 255 | shifts = np.random.uniform(-shift_range, shift_range, (B, 3)) 256 | for batch_index in range(B): 257 | batch_data[batch_index, :, :] += shifts[batch_index, :] 258 | return batch_data 259 | 260 | 261 | def random_scale_point_cloud(batch_data, scale_low=0.8, scale_high=1.25): 262 | """ Randomly scale the point cloud. Scale is per point cloud. 263 | Input: 264 | BxNx3 array, original batch of point clouds 265 | Return: 266 | BxNx3 array, scaled batch of point clouds 267 | """ 268 | B, N, C = batch_data.shape 269 | scales = np.random.uniform(scale_low, scale_high, B) 270 | for batch_index in range(B): 271 | batch_data[batch_index, :, :] *= scales[batch_index] 272 | return batch_data 273 | 274 | 275 | def random_point_dropout(batch_pc, max_dropout_ratio=0.875): 276 | """ batch_pc: BxNx3 """ 277 | for b in range(batch_pc.shape[0]): 278 | dropout_ratio = np.random.random() * max_dropout_ratio # 0~0.875 279 | drop_idx = np.where(np.random.random((batch_pc.shape[1])) <= dropout_ratio)[0] 280 | if len(drop_idx) > 0: 281 | batch_pc[b, drop_idx, :] = batch_pc[b, 0, :] # set to the first point 282 | return batch_pc 283 | 284 | 285 | def crop_point(data): 286 | """ 287 | # random crop: along one axis 288 | :param data: B, N, C 0 - 1 289 | :return: 290 | """ 291 | B, N, C = data.shape 292 | data_new = [] 293 | for ii in range(B): 294 | len_x = data[ii, :, 0].max() - data[ii, :, 0].min() 295 | threshold = data[ii, :, 0].min() + len_x * (np.random.random() * 0.2 + 0.4) # 0.4 - 0.6 296 | data1_indicate = data[ii, :, 0] < threshold 297 | data2_indicate = data[ii, :, 0] >= threshold 298 | 299 | num_data1 = data1_indicate.sum() 300 | num_data2 = data2_indicate.sum() 301 | 302 | data_indicate = data1_indicate if num_data1 > num_data2 else data2_indicate 303 | data_select = data[ii, data_indicate, :] 304 | 305 | data_select = reshape_num(data_select, num_point=1024) 306 | 307 | if len(data_new) == 0: 308 | data_new = np.expand_dims(data_select, axis=0) 309 | else: 310 | data_new = np.concatenate([data_new, np.expand_dims(data_select, axis=0)], axis=0) 311 | 312 | return data_new 313 | 314 | 315 | def rotate_point_cloud_3d(pc): 316 | rotation_angle = np.random.rand(3) * 2 * np.pi 317 | cosval = np.cos(rotation_angle) 318 | sinval = np.sin(rotation_angle) 319 | rotation_matrix_1 = np.array([[cosval[0], 0, sinval[0]], 320 | [0, 1, 0], 321 | [-sinval[0], 0, cosval[0]]]) 322 | rotation_matrix_2 = np.array([[1, 0, 0], 323 | [0, cosval[1], -sinval[1]], 324 | [0, sinval[1], cosval[1]]]) 325 | rotation_matrix_3 = np.array([[cosval[2], -sinval[2], 0], 326 | [sinval[2], cosval[2], 0], 327 | [0, 0, 1]]) 328 | rotation_matrix = np.matmul(np.matmul(rotation_matrix_1, rotation_matrix_2), rotation_matrix_3) 329 | rotated_data = np.dot(pc.reshape((-1, 3)), rotation_matrix) 330 | 331 | return rotated_data 332 | 333 | 334 | def density(pc, num_point=1024): 335 | # N, C 336 | try: 337 | rand_points = np.random.uniform(-1, 1, 40000) 338 | x1 = rand_points[:20000] 339 | x2 = rand_points[20000:] 340 | power_sum = x1 ** 2 + x2 ** 2 341 | p_filter = power_sum < 1 342 | power_sum = power_sum[p_filter] 343 | sqrt_sum = np.sqrt(1 - power_sum) 344 | x1 = x1[p_filter] 345 | x2 = x2[p_filter] 346 | x = (2 * x1 * sqrt_sum).reshape(-1, 1) 347 | y = (2 * x2 * sqrt_sum).reshape(-1, 1) 348 | z = (1 - 2 * power_sum).reshape(-1, 1) 349 | density_points = np.hstack([x, y, z]) 350 | v_point = density_points[np.random.choice(density_points.shape[0])] 351 | 352 | gate = np.random.uniform(low=1.3, high=1.6) 353 | dist = np.sqrt((v_point ** 2).sum()) 354 | max_dist = dist + 1 355 | min_dist = dist - 1 356 | dist = np.linalg.norm(pc - v_point.reshape(1, 3), axis=1) 357 | dist = (dist - min_dist) / (max_dist - min_dist) 358 | r_list = np.random.uniform(0.75, 1, pc.shape[0]) 359 | tmp_pc = pc[dist * gate < (r_list)] 360 | 361 | num_pad = np.ceil(num_point / tmp_pc.shape[0]).astype(np.long) 362 | pc = np.tile(tmp_pc, (num_pad, 1))[:num_point] 363 | except: 364 | pc = pc 365 | 366 | return pc 367 | 368 | 369 | def drop_hole(pc, num_point=1024): 370 | # N, C 371 | try: 372 | p = np.random.uniform(low=0.25, high=0.45) 373 | random_point = np.random.randint(0, pc.shape[0]) 374 | index = np.linalg.norm(pc - pc[random_point].reshape(1, 3), axis=1).argsort() 375 | 376 | tmp_pc = pc[index[int(pc.shape[0] * p):]] 377 | num_pad = np.ceil(num_point / tmp_pc.shape[0]).astype(np.long) 378 | pc = np.tile(tmp_pc, (num_pad, 1))[:num_point] 379 | except: 380 | pc = pc 381 | 382 | return pc 383 | 384 | 385 | def p_scan(pc, pixel_size=0.022, num_point=1024): 386 | # N, C 387 | try: 388 | pixel = int(2 / pixel_size) 389 | rotated_pc = rotate_point_cloud_3d(pc) 390 | pc_compress = (rotated_pc[:, 2] + 1) / 2 * pixel * pixel + (rotated_pc[:, 1] + 1) / 2 * pixel 391 | points_list = [None for i in range((pixel + 5) * (pixel + 5))] 392 | pc_compress = pc_compress.astype(np.int) 393 | for index, point in enumerate(rotated_pc): 394 | compress_index = pc_compress[index] 395 | if compress_index > len(points_list): 396 | print('out of index:', compress_index, len(points_list), point, pc[index], (pc[index] ** 2).sum(), 397 | (point ** 2).sum()) 398 | if points_list[compress_index] is None: 399 | points_list[compress_index] = index 400 | elif point[0] > rotated_pc[points_list[compress_index]][0]: 401 | points_list[compress_index] = index 402 | points_list = list(filter(lambda x: x is not None, points_list)) 403 | points_list = pc[points_list] 404 | 405 | num_pad = np.ceil(num_point / points_list.shape[0]).astype(np.long) 406 | points_list = np.tile(points_list, (num_pad, 1))[:num_point] 407 | except: 408 | points_list = pc 409 | 410 | return points_list 411 | 412 | 413 | def add_noise(data, noise=0.005): 414 | B, N, C = data.shape 415 | noise = (noise ** 0.5) * np.random.randn(B, N, C) 416 | data = data + noise 417 | return data 418 | 419 | 420 | def weak_aug(data): 421 | # data: B, N, C 422 | device = data.device 423 | data = data.cpu().numpy() 424 | # data = random_point_dropout(data) 425 | data[:, :, 0:3] = random_scale_point_cloud(data[:, :, 0:3]) 426 | data[:, :, 0:3] = shift_point_cloud(data[:, :, 0:3]) 427 | return torch.Tensor(data).to(device) 428 | 429 | 430 | def strong_aug(data, aug_type='2'): 431 | # data: B, N, C 432 | # v1 433 | device = data.device 434 | data = data.cpu().numpy() 435 | if aug_type == '1': 436 | data = normalize_data(data) 437 | # data = crop_point(data) 438 | data = density(data) 439 | # data = random_point_dropout(data) 440 | data = random_scale_point_cloud(data) 441 | data = shift_point_cloud(data) 442 | data = shuffle_points(data) 443 | data = rotate_point_cloud(data) 444 | data = add_noise(data) 445 | 446 | # v2 447 | else: 448 | data = normalize_data(data) 449 | # data = crop_point(data) 450 | data = drop_hole(data) 451 | # data = random_point_dropout(data) 452 | data = random_scale_point_cloud(data) 453 | data = shift_point_cloud(data) 454 | data = shuffle_points(data) 455 | data = rotate_point_cloud(data) 456 | data = add_noise(data) 457 | 458 | # fig = plt.figure() 459 | # ax = Axes3D(fig) 460 | # ax.scatter(data[0, :, 0], data[0, :, 1], data[0, :, 2]) 461 | 462 | return torch.Tensor(data).to(device) # B, N, C 463 | 464 | 465 | # def mask_data(xyz, mask_p=0.5, npoint=64, radius=0.1, nsample=16): 466 | # if xyz.shape[1] < xyz.shape[2]: 467 | # xyz = xyz.transpose(2, 1) 468 | # B, N, C = xyz.shape 469 | # fps_idx = farthest_point_sample(xyz, npoint) # 选出128个关键点 470 | # select_p = np.random.random_integers(0, npoint, [B, npoint]) 471 | # mask_idx = torch.from_numpy(select_p > (mask_p * npoint)) # 以一定的概率随机选关键点,大于p的是留下来的 472 | # mask_idx = fps_idx * mask_idx 473 | # select_xyz = index_points(xyz, mask_idx) 474 | # idx = query_ball_point(radius, nsample, xyz, select_xyz) 475 | # grouped_xyz = index_points(xyz, idx).reshape(B, -1, C) 476 | 477 | # fig = plt.figure() 478 | # ax = Axes3D(fig) 479 | # ax.scatter(grouped_xyz[3, :, 0], grouped_xyz[3, :, 1], grouped_xyz[3, :, 2]) 480 | 481 | # return grouped_xyz 482 | 483 | 484 | if __name__ == '__main__': 485 | dataroot = '../data/' 486 | DATA_DIR = os.path.join(dataroot, "PointDA_data", "shapenet_norm_curv_angle") 487 | npy_list = sorted(glob.glob(os.path.join(DATA_DIR, '*', 'train', '*.npy'))) 488 | pc_list = [] 489 | for _dir in npy_list: 490 | pc_list.append(_dir) 491 | 492 | data = np.load(pc_list[18])[:, :3].astype(np.float32) 493 | 494 | data_raw = data 495 | pcd = o3d.geometry.PointCloud() 496 | pcd.points = o3d.utility.Vector3dVector(data_raw) 497 | o3d.io.write_point_cloud('raw.ply', pcd) 498 | 499 | density_data = density(data) 500 | pcd = o3d.geometry.PointCloud() 501 | pcd.points = o3d.utility.Vector3dVector(density_data) 502 | o3d.io.write_point_cloud('density.ply', pcd) 503 | 504 | drop_data = drop_hole(data) 505 | pcd = o3d.geometry.PointCloud() 506 | pcd.points = o3d.utility.Vector3dVector(drop_data) 507 | o3d.io.write_point_cloud('drop.ply', pcd) 508 | 509 | scan_data = p_scan(data) 510 | pcd = o3d.geometry.PointCloud() 511 | pcd.points = o3d.utility.Vector3dVector(scan_data) 512 | o3d.io.write_point_cloud('scan.ply', pcd) 513 | 514 | # pdb.set_trace() 515 | print(data.shape) 516 | 517 | -------------------------------------------------------------------------------- /SPST_finetune_PCFEA_cls.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import torch 4 | import math 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | from torch.utils.data import Dataset 8 | from utils.pc_utils import random_rotate_one_axis 9 | from torch.optim.lr_scheduler import CosineAnnealingLR 10 | from torch.utils.data.sampler import SubsetRandomSampler 11 | from torch.utils.data import DataLoader 12 | import sklearn.metrics as metrics 13 | import argparse 14 | import copy 15 | import utils.log_SPST 16 | from data.dataloader_GraspNetPC import GraspNetRealPointClouds, GraspNetSynthetictPointClouds 17 | from data.dataloader_PointDA_initial import ScanNet, ModelNet, ShapeNet 18 | from models.model import linear_DGCNN_model 19 | import pdb 20 | 21 | MAX_LOSS = 9 * (10 ** 9) 22 | 23 | 24 | def str2bool(v): 25 | """ 26 | Input: 27 | v - string 28 | output: 29 | True/False 30 | """ 31 | if isinstance(v, bool): 32 | return v 33 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 34 | return True 35 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 36 | return False 37 | else: 38 | raise argparse.ArgumentTypeError('Boolean value expected.') 39 | 40 | 41 | # ================== 42 | # Argparse 43 | # ================== 44 | parser = argparse.ArgumentParser(description='DA on Point Clouds') 45 | parser.add_argument('--dataroot', type=str, default='../data/', metavar='N', help='data path') 46 | parser.add_argument('--out_path', type=str, default='./experiments/', help='log folder path') 47 | parser.add_argument('--num_workers', type=int, default=2, help='number of workers in dataloader') 48 | parser.add_argument('--exp_name', type=str, default='test2', help='Name of the experiment') 49 | 50 | # model 51 | parser.add_argument('--model', type=str, default='DGCNN', choices=['PointNet++', 'DGCNN'], help='Model to use') 52 | parser.add_argument('--num_class', type=int, default=10, help='number of classes per dataset') 53 | parser.add_argument('--dropout', type=float, default=0.5, help='dropout rate') 54 | parser.add_argument('--use_avg_pool', type=str2bool, default=False, help='Using average pooling & max pooling or max pooling only') 55 | 56 | # training details 57 | parser.add_argument('--epochs', type=int, default=10, help='number of episode per iteration to train') 58 | parser.add_argument('--num_iterations', type=int, default=10, help='number of SPST iterations') 59 | parser.add_argument('--src_dataset', type=str, default='Syn', choices=['Syn', 'Kin', 'RS', 'modelnet', 'shapenet', 'scannet']) 60 | parser.add_argument('--trgt_dataset', type=str, default='Kin', choices=['Kin', 'RS', 'modelnet', 'shapenet', 'scannet']) 61 | parser.add_argument('--seed', type=int, default=1, help='random seed (default: 1)') 62 | parser.add_argument('--gpus', type=lambda s: [int(item.strip()) for item in s.split(',')], default='1', 63 | help='comma delimited of gpu ids to use. Use "-1" for cpu usage') 64 | parser.add_argument('--batch_size', type=int, default=12, metavar='batch_size', help='Size of train batch per domain') 65 | parser.add_argument('--test_batch_size', type=int, default=12, metavar='batch_size', help='Size of test batch per domain') 66 | 67 | # method 68 | parser.add_argument('--base_threshold', type=float, default=0.8, help="base threshold to select target samples") 69 | parser.add_argument('--use_SPL', type=str2bool, default=True, help='Using self paced self train or not') 70 | parser.add_argument('--use_aug', type=str2bool, default=False, help='Using target augmentation or not (maybe can increase generalization)') 71 | parser.add_argument('--save_iter_model_by_val', type=str2bool, default=True, help='Saving model by val or test in each iteration') 72 | parser.add_argument('--mode_checkpoint', type=str, default='val', help='Using saved best model according to val or test') 73 | 74 | # optimizer 75 | parser.add_argument('--optimizer', type=str, default='ADAM', choices=['ADAM', 'SGD']) 76 | parser.add_argument('--lr', type=float, default=1e-3, help='learning rate') 77 | parser.add_argument('--momentum', type=float, default=0.9, help='SGD momentum') 78 | parser.add_argument('--wd', type=float, default=5e-5, help='weight decay') 79 | 80 | args = parser.parse_args() 81 | 82 | # ================== 83 | # init 84 | # ================== 85 | io = utils.log_SPST.IOStream(args) 86 | io.cprint(str(args)) 87 | 88 | random.seed(1) 89 | # np.random.seed(1) # to get the same point choice in ModelNet and ScanNet leave it fixed 90 | torch.manual_seed(args.seed) 91 | args.cuda = (args.gpus[0] >= 0) and torch.cuda.is_available() 92 | device = torch.device("cuda:" + str(args.gpus[0]) if args.cuda else "cpu") 93 | if args.cuda: 94 | io.cprint('Using GPUs ' + str(args.gpus) + ',' + ' from ' + 95 | str(torch.cuda.device_count()) + ' devices available') 96 | torch.cuda.manual_seed_all(args.seed) 97 | torch.backends.cudnn.enabled = False 98 | torch.backends.cudnn.benchmark = False 99 | torch.backends.cudnn.deterministic = True 100 | else: 101 | io.cprint('Using CPU') 102 | 103 | 104 | # ================== 105 | # Utils 106 | # ================== 107 | def split_set(dataset, domain, set_type="source"): 108 | """ 109 | Input: 110 | dataset 111 | domain - modelnet/shapenet/scannet 112 | type_set - source/target 113 | output: 114 | train_sampler, valid_sampler 115 | """ 116 | train_indices = dataset.train_ind 117 | val_indices = dataset.val_ind 118 | unique, counts = np.unique(dataset.label[train_indices], return_counts=True) 119 | io.cprint("Occurrences count of classes in " + set_type + " " + domain + 120 | " train part: " + str(dict(zip(unique, counts)))) 121 | unique, counts = np.unique(dataset.label[val_indices], return_counts=True) 122 | io.cprint("Occurrences count of classes in " + set_type + " " + domain + 123 | " validation part: " + str(dict(zip(unique, counts)))) 124 | # Creating PT data samplers and loaders: 125 | train_sampler = SubsetRandomSampler(train_indices) 126 | valid_sampler = SubsetRandomSampler(val_indices) 127 | return train_sampler, valid_sampler 128 | 129 | 130 | class DataLoad(Dataset): 131 | def __init__(self, io, data): 132 | self.pc, self.aug_pc, self.label, self.real_label, self.num_data = data 133 | self.num_examples = len(self.pc) 134 | 135 | io.cprint("number of selected examples in train set: " + str(len(self.pc))) 136 | unique, counts = np.unique(self.label, return_counts=True) 137 | io.cprint("Occurrences count of classes in train set: " + str(dict(zip(unique, counts)))) 138 | 139 | def __getitem__(self, item): 140 | pc = np.copy(self.pc[item]) 141 | aug_pc = np.copy(self.aug_pc[item]) 142 | label = np.copy(self.label[item]) 143 | real_label = np.copy(self.real_label[item]) 144 | return (pc, aug_pc, label, real_label) 145 | 146 | def __len__(self): 147 | return len(self.pc) 148 | 149 | 150 | def select_sample_by_conf(device, threshold, data_loader, model=None): 151 | pc_list = [] 152 | aug_pc_list = [] 153 | label_list = [] 154 | real_label_list = [] 155 | sfm = nn.Softmax(dim=1) 156 | 157 | total_number = 0 158 | 159 | with torch.no_grad(): 160 | model.eval() 161 | for data_all in data_loader: 162 | data = data_all[1] 163 | labels = data_all[2] 164 | aug_data = data_all[3] 165 | data, labels, aug_data = data.to(device), labels.long().to(device), aug_data.to(device) 166 | 167 | if data.shape[2] < data.shape[1]: 168 | data = data.permute(0, 2, 1) # data: B, C, N 169 | if aug_data.shape[2] < aug_data.shape[1]: 170 | aug_data = aug_data.permute(0, 2, 1) # data: B, C, N 171 | 172 | batch_size = data.size()[0] 173 | total_number += batch_size 174 | 175 | logits = model(data) 176 | cls_pred = logits["pred"] 177 | cls_pred_sfm = sfm(cls_pred) 178 | cls_pred_conf, cls_pred_label = torch.max(cls_pred_sfm, 1) # 2 * b 179 | 180 | index = 0 181 | for ii in range(batch_size): 182 | if cls_pred_conf[ii] > threshold: 183 | # pdb.set_trace() 184 | if len(pc_list) is 0: 185 | pc_list = data[index].detach().cpu().unsqueeze(0) 186 | label_list = cls_pred_label[index].detach().cpu().unsqueeze(0) 187 | real_label_list = labels[index].detach().cpu().unsqueeze(0) 188 | aug_pc_list = aug_data[index].detach().cpu().unsqueeze(0) 189 | else: 190 | pc_list = torch.cat((pc_list, data[index].detach().cpu().unsqueeze(0)), dim=0) 191 | label_list = torch.cat((label_list, cls_pred_label[index].detach().cpu().unsqueeze(0)), dim=0) 192 | real_label_list = torch.cat((real_label_list, labels[index].detach().cpu().unsqueeze(0)), dim=0) 193 | aug_pc_list = torch.cat((aug_pc_list, aug_data[index].detach().cpu().unsqueeze(0)), dim=0) 194 | 195 | index += 1 196 | 197 | return pc_list, aug_pc_list, label_list, real_label_list, total_number 198 | 199 | 200 | # ================== 201 | # Data loader 202 | # ================== 203 | 204 | src_dataset = args.src_dataset 205 | trgt_dataset = args.trgt_dataset 206 | 207 | # source 208 | if src_dataset == 'modelnet': 209 | src_trainset = ModelNet(io, args.dataroot, 'train') 210 | src_testset = ModelNet(io, args.dataroot, 'test') 211 | 212 | elif src_dataset == 'shapenet': 213 | src_trainset = ShapeNet(io, args.dataroot, 'train') 214 | src_testset = ShapeNet(io, args.dataroot, 'test') 215 | 216 | elif src_dataset == 'scannet': 217 | src_trainset = ScanNet(io, args.dataroot, 'train') 218 | src_testset = ScanNet(io, args.dataroot, 'test') 219 | 220 | elif src_dataset == 'Syn': 221 | if trgt_dataset == 'RS': 222 | trgt_device = 'realsense' 223 | if trgt_dataset == 'Kin': 224 | trgt_device = 'kinect' 225 | src_trainset = GraspNetSynthetictPointClouds(args.dataroot, partition='train') 226 | src_testset = GraspNetRealPointClouds(args.dataroot, mode=trgt_device, partition='test') 227 | 228 | elif src_dataset == 'Kin': 229 | src_trainset = GraspNetRealPointClouds(args.dataroot, mode='kinect', partition='train') 230 | src_testset = GraspNetRealPointClouds(args.dataroot, mode='kinect', partition='test') 231 | 232 | elif src_dataset == 'RS': 233 | src_trainset = GraspNetRealPointClouds(args.dataroot, mode='realsense', partition='train') 234 | src_testset = GraspNetRealPointClouds(args.dataroot, mode='realsense', partition='test') 235 | 236 | else: 237 | io.cprint('unknown src dataset') 238 | 239 | # target 240 | if trgt_dataset == 'modelnet': 241 | trgt_trainset = ModelNet(io, args.dataroot, 'train') 242 | trgt_testset = ModelNet(io, args.dataroot, 'test') 243 | 244 | elif trgt_dataset == 'shapenet': 245 | trgt_trainset = ShapeNet(io, args.dataroot, 'train') 246 | trgt_testset = ShapeNet(io, args.dataroot, 'test') 247 | 248 | elif trgt_dataset == 'scannet': 249 | trgt_trainset = ScanNet(io, args.dataroot, 'train') 250 | trgt_testset = ScanNet(io, args.dataroot, 'test') 251 | 252 | elif trgt_dataset == 'Kin': 253 | trgt_trainset = GraspNetRealPointClouds(args.dataroot, mode='kinect', partition='train') 254 | trgt_testset = GraspNetRealPointClouds(args.dataroot, mode='kinect', partition='test') 255 | 256 | elif trgt_dataset == 'RS': 257 | trgt_trainset = GraspNetRealPointClouds(args.dataroot, mode='realsense', partition='train') 258 | trgt_testset = GraspNetRealPointClouds(args.dataroot, mode='realsense', partition='test') 259 | 260 | else: 261 | io.cprint('unknown trgt dataset') 262 | 263 | src_train_sampler, src_valid_sampler = split_set(src_trainset, src_dataset, "source") 264 | trgt_train_sampler, trgt_valid_sampler = split_set(trgt_trainset, trgt_dataset, "target") 265 | 266 | # dataloaders for source and target 267 | src_train_loader = DataLoader(src_trainset, num_workers=args.num_workers, batch_size=args.batch_size, sampler=src_train_sampler, drop_last=True) 268 | src_val_loader = DataLoader(src_trainset, num_workers=args.num_workers, batch_size=args.test_batch_size, sampler=src_valid_sampler) 269 | src_test_loader = DataLoader(src_testset, num_workers=args.num_workers, batch_size=args.test_batch_size) 270 | 271 | trgt_train_loader = DataLoader(trgt_trainset, num_workers=args.num_workers, batch_size=args.batch_size, sampler=trgt_train_sampler, drop_last=True) 272 | trgt_val_loader = DataLoader(trgt_trainset, num_workers=args.num_workers, batch_size=args.test_batch_size, sampler=trgt_valid_sampler) 273 | trgt_test_loader = DataLoader(trgt_testset, num_workers=args.num_workers, batch_size=args.test_batch_size) 274 | 275 | 276 | # ================== 277 | # Init Model 278 | # ================== 279 | model = linear_DGCNN_model(args) 280 | model = model.to(device) 281 | io.cprint("------------------------------------------------------------------") 282 | try: 283 | if args.mode_checkpoint == 'val': 284 | checkpoint = torch.load(args.out_path + '/' + args.src_dataset + '_' + args.trgt_dataset + '/' + args.model + '/' + args.exp_name + '/save_best_by_val/model.pt') 285 | else: 286 | checkpoint = torch.load(args.out_path + '/' + args.src_dataset + '_' + args.trgt_dataset + '/' + args.model + '/' + args.exp_name + '/save_best_by_test/model.pt') 287 | start_epoch = checkpoint['epoch'] 288 | model.load_state_dict(checkpoint['model']) 289 | io.cprint('load saved model') 290 | except: 291 | start_epoch = 0 292 | io.cprint('no saved model') 293 | 294 | # Handle multi-gpu 295 | if (device.type == 'cuda') and len(args.gpus) > 1: 296 | model = nn.DataParallel(model, args.gpus) 297 | best_model = copy.deepcopy(model) 298 | 299 | 300 | # ================== 301 | # Optimizer 302 | # ================== 303 | if args.optimizer == "SGD": 304 | opt = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.wd) 305 | else: 306 | opt = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd) 307 | 308 | scheduler = CosineAnnealingLR(opt, args.epochs) 309 | 310 | criterion_cls = nn.CrossEntropyLoss() 311 | 312 | 313 | # ================== 314 | # Validation/test 315 | # ================== 316 | def test(loader, model=None, set_type="Target", partition="Val", epoch=0): 317 | # Run on cpu or gpu 318 | count = 0.0 319 | print_losses = {'cls': 0.0} 320 | batch_idx = 0 321 | 322 | with torch.no_grad(): 323 | model.eval() 324 | 325 | test_pred = [] 326 | test_true = [] 327 | 328 | num_sample = 0 329 | 330 | for data_all in loader: 331 | data, labels = data_all[1], data_all[2] 332 | data, labels = data.to(device), labels.to(device).squeeze() 333 | 334 | if data.shape[0] == 1: 335 | labels = labels.unsqueeze(0) 336 | 337 | if data.shape[1] > data.shape[2]: 338 | data = data.permute(0, 2, 1) 339 | 340 | batch_size = data.size()[0] 341 | num_point = data.shape[-1] 342 | 343 | num_sample = num_sample + batch_size 344 | 345 | logits = model(data) 346 | loss = criterion_cls(logits["pred"], labels) 347 | print_losses['cls'] += loss.item() * batch_size 348 | 349 | # evaluation metrics 350 | preds = logits["pred"].max(dim=1)[1] 351 | 352 | test_true.append(labels.cpu().numpy()) 353 | test_pred.append(preds.detach().cpu().numpy()) 354 | 355 | count += batch_size 356 | batch_idx += 1 357 | 358 | test_true = np.concatenate(test_true) 359 | test_pred = np.concatenate(test_pred) 360 | 361 | print_losses = {k: v * 1.0 / count for (k, v) in print_losses.items()} 362 | 363 | test_acc = io.print_progress(set_type, partition, epoch, print_losses, test_true, test_pred) 364 | 365 | conf_mat = metrics.confusion_matrix(test_true, test_pred, labels=list(range(args.num_class))).astype(int) 366 | 367 | return test_acc, print_losses['cls'], conf_mat 368 | 369 | 370 | # ================== 371 | # Train 372 | # ================== 373 | # first test the performance of the loaded model 374 | io.cprint("------------------------------------------------------------------") 375 | trgt_test_acc, trgt_test_loss, trgt_test_conf_mat = test(trgt_test_loader, model, "Target", "Test", 0) 376 | io.cprint("------------------------------------------------------------------") 377 | io.cprint("the performance of the loaded model is: %.4f" % (trgt_test_acc)) 378 | 379 | trgt_best_acc_by_val = trgt_test_acc 380 | trgt_best_acc_by_test = trgt_test_acc 381 | 382 | best_epoch_by_val = 0 383 | best_epoch_by_test = 0 384 | 385 | threshold_epoch = args.base_threshold 386 | 387 | sfm = nn.Softmax(dim=1) 388 | 389 | io.cprint("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++") 390 | io.cprint("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++") 391 | 392 | for ii in range(args.num_iterations): 393 | 394 | # determine threshold 395 | if trgt_best_acc_by_test > 0.9: 396 | threshold_epoch = 0.95 397 | 398 | trgt_best_acc_by_val_by_iter = 0 399 | trgt_best_acc_by_test_by_iter = 0 400 | 401 | best_epoch_by_val_by_iter = 0 402 | best_epoch_by_test_by_iter = 0 403 | 404 | io.cprint("==================================================================") 405 | io.cprint("iteration: %d, current threshold: %.4f" % (ii, threshold_epoch)) 406 | io.cprint("------------------------------------------------------------------") 407 | 408 | trgt_select_data = select_sample_by_conf(device, threshold_epoch, trgt_train_loader, model) 409 | trgt_new_data = DataLoad(io, trgt_select_data) 410 | trgt_new_train_loader = DataLoader(trgt_new_data, num_workers=args.num_workers, batch_size=args.batch_size, drop_last=True) 411 | io.cprint("------------------------------------------------------------------") 412 | 413 | count = 0.0 414 | print_losses = {'cls': 0.0, 'total': 0.0} 415 | 416 | for epoch in range(args.epochs): 417 | 418 | model.train() 419 | 420 | for trgt_data_all in trgt_new_train_loader: 421 | 422 | opt.zero_grad() 423 | 424 | trgt_data, aug_trgt_data, trgt_label, trgt_real_label = trgt_data_all[0].to(device), trgt_data_all[1].to(device), trgt_data_all[2].long().to(device), trgt_data_all[3].long().to(device) 425 | 426 | if trgt_data.shape[1] > trgt_data.shape[2]: 427 | trgt_data = trgt_data.permute(0, 2, 1) 428 | if aug_trgt_data.shape[1] > aug_trgt_data.shape[2]: 429 | aug_trgt_data = aug_trgt_data.permute(0, 2, 1) 430 | 431 | batch_size = trgt_data.shape[0] 432 | num_point = trgt_data.shape[-1] 433 | 434 | # start training process 435 | if args.use_aug: 436 | trgt_logits = model(aug_trgt_data) 437 | else: 438 | trgt_logits = model(trgt_data) 439 | 440 | # ============== # 441 | # calculate loss # 442 | # ============== # 443 | trgt_feature = trgt_logits['feature'] 444 | trgt_pred = trgt_logits['pred'] 445 | trgt_pred_sfm = sfm(trgt_pred) 446 | 447 | cls_loss = criterion_cls(trgt_pred, trgt_label) 448 | total_loss = cls_loss 449 | 450 | print_losses['cls'] += total_loss.item() * batch_size 451 | print_losses['total'] += total_loss.item() * batch_size 452 | 453 | total_loss.backward() 454 | 455 | count += batch_size 456 | 457 | opt.step() 458 | 459 | scheduler.step() 460 | 461 | print_losses = {k: v * 1.0 / (count + 1e-6) for (k, v) in print_losses.items()} 462 | io.print_progress("Target", "Trn", epoch, print_losses) 463 | io.cprint("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++") 464 | 465 | # =================== 466 | # Test 467 | # =================== 468 | io.cprint("------------------------------------------------------------------") 469 | trgt_val_acc, trgt_val_loss, trgt_val_conf_mat = test(trgt_val_loader, model, "Target", "Val", epoch) 470 | io.cprint("------------------------------------------------------------------") 471 | trgt_test_acc, trgt_test_loss, trgt_test_conf_mat = test(trgt_test_loader, model, "Target", "Test", epoch) 472 | io.cprint("------------------------------------------------------------------") 473 | 474 | if trgt_val_acc > trgt_best_acc_by_val_by_iter: 475 | trgt_best_acc_by_val_by_iter = trgt_test_acc 476 | best_epoch_by_val_by_iter = epoch 477 | best_epoch_conf_mat_by_val_by_iter = trgt_test_conf_mat 478 | best_model_by_val_by_iter = copy.deepcopy(model) 479 | 480 | if trgt_test_acc > trgt_best_acc_by_test_by_iter: 481 | trgt_best_acc_by_test_by_iter = trgt_test_acc 482 | best_epoch_by_test_by_iter = epoch 483 | best_epoch_conf_mat_by_test_by_iter = trgt_test_conf_mat 484 | best_model_by_test_by_iter = copy.deepcopy(model) 485 | 486 | io.cprint("------------------------------------------------------------------") 487 | io.cprint("iteration: %d, epoch: %d, " % (ii, epoch)) 488 | io.cprint("previous best target test accuracy saved by val during each iteration: %.4f" % (trgt_best_acc_by_val_by_iter)) 489 | io.cprint("previous best target test accuracy saved by test during each iteration: %.4f" % (trgt_best_acc_by_test_by_iter)) 490 | io.cprint("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++") 491 | 492 | # update 493 | if args.save_iter_model_by_val: 494 | model = copy.deepcopy(trgt_best_acc_by_val_by_iter) 495 | else: 496 | model = copy.deepcopy(best_model_by_test_by_iter) 497 | 498 | if args.use_SPL: 499 | threshold_epoch += 0.01 500 | if threshold_epoch > 0.95: 501 | threshold_epoch = 0.95 502 | 503 | if trgt_best_acc_by_val_by_iter > trgt_best_acc_by_val: 504 | trgt_best_acc_by_val = trgt_best_acc_by_val_by_iter 505 | best_epoch_by_val = best_epoch_by_val_by_iter 506 | best_epoch_conf_mat_by_val = best_epoch_conf_mat_by_val_by_iter 507 | best_model_by_val = io.save_model(model, epoch, 'save_best_by_SPST_val') 508 | 509 | if trgt_best_acc_by_test_by_iter > trgt_best_acc_by_test: 510 | trgt_best_acc_by_test = trgt_best_acc_by_test_by_iter 511 | best_epoch_by_test = best_epoch_by_test_by_iter 512 | best_epoch_conf_mat_by_test = best_epoch_conf_mat_by_test_by_iter 513 | best_model_by_test = io.save_model(model, epoch, 'save_best_by_SPST_test') 514 | 515 | io.cprint("Best model searched by val was found at epoch %d, target test accuracy: %.4f" 516 | % (best_epoch_by_val, trgt_best_acc_by_val)) 517 | io.cprint("Best test model confusion matrix:") 518 | io.cprint('\n' + str(best_epoch_conf_mat_by_val)) 519 | 520 | io.cprint("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++") 521 | io.cprint("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++") 522 | 523 | io.cprint("Best model searched by test was found at epoch %d, target test accuracy: %.4f" 524 | % (best_epoch_by_test, trgt_best_acc_by_test)) 525 | io.cprint("Best test model confusion matrix:") 526 | io.cprint('\n' + str(best_epoch_conf_mat_by_test)) 527 | 528 | io.cprint("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++") 529 | io.cprint("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++") 530 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | # add a average feature cls 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from utils.trans_norm import TransNorm2d 6 | import pdb 7 | import argparse 8 | from models.pointnet_util import PointNetSetAbstraction 9 | 10 | K = 20 11 | 12 | 13 | def index_points(points, idx): 14 | ''' 15 | 16 | Input: 17 | points: input points data, [B, N, C] 18 | idx: sample index data, [B, S] 19 | Return: 20 | new_points:, indexed points data, [B, S, C] 21 | ''' 22 | device = points.device 23 | B = points.shape[0] 24 | view_shape = list(idx.shape) 25 | view_shape[1:] = [1] * (len(view_shape) - 1) 26 | repeat_shape = list(idx.shape) 27 | repeat_shape[0] = 1 28 | batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape) 29 | new_points = points[batch_indices, idx, :] 30 | return new_points 31 | 32 | 33 | class Mapping(nn.Module): 34 | def __init__(self, input_channel, hidden_channel, output_channel): 35 | super(Mapping, self).__init__() 36 | self.fc1 = nn.Conv1d(input_channel, hidden_channel, 1) 37 | self.fc2 = nn.Conv1d(hidden_channel, output_channel, 1) 38 | self.bn1 = nn.BatchNorm1d(hidden_channel) 39 | 40 | def forward(self, x): 41 | x = F.relu(self.bn1(self.fc1(x))) 42 | x = self.fc2(x) 43 | x = x / torch.norm(x, p=2, dim=-2, keepdim=True) 44 | return x 45 | 46 | 47 | def knn(x, k): 48 | inner = -2 * torch.matmul(x.transpose(2, 1), x) 49 | xx = torch.sum(x ** 2, dim=1, keepdim=True) 50 | pairwise_distance = -xx - inner - xx.transpose(2, 1) 51 | 52 | # pdb.set_trace() 53 | # (batch_size, num_points, k) 54 | idx = pairwise_distance.topk(k=k, dim=-1)[1] 55 | return idx 56 | 57 | 58 | def get_graph_feature(x, args, k=20, idx=None): 59 | batch_size = x.size(0) 60 | num_points = x.size(2) 61 | x = x.view(batch_size, -1, num_points) 62 | if idx is None: 63 | idx = knn(x, k=k) # (batch_size, num_points, k) 64 | # Run on cpu or gpu 65 | device = x.device 66 | 67 | idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_points 68 | 69 | idx = idx + idx_base 70 | 71 | idx = idx.view(-1) 72 | 73 | _, num_dims, _ = x.size() 74 | 75 | x = x.transpose(2, 1).contiguous() # (batch_size, num_points, num_dims) -> (batch_size*num_points, num_dims) 76 | feature = x.view(batch_size * num_points, -1)[idx, :] # matrix [k*num_points*batch_size,3] 77 | feature = feature.view(batch_size, num_points, k, num_dims) 78 | x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1) 79 | 80 | feature = torch.cat((feature - x, x), dim=3).permute(0, 3, 1, 2) 81 | 82 | return feature 83 | 84 | 85 | def l2_norm(input, axit=1): 86 | norm = torch.norm(input, 2, axit, True) 87 | output = torch.div(input, norm) 88 | return output 89 | 90 | 91 | class conv_2d(nn.Module): 92 | def __init__(self, in_ch, out_ch, kernel, activation='relu', bias=True): 93 | super(conv_2d, self).__init__() 94 | if activation == 'relu': 95 | self.conv = nn.Sequential( 96 | nn.Conv2d(in_ch, out_ch, kernel_size=kernel, bias=bias), 97 | # nn.BatchNorm2d(out_ch), 98 | # nn.InstanceNorm2d(out_ch), 99 | TransNorm2d(out_ch), 100 | nn.ReLU(inplace=True) 101 | ) 102 | elif activation == 'leakyrelu': 103 | self.conv = nn.Sequential( 104 | nn.Conv2d(in_ch, out_ch, kernel_size=kernel, bias=bias), 105 | # nn.BatchNorm2d(out_ch), 106 | # nn.InstanceNorm2d(out_ch), 107 | TransNorm2d(out_ch), 108 | nn.LeakyReLU(negative_slope=0.2, inplace=True) 109 | ) 110 | 111 | def forward(self, x): 112 | x = self.conv(x) 113 | return x 114 | 115 | 116 | class fc_layer(nn.Module): 117 | def __init__(self, in_ch, out_ch, bn=False, activation='relu', bias=True): 118 | super(fc_layer, self).__init__() 119 | if activation == 'relu': 120 | self.ac = nn.ReLU(inplace=True) 121 | elif activation == 'leakyrelu': 122 | self.ac = nn.LeakyReLU(negative_slope=0.2, inplace=True) 123 | if bn: 124 | self.fc = nn.Sequential( 125 | nn.Linear(in_ch, out_ch, bias=bias), 126 | nn.LayerNorm(out_ch), 127 | self.ac 128 | ) 129 | else: 130 | self.fc = nn.Sequential( 131 | nn.Linear(in_ch, out_ch, bias=bias), 132 | self.ac 133 | ) 134 | 135 | def forward(self, x): 136 | x = l2_norm(x, 1) 137 | x = self.fc(x) 138 | return x 139 | 140 | 141 | class transform_net(nn.Module): 142 | ''' Input (XYZ) Transform Net, input is BxNx3 gray image 143 | Return: Transformation matrix of size 3xK ''' 144 | 145 | def __init__(self, args, in_ch, out=3): 146 | super(transform_net, self).__init__() 147 | self.K = out 148 | self.args = args 149 | 150 | activation = 'leakyrelu' if args.model == 'DGCNN' else 'relu' 151 | bias = False if args.model == 'DGCNN' else True 152 | 153 | self.conv2d1 = conv_2d(in_ch, 64, kernel=1, activation=activation, bias=bias) 154 | self.conv2d2 = conv_2d(64, 128, kernel=1, activation=activation, bias=bias) 155 | self.conv2d3 = conv_2d(128, 1024, kernel=1, activation=activation, bias=bias) 156 | self.fc1 = fc_layer(1024, 512, activation=activation, bias=bias, bn=True) 157 | self.fc2 = fc_layer(512, 256, activation=activation, bn=True) 158 | self.fc3 = nn.Linear(256, out * out) 159 | 160 | def forward(self, x): 161 | device = x.device 162 | 163 | x = self.conv2d1(x) 164 | x = self.conv2d2(x) 165 | if self.args.model == 'DGCNN': 166 | x = x.max(dim=-1, keepdim=False)[0] 167 | x = torch.unsqueeze(x, dim=3) 168 | x = self.conv2d3(x) 169 | x, _ = torch.max(x, dim=2, keepdim=False) 170 | x = x.view(x.size(0), -1) 171 | x = self.fc1(x) 172 | x = self.fc2(x) 173 | x = self.fc3(x) 174 | 175 | iden = torch.eye(self.K).view(1, self.K * self.K).repeat(x.size(0), 1) 176 | iden = iden.to(device) 177 | x = x + iden 178 | x = x.view(x.size(0), self.K, self.K) 179 | return x 180 | 181 | 182 | class DGCNN_encoder(nn.Module): 183 | def __init__(self, args): 184 | super(DGCNN_encoder, self).__init__() 185 | self.args = args 186 | self.k = K 187 | self.use_avg_pool = args.use_avg_pool 188 | 189 | self.input_transform_net = transform_net(args, 6, 3) 190 | 191 | self.conv1 = conv_2d(6, 64, kernel=1, bias=False, activation='leakyrelu') 192 | self.conv2 = conv_2d(64 * 2, 64, kernel=1, bias=False, activation='leakyrelu') 193 | self.conv3 = conv_2d(64 * 2, 128, kernel=1, bias=False, activation='leakyrelu') 194 | self.conv4 = conv_2d(128 * 2, 256, kernel=1, bias=False, activation='leakyrelu') 195 | num_f_prev = 64 + 64 + 128 + 256 196 | 197 | if self.use_avg_pool: 198 | # use avepooling + maxpooling 199 | self.conv5 = nn.Conv1d(num_f_prev, 512, kernel_size=1, bias=False) 200 | self.bn5 = nn.BatchNorm1d(512) 201 | else: 202 | # use only maxpooling 203 | self.conv5 = nn.Conv1d(num_f_prev, 1024, kernel_size=1, bias=False) 204 | self.bn5 = nn.BatchNorm1d(1024) 205 | 206 | def forward(self, x): 207 | batch_size = x.size(0) 208 | num_points = x.size(2) 209 | cls_logits = {} 210 | 211 | x = get_graph_feature(x, self.args, k=self.k) # x: [b, 6, 1024, 20] 212 | x = self.conv1(x) # x: [b, 64, 1024, 20] 213 | x1 = x.max(dim=-1, keepdim=False)[0] # B, 64, 1024 214 | 215 | x = get_graph_feature(x1, self.args, k=self.k) # [b, 128, 1024, 20] 216 | x = self.conv2(x) # [b, 64, 1024, 20] 217 | x2 = x.max(dim=-1, keepdim=False)[0] # [b, 64, 1024] 218 | 219 | x = get_graph_feature(x2, self.args, k=self.k) # [b, 128, 1024, 20] 220 | x = self.conv3(x) # [b, 128, 1024, 20] 221 | x3 = x.max(dim=-1, keepdim=False)[0] # [b, 128, 1024] 222 | 223 | x = get_graph_feature(x3, self.args, k=self.k) # [b, 256, 1024, 20] 224 | x = self.conv4(x) # [b, 256, 1024, 20] 225 | x4 = x.max(dim=-1, keepdim=False)[0] # [b, 256, 1024] 226 | 227 | x_cat = torch.cat((x1, x2, x3, x4), dim=1) # [b, 512, 1024] 228 | 229 | if self.use_avg_pool: 230 | x5 = self.conv5(x_cat) # [b, 512, 1024] 231 | x5 = F.leaky_relu(self.bn5(x5), negative_slope=0.2) 232 | x5_1 = F.adaptive_max_pool1d(x5, 1).view(batch_size, -1) 233 | x5_2 = F.adaptive_avg_pool1d(x5, 1).view(batch_size, -1) 234 | x5_pool = torch.cat((x5_1, x5_2), 1) 235 | else: 236 | x5 = self.conv5(x_cat) # [b, 512, 1024] 237 | x5 = F.leaky_relu(self.bn5(x5), negative_slope=0.2) 238 | x5_pool = F.adaptive_max_pool1d(x5, 1).view(batch_size, -1) 239 | 240 | x = x5_pool 241 | 242 | cls_logits['feature'] = x 243 | 244 | return cls_logits 245 | 246 | 247 | class DGCNN_model(nn.Module): 248 | def __init__(self, args): 249 | super(DGCNN_model, self).__init__() 250 | self.args = args 251 | self.k = K 252 | self.use_avg_pool = args.use_avg_pool 253 | 254 | self.input_transform_net = transform_net(args, 6, 3) 255 | 256 | self.conv1 = conv_2d(6, 64, kernel=1, bias=False, activation='leakyrelu') 257 | self.conv2 = conv_2d(64 * 2, 64, kernel=1, bias=False, activation='leakyrelu') 258 | self.conv3 = conv_2d(64 * 2, 128, kernel=1, bias=False, activation='leakyrelu') 259 | self.conv4 = conv_2d(128 * 2, 256, kernel=1, bias=False, activation='leakyrelu') 260 | num_f_prev = 64 + 64 + 128 + 256 261 | 262 | if self.use_avg_pool: 263 | # use avepooling + maxpooling 264 | self.conv5 = nn.Conv1d(num_f_prev, 512, kernel_size=1, bias=False) 265 | self.bn5 = nn.BatchNorm1d(512) 266 | else: 267 | # use only maxpooling 268 | self.conv5 = nn.Conv1d(num_f_prev, 1024, kernel_size=1, bias=False) 269 | self.bn5 = nn.BatchNorm1d(1024) 270 | 271 | self.cls = class_classifier(args, 1024, args.num_class) 272 | 273 | def forward(self, x): 274 | batch_size = x.size(0) 275 | num_points = x.size(2) 276 | cls_logits = {} 277 | 278 | x = get_graph_feature(x, self.args, k=self.k) # x: [b, 6, 1024, 20] 279 | x = self.conv1(x) # x: [b, 64, 1024, 20] 280 | x1 = x.max(dim=-1, keepdim=False)[0] # B, 64, 1024 281 | 282 | x = get_graph_feature(x1, self.args, k=self.k) # [b, 128, 1024, 20] 283 | x = self.conv2(x) # [b, 64, 1024, 20] 284 | x2 = x.max(dim=-1, keepdim=False)[0] # [b, 64, 1024] 285 | 286 | x = get_graph_feature(x2, self.args, k=self.k) # [b, 128, 1024, 20] 287 | x = self.conv3(x) # [b, 128, 1024, 20] 288 | x3 = x.max(dim=-1, keepdim=False)[0] # [b, 128, 1024] 289 | 290 | x = get_graph_feature(x3, self.args, k=self.k) # [b, 256, 1024, 20] 291 | x = self.conv4(x) # [b, 256, 1024, 20] 292 | x4 = x.max(dim=-1, keepdim=False)[0] # [b, 256, 1024] 293 | 294 | x_cat = torch.cat((x1, x2, x3, x4), dim=1) # [b, 512, 1024] 295 | 296 | if self.use_avg_pool: 297 | x5 = self.conv5(x_cat) # [b, 512, 1024] 298 | x5 = F.leaky_relu(self.bn5(x5), negative_slope=0.2) 299 | x5_1 = F.adaptive_max_pool1d(x5, 1).view(batch_size, -1) 300 | x5_2 = F.adaptive_avg_pool1d(x5, 1).view(batch_size, -1) 301 | x5_pool = torch.cat((x5_1, x5_2), 1) 302 | else: 303 | x5 = self.conv5(x_cat) # [b, 512, 1024] 304 | x5 = F.leaky_relu(self.bn5(x5), negative_slope=0.2) 305 | x5_pool = F.adaptive_max_pool1d(x5, 1).view(batch_size, -1) 306 | 307 | x = x5_pool 308 | 309 | cls_logits['feature'] = x 310 | 311 | cls_logits['pred'] = self.cls(x) 312 | 313 | return cls_logits 314 | 315 | 316 | class linear_DGCNN_model(nn.Module): 317 | def __init__(self, args): 318 | super(linear_DGCNN_model, self).__init__() 319 | self.args = args 320 | self.k = K 321 | self.use_avg_pool = args.use_avg_pool 322 | 323 | self.input_transform_net = transform_net(args, 6, 3) 324 | 325 | self.conv1 = conv_2d(6, 64, kernel=1, bias=False, activation='leakyrelu') 326 | self.conv2 = conv_2d(64 * 2, 64, kernel=1, bias=False, activation='leakyrelu') 327 | self.conv3 = conv_2d(64 * 2, 128, kernel=1, bias=False, activation='leakyrelu') 328 | self.conv4 = conv_2d(128 * 2, 256, kernel=1, bias=False, activation='leakyrelu') 329 | num_f_prev = 64 + 64 + 128 + 256 330 | 331 | if self.use_avg_pool: 332 | # use avepooling + maxpooling 333 | self.conv5 = nn.Conv1d(num_f_prev, 512, kernel_size=1, bias=False) 334 | self.bn5 = nn.BatchNorm1d(512) 335 | else: 336 | # use only maxpooling 337 | self.conv5 = nn.Conv1d(num_f_prev, 1024, kernel_size=1, bias=False) 338 | self.bn5 = nn.BatchNorm1d(1024) 339 | 340 | self.cls = linear_classifier(1024, args.num_class) 341 | 342 | def forward(self, x): 343 | batch_size = x.size(0) 344 | num_points = x.size(2) 345 | cls_logits = {} 346 | 347 | x = get_graph_feature(x, self.args, k=self.k) # x: [b, 6, 1024, 20] 348 | x = self.conv1(x) # x: [b, 64, 1024, 20] 349 | x1 = x.max(dim=-1, keepdim=False)[0] # B, 64, 1024 350 | 351 | x = get_graph_feature(x1, self.args, k=self.k) # [b, 128, 1024, 20] 352 | x = self.conv2(x) # [b, 64, 1024, 20] 353 | x2 = x.max(dim=-1, keepdim=False)[0] # [b, 64, 1024] 354 | 355 | x = get_graph_feature(x2, self.args, k=self.k) # [b, 128, 1024, 20] 356 | x = self.conv3(x) # [b, 128, 1024, 20] 357 | x3 = x.max(dim=-1, keepdim=False)[0] # [b, 128, 1024] 358 | 359 | x = get_graph_feature(x3, self.args, k=self.k) # [b, 256, 1024, 20] 360 | x = self.conv4(x) # [b, 256, 1024, 20] 361 | x4 = x.max(dim=-1, keepdim=False)[0] # [b, 256, 1024] 362 | 363 | x_cat = torch.cat((x1, x2, x3, x4), dim=1) # [b, 512, 1024] 364 | 365 | if self.use_avg_pool: 366 | x5 = self.conv5(x_cat) # [b, 512, 1024] 367 | x5 = F.leaky_relu(self.bn5(x5), negative_slope=0.2) 368 | x5_1 = F.adaptive_max_pool1d(x5, 1).view(batch_size, -1) 369 | x5_2 = F.adaptive_avg_pool1d(x5, 1).view(batch_size, -1) 370 | x5_pool = torch.cat((x5_1, x5_2), 1) 371 | else: 372 | x5 = self.conv5(x_cat) # [b, 512, 1024] 373 | x5 = F.leaky_relu(self.bn5(x5), negative_slope=0.2) 374 | x5_pool = F.adaptive_max_pool1d(x5, 1).view(batch_size, -1) 375 | 376 | x = x5_pool 377 | 378 | cls_logits['feature'] = x 379 | 380 | cls_logits['pred'] = self.cls(x) 381 | 382 | return cls_logits 383 | 384 | 385 | class segmentation(nn.Module): 386 | def __init__(self, args, input_size, output_size): 387 | super(segmentation, self).__init__() 388 | self.args = args 389 | self.of1 = 256 390 | self.of2 = 256 391 | self.of3 = output_size 392 | 393 | self.bn1 = nn.BatchNorm1d(self.of1) 394 | self.bn2 = nn.BatchNorm1d(self.of2) 395 | self.bn3 = nn.BatchNorm1d(self.of3) 396 | self.dp1 = nn.Dropout(p=args.dropout) 397 | self.dp2 = nn.Dropout(p=args.dropout) 398 | 399 | self.conv1 = nn.Conv1d(input_size, self.of1, kernel_size=1, bias=True) 400 | self.conv2 = nn.Conv1d(self.of1, self.of2, kernel_size=1, bias=True) 401 | self.conv3 = nn.Conv1d(self.of2, self.of3, kernel_size=1, bias=True) 402 | 403 | def forward(self, x): 404 | x = self.dp1(F.relu(self.bn1(self.conv1(x)))) 405 | x = self.dp2(F.relu(self.bn2(self.conv2(x)))) 406 | x = F.relu(self.bn3(self.conv3(x))) 407 | return x.permute(0, 2, 1) # [b, 1024, 128] 408 | 409 | 410 | class linear_DGCNN_seg_model(nn.Module): 411 | def __init__(self, args): 412 | super(linear_DGCNN_seg_model, self).__init__() 413 | self.args = args 414 | self.k = K 415 | self.use_avg_pool = args.use_avg_pool 416 | 417 | self.input_transform_net = transform_net(args, 6, 3) 418 | 419 | self.conv1 = conv_2d(6, 64, kernel=1, bias=False, activation='leakyrelu') 420 | self.conv2 = conv_2d(64 * 2, 64, kernel=1, bias=False, activation='leakyrelu') 421 | self.conv3 = conv_2d(64 * 2, 128, kernel=1, bias=False, activation='leakyrelu') 422 | self.conv4 = conv_2d(128 * 2, 256, kernel=1, bias=False, activation='leakyrelu') 423 | num_f_prev = 64 + 64 + 128 + 256 424 | 425 | if self.use_avg_pool: 426 | # use avepooling + maxpooling 427 | self.conv5 = nn.Conv1d(num_f_prev, 512, kernel_size=1, bias=False) 428 | self.bn5 = nn.BatchNorm1d(512) 429 | else: 430 | # use only maxpooling 431 | self.conv5 = nn.Conv1d(num_f_prev, 1024, kernel_size=1, bias=False) 432 | self.bn5 = nn.BatchNorm1d(1024) 433 | 434 | self.seg = segmentation(args, input_size=1024 + 512, output_size=args.feature_dim) 435 | self.seg_cls = linear_classifier(args.feature_dim, args.num_class) # default: 8 436 | 437 | def forward(self, x): 438 | batch_size = x.size(0) 439 | num_points = x.size(2) 440 | cls_logits = {} 441 | 442 | x = get_graph_feature(x, self.args, k=self.k) # x: [b, 6, 1024, 20] 443 | x = self.conv1(x) # x: [b, 64, 1024, 20] 444 | x1 = x.max(dim=-1, keepdim=False)[0] # B, 64, 1024 445 | 446 | x = get_graph_feature(x1, self.args, k=self.k) # [b, 128, 1024, 20] 447 | x = self.conv2(x) # [b, 64, 1024, 20] 448 | x2 = x.max(dim=-1, keepdim=False)[0] # [b, 64, 1024] 449 | 450 | x = get_graph_feature(x2, self.args, k=self.k) # [b, 128, 1024, 20] 451 | x = self.conv3(x) # [b, 128, 1024, 20] 452 | x3 = x.max(dim=-1, keepdim=False)[0] # [b, 128, 1024] 453 | 454 | x = get_graph_feature(x3, self.args, k=self.k) # [b, 256, 1024, 20] 455 | x = self.conv4(x) # [b, 256, 1024, 20] 456 | x4 = x.max(dim=-1, keepdim=False)[0] # [b, 256, 1024] 457 | 458 | x_cat = torch.cat((x1, x2, x3, x4), dim=1) # [b, 512, 1024] 459 | 460 | if self.use_avg_pool: 461 | x5 = self.conv5(x_cat) # [b, 512, 1024] 462 | x5 = F.leaky_relu(self.bn5(x5), negative_slope=0.2) 463 | x5_1 = F.adaptive_max_pool1d(x5, 1).view(batch_size, -1) 464 | x5_2 = F.adaptive_avg_pool1d(x5, 1).view(batch_size, -1) 465 | x5_pool = torch.cat((x5_1, x5_2), 1) 466 | else: 467 | x5 = self.conv5(x_cat) # [b, 1024, 1024] 468 | x5 = F.leaky_relu(self.bn5(x5), negative_slope=0.2) 469 | x5_pool = F.adaptive_max_pool1d(x5, 1).view(batch_size, -1) 470 | 471 | x = torch.cat((x_cat, x5_pool.unsqueeze(2).repeat(1, 1, num_points)), dim=1) # [b, 1536, 1024] 472 | 473 | seg_feature = self.seg(x) # [b, 1024, 128] 474 | seg_pred = self.seg_cls(seg_feature) 475 | 476 | cls_logits['feature'] = seg_feature 477 | 478 | cls_logits['pred'] = seg_pred 479 | 480 | return cls_logits 481 | 482 | 483 | class class_classifier(nn.Module): 484 | def __init__(self, args, input_dim, num_class=10): 485 | super(class_classifier, self).__init__() 486 | 487 | activate = 'leakyrelu' if args.model == 'DGCNN' else 'relu' 488 | bias = True if args.model == 'DGCNN' else False 489 | 490 | self.mlp1 = fc_layer(input_dim, 512, bias=bias, activation=activate, bn=True) 491 | self.dp1 = nn.Dropout(p=args.dropout) 492 | self.mlp2 = fc_layer(512, 256, bias=True, activation=activate, bn=True) 493 | self.dp2 = nn.Dropout(p=args.dropout) 494 | self.mlp3 = nn.Linear(256, num_class) 495 | 496 | def forward(self, x): 497 | x = self.dp1(self.mlp1(x)) 498 | x = self.dp2(self.mlp2(x)) 499 | x = self.mlp3(x) 500 | return x 501 | 502 | 503 | class linear_classifier(nn.Module): 504 | def __init__(self, input_dim, num_class): 505 | super(linear_classifier, self).__init__() 506 | 507 | self.mlp = nn.Linear(input_dim, num_class) 508 | 509 | def forward(self, x): 510 | x = self.mlp(x) 511 | return x 512 | 513 | 514 | if __name__ == '__main__': 515 | 516 | def str2bool(v): 517 | ''' 518 | Input: 519 | v - string 520 | output: 521 | True/False 522 | ''' 523 | if isinstance(v, bool): 524 | return v 525 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 526 | return True 527 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 528 | return False 529 | else: 530 | raise argparse.ArgumentTypeError('Boolean value expected.') 531 | 532 | parser = argparse.ArgumentParser(description='DA on Point Clouds') 533 | parser.add_argument('--dataroot', type=str, default='../gast/data/', metavar='N', help='data path') 534 | parser.add_argument('--model', type=str, default='DGCNN', choices=['pointnet', 'DGCNN'], help='Model to use') 535 | parser.add_argument('--seed', type=int, default=1, help='random seed (default: 1)') 536 | parser.add_argument('--gpus', type=lambda s: [int(item.strip()) for item in s.split(',')], default='1', 537 | help='comma delimited of gpu ids to use. Use -1 for cpu usage') 538 | parser.add_argument('--num_class', type=int, default=10, help='number of classes per dataset') 539 | parser.add_argument('--use_avg_pool', type=str2bool, default=False, help='Using average pooling & max pooling or max pooling only') 540 | parser.add_argument('--batch_size', type=int, default=20, metavar='batch_size', 541 | help='Size of train batch per domain') 542 | parser.add_argument('--test_batch_size', type=int, default=20, metavar='batch_size', 543 | help='Size of test batch per domain') 544 | parser.add_argument('--lr', type=float, default=1e-3, help='learning rate') 545 | parser.add_argument('--momentum', type=float, default=0.9, help='SGD momentum') 546 | parser.add_argument('--wd', type=float, default=5e-5, help='weight decay') 547 | parser.add_argument('--dropout', type=float, default=0.5, help='dropout rate') 548 | parser.add_argument('--gamma', type=float, default=0.1, help='threshold for pseudo label') 549 | parser.add_argument('--out_path', type=str, default='./experiments', help='log folder path') 550 | parser.add_argument('--exp_name', type=str, default='test', help='Name of the experiment') 551 | 552 | args = parser.parse_args() 553 | 554 | args.cuda = (args.gpus[0] >= 0) and torch.cuda.is_available() 555 | 556 | data = torch.rand(4, 3, 1024).cuda() 557 | model = DGCNN_model(args).cuda() 558 | out = model(data) 559 | # print(1) 560 | 561 | -------------------------------------------------------------------------------- /utils/pc_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import random 4 | import pdb 5 | 6 | eps = 10e-4 7 | eps2 = 10e-6 8 | KL_SCALER = 10.0 9 | MIN_POINTS = 20 10 | RADIUS = 0.5 11 | NREGIONS = 3 12 | NROTATIONS = 4 13 | N = 16 14 | K = 4 15 | NUM_FEATURES = K * 3 + 1 16 | 17 | 18 | def region_mean(num_regions): 19 | """ 20 | Input: 21 | num_regions - number of regions 22 | Return: 23 | means of regions 24 | """ 25 | 26 | n = num_regions 27 | lookup = [] 28 | d = 2 / n # the cube size length 29 | # construct all possibilities on the line [-1, 1] in the 3 axes 30 | for i in range(n - 1, -1, -1): 31 | for j in range(n - 1, -1, -1): 32 | for k in range(n - 1, -1, -1): 33 | lookup.append([1 - d * (i + 0.5), 1 - d * (j + 0.5), 1 - d * (k + 0.5)]) 34 | lookup = np.array(lookup) # n**3 x 3 35 | return lookup 36 | 37 | 38 | def assign_region_to_point(X, device, NREGIONS=3): 39 | """ 40 | Input: 41 | X: point cloud [B, C, N] 42 | device: cuda:0, cpu 43 | Return: 44 | Y: Region assignment per point [B, N] 45 | """ 46 | 47 | n = NREGIONS 48 | d = 2 / n 49 | X_clip = torch.clamp(X, -0.99999999, 0.99999999) # [B, C, N] 50 | batch_size, _, num_points = X.shape 51 | Y = torch.zeros((batch_size, num_points), device=device, dtype=torch.long) # label matrix [B, N] 52 | 53 | # The code below partitions all points in the shape to voxels. 54 | # At each iteration find per axis the lower threshold and the upper threshold values 55 | # of the range according to n (e.g., if n=3, then: -1, -1/3, 1/3, 1 - there are 3 ranges) 56 | # and save points in the corresponding voxel if they fall in the examined range for all axis. 57 | region_id = 0 58 | for x in range(n): 59 | for y in range(n): 60 | for z in range(n): 61 | # lt= lower threshold, ut = upper threshold 62 | x_axis_lt = -1 + x * d < X_clip[:, 0, :] # [B, 1, N] 63 | x_axis_ut = X_clip[:, 0, :] < -1 + (x + 1) * d # [B, 1, N] 64 | y_axis_lt = -1 + y * d < X_clip[:, 1, :] # [B, 1, N] 65 | y_axis_ut = X_clip[:, 1, :] < -1 + (y + 1) * d # [B, 1, N] 66 | z_axis_lt = -1 + z * d < X_clip[:, 2, :] # [B, 1, N] 67 | z_axis_ut = X_clip[:, 2, :] < -1 + (z + 1) * d # [B, 1, N] 68 | # get a mask indicating for each coordinate of each point of each shape whether 69 | # it falls inside the current inspected ranges 70 | in_range = torch.cat([x_axis_lt, x_axis_ut, y_axis_lt, y_axis_ut, 71 | z_axis_lt, z_axis_ut], dim=1).view(batch_size, 6, -1) # [B, 6, N] 72 | # per each point decide if it falls in the current region only if in all 73 | # ranges the value is 1 (i.e., it falls inside all the inspected ranges) 74 | mask, _ = torch.min(in_range, dim=1) # [B, N] 75 | Y[mask] = region_id # label each point with the region id 76 | region_id += 1 77 | return Y 78 | 79 | 80 | def collapse_to_point(x, device): 81 | """ 82 | Input: 83 | X: point cloud [C, N] 84 | device: cuda:0, cpu 85 | Return: 86 | x: A deformed point cloud. Randomly sample a point and cluster all point 87 | within a radius of RADIUS around it with some Gaussian noise. 88 | indices: the points that were clustered around x 89 | """ 90 | # get pairwise distances 91 | inner = -2 * torch.matmul(x.transpose(1, 0), x) 92 | xx = torch.sum(x ** 2, dim=0, keepdim=True) 93 | pairwise_distance = xx + inner + xx.transpose(1, 0) 94 | 95 | # get mask of points in threshold 96 | mask = pairwise_distance.clone() 97 | mask[mask > RADIUS ** 2] = 100 98 | mask[mask <= RADIUS ** 2] = 1 99 | mask[mask == 100] = 0 100 | 101 | # Choose only from points that have more than MIN_POINTS within a RADIUS of them 102 | pts_pass = torch.sum(mask, dim=1) 103 | pts_pass[pts_pass < MIN_POINTS] = 0 104 | pts_pass[pts_pass >= MIN_POINTS] = 1 105 | indices = (pts_pass != 0).nonzero() 106 | 107 | # pick a point from the ones that passed the threshold 108 | point_ind = np.random.choice(indices.squeeze().cpu().numpy()) 109 | point = x[:, point_ind] # get point 110 | point_mask = mask[point_ind, :] # get point mask 111 | 112 | # draw a gaussian centered at the point for points falling in the region 113 | indices = (point_mask != 0).nonzero().squeeze() 114 | x[:, indices] = torch.tensor(draw_from_gaussian(point.cpu().numpy(), len(indices)), dtype=torch.float).to(device) 115 | return x, indices 116 | 117 | 118 | def draw_from_gaussian(mean, num_points): 119 | """ 120 | Input: 121 | mean: a numpy vector 122 | num_points: number of points to sample 123 | Return: 124 | points sampled around the mean with small std 125 | """ 126 | return np.random.multivariate_normal(mean, np.eye(3) * 0.001, num_points).T # 0.001 127 | 128 | 129 | def draw_from_uniform(gap, region_mean, num_points): 130 | """ 131 | Input: 132 | gap: a numpy vector of region x,y,z length in each direction from the mean 133 | region_mean: 134 | num_points: number of points to sample 135 | Return: 136 | points sampled uniformly in the region 137 | """ 138 | return np.random.uniform(region_mean - gap, region_mean + gap, (num_points, 3)).T 139 | 140 | 141 | def farthest_point_sample(args, xyz, npoint): 142 | """ 143 | Input: 144 | xyz: pointcloud data, [B, C, N] 145 | npoint: number of samples 146 | Return: 147 | centroids: sampled pointcloud index, [B, npoint] 148 | """ 149 | device = torch.device("cuda:" + str(xyz.get_device()) if args.cuda else "cpu") 150 | 151 | B, C, N = xyz.shape 152 | centroids = torch.zeros(B, npoint, dtype=torch.long).to(device) # B x npoint 153 | distance = torch.ones(B, N).to(device) * 1e10 154 | farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) 155 | batch_indices = torch.arange(B, dtype=torch.long).to(device) 156 | centroids_vals = torch.zeros(B, C, npoint).to(device) 157 | for i in range(npoint): 158 | centroids[:, i] = farthest # save current chosen point index 159 | centroid = xyz[batch_indices, :, farthest].view(B, C, 1) # get the current chosen point value 160 | centroids_vals[:, :, i] = centroid[:, :, 0].clone() 161 | dist = torch.sum((xyz - centroid) ** 2, 1) # euclidean distance of points from the current centroid 162 | mask = dist < distance # save index of all point that are closer than the current max distance 163 | distance[mask] = dist[mask] # save the minimal distance of each point from all points that were chosen until now 164 | farthest = torch.max(distance, -1)[1] # get the index of the point farthest away 165 | return centroids, centroids_vals 166 | 167 | 168 | def farthest_point_sample_np(xyz, npoint): 169 | """ 170 | Input: 171 | xyz: pointcloud data, [B, C, N] 172 | npoint: number of samples 173 | Return: 174 | centroids: sampled pointcloud index, [B, npoint] 175 | """ 176 | 177 | B, C, N = xyz.shape 178 | centroids = np.zeros((B, npoint), dtype=np.int64) 179 | distance = np.ones((B, N)) * 1e10 180 | farthest = np.random.randint(0, N, (B,), dtype=np.int64) 181 | batch_indices = np.arange(B, dtype=np.int64) 182 | centroids_vals = np.zeros((B, C, npoint)) 183 | for i in range(npoint): 184 | centroids[:, i] = farthest # save current chosen point index 185 | centroid = xyz[batch_indices, :, farthest].reshape(B, C, 1) # get the current chosen point value 186 | centroids_vals[:, :, i] = centroid[:, :, 0].copy() 187 | dist = np.sum((xyz - centroid) ** 2, 1) # euclidean distance of points from the current centroid 188 | mask = dist < distance # save index of all point that are closer than the current max distance 189 | distance[mask] = dist[mask] # save the minimal distance of each point from all points that were chosen until now 190 | farthest = np.argmax(distance, axis=1) # get the index of the point farthest away 191 | return centroids, centroids_vals 192 | 193 | 194 | def rotate_shape(x, axis, angle): 195 | """ 196 | Input: 197 | x: pointcloud data, [B, C, N] 198 | axis: axis to do rotation about 199 | angle: rotation angle 200 | Return: 201 | A rotated shape 202 | """ 203 | R_x = np.asarray([[1, 0, 0], [0, np.cos(angle), -np.sin(angle)], [0, np.sin(angle), np.cos(angle)]]) 204 | R_y = np.asarray([[np.cos(angle), 0, np.sin(angle)], [0, 1, 0], [-np.sin(angle), 0, np.cos(angle)]]) 205 | R_z = np.asarray([[np.cos(angle), -np.sin(angle), 0], [np.sin(angle), np.cos(angle), 0], [0, 0, 1]]) 206 | 207 | if axis == "x": 208 | return x.dot(R_x).astype('float32') 209 | elif axis == "y": 210 | return x.dot(R_y).astype('float32') 211 | else: 212 | return x.dot(R_z).astype('float32') 213 | 214 | 215 | def random_rotate_one_axis(X, axis): 216 | """ 217 | Apply random rotation about one axis 218 | Input: 219 | x: pointcloud data, [B, C, N] 220 | axis: axis to do rotation about 221 | Return: 222 | A rotated shape 223 | """ 224 | rotation_angle = np.random.uniform() * 2 * np.pi 225 | cosval = np.cos(rotation_angle) 226 | sinval = np.sin(rotation_angle) 227 | if axis == 'x': 228 | R_x = [[1, 0, 0], [0, cosval, -sinval], [0, sinval, cosval]] 229 | X = np.matmul(X, R_x) 230 | elif axis == 'y': 231 | R_y = [[cosval, 0, sinval], [0, 1, 0], [-sinval, 0, cosval]] 232 | X = np.matmul(X, R_y) 233 | else: 234 | R_z = [[cosval, -sinval, 0], [sinval, cosval, 0], [0, 0, 1]] 235 | X = np.matmul(X, R_z) 236 | return X.astype('float32') 237 | 238 | 239 | def translate_pointcloud(pointcloud): 240 | """ 241 | Input: 242 | pointcloud: pointcloud data, [B, C, N] 243 | Return: 244 | A translated shape 245 | """ 246 | xyz1 = np.random.uniform(low=2. / 3., high=3. / 2., size=[3]) 247 | xyz2 = np.random.uniform(low=-0.2, high=0.2, size=[3]) 248 | 249 | translated_pointcloud = np.add(np.multiply(pointcloud, xyz1), xyz2).astype('float32') 250 | return translated_pointcloud 251 | 252 | 253 | def jitter_pointcloud(pointcloud, sigma=0.01, clip=0.02): 254 | """ 255 | Input: 256 | pointcloud: pointcloud data, [B, C, N] 257 | sigma: 258 | clip: 259 | Return: 260 | A jittered shape 261 | """ 262 | N, C = pointcloud.shape 263 | pointcloud += np.clip(sigma * np.random.randn(N, C), -1 * clip, clip) 264 | return pointcloud.astype('float32') 265 | 266 | 267 | def scale_to_unit_cube(x): 268 | """ 269 | Input: 270 | x: pointcloud data, [B, C, N] 271 | Return: 272 | A point cloud scaled to unit cube 273 | """ 274 | if len(x) == 0: 275 | return x 276 | 277 | centroid = np.mean(x, axis=0) 278 | x -= centroid 279 | furthest_distance = np.max(np.sqrt(np.sum(abs(x) ** 2, axis=-1))) 280 | x /= furthest_distance 281 | return x 282 | 283 | 284 | def dropout_points(x, norm_curv, num_points): 285 | """ 286 | Randomly dropout num_points, and randomly duplicate num_points 287 | Input: 288 | x: pointcloud data, [B, C, N] 289 | Return: 290 | A point cloud dropouted num_points 291 | """ 292 | ind = random.sample(range(0, x.shape[1]), num_points) 293 | ind_dpl = random.sample(range(0, x.shape[1]), num_points) 294 | x[:, ind, :] = x[:, ind_dpl, :] 295 | norm_curv[:, ind, :] = norm_curv[:, ind_dpl, :] 296 | return x, norm_curv 297 | 298 | 299 | def remove_region_points(x, norm_curv, device): 300 | """ 301 | Remove all points of a randomly selected region in the point cloud. 302 | Input: 303 | X - Point cloud [B, N, C] 304 | norm_curv: norm and curvature, [B, N, C] 305 | Return: 306 | X - Point cloud where points in a certain region are removed 307 | """ 308 | # get points' regions 309 | regions = assign_region_to_point(x, device) # [B, N] N:the number of region_id 310 | n = NREGIONS 311 | region_ids = np.random.permutation(n ** 3) 312 | for b in range(x.shape[0]): 313 | for i in region_ids: 314 | ind = regions[b, :] == i # [N] 315 | # if there are enough points in the region 316 | if torch.sum(ind) >= 50: 317 | num_points = int(torch.sum(ind)) 318 | rnd_ind = random.sample(range(0, x.shape[1]), num_points) 319 | x[b, ind, :] = x[b, rnd_ind, :] 320 | norm_curv[b, ind, :] = norm_curv[b, rnd_ind, :] 321 | break # move to the next shape in the batch 322 | return x, norm_curv 323 | 324 | 325 | def extract_feature_points(x, norm_curv, num_points, device="cuda:0"): 326 | """ 327 | Input: 328 | x: pointcloud data, [B, N, C] 329 | norm_curv: norm and curvature, [B, N, C] 330 | Return: 331 | Feature points, [B, num_points, C] 332 | """ 333 | IND = torch.zeros([x.size(0), num_points]).to(device) 334 | fea_pc = torch.zeros([x.size(0), num_points, x.size(2)]).to(device) 335 | for b in range(x.size(0)): 336 | curv = norm_curv[b, :, -1] 337 | curv = abs(curv) 338 | ind = torch.argsort(curv) 339 | ind = ind[:num_points] 340 | IND[b] = ind 341 | fea_pc[b] = x[b, ind, :] 342 | return fea_pc 343 | 344 | 345 | def pc2voxel(x): 346 | # Args: 347 | # x: size n x F where n is the number of points and F is feature size 348 | # Returns: 349 | # voxel: N x N x N x (K x 3 + 1) 350 | # index: N x N x N x K 351 | num_points = x.shape[0] 352 | data = np.zeros((N, N, N, NUM_FEATURES), dtype=np.float32) 353 | index = np.zeros((N, N, N, K), dtype=np.float32) 354 | x /= 1.05 355 | idx = np.floor((x + 1.0) / 2.0 * N) 356 | L = [[] for _ in range(N * N * N)] 357 | for p in range(num_points): 358 | k = int(idx[p, 0] * N * N + idx[p, 1] * N + idx[p, 2]) 359 | L[k].append(p) 360 | for i in range(N): 361 | for j in range(N): 362 | for k in range(N): 363 | u = int(i * N * N + j * N + k) 364 | if not L[u]: 365 | data[i, j, k, :] = np.zeros((NUM_FEATURES), dtype=np.float32) 366 | elif len(L[u]) >= K: 367 | choice = np.random.choice(L[u], size=K, replace=False) 368 | local_points = x[choice, :] - np.array( 369 | [-1.0 + (i + 0.5) * 2.0 / N, -1.0 + (j + 0.5) * 2.0 / N, 370 | -1.0 + (k + 0.5) * 2.0 / N], dtype=np.float32) 371 | data[i, j, k, 0: K * 3] = np.reshape(local_points, (K * 3)) 372 | data[i, j, k, K * 3] = 1.0 373 | index[i, j, k, :] = choice 374 | else: 375 | choice = np.random.choice(L[u], size=K, replace=True) 376 | local_points = x[choice, :] - np.array( 377 | [-1.0 + (i + 0.5) * 2.0 / N, -1.0 + (j + 0.5) * 2.0 / N, 378 | -1.0 + (k + 0.5) * 2.0 / N], dtype=np.float32) 379 | data[i, j, k, 0: K * 3] = np.reshape(local_points, (K * 3)) 380 | data[i, j, k, K * 3] = 1.0 381 | index[i, j, k, :] = choice 382 | return data, index 383 | 384 | 385 | def pc2voxel_B(x): 386 | """ 387 | Input: 388 | x: pointcloud data, [B, num_points, C] 389 | Return: 390 | voxel: N x N x N x (K x 3 + 1) 391 | index: N x N x N x K 392 | """ 393 | batch_size = x.shape[0] 394 | Data = np.zeros((batch_size, N, N, N, NUM_FEATURES), dtype=np.float32) 395 | Index = np.zeros((batch_size, N, N, N, K), dtype=np.float32) 396 | x = scale_to_unit_cube(x) 397 | for b in range(batch_size): 398 | pc = x[b] 399 | data, index = pc2voxel(pc) 400 | Data[b] = data 401 | Index[b] = index 402 | return Data, Index 403 | 404 | 405 | def pc2image(X, axis, RESOLUTION=32): 406 | """ 407 | Input: 408 | X: point cloud [N, C] 409 | axis: axis to do projection about 410 | Return: 411 | Y: image projected by 'X' along 'axis'. [32, 32] 412 | """ 413 | 414 | n = RESOLUTION 415 | d = 2 / n 416 | X_clip = np.clip(X, -0.99999999, 0.99999999) # [N, C] 417 | Y = np.zeros((n, n), dtype=np.float32) # label matrix [n, n] 418 | if axis == 'x': 419 | for y in range(n): 420 | for z in range(n): 421 | # lt= lower threshold, ut = upper threshold 422 | y_axis_lt = -1 + y * d < X_clip[:, 1] # [N] 423 | y_axis_ut = X_clip[:, 1] < -1 + (y + 1) * d # [N] 424 | z_axis_lt = -1 + z * d < X_clip[:, 2] # [N] 425 | z_axis_ut = X_clip[:, 2] < -1 + (z + 1) * d # [N] 426 | # get a mask indicating for each coordinate of each point of each shape whether 427 | # it falls inside the current inspected ranges 428 | in_range = np.concatenate([y_axis_lt, y_axis_ut, z_axis_lt, z_axis_ut], 0).reshape(4, -1) # [4, N] 429 | # per each point decide if it falls in the current region only if in all 430 | # ranges the value is 1 (i.e., it falls inside all the inspected ranges) 431 | mask = np.min(in_range, 0) # [N]: [False, ..., True, ...] 432 | if np.sum(mask) == 0: 433 | continue 434 | Y[y, z] = (X_clip[mask, 0] + 1).mean() 435 | if axis == 'y': 436 | for x in range(n): 437 | for z in range(n): 438 | # lt= lower threshold, ut = upper threshold 439 | x_axis_lt = -1 + x * d < X_clip[:, 0] # [N] 440 | x_axis_ut = X_clip[:, 0] < -1 + (x + 1) * d # [N] 441 | z_axis_lt = -1 + z * d < X_clip[:, 2] # [N] 442 | z_axis_ut = X_clip[:, 2] < -1 + (z + 1) * d # [N] 443 | # get a mask indicating for each coordinate of each point of each shape whether 444 | # it falls inside the current inspected ranges 445 | in_range = np.concatenate([x_axis_lt, x_axis_ut, z_axis_lt, z_axis_ut], 0).reshape(4, -1) # [4, N] 446 | # per each point decide if it falls in the current region only if in all 447 | # ranges the value is 1 (i.e., it falls inside all the inspected ranges) 448 | mask = np.min(in_range, 0) # [N] 449 | if np.sum(mask) == 0: 450 | continue 451 | Y[x, z] = (X_clip[mask, 1] + 1).mean() 452 | if axis == 'z': 453 | for x in range(n): 454 | for y in range(n): 455 | # lt= lower threshold, ut = upper threshold 456 | x_axis_lt = -1 + x * d < X_clip[:, 0] # [N] 457 | x_axis_ut = X_clip[:, 0] < -1 + (x + 1) * d # [N] 458 | y_axis_lt = -1 + y * d < X_clip[:, 1] # [N] 459 | y_axis_ut = X_clip[:, 1] < -1 + (y + 1) * d # [N] 460 | # get a mask indicating for each coordinate of each point of each shape whether 461 | # it falls inside the current inspected ranges 462 | in_range = np.concatenate([x_axis_lt, x_axis_ut, y_axis_lt, y_axis_ut], 0).reshape(4, -1) # [4, N] 463 | # per each point decide if it falls in the current region only if in all 464 | # ranges the value is 1 (i.e., it falls inside all the inspected ranges) 465 | mask = np.min(in_range, 0) # [N] 466 | if np.sum(mask) == 0: 467 | continue 468 | Y[x, y] = (X_clip[mask, 2] + 1).mean() 469 | 470 | return Y 471 | 472 | 473 | def pc2image_B(X, axis, device='cuda:0', RESOLUTION=32): 474 | """ 475 | Input: 476 | X: point cloud [B, C, N] 477 | axis: axis to do projection about 478 | Return: 479 | Y: image projected by 'X' along 'axis'. [B, 32, 32] 480 | """ 481 | n = RESOLUTION 482 | B = X.size(0) 483 | X = X.permute(0, 2, 1) # [B, N, C] 484 | X = X.cpu().numpy() 485 | Y = np.zeros((B, n, n), dtype=np.float32) # label matrix [B, n, n] 486 | for b in range(B): 487 | Y[b] = pc2image(X[b], axis, n) 488 | Y = torch.from_numpy(Y).to(device) 489 | return Y 490 | 491 | 492 | class PointcloudToTensor(object): 493 | def __call__(self, points): 494 | return torch.from_numpy(points).float() 495 | 496 | 497 | class PointcloudScale(object): 498 | def __init__(self, lo=0.8, hi=1.25): 499 | self.lo, self.hi = lo, hi 500 | 501 | def __call__(self, points): 502 | scaler = np.random.uniform(self.lo, self.hi) 503 | points[:, 0:3] *= scaler 504 | return points 505 | 506 | 507 | def angle_axis(angle, axis): 508 | u = axis / np.linalg.norm(axis) 509 | cosval, sinval = np.cos(angle), np.sin(angle) 510 | 511 | cross_prod_mat = np.array([[0.0, -u[2], u[1]], 512 | [u[2], 0.0, -u[0]], 513 | [-u[1], u[0], 0.0]]) 514 | 515 | R = torch.from_numpy( 516 | cosval * np.eye(3) 517 | + sinval * cross_prod_mat 518 | + (1.0 - cosval) * np.outer(u, u) 519 | ) 520 | return R.float() 521 | 522 | 523 | class PointcloudRotate(object): 524 | def __init__(self, axis=np.array([0.0, 1.0, 0.0])): 525 | self.axis = axis 526 | 527 | def __call__(self, points): 528 | rotation_angle = np.random.uniform() * 2 * np.pi 529 | rotation_matrix = angle_axis(rotation_angle, self.axis) 530 | 531 | normals = points.size(1) > 3 532 | if not normals: 533 | return torch.matmul(points, rotation_matrix.t()) 534 | else: 535 | pc_xyz = points[:, 0:3] 536 | pc_normals = points[:, 3:] 537 | points[:, 0:3] = torch.matmul(pc_xyz, rotation_matrix.t()) 538 | points[:, 3:] = torch.matmul(pc_normals, rotation_matrix.t()) 539 | 540 | return points 541 | 542 | 543 | class PointcloudRotatePerturbation(object): 544 | def __init__(self, angle_sigma=0.06, angle_clip=0.18): 545 | self.angle_sigma, self.angle_clip = angle_sigma, angle_clip 546 | 547 | def _get_angles(self): 548 | angles = np.clip( 549 | self.angle_sigma * np.random.randn(3), -self.angle_clip, self.angle_clip 550 | ) 551 | 552 | return angles 553 | 554 | def __call__(self, points): 555 | angles = self._get_angles() 556 | Rx = angle_axis(angles[0], np.array([1.0, 0.0, 0.0])) 557 | Ry = angle_axis(angles[1], np.array([0.0, 1.0, 0.0])) 558 | Rz = angle_axis(angles[2], np.array([0.0, 0.0, 1.0])) 559 | 560 | rotation_matrix = torch.matmul(torch.matmul(Rz, Ry), Rx) 561 | 562 | normals = points.size(1) > 3 563 | if not normals: 564 | return torch.matmul(points, rotation_matrix.t()) 565 | else: 566 | pc_xyz = points[:, 0:3] 567 | pc_normals = points[:, 3:] 568 | points[:, 0:3] = torch.matmul(pc_xyz, rotation_matrix.t()) 569 | points[:, 3:] = torch.matmul(pc_normals, rotation_matrix.t()) 570 | 571 | return points 572 | 573 | 574 | class PointcloudTranslate(object): 575 | def __init__(self, translate_range=0.1): 576 | self.translate_range = translate_range 577 | 578 | def __call__(self, points): 579 | translation = np.random.uniform(-self.translate_range, self.translate_range) 580 | points[:, 0:3] += translation 581 | return points 582 | 583 | 584 | class PointcloudJitter(object): 585 | def __init__(self, std=0.01, clip=0.05): 586 | self.std, self.clip = std, clip 587 | 588 | def __call__(self, points): 589 | jittered_data = ( 590 | points.new(points.size(0), 3) 591 | .normal_(mean=0.0, std=self.std) 592 | .clamp_(-self.clip, self.clip) 593 | ) 594 | points[:, 0:3] += jittered_data 595 | return points 596 | 597 | 598 | def normal_pc(pc): 599 | pc_mean = pc.mean(axis=0) 600 | pc = pc - pc_mean 601 | pc_L_max = np.max(np.sqrt(np.sum(abs(pc ** 2), axis=-1))) 602 | pc = pc/pc_L_max 603 | return pc 604 | 605 | 606 | if __name__ == '__main__': 607 | lookup = region_mean(3) 608 | print(lookup.shape) 609 | x = torch.randn([2, 3, 4]) 610 | print(x) 611 | dropout_points(x, 2) 612 | print(x) 613 | -------------------------------------------------------------------------------- /utils/pc_utils_Norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import random 4 | import pdb 5 | 6 | eps = 10e-4 7 | eps2 = 10e-6 8 | KL_SCALER = 10.0 9 | MIN_POINTS = 20 10 | RADIUS = 0.5 11 | NREGIONS = 3 12 | NROTATIONS = 4 13 | N = 16 14 | K = 4 15 | NUM_FEATURES = K * 3 + 1 16 | 17 | 18 | def region_mean(num_regions): 19 | """ 20 | Input: 21 | num_regions - number of regions 22 | Return: 23 | means of regions 24 | """ 25 | 26 | n = num_regions 27 | lookup = [] 28 | d = 2 / n # the cube size length 29 | # construct all possibilities on the line [-1, 1] in the 3 axes 30 | for i in range(n - 1, -1, -1): 31 | for j in range(n - 1, -1, -1): 32 | for k in range(n-1, -1, -1): 33 | lookup.append([1 - d * (i + 0.5), 1 - d * (j + 0.5), 1 - d * (k + 0.5)]) 34 | lookup = np.array(lookup) # n**3 x 3 35 | return lookup 36 | 37 | 38 | def assign_region_to_point(X, device='cuda:0', NREGIONS=3): 39 | """ 40 | Input: 41 | X: point cloud [B, C, N] 42 | device: cuda:0, cpu 43 | Return: 44 | Y: Region assignment per point [B, N] 45 | """ 46 | 47 | n = NREGIONS 48 | d = 2 / n 49 | X_clip = torch.clamp(X, -0.99999999, 0.99999999) # [B, C, N] 50 | batch_size, _, num_points = X.shape 51 | Y = torch.zeros((batch_size, num_points), device=device, dtype=torch.long) # label matrix [B, N] 52 | 53 | # The code below partitions all points in the shape to voxels. 54 | # At each iteration find per axis the lower threshold and the upper threshold values 55 | # of the range according to n (e.g., if n=3, then: -1, -1/3, 1/3, 1 - there are 3 ranges) 56 | # and save points in the corresponding voxel if they fall in the examined range for all axis. 57 | region_id = 0 58 | for x in range(n): 59 | for y in range(n): 60 | for z in range(n): 61 | # lt= lower threshold, ut = upper threshold 62 | x_axis_lt = -1 + x * d < X_clip[:, 0, :] # [B, 1, N] 63 | x_axis_ut = X_clip[:, 0, :] < -1 + (x + 1) * d # [B, 1, N] 64 | y_axis_lt = -1 + y * d < X_clip[:, 1, :] # [B, 1, N] 65 | y_axis_ut = X_clip[:, 1, :] < -1 + (y + 1) * d # [B, 1, N] 66 | z_axis_lt = -1 + z * d < X_clip[:, 2, :] # [B, 1, N] 67 | z_axis_ut = X_clip[:, 2, :] < -1 + (z + 1) * d # [B, 1, N] 68 | # get a mask indicating for each coordinate of each point of each shape whether 69 | # it falls inside the current inspected ranges 70 | in_range = torch.cat([x_axis_lt, x_axis_ut, y_axis_lt, y_axis_ut, 71 | z_axis_lt, z_axis_ut], dim=1).view(batch_size, 6, -1) # [B, 6, N] 72 | # per each point decide if it falls in the current region only if in all 73 | # ranges the value is 1 (i.e., it falls inside all the inspected ranges) 74 | mask, _ = torch.min(in_range, dim=1) # [B, N] 75 | Y[mask] = region_id # label each point with the region id 76 | region_id += 1 77 | return Y 78 | 79 | 80 | def collapse_to_point(x, device): 81 | """ 82 | Input: 83 | X: point cloud [C, N] 84 | device: cuda:0, cpu 85 | Return: 86 | x: A deformed point cloud. Randomly sample a point and cluster all point 87 | within a radius of RADIUS around it with some Gaussian noise. 88 | indices: the points that were clustered around x 89 | """ 90 | # get pairwise distances 91 | inner = -2 * torch.matmul(x.transpose(1, 0), x) 92 | xx = torch.sum(x ** 2, dim=0, keepdim=True) 93 | pairwise_distance = xx + inner + xx.transpose(1, 0) 94 | 95 | # get mask of points in threshold 96 | mask = pairwise_distance.clone() 97 | mask[mask > RADIUS ** 2] = 100 98 | mask[mask <= RADIUS ** 2] = 1 99 | mask[mask == 100] = 0 100 | 101 | # Choose only from points that have more than MIN_POINTS within a RADIUS of them 102 | pts_pass = torch.sum(mask, dim=1) 103 | pts_pass[pts_pass < MIN_POINTS] = 0 104 | pts_pass[pts_pass >= MIN_POINTS] = 1 105 | indices = (pts_pass != 0).nonzero() 106 | 107 | # pick a point from the ones that passed the threshold 108 | point_ind = np.random.choice(indices.squeeze().cpu().numpy()) 109 | point = x[:, point_ind] # get point 110 | point_mask = mask[point_ind, :] # get point mask 111 | 112 | # draw a gaussian centered at the point for points falling in the region 113 | indices = (point_mask != 0).nonzero().squeeze() 114 | x[:, indices] = torch.tensor(draw_from_gaussian(point.cpu().numpy(), len(indices)), dtype=torch.float).to(device) 115 | return x, indices 116 | 117 | 118 | def draw_from_gaussian(mean, num_points): 119 | """ 120 | Input: 121 | mean: a numpy vector 122 | num_points: number of points to sample 123 | Return: 124 | points sampled around the mean with small std 125 | """ 126 | return np.random.multivariate_normal(mean, np.eye(3) * 0.1, num_points).T # 0.001 127 | 128 | 129 | def draw_from_uniform(gap, region_mean, num_points): 130 | """ 131 | Input: 132 | gap: a numpy vector of region x,y,z length in each direction from the mean 133 | region_mean: 134 | num_points: number of points to sample 135 | Return: 136 | points sampled uniformly in the region 137 | """ 138 | return np.random.uniform(region_mean - gap, region_mean + gap, (num_points, 3)).T 139 | 140 | 141 | def farthest_point_sample(xyz, npoint, device='cuda:0'): 142 | """ 143 | Input: 144 | xyz: pointcloud data, [B, C, N] 145 | npoint: number of samples 146 | Return: 147 | centroids: sampled pointcloud index, [B, npoint] 148 | """ 149 | 150 | B, C, N = xyz.shape 151 | centroids = torch.zeros(B, npoint, dtype=torch.long).to(device) # B x npoint 152 | distance = torch.ones(B, N).to(device) * 1e10 153 | farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) 154 | batch_indices = torch.arange(B, dtype=torch.long).to(device) 155 | centroids_vals = torch.zeros(B, C, npoint).to(device) 156 | for i in range(npoint): 157 | centroids[:, i] = farthest # save current chosen point index 158 | centroid = xyz[batch_indices, :, farthest].view(B, C, 1) # get the current chosen point value 159 | centroids_vals[:, :, i] = centroid[:, :, 0].clone() 160 | dist = torch.sum((xyz - centroid) ** 2, 1) # euclidean distance of points from the current centroid 161 | mask = dist < distance # save index of all point that are closer than the current max distance 162 | distance[mask] = dist[mask] # save the minimal distance of each point from all points that were chosen until now 163 | farthest = torch.max(distance, -1)[1] # get the index of the point farthest away 164 | return centroids, centroids_vals 165 | 166 | 167 | def farthest_point_sample_np(xyz, norm_curv, npoint): 168 | """ 169 | Input: 170 | xyz: pointcloud data, [B, C, N] 171 | npoint: number of samples 172 | Return: 173 | centroids: sampled pointcloud index, [B, npoint] 174 | """ 175 | 176 | B, C, N = xyz.shape 177 | centroids = np.zeros((B, npoint), dtype=np.int64) 178 | distance = np.ones((B, N)) * 1e10 179 | farthest = np.random.randint(0, N, (B,), dtype=np.int64) 180 | batch_indices = np.arange(B, dtype=np.int64) 181 | centroids_vals = np.zeros((B, C, npoint)) 182 | centroids_norm_curv_vals = np.zeros((B, 4, npoint)) 183 | for i in range(npoint): 184 | centroids[:, i] = farthest # save current chosen point index 185 | centroid = xyz[batch_indices, :, farthest].reshape(B, C, 1) # get the current chosen point value 186 | centroid_norm_curv = norm_curv[batch_indices, :, farthest].reshape(B, 4, 1) 187 | centroids_vals[:, :, i] = centroid[:, :, 0].copy() 188 | centroids_norm_curv_vals[:, :, i] = centroid_norm_curv[:, :, 0].copy() 189 | dist = np.sum((xyz - centroid) ** 2, 1) # euclidean distance of points from the current centroid 190 | mask = dist < distance # save index of all point that are closer than the current max distance 191 | distance[mask] = dist[mask] # save the minimal distance of each point from all points that were chosen until now 192 | farthest = np.argmax(distance, axis=1) # get the index of the point farthest away 193 | return centroids, centroids_vals, centroids_norm_curv_vals 194 | 195 | 196 | def farthest_point_sample_no_curv_np(xyz, npoint): 197 | """ 198 | Input: 199 | xyz: pointcloud data, [B, C, N] 200 | npoint: number of samples 201 | Return: 202 | centroids: sampled pointcloud index, [B, npoint] 203 | """ 204 | 205 | B, C, N = xyz.shape 206 | centroids = np.zeros((B, npoint), dtype=np.int64) 207 | distance = np.ones((B, N)) * 1e10 208 | farthest = np.random.randint(0, N, (B,), dtype=np.int64) 209 | batch_indices = np.arange(B, dtype=np.int64) 210 | centroids_vals = np.zeros((B, C, npoint)) 211 | for i in range(npoint): 212 | centroids[:, i] = farthest # save current chosen point index 213 | centroid = xyz[batch_indices, :, farthest].reshape(B, C, 1) # get the current chosen point value 214 | centroids_vals[:, :, i] = centroid[:, :, 0].copy() 215 | dist = np.sum((xyz - centroid) ** 2, 1) # euclidean distance of points from the current centroid 216 | mask = dist < distance # save index of all point that are closer than the current max distance 217 | distance[mask] = dist[mask] # save the minimal distance of each point from all points that were chosen until now 218 | farthest = np.argmax(distance, axis=1) # get the index of the point farthest away 219 | return centroids, centroids_vals 220 | 221 | 222 | def rotate_shape(x, axis, angle): 223 | """ 224 | Input: 225 | x: pointcloud data, [B, C, N] 226 | axis: axis to do rotation about 227 | angle: rotation angle 228 | Return: 229 | A rotated shape 230 | """ 231 | R_x = np.asarray([[1, 0, 0], [0, np.cos(angle), -np.sin(angle)], [0, np.sin(angle), np.cos(angle)]]) 232 | R_y = np.asarray([[np.cos(angle), 0, np.sin(angle)], [0, 1, 0], [-np.sin(angle), 0, np.cos(angle)]]) 233 | R_z = np.asarray([[np.cos(angle), -np.sin(angle), 0], [np.sin(angle), np.cos(angle), 0], [0, 0, 1]]) 234 | 235 | if axis == "x": 236 | return x.dot(R_x).astype('float32') 237 | elif axis == "y": 238 | return x.dot(R_y).astype('float32') 239 | else: 240 | return x.dot(R_z).astype('float32') 241 | 242 | 243 | def random_rotate_one_axis(X, axis): 244 | """ 245 | Apply random rotation about one axis 246 | Input: 247 | x: pointcloud data, [B, C, N] 248 | axis: axis to do rotation about 249 | Return: 250 | A rotated shape 251 | """ 252 | rotation_angle = np.random.uniform() * 2 * np.pi 253 | cosval = np.cos(rotation_angle) 254 | sinval = np.sin(rotation_angle) 255 | if axis == 'x': 256 | R_x = [[1, 0, 0], [0, cosval, -sinval], [0, sinval, cosval]] 257 | X = np.matmul(X, R_x) 258 | elif axis == 'y': 259 | R_y = [[cosval, 0, sinval], [0, 1, 0], [-sinval, 0, cosval]] 260 | X = np.matmul(X, R_y) 261 | else: 262 | R_z = [[cosval, -sinval, 0], [sinval, cosval, 0], [0, 0, 1]] 263 | X = np.matmul(X, R_z) 264 | return X.astype('float32') 265 | 266 | 267 | def translate_pointcloud(pointcloud): 268 | """ 269 | Input: 270 | pointcloud: pointcloud data, [B, C, N] 271 | Return: 272 | A translated shape 273 | """ 274 | xyz1 = np.random.uniform(low=2./3., high=3./2., size=[3]) 275 | xyz2 = np.random.uniform(low=-0.2, high=0.2, size=[3]) 276 | 277 | translated_pointcloud = np.add(np.multiply(pointcloud, xyz1), xyz2).astype('float32') 278 | return translated_pointcloud 279 | 280 | 281 | def jitter_pointcloud(pointcloud, sigma=0.01, clip=0.02): 282 | """ 283 | Input: 284 | pointcloud: pointcloud data, [B, C, N] 285 | sigma: 286 | clip: 287 | Return: 288 | A jittered shape 289 | """ 290 | N, C = pointcloud.shape 291 | pointcloud += np.clip(sigma * np.random.randn(N, C), -1*clip, clip) 292 | return pointcloud.astype('float32') 293 | 294 | 295 | def scale_to_unit_cube(x): 296 | """ 297 | Input: 298 | x: pointcloud data, [B, C, N] 299 | Return: 300 | A point cloud scaled to unit cube 301 | """ 302 | if len(x) == 0: 303 | return x 304 | 305 | centroid = np.mean(x, axis=0) 306 | x -= centroid 307 | furthest_distance = np.max(np.sqrt(np.sum(abs(x) ** 2, axis=-1))) 308 | x /= furthest_distance 309 | return x 310 | 311 | 312 | def dropout_points(x, norm_curv, num_points): 313 | """ 314 | Randomly dropout num_points, and randomly duplicate num_points 315 | Input: 316 | x: pointcloud data, [B, C, N] 317 | Return: 318 | A point cloud dropouted num_points 319 | """ 320 | ind = random.sample(range(0, x.shape[1]), num_points) 321 | ind_dpl = random.sample(range(0, x.shape[1]), num_points) 322 | x[:, ind, :] = x[:, ind_dpl, :] 323 | norm_curv[:, ind, :] = norm_curv[:, ind_dpl, :] 324 | return x, norm_curv 325 | 326 | 327 | def remove_region_points(x, norm_curv, device): 328 | """ 329 | Remove all points of a randomly selected region in the point cloud. 330 | Input: 331 | X - Point cloud [B, N, C] 332 | norm_curv: norm and curvature, [B, N, C] 333 | Return: 334 | X - Point cloud where points in a certain region are removed 335 | """ 336 | # get points' regions 337 | regions = assign_region_to_point(x, device) # [B, N] N:the number of region_id 338 | n = NREGIONS 339 | region_ids = np.random.permutation(n ** 3) 340 | for b in range(x.shape[0]): 341 | for i in region_ids: 342 | ind = regions[b, :] == i # [N] 343 | # if there are enough points in the region 344 | if torch.sum(ind) >= 50: 345 | num_points = int(torch.sum(ind)) 346 | rnd_ind = random.sample(range(0, x.shape[1]), num_points) 347 | x[b, ind, :] = x[b, rnd_ind, :] 348 | norm_curv[b, ind, :] = norm_curv[b, rnd_ind, :] 349 | break # move to the next shape in the batch 350 | return x, norm_curv 351 | 352 | 353 | def extract_feature_points(x, norm_curv, num_points, device="cuda:0"): 354 | """ 355 | Input: 356 | x: pointcloud data, [B, N, C] 357 | norm_curv: norm and curvature, [B, N, C] 358 | Return: 359 | Feature points, [B, num_points, C] 360 | """ 361 | IND = torch.zeros([x.size(0), num_points]).to(device) 362 | fea_pc = torch.zeros([x.size(0), num_points, x.size(2)]).to(device) 363 | for b in range(x.size(0)): 364 | curv = norm_curv[b, :, -1] 365 | curv = abs(curv) 366 | ind = torch.argsort(curv) 367 | ind = ind[:num_points] 368 | IND[b] = ind 369 | fea_pc[b] = x[b, ind, :] 370 | return fea_pc 371 | 372 | 373 | def pc2voxel(x): 374 | # Args: 375 | # x: size n x F where n is the number of points and F is feature size 376 | # Returns: 377 | # voxel: N x N x N x (K x 3 + 1) 378 | # index: N x N x N x K 379 | num_points = x.shape[0] 380 | data = np.zeros((N, N, N, NUM_FEATURES), dtype=np.float32) 381 | index = np.zeros((N, N, N, K), dtype=np.float32) 382 | x /= 1.05 383 | idx = np.floor((x + 1.0) / 2.0 * N) 384 | L = [[] for _ in range(N * N * N)] 385 | for p in range(num_points): 386 | k = int(idx[p, 0] * N * N + idx[p, 1] * N + idx[p, 2]) 387 | L[k].append(p) 388 | for i in range(N): 389 | for j in range(N): 390 | for k in range(N): 391 | u = int(i * N * N + j * N + k) 392 | if not L[u]: 393 | data[i, j, k, :] = np.zeros((NUM_FEATURES), dtype=np.float32) 394 | elif len(L[u]) >= K: 395 | choice = np.random.choice(L[u], size=K, replace=False) 396 | local_points = x[choice, :] - np.array( 397 | [-1.0 + (i + 0.5) * 2.0 / N, -1.0 + (j + 0.5) * 2.0 / N, 398 | -1.0 + (k + 0.5) * 2.0 / N], dtype=np.float32) 399 | data[i, j, k, 0: K * 3] = np.reshape(local_points, (K * 3)) 400 | data[i, j, k, K * 3] = 1.0 401 | index[i, j, k, :] = choice 402 | else: 403 | choice = np.random.choice(L[u], size=K, replace=True) 404 | local_points = x[choice, :] - np.array( 405 | [-1.0 + (i + 0.5) * 2.0 / N, -1.0 + (j + 0.5) * 2.0 / N, 406 | -1.0 + (k + 0.5) * 2.0 / N], dtype=np.float32) 407 | data[i, j, k, 0: K * 3] = np.reshape(local_points, (K * 3)) 408 | data[i, j, k, K * 3] = 1.0 409 | index[i, j, k, :] = choice 410 | return data, index 411 | 412 | 413 | def pc2voxel_B(x): 414 | """ 415 | Input: 416 | x: pointcloud data, [B, num_points, C] 417 | Return: 418 | voxel: N x N x N x (K x 3 + 1) 419 | index: N x N x N x K 420 | """ 421 | batch_size = x.shape[0] 422 | Data = np.zeros((batch_size, N, N, N, NUM_FEATURES), dtype=np.float32) 423 | Index = np.zeros((batch_size, N, N, N, K), dtype=np.float32) 424 | x = scale_to_unit_cube(x) 425 | for b in range(batch_size): 426 | pc = x[b] 427 | data, index = pc2voxel(pc) 428 | Data[b] = data 429 | Index[b] = index 430 | return Data, Index 431 | 432 | 433 | def pc2image(X, axis, RESOLUTION=32): 434 | """ 435 | Input: 436 | X: point cloud [N, C] 437 | axis: axis to do projection about 438 | Return: 439 | Y: image projected by 'X' along 'axis'. [32, 32] 440 | """ 441 | 442 | n = RESOLUTION 443 | d = 2 / n 444 | X_clip = np.clip(X, -0.99999999, 0.99999999) # [N, C] 445 | Y = np.zeros((n, n), dtype=np.float32) # label matrix [n, n] 446 | if axis == 'x': 447 | for y in range(n): 448 | for z in range(n): 449 | # lt= lower threshold, ut = upper threshold 450 | y_axis_lt = -1 + y * d < X_clip[:, 1] # [N] 451 | y_axis_ut = X_clip[:, 1] < -1 + (y + 1) * d # [N] 452 | z_axis_lt = -1 + z * d < X_clip[:, 2] # [N] 453 | z_axis_ut = X_clip[:, 2] < -1 + (z + 1) * d # [N] 454 | # get a mask indicating for each coordinate of each point of each shape whether 455 | # it falls inside the current inspected ranges 456 | in_range = np.concatenate([y_axis_lt, y_axis_ut, z_axis_lt, z_axis_ut], 0).reshape(4, -1) # [4, N] 457 | # per each point decide if it falls in the current region only if in all 458 | # ranges the value is 1 (i.e., it falls inside all the inspected ranges) 459 | mask = np.min(in_range, 0) # [N]: [False, ..., True, ...] 460 | if np.sum(mask) == 0: 461 | continue 462 | Y[y, z] = (X_clip[mask, 0] + 1).mean() 463 | if axis == 'y': 464 | for x in range(n): 465 | for z in range(n): 466 | # lt= lower threshold, ut = upper threshold 467 | x_axis_lt = -1 + x * d < X_clip[:, 0] # [N] 468 | x_axis_ut = X_clip[:, 0] < -1 + (x + 1) * d # [N] 469 | z_axis_lt = -1 + z * d < X_clip[:, 2] # [N] 470 | z_axis_ut = X_clip[:, 2] < -1 + (z + 1) * d # [N] 471 | # get a mask indicating for each coordinate of each point of each shape whether 472 | # it falls inside the current inspected ranges 473 | in_range = np.concatenate([x_axis_lt, x_axis_ut, z_axis_lt, z_axis_ut], 0).reshape(4, -1) # [4, N] 474 | # per each point decide if it falls in the current region only if in all 475 | # ranges the value is 1 (i.e., it falls inside all the inspected ranges) 476 | mask = np.min(in_range, 0) # [N] 477 | if np.sum(mask) == 0: 478 | continue 479 | Y[x, z] = (X_clip[mask, 1] + 1).mean() 480 | if axis == 'z': 481 | for x in range(n): 482 | for y in range(n): 483 | # lt= lower threshold, ut = upper threshold 484 | x_axis_lt = -1 + x * d < X_clip[:, 0] # [N] 485 | x_axis_ut = X_clip[:, 0] < -1 + (x + 1) * d # [N] 486 | y_axis_lt = -1 + y * d < X_clip[:, 1] # [N] 487 | y_axis_ut = X_clip[:, 1] < -1 + (y + 1) * d # [N] 488 | # get a mask indicating for each coordinate of each point of each shape whether 489 | # it falls inside the current inspected ranges 490 | in_range = np.concatenate([x_axis_lt, x_axis_ut, y_axis_lt, y_axis_ut], 0).reshape(4, -1) # [4, N] 491 | # per each point decide if it falls in the current region only if in all 492 | # ranges the value is 1 (i.e., it falls inside all the inspected ranges) 493 | mask = np.min(in_range, 0) # [N] 494 | if np.sum(mask) == 0: 495 | continue 496 | Y[x, y] = (X_clip[mask, 2] + 1).mean() 497 | 498 | return Y 499 | 500 | 501 | def pc2image_B(X, axis, device='cuda:0', RESOLUTION=32): 502 | """ 503 | Input: 504 | X: point cloud [B, C, N] 505 | axis: axis to do projection about 506 | Return: 507 | Y: image projected by 'X' along 'axis'. [B, 32, 32] 508 | """ 509 | n = RESOLUTION 510 | B = X.size(0) 511 | X = X.permute(0, 2, 1) # [B, N, C] 512 | X = X.cpu().numpy() 513 | Y = np.zeros((B, n, n), dtype=np.float32) # label matrix [B, n, n] 514 | for b in range(B): 515 | Y[b] = pc2image(X[b], axis, n) 516 | Y = torch.from_numpy(Y).to(device) 517 | return Y 518 | 519 | 520 | # Down sampling: critical point layer 521 | def CPL(x, ratio): 522 | """ 523 | Input: 524 | x: points feature [N, C] 525 | ratio: down sampling ratio 526 | Return: 527 | f_out: down sampled points feature, [M, C] 528 | """ 529 | num_sample = int(np.size(x, 0) / ratio) 530 | fs = np.array([]) 531 | fr = np.array([]).astype(int) 532 | fmax = x.max(0) 533 | idx = x.argmax(0) 534 | _, d = np.unique(idx, return_index=True) 535 | uidx = np.argsort(d) 536 | for i in uidx: 537 | mask = (i == idx) 538 | val = fmax[mask].sum() 539 | fs = np.append(fs, val) 540 | fr = np.append(fr, mask.sum()) 541 | sidx = np.argsort(-fs) 542 | suidx = uidx[sidx] 543 | fr = fr[sidx] 544 | midx = np.array([]).astype(int) 545 | t = 0 546 | for i in fr: 547 | for j in range(int(i)): 548 | midx = np.append(midx, suidx[t]) 549 | t += 1 550 | rmidx = np.resize(midx, num_sample) 551 | fout = x[rmidx] 552 | return fout 553 | 554 | 555 | def CPL_B(X, ratio, device='cuda:0',): 556 | """ 557 | Input: 558 | X: points feature [B, C, N] 559 | ratio: down sampling ratio 560 | Return: 561 | F: down sampled points feature, [B, C, M] 562 | """ 563 | B, C, N = X.size() 564 | M = int(N / ratio) 565 | X = X.permute(0, 2, 1) # [B, N, C] 566 | X = X.cpu().numpy() 567 | F = np.zeros((B, M, C), dtype=np.float32) 568 | for b in range(B): 569 | F[b] = CPL(X[b], ratio) 570 | F = torch.from_numpy(F).to(device) 571 | F = F.permute(0, 2, 1) 572 | return F 573 | 574 | 575 | def sample_gumbel(shape, eps=1e-20): 576 | U = torch.rand(shape) 577 | return -torch.log(-torch.log(U + eps) + eps) 578 | 579 | 580 | def gumbel_softmax_sample(logits, temperature): 581 | y = logits + sample_gumbel(logits.size()) 582 | return torch.nn.functional.softmax(y / temperature, dim=-1) 583 | 584 | 585 | def gumbel_softmax(logits, temperature, hard=False): 586 | """ 587 | ST-gumple-softmax 588 | input: [*, n_class] 589 | return: flatten --> [*, n_class] an one-hot vector 590 | """ 591 | y = gumbel_softmax_sample(logits, temperature) 592 | 593 | if not hard: 594 | return y 595 | 596 | shape = y.size() 597 | _, ind = y.max(dim=-1) 598 | y_hard = torch.zeros_like(y).view(-1, shape[-1]) 599 | y_hard.scatter_(1, ind.view(-1, 1), 1) 600 | y_hard = y_hard.view(*shape) 601 | # Set gradients w.r.t. y_hard gradients w.r.t. y 602 | y_hard = (y_hard - y).detach() + y 603 | return y_hard 604 | 605 | 606 | def square_distance(src, dst): 607 | """ 608 | Calculate Euclid distance between each two points. 609 | src^T * dst = xn * xm + yn * ym + zn * zm; 610 | sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn; 611 | sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm; 612 | dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2 613 | = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst 614 | Input: 615 | src: source points, [B, N, C] 616 | dst: target points, [B, M, C] 617 | Output: 618 | dist: per-point square distance, [B, N, M] 619 | """ 620 | B, N, _ = src.shape 621 | _, M, _ = dst.shape 622 | dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) 623 | dist += torch.sum(src ** 2, -1).view(B, N, 1) 624 | dist += torch.sum(dst ** 2, -1).view(B, 1, M) 625 | return dist 626 | 627 | 628 | def index_points(points, idx): 629 | """ 630 | Input: 631 | points: input points data, [B, N, C] 632 | idx: sample index data, [B, S] 633 | Return: 634 | new_points:, indexed points data, [B, S, C] 635 | """ 636 | device = points.device 637 | B = points.shape[0] 638 | view_shape = list(idx.shape) 639 | view_shape[1:] = [1] * (len(view_shape) - 1) 640 | repeat_shape = list(idx.shape) 641 | repeat_shape[0] = 1 642 | batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape) 643 | new_points = points[batch_indices, idx, :] 644 | return new_points 645 | 646 | 647 | def query_ball_point(radius, nsample, xyz, new_xyz): 648 | """ 649 | Input: 650 | radius: local region radius 651 | nsample: max sample number in local region 652 | xyz: all points, [B, N, 3] 653 | new_xyz: query points, [B, S, 3] 654 | Return: 655 | group_idx: grouped points index, [B, S, nsample] 656 | """ 657 | device = xyz.device 658 | B, N, C = xyz.shape 659 | _, S, _ = new_xyz.shape 660 | group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1]) 661 | sqrdists = square_distance(new_xyz, xyz) 662 | group_idx[sqrdists > radius ** 2] = N 663 | group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample] 664 | group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample]) 665 | mask = group_idx == N 666 | group_idx[mask] = group_first[mask] 667 | return group_idx 668 | 669 | 670 | def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False): 671 | """ 672 | Input: 673 | npoint: 674 | radius: 675 | nsample: 676 | xyz: input points position data, [B, N, 3] 677 | points: input points data, [B, N, D] 678 | Return: 679 | new_xyz: sampled points position data, [B, npoint, nsample, 3] 680 | new_points: sampled points data, [B, npoint, nsample, 3+D] 681 | """ 682 | B, N, C = xyz.shape 683 | S = npoint 684 | fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C] 685 | new_xyz = index_points(xyz, fps_idx) 686 | idx = query_ball_point(radius, nsample, xyz, new_xyz) 687 | grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C] 688 | grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C) 689 | 690 | if points is not None: 691 | grouped_points = index_points(points, idx) 692 | new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D] 693 | else: 694 | new_points = grouped_xyz_norm 695 | if returnfps: 696 | return new_xyz, new_points, grouped_xyz, fps_idx 697 | else: 698 | return new_xyz, new_points 699 | 700 | 701 | def sample_and_group_all(xyz, points): 702 | """ 703 | Input: 704 | xyz: input points position data, [B, N, 3] 705 | points: input points data, [B, N, D] 706 | Return: 707 | new_xyz: sampled points position data, [B, 1, 3] 708 | new_points: sampled points data, [B, 1, N, 3+D] 709 | """ 710 | device = xyz.device 711 | B, N, C = xyz.shape 712 | new_xyz = torch.zeros(B, 1, C).to(device) 713 | grouped_xyz = xyz.view(B, 1, N, C) 714 | if points is not None: 715 | new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1) 716 | else: 717 | new_points = grouped_xyz 718 | return new_xyz, new_points 719 | 720 | 721 | if __name__ == '__main__': 722 | lookup = region_mean(3) 723 | print(lookup.shape) 724 | x = np.random.rand(2, 3, 6) # [B, C, N] 725 | print(x) 726 | x = scale_to_unit_cube(x) 727 | x = torch.from_numpy(x) 728 | print(x) 729 | #dropout_points(x, 2) 730 | y = pc2image_B(x, "x", RESOLUTION=6) 731 | print(y.shape) 732 | x = torch.stack((pc2image_B(x, "x", RESOLUTION=6), pc2image_B(x, "y", RESOLUTION=6), pc2image_B(x, "z", RESOLUTION=6)), dim=3) 733 | print(x) 734 | 735 | --------------------------------------------------------------------------------